In [1]:
# %%
%reload_ext autoreload
%autoreload 2
import torch
from torch import optim
from FinetunePatientClassification import *
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader

## Problem: bench kept crashing during the train and validation loop. According to torch profiler, the training loop ran but crashed somewhere during validation. There is a sudden spike in GPU usage then the error message appeared. 
Troubleshooting attempts:
x tried downsampling data
x batch size
x add gradient accumulation
x mixed precision training
x freeze model parameters

## Input data and preprocessing

In [2]:
# load the preprocessed data for scFoundation
data = pd.read_csv('gene_symbol_converted_data_for_scF.csv')

## Downsample the data due to GPU limitations

In [3]:
labels_df = pd.read_csv('patient_metadata.csv')
labels_df["Response_3m"].value_counts() 

Response_3m
1    13160
0    11906
Name: count, dtype: int64

In [4]:
import random

patients = labels_df.patient_id.unique()

sample_size = 3 # number of patients to sample

# randomly sample patient IDs
random_seed=42
sampled_patient_ids = random.sample(list(patients), sample_size)

# Create a new DF with only the sampled patients
labels_df_downsampled = labels_df[labels_df['patient_id'].isin(sampled_patient_ids)]

In [5]:
# downsample the expression data accordingly
data_downsampled = data.iloc[labels_df_downsampled.index]

# remove the cell_id column
data_downsampled = data_downsampled.drop(columns=['cell_id'])

In [6]:
# make sure the data and labels are aligned
data_downsampled.shape, labels_df_downsampled.shape

((1149, 19264), (1149, 7))

In [7]:
data_downsampled

Unnamed: 0,A1BG,A1CF,A2M,A2ML1,A3GALT2,A4GALT,A4GNT,AAAS,AACS,AADAC,...,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11A,ZYG11B,ZYX,ZZEF1,ZZZ3
414,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
415,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3.0,0.0,0.0
416,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,1.0,0.0,0.0,0.0,0.0,0.0,5.0,0.0,0.0
417,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
418,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
17841,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0
17842,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,2.0,2.0,0.0,0.0,0.0,0.0,0.0,5.0,0.0,0.0
17843,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,1.0,1.0,0.0,0.0,0.0,0.0,0.0,3.0,0.0,0.0
17844,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


Finetune the model

In [8]:
torch.cuda.empty_cache()

In [9]:
class SingleCellDataset(Dataset):
    def __init__(self, gene_expression_csv, labels_df, label_encoder):
        # Load the gene expression data
        self.gene_expression = gene_expression_csv
        
        # Ensure the labels DataFrame has the same number of rows as the gene expression data
        assert len(self.gene_expression) == len(labels_df), "Mismatch in number of samples between gene expression data and labels"
        
        # Convert labels to numeric if they're categorical
        
        self.labels = torch.LongTensor(label_encoder.transform(labels_df['Response_3m']))
        
        
        # Convert gene expression data to torch tensor
        self.gene_expression = torch.FloatTensor(self.gene_expression.values.astype(np.float32))
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return {
            'x': self.gene_expression[idx],
            'targets': self.labels[idx]
        }

## Training Loop

In [10]:
import gc
from torch.profiler import profile, record_function, ProfilerActivity

