## Imports


In [1]:
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"
from dotenv import load_dotenv
from os import getenv
from pathlib import Path

load_dotenv()

import findspark

findspark.init()

In [2]:
CWD: Path = Path("/app/")
EXAMPLE_INPUT_PATH: Path = CWD / Path("./coding_challenge_files/example_input.txt")

## TAST PART TWO MYSQL DB


### Database Config


First we setup the connector driver path


In [3]:
MYSQL_CONNECTOR_FILENAME: str = "mysql-connector-j-8.2.0.jar"
MYSQL_CONNECTOR_PATH: str = f"/app/mysql_connector/{MYSQL_CONNECTOR_FILENAME}"
TABLE_NAME: str = "INSTRUMENT_PRICE_MODIFIER"

We check if hte connector path is correct


In [4]:
Path(MYSQL_CONNECTOR_PATH).exists()

True

Then we move the connector driver in the proper place in order to be recognizable by pyspark


In [5]:
findspark.add_jars(MYSQL_CONNECTOR_PATH)

We setup the relevant credentials for the database connection


In [6]:
# database connection info
DB_CON_DICT = dict(
    user=getenv("MYSQL_ROOT_USER"),
    password=getenv("MYSQL_ROOT_PASSWORD"),
    host=getenv("HOST"),
    port=int(getenv("MYSQL_DOCKER_PORT")),  # type: ignore
    database=getenv("MYSQL_DATABASE"),
)

DB_CON_DICT

{'user': 'root',
 'password': 'example',
 'host': 'db',
 'port': 3306,
 'database': 'mydb'}

We also setup the pyspark specific format we need for the database connection


In [7]:
# Configure MySQL connection properties
MYSQL_PROPERTIES = {
    "driver": "com.mysql.cj.jdbc.Driver",
    "url": "jdbc:mysql://{host}:{port}/{database}".format(**DB_CON_DICT),  # type: ignore
    "user": DB_CON_DICT["user"],  # type: ignore
    "password": DB_CON_DICT["password"],  # type: ignore
}

MYSQL_PROPERTIES

{'driver': 'com.mysql.cj.jdbc.Driver',
 'url': 'jdbc:mysql://db:3306/mydb',
 'user': 'root',
 'password': 'example'}

We test the database connection


In [8]:
from tests.src import test_mysql_conx

In [9]:
# test database connection
test_mysql_conx(**DB_CON_DICT)  # type: ignore

Connection Success


We create a test table with some values for pyspark database test


In [10]:
from tests.src import table_preparation

In [11]:
table_preparation()

Table 'test_table' deleted successfully.
Table 'test_table' created successfully.
Sample data inserted successfully.


We test the pyspark session against this test_table


In [12]:
from tests.src import test_pyspark_db_conx

In [13]:
test_pyspark_db_conx()

MYSQL_PROPERTIES={'driver': 'com.mysql.cj.jdbc.Driver', 'url': 'jdbc:mysql://db:3306/mydb', 'user': 'root', 'password': 'example'}
MYSQL driver path existence: True


24/01/11 13:57:50 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


+---+-----+
| id| name|
+---+-----+
|  1| John|
|  2|Alice|
|  3|  Bob|
+---+-----+



We clean up the test_table


In [14]:
from tests.src import drop_table

In [15]:
drop_table()

Table 'test_table' deleted successfully.


### `So as part of your task we would like you to set up a database with only one table, called INSTRUMENT_PRICE_MODIFIER with the following columns:`

-   ID (primary key)
-   NAME (instrument name as read from the input file)
-   MULTIPLIER - double value


In [9]:
from pyspark.sql import SparkSession

spark = (
    SparkSession.builder.appName("DatabaseConnection")
    .config("spark.jars", MYSQL_CONNECTOR_PATH)
    .config("spark.driver.extraClassPath", MYSQL_CONNECTOR_PATH)
    .config("spark.executor.extraClassPath", MYSQL_CONNECTOR_PATH)
    .getOrCreate()
)

Next we create the schema for the table


In [32]:
from pyspark.sql.types import (
    StructType,
    StructField,
    IntegerType,
    StringType,
    DoubleType,
)

