In [34]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50, ResNet50_Weights
import torch.optim as optimizer
import numpy as np
import sys
from tqdm import tqdm

## Transform

1. Random Cropping with size ratio between .08 and 1.0 with resizing. RandomResizedCrop(32, scale=(0.08, 0.1)) in PyTorch.
2. Random horizontal flip with probability 0.5.
3. Color jittering of brightness, contrast, saturation and hue, with probability 0.8.
ColorJitter(0.4, 0.4, 0.2, 0.1) in PyTorch.
4. Grayscale with probability 0.2
5. Gaussian blur with probability 0.5 and kernel size 23. (Do we keep the sample kernel size for cifar-10?)
6. Solarization with probability 0.1.
7. Color normalization with mean (0.485, 0.456, 0.406) and standard deviation (0.229, 0.224,
0.225).

https://pytorch.org/vision/main/transforms.html

GAUSSIAN_BLUR:
https://pytorch.org/vision/main/generated/torchvision.transforms.GaussianBlur.html#torchvision.transforms.GaussianBlur
    Inputs: 
    - kernel_size (int or sequence) – Size of the Gaussian kernel.
    - sigma (float or tuple of python:float (min, max)) – Standard deviation to be used for creating kernel to perform blurring.
    If float, sigma is fixed. If it is tuple of float (min, max), sigma is chosen uniformly at random to lie in the given range.



In [25]:
kernel_size=23
sigma=(0.1, 2.0)
solarize_threshold = .5
transform = transforms.Compose([
    transforms.RandomResizedCrop(size=(32, 32), antialias=True),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(0.4, 0.4, 0.2, 0.1),
    transforms.RandomGrayscale(.2), # [BETA] Randomly convert image or videos to grayscale with a probability of p (default 0.1).
    transforms.GaussianBlur(kernel_size , sigma), # [BETA] Blurs image with randomly chosen Gaussian blur.
    transforms.RandomSolarize(solarize_threshold, p = .1),
    #transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


## VicReg Loss Function

In [14]:
lamb = 25 # invar loss weight
mu = 25 # var loss weight
nu = 1 # covar loss weight

In [32]:
def calculate_loss(z, z_prime):
    """Calculate the loss function. 
    
    The following is heavily based on the psuedo code provided on page 13 of https://arxiv.org/pdf/2105.04906.pdf
    
    Need to calculate 3 things: 
    
    1. Variance
    
    2. Invariance
    
    3. Covariance
    
    Args:
        z (_type_): batch of images transformed, encodeded, projected
        z_prime (_type_): batch of images transformed, encodeded, projected
    """
    
    # 1. Variance Loss
    var_epsilon = 1e-4
    std_z = torch.sqrt(z.var(dim=0) + var_epsilon)
    std_z_prime = torch.sqrt(z_prime.var(dim=0) + var_epsilon)
    std_loss = torch.mean(F.relu(1 - std_z)) + torch.mean(F.relu(1 - std_z_prime))
    
    
    # 2. Invariance Loss (Just MSE Loss)
    invar_loss = F.mse_loss(z, z_prime)
    
    # 3. Covariance Loss
    
    N , D = z.shape
    z = z - z.mean(dim=0)
    z_prime = z_prime - z_prime.mean(dim=0)
    cov_z = (z.T @ z) / (N - 1)
    cov_z_prime = (z_prime.T @ z_prime) / (N - 1)
    
    cov_z = cov_z.pow(2)
    cov_z_prime = cov_z_prime.pow(2)
    
    loss_c_a = (cov_z.sum() - cov_z.diagonal().sum()) / D
    loss_c_b = (cov_z_prime.sum() - cov_z_prime.diagonal().sum()) / D
    
    loss_cov = loss_c_a + loss_c_b
    
    loss = lamb * invar_loss + mu * std_loss + nu * loss_cov
    
    return loss

In [16]:
class SameTransform:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, x):
        return self.transform(x), self.transform(x)

In [17]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 256
default_transform = transforms.Compose([
    # you can add other transformations in this list
    transforms.ToTensor()
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download = True, transform = default_transform)

split_ratio = 0.8
total_size = len(trainset)
train_size = int(split_ratio * total_size)
valid_size = total_size - train_size
print(train_size)
train_dataset, valid_dataset = torch.utils.data.random_split(trainset, [train_size, valid_size])

trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
validloader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=8)


testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform = default_transform)

testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=8)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
40000
Files already downloaded and verified


## Encoder

- Standard ResNet - 50 Backbone

In [18]:
import torchvision.models as models

class Encoder(nn.Module):
    def __init__(self, output_units=512):
        super(Encoder, self).__init__()
        # Load pre-trained ResNet-50 model from torchvision
        resnet = models.resnet50(pretrained=True)
        
        # Remove the fully connected layers at the end
        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)
        
        # Add global average pooling layer
        self.global_avg_pooling = nn.AdaptiveAvgPool2d((1, 1))
        
        # Output projection layer
        self.projection_layer = nn.Linear(2048, output_units)
    
    def forward(self, x):
        # Forward pass through ResNet-50 backbone
        x = self.resnet(x)
        
        # Global average pooling
        x = self.global_avg_pooling(x)
        x = x.view(x.size(0), -1)
        
        # Projection layer
        x = self.projection_layer(x)
        
        return x

In [19]:

# Create an instance of the Encoder with 512 output units
encoder = Encoder(output_units=512)

# Test the encoder with a random input
random_input = torch.randn((1, 3, 224, 224))  # Assuming input image size is 224x224
output = encoder(random_input)

print("Output shape:", output.shape)

Output shape: torch.Size([1, 512])


In [20]:
class Expander(nn.Module):
    def __init__(self, input_size, output_size=8192):
        """
        
        expander hφ:
        Composed of two fully-connected layers with batch normalization and ReLU,
        and a third linear layer. The sizes of all 3 layers were set to 8192

        Args:
            input_size (int): cifar vector size 
            output_size (int): output vector size, also size of intermediate linear layers.
        """
        super(Expander, self).__init__()

          # Flatten layer
        self.flatten = nn.Flatten()
        # First fully-connected layer
        self.fc1 = nn.Linear(input_size, output_size)
        self.bn1 = nn.BatchNorm1d(output_size)
        self.relu1 = nn.ReLU()

        # Second fully-connected layer
        self.fc2 = nn.Linear(output_size, output_size)
        self.bn2 = nn.BatchNorm1d(output_size)
        self.relu2 = nn.ReLU()

        # Third linear layer
        self.fc3 = nn.Linear(output_size, output_size)

    def forward(self, x):
        # Forward pass through the layers
        x = self.flatten(x)
        x = self.fc1(x)
        if x.size(0) > 1:
            x = self.bn1(x)

        x = self.relu1(x)

        x = self.fc2(x)
        
        if x.size(0) > 1:
            x = self.bn2(x)

        x = self.relu2(x)

        x = self.fc3(x)

        return x

In [21]:
expander = Expander(input_size=512, output_size=2048)
expanded_output = expander(output)
print("Expanded output shape:", expanded_output.shape)

Expanded output shape: torch.Size([1, 2048])


In [22]:
class VICReg(nn.Module):
    
    def __init__(self, encoder_size, expander_size):
       
        super().__init__()
        self.encoder = Encoder(encoder_size)
        self.expander = Expander(encoder_size, expander_size)

    def forward(self, x):
        x = self.encoder(x)
        x = self.expander(x)
        
        return x  

In [23]:


train_epochs = 50

# Hyper parameters
learning_rate_decay = 10e-6
batch_size = 256
encoder_size = 256
expander_size = 512
base_lr = .01

#encoder = ResNet50_Weights.DEFAULT.transforms()
#expander = Expander(encoder_size, expander_size)

lr = (batch_size / 256) * base_lr

model = VICReg(encoder_size, expander_size)
model.to(device)
params = model.parameters()
optimiz = optimizer.SGD(params, lr = lr, weight_decay= learning_rate_decay)

In [35]:
def train_loop(model, optimizer, trainloader, criterion, device):
    tk0 = tqdm(trainloader)
    train_loss = []

    for batch, _ in tk0:

        model.train()

        batch = batch.to(device)

        x = transform(batch)
        x1 = transform(batch)
    
        fx = model(x)
        fx1 = model(x1)
        loss = criterion(fx, fx1)
        #print(loss)
        train_loss.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return train_loss


