In [1]:
# %%
%reload_ext autoreload
%autoreload 2
import torch
from torch import optim
from FinetunePatientClassification import *
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader

## Input data and preprocessing

In [2]:
# load the preprocessed data for scFoundation
data = pd.read_csv('gene_symbol_converted_data_for_scF.csv')

In [3]:
# extract the label data from the original dataset
import scanpy as sc
anndata = sc.read_h5ad("Preprocessed_Data/clono_filtered_counts_adata.h5ad")
anndata.shape

(25066, 15093)

In [4]:
# replace label values with binary values
anndata.obs['Response_3m'] =  anndata.obs['Response_3m'].replace({'CR':1, 'OR':1, 'NR':0})
anndata.obs['Response_3m'].value_counts()

  anndata.obs['Response_3m'] =  anndata.obs['Response_3m'].replace({'CR':1, 'OR':1, 'NR':0})
  anndata.obs['Response_3m'] =  anndata.obs['Response_3m'].replace({'CR':1, 'OR':1, 'NR':0})


Response_3m
1    13160
0    11906
Name: count, dtype: int64

## Downsample the data due to GPU limitations

In [5]:
import random
labels_df = anndata.obs
patients = labels_df.patient_id.unique()

sample_size = 1

# randomly sample patient IDs
random_seed=42
sampled_patient_ids = random.sample(list(patients), sample_size)

# Create a new DF with only the sampled patients
labels_df_downsampled = labels_df[labels_df['patient_id'].isin(sampled_patient_ids)]

In [6]:
# downsample the expression data accordingly
data_downsampled = data.iloc[labels_df_downsampled.index]

# remove the cell_id column
data_downsampled = data_downsampled.drop(columns=['cell_id'])

In [7]:
# make sure the data and labels are aligned
data_downsampled.shape, labels_df_downsampled.shape

((338, 19264), (338, 7))

In [9]:
data_downsampled

Unnamed: 0,A1BG,A1CF,A2M,A2ML1,A3GALT2,A4GALT,A4GNT,AAAS,AACS,AADAC,...,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11A,ZYG11B,ZYX,ZZEF1,ZZZ3
6072,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,7.0,1.0,0.0
6073,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,13.0,0.0,0.0,0.0,0.0,0.0,11.0,0.0,0.0
6074,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,1.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0
6075,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,7.0,0.0,0.0
6076,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6410,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0,...,1.0,15.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0
6411,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,4.0,0.0,0.0
6412,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,...,1.0,3.0,0.0,0.0,0.0,0.0,1.0,7.0,1.0,0.0
6413,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0


Finetune the model

In [8]:
torch.cuda.empty_cache()

In [9]:
class SingleCellDataset(Dataset):
    def __init__(self, gene_expression_csv, labels_df, label_encoder):
        # Load the gene expression data
        self.gene_expression = gene_expression_csv
        
        # Ensure the labels DataFrame has the same number of rows as the gene expression data
        assert len(self.gene_expression) == len(labels_df), "Mismatch in number of samples between gene expression data and labels"
        
        # Convert labels to numeric if they're categorical
        
        self.labels = torch.LongTensor(label_encoder.transform(labels_df['Response_3m']))
        
        
        # Convert gene expression data to torch tensor
        self.gene_expression = torch.FloatTensor(self.gene_expression.values.astype(np.float32))
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return {
            'x': self.gene_expression[idx],
            'targets': self.labels[idx]
        }

In [10]:
import gc
from torch.profiler import profile, record_function, ProfilerActivity

