In [None]:
import os 
print(os.getcwd())

In [2]:
import sys
from utils.pipeline import Context, Step
from utils.pipeline_lib import *
from tqdm import tqdm
from sqlalchemy import create_engine


with open("data/geoapify.key", "r") as  f:
    geoapify_api_key = f.readline().rstrip("\n").strip()
# generated file paths
saved_context_ct_path = "data/h1n1/h1n1_ctx"
export_seq_id_path = "data/h1n1/h1n1_08-09_dataset.csv"
db_url = 'sqlite:///data/h1n1/h1n1.sqlite'

# RSCU, dinucleotides and other attributes

In [None]:
ct = Context()

# DATA
class Data(Step):
    def run(self):
        return {
            'output': pd.read_csv(...),         # replace with your data
            'metadata_feature_names': [...],    # the list of metadata column names from your file
            'data_feature_names': ["CDS"]       
        }
S_global_data = Data()
S_global_data_1 = Input.AssignHostType(depends_on=S_global_data)
S_global_data_2 = Input.RSCU(depends_on=S_global_data_1, ctx=ct)
S_transformation_1 = DataTransform.LogBySynCount(depends_on=S_global_data_2)
S_transformation_2 = DataTransform.PlainAndLogDinucleotide(depends_on=S_transformation_1, ctx=ct)
S_partition_data = WindowSelector(date_range=DateRange.from_iso_weeks('2008-01', '2010-10'), start="2008-01", end="2010-10", depends_on=S_transformation_2, name_alias="NarrowGlobalData", ctx=ct)
S_ordered_data = AbsoluteSortByCollectionDateAndIndex(depends_on=S_partition_data, ctx=ct)

print("!! Using filter boundary box of North America")
# start geolocate cache
geolocate_cache_path = "data/cache/cache.parquet"
geolocate_cache = pd.read_parquet(geolocate_cache_path) if os.path.exists(geolocate_cache_path) else Geolocation.GeoapifyBatchJobRequest.get_new_cache()

# geolocate
S_geolocate_1 = Geolocation.GeoapifyBatchJobRequest(geoapify_api_key, continent_bounding_box=Geoapify.continent_bounding_box['North America'], 
                                                    cache=geolocate_cache, 
                                                    depends_on=S_ordered_data, ctx=ct)
geolocate_cache = S_geolocate_1.cache
S_geolocate_2 = Geolocation.GeoapifyParseBatchRequestOutput(geoapify_api_key, cache=geolocate_cache, 
                                                            depends_on=[S_geolocate_1, S_ordered_data])
geolocate_cache = S_geolocate_2.cache

S_geolocate_3 = Geolocation.GuessCountryOrStateFromSequenceName(geoapify_api_key, continent_bounding_box=Geoapify.continent_bounding_box['North America'], 
                                                                cache=geolocate_cache, 
                                                                depends_on=S_geolocate_2, ctx=ct)
geolocate_cache = S_geolocate_3.cache

S_global_data_sorted_bmc_annotated_1 = InputH1N1.AttachBMC_GenomicsCluster(depends_on=S_geolocate_3)
S_global_data_sorted_bmc_annotated_2 = AttachWeek(depends_on=S_global_data_sorted_bmc_annotated_1, ctx=ct)

result_input_data = S_global_data_sorted_bmc_annotated_2.materialize()

# del ct['GeoapifyBatchJobRequest']
ct.store(saved_context_ct_path)

### write table of input sequences

In [51]:
# cnx.dispose()
cnx = create_engine(db_url)

write result_input_data

In [37]:
with cnx.connect() as connection:
    (result_input_data.output[['Lineage', 'Clade', 'Location', 'Host', 'Host_Type', 'Collection_Date', 'Submission_Date', 'Sort_Key', 'lat', 'lon', 'country_or_state', 'country', 'state', 'bmc_cluster_label', 'Week']]
     .to_sql(name='input_data', con=connection, if_exists='replace'))

compute and write warnings

