In [1]:
import warnings
warnings.filterwarnings("ignore")
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import scanpy as sc
import torchvision
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch.utils.data import DataLoader


from Modules import *
from predict import *


In [13]:
class CLR(pl.LightningModule):
    def __init__(self, learning_rate = 1e-3, alpha=25, beta=25, gamma=1):
        super(CLR, self).__init__()
        self.learning_rate = learning_rate
        self.alpha, self.beta, self.gamma = alpha, beta, gamma
        self.img_encoder = VICReg()
        self.exp_encoder = VariationalEncoder()
        self.loss = 0
        

    def forward(self, img, exp):
        img_emb = self.img_encoder(img)
        exp_emb = self.exp_encoder(exp)
        self.loss = VICReg_Loss(img_emb, exp_emb, self.alpha, self.beta, self.gamma)
        return self.loss
    
    def configure_optimizers(self):
        optim=torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        optim_dict = {'optimizer': optim}
        return optim_dict
    
    """Training steps"""
    def training_step(self, batch, batch_idx):
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        patch, center, exp, label, adj, oris, sfs, *_ = batch
        
        # resize the tile to 224
        transforms = torch.nn.Sequential(torchvision.transforms.Resize((224, 224)),)
        patch = transforms(patch.squeeze(0))
        
        """Model inference"""
        loss = self(patch, exp)
        self.log('Loss', loss, on_epoch=True, prog_bar=True, logger=True) 
        return loss
        
    def test_step(self, batch, batch_idx):
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        patch, center, exp, label, adj, oris, sfs, *_ = batch
        
        # resize the tile to 224
        transforms = torch.nn.Sequential(torchvision.transforms.Resize((224, 224)),)
        patch = transforms(patch.squeeze(0))
        
        """Model inference"""
        loss = self(patch, exp)
        self.log('Loss', loss, on_epoch=True, prog_bar=True, logger=True)


In [None]:
# folds=[0, 6, 12, 18, 24, 27, 31]
"""Training loops"""
seed = 42
epochs = 100
fold = 0

"""Load dataset"""
seed_everything(seed)
trainset = pk_load(mode='train',fold=fold, )
testset = pk_load(mode='test',fold=fold, )
train_loader = DataLoader(trainset, batch_size=1, num_workers=0, shuffle=True)
test_loader = DataLoader(testset, batch_size=1, num_workers=0, shuffle=False)

"""Define model"""
model = CLR()

"""Setup trainer"""
logger = pl.loggers.CSVLogger("logs", name=f"Earlystop_gamma{gamma}_{fold}")
trainer = pl.Trainer(accelerator='auto',callbacks=[EarlyStopping(monitor='Loss',mode='min')], max_epochs=epochs,logger=False)
trainer.fit(model, test_loader)
trainer.test(model, test_loader)

""" Save model and clean memory """
torch.save(model.state_dict(),f"./model/Earlystop/VICReg+VAE-seed{seed}-epochs{epochs}.ckpt")
gc.collect()

[rank: 0] Global seed set to 42


test sample: ['A1']
validation samples: ['A6', 'C2', 'C6']
Loading imgs...
Loading metadata...


Using cache found in /home/s4654864/.cache/torch/hub/facebookresearch_vicreg_main
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name        | Type               | Params
---------------------------------------------------
0 | img_encoder | VICReg             | 23.5 M
1 | exp_encoder | VariationalEncoder | 5.0 M 
---------------------------------------------------
5.0 M     Trainable params
23.5 M    Non-trainable params
28.5 M    Total params
114.045   Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.


Epoch 36:   0%|          | 0/1 [00:00<?, ?it/s, loss=2.39e+03, Loss_step=2.11e+3, Loss_epoch=2.11e+3]        