def print_gpu_memory(step_name):
    print(f"GPU memory at {step_name}: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

def finetune_scFoundation(gene_expression_csv, labels_df, model_class, ckpt_path,
                          batch_size=4, num_epochs=5, lr=0.001,
                          validation_split=0.2, device='cuda',
                          gradient_accumulation_steps=2):
    
    gene_exp_train, gene_exp_val, labels_train, labels_val = train_test_split(gene_expression_csv, 
                                                                              labels_df, test_size=validation_split, 
                                                                              random_state=42)

    # Fit LabelEncoder on combined dataset
    le = LabelEncoder()
    combined_labels = pd.concat([labels_train['Response_3m'], labels_val['Response_3m']])
    le.fit(combined_labels)

    #create datasets
    train_dataset = SingleCellDataset(gene_exp_train, labels_train, le)
    val_dataset = SingleCellDataset(gene_exp_val, labels_val, le)

    #create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    # Initialize model
    model = model_class(ckpt_path=ckpt_path)
    model.build()
    model = model.to(device)

    # Initialize optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Initialize gradient scaler for mixed precision training
    scaler = GradScaler()

    # Optionally load checkpoint
    start_epoch = 0
    

    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
             profile_memory=True, record_shapes=True) as prof:

        for epoch in range(start_epoch, num_epochs):
            model.train()
            train_loss = 0.0

            # Clear CUDA cache
            torch.cuda.empty_cache()

            for i, batch in enumerate(train_loader):
                print_gpu_memory(f"Epoch {epoch}, Batch {i} start")
                # Move batch to device
                batch = {k: v.to(device) for k, v in batch.items()}

                # Mixed precision training
                with autocast():
                    # Forward pass
                    logits = model(batch)

                    # Compute loss
                    loss = model.compute_loss(logits, batch['targets'].float()) / gradient_accumulation_steps
                
                print_gpu_memory(f"Epoch {epoch}, Batch {i} after forward pass")
                # Ensure loss requires gradient
                assert loss.requires_grad, "Loss does not require gradients"

                # Backward pass with gradient scaling
                scaler.scale(loss).backward()

                print_gpu_memory(f"Epoch {epoch}, Batch {i} after backward pass")

                if (i + 1) % gradient_accumulation_steps == 0:
                    # Unscale gradients and optimizer step
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()

                    # add garbage collection
                    gc.collect()
                    torch.cuda.empty_cache()
                
                    print_gpu_memory(f"Epoch {epoch}, Batch {i} after optimizer step")

                train_loss += loss.item() * gradient_accumulation_steps
            print(f"Max GPU memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")
    


        #validation
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0

        torch.cuda.empty_cache()
        print_gpu_memory(f"Before validation")

        with torch.no_grad():
            for j, batch in enumerate(val_loader):
                print_gpu_memory(f"Validation, Batch {j} start")
                batch = {k: v.to(device) for k, v in batch.items()}
                logits = model(batch)
                val_loss += model.compute_loss(logits, batch['targets'].float()).item()
                predicted = (torch.sigmoid(logits) > 0.5).float()
                total += batch['targets'].size(0)
                correct += (predicted == batch['targets']).sum().item()

                # Move data back to CPU
                for k in batch.keys():
                    batch[k] = batch[k].cpu()
                del batch, logits, predicted
                torch.cuda.empty_cache()

                print_gpu_memory(f"Validation, Batch {j} end")

                
        
        # Print epoch results
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"Train Loss: {train_loss/len(train_loader):.4f}")
        print(f"Validation Loss: {val_loss/len(val_loader):.4f}")
        print(f"Validation Accuracy: {100*correct/total:.2f}%")
        print("-----------------------------")
    
    print(prof.key_averages().table(sort_by="cuda_memory_total", row_limit=10))

    return model

In [13]:
anndata.obs

Unnamed: 0,cell_id,nCount_RNA,nFeature_RNA,patient_id,percent_mito,Response_3m,sample_source
0,Good_Li_2023ac01_AAACGGGAGATGTGTA-1,652,433,ac01,3.935860,1,Deng
1,Good_Li_2023ac01_AAAGATGCAGCCTTGG-1,18106,3790,ac01,2.465065,1,Deng
2,Good_Li_2023ac01_AAAGTAGTCATGGTCA-1,1145,688,ac01,3.858785,1,Deng
3,Good_Li_2023ac01_AAATGCCAGTACCGGA-1,3430,1413,ac01,2.677339,1,Deng
4,Good_Li_2023ac01_AAATGCCGTTCAGACT-1,924,554,ac01,1.731161,1,Deng
...,...,...,...,...,...,...,...
25061,Sheih_TTTGGTTCATTGTGCA-5,2722,1253,NHL-7,6.673766,1,Sheih
25062,Sheih_TTTGGTTTCGATCCCT-5,6209,1980,NHL-7,3.170656,1,Sheih
25063,Sheih_TTTGGTTTCTACGAGT-5,5606,1925,NHL-7,2.868498,1,Sheih
25064,Sheih_TTTGTCACAGTCAGAG-5,2317,1049,NHL-7,2.549102,1,Sheih


tried downsampling data, batch size

