In [None]:
import sys
repo_dir = '/home/labs/amit/noamsh/repos/MM_2023'
sys.path.append(repo_dir)

%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
from omegaconf import OmegaConf

import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import pyreadr

import anndata as ad
import scanpy as sc

from data_loading.utils import load_dataframe_from_file
from clinical_predictions.clinical_data_loading import load_and_process_clinical_data
from data_loading.utils import get_updated_disease_col
from data_loading.utils import extract_samples_metadata

In [None]:
config_path = Path(repo_dir, 'config.yaml')
conf = OmegaConf.load(config_path)

## patient history data

In [None]:
sample_level = True

### samples data loading

In [None]:
adata = ad.read_h5ad(Path(conf.outputs.output_dir, "adata_with_scvi_annot_pred_data_v_20240813_ts_2024-08-18_only_pc_annotated_filtered.h5ad"))
adata.obs['Hospital.Code'] = adata.obs['Hospital.Code'].apply(lambda h_code: f"cart_p{h_code.replace('cart','')}" if h_code in ('cart13', 'cart21') else h_code)

all_samples = extract_samples_metadata(adata, ['Disease', 'Project', 'Cohort'], 
                                       split_by_method=True, split_by_sample=True,
                                       generate_architype_id=True, generate_hl_architype_id=True,
                                       generate_sample_level_ids=sample_level,
                                       code_lower_case=True)
all_samples

#### all smaples with clinical history

In [None]:
new_hospital_path = Path('/home/labs/amit/noamsh/data/mm_2023/clinical_prediction/Anonymized_CRF_BP_01082024.xlsx')
clinical_data = load_and_process_clinical_data(new_hospital_path, 
                                               code_lower_case=True, get_treatment_history=True, get_pfs_data=True,
                                               get_hospital_stage=False, get_post_treatment=True, get_combination_exposure=False, 
                                              get_fish_data=True)
clinical_data.shape

In [None]:
all_samples_with_clinical_data = all_samples.merge(clinical_data, how='left', 
                                                   left_on=['Hospital.Code', 'Biopsy.Sequence'],
                                                   right_on=['Code', 'Biopsy sequence No.'])

all_samples_with_clinical_data.drop(columns=['Code', 'Biopsy sequence No.'], inplace=True)
all_samples_with_clinical_data.shape

In [None]:
## this aggregaton is relevant only if later PID encoding (z-score) is not the same as current sample level
### prefer later Biopsy.Sequence
### on post data: prefer sensitive over Reff
### on pre data: prefer Reff over exposed
if not sample_level:
    numeric_cols = all_samples_with_clinical_data.select_dtypes(include=np.number).columns.tolist()
    non_numric_cols = list(set(all_samples_with_clinical_data.columns) - set(numeric_cols))
    
    agg_gict = {non_numeric_col: pd.Series.mode for non_numeric_col in non_numric_cols}
    agg_gict.update({numeric_col: max for numeric_col in numeric_cols})  
    
    all_samples_with_clinical_data = all_samples_with_clinical_data.groupby('PID').agg(agg_gict)
    all_samples_with_clinical_data.drop(columns=["PID"], inplace=True)
else:
    all_samples_with_clinical_data = all_samples_with_clinical_data.set_index("SID")
all_samples_with_clinical_data

#### arch and z-score

In [None]:
if sample_level:
    # nmf_path = "/home/labs/amit/annaku/repos/MM_2024_AK/Shuang_scripts/outputs/arch_sample_v5_without_drivers.csv"
    nmf_path = "/home/labs/amit/annaku/repos/MM_2024_AK/Shuang_scripts/outputs/arch_sample_v5_with_drivers.csv"
else:
    nmf_path = "/home/labs/amit/annaku/repos/MM_2024_AK/Shuang_scripts/outputs/arch_sample_v4_without_drivers.csv"

# load old shuang architypes to rename our new ones
path_sh = '/home/labs/amit/shuangyi/Project_MM3/Atlas/scvi_diff/z_v4_cl_clus.Rds'
result = pyreadr.read_r(path_sh)
df_sh = result[None]
df_sh['PID'] = df_sh['PID'].str.lower()

