In [2]:
!pip install timm
!pip install pytorch-lightning

Collecting timm
  Downloading timm-0.9.12-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: timm
Successfully installed timm-0.9.12
Collecting pytorch-lightning
  Downloading pytorch_lightning-2.1.3-py3-none-any.whl (777 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m777.7/777.7 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
Collecting torchmetrics>=0.7.0 (from pytorch-lightning)
  Downloading torchmetrics-1.3.0.post0-py3-none-any.whl (840 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m840.2/840.2 kB[0m [31m18.6 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from pytorch-lightning)
  Downloading lightning_utilities-0.10.0-py3-none-any.whl (24 kB)
Installing collected packages: lightning-utilities, torchmetrics, pytorch-lightning
Successfully installed lightning-utilities-0.10.0 pytorch-lightning-2.1

In [3]:
#Implementing Supervised Constrastive Learning using Pytorch Lightning with MNIST Dataset

# Stage 1: Training the Enconder

#import
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
import torchvision.transforms as transforms
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.data import random_split
import pytorch_lightning as pl
import torchmetrics
from torchmetrics import Metric
import timm
import torch.nn as nn

In [4]:
# SupConLoss: https://github.com/HobbitLong/SupContrast/blob/master/losses.py
class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature

    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf

        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        # modified to handle edge cases when there is no positive pair
        # for an anchor point.
        # Edge case e.g.:-
        # features of shape: [4,1,...]
        # labels:            [0,1,1,2]
        # loss before mean:  [nan, ..., ..., nan]
        mask_pos_pairs = mask.sum(1)
        mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss

In [5]:
#Converting 1 Channel Image to 3 Channels
class To3Channel:
  def __call__(self, img):

    stacked_img = torch.stack([img[0], img[0], img[0]])
    return stacked_img

In [6]:
# Transformation to get multiple versions of a image (anchor, positives)
class TwoCropTransform:
    """Create two crops of the same image"""
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, x):
        return [self.transform(x), self.transform(x)]

In [7]:
#DataLoader
class MnistDataModuleSCL(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size, num_workers):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

    # Single GPU, We download the data here
    def prepare_data(self):
        datasets.MNIST(self.data_dir, train = True,
                       download = True)
        datasets.MNIST(self.data_dir, train = False,
                       download = True)

    # Multiple GPUs
    def setup(self, stage):
        train_transform = transforms.Compose([
                                              transforms.RandomHorizontalFlip(),
                                              transforms.RandomApply([
                                                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
                                                ], p=0.8),
                                              transforms.RandomGrayscale(p=0.2),
                                              transforms.ToTensor(),
                                              To3Channel(),
                                            ])

        entire_dataset = datasets.MNIST(root = self.data_dir,
                                        train = True,
                                        transform = TwoCropTransform(train_transform),
                                        download = False,
                                       )
        self.train_ds, self.valid_ds = random_split(entire_dataset, [0.7,0.3])

        self.test_ds = datasets.MNIST(root = self.data_dir,
                                      train = False,
                                      transform = transforms.ToTensor(),
                                      download = False)

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size = self.batch_size,
                          num_workers = self.num_workers, shuffle = True)

    def val_dataloader(self):
        return DataLoader(self.valid_ds, batch_size = self.batch_size,
                          num_workers = self.num_workers, shuffle = False)

    def test_dataloader(self):
        return DataLoader(self.test_ds, batch_size = self.batch_size,
                          num_workers = self.num_workers, shuffle = False)

In [14]:
# Model
class Encoder(pl.LightningModule):
    def __init__(self, model_name, emb_dim):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained = True)
        self.in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(self.in_features, emb_dim)
        self.loss_fn = SupConLoss(0.07, 'one', 0.07)


    def forward(self, x):
        emb = self.backbone(x)
        return emb

    # Difference between Normal and Lightning: The train, valid and test steps is written here inside the class
    def training_step(self, batch, batch_idx):
        images, labels = batch

        bsz = len(labels)

        images = torch.cat([images[0], images[1]], dim=0)

        #print(images.shape)

        features = self.forward(images)

        # Manipulating the features for SupConLoss
        f1, f2 = torch.split(features, [bsz, bsz], dim=0)
        features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)

        # Calculating SupConLoss
        loss = self.loss_fn(features,labels)

        return loss

    # We have training_epoch_end function
    # def on_train_epoch_end(self):
    #     #print("Epoch Done")

    def validation_step(self, batch , batch_idx):

        images, labels = batch

        bsz = len(labels)

        images = torch.cat([images[0], images[1]], dim=0)

        features = self.forward(images)

        # Manipulating the arrangment of features for SupConLoss
        f1, f2 = torch.split(features, [bsz, bsz], dim=0)

        features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)

        # Calculating SupConLoss
        loss = self.loss_fn(features,labels)

        return loss

    def test_step(self, batch, batch_idx):
        images, labels = batch

        bsz = len(labels)

        images = torch.cat([images[0], images[1]], dim=0)

        features = self.forward(images)

        # Manipulating the arrangment of features for SupConLoss
        f1, f2 = torch.split(features, [bsz, bsz], dim=0)

        features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)

        # Calculating SupConLoss
        loss = self.loss_fn(features,labels)
        return loss

    # We can add schedulers to this method
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr = 0.001)

In [15]:
device = torch.device("cuda" if torch.backends.mps.is_available() else 'cpu')

# HP
input_size = 784
emb_dim = 128
learning_rate = 0.001
batch_size = 64
num_epochs = 2
backbone = "resnet50"

#Data Loading
dm = MnistDataModuleSCL(data_dir = "dataset/", batch_size = batch_size, num_workers = 1)