In [12]:
finetuned_model = finetune_scFoundation(data_downsampled, labels_df_downsampled, model_class=FinetunePatientClassification, 
                                        ckpt_path='./models/models.ckpt', num_epochs=1, lr=0.001, device='cuda',
                                        validation_split=0.2, batch_size=4)   

{'mask_gene_name': False, 'gene_num': 19266, 'seq_len': 19266, 'encoder': {'hidden_dim': 768, 'depth': 12, 'heads': 12, 'dim_head': 64, 'seq_len': 19266, 'module_type': 'transformer', 'norm_first': False}, 'decoder': {'hidden_dim': 512, 'depth': 6, 'heads': 8, 'dim_head': 64, 'module_type': 'performer', 'seq_len': 19266, 'norm_first': False}, 'n_class': 104, 'pad_token_id': 103, 'mask_token_id': 102, 'bin_num': 100, 'bin_alpha': 1.0, 'rawcount': True, 'model': 'mae_autobin', 'test_valid_train_idx_dict': '/nfs_beijing/minsheng/data/os10000w-new/global_shuffle/meta.csv.train_set_idx_dict.pt', 'valid_data_path': '/nfs_beijing/minsheng/data/valid_count_10w.npz', 'num_tokens': 13, 'train_data_path': None, 'isPanA': False, 'isPlanA1': False, 'max_files_to_load': 5, 'bin_type': 'auto_bin', 'value_mask_prob': 0.3, 'zero_mask_prob': 0.03, 'replace_prob': 0.8, 'random_token_prob': 0.1, 'mask_ignore_token_ids': [0], 'decoder_add_zero': True, 'mae_encoder_max_seq_len': 15000, 'isPlanA': False, 'ma

STAGE:2024-07-24 17:44:33 214234:214234 ActivityProfilerController.cpp:314] Completed Stage: Warm Up


GPU memory at Epoch 0, Batch 0 start: 0.40 GB
GPU memory at Epoch 0, Batch 0 after forward pass: 0.41 GB
GPU memory at Epoch 0, Batch 0 after backward pass: 0.42 GB
GPU memory at Epoch 0, Batch 1 start: 0.42 GB
GPU memory at Epoch 0, Batch 1 after forward pass: 0.42 GB
GPU memory at Epoch 0, Batch 1 after backward pass: 0.42 GB
GPU memory at Epoch 0, Batch 1 after optimizer step: 0.42 GB
GPU memory at Epoch 0, Batch 2 start: 0.42 GB
GPU memory at Epoch 0, Batch 2 after forward pass: 0.42 GB
GPU memory at Epoch 0, Batch 2 after backward pass: 0.42 GB
GPU memory at Epoch 0, Batch 3 start: 0.42 GB
GPU memory at Epoch 0, Batch 3 after forward pass: 0.42 GB
GPU memory at Epoch 0, Batch 3 after backward pass: 0.42 GB
GPU memory at Epoch 0, Batch 3 after optimizer step: 0.42 GB
GPU memory at Epoch 0, Batch 4 start: 0.42 GB
GPU memory at Epoch 0, Batch 4 after forward pass: 0.42 GB
GPU memory at Epoch 0, Batch 4 after backward pass: 0.42 GB
GPU memory at Epoch 0, Batch 5 start: 0.42 GB
GPU mem

STAGE:2024-07-24 17:45:13 214234:214234 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-07-24 17:45:13 214234:214234 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


OutOfMemoryError: CUDA out of memory. Tried to allocate 5.26 GiB. GPU 

In [14]:
import torch
import gc

def get_gpu_memory():
    """Returns total and available GPU memory in gigabytes"""
    if torch.cuda.is_available():
        total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
        reserved_memory = torch.cuda.memory_reserved(0) / 1e9
        allocated_memory = torch.cuda.memory_allocated(0) / 1e9
        available_memory = total_memory - (reserved_memory + allocated_memory)
        return total_memory, available_memory
    else:
        return 0, 0

def get_model_size(model):
    """Returns the size of the model in gigabytes"""
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    size_all_gb = (param_size + buffer_size) / 1024**3
    return size_all_gb

