In [1]:
# %%
%reload_ext autoreload
%autoreload 2
import torch
from torch import optim
from FinetunePatientClassification import *
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader

## Problem: bench kept crashing during the train and validation loop. 

According to torch profiler, the training loop ran but crashed somewhere during validation. There is a sudden spike in GPU usage then the error message appeared. 
Troubleshooting attempts:
x tried downsampling data
x batch size
x add gradient accumulation
x mixed precision training
x freeze model parameters

## Input data and preprocessing

In [2]:
# load the preprocessed data for scFoundation
data = pd.read_csv('gene_symbol_converted_data_for_scF.csv')

## Downsample the data due to GPU limitations

In [3]:
labels_df = pd.read_csv('patient_metadata.csv')
labels_df["Response_3m"].value_counts() # 1: OR, 0: NR

Response_3m
1    13160
0    11906
Name: count, dtype: int64

Subset the data to reduce GPU memory burden during training

In [5]:
labels_df.patient_id = labels_df.patient_id.astype('category')

In [8]:
labels_df.patient_id.unique()

['ac01', 'ac02', 'ac03', 'ac04', 'ac05', ..., '30', '31', '32', 'NHL-6', 'NHL-7']
Length: 81
Categories (81, object): ['01', '02', '03', '04', ..., 'ac34', 'ac37', 'ac38', 'ac39']

In [9]:
import random
# remove these ids from the dataset
# rm_patient_ids = ['15', '11', '27', '03', 'Pt010', 'Pt375', 'ac33', 
                    # 'ac12', 'ac03', 'ac17', '26', 'ac22', 'ac08', 'ac32', '05', 
                    # 'ac02', 'ac05', 'ac07', 'ac13', 'ac19', 'ac24', 'ac25', 'ac29',
                    # 'Pt011', 'Pt025', 'Pt237', '02', '04', '07', '13', '14', '20',
                    # '24', '25', '30']

# # remove the patients from the dataset
# labels_df = labels_df[~labels_df['patient_id'].isin(rm_patient_ids)]
patients = labels_df.patient_id.unique()

# # randomly sample patient IDs
random.seed(40)
sample_size = 20
sampled_patient_ids = random.sample(list(patients), sample_size)
# Create a new DF with only the sampled patients
labels_df_downsampled = labels_df[labels_df['patient_id'].isin(sampled_patient_ids)]



# Create a new DF with only the sampled patients
labels_df_downsampled = labels_df_downsampled[labels_df_downsampled['patient_id'].isin(sampled_patient_ids)]




In [10]:
labels_df_downsampled.patient_id.unique()   

['ac04', 'ac05', 'ac08', 'ac09', 'ac13', ..., '21', '27', '28', '32', 'NHL-7']
Length: 20
Categories (81, object): ['01', '02', '03', '04', ..., 'ac34', 'ac37', 'ac38', 'ac39']

In [20]:
labels_df_downsampled

Unnamed: 0,cell_id,nCount_RNA,nFeature_RNA,patient_id,percent_mito,Response_3m,sample_source
1312,Good_Li_2023ac04_AAACCTGGTAGTAGTA-1,2704,1190,ac04,2.841303,0,Deng
1313,Good_Li_2023ac04_AAAGTAGAGAGGACGG-1,4728,1679,ac04,2.697842,0,Deng
1314,Good_Li_2023ac04_AAAGTAGCAGGAATGC-1,3613,1263,ac04,2.482362,0,Deng
1315,Good_Li_2023ac04_AAATGCCCACCGATAT-1,6556,2118,ac04,4.809334,0,Deng
1316,Good_Li_2023ac04_AAATGCCTCACCTTAT-1,704,412,ac04,1.899593,0,Deng
...,...,...,...,...,...,...,...
25061,Sheih_TTTGGTTCATTGTGCA-5,2722,1253,NHL-7,6.673766,1,Sheih
25062,Sheih_TTTGGTTTCGATCCCT-5,6209,1980,NHL-7,3.170656,1,Sheih
25063,Sheih_TTTGGTTTCTACGAGT-5,5606,1925,NHL-7,2.868498,1,Sheih
25064,Sheih_TTTGTCACAGTCAGAG-5,2317,1049,NHL-7,2.549102,1,Sheih


In [23]:
# method 1: randomly spliting
def train_valid_split(label_df, seed_num, split_ratio):
    patient_list = label_df["patient_id"].unique()
    np.random.seed(seed_num)
    reference_patients = np.random.choice(patient_list, size=int(len(patient_list) * split_ratio), replace=False)
    train_data = label_df[label_df['patient_id'].isin(reference_patients)]   
    valid_data = label_df[~label_df['patient_id'].isin(reference_patients)]  
    return train_data, valid_data

