In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np 
import pandas as pd 
import os
import torch
import scanpy as sc 

import sys
sys.path.append('../src')

embedding_dir = '../data/pjm_models/embeddings/15epochs'

In [None]:
# # CLONAL EXPANSION
# name = 'jing_clonal_expansion'
# n_pcs = 16
# model_name = f'pjm_15epochs_pca{n_pcs}'
# y_path = '/ix/djishnu/Jane/SLIDESWING/jing_data/KIR+TEDDY/data/KIR+TEDDY_Yexpanded_filtered85.csv'
# slide_outs = '/ix/djishnu/Jane/SLIDESWING/jing_data/KIR+TEDDY/KIR+TEDDY_filtered85/KIR+TEDDY_filtered85_noint_output/0.01_0.5_out'
# y = pd.read_csv(y_path)['Y'].values
# sequences = pd.read_csv('/ix/djishnu/Jane/SLIDESWING/jing_data/KIR+TEDDY/data/KIR+TEDDY_betaseqs_raw.csv', index_col=0)

# # JING TUMOR
# name = 'jing_tumor'
# n_pcs = 16
# model_name = f'pjm_15epochs_pca{n_pcs}'
# y_path = '/ix/djishnu/alw399/SLIDE_PLM/data/jing_tumor/tumor_y2.csv'
# slide_outs = '/ix/djishnu/alw399/SLIDE_PLM/data/jing_tumor/0.05_0.5_out'
# y = pd.read_csv(y_path)['y'].values
# sequences = pd.read_csv('/ix/djishnu/alw399/SLIDE_PLM/data/jing_tumor/filtered_x2_cdr3_b.csv')

# # ANTIGEN SPECIFICITY
# name = 'alok_antigen_specificity'
# n_pcs = 16
# model_name = f'pjm_15epochs_pca{n_pcs}'
# slide_outs = '/ix/djishnu/Jane/SLIDESWING/alok_data/alok_data12_MRfilt_noint_out/0.01_2_out'
# x_path = '/ix/djishnu/Jane/SLIDESWING/alok_data/data/Ins1_InsChg2_rna_MRfilt_forSLIDE.csv'
# y_path = '/ix/djishnu/Jane/SLIDESWING/alok_data/data/Ins1_InsChg2_rna_MRfilt_antigens.csv' 
# y = pd.read_csv(y_path)['Antigen'].values - 1
# sequences = pd.read_csv('/ix/djishnu/Jane/SLIDESWING/alok_data/data/Ins1_InsChg2_seqs.csv', index_col=0)['beta']


# CONGA c2_gex_donor2
name = 'conga_c2_gex_donor2'
n_pcs = 16
model_name = f'pjm_15epochs_pca{n_pcs}'
y_path = '/ix/djishnu/alw399/SLIDE_PLM/data/conga/slide/inputs/c2gex_donor2_y.csv'
slide_outs = '/ix/djishnu/alw399/SLIDE_PLM/data/conga/slide/outputs/c2gex_donor2/0.1_1_out'
y = pd.read_csv(y_path)['is_c2'].values
sequences = sc.read_h5ad('/ix/djishnu/alw399/SLIDE_PLM/data/conga/paper_data/10x_200k/donor2/donor2_conga.h5ad').obs['cdr3b']


In [None]:
from util import remove_empty_tcrs, get_sigLFs

z1s = get_sigLFs(slide_outs)
z_matrix = pd.read_csv(os.path.join(slide_outs, 'z_matrix.csv'), index_col=0)
z_matrix = z_matrix[z1s]
z_matrix.shape

In [None]:
# Interactors did not contribute to signal
# z_matrix = pd.DataFrame(z_matrix['Z7'])

In [None]:
from util import remove_empty_tcrs

sequences, y, z_matrix = remove_empty_tcrs(sequences, y, z_matrix)
sequences.shape, y.shape, z_matrix.shape

### Get embeddings from pjm model

In [8]:
import sys
sys.path.append('../src/pjm')

from pjm import from_pretrained, build_default_alphabet

alphabet = build_default_alphabet()

# Load encoder
embedder = from_pretrained(
    model_type="mmplm",
    alphabet=alphabet,
    checkpoint_path='../data/pjm_models/mmplm_15epochs_dim256_ckpt.pth',
)

In [None]:
tokenizer = alphabet.get_batch_converter()

batch_labels = sequences.astype(str)
seq_str_list = sequences.astype(str)
raw_batch = list(zip(batch_labels, seq_str_list))

labels, strs, tokens = tokenizer(raw_batch)

In [None]:
embeddings = embedder(tokens)
embeddings.shape

In [None]:
np.save(os.path.join(embedding_dir, f'{name}.npy'), embeddings.detach().numpy())

### Run interactions

In [None]:
embeddings = np.load(os.path.join(embedding_dir, f'{name}.npy'))
embeddings = embeddings.reshape(embeddings.shape[0], -1)
embeddings.shape

In [None]:
from sklearn.decomposition import PCA

