# Libraries

In [None]:
!pip install medmnist

In [None]:
%matplotlib inline

import os

import medmnist
from medmnist import INFO, Evaluator

# Generate images with condition labels
import matplotlib.pyplot as plt

import torch
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch import nn
from torch.utils.data import Subset
import torch.optim as optim

import numpy as np
import pandas as pd

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

# Augmentation Function

In [None]:
class TransformsSimCLR:
    """
    A stochastic data augmentation module that transforms any given data example randomly
    resulting in two correlated views of the same example,
    denoted x ̃i and x ̃j, which we consider as a positive pair.
    """

    def __init__(self,size):
        color_jitter = torchvision.transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
        
        self.transform = torchvision.transforms.Compose(
            [
                torchvision.transforms.RandomResizedCrop(size=size),
                torchvision.transforms.RandomHorizontalFlip(),
                torchvision.transforms.RandomApply([color_jitter], p=0.8),
                torchvision.transforms.RandomGrayscale(p=0.2),
                torchvision.transforms.ToTensor(),
                transforms.Normalize(mean=[.5], std=[.5])
            ]
        )

    def __call__(self, x):
        return self.transform(x), self.transform(x)

In [None]:
num_epochs = 20
BATCH_SIZE = 128
num_epochs_finetune = 20

# Load Dataset

In [None]:
#load 

#data_flag = 'pneumoniamnist'
data_flag = 'pathmnist'
#data_flag='bloodmnist'


download = True

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])

In [None]:
# preprocessing

train_transform = TransformsSimCLR(size=(28,28))

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

# load the data
train_dataset = DataClass(split='train', transform=train_transform, download=download)
test_dataset = DataClass(split='test', transform=test_transform, download=download)

finetuning_set = DataClass(split='train', transform=test_transform, download=download)


# encapsulate data into dataloader form
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

In [None]:
dataiter = iter(train_loader)
(sample_1, sample_2), sample_label = next(dataiter)
sample_1.shape

# Model 

In [None]:
class CNN(nn.Module):
    def __init__(self, in_channels, num_features):
        super(CNN, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=3),
            nn.BatchNorm2d(16),
            nn.ReLU())

        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 16, kernel_size=3),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))

        self.layer3 = nn.Sequential(
            nn.Conv2d(16, 64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU())
        
        self.layer4 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU())

        self.layer5 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))

        self.fc = nn.Sequential(
            nn.Linear(64 * 4 * 4, 128),
            nn.ReLU(),
            nn.Linear(128, num_features))


    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


In [None]:
class SimCLR_Model(nn.Module):

    def __init__(self, in_channels, n_features):
        super(SimCLR_Model, self).__init__()
        
        self.n_features = n_features
        
        #CNN
        self.cnn = CNN(in_channels, self.n_features)
        

        #The CNN already has an MLP built-in

    def forward(self, x_i, x_j):
        h_i = self.cnn(x_i)
        h_j = self.cnn(x_j)

        z_i, z_j = h_i, h_j 
        return z_i, z_j

# loss and train

In [None]:
class SimCLR_Loss(nn.Module):
    def __init__(self, batch_size, temperature=1.0):
        super().__init__()
        self.batch_size = batch_size
        self.temperature = temperature

        self.mask = self.mask_correlated_samples(batch_size).to(device)
        self.criterion = nn.CrossEntropyLoss(reduction="mean")
        self.similarity_f = nn.CosineSimilarity(dim=2)

    def mask_correlated_samples(self, batch_size):
        N = 2 * batch_size
        mask = torch.ones((N, N), dtype=bool)
        mask = mask.fill_diagonal_(0)
        
        return mask

    def forward(self, z_i, z_j):
        N = 2 * self.batch_size
        z = torch.cat((z_i, z_j), dim=0)
        sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0))
        #sim[~self.mask] = float("-inf")
        
        
        
        sim_i_j = torch.diag(sim, self.batch_size)
        sim_j_i = torch.diag(sim, -self.batch_size)
        
        #Old implementation
        '''positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)        
        negative_samples = sim[self.mask].reshape(N, -1)
         
        labels = torch.from_numpy(np.array([0]*N)).reshape(-1).to(device).long()
        logits = torch.cat((positive_samples, negative_samples), dim=1)
        
        loss = self.criterion(logits, labels)
        loss /= N'''
        
        #New implementation
        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0)
        numerator = torch.exp(positive_samples/self.temperature)
        denominator = self.mask * torch.exp(sim / self.temperature)
        all_losses = -torch.log(numerator / torch.sum(denominator, dim=1))
        loss = torch.sum(all_losses) / N
        
        return loss

