# Handle Duplicate Data

In [34]:
# Standard library
import typing as T
import json
import uuid
import textwrap
import random
import dataclasses
from pathlib import Path
from datetime import datetime

# Third Party Library
import boto3
from boto_session_manager import BotoSesManager
from s3pathlib import S3Path, context
import redshift_connector
import sqlalchemy as sa
import sqlalchemy.orm as orm
from rich import print as rprint

## Connect to Redshift Serverless Using IAM

In [2]:
@dataclasses.dataclass
class Config:
    aws_profile: str
    workgroup: str

config = Config(
    aws_profile="awshsh_app_dev_us_east_1",
    workgroup="sanhe-dev-workgroup",
)

In [3]:
bsm = BotoSesManager(profile_name=config.aws_profile)
context.attach_boto_session(bsm.boto_ses)

In [60]:
def get_host_port(
    redshift_serverless_client, 
    workgroup: str,
) -> T.Tuple[str, int]:
    res = redshift_serverless_client.get_workgroup(workgroupName=workgroup)
    host = res["workgroup"]["endpoint"]["address"]
    port = res["workgroup"]["endpoint"]["port"]
    return host, port


def get_username_password(
    redshift_serverless_client, 
    workgroup: str,
) -> T.Tuple[str, str]:
    res = redshift_serverless_client.get_credentials(
        dbName="dev",
        workgroupName=workgroup,
        durationSeconds=900,
    )
    username = res["dbUser"]
    password = res["dbPassword"]
    return username, password


def create_redshift_connector_connection_using_iam(
    aws_profile: str,
    workgroup: str,
) -> redshift_connector.Connection:
    database = "dev"
    conn = redshift_connector.connect(
        iam=True,
        database=database,
        profile=aws_profile,
        is_serverless=True,
        serverless_work_group=workgroup,
    )
    return conn


def test_redshift_connector_connection(conn):
    print("Test connection by running a query")
    cursor = conn.cursor()
    sql = f"SELECT {random.randint(1, 100)};"
    row = cursor.execute(sql).fetchone()
    print(f"result = {row[0]}")
    print("Success!")


def create_sqlalchemy_engine_using_iam(redshift_serverless_client, workgroup: str) -> sa.engine.Engine:
    host, port = get_host_port(redshift_serverless_client, workgroup)
    database = "dev"
    username, password = get_username_password(redshift_serverless_client, workgroup)
    username = username.replace(":", "%3A")  # url encode the ":" character
    conn_str = (
        f"redshift+psycopg2://{username}:{password}"
        f"@{host}:{port}/{database}"
    )
    return sa.create_engine(conn_str)


def test_sqlalchemy_engine(engine: sa.engine.Engine):
    print("Test connection by running a query")
    with engine.connect() as conn:
        sql = f"SELECT {random.randint(1, 100)};"
        row = conn.execute(sql).first()
        print(f"result = {row[0]}")
        print("Success!")

conn = create_redshift_connector_connection_using_iam(config.aws_profile, config.workgroup)
test_redshift_connector_connection(conn)

engine = create_sqlalchemy_engine_using_iam(bsm.redshiftserverless_client, config.workgroup)
test_sqlalchemy_engine(engine)

Test connection by running a query
result = 19
Success!
Test connection by running a query
result = 32
Success!


## Define Data Model and Create Table

In [11]:
Base = orm.declarative_base()


class Transaction(Base):
    __tablename__ = "transactions"

    id: str = sa.Column(sa.String, primary_key=True)
    create_at: str = sa.Column(sa.String)
    update_at: str = sa.Column(sa.String)
    note: str = sa.Column(sa.String, nullable=True)

    @classmethod
    def new(cls, note: str = None):
        return cls(
            id=str(uuid.uuid4()),
            create_at=datetime.utcnow().isoformat(),
            update_at=datetime.utcnow().isoformat(),
            note=note,
        )

t_transaction = Transaction.__table__

def create_table():
    with engine.connect() as conn:
        sql = f"DROP TABLE IF EXISTS {Transaction.__tablename__};"
        conn.execute(sql)

    with engine.connect() as conn:
        sql = textwrap.dedent(
            f"""
            CREATE TABLE {Transaction.__tablename__}(
                id VARCHAR(36) DISTKEY NOT NULL,
                create_at VARCHAR(26) NOT NULL,
                update_at VARCHAR(26) NOT NULL,
                note VARCHAR
            )
            DISTSTYLE key
            COMPOUND SORTKEY(create_at);
            """
        )
        conn.execute(sql)

create_table()

In [16]:
s3dir_tmp = S3Path(
    f"s3://{bsm.aws_account_id}-{bsm.aws_region}-data"
    "/projects/learn_redshift/tmp/"
).to_dir()
s3dir_tmp

S3Path('s3://807388292768-us-east-1-data/projects/learn_redshift/tmp/')