def check_model_fit(model):
    # Clear GPU cache
    torch.cuda.empty_cache()
    gc.collect()

    total_memory, available_memory = get_gpu_memory()
    model_size = get_model_size(model)

    print(f"Total GPU Memory: {total_memory:.2f} GB")
    print(f"Available GPU Memory: {available_memory:.2f} GB")
    print(f"Estimated Model Size: {model_size:.2f} GB")

    if model_size > available_memory:
        print("WARNING: Model size exceeds available GPU memory!")
    else:
        print("Model size is within available GPU memory.")

    # Consider memory needed for activations, gradients, and optimizer states
    estimated_total_need = model_size * 3  # A rough estimate
    print(f"Estimated total memory need (including activations and gradients): {estimated_total_need:.2f} GB")

    if estimated_total_need > available_memory:
        print("WARNING: Estimated total memory need exceeds available GPU memory!")
    else:
        print("Estimated total memory need is within available GPU memory.")

# Usage
model = FinetunePatientClassification(ckpt_path='./models/models.ckpt')  # Initialize your model
check_model_fit(model)

Total GPU Memory: 23.61 GB
Available GPU Memory: 22.99 GB
Estimated Model Size: 0.00 GB
Model size is within available GPU memory.
Estimated total memory need (including activations and gradients): 0.00 GB
Estimated total memory need is within available GPU memory.


In [23]:
finetuned_model = finetune_scFoundation(data_downsampled, labels_df_downsampled, model_class=FinetunePatientClassification, 
                                        ckpt_path='./models/models.ckpt', num_epochs=5, lr=0.001, device='cuda',
                                        validation_split=0.2, load_checkpoint_or_not=True, batch_size=8)   

{'mask_gene_name': False, 'gene_num': 19266, 'seq_len': 19266, 'encoder': {'hidden_dim': 768, 'depth': 12, 'heads': 12, 'dim_head': 64, 'seq_len': 19266, 'module_type': 'transformer', 'norm_first': False}, 'decoder': {'hidden_dim': 512, 'depth': 6, 'heads': 8, 'dim_head': 64, 'module_type': 'performer', 'seq_len': 19266, 'norm_first': False}, 'n_class': 104, 'pad_token_id': 103, 'mask_token_id': 102, 'bin_num': 100, 'bin_alpha': 1.0, 'rawcount': True, 'model': 'mae_autobin', 'test_valid_train_idx_dict': '/nfs_beijing/minsheng/data/os10000w-new/global_shuffle/meta.csv.train_set_idx_dict.pt', 'valid_data_path': '/nfs_beijing/minsheng/data/valid_count_10w.npz', 'num_tokens': 13, 'train_data_path': None, 'isPanA': False, 'isPlanA1': False, 'max_files_to_load': 5, 'bin_type': 'auto_bin', 'value_mask_prob': 0.3, 'zero_mask_prob': 0.03, 'replace_prob': 0.8, 'random_token_prob': 0.1, 'mask_ignore_token_ids': [0], 'decoder_add_zero': True, 'mae_encoder_max_seq_len': 15000, 'isPlanA': False, 'ma

KeyError: 'model_state_dict'

In [None]:
import os

def save_finetuned_model(model, save_path, model_name):
    """
    Save the finetuned model using both methods: entire model and state dict.
    
    Args:
    model (torch.nn.Module): The finetuned model to save
    save_path (str): Directory to save the model
    model_name (str): Name to use for the saved model files
    """
    # Ensure the save directory exists
    os.makedirs(save_path, exist_ok=True)
    
    # 1. Save the entire model
    entire_model_path = os.path.join(save_path, f"{model_name}_entire.pth")
    torch.save(model, entire_model_path)
    print(f"Entire model saved to {entire_model_path}")
    
    # 2. Save only the state dict
    state_dict_path = os.path.join(save_path, f"{model_name}_state_dict.pth")
    torch.save(model.state_dict(), state_dict_path)
    print(f"Model state dict saved to {state_dict_path}")


In [None]:
save_path = './saved_models'
model_name = 'finetuned_scFoundation'
save_finetuned_model(finetuned_model, save_path, model_name)

In [17]:
x = torch.tensor([1, 2, 3, 4])
torch.unsqueeze(x, 1)

tensor([[1],
        [2],
        [3],
        [4]])

In [None]:
# 1. Save the entire model
entire_model_path = os.path.join(save_path, f"{model_name}_entire.pth")
torch.save(model, entire_model_path)
print(f"Entire model saved to {entire_model_path}")

# 2. Save only the state dict
state_dict_path = os.path.join(save_path, f"{model_name}_state_dict.pth")
torch.save(model.state_dict(), state_dict_path)
print(f"Model state dict saved to {state_dict_path}")