In [24]:
train_label, test_label = train_valid_split(labels_df_downsampled, 42, 0.8)


In [29]:
train_label.patient_id.unique()

['ac04', 'ac05', 'ac08', 'ac09', 'ac13', ..., '21', '27', '28', '32', 'NHL-7']
Length: 16
Categories (81, object): ['01', '02', '03', '04', ..., 'ac34', 'ac37', 'ac38', 'ac39']

In [27]:
test_label.patient_id.unique() 

['ac27', 'ac28', 'Pt011', '12']
Categories (81, object): ['01', '02', '03', '04', ..., 'ac34', 'ac37', 'ac38', 'ac39']

In [16]:
train_label["Response_3m"].value_counts(), test_label["Response_3m"].value_counts()

(Response_3m
 1    11997
 0     8821
 Name: count, dtype: int64,
 Response_3m
 0    3085
 1    1163
 Name: count, dtype: int64)

In [30]:
# downsample the expression data accordingly
train_data_downsampled = data.iloc[train_label.index]
test_data_downsampled = data.iloc[test_label.index]

# remove the cell_id column
train_data_downsampled = train_data_downsampled.drop(columns=['cell_id'])
test_data_downsampled = test_data_downsampled.drop(columns=['cell_id'])


In [31]:
# make sure the data and labels are aligned
train_data_downsampled.shape, train_label.shape

((5084, 19264), (5084, 7))

In [32]:
test_data_downsampled.shape, test_label.shape

((1365, 19264), (1365, 7))

In [34]:
test_data_downsampled

Unnamed: 0,A1BG,A1CF,A2M,A2ML1,A3GALT2,A4GALT,A4GNT,AAAS,AACS,AADAC,...,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11A,ZYG11B,ZYX,ZZEF1,ZZZ3
8349,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,1.0,3.0,0.0,0.0,0.0,0.0,0.0,4.0,0.0,0.0
8350,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,1.0,1.0,0.0,0.0,0.0,0.0,0.0,4.0,0.0,0.0
8351,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
8352,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0,...,0.0,7.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0
8353,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4.0,1.0,0.0,...,1.0,3.0,0.0,0.0,0.0,0.0,0.0,10.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
19940,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,10.0,1.0,0.0
19941,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,11.0,0.0,0.0
19942,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,0.0
19943,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,5.0,0.0,0.0,0.0,0.0,0.0,3.0,0.0,0.0


Finetune the model

In [40]:
torch.cuda.empty_cache()

In [39]:
class SingleCellDataset(Dataset):
    def __init__(self, gene_expression_csv, labels_df, label_encoder):
        # Load the gene expression data
        self.gene_expression = gene_expression_csv
        
        # Ensure the labels DataFrame has the same number of rows as the gene expression data
        assert len(self.gene_expression) == len(labels_df), "Mismatch in number of samples between gene expression data and labels"
        
        # Convert labels to numeric if they're categorical
        
        self.labels = torch.LongTensor(label_encoder.transform(labels_df['Response_3m']))
        
        
        # Convert gene expression data to torch tensor
        self.gene_expression = torch.FloatTensor(self.gene_expression.values.astype(np.float32))
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return {
            'x': self.gene_expression[idx],
            'targets': self.labels[idx]
        }

## Training Loop

In [36]:
import os

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

