Skip to content

Commit

Permalink
Replace unicodecsv with standard csv library (#31693)
Browse files Browse the repository at this point in the history
unicodecsv appears to be missing a license which can cause trouble (see jdunck/python-unicodecsv#80)

And it appears that this library may no longer be required.
  • Loading branch information
dstandish committed Jun 7, 2023
1 parent f7ed878 commit fbeb01c
Show file tree
Hide file tree
Showing 13 changed files with 84 additions and 117 deletions.
6 changes: 3 additions & 3 deletions airflow/providers/apache/hive/hooks/hive.py
Expand Up @@ -37,7 +37,7 @@

raise AirflowOptionalProviderFeatureException(e)

import unicodecsv as csv
import csv

from airflow.configuration import conf
from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -989,8 +989,8 @@ def to_csv(
message = None

i = 0
with open(csv_filepath, "wb") as file:
writer = csv.writer(file, delimiter=delimiter, lineterminator=lineterminator, encoding="utf-8")
with open(csv_filepath, "w", encoding="utf-8") as file:
writer = csv.writer(file, delimiter=delimiter, lineterminator=lineterminator)
try:
if output_header:
self.log.debug("Cursor description is %s", header)
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/apache/hive/transfers/mssql_to_hive.py
Expand Up @@ -18,12 +18,12 @@
"""This module contains an operator to move data from MSSQL to Hive."""
from __future__ import annotations

import csv
from collections import OrderedDict
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Sequence

import pymssql
import unicodecsv as csv

from airflow.models import BaseOperator
from airflow.providers.apache.hive.hooks.hive import HiveCliHook
Expand Down Expand Up @@ -113,8 +113,8 @@ def execute(self, context: Context):
with mssql.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(self.sql)
with NamedTemporaryFile("w") as tmp_file:
csv_writer = csv.writer(tmp_file, delimiter=self.delimiter, encoding="utf-8")
with NamedTemporaryFile(mode="w", encoding="utf-8") as tmp_file:
csv_writer = csv.writer(tmp_file, delimiter=self.delimiter)
field_dict = OrderedDict()
for col_count, field in enumerate(cursor.description, start=1):
col_position = f"Column{col_count}"
Expand Down
7 changes: 3 additions & 4 deletions airflow/providers/apache/hive/transfers/mysql_to_hive.py
Expand Up @@ -18,13 +18,13 @@
"""This module contains an operator to move data from MySQL to Hive."""
from __future__ import annotations

import csv
from collections import OrderedDict
from contextlib import closing
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Sequence

import MySQLdb
import unicodecsv as csv

from airflow.models import BaseOperator
from airflow.providers.apache.hive.hooks.hive import HiveCliHook
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(
recreate: bool = False,
partition: dict | None = None,
delimiter: str = chr(1),
quoting: str | None = None,
quoting: int | None = None,
quotechar: str = '"',
escapechar: str | None = None,
mysql_conn_id: str = "mysql_default",
Expand Down Expand Up @@ -133,7 +133,7 @@ def execute(self, context: Context):
hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id, auth=self.hive_auth)
mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)
self.log.info("Dumping MySQL query results to local file")
with NamedTemporaryFile("wb") as f:
with NamedTemporaryFile(mode="w", encoding="utf-8") as f:
with closing(mysql.get_conn()) as conn:
with closing(conn.cursor()) as cursor:
cursor.execute(self.sql)
Expand All @@ -143,7 +143,6 @@ def execute(self, context: Context):
quoting=self.quoting,
quotechar=self.quotechar if self.quoting != csv.QUOTE_NONE else None,
escapechar=self.escapechar,
encoding="utf-8",
)
field_dict = OrderedDict()
if cursor.description is not None:
Expand Down
7 changes: 3 additions & 4 deletions airflow/providers/apache/hive/transfers/vertica_to_hive.py
Expand Up @@ -18,12 +18,11 @@
"""This module contains an operator to move data from Vertica to Hive."""
from __future__ import annotations

import csv
from collections import OrderedDict
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, Sequence

import unicodecsv as csv

from airflow.models import BaseOperator
from airflow.providers.apache.hive.hooks.hive import HiveCliHook
from airflow.providers.vertica.hooks.vertica import VerticaHook
Expand Down Expand Up @@ -118,8 +117,8 @@ def execute(self, context: Context):
conn = vertica.get_conn()
cursor = conn.cursor()
cursor.execute(self.sql)
with NamedTemporaryFile("w") as f:
csv_writer = csv.writer(f, delimiter=self.delimiter, encoding="utf-8")
with NamedTemporaryFile(mode="w", encoding="utf-8") as f:
csv_writer = csv.writer(f, delimiter=self.delimiter)
field_dict = OrderedDict()
for col_count, field in enumerate(cursor.description, start=1):
col_position = f"Column{col_count}"
Expand Down
16 changes: 7 additions & 9 deletions airflow/providers/google/cloud/transfers/sql_to_gcs.py
Expand Up @@ -19,14 +19,14 @@
from __future__ import annotations

import abc
import csv
import json
import os
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Sequence

import pyarrow as pa
import pyarrow.parquet as pq
import unicodecsv as csv

from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook
Expand Down Expand Up @@ -286,12 +286,10 @@ def _write_local_data_files(self, cursor):
row = self.convert_types(schema, col_type_dict, row)
row_dict = dict(zip(schema, row))

tmp_file_handle.write(
json.dumps(row_dict, sort_keys=True, ensure_ascii=False).encode("utf-8")
)
json.dump(row_dict, tmp_file_handle, sort_keys=True, ensure_ascii=False)

# Append newline to make dumps BigQuery compatible.
tmp_file_handle.write(b"\n")
tmp_file_handle.write("\n")

# Stop if the file exceeds the file size limit.
fppos = tmp_file_handle.tell()
Expand Down Expand Up @@ -323,7 +321,7 @@ def _write_local_data_files(self, cursor):

def _get_file_to_upload(self, file_mime_type, file_no):
"""Returns a dictionary that represents the file to upload."""
tmp_file_handle = NamedTemporaryFile(delete=True)
tmp_file_handle = NamedTemporaryFile(mode="w", encoding="utf-8", delete=True)
return (
{
"file_name": self.filename.format(file_no),
Expand All @@ -347,7 +345,7 @@ def _configure_csv_file(self, file_handle, schema):
"""Configure a csv writer with the file_handle and write schema
as headers for the new file.
"""
csv_writer = csv.writer(file_handle, encoding="utf-8", delimiter=self.field_delimiter)
csv_writer = csv.writer(file_handle, delimiter=self.field_delimiter)
csv_writer.writerow(schema)
return csv_writer

Expand Down Expand Up @@ -436,8 +434,8 @@ def _write_local_schema_file(self, cursor):
self.log.info("Using schema for %s", self.schema_filename)
self.log.debug("Current schema: %s", schema)

tmp_schema_file_handle = NamedTemporaryFile(delete=True)
tmp_schema_file_handle.write(schema.encode("utf-8"))
tmp_schema_file_handle = NamedTemporaryFile(mode="w", encoding="utf-8", delete=True)
tmp_schema_file_handle.write(schema)
schema_file_to_upload = {
"file_name": self.schema_filename,
"file_handle": tmp_schema_file_handle,
Expand Down
Expand Up @@ -17,12 +17,11 @@
# under the License.
from __future__ import annotations

import csv
import os
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, Sequence

import unicodecsv as csv

from airflow.models import BaseOperator
from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook
from airflow.providers.oracle.hooks.oracle import OracleHook
Expand All @@ -46,7 +45,7 @@ class OracleToAzureDataLakeOperator(BaseOperator):
:param delimiter: field delimiter in the file.
:param encoding: encoding type for the file.
:param quotechar: Character to use in quoting.
:param quoting: Quoting strategy. See unicodecsv quoting for more information.
:param quoting: Quoting strategy. See csv library for more information.
"""

template_fields: Sequence[str] = ("filename", "sql", "sql_params")
Expand All @@ -65,7 +64,7 @@ def __init__(
delimiter: str = ",",
encoding: str = "utf-8",
quotechar: str = '"',
quoting: str = csv.QUOTE_MINIMAL,
quoting: int = csv.QUOTE_MINIMAL,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -83,11 +82,10 @@ def __init__(
self.quoting = quoting

def _write_temp_file(self, cursor: Any, path_to_save: str | bytes | int) -> None:
with open(path_to_save, "wb") as csvfile:
with open(path_to_save, "w", encoding=self.encoding) as csvfile:
csv_writer = csv.writer(
csvfile,
delimiter=self.delimiter,
encoding=self.encoding,
quotechar=self.quotechar,
quoting=self.quoting,
)
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/mysql/transfers/vertica_to_mysql.py
Expand Up @@ -17,12 +17,12 @@
# under the License.
from __future__ import annotations

import csv
from contextlib import closing
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Sequence

import MySQLdb
import unicodecsv as csv

from airflow.models import BaseOperator
from airflow.providers.mysql.hooks.mysql import MySqlHook
Expand Down Expand Up @@ -125,11 +125,11 @@ def _bulk_load_transfer(self, mysql, vertica):
with closing(conn.cursor()) as cursor:
cursor.execute(self.sql)
selected_columns = [d.name for d in cursor.description]
with NamedTemporaryFile("w") as tmpfile:
with NamedTemporaryFile("w", encoding="utf-8") as tmpfile:
self.log.info("Selecting rows from Vertica to local file %s...", tmpfile.name)
self.log.info(self.sql)

csv_writer = csv.writer(tmpfile, delimiter="\t", encoding="utf-8")
csv_writer = csv.writer(tmpfile, delimiter="\t")
for row in cursor.iterate():
csv_writer.writerow(row)
count += 1
Expand Down
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Expand Up @@ -1554,6 +1554,7 @@ undead
Undeads
ungenerated
unicode
unicodecsv
unindent
unittest
unittests
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Expand Up @@ -146,7 +146,6 @@ install_requires =
tenacity>=6.2.0,!=8.2.0
termcolor>=1.1.0
typing-extensions>=4.0.0
unicodecsv>=0.14.1
werkzeug>=2.0

[options.packages.find]
Expand Down
5 changes: 2 additions & 3 deletions tests/providers/apache/hive/transfers/test_mssql_to_hive.py
Expand Up @@ -71,9 +71,8 @@ def test_execute(self, mock_hive_hook, mock_mssql_hook, mock_tmp_file, mock_csv)
mssql_to_hive_transfer.execute(context={})

mock_mssql_hook_cursor.return_value.execute.assert_called_once_with(mssql_to_hive_transfer.sql)
mock_csv.writer.assert_called_once_with(
mock_tmp_file, delimiter=mssql_to_hive_transfer.delimiter, encoding="utf-8"
)
mock_tmp_file.assert_called_with(mode="w", encoding="utf-8")
mock_csv.writer.assert_called_once_with(mock_tmp_file, delimiter=mssql_to_hive_transfer.delimiter)
field_dict = OrderedDict()
for field in mock_mssql_hook_cursor.return_value.description:
field_dict[field[0]] = mssql_to_hive_transfer.type_map(field[1])
Expand Down
5 changes: 1 addition & 4 deletions tests/providers/apache/hive/transfers/test_mysql_to_hive.py
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import csv
import textwrap
from collections import OrderedDict
from contextlib import closing
Expand Down Expand Up @@ -187,7 +188,6 @@ def baby_names_table(self):
)
@pytest.mark.usefixtures("baby_names_table")
def test_mysql_to_hive(self, spy_on_hive, params, expected, csv):

sql = "SELECT * FROM baby_names LIMIT 1000;"
op = MySqlToHiveOperator(
task_id="test_m2h",
Expand Down Expand Up @@ -247,7 +247,6 @@ def test_mysql_to_hive_type_conversion(self, spy_on_hive):
cursor.execute(f"DROP TABLE IF EXISTS {mysql_table}")

def test_mysql_to_hive_verify_csv_special_char(self, spy_on_hive):

mysql_table = "test_mysql_to_hive"
hive_table = "test_mysql_to_hive"

Expand Down Expand Up @@ -277,8 +276,6 @@ def test_mysql_to_hive_verify_csv_special_char(self, spy_on_hive):
)
conn.commit()

import unicodecsv as csv

op = MySqlToHiveOperator(
task_id="test_m2h",
hive_cli_conn_id="hive_cli_default",
Expand Down

0 comments on commit fbeb01c

Please sign in to comment.