En este notebook vamos a ver cómo hacer fine-tuning de los modelos scGPT y Geneformer a través del paquete Helical. Este paso es importante de cara a nuestra experimentación debido a que nos permitirá comparar a los tres modelos en las mismas condiciones. Este notebook ha sido desarrollado para ejecutarse en Google Colab.

De momento me falta estudiar el funcionamiento y aplicarlo al dataset de esclerosis múltiple. También me faltaría probarlo con scGPT.

In [None]:
#!pip install helical

Collecting causal-conv1d==1.4.0 (from helical[mamba-ssm])
  Using cached causal_conv1d-1.4.0.tar.gz (9.3 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mamba-ssm==2.2.4 (from helical[mamba-ssm])
  Downloading mamba_ssm-2.2.4.tar.gz (91 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.8/91.8 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting ninja (from causal-conv1d==1.4.0->helical[mamba-ssm])
  Using cached ninja-1.11.1.3-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (5.3 kB)
Using cached ninja-1.11.1.3-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (422 kB)
Building wheels for collected packages: causal-conv1d, mamba-ssm
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py bdist_wheel[0m did n

In [None]:
from helical.utils import get_anndata_from_hf_dataset
from helical import GeneformerConfig, GeneformerFineTuningModel, scGPTConfig, scGPTFineTuningModel
import torch
import numpy as np
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
import matplotlib.pyplot as plt
import logging, warnings
import umap
import pandas as pd
import seaborn as sns

logging.getLogger().setLevel(logging.ERROR)

warnings.filterwarnings("ignore")

INFO:numexpr.utils:NumExpr defaulting to 2 threads.
INFO:datasets:PyTorch version 2.5.1+cu124 available.
INFO:datasets:Polars version 1.9.0 available.
INFO:datasets:TensorFlow version 2.18.0 available.
INFO:datasets:JAX version 0.4.33 available.
INFO:helical:Caduceus not available: If you want to use this model, ensure you have a CUDA GPU and have installed the optional helical[mamba-ssm] dependencies.


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

from datasets import load_dataset
ds = load_dataset("helical-ai/yolksac_human",trust_remote_code=True, download_mode="reuse_cache_if_exists")

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/4.23k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/553M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [None]:
ds

DatasetDict({
    train: Dataset({
        features: ['raw_counts', 'rows', 'size', 'LVL1', 'LVL2', 'LVL3'],
        num_rows: 25344
    })
    test: Dataset({
        features: ['raw_counts', 'rows', 'size', 'LVL1', 'LVL2', 'LVL3'],
        num_rows: 6336
    })
})

La siguiente celda es debido a que la función load_datasets te permite, a la misma vez que descargas los datos, hacer el train/test split. Ambos tienen las mismas columnas, pero train tiene 25344 instancias por 6336 de test.

In [None]:
train_dataset = get_anndata_from_hf_dataset(ds["train"])
test_dataset = get_anndata_from_hf_dataset(ds["test"])

Aquí tenemos arrays de 37000 valores

In [None]:
train_dataset.X.toarray()

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])

Primer paso común -> Guardar todos los tipos celulares en una lista.

In [None]:
cell_types_train = list(np.array(train_dataset.obs["LVL1"].tolist()))
cell_types_test = list(np.array(test_dataset.obs["LVL1"].tolist()))

Ahora, convertimos estas clases en id's(enteros) para la clasificación.

In [None]:
label_set = set(cell_types_train) | set(cell_types_test) # Unimos ambos conjuntos
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))
id_class_dict = {v: k for k, v in class_id_dict.items()}


#Sustituimos los enteros por sus identificadores numéricos
for i in range(len(cell_types_train)):
    cell_types_train[i] = class_id_dict[cell_types_train[i]]

for i in range(len(cell_types_test)):
    cell_types_test[i] = class_id_dict[cell_types_test[i]]

In [None]:
id_class_dict

{0: 'STROMA',
 1: 'MYELOID',
 2: 'PROGENITOR',
 3: 'MK',
 4: 'LYMPHOID',
 5: 'ERYTHROID'}

## Fine-tuning GENEFORMER

En **geneformer_config** el único cambio es que ahora indicamos el modelo a usar.

In [None]:
geneformer_config = GeneformerConfig(device=device, batch_size=50, model_name="gf-6L-30M-i2048") # En principio hay más opciones de modelos
geneformer_fine_tune = GeneformerFineTuningModel(geneformer_config=geneformer_config, fine_tuning_head="classification", output_size=len(label_set)) #classification o regression

