In [1]:
# import logfire

# logfire.install_auto_tracing(modules=['cynde'])
# logfire.configure(pydantic_plugin=logfire.PydanticPlugin(record='all'))
import os
import polars as pl
from typing import List, Optional, Tuple, Generator
import time
from cynde.functional.train.types import PredictConfig, BaseClassifierConfig,StratifiedConfig,Feature,FeatureSet,NumericalFeature, CategoricalFeature,EmbeddingFeature, InputConfig, ClassifierConfig, LogisticRegressionConfig, RandomForestClassifierConfig, HistGradientBoostingClassifierConfig, CVConfig, PipelineResults, PipelineInput, PurgedConfig
from cynde.functional.train.preprocess import convert_utf8_to_enum, check_add_cv_index

from sklearn.pipeline import Pipeline
from cynde.functional.train.train_modal import train_nested_cv_distributed
from cynde.functional.train.train_local import train_nested_cv
from cynde.functional.train.preprocess import preprocess_inputs, get_unique_columns


def load_tggate_grouped_data(data_path: str = r"C:\Users\Tommaso\Documents\Dev\Cynde\cache\tgca_nogen_simplified_smiles_malformer_embeddings_ada02_3large_3small_embeddings_grouped.parquet") -> pl.DataFrame:
    return pl.read_parquet(data_path)

df = load_tggate_grouped_data()
df = convert_utf8_to_enum(df, threshold=0.2)
df = check_add_cv_index(df,strict=False)

