In [None]:
%load_ext autoreload
%autoreload 2
import sys
repo_dir = '/home/labs/amit/noamsh/repos/MM_2023'
sys.path.append(repo_dir)

In [None]:
from pathlib import Path
from omegaconf import OmegaConf

import pandas as pd
import numpy as np
import anndata as ad
import scanpy as sc

from matplotlib.pyplot import rc_context


from data_loading.utils import load_dataframe_from_file
from io_utils import generate_path_in_output_dir

In [None]:
config_path = Path(repo_dir, 'config.yaml')
conf = OmegaConf.load(config_path)

In [None]:
markers = ['FCRL2', 'BTLA', 'PERP', 'P2RX5', 'RASGRP3', 'FCRLA', 'KCNN3', 'TNFRSF13B', 'CCR10']

In [None]:
# load_ts_iso = "2024-06-20"
load_ts_iso = "2024-06-28"
data_version = "20240619"

In [None]:
from datetime import date
ts_iso = date.today().isoformat()

sc.set_figure_params(dpi=150, dpi_save=300)
figures_dir = Path(conf.outputs.output_dir, "figures", ts_iso)

### general evaluation of markers

In [None]:
annotated_filtered_only_pc_path = Path(conf.outputs.output_dir,
                                       f"adata_with_scvi_annot_pred_data_v_{data_version}_ts_{load_ts_iso}_only_pc_annotated_filtered.h5ad")
pc_adata = ad.read_h5ad(annotated_filtered_only_pc_path)
pc_adata

In [None]:
patients_with_more_than_20_pc = pc_adata.obs["Hospital.Code"].value_counts() > 20
pc_adata = pc_adata[pc_adata.obs["Hospital.Code"].map(patients_with_more_than_20_pc)]

In [None]:
adf = pc_adata.to_df()
adf["pid"] = pc_adata.obs["Hospital.Code"]
adf_mal = pc_adata[pc_adata.obs['pc_annotation'] == 'Malignant'].to_df()
adf_mal["pid"] = pc_adata.obs["Hospital.Code"]

In [None]:
def count_patients_geneapearance(sc_gene_df, genes, patient_col, apearance_thresh=0, cell_percentage_thresh=None):
    agg_markers = sc_gene_df.groupby(patient_col)[genes].agg({target: lambda x: np.mean(x>apearance_thresh) for target in genes})
    pp_thresholds = [cell_percentage_thresh] if cell_percentage_thresh is not None else [0.05, 0.07, 0.1, 0.15, 0.2, 0.3, 0.5]
    
    gene_apearances = {}
    for pp_thresh in pp_thresholds:
        gene_apearances[pp_thresh] = (agg_markers > pp_thresh).sum()
    
    patient_apearance = pd.DataFrame(gene_apearances)
    patient_apearance.columns.name = "precentage of cells expressing gene per patient"
    return patient_apearance
    

In [None]:
count_patients_geneapearance(adf, markers, 'pid')

In [None]:
count_patients_geneapearance(adf_mal, markers, 'pid')

In [None]:
sc.tl.dendrogram(pc_adata, var_names=markers, groupby="Hospital.Code")
sc.pl.dotplot(pc_adata, markers, groupby="Hospital.Code", dendrogram=True, swap_axes=True)

In [None]:
with rc_context({"figure.figsize": (16, 2)}):
    sc.pl.violin(pc_adata, markers, groupby="Hospital.Code", rotation=90, show=False)
    plt.savefig(Path(figures_dir, f"target_vs_patient_violins.pdf"), bbox_inches="tight", format="pdf")    

In [None]:
sc.tl.dendrogram(pc_adata, var_names=markers, groupby="Disease")
sc.pl.dotplot(pc_adata, markers, groupby="Disease", dendrogram=True, swap_axes=True)

In [None]:

allowed_MARS_trails = ["CART", "KPT", "KYDAR", "PPIA"]
allowed_SPID_trails =  ["CART", "BISE JnJ", "Transplantation"]

def get_clinical_trail(row):
    method = row['Method']
    if method == "MARS":
        clinical_trial = row['Project'] if row['Project'] in allowed_MARS_trails else None
    if method == "SPID":
        clinical_trial = row['Cohort'] if row['Cohort'] in allowed_SPID_trails else None
    return clinical_trial


In [None]:
pc_adata.obs['Clinical.Trial'] = pc_adata.obs.apply(get_clinical_trail, axis=1)
pc_adata.obs['Clinical.Trial'] = pc_adata.obs['Clinical.Trial'].astype('category')

In [None]:
sc.tl.dendrogram(pc_adata, var_names=markers, groupby='Clinical.Trial')
sc.pl.dotplot(pc_adata, markers, groupby='Clinical.Trial', dendrogram=True, swap_axes=True, figsize=(4,4))

### clinical evalutaion of markers

In [None]:
new_hospital_path = Path(conf.annotation.clinical_data.clinical_data_file_path)
new_hospital_dataset = load_dataframe_from_file(new_hospital_path)

In [None]:
pc_adata.obs['Biopsy.Sequence'] = pc_adata.obs['Biopsy.Sequence'].astype(int)

In [None]:
pc_adata.obs["Hospital.Code"] = pc_adata.obs["Hospital.Code"].str.lower()
new_hospital_dataset["Code"] = new_hospital_dataset["Code"].str.lower()

In [None]:
pc_adata.obs = pc_adata.obs.merge(new_hospital_dataset, how='left', left_on=['Hospital.Code', 'Biopsy.Sequence'], right_on=['Code', 'Biopsy sequence No.'], validate='m:1')
pc_adata.obs.shape

In [None]:
pc_adata

In [None]:
import matplotlib.pyplot as plt

ref_cols = ['Bortezomib ref.', 'Ixazomib ref.', 'Carfilzomib ref.', 'Lenalidomide ref.', 'Thalidomide ref.', 'Pomalidomide ref.', 'Cyclophosphamide ref.', 'Daratumumab ref.', 'Belantamab ref.', 'Talquetamab ref.', 'Cevostamab ref.', 'Selinexor ref.', 'Auto-SCT ref.', 'CART ref.']
for col in ref_cols:
    with plt.rc_context():  
        pc_adata.obs[col] = pc_adata.obs[col].astype("category")
        sc.pl.violin(pc_adata, markers, groupby=col, show=False)
        plt.savefig(Path(figures_dir, f"violin_of_{col}.pdf"), bbox_inches="tight", format="pdf")    
        

In [None]:
pc_adata.obs['Cytogenetics Risk (1=standard risk, 2=single hit, 3=2+ hits)'] = pc_adata.obs['Cytogenetics Risk (1=standard risk, 2=single hit, 3=2+ hits)'].astype("category")
sc.tl.dendrogram(pc_adata, var_names=markers, groupby='Cytogenetics Risk (1=standard risk, 2=single hit, 3=2+ hits)')
sc.pl.dotplot(pc_adata, markers, groupby='Cytogenetics Risk (1=standard risk, 2=single hit, 3=2+ hits)', dendrogram=True, swap_axes=True, figsize=(6,4), show=False)
plt.savefig(Path(figures_dir, f"Cytogenetics_Risk_vs_targets.pdf"), bbox_inches="tight", format="pdf")    