# Define the schema
SCHEMA_DB = StructType(
    [
        StructField(
            "ID", IntegerType(), False
        ),  # False indicates that the field is not nullable
        StructField(
            "NAME", StringType(), True
        ),  # True indicates that the field is nullable
        StructField("MULTIPLIER", DoubleType(), True),
    ]
)

For testing purposes we generate some dummy data, WITHOUT INSTRUMENT1 for the sake of realistic example


In [149]:
# from random import uniform, randint

# Generate test data
# num_rows = 20
# data = [
#     (i, f"INSTRUMENT{choice([2,4,5,6])}", round(uniform(1.0, 10.0), 2))
#     for i in range(1, num_rows + 1)
# ]

data: list[tuple[int, str, float]] = [
    (1, "INSTRUMENT5", 5.19),
    (2, "INSTRUMENT4", 5.05),
    (3, "INSTRUMENT2", 4.4),
    (4, "INSTRUMENT4", 2.25),
    (5, "INSTRUMENT5", 1.75),
    (6, "INSTRUMENT6", 9.91),
    (7, "INSTRUMENT4", 4.5),
    (8, "INSTRUMENT2", 9.24),
    (9, "INSTRUMENT6", 5.83),
    (10, "INSTRUMENT5", 1.34),
    (11, "INSTRUMENT6", 8.89),
    (12, "INSTRUMENT2", 2.59),
    (13, "INSTRUMENT5", 4.04),
    (14, "INSTRUMENT4", 8.58),
    (15, "INSTRUMENT2", 8.64),
    (16, "INSTRUMENT2", 3.99),
    (17, "INSTRUMENT2", 6.82),
    (18, "INSTRUMENT2", 7.7),
    (19, "INSTRUMENT4", 4.44),
    (20, "INSTRUMENT4", 8.01),
]

In [150]:
df = spark.createDataFrame(data, schema=SCHEMA_DB).orderBy("ID")

df.printSchema()

df.show(5)

root
 |-- ID: integer (nullable = false)
 |-- NAME: string (nullable = true)
 |-- MULTIPLIER: double (nullable = true)

+---+-----------+----------+
| ID|       NAME|MULTIPLIER|
+---+-----------+----------+
|  1|INSTRUMENT5|      5.19|
|  2|INSTRUMENT4|      5.05|
|  3|INSTRUMENT2|       4.4|
|  4|INSTRUMENT4|      2.25|
|  5|INSTRUMENT5|      1.75|
+---+-----------+----------+
only showing top 5 rows



In [151]:
# Write the DataFrame to MySQL
df.write.jdbc(
    url=MYSQL_PROPERTIES["url"],
    table=TABLE_NAME,
    mode="overwrite",  # or "append" if needed
    properties=MYSQL_PROPERTIES,
)

### Determination of the final value


#### 1 `read the line from the input file;`


In [14]:
# Specify the path to the .txt file
txt_file_path: str = f"{EXAMPLE_INPUT_PATH}"

# Define the schema with StringType for DATE initially
schema = StructType(
    [
        StructField(name="INSTRUMENT_NAME", dataType=StringType(), nullable=True),
        StructField(name="DATE", dataType=StringType(), nullable=True),
        StructField(name="VALUE", dataType=DoubleType(), nullable=True),
    ]
)

# Read the .txt file into a PySpark DataFrame
extr = spark.read.option("delimiter", ",").csv(
    txt_file_path, header=False, schema=schema
)

# transform to dataframe
df_txt = extr.toDF("INSTRUMENT_NAME", "DATE", "VALUE")

In [16]:
from pyspark.sql.functions import to_date

date_format_str = "dd-MMM-yyyy"

# Convert the string to a DateType using to_date function
col_date_str = "DATE"
col_transformed_to_date = "DATE"  # "transformed_date"
col_formatted_Date = "DATE"  # "formatted_date"

df_txt = df_txt.withColumn(
    col_transformed_to_date, to_date(df_txt[col_date_str], date_format_str)
)

In [17]:
df_txt.printSchema()

df_txt.show(5)

root
 |-- INSTRUMENT_NAME: string (nullable = true)
 |-- DATE: date (nullable = true)
 |-- VALUE: double (nullable = true)

