In [None]:
import os
import sys
import gseapy
import torch
import torch.nn as nn
import torch.optim as optim
import pickle
import yaml
import numpy as np
import pandas as pd
import scanpy as sc
import random
import matplotlib.pyplot as plt
from IPython.display import Markdown,display, Image, SVG
import io 
import shap
import shap.maskers as maskers 
%matplotlib inline

In [None]:
notebook_dir = os.path.dirname(os.path.abspath(''))
sys.path.append(notebook_dir)
from AE.AE import Autoencoder
from AE.AEclassifier import AEClassifier, ClassificationDataset

In [None]:
## Load parameters
# Autoencoder params:
with open('/app/chatbot/test_params.yaml', "r") as f:
    best_params=yaml.safe_load(f)
# Classifier params
with open('/app/chatbot/classifier_params.yaml', "r") as f:
    classifier_params=yaml.safe_load(f)

# load lables
with open('/app/chatbot/Data/training_classifier_data.pkl','rb') as f:
    data=pickle.load(f)
    labels=data['labels_names']
    num_classes=data['num_classes']
    del data
    
# load gene names
with open('/app/chatbot/Data/training_data.pkl', 'rb') as f:
    data=pickle.load(f)
    genes=data['genes']
    full_data=data['full_dataset']
    del data

with open('/app/chatbot/models/standard_scaler.pkl','rb') as f:
    ss=pickle.load(f)

with open('/app/chatbot/models/SHAP.pkl', 'rb') as f:
    explainer=pickle.load(f)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# recreate model

# Extract hidden dimensions based on the suggested params
best_hidden_dims = []
n_hidden_layers = best_params['n_hidden_layers']
for i in range(n_hidden_layers):
    # Append the hidden dimension for each layer
    if f'h_dim_{i}' in best_params:
        best_hidden_dims.append(best_params[f'h_dim_{i}'])


best_latent_dim = best_params['latent_dim']
best_dropout_rate=best_params['dropout_rate']
input_dim=len(genes)
# Extract hidden dimensions based on the suggested params
classifier_hidden_dims = []
n_hidden_layers = classifier_params['n_hidden_layers']
for i in range(n_hidden_layers):
    # Append the hidden dimension for each layer
    if f'h_dim_{i}' in classifier_params:
        classifier_hidden_dims.append(classifier_params[f'h_dim_{i}'])
classifier_dropout_rate=classifier_params['dropout_rate']

In [None]:
AE_arch = Autoencoder(input_dim,
                      best_latent_dim,
                      best_hidden_dims,
                     best_dropout_rate)


classifier_model = AEClassifier(AE_arch.encoder,num_classes=num_classes, 
                                        latent_dim=best_latent_dim,
                                        hidden_dims=classifier_hidden_dims, 
                                        dropout_rate=classifier_dropout_rate).to(device)
classifier_model.load_state_dict(torch.load('/app/chatbot/models/classifier_model.pth', map_location=device))
classifier_model = nn.Sequential(
    classifier_model,
    nn.Softmax(dim=1)  # apply softmax across classes
)

classifier_model.eval()
classifier_model.to(device)
print('')

In [None]:
def full_model_predict_proba(samples_features):
    # Ensure input is 2D (n_samples, n_features) even if only one sample
    if samples_features.ndim == 1:
        samples_features = samples_features.reshape(1, -1)


    # Convert numpy arrays to tensors and move to device
    # samples_tensor = torch.FloatTensor(samples_features).to(device)

    classifier_model.eval() # Ensure eval mode

    # Perform forward pass through the *full* loaded_classifier_model
    with torch.no_grad(): # No gradients needed for inference
        probabilities = classifier_model(samples_features)
        # probabilities = torch.softmax(logits, dim=1) # Get probabilities

    # Move probabilities back to CPU and convert to NumPy
    return probabilities

def get_random_sample(genes,ss):
    adata=sc.read('/app/chatbot/Data/dataset_annotated.h5ad')
    samples=adata.obs_names
    sample_name = random.choice(samples)
    sample_gt=sc.get.obs_df(adata, keys='classification').loc[sample_name].values[0]
    sample = sc.get.var_df(adata, keys=sample_name)
    sample = sample[sample.index.isin(genes)].reindex(genes).T
    return sample, sample_gt, sample_name

In [None]:
try:
    clinical_data