for epoch in range(train_epochs):
    train_loss = train_loop(model, optimiz, trainloader, calculate_loss, device)
    print(np.mean(train_loss))

  0%|          | 0/157 [00:00<?, ?it/s]

encoder done
encoder done
tensor(38.6075, device='cuda:0', grad_fn=<AddBackward0>)


  1%|▏         | 2/157 [00:18<19:48,  7.67s/it]

encoder done
encoder done
tensor(39.2890, device='cuda:0', grad_fn=<AddBackward0>)


  2%|▏         | 3/157 [00:18<11:23,  4.44s/it]

encoder done
encoder done
tensor(38.6002, device='cuda:0', grad_fn=<AddBackward0>)


  3%|▎         | 4/157 [00:19<07:27,  2.92s/it]

encoder done
encoder done
tensor(38.3940, device='cuda:0', grad_fn=<AddBackward0>)


  3%|▎         | 5/157 [00:20<05:16,  2.08s/it]

encoder done
encoder done
tensor(38.8999, device='cuda:0', grad_fn=<AddBackward0>)


  4%|▍         | 6/157 [00:20<03:57,  1.58s/it]

encoder done
encoder done
tensor(38.6960, device='cuda:0', grad_fn=<AddBackward0>)


  4%|▍         | 7/157 [00:21<03:08,  1.25s/it]

encoder done
encoder done
tensor(38.7402, device='cuda:0', grad_fn=<AddBackward0>)


  5%|▌         | 8/157 [00:21<02:35,  1.05s/it]

encoder done
encoder done
tensor(38.2629, device='cuda:0', grad_fn=<AddBackward0>)


  6%|▌         | 9/157 [00:22<02:13,  1.11it/s]

encoder done
encoder done
tensor(38.1990, device='cuda:0', grad_fn=<AddBackward0>)


  6%|▋         | 10/157 [00:23<01:58,  1.24it/s]

encoder done
encoder done
tensor(38.7832, device='cuda:0', grad_fn=<AddBackward0>)


  7%|▋         | 11/157 [00:23<01:48,  1.35it/s]

encoder done
encoder done
tensor(38.2666, device='cuda:0', grad_fn=<AddBackward0>)


  8%|▊         | 12/157 [00:24<01:41,  1.43it/s]

encoder done
encoder done
tensor(37.8624, device='cuda:0', grad_fn=<AddBackward0>)


  8%|▊         | 13/157 [00:24<01:36,  1.50it/s]

encoder done
encoder done
tensor(38.0759, device='cuda:0', grad_fn=<AddBackward0>)


  9%|▉         | 14/157 [00:25<01:31,  1.56it/s]

encoder done
encoder done
tensor(37.5139, device='cuda:0', grad_fn=<AddBackward0>)


 10%|▉         | 15/157 [00:26<01:28,  1.60it/s]

encoder done
encoder done
tensor(37.3466, device='cuda:0', grad_fn=<AddBackward0>)


 10%|█         | 16/157 [00:26<01:26,  1.63it/s]

encoder done
encoder done
tensor(37.9120, device='cuda:0', grad_fn=<AddBackward0>)


 11%|█         | 17/157 [00:27<01:25,  1.64it/s]

encoder done
encoder done
tensor(37.8515, device='cuda:0', grad_fn=<AddBackward0>)


 11%|█▏        | 18/157 [00:27<01:24,  1.64it/s]

encoder done
encoder done
tensor(37.2447, device='cuda:0', grad_fn=<AddBackward0>)


 12%|█▏        | 19/157 [00:28<01:23,  1.65it/s]

encoder done
encoder done
tensor(37.0279, device='cuda:0', grad_fn=<AddBackward0>)


 13%|█▎        | 20/157 [00:29<01:22,  1.67it/s]

encoder done
encoder done
tensor(37.5178, device='cuda:0', grad_fn=<AddBackward0>)


 13%|█▎        | 21/157 [00:29<01:21,  1.66it/s]

encoder done
encoder done
tensor(37.0042, device='cuda:0', grad_fn=<AddBackward0>)


 14%|█▍        | 22/157 [00:30<01:20,  1.67it/s]