def print_gpu_memory(step_name):
    print(f"GPU memory at {step_name}: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

def finetune_scFoundation(gene_expression_csv, labels_df, model_class, ckpt_path,
                          batch_size=4, num_epochs=5, lr=0.001,
                          validation_split=0.2, device='cuda',
                          gradient_accumulation_steps=2):
    
    gene_exp_train, gene_exp_val, labels_train, labels_val = train_test_split(gene_expression_csv, 
                                                                              labels_df, test_size=validation_split, 
                                                                              random_state=42)

    # Fit LabelEncoder on combined dataset
    le = LabelEncoder()
    combined_labels = pd.concat([labels_train['Response_3m'], labels_val['Response_3m']])
    le.fit(combined_labels)

    #create datasets
    train_dataset = SingleCellDataset(gene_exp_train, labels_train, le)
    val_dataset = SingleCellDataset(gene_exp_val, labels_val, le)

    #create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size = batch_size // 2 , shuffle=False, num_workers=4, pin_memory=True)  # Ensure at least 2, but no more than 8

    # Initialize model
    model = model_class(ckpt_path=ckpt_path)
    model.build()
    model = model.to(device)

    # Initialize optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Initialize gradient scaler for mixed precision training
    scaler = GradScaler()

    # Optionally load checkpoint
    start_epoch = 0
    

    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
            profile_memory=True, record_shapes=True) as prof:
        
        best_val_loss = float('inf')
        patience = 10
        patience_counter = 0

        for epoch in range(start_epoch, num_epochs):
            model.train()
            train_loss = 0.0

            # Clear CUDA cache
            torch.cuda.empty_cache()

            for i, batch in enumerate(train_loader):
                print_gpu_memory(f"Epoch {epoch}, Batch {i} start")
                # Move batch to device
                batch = {k: v.to(device) for k, v in batch.items()}

                # Mixed precision training
                with autocast():
                    # Forward pass
                    logits = model(batch)

                    # Compute loss
                    loss = model.compute_loss(logits, batch['targets'].float()) / gradient_accumulation_steps
                
                print_gpu_memory(f"Epoch {epoch}, Batch {i} after forward pass")
                # Ensure loss requires gradient
                assert loss.requires_grad, "Loss does not require gradients"

                # Backward pass with gradient scaling
                scaler.scale(loss).backward()

                print_gpu_memory(f"Epoch {epoch}, Batch {i} after backward pass")

                if (i + 1) % gradient_accumulation_steps == 0:
                    # Unscale gradients and optimizer step
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()

                    # add garbage collection
                    gc.collect()
                    torch.cuda.empty_cache()
                
                    print_gpu_memory(f"Epoch {epoch}, Batch {i} after optimizer step")

                train_loss += loss.item() * gradient_accumulation_steps
            print(f"Max GPU memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")
    


            #validation
            model.eval()
            val_loss = 0.0
            correct = 0
            total = 0

            torch.cuda.empty_cache()
            print_gpu_memory(f"Before validation")

            with torch.no_grad():
                for j, batch in enumerate(val_loader):
                    
                    print_gpu_memory(f"Validation, Batch {j} start")
                    if batch['x'].size(0) < 2:  # Skip batches smaller than 2
                        continue
                    batch = {k: v.to(device) for k, v in batch.items()}
                    logits = model(batch)
                    val_loss += model.compute_loss(logits, batch['targets'].float()).item()
                    predicted = (torch.sigmoid(logits) > 0.5).float()
                    print("predicted: ", predicted)
                    # calculate correct predictions
                    correct_pred = (predicted == batch['targets']).sum().item()
                    print(f'correct predictions: {correct_pred}')


                    total += batch['targets'].size(0)
                    correct += correct_pred

                    # Move data back to CPU
                    for k in batch.keys():
                        batch[k] = batch[k].cpu()
                    del batch, logits, predicted
                    torch.cuda.empty_cache()

                    print_gpu_memory(f"Validation, Batch {j} end")

            
            
            # Print epoch results
            print(f"Epoch {epoch+1}/{num_epochs}")
            print(f"total samples in validation: {total}; correct predictions: {correct}")
            print(f"Train Loss: {train_loss/len(train_loader):.4f}")
            print(f"Validation Loss: {val_loss/len(val_loader):.4f}")
            if total == 0:
                print("No validation samples")
            else:
                print(f"Validation Accuracy: {100*correct/total:.2f}%")
            print("-----------------------------")
    
    
    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
    print(f"Peak CUDA memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")
    print(f"Peak CUDA memory reserved: {torch.cuda.max_memory_reserved() / 1e9:.2f} GB")
    print("\nCUDA Memory Summary:")
    print(torch.cuda.memory_summary())

    return model

    

In [11]:
finetuned_model = finetune_scFoundation(data_downsampled, labels_df_downsampled, model_class=FinetunePatientClassification, 
                                        ckpt_path='./models/models.ckpt', num_epochs=1, lr=0.001, device='cuda',
                                        validation_split=0.2, batch_size=4)   