if not sample_level:
    arch_score_df = pd.read_csv(nmf_path)
    arch_score_df['index'] = arch_score_df['index'].str.lower()
    arch_score_df = arch_score_df.set_index('index').drop(columns=['Unnamed: 0', 'Row.names'])
    
    print("make sure that the following manual cluster map match the print")
    print(pd.concat([arch_score_df['Cluster'], df_sh.set_index("PID")['clus_new']], axis=1).value_counts().sort_index())
    arch_map = {'1': '2', '2': '8', '3': '7', '4':'6', '5':'5', '6':'1', '7':'4', '8':'3'}
    
    arch_score_df = arch_score_df.rename(columns=arch_map)
    arch_score_df = arch_score_df.drop(columns='Cluster')
    arch_score_df = arch_score_df.merge(df_sh[["PID", "clus_new"]], how='inner', left_index=True, right_on='PID')
    arch_score_df = arch_score_df.set_index('PID')
    arch_score_df = arch_score_df.rename(columns={'clus_new': 'architype'})
    arch_score_df['architype'] = arch_score_df['architype'].astype(str)
else:
    arch_score_df = arch_score_df = pd.read_csv(nmf_path)
    arch_score_df['index'] = arch_score_df['index'].str.lower()
    arch_score_df['SID'] = arch_score_df['index']
    arch_score_df = arch_score_df.set_index('SID').drop(columns=['Unnamed: 0', 'Row.names', 'index'])
    arch_score_df = arch_score_df.rename(columns={'Cluster': 'architype'})

    print("make sure that the following manual cluster map match the print")
    arch_score_df['PID'] = pd.Series([str(sid)[:-2] for sid in arch_score_df.index], index=arch_score_df.index)
    cm = confusion_matrix(arch_score_df.merge(df_sh, how='inner', on="PID")['clus_new'].astype(int),
                      arch_score_df.merge(df_sh, how='inner', on="PID")['architype'].astype(int))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                                  display_labels=[int(i) for i in range(1,9)])
    disp.plot()
    disp.ax_.set_title('before renameing')
    
    # arch_map = {'1': '4', '2': '2', '3': '3', '4': '5', '5': '6', '6': '8',  '7': '1', '8': '7'} # for version with no MM drivers
    arch_map = {'1': '8', '2': '3', '3': '5', '4': '4', '5': '7', '6': '1', '7': '6', '8': '2'}
    
    arch_score_df = arch_score_df.rename(columns=arch_map)
    arch_score_df['architype'] = arch_score_df['architype'].apply(lambda x: arch_map[str(x)])
    
    cm = confusion_matrix(arch_score_df.merge(df_sh, how='inner', on="PID")['clus_new'].astype(int),
                      arch_score_df.merge(df_sh, how='inner', on="PID")['architype'].astype(int))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                                  display_labels=[int(i) for i in range(1,9)])
    disp.plot()
    disp.ax_.set_title('after renameing')


In [None]:
arch_score_df

#### exploration

In [None]:
## exploration of patients with multiple clinical data points 
### some from MARS, SPID split
### some from multple timepoints split

patients_sample_count = all_samples_with_clinical_data['Hospital.Code'].value_counts()
duplicated_patients = list(patients_sample_count[patients_sample_count > 1].index)

multi_sample_patints = all_samples_with_clinical_data[all_samples_with_clinical_data["Hospital.Code"].apply(lambda x: x in duplicated_patients)]
multi_sample_patints


## adata createion

In [None]:
if sample_level:
    zstat_path = Path('/home/labs/amit/annaku/repos/MM_2024_AK/Shuang_scripts/outputs/zstat_Atlas_v5_full.txt')
    zstat_df = load_dataframe_from_file(zstat_path).rename(columns={'Unnamed: 0':"gene_name"}).set_index("gene_name").T
    zstat_df.index = zstat_df.index.str.lower()
    zstat_df
else:
    zstat_path = Path("/home/labs/amit/shuangyi/Project_MM3/Atlas/scvi_diff/zstat_Atlas_20240519_full.txt")
    zstat_df = load_dataframe_from_file(zstat_path).rename(columns={'Unnamed: 0':"gene_name"}).set_index("gene_name").T
    zstat_df.index = zstat_df.index.str.lower()
zstat_df

In [None]:
if sample_level:
    method = pd.Series(zstat_df.index).apply(lambda x: x.split("_")[0].split(".")[1])
    mal_or_healthy = pd.Series(zstat_df.index).apply(lambda x: x.split("_")[1])
    hospital_code = pd.Series(zstat_df.index).apply(lambda x: "".join(x[:-2].split("_")[-1]))
