In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np 
import pandas as pd 
import scanpy as sc

import matplotlib.pyplot as plt 
import seaborn as sns

In [None]:
import sys, os
sys.path.append('../src')

# from interaction import Interaction
from util import compute_auc

In [None]:
# # JING CLONAL EXPANSION
# x_path = '/ix/djishnu/Jane/SLIDESWING/jing_data/KIR+TEDDY/data/KIR+TEDDY_rna_filtered85.csv'
# y_path = '/ix/djishnu/Jane/SLIDESWING/jing_data/KIR+TEDDY/data/KIR+TEDDY_Yexpanded_filtered85.csv'
# slide_outs = '/ix/djishnu/Jane/SLIDESWING/jing_data/KIR+TEDDY/KIR+TEDDY_filtered85/KIR+TEDDY_filtered85_noint_output/0.01_0.5_out'
# y = pd.read_csv(y_path)['Y'].values

# JING TUMOR TIL VS TEMRA
x_path = '/ix/djishnu/alw399/SLIDE_PLM/data/jing_tumor/tumor_x2.csv'
y_path = '/ix/djishnu/alw399/SLIDE_PLM/data/jing_tumor/tumor_y2.csv'
slide_outs = '/ix/djishnu/alw399/SLIDE_PLM/data/jing_tumor/0.05_0.5_out'
y = pd.read_csv(y_path)['y'].values

# # ALOK ANTIGEN SPECIFICITY 
# x_path = '/ix/djishnu/Jane/SLIDESWING/alok_data/data/Ins1_InsChg2_rna_MRfilt_forSLIDE.csv'
# y_path = '/ix/djishnu/Jane/SLIDESWING/alok_data/data/Ins1_InsChg2_rna_MRfilt_antigens.csv'
# slide_outs = '/ix/djishnu/Jane/SLIDESWING/alok_data/alok_data12_MRfilt_noint_out/0.01_2_out'
# y = pd.read_csv(y_path)['Antigen'].values - 1 


In [None]:
from util import get_genes_from_slide_outs

lf_dict = get_genes_from_slide_outs(slide_outs)
lf_dict.keys()

In [None]:
all_genes = np.unique(np.concatenate([lf_dict[lf] for lf in lf_dict]))
len(all_genes)

In [None]:
from genept import GenePTEmbedder

genept = GenePTEmbedder(species='human')
gene_embeddings = genept.get_gene_info(all_genes)
gene_embeddings.shape

In [None]:
# Create 0, 1 presence/absence matrix

gex_df = pd.read_csv(x_path, usecols=list(all_genes))
gex_threshes = gex_df.mean(axis=0)

mask_df = pd.DataFrame(
    np.where(gex_df > gex_threshes, 1, 0), 
    index=gex_df.index, 
    columns=gex_df.columns
)

mask_df.shape

In [None]:
genept_df = np.einsum('ij,jk->ijk', mask_df.values, gene_embeddings)

genept_df = genept_df.reshape(gex_df.shape[0], -1)
genept_df.shape

In [None]:
wgenept_df = gex_df @ gene_embeddings
wgenept_df.shape

In [None]:
z_matrix = pd.read_csv(os.path.join(slide_outs, 'z_matrix.csv'), index_col=0)
z_matrix = z_matrix[list(lf_dict.keys())].values
z_matrix.shape

In [None]:
from models import Estimator
from sklearn.linear_model import Lasso, LinearRegression
from sklearn.neural_network import MLPClassifier


a=0.1
model = Lasso(alpha=a, max_iter=1000)

# model = LinearRegression()

# model = MLPClassifier(max_iter=1000)

In [None]:
# SLIDE z-matrix performance
lasso0 = Estimator(model=model)
auc0 = lasso0.evaluate(z_matrix, y)

# Lasso regression on LF gene expression matrix
lasso1 = Estimator(model=model)
auc1 = lasso0.evaluate(gex_df.values, y)

# Lasso regression on mean-thresholded gene expression
lasso2 = Estimator(model=model)
auc2 = lasso2.evaluate(mask_df.values, y)

# Lasso regression on semantic embeddings
lasso3 = Estimator(model=model)
auc3 = lasso3.evaluate(genept_df, y)

# Lasso regression on weighted semantic embeddings
lasso4 = Estimator(model=model)
auc4 = lasso3.evaluate(wgenept_df.values, y)

In [None]:
df = pd.DataFrame(
    np.vstack([auc0, auc1, auc2, auc3, auc4]),
    index=['z-matrix', 'gex', 'mask_gex', 'genept', 'wgenept']
)
df.reset_index(inplace=True)
df = df.melt(id_vars='index', var_name='iter', value_name='auc')

In [None]:
def filter_pairs(pairs, df):
    filtered = []
    for i, j in pairs:
        if not np.all(df[df['index'] == i]['auc'].values == df[df['index'] == j]['auc'].values):
            filtered.append((i, j))
    return filtered


In [None]:
import seaborn as sns 
import matplotlib.pyplot as plt 
from statannotations.Annotator import Annotator
import itertools 

fig, ax = plt.subplots(figsize=(10,10), dpi=150)

sns.boxplot(data=df, x='index', y='auc', hue='index', palette='hls', ax=ax, showfliers=False, order=np.unique(df['index']))
sns.stripplot(data=df, x='index', y='auc', hue='index', ax=ax, palette='hls', legend=False, linewidth=1, edgecolor='black', jitter=True)

pairs=list(itertools.combinations(np.unique(df['index']), 2))
pairs = filter_pairs(pairs, df)

annotator = Annotator(ax, pairs, data=df, x='index', y='auc', order=np.unique(df['index']))
annotator.configure(test='Kruskal', text_format='star', loc='inside', verbose=2, hide_non_significant=True)
annotator.apply_and_annotate()

means = df.groupby('index')['auc'].mean()
for i, mean in zip(means.index, means):
    plt.text(i, df['auc'].max()+0.001 , f'Mean: {mean:.2f}', ha='center', va='bottom', fontsize=8, color='black')

plt.title(f'{model.__class__.__name__} Performance')
plt.tight_layout()

In [None]:
# with open(os.path.join(slide_outs, 'standard_out.txt'), 'r') as f:
#     standard_out = f.readlines()
# slide_auc = standard_out[-1].split(' ')[-2]

# # slide_auc = '0.951218206396577'
# # slide_auc = '0.747932'

In [None]:
# pd.DataFrame({
#     'slide': slide_auc[:8],
#     'z_matrix': auc0,
#     'gex': auc1,
#     'mask_gex': auc2,
#     'wgenept': auc3
# }, index=['auc']).T
