In [None]:
import os 
print(os.getcwd())

In [None]:
import sys
from utils.pipeline import Context, Step
from utils.pipeline_lib import *
from tqdm import tqdm
from sqlalchemy import create_engine

# generated file paths
saved_context_ct_path = "data/h5n1/h1n1_ctx"
export_seq_id_path = "data/h5n1/h1n1_08-09_dataset.csv"
db_url = 'sqlite:///data/h5n1/h1n1.sqlite'

# RSCU, dinucleotides and other attributes

In [4]:
ct = Context()

# DATA
class Data(Step):
    def run(self):
        return {
            'output': pd.read_csv(...),         # replace with your data
            'metadata_feature_names': [...],    # the list of metadata column names from your file
            'data_feature_names': ["CDS"]       
        }
S_global_data = Data()
S_global_data_1 = Input.AssignHostType(depends_on=S_global_data)
S_global_data_2 = Input.RSCU(depends_on=S_global_data_1, ctx=ct)
S_transformation_1 = DataTransform.LogBySynCount(depends_on=S_global_data_2)
S_transformation_2 = DataTransform.PlainAndLogDinucleotide(depends_on=S_transformation_1, ctx=ct)
S_ordered_data = AbsoluteSortByCollectionDateAndIndex(depends_on=S_transformation_2, ctx=ct)
S_global_data_sorted_bmc_annotated_1 = InputH5N1.AttachBMC_GenomicsCluster(depends_on=S_ordered_data)
S_global_data_sorted_bmc_annotated_2 = AttachWeek(depends_on=S_global_data_sorted_bmc_annotated_1, ctx=ct)

In [None]:
result_input_data = S_global_data_sorted_bmc_annotated_2.materialize()
print(result_input_data.output.shape)
# Assert the DF is still sorted by Sort_Key
pd.testing.assert_series_equal(result_input_data.output.Sort_Key, result_input_data.output.Sort_Key.sort_values())
" ".join(result_input_data.output.columns)

In [6]:
exported_columns = ['Lineage', 'Clade', 'Location', 'Host', 'Collection_Date', 'Submission_Date', 'Host_Type', 'Sort_Key',  
                               'bmc_cluster_label', 'Week', 
                               'Pathogenicity']

In [7]:
ct.store(saved_context_ct_path)

### write table of input sequences

In [9]:
# cnx.dispose()
cnx = create_engine(db_url)

write result_input_data

In [11]:
with cnx.connect() as connection:
    (result_input_data.output[exported_columns]
     .to_sql(name='input_data', con=connection, if_exists='replace'))

compute and write warnings

In [None]:
def run_windows(S_data, stray_window_size):
    ct = Context.load(saved_context_ct_path)
    S_partitioner = Stray.MovingWindowFixedSize(evaluation_date_range=DateRange.from_iso_weeks('2019-36', '2025-04'), 
                                                train_size=stray_window_size-1, test_size=1, window_shift=1,
                                                depends_on=S_data, ctx=ct)
    # GET WINDOW NAMES
    result_partitioner = S_partitioner.materialize()
    train_plus_eval_index_range = result_partitioner.train_plus_test_range_names.values()
    eval_index_range = result_partitioner.test_range_names.values()

    # GET PARTITIONS
    S_train_plus_test_partitions = [NumericIndexBasedPartition(partitioner_function_name='.next_train_plus_test_partition', 
                                                            range_start=s, range_end=e,
                                                            depends_on=S_partitioner, ctx=ct,
                                                            name_alias=f"TrainEvalPartition_{s}_{e}") for s,e in train_plus_eval_index_range]
    S_eval_partitions = [NumericIndexBasedPartition(partitioner_function_name='.next_test_partition', 
                                                    range_start=s, range_end=e,
                                                    depends_on=S_partitioner, ctx=ct,
                                                    name_alias=f"EvalPartition_{s}_{e}") for s,e in eval_index_range]
    return ct, S_partitioner, result_partitioner, train_plus_eval_index_range, eval_index_range, S_train_plus_test_partitions, S_eval_partitions
    