else:
    method = pd.Series(zstat_df.index).apply(lambda x: x.split("_")[0].split(".")[1])
    mal_or_healthy = pd.Series(zstat_df.index).apply(lambda x: x.split("_")[1])
    hospital_code = pd.Series(zstat_df.index).apply(lambda x: "".join(x.split("_")[-1]))
# hospital_code.value_counts()

In [None]:
def np_relu(x):
    return np.maximum(x, 0)
    
filtered_genes = list(zstat_df.columns[((((zstat_df >= 3).sum(axis=0) > 1)) & (np_relu(zstat_df).std() > 0.5))])
if 'KCNN3' not in filtered_genes:
    print("adding KCNN3")
    filtered_genes.append('KCNN3')
filtered_zstat_df = zstat_df[filtered_genes]
filtered_zstat_df

In [None]:

obs = pd.DataFrame(data=zip(mal_or_healthy, zstat_df.index),  columns=["mal_or_healthy", "id"]).set_index("id")
obs = obs.merge(all_samples_with_clinical_data,
                how="inner", left_index=True, right_index=True, validate="1:1", indicator=True)

adata_patients = ad.AnnData(zstat_df.loc[obs.index],  obs, zstat_df.columns.to_frame())
adata_patients.raw = adata_patients

adata_patients = adata_patients[:, list(filtered_zstat_df.columns)]

treatment_names = ["Bortezomib", "Ixazomib", "Carfilzomib", "Lenalidomide", "Thalidomide", "Pomalidomide",
                               "Cyclophosphamide", "Chemotherapy", "Venetoclax", "Dexamethasone", "Prednisone",
                               "Daratumumab", "Elotuzumab", "Belantamab", "Talquetamab", "Teclistamab", "Cevostamab",
                               "Selinexor", "Auto-SCT", "CART"]

adata_patients.obs['Project'] = adata_patients.obs['Project'].astype(str).astype('category')

In [None]:
adata_patients.obs = adata_patients.obs.merge(arch_score_df, how="left", right_index=True, left_index=True)
adata_patients

In [None]:
adata_patients = adata_patients[adata_patients.obs["Disease"] != 'Healthy']
adata_patients

#### visaliztation genes per arch

In [None]:
import plotly.graph_objects as go

genes_of_interset = ['FCRLA', 'TNFRSF13B', 'KCNN3', 'CCR10']
box_df = pd.concat([adata_patients[:, genes_of_interset].to_df(), adata_patients.obs['architype']], axis=1)

figs = []
for gene in genes_of_interset:
    fig = go.Figure()
    for arch in pd.Series(box_df['architype'].unique()).sort_values():
        if arch != '1':
            color_thresh = 0.9
            scores = box_df[box_df['architype']==arch][gene]
            color = 'black' if np.median(scores) >= color_thresh else 'gray'
            fig.add_trace(go.Box(x=scores, name=arch, marker_color=color))
    fig.add_vline(x=color_thresh, line_width=1, line_dash="dash", line_color="black")
    fig.add_vline(x=0, line_width=1, line_dash="dash", line_color="black")
    
    fig.update_layout(title=gene, 
                      width=300, height=500, 
                      showlegend=False,
                    plot_bgcolor='white')
    fig.update_layout(yaxis=dict(title_text="Architype"),
                      xaxis=dict(title_text="Z score", dtick = 1.5))
    fig.update_xaxes(
        mirror=True,
        ticks='outside',
        showline=True,
        linecolor='black',
        gridcolor='lightgrey'
    )
    fig.update_yaxes(
        mirror=True,
        ticks='outside',
        showline=True,
        linecolor='black',
        gridcolor='lightgrey'
    )

    figs.append(fig)
    fig.show()

In [None]:
from plotly.subplots import make_subplots

n_cols = len(genes_of_interset)
n_rows = 1
fig = make_subplots(n_rows, cols=n_cols, subplot_titles=genes_of_interset) 

shapes = []
for i, figure in enumerate(figs):
    row = i
    col = 1
    for trace in range(len(figure["data"])):
        fig.add_trace(figure["data"][trace], row=1, col=i+1)
        # fig.add_trace(figure["layout"]['shapes'], row=1, col=i+1)
        # shapes.append(figure["layout"]['shapes'])
    