{'mask_gene_name': False, 'gene_num': 19266, 'seq_len': 19266, 'encoder': {'hidden_dim': 768, 'depth': 12, 'heads': 12, 'dim_head': 64, 'seq_len': 19266, 'module_type': 'transformer', 'norm_first': False}, 'decoder': {'hidden_dim': 512, 'depth': 6, 'heads': 8, 'dim_head': 64, 'module_type': 'performer', 'seq_len': 19266, 'norm_first': False}, 'n_class': 104, 'pad_token_id': 103, 'mask_token_id': 102, 'bin_num': 100, 'bin_alpha': 1.0, 'rawcount': True, 'model': 'mae_autobin', 'test_valid_train_idx_dict': '/nfs_beijing/minsheng/data/os10000w-new/global_shuffle/meta.csv.train_set_idx_dict.pt', 'valid_data_path': '/nfs_beijing/minsheng/data/valid_count_10w.npz', 'num_tokens': 13, 'train_data_path': None, 'isPanA': False, 'isPlanA1': False, 'max_files_to_load': 5, 'bin_type': 'auto_bin', 'value_mask_prob': 0.3, 'zero_mask_prob': 0.03, 'replace_prob': 0.8, 'random_token_prob': 0.1, 'mask_ignore_token_ids': [0], 'decoder_add_zero': True, 'mae_encoder_max_seq_len': 15000, 'isPlanA': False, 'ma

STAGE:2024-07-26 17:55:44 120286:120286 ActivityProfilerController.cpp:314] Completed Stage: Warm Up


GPU memory at Epoch 0, Batch 0 start: 0.40 GB
logits shape:  torch.Size([4, 1])
target shape:  torch.Size([4])
squeezed logits shape:  torch.Size([4])
GPU memory at Epoch 0, Batch 0 after forward pass: 0.41 GB
GPU memory at Epoch 0, Batch 0 after backward pass: 0.42 GB
GPU memory at Epoch 0, Batch 1 start: 0.42 GB
logits shape:  torch.Size([4, 1])
target shape:  torch.Size([4])
squeezed logits shape:  torch.Size([4])
GPU memory at Epoch 0, Batch 1 after forward pass: 0.42 GB
GPU memory at Epoch 0, Batch 1 after backward pass: 0.42 GB
GPU memory at Epoch 0, Batch 1 after optimizer step: 0.42 GB
GPU memory at Epoch 0, Batch 2 start: 0.42 GB
logits shape:  torch.Size([4, 1])
target shape:  torch.Size([4])
squeezed logits shape:  torch.Size([4])
GPU memory at Epoch 0, Batch 2 after forward pass: 0.42 GB
GPU memory at Epoch 0, Batch 2 after backward pass: 0.42 GB
GPU memory at Epoch 0, Batch 3 start: 0.42 GB
logits shape:  torch.Size([4, 1])
target shape:  torch.Size([4])
squeezed logits sh

STAGE:2024-07-26 17:56:57 120286:120286 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-07-26 17:56:58 120286:120286 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::linear         0.05%      30.368ms         5.02%        2.799s     114.794us       0.000us         0.00%       45.189s       1.854ms           0 b           0 b     684.79 Gb    -100.88 G

In [None]:
import os

def save_finetuned_model(model, save_path, model_name):
    """
    Save the finetuned model using both methods: entire model and state dict.
    
    Args:
    model (torch.nn.Module): The finetuned model to save
    save_path (str): Directory to save the model
    model_name (str): Name to use for the saved model files
    """
    # Ensure the save directory exists
    os.makedirs(save_path, exist_ok=True)
    
    # 1. Save the entire model
    entire_model_path = os.path.join(save_path, f"{model_name}_entire.pth")
    torch.save(model, entire_model_path)
    print(f"Entire model saved to {entire_model_path}")
    
    # 2. Save only the state dict
    state_dict_path = os.path.join(save_path, f"{model_name}_state_dict.pth")
    torch.save(model.state_dict(), state_dict_path)
    print(f"Model state dict saved to {state_dict_path}")


In [None]:
save_path = './saved_models'
model_name = 'finetuned_scFoundation'
save_finetuned_model(finetuned_model, save_path, model_name)

In [17]:
x = torch.tensor([1, 2, 3, 4])
torch.unsqueeze(x, 1)

tensor([[1],
        [2],
        [3],
        [4]])

In [None]:
# 1. Save the entire model
entire_model_path = os.path.join(save_path, f"{model_name}_entire.pth")
torch.save(model, entire_model_path)
print(f"Entire model saved to {entire_model_path}")

# 2. Save only the state dict
state_dict_path = os.path.join(save_path, f"{model_name}_state_dict.pth")
torch.save(model.state_dict(), state_dict_path)
print(f"Model state dict saved to {state_dict_path}")