+---------------+----------+------+
|INSTRUMENT_NAME|      DATE| VALUE|
+---------------+----------+------+
|    INSTRUMENT1|1996-01-01|2.4655|
|    INSTRUMENT1|1996-01-02|2.4685|
|    INSTRUMENT1|1996-01-03| 2.473|
|    INSTRUMENT1|1996-01-04|2.4845|
|    INSTRUMENT1|1996-01-05|2.4868|
+---------------+----------+------+
only showing top 5 rows



Extract all rows


In [18]:
rows = df_txt.collect()
rows[:5]

[Row(INSTRUMENT_NAME='INSTRUMENT1', DATE=datetime.date(1996, 1, 1), VALUE=2.4655),
 Row(INSTRUMENT_NAME='INSTRUMENT1', DATE=datetime.date(1996, 1, 2), VALUE=2.4685),
 Row(INSTRUMENT_NAME='INSTRUMENT1', DATE=datetime.date(1996, 1, 3), VALUE=2.473),
 Row(INSTRUMENT_NAME='INSTRUMENT1', DATE=datetime.date(1996, 1, 4), VALUE=2.4845),
 Row(INSTRUMENT_NAME='INSTRUMENT1', DATE=datetime.date(1996, 1, 5), VALUE=2.4868)]

#### Constructing the Solution

`query the database to see if there is an entry for the <INSTRUMENT_NAME> you read in the 0st step;`

`if there is - multiply the original <VALUE> by the factor you found in the step 2;`

`if there is no entry - simply use the original <VALUE> from the file`


The rough idea is the following for a couple of rows. For our example we cherry pick a couple of rows from each INSTRUMENT name we have on the .txt file


In [165]:
l_1 = [row for row in rows if row["INSTRUMENT_NAME"] == "INSTRUMENT1"][:2]
l_1

[Row(INSTRUMENT_NAME='INSTRUMENT1', DATE=datetime.date(1996, 1, 1), VALUE=2.4655),
 Row(INSTRUMENT_NAME='INSTRUMENT1', DATE=datetime.date(1996, 1, 2), VALUE=2.4685)]

In [166]:
l_2 = [row for row in rows if row["INSTRUMENT_NAME"] == "INSTRUMENT2"][:2]
l_2

[Row(INSTRUMENT_NAME='INSTRUMENT2', DATE=datetime.date(1996, 2, 22), VALUE=9.326787847),
 Row(INSTRUMENT_NAME='INSTRUMENT2', DATE=datetime.date(1996, 2, 23), VALUE=9.321527686)]

In [167]:
l_3 = [row for row in rows if row["INSTRUMENT_NAME"] == "INSTRUMENT3"][:2]
l_3

[Row(INSTRUMENT_NAME='INSTRUMENT3', DATE=datetime.date(2012, 5, 31), VALUE=78.5325),
 Row(INSTRUMENT_NAME='INSTRUMENT3', DATE=datetime.date(2012, 6, 1), VALUE=78.2655)]

We engineered our database to not include INSTRUMENT1 for this example to be realistic

In [175]:
# create an empty dataframe
result_df = spark.createDataFrame([], schema=SCHEMA_DB)

# let's see the first 10
for row in l_1 + l_2 + l_3:

    # check if the name is empty
    if name := row["INSTRUMENT_NAME"]:
        # if not we proceed with querying the database
        query: str = f"SELECT * FROM {TABLE_NAME} WHERE NAME = '{name}'"

        # Query the database
        result_df = (
            spark.read.format("jdbc").options(**MYSQL_PROPERTIES, query=query).load()
        )
        print(f"{name=}")
        print(f'Initial value: {row["VALUE"]}')

    # is there is match in the database
    if not result_df.isEmpty():
        print(f"Multiplier found: {result_df.collect()[0]['MULTIPLIER']}")
        final_value = row["VALUE"] * result_df.collect()[0]["MULTIPLIER"]
        print(f"{final_value=}")
    # if there is no match in the database
    else:
        print("No Multiplier found")
        final_value = row["VALUE"]
        print(f"{final_value=}")
    print("-----------------------------")