fig.update_layout(height=500, width=300*n_cols, showlegend=False, plot_bgcolor='white')
fig.update_layout(yaxis=dict(title_text="Architype"))
fig.update_xaxes(
    mirror=True,
    ticks='outside',
    showline=True,
    linecolor='black',
    gridcolor='lightgrey',
    dtick = 1.5,
    title_text="Z score"
)
fig.update_yaxes(
    mirror=True,
    ticks='outside',
    showline=True,
    linecolor='black',
    gridcolor='lightgrey'
)

fig.show()

## labels engeneering - history

In [None]:
y = adata_patients.obs[treatment_names]
labels_count = y.count(axis=1)
print(f"mean number of labels: {labels_count.mean()}")
labels_count.hist(bins=20)
y.count()

In [None]:
covered_treatments = list(y.count()[y.count() > 11].index)
y = y[covered_treatments]
adata_patients.obs["num_exposed_treatments"] = adata_patients.obs[covered_treatments].count(axis=1)
labels_count = y.count(axis=1)
print(f"mean number of labels: {labels_count.mean()}")
labels_count.hist(bins=20)
y.count()

In [None]:
target_2_treatments = {
   'Proteasome': ['Bortezomib', 'Carfilzomib', 'Ixazomib'],
   'IMIDs': ['Lenalidomide', 'Pomalidomide', 'Thalidomide'],
   'CD38_target': ['Daratumumab'],
   'BCMA': ['Belantamab', 'Teclistamab', 'CAR-T'],
   # 'GPRC5D': ['Talquetamab'],
   # 'CART': ['CART'],
   # 'Nuclear_Export_Inhibitor': ['Selinexor'],
   # 'Steroid': ['Dexamethasone'], #would not even look at this….. most regimens have steroids, synergistic to other drugs
   'Chemo': ['Chemo', 'Auto-SCT', 'Cyclophosphamide']
}

treatment_2_target = {}
for target, treatments in target_2_treatments.items():
    for treatment in treatments:
        treatment_2_target[treatment] = target


In [None]:
def transform_drugs_history_to_targets_history(y: pd.DataFrame):
    target_history_dict = {}
    n_patients = len(y)
    for target, treatments in target_2_treatments.items():
        y_target = pd.Series([np.nan]*n_patients, index=y.index)
        for treamtent in treatments:
            if treamtent in y.columns:
                value = np.nanmax(pd.concat([y_target, y[treamtent]], axis=1), axis=1)
                y_target = pd.Series(value, y.index)
        target_history_dict[target] = y_target
    for treamtent in y.columns:
        if treamtent not in treatment_2_target:
            target_history_dict[treamtent] = y[treatment]
    return pd.DataFrame(target_history_dict)


In [None]:
y_targets = transform_drugs_history_to_targets_history(y).dropna(axis=1, how='all')


In [None]:
labels_count = y_targets.count(axis=1)
print(f"mean number of labels: {labels_count.mean()}")
labels_count.hist(bins=20)
y_targets.count()

In [None]:
y_exposed_and_ref = ~y.isna()
y_ref = y == 2
y_any = ~ (y_exposed_and_ref == 0).all(axis=1)
y_any.value_counts()

In [None]:
import seaborn as sns
sns.clustermap(pd.DataFrame(np.corrcoef(y_exposed_and_ref.T),
                            index=covered_treatments,
                            columns=covered_treatments
                           ), annot=True)

In [None]:
y_target_exposed_and_ref = ~y_targets.isna()
y_target_ref = y_targets == 2
y_target_any = ~ (y_target_exposed_and_ref == 0).all(axis=1)
y_target_any.value_counts()

In [None]:
covered_targets = list(y_target_exposed_and_ref.columns)
sns.clustermap(pd.DataFrame(np.corrcoef(y_target_exposed_and_ref.T),
                            index=covered_targets,
                            columns=covered_targets
                           ), annot=True)

## baseline - architypes

In [None]:
nmf_emb = 'NMF_emb'
nmf_neighborhood = 'NMF_neighborhod'
adata_patients.obsm[nmf_emb] = adata_patients.obs[[str(i) for i in range(1,9)]]

In [None]:
# adata_patients[adata_patients.obsm[nmf_emb].isna().sum(axis=1) != 0].obs
adata_patients = adata_patients[~adata_patients.obs['architype'].isna()]

In [None]:
adata_patients

