In [2]:
import pprint
import torch
import numpy as np
import urllib
from PIL import Image
from torch import nn
import torchvision
import torchvision.models as models
import torchvision.transforms as T
from torchvision.datasets import CIFAR10

from pl_bolts.models.self_supervised import SimCLR
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Trainer


## Dataset Download

In [7]:
train_set = CIFAR10('./', download=True,
transform=transforms.Compose([
transforms.ToTensor(),
]), train=True)


test_set = CIFAR10('./', download=True,
transform=transforms.Compose([
transforms.ToTensor(),
]), train=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


170499072it [00:13, 13034076.91it/s]                               


Extracting ./cifar-10-python.tar.gz to ./
Files already downloaded and verified


In [8]:
batch_size = 64
trainloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
valloader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True)

In [11]:
dataiter = iter(trainloader)
images, labels = dataiter.next()

print(images.shape)
print(labels.shape)

torch.Size([64, 3, 32, 32])
torch.Size([64])


In [57]:
dummy_image = images[0]
dummy_image.shape

torch.Size([3, 32, 32])

## Model

The aim is to run the model in evaluation mode. For now, implement a single image format where we have a single image and a set of transformations and we calculate the mean embeddings in the end.

In [71]:
class Encoder(LightningModule):
    def __init__(self, encoder='resnet50_supervised'):
        super().__init__()
        
        self.encoder = encoder

        #TODO: List of encoders in the configuration file
        if encoder not in ['resnet50_supervised', 'simclr_r50', 'vit_base_patch8_224', 'vit_base_patch16_224_in21k', 'vit_base_patch32_224_in21k']:
            raise AssertionError("Encoder not in the list of supported encoders.")
        
        
        if(self.encoder == 'resnet50_supervised'):
            backbone = models.resnet50(pretrained=True)
            num_filters = backbone.fc.in_features
            layers = list(backbone.children())[:-1]
            self.feature_extractor = nn.Sequential(*layers)

        elif(self.encoder == 'simclr_r50'):
            weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt'
            simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)
            self.feature_extractor = simclr.encoder

        elif(self.encoder == 'vit_base_patch16_224_in21k'):
            self.feature_extractor = timm.create_model('vit_base_patch16_224_in21k', pretrained=True, num_classes=0)
            config = resolve_data_config({}, model=model)
            transform = create_transform(**config)

        elif(self.encoder == 'vit_base_patch32_224_in21k'):
            self.feature_extractor = timm.create_model('vit_base_patch32_224_in21k', pretrained=True, num_classes=0)
            config = resolve_data_config({}, model=model)
            transform = create_transform(**config)
            

    def forward(self, x):
        self.feature_extractor.eval()
        with torch.no_grad():
            if(self.encoder == 'resnet50_supervised'):
                representations = self.feature_extractor(x).flatten(1)
            elif(self.encoder == 'simclr_r50'):
                representations = self.feature_extractor(x)[0]
            elif(self.encoder == 'vit_base_patch16_224_in21k' or self.encoder == 'vit_base_patch32_224_in21k'):
                representations = self.feature_extractor(x)
            

        return representations

In [24]:
aug_list = [T.RandomGrayscale(p=0.2), T.RandomHorizontalFlip(),
             T.ColorJitter(0.4, 0.4, 0.4, 0.1)]


test_img = torch.randn(3,224,224)
test_tensor = torch.unsqueeze(test_img, 0)
print(test_tensor.shape)

torch.Size([1, 3, 224, 224])


In [72]:
'''
For each image - 
1. Generate an augmented version using the augmentation from the list
2. Pass this augmented version through the encoder
3. Get the embedding and append it to a list
4. Get the mean embeddings
'''

all_embeddings = torch.tensor([])
encoder = Encoder(encoder='vit_base_patch32_224_in21k')

for aug in aug_list:
    print("Augmentation: ", aug)
    preprocess = T.Compose([T.ToPILImage(), aug, T.ToTensor()])
    aug_img = preprocess(test_img)
    print(aug_img.size())

    embedding = encoder(aug_img.unsqueeze(0))
    print('embedding: ', embedding.size())

    all_embeddings = torch.cat((all_embeddings,
                                embedding), 0)

torch.mean(all_embeddings, 0).size()

Augmentation:  RandomGrayscale(p=0.2)
torch.Size([3, 224, 224])
embedding:  torch.Size([1, 768])
Augmentation:  RandomHorizontalFlip(p=0.5)
torch.Size([3, 224, 224])
embedding:  torch.Size([1, 768])
Augmentation:  ColorJitter(brightness=[0.6, 1.4], contrast=[0.6, 1.4], saturation=[0.6, 1.4], hue=[-0.1, 0.1])
torch.Size([3, 224, 224])
embedding:  torch.Size([1, 768])


torch.Size([768])