name='INSTRUMENT1'
Initial value: 2.4655
No Multiplier found
final_value=2.4655
-----------------------------
name='INSTRUMENT1'
Initial value: 2.4685
No Multiplier found
final_value=2.4685
-----------------------------
name='INSTRUMENT2'
Initial value: 9.326787847
Multiplier found: 4.4
final_value=41.0378665268
-----------------------------
name='INSTRUMENT2'
Initial value: 9.321527686
Multiplier found: 4.4
final_value=41.0147218184
-----------------------------
name='INSTRUMENT3'
Initial value: 78.5325
No Multiplier found
final_value=78.5325
-----------------------------
name='INSTRUMENT3'
Initial value: 78.2655
No Multiplier found
final_value=78.2655
-----------------------------


So the rough idea works just fine. Let us refine it by properly introducing a proper  query mechanism according to the specifications.

Since the query is pretty simple and the table is supposed to be GBs it is much more efficient to directly query the database with the instrument name in order to obtain our result. 
- This really depends on how powerful is the pyspark cluster or the database itself.
- We assume that direct query to the database consumes the least resources in this case.

Regarding the time frequency of our queries, we will use a guard condition that will return the results of the previous query if we are within the 5 seconds time-window and the previous instrument name is the same as the current instrument name. Otherwise a new query will issued on the database.

This construction will be realized using the notion of a closure. A closure very roughly is a function with memory.


In [122]:
from datetime import datetime, timezone

DB_MIN_UPDATE_TIME = 5  # secs


def handle_query(instrument_name):
    """helper function to query a specific instrument in the db"""
    query: str = f"SELECT * FROM {TABLE_NAME} WHERE NAME = '{instrument_name}'"

    # Query the database
    return spark.read.format("jdbc").options(**MYSQL_PROPERTIES, query=query).load()


def query_db_closure(verbose: bool = False):
    result_df = spark.createDataFrame([], schema=SCHEMA_DB)
    last_time: datetime = datetime.now(timezone.utc)
    last_instrument_name: str = ""
    call_counter = 1

    def query_db(instrument_name: str):
        # this is to make the variables
        # available to the upper scope of
        # which is query_db_closure
        nonlocal result_df, last_time, last_instrument_name, call_counter

        current_time: datetime = datetime.now(timezone.utc)
        timediff: int = (current_time - last_time).seconds

        # printing logs
        if verbose:
            print("----query_db----")
            print(f"{last_instrument_name=}, current_instrument_name={instrument_name}")
            print(f"{call_counter=}")

        # this is a guard condition that ensures
        # we query the database a maximum of once
        # every 5 seconds for the same instrument
        if timediff <= DB_MIN_UPDATE_TIME and (last_instrument_name == instrument_name):
            if verbose:
                current_time_str: str = datetime.strftime(
                    current_time, "%d/%m/%Y, %H:%M:%S"
                )
                print(f"SAME Q {current_time_str}, {timediff=}")
                print(" ")
            last_time = current_time
            last_instrument_name = instrument_name
            call_counter += 1
            return result_df

        # printing logs
        if verbose:
            current_time_str: str = datetime.strftime(
                current_time, "%d/%m/%Y, %H:%M:%S"
            )
            print(f"NEW Q {current_time_str}, {timediff=}")
            print(" ")

        # query database
        result_df = handle_query(instrument_name)

        last_time = current_time
        last_instrument_name = instrument_name
        call_counter += 1
        return result_df

    return query_db

Let's test this first with the exact same instruments all over


In [124]:
# dummy_instrument_list = [f"INSTRUMENT{choice([1, 1])}" for i in range(1, 100)]

dummy_instrument_list: list[str] = [
    "INSTRUMENT1",
    "INSTRUMENT1",
    "INSTRUMENT1",
    "INSTRUMENT1",
    "INSTRUMENT1",
    "INSTRUMENT1",
    "INSTRUMENT1",
    "INSTRUMENT1",
    "INSTRUMENT1",
    "INSTRUMENT1",
    "INSTRUMENT1",
    "INSTRUMENT1",
    "INSTRUMENT1",
    "INSTRUMENT1",
    "INSTRUMENT1",
]

dummy_instrument_list

['INSTRUMENT1',
 'INSTRUMENT1',
 'INSTRUMENT1',
 'INSTRUMENT1',
 'INSTRUMENT1',
 'INSTRUMENT1',
 'INSTRUMENT1',
 'INSTRUMENT1',
 'INSTRUMENT1',
 'INSTRUMENT1',
 'INSTRUMENT1',
 'INSTRUMENT1',
 'INSTRUMENT1',
 'INSTRUMENT1',
 'INSTRUMENT1']