In [None]:
sc.pp.neighbors(adata_patients, use_rep=nmf_emb, n_neighbors=7, key_added=nmf_neighborhood, metric='l2')
sc.tl.umap(adata_patients, min_dist=0.5, neighbors_key=nmf_neighborhood)
sc.tl.leiden(adata_patients, resolution=0.5, neighbors_key=nmf_neighborhood)

In [None]:
sc.pl.umap(adata_patients, color=["Method", "leiden", 'Disease', 'architype', 'num_exposed_treatments', 'Biopsy.Sequence', 'mal_or_healthy'], ncols=3)

In [None]:
sc.pl.umap(adata_patients, color=[str(i) for i in range(1,9)], ncols=3)

In [None]:
# arch_programs_values = adata_patients.obs[[str(i) for i in range(1,9)]]
# arch_programs_values

In [None]:
program_thresh = 0.15
for i in range(1,9):
    adata_patients.obs[f"is_{i}"] = adata_patients.obs[str(i)] >= program_thresh

arch_programs_values = adata_patients.obs[[str(i) for i in range(1,9)]]

adata_patients.obs[f"is_low"] = (arch_programs_values <= program_thresh).all(axis=1)
has_multiprograms = (arch_programs_values >= program_thresh).sum(axis=1) > 1
adata_patients.obs[f"is_multiprograms"] = (has_multiprograms)

second_largest = arch_programs_values.apply(lambda row: row.nlargest(2).values[-1],axis=1)
adata_patients.obs[f"is_anbigius"] = (arch_programs_values.max(axis=1) - second_largest) < (0.15 *arch_programs_values.max(axis=1))

sc.pl.umap(adata_patients, color=[f"is_{i}" for i in range(1,9)], ncols=3)
sc.pl.umap(adata_patients, color="is_low")
sc.pl.umap(adata_patients, color="is_anbigius")

## baseline - pca

In [None]:
n_emb = 20
pca = PCA(n_components=n_emb)

zstat_pca = pca.fit_transform(adata_patients.X)

In [None]:
pca_emb = 'pca_emb'
pca_neighborhood = 'pca_neirborhod'
adata_patients.obsm[pca_emb] = zstat_pca

sc.pp.neighbors(adata_patients, use_rep=pca_emb, n_neighbors=7, key_added=pca_neighborhood)
sc.tl.umap(adata_patients, min_dist=0.4, neighbors_key=pca_neighborhood)
sc.tl.draw_graph(adata_patients, layout='fr', neighbors_key=pca_neighborhood)
sc.tl.leiden(adata_patients, resolution=0.5, neighbors_key=pca_neighborhood)


In [None]:
samples_clustered_by_pca_not = adata_patients[adata_patients.obs['leiden']=='4'].obs_names
# list(patietns_clustered_mix_arc)

samples_low_on_all_arch = adata_patients[adata_patients.obs['is_low']].obs_names
# list(samples_low_on_all_arch)

samples_low_on_arch_in_pca_arc = list(set(samples_clustered_by_pca_not).intersection(set(samples_low_on_all_arch)))
# samples_low_on_arch_in_pca_arc

samples_low_on_arch_not_in_pca_arc = list(set(samples_low_on_all_arch).difference(set(samples_clustered_by_pca_not)))
# samples_low_on_arch_not_in_pca_arc

In [None]:
sc.pl.umap(adata_patients, color=["Method", "leiden", 'Disease', 'architype', 'num_exposed_treatments', 'Biopsy.Sequence', 'mal_or_healthy'], ncols=3)
# sc.pl.draw_graph(adata_patients, color=["Method_x", "leiden", 'Disease', 'architype', 'num_exposed_treatments', 'Biopsy.Sequence', 'mal_or_healthy'], ncols=3)

In [None]:
sc.pl.umap(adata_patients, color=[str(i) for i in range(1,9)], ncols=4)
# sc.pl.draw_graph(adata_patients, color=[str(i) for i in range(1,9)], ncols=3)

In [None]:
# sc.pl.umap(adata_patients, color=[f"is_{i}" for i in range(1,9)], ncols=3)
sc.pl.umap(adata_patients, color=["is_low", "is_anbigius"])
adata_patients.obs[['mal_or_healthy', "is_low", "is_anbigius"]].value_counts()
# sc.pl.draw_graph(adata_patients, color=[f"is_{i}" for i in range(1,9)], ncols=3)
# sc.pl.draw_graph(adata_patients, color="is_low")

