<b> Model </b>

In [2]:
from torchvision.models.resnet import resnet18
import torch.nn as nn


base_model = resnet18(pretrained=True)



In [3]:
class ContrastiveNet(nn.Module):
    def __init__(self, base_model, hidden_dim = 128) -> None:
        super(ContrastiveNet, self).__init__()
        self.contnet = nn.Sequential(
            base_model,
            nn.ReLU(inplace=True),
            nn.Linear(1000, hidden_dim)
        )
    
    def forward(self, x):
        return self.contnet(x)


In [4]:
contrastive_model = ContrastiveNet(base_model)

<b> DataLoaders </b>

In [5]:
import torchvision.transforms.v2 as transforms

contrast_transforms = transforms.Compose([transforms.Resize((224, 224)),
                                        transforms.RandomHorizontalFlip(),
                                        transforms.RandomApply([
                                        transforms.ColorJitter(brightness=0.5,
                                                                contrast=0.5,
                                                                saturation=0.5,
                                                                hue=0.1)
                                    ], p=0.8),
                                        transforms.RandomGrayscale(p=0.2),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.5,), (0.5,))
                                         ])



In [6]:
from torchvision import datasets

data_dir = 'dataset'
dataset = datasets.ImageFolder(data_dir, transform=contrast_transforms)

Splitting into Train & Validation datasets with a ratio of 85:15

In [7]:
import torch


train_ratio = 0.85
train_size = int(len(dataset) * train_ratio)
test_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
                                                            dataset,
                                                            [train_size, test_size])

In [8]:
batch_size = 32
num_workers = 4
trainLoader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, 
                                           num_workers=num_workers, shuffle=True)
valLoader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, 
                                           num_workers=num_workers, shuffle=True)

<b> Loss Function & Training Loop </b>

In [9]:
from pytorch_metric_learning import losses
import torch.nn.functional as F

class SupervisedContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.1):
        super(SupervisedContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, feature_vectors, labels):
        # Normalize feature vectors
        feature_vectors_normalized = F.normalize(feature_vectors, p=2, dim=1)
        # Compute logits
        logits = torch.div(
            torch.matmul(
                feature_vectors_normalized, torch.transpose(feature_vectors_normalized, 0, 1)
            ),
            self.temperature,
        )
        return losses.NTXentLoss(temperature=0.07)(logits, torch.squeeze(labels))

In [11]:
criterion = SupervisedContrastiveLoss()
optimizer = torch.optim.Adam(contrastive_model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-5)

In [12]:
import time
from tqdm import tqdm

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device):
    start = time.time()

    train_loss_history = []
    val_loss_history = []

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluation mode

            running_loss = 0.0

            # Iterate over data.
            with tqdm(total=len(train_loader) if phase == 'train' else len(val_loader), desc=phase) as pbar:
                for inputs, labels in train_loader if phase == 'train' else val_loader:
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)

                        # backward + optimize only if in training phase
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    # statistics
                    running_loss += loss.item() * inputs.size(0)

                    # update progress bar
                    pbar.update(1)

            epoch_loss = running_loss / len(train_loader.dataset) if phase == 'train' else running_loss / len(val_loader.dataset)

            print('{} Loss: {:.4f}'.format(phase, epoch_loss))

            # update loss history
            if phase == 'train':
                train_loss_history.append(epoch_loss)
            else:
                val_loss_history.append(epoch_loss)

        # update learning rate based on scheduler
        scheduler.step()

        time_elapsed = time.time() - start
        print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        print('\n')

In [None]:
num_epochs = 1
device = "cpu"
train_model(contrastive_model, trainLoader, valLoader, criterion, optimizer, scheduler, num_epochs, device)