In [None]:
from deltalake import DeltaTable, write_deltalake
from pathlib import Path
import datetime
import ray
import pandas as pd
import polars as pl
from tpot import TPOTClassifier, TPOTRegressor
from sklearn.model_selection import train_test_split

In [None]:
ray.init(ignore_reinit_error=True, num_cpus=4, object_store_memory=1_000_000_000)

In [3]:
input_data_path = r"../data/delta"

In [13]:
dt = DeltaTable(input_data_path)
# dict of partitions
entity_partitions = dt.partitions()
sample_fraction = 1.0

In [5]:
entity_partitions = [{'entity': 'entity_18'},
 {'entity': 'entity_13'},
 {'entity': 'entity_7'},
]
entity_partition = "entity_19"
entity_col = 'entity'


In [14]:
# tpot = TPOTRegressor(generations=2, population_size=3, verbosity=2, random_state=42)
tpot = TPOTClassifier(generations=2, population_size=3, verbosity=2, random_state=42)

In [15]:
@ray.remote
def process_entity(input_data_path, entity_partition, sample_fraction=1.0):
    entity_pldf = pl.read_delta(input_data_path, pyarrow_options={"partitions": [("entity", "=", f"{entity_partition}")]})
    entity_pdf = entity_pldf.sample(fraction=sample_fraction).to_pandas()

    features = [s for s in entity_pdf.columns.to_list() if s.isdigit()]
    target = 'target'

    # prepare data
    X_train, X_test, y_train, y_test = train_test_split(
        entity_pdf[features], entity_pdf[target], test_size=0.25, random_state=42
    )

    tpot.fit(X_train, y_train)
    entity_score = tpot.score(X_test, y_test)

    # # mocking failure
    # if entity_partition == 'entity_19' and sample_fraction > 0.8:
    #     raise ValueError
    return entity_partition, entity_pldf.shape, entity_pdf.shape, entity_score

In [16]:
output = {}
start_time_ray =  datetime.datetime.now()

for part in entity_partitions:
    entity_partition = part[entity_col]
    entity_objref = process_entity.remote(input_data_path, entity_partition, sample_fraction)
    output[entity_objref] = (input_data_path, entity_partition, sample_fraction)

In [None]:
# TODO arch diagram

In [None]:
results = {}
while output:
    done, _ = ray.wait(list(output.keys()), num_returns=1)
    done = done[0]
    done_key = output.pop(done)
    print(f"original args ..... {done_key}")
    try:
        results[done_key[1]] = ray.get(done)
    except (ray.exceptions.RayTaskError, ray.exceptions.WorkerCrashedError) as e:
        # get raised memory error and then retry based on that
        print(e)
        input_data_path, entity_partition, sample_fraction = done_key
        sample_fraction *= 0.9
        done_key = (input_data_path, entity_partition, sample_fraction)
        entity_objref = process_entity.remote(*done_key)
        output[entity_objref] = done_key
        print(f"updated args ------ {done_key}")

exec_duration_ray =  datetime.datetime.now() - start_time_ray

In [None]:
exec_duration_ray

In [102]:
ray.shutdown()