In [None]:
# adata_patients.obs[(adata_patients.obs['is_low'] & adata_patients.obs['is_anbigius'])]['architype'].value_counts()
adata_patients.obs[adata_patients.obs['is_anbigius']]['architype'].value_counts()
# adata_patients.obs[adata_patients.obs['is_low']]['architype'].value_counts()


In [None]:
for i in range(20):
    adata_patients.obs[f"PC_{i}"] = zstat_pca[:,i]
sc.pl.umap(adata_patients, color=[f"PC_{i}" for i in range(20)], ncols=4)

## prediction of clinical history

In [None]:
patient_emb = nmf_emb
X = adata_patients.obsm[patient_emb]
X = pd.DataFrame(X, index=adata_patients.obs_names)

In [None]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import RidgeClassifier
from sklearn.linear_model import LogisticRegressionCV, LogisticRegression
from sklearn.neighbors import KNeighborsClassifier

from sklearn.model_selection import train_test_split
import sklearn.metrics as metrics

In [None]:
covered_targets_no_dex = covered_targets.copy()
covered_targets_no_dex.remove('Dexamethasone')
covered_targets_no_dex

In [None]:
y_target_exposed_and_ref.loc[X.index]

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y_target_exposed_and_ref.loc[X.index][covered_targets_no_dex])
# X_train, X_test, y_train, y_test = train_test_split(X, y_any)

In [None]:
# clf = LogisticRegression()
clf = KNeighborsClassifier(n_neighbors=3)
# clf = RandomForestClassifier(max_depth=2,n_estimators=400, random_state=0)
# clf = RandomForestClassifier(max_depth=3, random_state=0, class_weight='balanced')

clf.fit(X_train, y_train)

In [None]:
if len(y_train.shape) >1 :
    y_pred_train = pd.DataFrame(clf.predict(X_train), index=X_train.index, columns=y_train.columns)
    y_proba_train = pd.DataFrame([probas[:,1] for probas in clf.predict_proba(X_train)]).T
    y_proba_train.index = X_train.index
    y_proba_train.columns = y_train.columns
    
    y_pred_test = pd.DataFrame(clf.predict(X_test) , index=X_test.index, columns=y_test.columns)
    y_proba_test = pd.DataFrame([probas[:,1] for probas in clf.predict_proba(X_test)]).T
    y_proba_test.index = X_test.index
    y_proba_test.columns = y_test.columns

else:
    y_pred_train = clf.predict(X_train)
    y_proba_train = clf.predict_proba(X_train)
    
    y_pred_test = clf.predict(X_test)
    y_proba_test = clf.predict_proba(X_test)


In [None]:
print(metrics.classification_report(y_true=y_train, y_pred=y_pred_train))
if len(y_train.shape) >1 :
    print(y_train.columns)
else:
    print(f"{sum(y_pred_train)} predicted positive")

In [None]:
if len(y_train.shape) >1 :
    for treat_name in y_train.columns:
        y_train_treat = y_train[treat_name]
        y_pred_train_treat = y_pred_train[treat_name]
        print(f"{treat_name}, {sum(y_pred_train_treat)} predicted positive")
        print(metrics.classification_report(y_true=y_train_treat, y_pred=y_pred_train_treat))

In [None]:
# multilabel metrics
if len(y_train.shape) >1 :
    print(metrics.classification_report(y_test, y_pred_test))
    
    print(metrics.label_ranking_loss(y_train, y_proba_train))
    print(metrics.label_ranking_loss(y_test, y_proba_test))
    
    print(metrics.label_ranking_average_precision_score(y_train, y_proba_train))
    print(metrics.label_ranking_average_precision_score(y_test, y_proba_test))
else:
    print(f"{sum(y_pred_test)} predicted positive")
    print(metrics.classification_report(y_test, y_pred_test))

In [None]:
sc.pl.umap(adata_patients, color=covered_treatments)

In [None]:
y_pred_train.columns = [f"{col}_pred_train" for col in y_train.columns]
y_pred_test.columns = [f"{col}_pred_test" for col in y_train.columns]
# pd.concat([y_pred_train, y_pred_test])

In [None]:
y_target_labels = y_target_exposed_and_ref[covered_targets_no_dex]
y_target_labels.columns = [f"{col}_label" for col in y_train.columns]

