Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 29 additions & 26 deletions awswrangler/pandas.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Dict, List, Tuple, Optional, Any
from io import BytesIO, StringIO
import multiprocessing as mp
import logging
Expand Down Expand Up @@ -854,20 +855,20 @@ def write_parquet_dataframe(dataframe, path, preserve_index, compression, fs, ca

def to_redshift(
self,
dataframe,
path,
connection,
schema,
table,
iam_role,
diststyle="AUTO",
distkey=None,
sortstyle="COMPOUND",
sortkey=None,
preserve_index=False,
mode="append",
cast_columns=None,
):
dataframe: pd.DataFrame,
path: str,
connection: Any,
schema: str,
table: str,
iam_role: str,
diststyle: str = "AUTO",
distkey: Optional[str] = None,
sortstyle: str = "COMPOUND",
sortkey: Optional[str] = None,
preserve_index: bool = False,
mode: str = "append",
cast_columns: Optional[Dict[str, str]] = None,
) -> None:
"""
Load Pandas Dataframe as a Table on Amazon Redshift

Expand All @@ -888,28 +889,30 @@ def to_redshift(
"""
if cast_columns is None:
cast_columns = {}
cast_columns_parquet = {}
cast_columns_parquet: Dict = {}
else:
cast_columns_parquet = data_types.convert_schema(func=data_types.redshift2athena, schema=cast_columns)
cast_columns_tuples: List[Tuple[str, str]] = [(k, v) for k, v in cast_columns.items()]
cast_columns_parquet = data_types.convert_schema(func=data_types.redshift2athena,
schema=cast_columns_tuples)
if path[-1] != "/":
path += "/"
self._session.s3.delete_objects(path=path)
num_rows = len(dataframe.index)
num_rows: int = len(dataframe.index)
logger.debug(f"Number of rows: {num_rows}")
if num_rows < MIN_NUMBER_OF_ROWS_TO_DISTRIBUTE:
num_partitions = 1
num_partitions: int = 1
else:
num_slices = self._session.redshift.get_number_of_slices(redshift_conn=connection)
num_slices: int = self._session.redshift.get_number_of_slices(redshift_conn=connection)
logger.debug(f"Number of slices on Redshift: {num_slices}")
num_partitions = num_slices
logger.debug(f"Number of partitions calculated: {num_partitions}")
objects_paths = self.to_parquet(dataframe=dataframe,
path=path,
preserve_index=preserve_index,
mode="append",
procs_cpu_bound=num_partitions,
cast_columns=cast_columns_parquet)
manifest_path = f"{path}manifest.json"
objects_paths: List[str] = self.to_parquet(dataframe=dataframe,
path=path,
preserve_index=preserve_index,
mode="append",
procs_cpu_bound=num_partitions,
cast_columns=cast_columns_parquet)
manifest_path: str = f"{path}manifest.json"
self._session.redshift.write_load_manifest(manifest_path=manifest_path, objects_paths=objects_paths)
self._session.redshift.load_table(
dataframe=dataframe,
Expand Down
47 changes: 44 additions & 3 deletions testing/test_awswrangler/test_redshift.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
import logging
from datetime import date, datetime

import pytest
import boto3
import pandas
import pandas as pd
from pyspark.sql import SparkSession
import pg8000

Expand Down Expand Up @@ -80,7 +81,7 @@ def test_to_redshift_pandas(session, bucket, redshift_parameters, sample_name, m
dates = ["date"]
if sample_name == "nano":
dates = ["date", "time"]
dataframe = pandas.read_csv(f"data_samples/{sample_name}.csv", parse_dates=dates, infer_datetime_format=True)
dataframe = pd.read_csv(f"data_samples/{sample_name}.csv", parse_dates=dates, infer_datetime_format=True)
dataframe["date"] = dataframe["date"].dt.date
con = Redshift.generate_connection(
database="test",
Expand Down Expand Up @@ -113,6 +114,46 @@ def test_to_redshift_pandas(session, bucket, redshift_parameters, sample_name, m
assert len(list(dataframe.columns)) + 1 == len(list(rows[0]))


def test_to_redshift_pandas_cast(session, bucket, redshift_parameters):
df = pd.DataFrame({
"id": [1, 2, 3],
"name": ["name1", "name2", "name3"],
"foo": [None, None, None],
"boo": [date(2020, 1, 1), None, None],
"bar": [datetime(2021, 1, 1), None, None]})
schema = {
"id": "BIGINT",
"name": "VARCHAR",
"foo": "REAL",
"boo": "DATE",
"bar": "TIMESTAMP"}
con = Redshift.generate_connection(
database="test",
host=redshift_parameters.get("RedshiftAddress"),
port=redshift_parameters.get("RedshiftPort"),
user="test",
password=redshift_parameters.get("RedshiftPassword"),
)
path = f"s3://{bucket}/redshift-load/"
session.pandas.to_redshift(dataframe=df,
path=path,
schema="public",
table="test",
connection=con,
iam_role=redshift_parameters.get("RedshiftRole"),
mode="overwrite",
preserve_index=False,
cast_columns=schema)
cursor = con.cursor()
cursor.execute("SELECT * from public.test")
rows = cursor.fetchall()
cursor.close()
con.close()
print(rows)
assert len(df.index) == len(rows)
assert len(list(df.columns)) == len(list(rows[0]))


@pytest.mark.parametrize(
"sample_name,mode,factor,diststyle,distkey,exc,sortstyle,sortkey",
[
Expand All @@ -125,7 +166,7 @@ def test_to_redshift_pandas(session, bucket, redshift_parameters, sample_name, m
)
def test_to_redshift_pandas_exceptions(session, bucket, redshift_parameters, sample_name, mode, factor, diststyle,
distkey, sortstyle, sortkey, exc):
dataframe = pandas.read_csv(f"data_samples/{sample_name}.csv")
dataframe = pd.read_csv(f"data_samples/{sample_name}.csv")
con = Redshift.generate_connection(
database="test",
host=redshift_parameters.get("RedshiftAddress"),
Expand Down