In [64]:
from sqlalchemy_redshift.dialect import RedshiftDialect

def delete_all_data():
    with engine.connect() as conn:
        sql = f"DELETE FROM {Transaction.__tablename__};"
        conn.execute(sa.text(sql))


def insert_initial_data():
    # [
    #    ("id-1", "001", "001", "note 1"),
    #    ("id-2", "002", "002", "note 2"),
    #    ...
    # ]
    data = [
        {
            "id": f"id-{i}", 
            "create_at": str(i).zfill(3), 
            "update_at": str(i).zfill(3), 
            "note": f"note {i}",
        }
        for i in range(1, 1+5)
    ]
    s3path = s3dir_tmp.joinpath(f"{uuid.uuid4().hex}.json")
    content = "\n".join([
        json.dumps(row)
        for row in data
    ])
    s3path.write_text(content, content_type="application/json")
    print(s3path.console_url)
    
    sql = textwrap.dedent(f"""
    COPY {Transaction.__tablename__}
    FROM '{s3path.uri}'
    iam_role 'arn:aws:iam::{bsm.aws_account_id}:role/service-role/AmazonRedshift-CommandsAccessRole-20230630T231500'
    JSON 'auto';
    """)
    cursor = conn.cursor()
    cursor.execute(sql)
    conn.commit()

delete_all_data()
insert_initial_data()

https://console.aws.amazon.com/s3/object/807388292768-us-east-1-data?prefix=projects/learn_redshift/tmp/11e38ff5162841d5add5d6d2ea5f373b.json


In [63]:
def upsert_data():
    data = [
        {"id": "id-4", "create_at": "004", "update_at": "006", "note": "note 444"},
        {"id": "id-5", "create_at": "005", "update_at": "006", "note": "note 555"},
        {"id": "id-6", "create_at": "006", "update_at": "006", "note": "note 6"},
        {"id": "id-7", "create_at": "007", "update_at": "006", "note": "note 7"},
        {"id": "id-8", "create_at": "008", "update_at": "006", "note": "note 8"},
    ]
    s3path = s3dir_tmp.joinpath(f"{uuid.uuid4().hex}.json")
    content = "\n".join([
        json.dumps(row)
        for row in data
    ])
    s3path.write_text(content, content_type="application/json")
    print(s3path.console_url)

    table = Transaction.__tablename__
    tmp_table = f"{table}_tmp"

    conn.rollback()
    cursor = conn.cursor()

    # -------------------------------------------------------------
    # THIS IS WORKING
    sql = f"begin transaction;"
    print(sql)
    cursor.execute(sql)
    
    sql = f"DROP TABLE IF EXISTS {tmp_table};"
    print(sql)
    cursor.execute(sql)

    sql = textwrap.dedent(f"""
    CREATE TEMPORARY TABLE {tmp_table}
    (
        id VARCHAR(36) DISTKEY NOT NULL,
        create_at VARCHAR(26) NOT NULL,
        update_at VARCHAR(26) NOT NULL,
        note VARCHAR
    )
    DISTSTYLE key
    COMPOUND SORTKEY(create_at);
    """)
    print(sql)
    cursor.execute(sql)

    sql = textwrap.dedent(f"""
    COPY {tmp_table}
    FROM '{s3path.uri}'
    iam_role 'arn:aws:iam::{bsm.aws_account_id}:role/service-role/AmazonRedshift-CommandsAccessRole-20230630T231500'
    JSON 'auto';
    """)
    print(sql)
    cursor.execute(sql)

    sql = textwrap.dedent(f"""
    DELETE FROM {table}
    USING {tmp_table}
    WHERE 
        {table}.id = {tmp_table}.id;
    """)
    print(sql)
    cursor.execute(sql)

    sql = textwrap.dedent(f"""
    INSERT INTO {table}
    SELECT * FROM {tmp_table};
    """)
    print(sql)
    cursor.execute(sql)

    sql = f"end transaction;"
    print(sql)
    cursor.execute(sql)
    
    conn.commit()
    # -------------------------------------------------------------

upsert_data()

https://console.aws.amazon.com/s3/object/807388292768-us-east-1-data?prefix=projects/learn_redshift/tmp/2fc7b6fe5f864d90a13714bf6a634ba4.json
begin transaction;
DROP TABLE IF EXISTS transactions_tmp;

CREATE TEMPORARY TABLE transactions_tmp
(
    id VARCHAR(36) DISTKEY NOT NULL,
    create_at VARCHAR(26) NOT NULL,
    update_at VARCHAR(26) NOT NULL,
    note VARCHAR
)
DISTSTYLE key
COMPOUND SORTKEY(create_at);


COPY transactions_tmp
FROM 's3://807388292768-us-east-1-data/projects/learn_redshift/tmp/2fc7b6fe5f864d90a13714bf6a634ba4.json'
iam_role 'arn:aws:iam::807388292768:role/service-role/AmazonRedshift-CommandsAccessRole-20230630T231500'
JSON 'auto';