Downloading: 100%|██████████| 941k/941k [00:00<00:00, 1.33MB/s]
Downloading: 100%|██████████| 788k/788k [00:00<00:00, 3.54MB/s]
Downloading: 100%|██████████| 3.96M/3.96M [00:00<00:00, 8.65MB/s]
Downloading: 100%|██████████| 565/565 [00:00<00:00, 7.18kB/s]
Downloading: 100%|██████████| 2.61k/2.61k [00:00<00:00, 33.3kB/s]
Downloading: 100%|██████████| 41.2M/41.2M [00:04<00:00, 8.85MB/s]


In [None]:
geneformer_train_dataset = geneformer_fine_tune.process_data(train_dataset) # use_raw_counts = False, para ms
geneformer_test_dataset = geneformer_fine_tune.process_data(test_dataset)

Map:   0%|          | 0/25344 [00:00<?, ? examples/s]

Map:   0%|          | 0/6336 [00:00<?, ? examples/s]

In [None]:
geneformer_train_dataset

Dataset({
    features: ['input_ids', 'length', 'LVL1'],
    num_rows: 25344
})

Le añadimos al dataset del modelo, una columna llamada LVL1 que contiene las distintas clases de células que recogíamos de forma previa en anteriores celdas de código.

In [None]:
geneformer_train_dataset = geneformer_train_dataset.add_column("LVL1", cell_types_train)
geneformer_test_dataset = geneformer_test_dataset.add_column("LVL1", cell_types_test)

Procedemos a la optimización, en principio está genial ya que nos permitiría igualar parámetros con CellPLM y poder hacer los benchmarks en igualdad de condiciones.

En principio no habría problema, pero este dataset tiene demasiadas instancias de entrenamiento, lo cual sobrepasa las capacidades de memoria de Google Colab.

In [None]:
geneformer_fine_tune.train(train_dataset=geneformer_train_dataset.shuffle(), validation_dataset=geneformer_test_dataset, label="LVL1", freeze_layers=0, epochs=1, optimizer_params={"lr": 1e-4}, lr_scheduler_params={"name":"linear", "num_warmup_steps":0, 'num_training_steps':1})

Fine-Tuning:   0%|          | 0/507 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 9.70 GiB. GPU 0 has a total capacity of 14.74 GiB of which 7.51 GiB is free. Process 11000 has 7.22 GiB memory in use. Of the allocated memory 7.10 GiB is allocated by PyTorch, and 2.21 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Ahora, obtenemos las predicciones para el test set y, a continuación, los embeddings generados.

In [None]:
outputs = geneformer_fine_tune.get_outputs(geneformer_test_dataset)

In [None]:
embeddings = geneformer_fine_tune.get_embeddings(geneformer_test_dataset)

Terminamos visualizando los resultados y calculando las métricas de rendimiento de la clasificación.

In [None]:
reducer = umap.UMAP(min_dist=0.2, n_components=2, n_epochs=None, n_neighbors=4)
mapper = reducer.fit(embeddings)

plot_df = pd.DataFrame(mapper.embedding_,columns=['px','py'])
labels = geneformer_test_dataset["LVL1"]
plot_df['Cell Type'] = labels


# Create a matplotlib figure and axes
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(14, 5))

#plt.style.use("dark_background")

sns.scatterplot(data = plot_df,x='px',y='py',sizes=(50,200),ax=axs[0],palette="pastel")
axs[0].set_title('UMAP of Reference Data without labels')

sns.scatterplot(data = plot_df,x='px',y='py',hue='Cell Type',sizes=(50,200),ax=axs[1],palette="pastel")
axs[1].set_title('UMAP of Reference Data with labels')

In [None]:
print(classification_report(cell_types_test,outputs.argmax(axis=1)))

In [None]:
# Compute the confusion matrix
cm = confusion_matrix(cell_types_test, outputs.argmax(axis=1))

# Perform row-wise normalization
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Get unique labels in the order they appear in the confusion matrix
unique_labels = np.unique(np.concatenate((cell_types_test, outputs.argmax(axis=1))))

# Use id_class_dict to get the class names
class_names = [id_class_dict[label] for label in unique_labels]

# Create and plot the normalized confusion matrix
fig, ax = plt.subplots(figsize=(15, 15))
disp = ConfusionMatrixDisplay(confusion_matrix=cm_normalized, display_labels=class_names)
disp.plot(ax=ax, xticks_rotation='vertical', values_format='.2f', cmap='coolwarm')

# Customize the plot
ax.set_title('Normalized Confusion Matrix (Row-wise)')
fig.set_facecolor("none")

plt.tight_layout()
plt.show()