In [None]:
# DO NOT DELETE THIS CELL
# ADD YOUR PARAMETER DEFAULT VALUES HERE
params = None

In [None]:
import json
import os
from dotenv import load_dotenv
import pandas as pd
import pyodbc

def is_dev_mode(env_mode: str):
    """
    check if the current env_mode is "dev", otherwise "prod" is assumed
    """
    return env_mode == "dev"

def debug(var):
    """
    only debug (print) the variables if env_mode is "dev"
    """
    if is_dev_mode(ENV_MODE):
        return var
    else:
        pass

def get_available_sql_driver():
    driver_names = [x for x in pyodbc.drivers() if x.endswith(' for SQL Server')]
    if len(driver_names) > 0:
        driver_name = driver_names[0]
        return driver_name
    else:
       raise ValueError("Cannot connect. No suitable driver found.\nInstall driver from here: https://learn.microsoft.com/en-us/sql/connect/odbc/download-odbc-driver-for-sql-server?view=sql-server-ver16\n\n")

def parse_params(params):
    """
    parse the data from the params if env_mode is not "dev"
    """
    if not is_dev_mode(ENV_MODE):
        params = json.loads(params) # type: ignore - this params will either be overridden by papermill or the local sql query
        data = params["data"]
        return data
    
def get_dataset(name: str, data):
    """
    get the dataset with the given name if data was passed from params
    """
    if data is not None:
        return data[name]

def read_data(data, *args, **kwargs):
    """
    only read sql if env_mode is "dev", otherwise it is assumed that data is the json data passed as parameter
    """
    if is_dev_mode(ENV_MODE):
        return pd.read_sql(*args, **kwargs)
    else:
        return pd.DataFrame(data)

load_dotenv()
ENV_MODE = os.environ["ENV_MODE"]
SQL_CONNECTION_STRING = os.environ["SQL_CONNECTION_STRING"]
driver = get_available_sql_driver()
formatted_driver = "+".join(driver.split())
SQL_CONNECTION_STRING_WITH_DRIVER = SQL_CONNECTION_STRING + f"?driver={formatted_driver}"
data = parse_params(params)

In [None]:
# GET DATA
sales_transactions = read_data(data=get_dataset("sales_transactions", data), sql="SELECT * FROM etl.SalesTransactions st WHERE YEAR(st.PostingDate) > 2020", con=SQL_CONNECTION_STRING_WITH_DRIVER)

In [None]:
debug(sales_transactions)

In [None]:
sales_transactions["PostingDate"] = pd.to_datetime(sales_transactions["PostingDate"])
sales_transactions["Year"] = sales_transactions["PostingDate"].dt.year
sales_transactions["Month"] = sales_transactions["PostingDate"].dt.month
grouped_sales = sales_transactions.groupby(by=["CompanyCode", "Year", "Month"])[["Quantity", "NetSalesEUR", "GrossProfitEUR"]].sum()

In [None]:
debug(grouped_sales)

In [None]:
# DO NOT DELETE THIS CELL
# ASSIGN YOUR RETURN VALUE HERE
return_value = grouped_sales.reset_index().to_json(orient="records")

In [None]:
# DO NOT EDIT OR DELETE THIS CELL
return_value