In [None]:
def run(S_data, stray_window_size, stray_k):
    ct = Context.load(saved_context_ct_path)
    S_partitioner = Stray.MovingWindowFixedSize(evaluation_date_range=DateRange.from_iso_weeks('2008-31', '2010-01'), 
                                                train_size=stray_window_size-1, test_size=1, window_shift=1,
                                                depends_on=S_data, ctx=ct)
    # GET WINDOW NAMES
    result_partitioner = S_partitioner.materialize()
    train_plus_eval_index_range = result_partitioner.train_plus_test_range_names.values()
    eval_index_range = result_partitioner.test_range_names.values()

    # GET PARTITIONS
    S_train_plus_test_partitions = [NumericIndexBasedPartition(partitioner_function_name='.next_train_plus_test_partition', 
                                                            range_start=s, range_end=e,
                                                            depends_on=S_partitioner, ctx=ct,
                                                            name_alias=f"TrainEvalPartition_{s}_{e}") for s,e in train_plus_eval_index_range]
    S_eval_partitions = [NumericIndexBasedPartition(partitioner_function_name='.next_test_partition', 
                                                    range_start=s, range_end=e,
                                                    depends_on=S_partitioner, ctx=ct,
                                                    name_alias=f"EvalPartition_{s}_{e}") for s,e in eval_index_range]
    
    # STRAY
    S_outlier_detection = [Stray.OutlierDetection(k=stray_k, depends_on=x, name_alias=f"OutlierDetection{s}_{e}") for x,(s,e) in zip(S_train_plus_test_partitions,eval_index_range)]
    S_train_plus_eval_partitions_enriched = [Stray.MapOutliersToOriginalData(depends_on=[s1,s2], 
                                                                            name_alias=f"MapOutliersToOriginalData_{s}_{e}") for s1,s2,(s,e) in zip(S_train_plus_test_partitions,S_outlier_detection,eval_index_range)]
    S_eval_partitions_enriched = [IndexBasedFilter(depends_on=[i1,i2], name_alias=f"IndexBasedFilter_{s}_{e}") for i1,i2,(s,e) in zip(S_train_plus_eval_partitions_enriched, S_eval_partitions, eval_index_range)]
    S_eval_outliers_enriched = [Stray.FilterOutliers(depends_on=x, name_alias=f"FilterOutliers_{s}_{e}") for x,(s,e) in zip(S_eval_partitions_enriched, eval_index_range)]

    # COLLECTIVE EVALUATION
    S_all_outliers = Stray.CollectOutlierIdMultipleWindows(depends_on=S_eval_outliers_enriched)
    S_tested_data = WindowSelector(date_range=DateRange.from_iso_weeks('2008-31', '2010-01'),start="2008-31",end="2010-01", inclusive="left", depends_on=S_data, name_alias="EvalData")
    S_global_data_annotated_1 = Stray.AnnotateOutliersinOriginalData(depends_on=[S_tested_data, S_all_outliers])

    S_outliers_count = Stray.CollectOutlierCountMultipleWindows(depends_on=S_global_data_annotated_1)
    S_global_data_annotated_3 = Stray.AnnotateOutliersCountInOriginalData(depends_on=[S_global_data_annotated_1, S_outliers_count])

    # ##########  RUN
    head = S_global_data_annotated_3
    result = head.materialize()
    return result

def format_detail_table(result: ResultType[Stray.AnnotateOutliersCountInOriginalData]):
    # outlier_count = result.output.outlier.sum()
    a = result.output[result.output.outlier].copy()
    a['Week'] = a.Collection_Date.dt.strftime("%G-%V")
    a = a[['Lineage', 'Clade', 'Location', 'Host', 'Host_Type', 'Collection_Date', 'Submission_Date', 'Flu_Season', 'Sort_Key', 'lat', 'lon', 'country_or_state', 'country', 'state', 'bmc_cluster_label', 'Week']]
    return a 

def write_sqlite_table(df, name, cnx):
    with cnx.connect() as connection:
        df.to_sql(name=name, con=connection, if_exists='replace')

def loop(window_sizes, ks):
    progress = tqdm(total=len(window_sizes)*len(ks))
    comb_n = 0
    
    for window_size in tqdm(window_sizes):
        for k in tqdm(ks):
            if k >= window_size:
                continue

            
            try:
                result_unformatted = run(S_global_data_sorted_bmc_annotated_2, window_size, k)
                ## tabella di sequenze con almeno un warning 
            except Exception as e:
                print(f"Error while testing combination (comb_n) {comb_n}: {window_size} {k}")
                raise e
                      
            try:
                detail_table = format_detail_table(result_unformatted)
            except Exception as e:
                print(f"Error while formatting combination (comb_n) {comb_n}: {window_size} {k}")
                raise e
            
            try:
                write_sqlite_table(detail_table, name=f"window{window_size}_k{k}", cnx=cnx)
            except Exception as e:
                print(f"Error while writing warnings detail file of combination (comb_n) {comb_n}: {window_size} {k}")
                raise e
            
            comb_n += 1
            progress.update(1)
                               
loop(window_sizes=(5,10,50,100), ks=(1,3,5,10,15))
