In [None]:
!pip install -e /dss/dsshome1/04/di93zer/git/cellnet --no-deps

In [1]:
import os
import sys

import seaborn as sns
import torch
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, TQDMProgressBar
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.utilities.model_summary import ModelSummary
from lightning.pytorch import seed_everything

In [2]:
torch.set_float32_matmul_precision('high')

In [3]:
%load_ext autoreload

In [4]:
%autoreload
sys.path.append("/projects/b1042/GoyalLab/jaekj/scTab-devel")
from cellnet.estimators import EstimatorCellTypeClassifier

ModuleNotFoundError: No module named 'cellnet'

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



## Load Embedding

In [12]:
import torch
from torch.utils.data import Dataset, DataLoader

class EmbeddingDataset(Dataset):
    def __init__(self, embeddings, labels=None):
        self.embeddings = embeddings
        self.labels = labels

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        x = self.embeddings[idx]
        if self.labels is not None:
            return x, self.labels[idx]
        return x


class EmbeddingDataModule:
    def __init__(self, train_emb, val_emb, test_emb, batch_size=2048):
        self.train_dataset = EmbeddingDataset(train_emb['X'], train_emb['y'])
        self.val_dataset = EmbeddingDataset(val_emb['X'], val_emb['y'])
        self.test_dataset = EmbeddingDataset(test_emb['X'], test_emb['y'])
        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)


In [13]:
DATA_PATH = "/projects/b1042/GoyalLab/jaekj/SCTAB_FINAL/merlin_cxg_2023_05_15_sf-log1p"

In [14]:
estim = EstimatorCellTypeClassifier(DATA_PATH)

train_emb = torch.load("/projects/b1042/GoyalLab/jaekj/SCTAB_FINAL/train_embedding_batch_fixed.pt",  weights_only=True)
val_emb = torch.load("/projects/b1042/GoyalLab/jaekj/SCTAB_FINAL/val_embedding.pt",  weights_only=True)
test_emb = torch.load("/projects/b1042/GoyalLab/jaekj/SCTAB_FINAL/test_embedding.pt",  weights_only=True)

estim.datamodule = EmbeddingDataModule(train_emb, val_emb, test_emb, batch_size=2048)

# Init model

In [15]:
# config parameters
MODEL = 'cxg_2023_05_15_linear'
CHECKPOINT_PATH = os.path.join('/mnt/dssfs02/tb_logs', MODEL)
LOGS_PATH = os.path.join('/mnt/dssfs02/tb_logs', MODEL)
SEED = 1

In [16]:
estim.init_trainer(
    trainer_kwargs={
        'max_epochs': 12,
        'default_root_dir': CHECKPOINT_PATH,
        'accelerator': 'gpu',
        'devices': 1,
        'num_sanity_val_steps': 0,
        'check_val_every_n_epoch': 1,
        'logger': [TensorBoardLogger(LOGS_PATH, name='default')],
        'log_every_n_steps': 100,
        'detect_anomaly': False,
        'enable_progress_bar': True,
        'enable_model_summary': False,
        'enable_checkpointing': True,
        'callbacks': [
            TQDMProgressBar(refresh_rate=50),
            LearningRateMonitor(logging_interval='step'),
            ModelCheckpoint(filename='val_f1_macro_{epoch}_{val_f1_macro:.3f}', monitor='val_f1_macro', mode='max',
                            every_n_epochs=1, save_top_k=2),
            ModelCheckpoint(filename='val_loss_{epoch}_{val_loss:.3f}', monitor='val_loss', mode='min',
                            every_n_epochs=1, save_top_k=2)
        ],
    }
)

estim.init_model(
    model_type='linear',
    model_kwargs={
        'learning_rate': 0.0005,
        'weight_decay': 0.05,
        'optimizer': torch.optim.AdamW,
        'lr_scheduler': torch.optim.lr_scheduler.StepLR,
        'lr_scheduler_kwargs': {'step_size': 3, 'gamma': 0.9, 'verbose': True},
    },
)

print(ModelSummary(estim.model))



MisconfigurationException: No supported gpu backend found!

# Find learning rate

In [None]:
lr_find_res = estim.find_lr(lr_find_kwargs={'early_stop_threshold': 10., 'min_lr': 1e-8, 'max_lr': 10., 'num_training': 120})

In [None]:
ax = sns.lineplot(x=lr_find_res[1]['lr'], y=lr_find_res[1]['loss'])
ax.set_xscale('log')
ax.set_ylim(2., top=9.)
ax.set_xlim(1e-6, 10.)
print(f'Suggested learning rate: {lr_find_res[0]}')

# Fit model

In [None]:
estim.train()