cols = ['COMPOUND_NAME',
        'SACRIFICE_PERIOD',
        'DOSE_LEVEL',
        'num_samples',
        'lesion_prob',
        'weight_mean',
        'SMILES_CODE_MoLFormer-XL-both-10pct_embeddings',
        'COMPOUND_NAME_text-embedding-ada-002_embeddings',
        'SMILES_CODE_text-embedding-ada-002_embeddings',
        'SACRIFICE_PERIOD_text-embedding-ada-002_embeddings',
        'DOSE_LEVEL_text-embedding-ada-002_embeddings',
        'DOSE_DOSE_UNIT_text-embedding-ada-002_embeddings',
        'COMPOUND_NAME_SACRIFICE_PERIOD_DOSE_LEVEL_text-embedding-ada-002_embeddings',
        'SMILES_CODE_SACRIFICE_PERIOD_DOSE_LEVEL_text-embedding-ada-002_embeddings',
        'COMPOUND_NAME_SMILES_CODE_SACRIFICE_PERIOD_DOSE_LEVEL_text-embedding-ada-002_embeddings',
        'COMPOUND_NAME_SACRIFICE_PERIOD_DOSE_DOSE_UNIT_text-embedding-ada-002_embeddings',
        'SMILES_CODE_SACRIFICE_PERIOD_DOSE_DOSE_UNIT_text-embedding-ada-002_embeddings',
        'COMPOUND_NAME_SMILES_CODE_SACRIFICE_PERIOD_DOSE_DOSE_UNIT_text-embedding-ada-002_embeddings',
        'COMPOUND_NAME_text-embedding-3-small_embeddings', 
        'SMILES_CODE_text-embedding-3-small_embeddings',
        'SACRIFICE_PERIOD_text-embedding-3-small_embeddings', 
        'DOSE_LEVEL_text-embedding-3-small_embeddings',
        'DOSE_DOSE_UNIT_text-embedding-3-small_embeddings', 
        'COMPOUND_NAME_DOSE_LEVEL_text-embedding-3-small_embeddings',
        'SMILES_CODE_DOSE_LEVEL_text-embedding-3-small_embeddings', 
        'COMPOUND_NAME_SMILES_CODE_DOSE_LEVEL_text-embedding-3-small_embeddings',
        'COMPOUND_NAME_DOSE_DOSE_UNIT_text-embedding-3-small_embeddings', 
        'SMILES_CODE_DOSE_DOSE_UNIT_text-embedding-3-small_embeddings',
        'COMPOUND_NAME_SMILES_CODE_DOSE_DOSE_UNIT_text-embedding-3-small_embeddings', 
        'COMPOUND_NAME_text-embedding-3-large_embeddings',
        'SMILES_CODE_text-embedding-3-large_embeddings', 
        'SACRIFICE_PERIOD_text-embedding-3-large_embeddings', 
        'DOSE_LEVEL_text-embedding-3-large_embeddings',
        'DOSE_DOSE_UNIT_text-embedding-3-large_embeddings', 
        'COMPOUND_NAME_DOSE_LEVEL_text-embedding-3-large_embeddings',
        'SMILES_CODE_DOSE_LEVEL_text-embedding-3-large_embeddings', 
        'COMPOUND_NAME_SMILES_CODE_DOSE_LEVEL_text-embedding-3-large_embeddings',
        'COMPOUND_NAME_DOSE_DOSE_UNIT_text-embedding-3-large_embeddings', 
        'SMILES_CODE_DOSE_DOSE_UNIT_text-embedding-3-large_embeddings', 
        'COMPOUND_NAME_SMILES_CODE_DOSE_DOSE_UNIT_text-embedding-3-large_embeddings', 
        'COMPOUND_NAME_SMILES_CODE_SACRIFICE_PERIOD_DOSE_DOSE_UNIT_text-embedding-3-large_embeddings', 
        'SMILES_CODE_SACRIFICE_PERIOD_DOSE_DOSE_UNIT_text-embedding-3-large_embeddings', 
        'COMPOUND_NAME_SACRIFICE_PERIOD_DOSE_DOSE_UNIT_text-embedding-3-large_embeddings', 
        'COMPOUND_NAME_SMILES_CODE_SACRIFICE_PERIOD_DOSE_LEVEL_text-embedding-3-large_embeddings', 
        'SMILES_CODE_SACRIFICE_PERIOD_DOSE_LEVEL_text-embedding-3-large_embeddings', 
        'COMPOUND_NAME_SACRIFICE_PERIOD_DOSE_LEVEL_text-embedding-3-large_embeddings', 
        'COMPOUND_NAME_SMILES_CODE_SACRIFICE_PERIOD_DOSE_DOSE_UNIT_text-embedding-3-small_embeddings', 
        'SMILES_CODE_SACRIFICE_PERIOD_DOSE_DOSE_UNIT_text-embedding-3-small_embeddings', 
        'COMPOUND_NAME_SACRIFICE_PERIOD_DOSE_DOSE_UNIT_text-embedding-3-small_embeddings', 
        'COMPOUND_NAME_SMILES_CODE_SACRIFICE_PERIOD_DOSE_LEVEL_text-embedding-3-small_embeddings', 
        'SMILES_CODE_SACRIFICE_PERIOD_DOSE_LEVEL_text-embedding-3-small_embeddings', 
        'COMPOUND_NAME_SACRIFICE_PERIOD_DOSE_LEVEL_text-embedding-3-small_embeddings', 
        'target']


#the best model combines compound_name, dose_level and sacrifice period embeddings each indipendently using hte large embeddings
feature_set_best_model = {"embeddings":[{"column_name":"COMPOUND_NAME_text-embedding-3-large_embeddings",
                                         "name":"embeddings of the compounds names",
                                         "embedder": "text-embedding-3-large_embeddings",
                                         "embedding_size":3072},
                                         {"column_name":"DOSE_LEVEL_text-embedding-3-large_embeddings",
                                            "name":"embeddings of the dose levels",
                                            "embedder": "text-embedding-3-large_embeddings",
                                            "embedding_size":3072},
                                            {"column_name":"SACRIFICE_PERIOD_text-embedding-3-large_embeddings",
                                            "name":"embeddings of the sacrifice periods",
                                            "embedder": "text-embedding-3-large_embeddings",
                                            "embedding_size":3072}]}

feature_set_only_compound_name = {"embeddings":[{"column_name":"COMPOUND_NAME_text-embedding-3-large_embeddings",
                                            "name":"embeddings of the compounds names",
                                            "embedder": "text-embedding-3-large_embeddings",
                                            "embedding_size":3072}]}

input_config_data = {"feature_sets":[feature_set_best_model,feature_set_only_compound_name],
                        "target_column":"target",
                        "save_folder":"C:/Users/Tommaso/Documents/Dev/Cynde/cynde_mount/",
                        "remote_folder":"/cynde_mount"}