In [37]:
import wandb as wb
wb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkristint[0m ([33mmackall_lab[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [42]:
import gc
import time
from torch.profiler import profile, record_function, ProfilerActivity


run = wb.init(project='scF', job_type='scF_finetune')



def print_gpu_memory(step_name):
    print(f"GPU memory at {step_name}: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

def finetune_scFoundation(gene_exp_train, gene_exp_val, labels_train, labels_val, model_class, ckpt_path,
                          batch_size=4, num_epochs=5, lr=0.001,
                          device='cuda',
                          gradient_accumulation_steps=2,
                          pretrained_model_path=None):
    
    # gene_exp_train, gene_exp_val, labels_train, labels_val = train_test_split(gene_expression_csv, 
    #                                                                           labels_df, test_size=validation_split, 
    #                                                                           random_state=42)

    # Fit LabelEncoder on combined dataset
    le = LabelEncoder()
    # combined_labels = pd.concat([labels_train['Response_3m'], labels_val['Response_3m']])
    # le.fit(combined_labels)

    #create datasets
    train_dataset = SingleCellDataset(gene_exp_train, labels_train, le)
    val_dataset = SingleCellDataset(gene_exp_val, labels_val, le)

    #create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                              num_workers=4, pin_memory=True, drop_last=True) # drop the last incomplete batch
    # to ensure at least 2 samples per batch
    val_loader = DataLoader(val_dataset, batch_size = batch_size // 2 , shuffle=False, 
                            num_workers=4, pin_memory=True, drop_last=True)  



    
    # Initialize model
    model = model_class(ckpt_path=ckpt_path)
    model.build()
    if pretrained_model_path:
        print(f"Loading the pretrained model from {pretrained_model_path}")
        start_time = time.time()    
        state_dict = torch.load(pretrained_model_path, map_location=device)
        model.load_state_dict(state_dict)
        print(f"Model loaded in {time.time() - start_time:.2f} seconds")

    model = model.to(device)

    # Initialize optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Initialize gradient scaler for mixed precision training
    scaler = GradScaler()

    # Optionally load checkpoint
    start_epoch = 0
    

    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
            profile_memory=True, record_shapes=True) as prof:
        
        best_val_loss = float('inf')
        patience = 10
        patience_counter = 0


    # log memory usage
        torch.cuda.memory._record_memory_history(
        max_entries = 100000
    )

        for epoch in range(start_epoch, num_epochs):
            model.train()
            train_loss = 0.0

            # Clear CUDA cache
            torch.cuda.empty_cache()

            for i, batch in enumerate(train_loader):
                print_gpu_memory(f"Epoch {epoch}, Batch {i} start")
                # Move batch to device
                batch = {k: v.to(device) for k, v in batch.items()}

                # Mixed precision training
                with autocast():
                    # Forward pass
                    logits = model(batch)

                    try:
                        torch.cuda.memory._dump_snapshot("forward_pass.pickle")
                    except Exception as e:
                        run.log(f"Failed to capture memory snapshot {e}")

                    # Compute loss
                    loss = model.compute_loss(logits, batch['targets'].float()) / gradient_accumulation_steps
                
                # print_gpu_memory(f"Epoch {epoch}, Batch {i} after forward pass")
                # Ensure loss requires gradient
                assert loss.requires_grad, "Loss does not require gradients"

                # Backward pass with gradient scaling
                scaler.scale(loss).backward()

                try:
                    torch.cuda.memory._dump_snapshot("backward_pass.pickle")
                except Exception as e:
                    run.log(f"Failed to capture memory snapshot {e}")

                # print_gpu_memory(f"Epoch {epoch}, Batch {i} after backward pass")

                if (i + 1) % gradient_accumulation_steps == 0:
                    # Unscale gradients and optimizer step
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()

                    # add garbage collection
                    gc.collect()
                    torch.cuda.empty_cache()
                
                    # print_gpu_memory(f"Epoch {epoch}, Batch {i} after optimizer step")

                train_loss += loss.item() * gradient_accumulation_steps
            print(f"Max GPU memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")
    


            #validation
            
            model.eval()
            
            val_loss = 0.0
            correct = 0
            total = 0

            torch.cuda.empty_cache()
            print_gpu_memory(f"Before validation")

            with torch.no_grad():
                for j, batch in enumerate(val_loader):
                    
                    print_gpu_memory(f"Validation, Batch {j} start")
                    if batch['x'].size(0) < 2:  # Skip batches smaller than 2
                        continue
                    batch = {k: v.to(device) for k, v in batch.items()}
                    try:
                        torch.cuda.memory._dump_snapshot("validation.pickle")
                    except Exception as e:
                        run.log(f"Failed to capture memory snapshot {e}")

                    logits = model(batch)
                    val_loss += model.compute_loss(logits, batch['targets'].float()).item()
                    
                    predicted = (torch.sigmoid(logits) > 0.5).float()
                    predicted = predicted.squeeze().flatten()
                    print("predicted: ", predicted)
                    print("targets: ", batch['targets'])
                    # calculate correct predictions
                    correct_pred = (predicted == batch['targets']).sum().item()
                    print(f'correct predictions: {correct_pred}')


                    total += batch['targets'].size(0)
                    correct += correct_pred

                    # Move data back to CPU
                    for k in batch.keys():
                        batch[k] = batch[k].cpu()
                    del batch, logits, predicted
                    torch.cuda.empty_cache()

                    # print_gpu_memory(f"Validation, Batch {j} end")

            
            
            # Print epoch results
            print(f"Epoch {epoch+1}/{num_epochs}")
            print(f"total samples in validation: {total}; correct predictions: {correct}")
            print(f"Train Loss: {train_loss/len(train_loader):.4f}")
            print(f"Validation Loss: {val_loss/len(val_loader):.4f}")
            if total == 0:
                print("No validation samples")
            else:
                print(f"Validation Accuracy: {100*correct/total:.2f}%")
            print("-----------------------------")
        # Stop recording memory snapshot history.
        torch.cuda.memory._record_memory_history(enabled=None)
    
    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
    print(f"Peak CUDA memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")
    print(f"Peak CUDA memory reserved: {torch.cuda.max_memory_reserved() / 1e9:.2f} GB")
    print("\nCUDA Memory Summary:")
    print(torch.cuda.memory_summary())

    return model

    

In [43]:
finetuned_model = finetune_scFoundation(train_data_downsampled,
                                        test_data_downsampled, 
                                        train_label, test_label, 
                                        model_class=FinetunePatientClassification, 
                                        ckpt_path='./models/models.ckpt', 
                                        num_epochs=3, 
                                        lr=0.0001, 
                                        device='cuda',
                                        batch_size=4, 
                                        pretrained_model_path=None)   

{'mask_gene_name': False, 'gene_num': 19266, 'seq_len': 19266, 'encoder': {'hidden_dim': 768, 'depth': 12, 'heads': 12, 'dim_head': 64, 'seq_len': 19266, 'module_type': 'transformer', 'norm_first': False}, 'decoder': {'hidden_dim': 512, 'depth': 6, 'heads': 8, 'dim_head': 64, 'module_type': 'performer', 'seq_len': 19266, 'norm_first': False}, 'n_class': 104, 'pad_token_id': 103, 'mask_token_id': 102, 'bin_num': 100, 'bin_alpha': 1.0, 'rawcount': True, 'model': 'mae_autobin', 'test_valid_train_idx_dict': '/nfs_beijing/minsheng/data/os10000w-new/global_shuffle/meta.csv.train_set_idx_dict.pt', 'valid_data_path': '/nfs_beijing/minsheng/data/valid_count_10w.npz', 'num_tokens': 13, 'train_data_path': None, 'isPanA': False, 'isPlanA1': False, 'max_files_to_load': 5, 'bin_type': 'auto_bin', 'value_mask_prob': 0.3, 'zero_mask_prob': 0.03, 'replace_prob': 0.8, 'random_token_prob': 0.1, 'mask_ignore_token_ids': [0], 'decoder_add_zero': True, 'mae_encoder_max_seq_len': 15000, 'isPlanA': False, 'ma

STAGE:2024-07-29 03:23:48 13950:13950 ActivityProfilerController.cpp:314] Completed Stage: Warm Up


GPU memory at Epoch 0, Batch 0 start: 0.40 GB
GPU memory at Epoch 0, Batch 1 start: 0.45 GB
GPU memory at Epoch 0, Batch 2 start: 0.48 GB
GPU memory at Epoch 0, Batch 3 start: 0.51 GB
GPU memory at Epoch 0, Batch 4 start: 0.48 GB
GPU memory at Epoch 0, Batch 5 start: 0.51 GB
GPU memory at Epoch 0, Batch 6 start: 0.48 GB
GPU memory at Epoch 0, Batch 7 start: 0.51 GB
GPU memory at Epoch 0, Batch 8 start: 0.48 GB
GPU memory at Epoch 0, Batch 9 start: 0.51 GB
GPU memory at Epoch 0, Batch 10 start: 0.48 GB
GPU memory at Epoch 0, Batch 11 start: 0.51 GB
GPU memory at Epoch 0, Batch 12 start: 0.48 GB
GPU memory at Epoch 0, Batch 13 start: 0.51 GB
GPU memory at Epoch 0, Batch 14 start: 0.48 GB
GPU memory at Epoch 0, Batch 15 start: 0.51 GB
GPU memory at Epoch 0, Batch 16 start: 0.48 GB
GPU memory at Epoch 0, Batch 17 start: 0.51 GB
GPU memory at Epoch 0, Batch 18 start: 0.48 GB
GPU memory at Epoch 0, Batch 19 start: 0.51 GB
GPU memory at Epoch 0, Batch 20 start: 0.48 GB
GPU memory at Epoch 0, 

STAGE:2024-07-29 07:11:14 13950:13950 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-07-29 07:11:27 13950:13950 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::linear         0.19%        3.275s         5.05%       89.282s     220.593us       0.000us         0.00%     1128.118s       2.787ms           0 b           0 b   18568.44 Gb   -2075.75 G

In [17]:
finetuned_model = finetune_scFoundation(data, labels_df, model_class=FinetunePatientClassification, 
                                        ckpt_path='./models/models.ckpt', num_epochs=3, lr=0.0001, device='cuda',
                                        validation_split=0.2, batch_size=4, pretrained_model_path=None)   

{'mask_gene_name': False, 'gene_num': 19266, 'seq_len': 19266, 'encoder': {'hidden_dim': 768, 'depth': 12, 'heads': 12, 'dim_head': 64, 'seq_len': 19266, 'module_type': 'transformer', 'norm_first': False}, 'decoder': {'hidden_dim': 512, 'depth': 6, 'heads': 8, 'dim_head': 64, 'module_type': 'performer', 'seq_len': 19266, 'norm_first': False}, 'n_class': 104, 'pad_token_id': 103, 'mask_token_id': 102, 'bin_num': 100, 'bin_alpha': 1.0, 'rawcount': True, 'model': 'mae_autobin', 'test_valid_train_idx_dict': '/nfs_beijing/minsheng/data/os10000w-new/global_shuffle/meta.csv.train_set_idx_dict.pt', 'valid_data_path': '/nfs_beijing/minsheng/data/valid_count_10w.npz', 'num_tokens': 13, 'train_data_path': None, 'isPanA': False, 'isPlanA1': False, 'max_files_to_load': 5, 'bin_type': 'auto_bin', 'value_mask_prob': 0.3, 'zero_mask_prob': 0.03, 'replace_prob': 0.8, 'random_token_prob': 0.1, 'mask_ignore_token_ids': [0], 'decoder_add_zero': True, 'mae_encoder_max_seq_len': 15000, 'isPlanA': False, 'ma

STAGE:2024-07-28 05:11:35 9276:9276 ActivityProfilerController.cpp:314] Completed Stage: Warm Up


GPU memory at Epoch 0, Batch 0 start: 0.40 GB
GPU memory at Epoch 0, Batch 1 start: 0.42 GB
GPU memory at Epoch 0, Batch 2 start: 0.42 GB
GPU memory at Epoch 0, Batch 3 start: 0.42 GB
GPU memory at Epoch 0, Batch 4 start: 0.42 GB
GPU memory at Epoch 0, Batch 5 start: 0.42 GB
GPU memory at Epoch 0, Batch 6 start: 0.42 GB
GPU memory at Epoch 0, Batch 7 start: 0.42 GB
GPU memory at Epoch 0, Batch 8 start: 0.42 GB
GPU memory at Epoch 0, Batch 9 start: 0.42 GB
GPU memory at Epoch 0, Batch 10 start: 0.42 GB
GPU memory at Epoch 0, Batch 11 start: 0.42 GB
GPU memory at Epoch 0, Batch 12 start: 0.42 GB
GPU memory at Epoch 0, Batch 13 start: 0.42 GB
GPU memory at Epoch 0, Batch 14 start: 0.42 GB
GPU memory at Epoch 0, Batch 15 start: 0.42 GB
GPU memory at Epoch 0, Batch 16 start: 0.42 GB
GPU memory at Epoch 0, Batch 17 start: 0.42 GB
GPU memory at Epoch 0, Batch 18 start: 0.42 GB
GPU memory at Epoch 0, Batch 19 start: 0.42 GB
GPU memory at Epoch 0, Batch 20 start: 0.42 GB
GPU memory at Epoch 0, 

In [25]:
finetuned_model = finetune_scFoundation(data_downsampled, labels_df_downsampled, model_class=FinetunePatientClassification, 
                                        ckpt_path='./models/models.ckpt', num_epochs=3, lr=0.0001, device='cuda',
                                        validation_split=0.2, batch_size=4, pretrained_model_path=None)#"saved_models/finetuned_scFoundation_v1_state_dict.pth")   

{'mask_gene_name': False, 'gene_num': 19266, 'seq_len': 19266, 'encoder': {'hidden_dim': 768, 'depth': 12, 'heads': 12, 'dim_head': 64, 'seq_len': 19266, 'module_type': 'transformer', 'norm_first': False}, 'decoder': {'hidden_dim': 512, 'depth': 6, 'heads': 8, 'dim_head': 64, 'module_type': 'performer', 'seq_len': 19266, 'norm_first': False}, 'n_class': 104, 'pad_token_id': 103, 'mask_token_id': 102, 'bin_num': 100, 'bin_alpha': 1.0, 'rawcount': True, 'model': 'mae_autobin', 'test_valid_train_idx_dict': '/nfs_beijing/minsheng/data/os10000w-new/global_shuffle/meta.csv.train_set_idx_dict.pt', 'valid_data_path': '/nfs_beijing/minsheng/data/valid_count_10w.npz', 'num_tokens': 13, 'train_data_path': None, 'isPanA': False, 'isPlanA1': False, 'max_files_to_load': 5, 'bin_type': 'auto_bin', 'value_mask_prob': 0.3, 'zero_mask_prob': 0.03, 'replace_prob': 0.8, 'random_token_prob': 0.1, 'mask_ignore_token_ids': [0], 'decoder_add_zero': True, 'mae_encoder_max_seq_len': 15000, 'isPlanA': False, 'ma

STAGE:2024-07-28 23:13:26 7406:7406 ActivityProfilerController.cpp:314] Completed Stage: Warm Up


GPU memory at Epoch 0, Batch 0 start: 0.40 GB
GPU memory at Epoch 0, Batch 1 start: 0.45 GB
GPU memory at Epoch 0, Batch 2 start: 0.48 GB
GPU memory at Epoch 0, Batch 3 start: 0.51 GB
GPU memory at Epoch 0, Batch 4 start: 0.48 GB
GPU memory at Epoch 0, Batch 5 start: 0.51 GB
GPU memory at Epoch 0, Batch 6 start: 0.48 GB
GPU memory at Epoch 0, Batch 7 start: 0.51 GB
GPU memory at Epoch 0, Batch 8 start: 0.48 GB
GPU memory at Epoch 0, Batch 9 start: 0.51 GB
GPU memory at Epoch 0, Batch 10 start: 0.48 GB
GPU memory at Epoch 0, Batch 11 start: 0.51 GB
GPU memory at Epoch 0, Batch 12 start: 0.48 GB
GPU memory at Epoch 0, Batch 13 start: 0.51 GB
GPU memory at Epoch 0, Batch 14 start: 0.48 GB
GPU memory at Epoch 0, Batch 15 start: 0.51 GB
GPU memory at Epoch 0, Batch 16 start: 0.48 GB
GPU memory at Epoch 0, Batch 17 start: 0.51 GB
GPU memory at Epoch 0, Batch 18 start: 0.48 GB
GPU memory at Epoch 0, Batch 19 start: 0.51 GB
GPU memory at Epoch 0, Batch 20 start: 0.48 GB
GPU memory at Epoch 0, 

STAGE:2024-07-29 00:07:13 7406:7406 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-07-29 00:07:16 7406:7406 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::linear         0.25%     762.389ms         7.41%       22.928s     233.321us       0.000us         0.00%      225.269s       2.292ms           0 b           0 b    3783.49 Gb    -446.58 G

In [44]:
import os

def save_finetuned_model(model, save_path, model_name):
    """
    Save the finetuned model using both methods: entire model and state dict.
    
    Args:
    model (torch.nn.Module): The finetuned model to save
    save_path (str): Directory to save the model
    model_name (str): Name to use for the saved model files
    """
    # Ensure the save directory exists
    os.makedirs(save_path, exist_ok=True)
    
    # 1. Save the entire model
    entire_model_path = os.path.join(save_path, f"{model_name}_entire.pth")
    torch.save(model, entire_model_path)
    print(f"Entire model saved to {entire_model_path}")
    
    # 2. Save only the state dict
    state_dict_path = os.path.join(save_path, f"{model_name}_state_dict.pth")
    torch.save(model.state_dict(), state_dict_path)
    print(f"Model state dict saved to {state_dict_path}")


In [45]:
save_path = './saved_models'
model_name = 'finetuned_scFoundation_2_layers_20_patients'
save_finetuned_model(finetuned_model, save_path, model_name)

Entire model saved to ./saved_models/finetuned_scFoundation_2_layers_20_patients_entire.pth
Model state dict saved to ./saved_models/finetuned_scFoundation_2_layers_20_patients_state_dict.pth


In [14]:
from torchinfo import summary

summary(finetune_model)

NameError: name 'finetune_model' is not defined

In [18]:
x = torch.tensor([1, 2, 3, 4])
torch.unsqueeze(x, 1)

tensor([[1],
        [2],
        [3],
        [4]])