encoder done
encoder done
tensor(38.1645, device='cuda:0', grad_fn=<AddBackward0>)


 15%|█▍        | 23/157 [00:30<01:20,  1.67it/s]

encoder done
encoder done
tensor(38.3390, device='cuda:0', grad_fn=<AddBackward0>)


 15%|█▌        | 24/157 [00:31<01:19,  1.66it/s]

encoder done
encoder done
tensor(36.6897, device='cuda:0', grad_fn=<AddBackward0>)


 16%|█▌        | 25/157 [00:32<01:19,  1.67it/s]

encoder done
encoder done
tensor(38.6175, device='cuda:0', grad_fn=<AddBackward0>)


 17%|█▋        | 26/157 [00:32<01:18,  1.67it/s]

encoder done
encoder done
tensor(37.5574, device='cuda:0', grad_fn=<AddBackward0>)


 17%|█▋        | 27/157 [00:33<01:17,  1.68it/s]

encoder done
encoder done
tensor(36.3419, device='cuda:0', grad_fn=<AddBackward0>)


 18%|█▊        | 28/157 [00:33<01:17,  1.67it/s]

encoder done
encoder done
tensor(37.4368, device='cuda:0', grad_fn=<AddBackward0>)


 18%|█▊        | 29/157 [00:34<01:16,  1.67it/s]

encoder done
encoder done
tensor(37.3829, device='cuda:0', grad_fn=<AddBackward0>)


 19%|█▉        | 30/157 [00:35<01:16,  1.66it/s]

encoder done
encoder done
tensor(37.3691, device='cuda:0', grad_fn=<AddBackward0>)


 20%|█▉        | 31/157 [00:35<01:16,  1.66it/s]

encoder done
encoder done
tensor(37.5535, device='cuda:0', grad_fn=<AddBackward0>)


 20%|██        | 32/157 [00:36<01:14,  1.67it/s]

encoder done
encoder done
tensor(35.3534, device='cuda:0', grad_fn=<AddBackward0>)


 21%|██        | 33/157 [00:36<01:14,  1.67it/s]

encoder done
encoder done
tensor(35.6452, device='cuda:0', grad_fn=<AddBackward0>)


 22%|██▏       | 34/157 [00:37<01:13,  1.68it/s]

encoder done
encoder done
tensor(36.7222, device='cuda:0', grad_fn=<AddBackward0>)


 22%|██▏       | 35/157 [00:38<01:13,  1.67it/s]

encoder done
encoder done
tensor(36.6905, device='cuda:0', grad_fn=<AddBackward0>)


 23%|██▎       | 36/157 [00:38<01:12,  1.66it/s]

encoder done
encoder done
tensor(37.0516, device='cuda:0', grad_fn=<AddBackward0>)


 24%|██▎       | 37/157 [00:39<01:11,  1.68it/s]

encoder done
encoder done
tensor(36.7217, device='cuda:0', grad_fn=<AddBackward0>)


 24%|██▍       | 38/157 [00:39<01:10,  1.68it/s]

encoder done
encoder done
tensor(36.2134, device='cuda:0', grad_fn=<AddBackward0>)


 25%|██▍       | 39/157 [00:40<01:11,  1.64it/s]

encoder done
encoder done
tensor(36.4855, device='cuda:0', grad_fn=<AddBackward0>)


 25%|██▌       | 40/157 [00:41<01:11,  1.63it/s]

encoder done
encoder done
tensor(35.8487, device='cuda:0', grad_fn=<AddBackward0>)


 26%|██▌       | 41/157 [00:41<01:10,  1.64it/s]

encoder done
encoder done
tensor(35.4543, device='cuda:0', grad_fn=<AddBackward0>)


 27%|██▋       | 42/157 [00:42<01:09,  1.66it/s]

encoder done
encoder done
tensor(38.4773, device='cuda:0', grad_fn=<AddBackward0>)


 27%|██▋       | 43/157 [00:42<01:09,  1.65it/s]

encoder done
encoder done
tensor(35.7569, device='cuda:0', grad_fn=<AddBackward0>)


 28%|██▊       | 44/157 [00:43<01:07,  1.67it/s]

encoder done
encoder done
tensor(37.5954, device='cuda:0', grad_fn=<AddBackward0>)