def run_stray4k(S_data, ct, S_partitioner, result_partitioner, train_plus_eval_index_range, eval_index_range, S_train_plus_test_partitions, S_eval_partitions, stray_k):
    # STRAY
    S_outlier_detection = [Stray.OutlierDetection(k=stray_k, depends_on=x, name_alias=f"OutlierDetection{s}_{e}") for x,(s,e) in zip(S_train_plus_test_partitions,eval_index_range)]
    S_train_plus_eval_partitions_enriched = [Stray.MapOutliersToOriginalData(depends_on=[s1,s2], 
                                                                            name_alias=f"MapOutliersToOriginalData_{s}_{e}") for s1,s2,(s,e) in zip(S_train_plus_test_partitions,S_outlier_detection,eval_index_range)]
    S_eval_partitions_enriched = [IndexBasedFilter(depends_on=[i1,i2], name_alias=f"IndexBasedFilter_{s}_{e}") for i1,i2,(s,e) in zip(S_train_plus_eval_partitions_enriched, S_eval_partitions, eval_index_range)]
    S_eval_outliers_enriched = [Stray.FilterOutliers(depends_on=x, name_alias=f"FilterOutliers_{s}_{e}") for x,(s,e) in zip(S_eval_partitions_enriched, eval_index_range)]

    # COLLECTIVE EVALUATION
    S_all_outliers = Stray.CollectOutlierIdMultipleWindows(depends_on=S_eval_outliers_enriched)
    S_tested_data = WindowSelector(date_range=DateRange.from_iso_weeks('2019-36', '2025-04'),start="2019-36",end="2025-04", inclusive="left", depends_on=S_data, name_alias="EvalData")
    S_global_data_annotated_1 = Stray.AnnotateOutliersinOriginalData(depends_on=[S_tested_data, S_all_outliers])

    S_outliers_count = Stray.CollectOutlierCountMultipleWindows(depends_on=S_global_data_annotated_1)
    S_global_data_annotated_3 = Stray.AnnotateOutliersCountInOriginalData(depends_on=[S_global_data_annotated_1, S_outliers_count])

    # ##########  RUN
    head = S_global_data_annotated_3
    result = head.materialize()
    return result

def format_detail_table(result: ResultType[Stray.AnnotateOutliersCountInOriginalData]):
    # outlier_count = result.output.outlier.sum()
    a = result.output[result.output.outlier].copy()
    a['Week'] = a.Collection_Date.dt.strftime("%G-%V")
    a = a[exported_columns]
    return a 

def write_sqlite_table(df, name, cnx):
    with cnx.connect() as connection:
        df.to_sql(name=name, con=connection, if_exists='replace')

def loop(window_sizes, ks):
    progress = tqdm(total=len(window_sizes)*len(ks))
    comb_n = 0
    
    for window_size in tqdm(window_sizes):
        args = run_windows(S_global_data_sorted_bmc_annotated_2, window_size)
        for k in tqdm(ks):
            if k >= window_size:
                continue
           
            try:
                result_unformatted = run_stray4k(S_global_data_sorted_bmc_annotated_2, *args, k)
                ## tabella di sequenze con almeno un warning 
            except Exception as e:
                print(f"Error while testing combination (comb_n) {comb_n}: {window_size} {k}")
                raise e
            
            try:
                detail_table = format_detail_table(result_unformatted)
            except Exception as e:
                print(f"Error while formatting combination (comb_n) {comb_n}: {window_size} {k}")
                raise e
            
            try:
                write_sqlite_table(detail_table, name=f"window{window_size}_k{k}", cnx=cnx)
            except Exception as e:
                print(f"Error while writing warnings detail file of combination (comb_n) {comb_n}: {window_size} {k}")
                raise e
            
            comb_n += 1
        
            progress.update(1)
                               
print(f"{4 * 5} input combinations")
loop(window_sizes=(5,10,50,100), ks=(1,3,5,10,15))
