In [2]:
%%writefile dags/simple_etl_dag.py

import os

from functools import wraps

import pandas as pd
import numpy as np

from airflow.models import DAG
from airflow.utils.dates import days_ago
from airflow.operators.python import PythonOperator

from dotenv import dotenv_values
from sqlalchemy import create_engine, inspect

args = {"owner": "Airflow", "start_date": days_ago(1)}

dag = DAG(dag_id="simple_etl_dag", default_args=args, schedule_interval=None)

def logger(func):
    from datetime import datetime, timezone

    @wraps(func)
    def wrapper(*args, **kwargs):
        called_at = datetime.now(timezone.utc)
        print(f">>> Running {func.__name__!r} function. Logged at {called_at}")
        to_execute = func(*args, **kwargs)
        print(f">>> Function: {func.__name__!r} executed. Logged at {called_at}")
        return to_execute

    return wrapper

CONFIG = dotenv_values(".env")
if not CONFIG:
    CONFIG = os.environ


@logger
def connect_db():
    print("Connecting to DB")
    connection_uri = "postgresql+psycopg2://{}:{}@{}:{}".format(
        CONFIG["POSTGRES_USER"],
        CONFIG["POSTGRES_PASSWORD"],
        CONFIG["POSTGRES_HOST"],
        CONFIG["POSTGRES_PORT"],
    )

    engine = create_engine(connection_uri, pool_pre_ping=True)
    engine.connect()
    return engine


@logger
def extract(NationalNames_path):
    print(f" workdir: {os.getcwd()}")
    print(f"colu")
    print(f"Reading dataset from {NationalNames_path}")
    df = pd.read_csv(NationalNames_path)
    # stateNamesDF = pd.read_csv(StateNames_path)
    return df

@logger
def sample_data(df):
    random_names  = np.random.choice(df['Name'].unique(), size=30, replace=False)
    df = df.loc[df.Name.isin(random_names),:]
    return df


@logger
def transform(df):
    # transformation
    print("Transforming data")
    df = pd.pivot_table(df, values='Count', index=['Name','Year'],
                                       columns=['Gender'], aggfunc=np.sum)
    
    df.reset_index(inplace=True)
    df.index.name = None
    df.columns.name = None
    
    df = df.sort_values(['Year','Name'], ascending=True)
    
    df.loc[:,'Total'] = df[['F','M']].sum(axis=1)
    
    for name in df['Name'].unique():
        df.loc[df['Name'] == name, 'F_YoY_pct'] = df.loc[df['Name'] == name, 'F'].pct_change()
        df.loc[df['Name'] == name, 'M_YoY_pct'] = df.loc[df['Name'] == name, 'M'].pct_change()
        df.loc[df['Name'] == name, 'Total_YoY_pct'] = df.loc[df['Name'] == name, 'Total'].pct_change()

    df_grouped = df.groupby(['Name'])[['F','M','Total']].sum()
    df_grouped['FM_ratio'] = (df_grouped['F'] - df_grouped['M'])/df_grouped[['F','M']].max(axis=1)
    df_grouped
    
    
    return df, df_grouped

@logger
def check_table_exists(table_name, engine):
    if table_name in inspect(engine).get_table_names():
        print(f"{table_name!r} exists in the DB!")
    else:
        print(f"{table_name} does not exist in the DB!")

@logger
def load_to_db(df, table_name, engine):
    print(f"Loading dataframe to DB on table: {table_name}")
    df.to_sql(table_name, engine, if_exists="replace")

@logger
def tables_exists():
    db_engine = connect_db()
    print("Checking if tables exists")
    check_table_exists("m_national_names_detail", db_engine)
    check_table_exists("m_national_names", db_engine)
    db_engine.dispose()

@logger
def etl():
    db_engine = connect_db()

    raw_df = extract("/opt/airflow/dags/NationalNames.csv")
    raw_table_name = "m_national_names_detail"
    
    sampled_data_df = sample_data(raw_df)

    nationalNamesTrendsDF, nationaNamesGrouped = transform(sampled_data_df)
    m_national_names_detail = 'm_national_names_detail'
    m_national_names = 'm_national_names'

    # load_to_db(raw_df, raw_table_name, db_engine)
    load_to_db(nationalNamesTrendsDF, m_national_names_detail, db_engine)
    load_to_db(nationaNamesGrouped, m_national_names, db_engine)

    db_engine.dispose()
    
with dag:
    run_etl_task = PythonOperator(task_id="run_etl_task", python_callable=etl)
    run_tables_exists_task = PythonOperator(
        task_id="run_tables_exists_task", python_callable=tables_exists)

    run_etl_task >> run_tables_exists_task

Overwriting dags/simple_etl_dag.py