In [None]:
def train(model, train_loader, optimizer, criterion, epoch):
    
    loss_epoch = 0
    for step, ((x_i, x_j), label) in enumerate(train_loader):
        optimizer.zero_grad()

        x_i = x_i.to(device)
        x_j = x_j.to(device)
        
        # positive pair, with encoding
        z_i, z_j = model(x_i,x_j)

        loss = criterion(z_i, z_j)
        loss.backward()
        

        optimizer.step()

        loss_epoch += loss.item()
        
        if step % (len(train_loader)//3) == 0:
            print(f"Epoch {epoch}[{step}/{len(train_loader)}] - Loss: {loss.item()}")
        
    return loss_epoch

In [None]:
def make_optimizer(optimizer_name, model, **kwargs):
    if optimizer_name=='Adam':
        optimizer = torch.optim.Adam(model.parameters(),lr=kwargs['lr'])
    elif optimizer_name=='SGD':
        optimizer = torch.optim.SGD(model.parameters(),lr=kwargs['lr'],momentum=kwargs['momentum'], weight_decay=kwargs['weight_decay'])
    else:
        raise ValueError('Not valid optimizer name')
    return optimizer
    
def make_scheduler(scheduler_name, optimizer, **kwargs):
    if scheduler_name=='MultiStepLR':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=kwargs['milestones'],gamma=kwargs['factor'])
    elif scheduler_name=='CosineAnnealingLR':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=kwargs['tmax'])
    else:
        raise ValueError('Not valid scheduler name')
    return scheduler

# Parameters and training

In [None]:
learning_rate = 0.001
optimizer_name = 'Adam'
scheduler_name = 'CosineAnnealingLR'

model = SimCLR_Model(in_channels=n_channels, n_features=32).to(device)
criterion = SimCLR_Loss(batch_size=BATCH_SIZE, temperature=0.5).to(device)
optimizer = make_optimizer(optimizer_name, model, lr=learning_rate)
#scheduler = make_scheduler(scheduler_name, optimizer, milestones=[20], factor=0.1)
scheduler = make_scheduler(scheduler_name, optimizer, tmax=num_epochs)


for epoch in range(1,num_epochs+1):
    loss = train(model, train_loader, optimizer, criterion, epoch)
    print(f'Epoch {epoch} - Loss: {loss}')
    scheduler.step()

# Final Result

In [None]:
# linear evaluation
finetuning_size = int(0.1 * len(finetuning_set))
finetuning_indices = np.random.choice(len(finetuning_set), finetuning_size, replace=False)
finetuning_dataset = Subset(finetuning_set, finetuning_indices)
finetuning_dataloader = DataLoader(finetuning_dataset, batch_size=32, shuffle=True)

for param in model.parameters():
    param.requires_grad = False

classifier = nn.Linear(model.n_features, n_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=0.001)

for epoch in range(num_epochs_finetune):
    loss_epoch = 0
    correct = 0
    for inputs, labels in finetuning_dataloader:
        # Forward pass through the classifier
        inp = inputs.to(device)
        labels = labels.to(device)
        z_i, _ = model(inp, inp)
        outputs = classifier(z_i)
        pred = torch.argmax(outputs, dim=1)
        

        labels = labels.view(-1)
        
        
        correct += (pred == labels).sum()
        # Compute loss
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_epoch += loss.item()
        
    print(f'Epoch {epoch+1} - Loss: {loss_epoch} - Accuracy: {correct/finetuning_size}')

In [None]:
correct = 0
for inputs, labels in test_loader:
    with torch.no_grad():
            # Forward pass through the classifier
            inp = inputs.to(device)
            labels = labels.to(device)
            z_i, _ = model(inp, inp)
            outputs = classifier(z_i)
            pred = torch.argmax(outputs, dim=1)


            labels = labels.view(-1)


            correct += (pred == labels).sum()

print(f'Test Accuracy: {correct/len(test_dataset)}')