In [1]:
import sys
sys.path.append('../')

import os

import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from PIL import Image
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import normalize

from lightly.data import LightlyDataset, SimCLRCollateFunction, collate

from src.dataset.bps_dataset import BPSMouseDataset
from src.dataset.augmentation import (
    NormalizeBPS,
    ResizeBPS,
    VFlipBPS,
    HFlipBPS,
    RotateBPS,
    RandomCropBPS,
    ToThreeChannels,
    ToTensor
)
from src.model.unsupervised.resnet101 import ResNet101

In [8]:
num_workers = 4
batch_size = 32
seed = 1
max_epochs = 20
input_size = 128

In [9]:
pl.seed_everything(seed)

Global seed set to 1


1

In [15]:
collate_fn = SimCLRCollateFunction(input_size=input_size, vf_prob=0.5, rr_prob=0.5)

train_csv_path = 'meta_dose_hi_hr_4_post_exposure_train.csv'
test_csv_path = 'meta_dose_hi_hr_4_post_exposure_test.csv'

transformations_train = transforms.Compose([
                        NormalizeBPS(),
                        ResizeBPS(256, 256),
                        transforms.RandomApply(VFlipBPS(), p=0.5),
                        transforms.RandomApply(HFlipBPS(), p=0.5),
                        transforms.RandomApply(RotateBPS(90), p=0.5),
                        RandomCropBPS(200, 200),
                        ToThreeChannels(),
                        ToTensor()
                    ])

transformations_test = transforms.Compose([
                        NormalizeBPS(),
                        ResizeBPS(256, 256),
                        ToThreeChannels(),
                        ToTensor()
                    ])

transformed_dataset_train = BPSMouseDataset(train_csv_path,
                                           '../Microscopy/train',
                                           transform=transformations_train,
                                            file_on_prem=True
                                           )

transformed_dataset_test = BPSMouseDataset(train_csv_path,
                                           '../Microscopy/train',
                                           transform=transformations_test,
                                            file_on_prem=True
                                           )

data_loader_train = DataLoader(transformed_dataset_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
data_loader_test = DataLoader(transformed_dataset_test, batch_size=batch_size, shuffle=True, num_workers=num_workers)

In [11]:
from lightly.loss import NTXentLoss
from lightly.models.modules.heads import SimCLRProjectionHead

class SimCLRModel(pl.LightningModule):
    def __init__(self):
        super().__init__()

        # create a ResNet backbone and remove the classification head
        resnet = ResNet101(pretrained=False)
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])

        hidden_dim = resnet.fc.in_features
        self.projection_head = SimCLRProjectionHead(hidden_dim, hidden_dim, 128)

        self.criterion = NTXentLoss()

    def forward(self, x):
        h = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(h)
        return z

    def training_step(self, batch, batch_idx):
        x0, x1, _ = batch
        z0 = self.forward(x0)
        z1 = self.forward(x1)
        loss = self.criterion(z0, z1)
        self.log("train_loss_ssl", loss)
        return loss

    def configure_optimizers(self):
        optim = torch.optim.SGD(
            self.parameters(), lr=6e-2, momentum=0.9, weight_decay=5e-4
        )
        return optim

In [12]:
model = SimCLRModel()
trainer = pl.Trainer(max_epochs=max_epochs, devices=1, accelerator="gpu")
trainer.fit(model, data_loader_train)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3070 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type                 | Params
---------------------------------------------------------
0 | backbone        | Sequential           | 42.5 M
1 | projection_head | SimCLRProjectionHead | 4.5 M 
2 | criterion       | NTXentLoss           | 0     
---------------------------------------------------------
47.0 M    Trainable params
0         Non-trainable params
47.0 M    Total params
187.844   Total estimated

Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