We should encouter NEW Q (new querries) only when timediff > 5


In [125]:
import time

query_db = query_db_closure(verbose=True)

for instrument in dummy_instrument_list:
    time.sleep(randint(1, 10))
    _ = query_db(instrument)

----query_db----
last_instrument_name='', current_instrument_name=INSTRUMENT1
call_counter=1
NEW Q 11/01/2024, 16:22:44, timediff=6
 
----query_db----
last_instrument_name='INSTRUMENT1', current_instrument_name=INSTRUMENT1
call_counter=2
SAME Q 11/01/2024, 16:22:45, timediff=1
 
----query_db----
last_instrument_name='INSTRUMENT1', current_instrument_name=INSTRUMENT1
call_counter=3
SAME Q 11/01/2024, 16:22:49, timediff=4
 
----query_db----
last_instrument_name='INSTRUMENT1', current_instrument_name=INSTRUMENT1
call_counter=4
SAME Q 11/01/2024, 16:22:51, timediff=2
 
----query_db----
last_instrument_name='INSTRUMENT1', current_instrument_name=INSTRUMENT1
call_counter=5
SAME Q 11/01/2024, 16:22:54, timediff=3
 
----query_db----
last_instrument_name='INSTRUMENT1', current_instrument_name=INSTRUMENT1
call_counter=6
SAME Q 11/01/2024, 16:22:56, timediff=2
 
----query_db----
last_instrument_name='INSTRUMENT1', current_instrument_name=INSTRUMENT1
call_counter=7
NEW Q 11/01/2024, 16:23:02, time

everything seems to be working as it should, we get a new query only if timediff>5secs which means 5 seconds pass after the previous query


In the same manner for random instruments


In [126]:
# dummy_instrument_list = [f"INSTRUMENT{choice([1,2,3])}" for i in range(1, 10)]

dummy_instrument_list: list[str] = [
    "INSTRUMENT2",
    "INSTRUMENT2",
    "INSTRUMENT2",
    "INSTRUMENT1",
    "INSTRUMENT3",
    "INSTRUMENT1",
    "INSTRUMENT2",
    "INSTRUMENT3",
    "INSTRUMENT3",
]

dummy_instrument_list

['INSTRUMENT2',
 'INSTRUMENT2',
 'INSTRUMENT2',
 'INSTRUMENT1',
 'INSTRUMENT3',
 'INSTRUMENT1',
 'INSTRUMENT2',
 'INSTRUMENT3',
 'INSTRUMENT3']

In [127]:
query_db = query_db_closure(verbose=True)

for instrument in dummy_instrument_list:
    time.sleep(randint(1, 10))
    _ = query_db(instrument)

----query_db----
last_instrument_name='', current_instrument_name=INSTRUMENT2
call_counter=1
NEW Q 11/01/2024, 16:24:44, timediff=4
 
----query_db----
last_instrument_name='INSTRUMENT2', current_instrument_name=INSTRUMENT2
call_counter=2
NEW Q 11/01/2024, 16:24:52, timediff=8
 
----query_db----
last_instrument_name='INSTRUMENT2', current_instrument_name=INSTRUMENT2
call_counter=3
SAME Q 11/01/2024, 16:24:53, timediff=1
 
----query_db----
last_instrument_name='INSTRUMENT2', current_instrument_name=INSTRUMENT1
call_counter=4
NEW Q 11/01/2024, 16:25:03, timediff=10
 
----query_db----
last_instrument_name='INSTRUMENT1', current_instrument_name=INSTRUMENT3
call_counter=5
NEW Q 11/01/2024, 16:25:06, timediff=3
 
----query_db----
last_instrument_name='INSTRUMENT3', current_instrument_name=INSTRUMENT1
call_counter=6
NEW Q 11/01/2024, 16:25:13, timediff=7
 
----query_db----
last_instrument_name='INSTRUMENT1', current_instrument_name=INSTRUMENT2
call_counter=7
NEW Q 11/01/2024, 16:25:16, timedif

As we can see for different subsequent instrument names our closure works as intended without regard for the time constraint of 5 secs but it does respect it in case of same subsequent instruments.