model = PCA(n_components=n_pcs)
model.fit(embeddings)
tcr_embeddings = model.transform(embeddings)
tcr_embeddings.shape

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 2, figsize=(12, 5))

# Scatter plot
scatter = ax[0].scatter(tcr_embeddings[:, 0], tcr_embeddings[:, 1], c=y, cmap='tab20')
legend1 = ax[0].legend(*scatter.legend_elements(), title="Classes")
ax[0].set_title('TCR Embeddings')
ax[0].set_xlabel('PC1')
ax[0].set_ylabel('PC2')

# Cumulative variance explained plot
cumsum_variance = np.cumsum(model.explained_variance_ratio_)
ax[1].plot(range(1, n_pcs+1), cumsum_variance, marker='o')
ax[1].set_title('Cumulative Variance Explained')
ax[1].set_xlabel('Number of PCs')
ax[1].set_ylabel('Cumulative Variance Explained')

plt.tight_layout()
os.makedirs(f'../results/plm/{name}', exist_ok=True)
plt.savefig(f'../results/plm/{name}/{model_name}_pca.png')
plt.show()


In [None]:
import sys
sys.path.append('../src')

from interaction import Interaction

machop = Interaction(
    slide_outs, 
    plm_embed=tcr_embeddings, 
    y=y,
    z_matrix=z_matrix,
    interacts_only=False,
    model='LR'
)

In [None]:
fdr = 0.2
thresh = 0.8

# machop.get_sig_interactions(fdr=0.2, n_iters=20, thresh=0.7)
machop.get_sig_interactions(fdr=fdr, n_iters=20, thresh=thresh)

In [None]:
from plotting import show_interactions
show_interactions(machop, save_path=f'../results/plm/{name}/{model_name}_betas_fdr{fdr}_thresh{thresh}.png')

In [None]:
machop.get_joint_embed()
joint_embed = machop.joint_embed.copy()
joint_embed.shape

In [None]:
np.save(f'../results/plm/{name}/{model_name}_joint_embed_fdr{fdr}_thresh{thresh}.npy', joint_embed)

In [None]:
full_embed = np.hstack([machop.z_matrix, machop.plm_embedding])
full_embed.shape

In [None]:
from models import Estimator
from sklearn.linear_model import Lasso, LinearRegression

model = Lasso(alpha=0.05)

In [None]:
estimator = Estimator(model=model)
auc0 = estimator.evaluate(joint_embed, y)
auc1 = estimator.evaluate(full_embed, y)
auc2 = estimator.evaluate(machop.z_matrix, y)
auc3 = estimator.evaluate(machop.plm_embedding, y)

In [None]:
df = pd.DataFrame(
    np.vstack([auc0, auc1, auc2, auc3]),
    index=['joint', 'full', 'z-matrix', 'plm']
)
df.reset_index(inplace=True)

In [None]:
from plotting import show_performance
show_performance(model, df, 
                 save_path=f'../results/plm/{name}/{model_name}_{model.__class__.__name__}_performance_fdr{fdr}_thresh{thresh}.png')

In [None]:
model = LinearRegression()

estimator = Estimator(model=model)
auc0 = estimator.evaluate(joint_embed, y)
auc1 = estimator.evaluate(full_embed, y)
auc2 = estimator.evaluate(machop.z_matrix, y)
auc3 = estimator.evaluate(machop.plm_embedding, y)

df = pd.DataFrame(
    np.vstack([auc0, auc1, auc2, auc3]),
    index=['joint', 'full', 'z-matrix', 'plm']
)
df.reset_index(inplace=True)

show_performance(
    model, df, 
    save_path=f'../results/plm/{name}/{model_name}_{model.__class__.__name__}_performance_fdr{fdr}_thresh{thresh}.png'
)

### Examine joint_embedding

In [None]:
fdr = 0.2
thresh = 0.8
joint_embed = np.load(f'../results/plm/{name}/{model_name}_joint_embed_fdr{fdr}_thresh{thresh}.npy')
joint_embed.shape

In [None]:
import matplotlib.pyplot as plt

n_plots = (joint_embed.shape[1] + 1) // 2
n_rows = (n_plots + 2 - 1) // 2 

fig, axs = plt.subplots(2, n_rows, figsize=(20, n_rows * 4)) 
axs = axs.flatten()

for i in range(0, joint_embed.shape[1], 2):
    dim1 = joint_embed[:, i]
    dim2 = joint_embed[:, i+1] if i+1 < joint_embed.shape[1] else joint_embed[:, i-1] 
    
    axs[i//2].scatter(dim1, dim2, c=y, alpha=0.5, cmap='coolwarm', s=2)
    axs[i//2].set_xlabel(f'joint embed {i}')
    axs[i//2].set_ylabel(f'joint embed {i+1}' if i+1 < joint_embed.shape[1] else f'joint embed {i-1}')

plt.suptitle('Separation of classes using joint embedding from Daniel\'s model', fontsize=16)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])  
plt.show()