In [None]:
combined_obs = pd.concat([adata_patients.obs, y_targets.loc[adata_patients.obs.index]], axis=1)
adata_patients.obs = combined_obs.loc[:,~combined_obs.columns.duplicated()].copy()
sc.pl.umap(adata_patients, color=covered_targets)


In [None]:
combined_obs = pd.concat([adata_patients.obs, y_pred_train, y_pred_test, y_target_labels], axis=1)
adata_patients.obs = combined_obs.loc[:,~combined_obs.columns.duplicated()].copy()
sc.pl.umap(adata_patients, color=[f"{col}_pred_train" for col in y_train.columns] + [f"{col}_pred_test" for col in y_train.columns] + [f"{col}_label" for col in y_train.columns])

## AE

In [None]:
covered_targets

## prediction treatment outcome

In [None]:
adata_post_treatment_pred = adata_patients[(adata_patients.obs['Biopsy.Sequence']==1) & (adata_patients.obs['mal_or_healthy']=='malignant')]
adata_post_treatment_pred

#### feature_selection

In [None]:
use_fish_features_for_prediction = False
use_genes = True 
use_nmf = False

feats = []

In [None]:
fish_feats = [col for col in adata_patients.obs.columns if "t(" in col or "del(" in col or col in ['1q21+', 'IGH rearrangement', 'Cytogenetics Risk (0=standard risk, 1=single hit, 2=2+ hits)']]
if use_fish_features_for_prediction:
    adata_patients.obs[fish_feats] = adata_patients.obs[fish_cols].fillna(-1) ## need to check this if want to use fish features
    feats += fish_feats

genes_feats = list(adata_post_treatment_pred.var_names)
if use_genes:
    feats += genes_feats

nmf_feats = [str(i) for i in range(1,9)]
if use_nmf:
    feats += nmf_feats

In [None]:
X_all_feats = pd.concat([adata_post_treatment_pred.to_df(), adata_post_treatment_pred.obs[nmf_feats + fish_feats]] , axis=1)
X_all_feats.shape

In [None]:
X_all = X_all_feats[feats]
X_all.shape

#### lebel engeneering

In [None]:
from clinical_predictions.clinical_data_loading import add_CART_response, add_Kydar_response, add_general_response

all_metadata_df = adata_post_treatment_pred.obs
all_metadata_df = add_general_response(all_metadata_df, pfs_thresh_months=9)
all_metadata_df = add_Kydar_response(all_metadata_df, number_of_months=9)

CAR_T_full_clinical_data_path = Path('/home/labs/amit/noamsh/data/mm_2023/clinical_prediction/CART MM responder vs non responder 160424.xlsx')
CAR_T_full_clinical_data = load_dataframe_from_file(CAR_T_full_clinical_data_path)
all_metadata_df = add_CART_response(all_metadata_df, full_clinical_df=CAR_T_full_clinical_data, pfs_policy="9M PFS")

all_metadata_df = all_metadata_df.set_index(adata_post_treatment_pred.obs_names)

In [None]:
response_cols = all_metadata_df.columns[all_metadata_df.columns.str.contains('_response')]

v_counts = []
for col in response_cols:
    v_counts.append(all_metadata_df[col].value_counts().rename(col))
pd.concat(v_counts, axis=1)

In [None]:
treatmetns_with_response = ['general', 'general_pfs', 'Kydar', 'CART']
datasets = {}
for treatment in treatmetns_with_response:
    y_treatment = all_metadata_df[f'{treatment}_response'].dropna()
    y_treatment = (y_treatment == 'R').astype(int)
    datasets[treatment] = (X_all.loc[y_treatment.index], y_treatment)


In [None]:
from sklearn.model_selection import train_test_split
from clinical_predictions.optuna_optimization import get_best_model_with_optuna
from clinical_predictions.evaluation import train_and_eval_model

In [None]:
monitors = {}
for treatment, (X, y) in datasets.items():
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)
    best_model, best_trail = get_best_model_with_optuna(X_train, y_train, precision_alpha = 0, n_trials=100)
    monitors[treatment] = train_and_eval_model(X_train, X_test, y_train, y_test, best_model, 
                                               extra_for_report=pd.DataFrame(best_trail.user_attrs["scores"]))

In [None]:
from clinical_predictions.evaluation import generate_datasets_summerization
generate_datasets_summerization(monitors)

In [None]:
monitors['general_pfs'].report['extra']