In [1]:
import sys
import seaborn as sns
import pandas as pd 
import numpy as np
from scipy.spatial.distance import squareform, pdist
import matplotlib.pyplot as plt

import torch
import anndata as an
import scanpy as sc
import os
import gc
from importlib import reload

from datasets import Dataset, load_from_disk
from datasets import load_dataset
from geneformer import EmbExtractor
import geneformer

# classifer tools
import xgboost
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split, KFold, cross_val_score
from sklearn.metrics import accuracy_score, classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import silhouette_samples, silhouette_score

# local imports
import geneformer_utils as gtu

sns.set_style('white')
torch.cuda.empty_cache()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
"""Load the model"""
model_path = "/scratch/indikar_root/indikar1/cstansbu/hematokytos/finetuned_models/merged_adata/240923_geneformer_cellClassifier_hsc/ksplit1/"
model = gtu.load_model(model_path)
print('loaded!')

Some weights of BertForMaskedLM were not initialized from the model checkpoint at /scratch/indikar_root/indikar1/cstansbu/hematokytos/finetuned_models/merged_adata/240923_geneformer_cellClassifier_hsc/ksplit1/ and are newly initialized: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


loaded!


In [None]:
fpath = "/scratch/indikar_root/indikar1/cstansbu/hematokytos/tokenized_data/merged_adata.dataset"
sample_size = 50000
df = gtu.load_data_as_dataframe(
    fpath, 
    num_cells=sample_size, 
    shuffle=True,
)
print(f"{df.shape=}")
data = Dataset.from_pandas(df)
data

In [None]:
reload(gtu)
torch.cuda.empty_cache()
embs = gtu.extract_embedding_in_mem(model, data, layer_to_quant=-1)
print(f"{embs.shape=}")

# translate into an anndata object and plot
finetuned = gtu.embedding_to_adata(embs)
finetuned.obs = df.copy()

sc.tl.pca(finetuned)
sc.pp.neighbors(finetuned)
sc.tl.umap(finetuned)

finetuned.obs['UMAP 1'] = finetuned.obsm['X_umap'][:, 0]
finetuned.obs['UMAP 2'] = finetuned.obsm['X_umap'][:, 1]

finetuned

In [None]:
plt.rcParams['figure.dpi'] = 200
plt.rcParams['figure.figsize'] = 7, 7


sns.scatterplot(
    data=finetuned.obs,
    x='UMAP 1',
    y='UMAP 2',
    c='k',
    ec='none',
    s=35,
)

sns.scatterplot(
    data=finetuned.obs,
    x='UMAP 1',
    y='UMAP 2',
    hue='standard_cell_type',
    ec='none',
    s=5,
)

sns.move_legend(
    plt.gca(),
    title="",
    frameon=False,
    loc='center right',
    bbox_to_anchor=(1.4, 0.5),
    markerscale=2.5,
)

plt.xticks([])
plt.yticks([])
sns.despine()

In [None]:
plt.rcParams['figure.dpi'] = 200
plt.rcParams['figure.figsize'] = 7, 7


sns.scatterplot(
    data=finetuned.obs,
    x='UMAP 1',
    y='UMAP 2',
    c='k',
    ec='none',
    s=35,
)

sns.scatterplot(
    data=finetuned.obs,
    x='UMAP 1',
    y='UMAP 2',
    hue='dataset_x',
    ec='none',
    s=5,
)

sns.move_legend(
    plt.gca(),
    title="",
    frameon=False,
    loc='center right',
    bbox_to_anchor=(1.5, 0.5),
    markerscale=2.5,
)

plt.xticks([])
plt.yticks([])
sns.despine()

In [None]:
break

# how do these compare to the pretrainned embeddings?

In [None]:
"""Load the model"""
model_path = "/scratch/indikar_root/indikar1/cstansbu/hematokytos/pretrained_model"
model = gtu.load_model(model_path)
print('loaded!')

In [None]:
reload(gtu)
torch.cuda.empty_cache()
embs = gtu.extract_embedding_in_mem(model, data, layer_to_quant=-1)
print(f"{embs.shape=}")

# translate into an anndata object and plot
pretrained = gtu.embedding_to_adata(embs)
pretrained.obs = df.copy()

sc.tl.pca(pretrained)
sc.pp.neighbors(pretrained)
sc.tl.umap(pretrained)

plt.rcParams['figure.dpi'] = 200
plt.rcParams['figure.figsize'] = 7, 7

sc.pl.umap(
    pretrained,
    color=["standardized_cell_type", "dataset", "broad_type"],
    palette='tab20',
    ncols=1,
    size=30,
)

pretrained

# Compare and Contrast

In [None]:
def compute_silhouette_score(adata, cluster_key):
    """Computes the average silhouette score for an AnnData object
       using the .X data.

    Args:
        adata: AnnData object containing the data and cluster assignments.
        cluster_key: Column in adata.obs containing cluster assignments.

    Returns:
        float: The average silhouette score.
    """

    X_data = adata.X
    labels = adata.obs[cluster_key].values

    # Compute silhouette score
    silhouette_avg = silhouette_score(X_data, labels, metric='euclidean')  # Adjust metric if needed

    print(f"Average silhouette score for '{cluster_key}': {silhouette_avg:.3f}")
    return silhouette_avg
    
    
compute_silhouette_score(pretrained, 'standardized_cell_type')
compute_silhouette_score(finetuned, 'standardized_cell_type')

In [None]:
# for data set
compute_silhouette_score(pretrained, 'dataset')
compute_silhouette_score(finetuned, 'dataset')

In [None]:
compute_silhouette_score(pretrained, 'broad_type')
compute_silhouette_score(finetuned, 'broad_type')