classifiers_config = ClassifierConfig(classifiers=[RandomForestClassifierConfig(n_estimators=10, max_depth=5)])

input_config = InputConfig.model_validate(input_config_data,context={"df":df})

inner_groups = ["DOSE_LEVEL","SACRIFICE_PERIOD","target"]
outer_groups = ["COMPOUND_NAME"]

cv_config = CVConfig(inner= StratifiedConfig(groups=inner_groups,k=2),
                     inner_replicas=1,
                     outer = PurgedConfig(groups=outer_groups,k=2),
                        outer_replicas=1)

task = PredictConfig(input_config=input_config, cv_config=cv_config, classifiers_config=classifiers_config)






df_filtered = df.select(get_unique_columns(task))

preprocess_inputs(df_filtered,task.input_config)
results = train_nested_cv(df,task)

# results = train_nested_cv_distributed(df_filtered,task)

#todo
# 1) fix the cv objects for purged (add a test) to the pydantic object I guess
# 3) results saving and aggregation

selected columns: ['cv_index', 'target', 'COMPOUND_NAME_text-embedding-3-large_embeddings', 'DOSE_LEVEL_text-embedding-3-large_embeddings', 'SACRIFICE_PERIOD_text-embedding-3-large_embeddings']
selected columns: ['cv_index', 'target', 'COMPOUND_NAME_text-embedding-3-large_embeddings']
using groups ['DOSE_LEVEL', 'target', 'COMPOUND_NAME', 'SACRIFICE_PERIOD'] to generate the cv folds
For outer replica 0, outer fold 0, inner replica 0, inner fold 0: train 423, val 413, test 838 samples.
Training pipeline with classifier RandomForestClassifier on feature set numerical=[] embeddings=[EmbeddingFeature(column_name='COMPOUND_NAME_text-embedding-3-large_embeddings', name='embeddings of the compounds names', description=None, scaler_type=<ScalerType.STANDARD_SCALER: 'StandardScaler'>, embedder='text-embedding-3-large_embeddings', embedding_size=3072), EmbeddingFeature(column_name='DOSE_LEVEL_text-embedding-3-large_embeddings', name='embeddings of the dose levels', description=None, scaler_type=

  sliced = sdf.select(pl.col("cv_index").map_elements(lambda s: hacky_list_relative_slice(s,k)).alias("hacky_cv_index")).unnest("hacky_cv_index")


Converting COMPOUND_NAME_text-embedding-3-large_embeddings of type List(Float64) to a list of columns.
Converting DOSE_LEVEL_text-embedding-3-large_embeddings of type List(Float64) to a list of columns.
Converting SACRIFICE_PERIOD_text-embedding-3-large_embeddings of type List(Float64) to a list of columns.
shape: (1_674, 9_218)
┌──────────┬────────┬────────────┬────────────┬───┬────────────┬───────────┬───────────┬───────────┐
│ cv_index ┆ target ┆ COMPOUND_N ┆ COMPOUND_N ┆ … ┆ SACRIFICE_ ┆ SACRIFICE ┆ SACRIFICE ┆ SACRIFICE │
│ ---      ┆ ---    ┆ AME_text-e ┆ AME_text-e ┆   ┆ PERIOD_tex ┆ _PERIOD_t ┆ _PERIOD_t ┆ _PERIOD_t │
│ u32      ┆ i32    ┆ mbedding-3 ┆ mbedding-3 ┆   ┆ t-embeddin ┆ ext-embed ┆ ext-embed ┆ ext-embed │
│          ┆        ┆ …          ┆ …          ┆   ┆ …          ┆ din…      ┆ din…      ┆ din…      │
│          ┆        ┆ ---        ┆ ---        ┆   ┆ ---        ┆ ---       ┆ ---       ┆ ---       │
│          ┆        ┆ f64        ┆ f64        ┆   ┆ f64        

  sliced = sdf.select(pl.col("cv_index").map_elements(lambda s: hacky_list_relative_slice(s,k)).alias("hacky_cv_index")).unnest("hacky_cv_index")


Converting COMPOUND_NAME_text-embedding-3-large_embeddings of type List(Float64) to a list of columns.
Converting DOSE_LEVEL_text-embedding-3-large_embeddings of type List(Float64) to a list of columns.
Converting SACRIFICE_PERIOD_text-embedding-3-large_embeddings of type List(Float64) to a list of columns.
shape: (1_674, 9_218)
┌──────────┬────────┬────────────┬────────────┬───┬────────────┬───────────┬───────────┬───────────┐
│ cv_index ┆ target ┆ COMPOUND_N ┆ COMPOUND_N ┆ … ┆ SACRIFICE_ ┆ SACRIFICE ┆ SACRIFICE ┆ SACRIFICE │
│ ---      ┆ ---    ┆ AME_text-e ┆ AME_text-e ┆   ┆ PERIOD_tex ┆ _PERIOD_t ┆ _PERIOD_t ┆ _PERIOD_t │
│ u32      ┆ i32    ┆ mbedding-3 ┆ mbedding-3 ┆   ┆ t-embeddin ┆ ext-embed ┆ ext-embed ┆ ext-embed │
│          ┆        ┆ …          ┆ …          ┆   ┆ …          ┆ din…      ┆ din…      ┆ din…      │
│          ┆        ┆ ---        ┆ ---        ┆   ┆ ---        ┆ ---       ┆ ---       ┆ ---       │
│          ┆        ┆ f64        ┆ f64        ┆   ┆ f64        

In [2]:
from cynde.functional.train.results import merge_results

merge_results(results)

cv_index,predictions_RandomForestClassifier_COMPOUND_NAME_text-embedding-3-large_embeddings_DOSE_LEVEL_text-embedding-3-large_embeddings_SACRIFICE_PERIOD_text-embedding-3-large_embeddings_0_0_0_0,RandomForestClassifier_COMPOUND_NAME_text-embedding-3-large_embeddings_DOSE_LEVEL_text-embedding-3-large_embeddings_SACRIFICE_PERIOD_text-embedding-3-large_embeddings_0_0_0_0,predictions_RandomForestClassifier_COMPOUND_NAME_text-embedding-3-large_embeddings_0_0_0_0,RandomForestClassifier_COMPOUND_NAME_text-embedding-3-large_embeddings_0_0_0_0,predictions_RandomForestClassifier_COMPOUND_NAME_text-embedding-3-large_embeddings_DOSE_LEVEL_text-embedding-3-large_embeddings_SACRIFICE_PERIOD_text-embedding-3-large_embeddings_0_0_0_1,RandomForestClassifier_COMPOUND_NAME_text-embedding-3-large_embeddings_DOSE_LEVEL_text-embedding-3-large_embeddings_SACRIFICE_PERIOD_text-embedding-3-large_embeddings_0_0_0_1,predictions_RandomForestClassifier_COMPOUND_NAME_text-embedding-3-large_embeddings_0_0_0_1,RandomForestClassifier_COMPOUND_NAME_text-embedding-3-large_embeddings_0_0_0_1,predictions_RandomForestClassifier_COMPOUND_NAME_text-embedding-3-large_embeddings_DOSE_LEVEL_text-embedding-3-large_embeddings_SACRIFICE_PERIOD_text-embedding-3-large_embeddings_0_1_0_0,RandomForestClassifier_COMPOUND_NAME_text-embedding-3-large_embeddings_DOSE_LEVEL_text-embedding-3-large_embeddings_SACRIFICE_PERIOD_text-embedding-3-large_embeddings_0_1_0_0,predictions_RandomForestClassifier_COMPOUND_NAME_text-embedding-3-large_embeddings_0_1_0_0,RandomForestClassifier_COMPOUND_NAME_text-embedding-3-large_embeddings_0_1_0_0,predictions_RandomForestClassifier_COMPOUND_NAME_text-embedding-3-large_embeddings_DOSE_LEVEL_text-embedding-3-large_embeddings_SACRIFICE_PERIOD_text-embedding-3-large_embeddings_0_1_0_1,RandomForestClassifier_COMPOUND_NAME_text-embedding-3-large_embeddings_DOSE_LEVEL_text-embedding-3-large_embeddings_SACRIFICE_PERIOD_text-embedding-3-large_embeddings_0_1_0_1,predictions_RandomForestClassifier_COMPOUND_NAME_text-embedding-3-large_embeddings_0_1_0_1,RandomForestClassifier_COMPOUND_NAME_text-embedding-3-large_embeddings_0_1_0_1
u32,i32,str,i32,str,i32,str,i32,str,i32,str,i32,str,i32,str,i32,str
0,1,"""test""",0,"""test""",1,"""test""",1,"""test""",1,"""train""",0,"""train""",1,"""val""",0,"""val"""
1,0,"""val""",0,"""val""",0,"""train""",0,"""train""",1,"""test""",1,"""test""",1,"""test""",0,"""test"""
2,1,"""train""",1,"""train""",1,"""val""",1,"""val""",1,"""test""",1,"""test""",1,"""test""",0,"""test"""
3,0,"""train""",0,"""train""",0,"""val""",0,"""val""",1,"""test""",1,"""test""",1,"""test""",0,"""test"""
4,1,"""test""",1,"""test""",1,"""test""",1,"""test""",1,"""train""",1,"""train""",1,"""val""",0,"""val"""
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
1669,0,"""train""",0,"""train""",0,"""val""",0,"""val""",1,"""test""",1,"""test""",0,"""test""",1,"""test"""
1670,0,"""train""",0,"""train""",0,"""val""",0,"""val""",1,"""test""",0,"""test""",1,"""test""",1,"""test"""
1671,1,"""train""",1,"""train""",0,"""val""",1,"""val""",1,"""test""",0,"""test""",0,"""test""",0,"""test"""
1672,1,"""train""",1,"""train""",1,"""val""",1,"""val""",0,"""test""",0,"""test""",1,"""test""",1,"""test"""


In [3]:
df_filtered

DOSE_LEVEL,COMPOUND_NAME_text-embedding-3-large_embeddings,SACRIFICE_PERIOD_text-embedding-3-large_embeddings,DOSE_LEVEL_text-embedding-3-large_embeddings,COMPOUND_NAME,cv_index,target,SACRIFICE_PERIOD
enum,list[f64],list[f64],list[f64],enum,u32,i32,enum
"""High""","[-0.036714, -0.006798, … -0.006142]","[-0.015142, 0.0055, … 0.005091]","[-0.019848, 0.01349, … -0.001481]","""cisplatin""",0,1,"""29 day"""
"""High""","[-0.001246, -0.027476, … -0.034356]","[-0.015142, 0.0055, … 0.005091]","[-0.019848, 0.01349, … -0.001481]","""tacrine""",1,0,"""29 day"""
"""High""","[-0.00284, -0.01553, … -0.030346]","[-0.031135, -0.022565, … -0.005233]","[-0.019848, 0.01349, … -0.001481]","""nitrosodiethylamine""",2,1,"""15 day"""
"""High""","[-0.001246, -0.027476, … -0.034356]","[-0.031135, -0.022565, … -0.005233]","[-0.019848, 0.01349, … -0.001481]","""tacrine""",3,0,"""15 day"""
"""High""","[-0.026626, -0.043638, … -0.004134]","[-0.031135, -0.022565, … -0.005233]","[-0.019848, 0.01349, … -0.001481]","""monocrotaline""",4,1,"""15 day"""
…,…,…,…,…,…,…,…
"""Low""","[-0.020394, 0.028718, … -0.00965]","[-0.031135, -0.022565, … -0.005233]","[-0.031503, 0.019799, … -0.001511]","""perhexiline""",1669,0,"""15 day"""
"""High""","[-0.001497, -0.011862, … -0.026472]","[-0.015142, 0.0055, … 0.005091]","[-0.019848, 0.01349, … -0.001481]","""danazol""",1670,0,"""29 day"""
"""Middle""","[-0.029336, -0.003656, … -0.021362]","[-0.009677, 0.01172, … -0.013844]","[-0.001699, -0.001917, … -0.008718]","""ticlopidine""",1671,1,"""4 day"""
"""Middle""","[0.007756, -0.019126, … -0.007509]","[-0.028282, -0.012715, … -0.006453]","[-0.001699, -0.001917, … -0.008718]","""WY-14643""",1672,1,"""8 day"""