except NameError:
    clinical_data = {
            "age": 50,
            "tumor_size": 5,
            "lymph_node": 'Positive',
            "er_status": 'Positive',
            "pgr_status": 'Negative',
            "her2_status": 'Negative',
            "ki67_status": 'NA',
            "nhg": 'G2',
            "pam50": 'NA'
        }

In [None]:
clinical_table=pd.DataFrame.from_dict({k:v for k,v in clinical_data.items() if v!='NA'},orient='index', columns=['Clinical'])

In [None]:
Markdown(f"""
# Clinical Features:

{clinical_table.to_html()}
""")

In [None]:
try:
    sample
    sample=pd.read_csv(sample)[['gene_id', 'expression']]
    sample=sample.set_index('gene_id')
    sample=ss.transform(sample.T[genes])
except NameError:
    sample, _, _ = get_random_sample(genes,ss)
    sample = ss.transform(sample)
sample = torch.Tensor(sample).to(device)

In [None]:
probabilities = full_model_predict_proba(sample)
cluster_id=np.argmax(probabilities.cpu().numpy()[0])
cluster_name=labels[cluster_id]

In [None]:
Markdown(f"""
# Predicted Cluster:

{cluster_name}
""")

In [None]:
cluster_samples = sc.read('/app/chatbot/Data/dataset_annotated.h5ad').obs['classification']
background_data=full_data.sample(1000, random_state=42,axis=0)
background_features=torch.Tensor(background_data.to_numpy()).to(device)


In [None]:
cluster_mask = cluster_samples[cluster_samples=='Basal-G3'].index
cluster_data = full_data[full_data.index.isin(cluster_mask)]
cluster_features=torch.Tensor(cluster_data.to_numpy()).to(device)


In [None]:
shap_values = explainer.shap_values(sample)#, max_evals=20001, batch_size=1)


In [None]:
target_class_id=labels.index(cluster_name)

In [None]:
# Get class 2 SHAP values for sample 0
vals = shap_values[0,:,cluster_id]
base = explainer.expected_value[cluster_id]  # scalar
Markdown(f"""
# SHAP Values for cluster {cluster_name}
""")

In [None]:
# Build the Explanation object
explanation = shap.Explanation(
    values=vals,
    base_values=base,
    data=sample.cpu().numpy()[0],  # the input features
    feature_names=genes  # optional
)

# Plot
ax = shap.plots.waterfall(explanation, show=False)
fig = ax.figure
buf = io.BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
display(Image(data=buf.read(), format='png'))

In [None]:
shaps = pd.DataFrame({'gene_name':genes, 'score':vals}).sort_values('score',ascending=False)

Markdown(f"""
## Top 10 SHAP values with higher value:

{shaps.sort_values(by='score',ascending=False).head(10).to_html()}

## Top 10 SHAP values with lowest value:

{shaps.sort_values(by='score',ascending=True).head(10).to_html()}


""")

In [None]:
from gseapy import Msigdb
from gseapy import GSEA
from gseapy import dotplot
import gseapy as gp

In [None]:
msig = Msigdb()
gmt = msig.get_gmt(category='h.all', dbver="2025.1.Hs")


def pathways(expr,ax):
    pre_res = gp.prerank(
        rnk=expr,  # DataFrame or path to .rnk file
        gene_sets=gmt, 
        permutation_num=10000,  # recommended ≥1000
        seed=42,
        threads=4,  # parallelization
        outdir=None
    )
    try:
        ax = dotplot(pre_res.res2d,
                 column="FDR q-val",
                 cmap=plt.cm.viridis,
                 size=5, # adjust dot size
                 show_ring=False,ax=ax, figsize=(10,15))
    except:
        ax=None
    return ax, pre_res.res2d


In [None]:
Markdown(f"""
# GSEA:

""")


In [None]:

fig,ax=plt.subplots(1,1)
ax, res = pathways(shaps,ax)
fig = ax.figure
buf = io.BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
display(Image(data=buf.read(), format='png'))

In [None]:
if ax is None:
    res=res.iloc[:,:-1].sort_values(by='FDR q-val').head(20)
    print("No significant enriched pathways")
else:
    res=res.iloc[:,:-1].sort_values(by='FDR q-val')
    res=res[res['FDR q-val']<0.05]

In [None]:
Markdown(f"""

{res.to_html()}

""")