DELETE FROM transactions
USING transactions_tmp
WHERE 
    transactions.id = transactions_tmp.id;


INSERT INTO transactions
SELECT * FROM transactions_tmp;

end transaction;


In [None]:
# -*- coding: utf-8 -*-

"""
This script introduces how to do CRUD in Python using redshift_connector.

redshift_connector is a low level DB API 2.0 compatible driver for Amazon Redshift.

for more complicate data manipulation, you can use SQLAlchemy + sqlalchemy-redshift,
or awswrangler.
"""

import json
import uuid
import textwrap
import random
import dataclasses
from pathlib import Path
from datetime import datetime


import boto3
import boto_session_manager
import sqlalchemy as sa
import sqlalchemy.orm as orm


@dataclasses.dataclass
class DBConn:
    """
    Data model for database connection config file. It should be a json file
    like this::

        {
            "host": "redshift.host.com",
            "port":  5439,
            "database": "database",
            "username": "username",
            "password": "password"
        }
    """

    host: str
    port: int
    database: str
    username: str
    password: str

    @classmethod
    def read_config(cls):
        path_config = Path(__file__).absolute().parent.joinpath("config.json")
        return DBConn(**json.loads(path_config.read_text()))


Base = orm.declarative_base()


class Transaction(Base):
    __tablename__ = "transactions"

    id: str = sa.Column(sa.String, primary_key=True)
    create_at: str = sa.Column(sa.String)
    update_at: str = sa.Column(sa.String)
    note: str = sa.Column(sa.String, nullable=True)

    @classmethod
    def new(cls, note: str = None):
        return cls(
            id=str(uuid.uuid4()),
            create_at=datetime.utcnow().isoformat(),
            update_at=datetime.utcnow().isoformat(),
            note=note,
        )


def create_engine_using_username_password(db_conn: DBConn):
    conn_str = (
        f"redshift+psycopg2://{db_conn.username}:{db_conn.password}"
        f"@{db_conn.host}:{db_conn.port}/{db_conn.database}"
    )
    return sa.create_engine(conn_str)

def create_engine_using_iam(db_conn: DBConn, boto_ses: boto3.session.Session):
    redshift_serverless_client = boto_ses.client("redshift-serverless")
    res = redshift_serverless_client.get_credentials(
        dbName="dev",
        workgroupName="sanhe-dev-workgroup",
        durationSeconds=900,
    )
    username = res["dbUser"]
    password = res["dbPassword"]
    username = username.replace(":", "%3A")  # url encode the : character
    conn_str = (
        f"redshift+psycopg2://{username}:{password}"
        f"@{db_conn.host}:{db_conn.port}/{db_conn.database}"
    )
    return sa.create_engine(conn_str)


def test_connection(conn):
    print("Test connection by running a query")
    cursor = conn.cursor()
    sql = f"SELECT {random.randint(1, 100)};"
    row = cursor.execute(sql).fetchone()
    print(row[0])
    print("Success!")


TABLE_NAME = "transactions"


def create_table(engine):
    with engine.connect() as conn:
        sql = textwrap.dedent(
            f"""
            DROP TABLE IF EXISTS {TABLE_NAME};
            """
        )
        conn.execute(sql)

    with engine.connect() as conn:
        sql = textwrap.dedent(
            """
            CREATE TABLE transactions(
                id VARCHAR(36) DISTKEY NOT NULL,
                create_at VARCHAR(26) NOT NULL,
                update_at VARCHAR(26) NOT NULL,
                note VARCHAR
            )
            DISTSTYLE key
            COMPOUND SORTKEY(create_at);
            """
        )
        conn.execute(sql)


def insert_data(engine):
    print(f"Insert some data into {TABLE_NAME!r} table")
    with orm.Session(engine) as ses:
        transaction = Transaction.new(note=f"note {random.randint(1, 1000000)}")
        ses.add(transaction)
        ses.commit()


def select_data(engine):
    print(f"Select data from {TABLE_NAME!r} table")

    # return object
    with orm.Session(engine) as ses:
        for transaction in ses.query(Transaction):
            print(
                [
                    transaction.id,
                    transaction.create_at,
                    transaction.update_at,
                    transaction.note,
                ]
            )

    # return python dict
    # with engine.connect() as conn:
    #     for transaction in conn.execute(sa.select(Transaction)).mappings():
    #         print(transaction)


if __name__ == "__main__":
    db_conn = DBConn.read_config()
    aws_profile = "awshsh_app_dev_us_east_1"
    boto_ses = boto3.session.Session()

    # engine = create_engine_using_username_password(db_conn)
    engine = create_engine_using_iam(db_conn, boto_ses)
    # create_table(engine)
    # insert_data(engine)
    # select_data(engine)