# Init Model
model = Encoder(backbone, 128).to(device)

# Trainer
trainer = pl.Trainer(accelerator = "cuda", devices = [0], min_epochs = 1, max_epochs = num_epochs, precision = 16)
trainer.fit(model, dm)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/102M [00:00<?, ?B/s]

/usr/local/lib/python3.10/dist-packages/lightning_fabric/connector.py:558: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name     | Type       | Params
----------------------------------------
0 | backbone | ResNet     | 23.8 M
1 | loss_fn  | SupConLoss | 0     
----------------------------------------
23.8 M    Trainable params
0         Non-trainable params
23.8 M    

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=2` reached.


In [16]:
# Save the model checkpoint
model_checkpoint_path = './supconencoder.ckpt'
trainer.save_checkpoint(model_checkpoint_path)

In [26]:
#DataLoader
class MnistDataModuleCE(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size, num_workers):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

    # Single GPU, We download the data here
    def prepare_data(self):
        datasets.MNIST(self.data_dir, train = True,
                       download = True)
        datasets.MNIST(self.data_dir, train = False,
                       download = True)

    # Multiple GPUs
    def setup(self, stage):
        train_transform = transforms.Compose([
                                              transforms.RandomHorizontalFlip(),
                                              transforms.RandomApply([
                                                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
                                                ], p=0.8),
                                              transforms.RandomGrayscale(p=0.2),
                                              transforms.ToTensor(),
                                              To3Channel(),
                                            ])


        entire_dataset = datasets.MNIST(root = self.data_dir,
                                        train = True,
                                        transform = train_transform,
                                        download = False,
                                       )

        self.train_ds, self.valid_ds = random_split(entire_dataset, [0.7,0.3])

        self.test_ds = datasets.MNIST(root = self.data_dir,
                                      train = False,
                                      transform = transforms.Compose([transforms.ToTensor(), To3Channel()]),
                                      download = False)

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size = self.batch_size,
                          num_workers = self.num_workers, shuffle = True)

    def val_dataloader(self):
        return DataLoader(self.valid_ds, batch_size = self.batch_size,
                          num_workers = self.num_workers, shuffle = False)

    def test_dataloader(self):
        return DataLoader(self.test_ds, batch_size = self.batch_size,
                          num_workers = self.num_workers, shuffle = False)

In [27]:
import torch
import torch.nn as nn
import pickle
import pytorch_lightning as pl
from torch.optim import Adam



class SupConCE(pl.LightningModule):
    def __init__(self,):
        super().__init__()
        self.loss_fn = nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.Accuracy(task = 'multiclass', num_classes = 10)
        self.f1_score = torchmetrics.F1Score(task = 'multiclass', num_classes = 10)

        model_path = '/content/supconencoder.ckpt'
        pretrained_model = Encoder.load_from_checkpoint(model_path, model_name = backbone, emb_dim = 128)


        #Freezing all the encoder layers
        for param in pretrained_model.parameters():
          param.requires_grad = False


        #Trainging only the last layer
        pretrained_model.backbone.fc = nn.Linear(in_features=pretrained_model.backbone.fc.in_features, out_features=10)

        pretrained_model.backbone.fc.requires_grad = True

        self.model = pretrained_model

    def forward(self, x):
        logits = self.model(x)
        return logits

    def training_step(self, batch, batch_idx):
        images, labels = batch
        y_pred = self.forward(images)
        loss = self.loss_fn(y_pred,labels)
        accuracy = self.accuracy(y_pred,labels)
        f1_score = self.f1_score(y_pred,labels)
        self.log_dict({'train_loss': loss, 'train_accuracy': accuracy, 'train_f1_score': f1_score},
                      on_step = False, on_epoch = True, prog_bar = True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        y_pred = self.forward(images)
        loss = self.loss_fn(y_pred,labels)
        accuracy = self.accuracy(y_pred,labels)
        f1_score = self.f1_score(y_pred,labels)
        self.log_dict({'valid_loss': loss, 'valid_accuracy': accuracy, 'valid_f1_score': f1_score},
                      on_step = False, on_epoch = True, prog_bar = True)
        return loss

    def test_step(self, batch, batch_idx):
        images, labels = batch
        y_pred = self.forward(images)
        loss = self.loss_fn(y_pred,labels)
        accuracy = self.accuracy(y_pred,labels)
        f1_score = self.f1_score(y_pred,labels)
        self.log_dict({'test_loss': loss, 'test_accuracy': accuracy, 'test_f1_score': f1_score},
                      on_step = False, on_epoch = True, prog_bar = True)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr = 0.001)



dm_ce = MnistDataModuleCE(data_dir = "dataset/", batch_size = batch_size, num_workers = 1)

final_model = SupConCE().to(device)

trainer_ce = pl.Trainer(accelerator = "cuda", devices = [0], min_epochs = 1, max_epochs = 1, precision = 16)

trainer_ce.fit(final_model, dm_ce)

trainer.validate(final_model, dm_ce)

trainer.test(final_model, dm_ce)

INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name     | Type               | Params
------------------------------------------------
0 | loss_fn  | CrossEntropyLoss   | 0     
1 | accuracy | MulticlassAccuracy | 0     
2 | f1_score | MulticlassF1Score  | 0     
3 | model    | Encoder            | 23.5 M
------------------------------------------------
20.5 K    Trainable params
23.5 M    Non-trainable params
23.5 M    Total params
94.114    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': 0.2413640171289444,
  'test_accuracy': 0.9704999923706055,
  'test_f1_score': 0.9704999923706055}]