In [1]:
from torchvision import transforms

train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomResizedCrop(32, scale=(0.2, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1),
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])
test_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])


In [2]:
transform = transforms.Compose([
    transforms.ToTensor()  # Converts PIL images to PyTorch tensors
])

In [3]:
import torch
import torch.nn as nn
from torchvision.models import resnet18


class Encoder(nn.Module):
    def __init__(self, D=128, device='cuda'):
        super(Encoder, self).__init__()
        self.resnet = resnet18(pretrained=False).to(device)
        self.resnet.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=1)
        self.resnet.maxpool = nn.Identity()
        self.resnet.fc = nn.Linear(512, 512)
        self.fc = nn.Sequential(nn.BatchNorm1d(512), nn.ReLU(inplace=True), nn.Linear(512, D))

    def forward(self, x):
        x = self.resnet(x)
        x = self.fc(x)
        return x

    def encode(self, x):
        return self.forward(x)


class Projector(nn.Module):
    def __init__(self, D, proj_dim=512):
        super(Projector, self).__init__()
        self.model = nn.Sequential(nn.Linear(D, proj_dim),
                                   nn.BatchNorm1d(proj_dim),
                                   nn.ReLU(inplace=True),
                                   nn.Linear(proj_dim, proj_dim),
                                   nn.BatchNorm1d(proj_dim),
                                   nn.ReLU(inplace=True),
                                   nn.Linear(proj_dim, proj_dim)
                                   )

    def forward(self, x):
        return self.model(x)




In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [6]:
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)

test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 59378568.39it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [7]:
class VICReg(nn.Module):
    def __init__(self, encoder, projector, device='cuda'):
        super(VICReg, self).__init__()
        self.encoder = encoder
        self.projector = projector
        self.device = device

    def forward(self, x):
        z = self.encoder(x)
        p = self.projector(z)
        return p

In [8]:
encoder = Encoder(D=128)
projector = Projector(D=128, proj_dim=512)
model = VICReg(encoder, projector, device)
model.to(device)




VICReg(
  (encoder): Encoder(
    (resnet): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): Identity()
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(

In [9]:
def off_diagonal(x): #used chatGPT
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

In [10]:
def loss(x, y, lambda_param=25., mu=25., nu=1., sigma=1e-4, d=512, eps=1e-4, gamma=1):
  # L inverse
  inv_loss = nn.MSELoss()(x,y)

  # variance loss
  x_sigma = torch.sqrt(torch.var(x, dim=0) + eps)
  y_sigma = torch.sqrt(torch.var(y, dim=0) + eps)
  x_var = torch.mean(torch.max(gamma - x_sigma, torch.zeros_like(x_sigma)))
  y_var = torch.mean(torch.max(gamma - y_sigma, torch.zeros_like(y_sigma)))
  var_loss = x_var + y_var

  # cov loss
  x_cov = torch.cov(x)
  y_cov = torch.cov(y)
  cov_loss = (off_diagonal(x_cov).pow_(2).sum() / 256) + (off_diagonal(y_cov).pow_(2).sum() / 256)

  total_loss = lambda_param * inv_loss + mu * var_loss + nu * cov_loss

  return total_loss, inv_loss, var_loss, cov_loss


In [11]:
optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999), weight_decay=1e-6)

In [None]:
from tqdm import tqdm
train_loss = []
train_inv_loss = []
train_var_loss = []
train_cov_loss = []
test_loss = []
test_inv_loss = []
test_var_loss = []
test_cov_loss = []

num_epochs = 10
for epoch in range(num_epochs):
  model.train()
  batch_loss = []
  batch_inv_loss = []
  batch_var_loss = []
  batch_cov_loss = []
  for imgs, _ in tqdm(train_loader):

    imgs = imgs.to(model.device)

    imgs1 = torch.stack([train_transform(img) for img in imgs]).to(model.device)
    imgs2 = torch.stack([train_transform(img) for img in imgs]).to(model.device)

    # imgs1 = train_transform(imgs)
    # imgs2 = test_transform(imgs)

    res1, res2 = model(imgs1), model(imgs2)

    total_loss, inv_loss, var_loss, cov_loss = loss(res1, res2)

    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    batch_inv_loss.append(inv_loss.item())
    batch_var_loss.append(var_loss.item())
    batch_cov_loss.append(cov_loss.item())
    batch_loss.append(total_loss.item())

  train_inv_loss.append(sum(batch_inv_loss) / len(batch_inv_loss))
  train_var_loss.append(sum(batch_var_loss) / len(batch_var_loss))
  train_cov_loss.append(sum(batch_cov_loss) / len(batch_cov_loss))
  train_loss.append(sum(batch_loss) / len(batch_loss))

  print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss.item():.4f}')

  model.eval()
  with torch.no_grad():
    test_batch_loss = []
    test_batch_inv_loss = []
    test_batch_var_loss = []
    test_batch_cov_loss = []
    for batch in test_loader:
      imgs, _ = batch
      imgs = imgs.to(model.device)

      # imgs1 = test_transform(imgs)
      # imgs2 = test_transform(imgs)
      imgs1 = torch.stack([train_transform(img) for img in imgs]).to(model.device)
      imgs2 = torch.stack([train_transform(img) for img in imgs]).to(model.device)


      res1, res2 = model(imgs1), model(imgs2)
      total_loss, inv_loss, var_loss, cov_loss = loss(res1, res2)
      test_batch_inv_loss.append(inv_loss.item())
      test_batch_var_loss.append(var_loss.item())
      test_batch_cov_loss.append(cov_loss.item())
      test_batch_loss.append(total_loss.item())

    test_inv_loss.append(sum(test_batch_inv_loss) / len(test_batch_inv_loss))
    test_var_loss.append(sum(test_batch_var_loss) / len(test_batch_var_loss))
    test_cov_loss.append(sum(test_batch_cov_loss) / len(test_batch_cov_loss))
    test_loss.append(sum(test_batch_loss) / len(test_batch_loss))



100%|██████████| 196/196 [04:16<00:00,  1.31s/it]


Epoch [1/10], Loss: 12.3682


 52%|█████▏    | 101/196 [02:08<01:57,  1.23s/it]

In [None]:
import plotly.graph_objects as go
import plotly.subplots as sp

# Plot training and test loss components
fig = sp.make_subplots(rows=2, cols=2, subplot_titles=("Invariance Loss", "Variance Loss", "Covariance Loss"))

# Invariance Loss
fig.add_trace(go.Scatter(x=list(range(num_epochs)), y=train_inv_loss, mode='lines+markers', name='Train Invariance Loss'), row=1, col=1)
fig.add_trace(go.Scatter(x=list(range(num_epochs)), y=test_inv_loss, mode='lines+markers', name='Test Invariance Loss', line=dict(dash='dash')), row=1, col=1)

# Variance Loss
fig.add_trace(go.Scatter(x=list(range(num_epochs)), y=train_var_loss, mode='lines+markers', name='Train Variance Loss'), row=1, col=2)
fig.add_trace(go.Scatter(x=list(range(num_epochs)), y=test_var_loss, mode='lines+markers', name='Test Variance Loss', line=dict(dash='dash')), row=1, col=2)

# Covariance Loss
fig.add_trace(go.Scatter(x=list(range(num_epochs)), y=train_cov_loss, mode='lines+markers', name='Train Covariance Loss'), row=2, col=1)
fig.add_trace(go.Scatter(x=list(range(num_epochs)), y=test_cov_loss, mode='lines+markers', name='Test Covariance Loss', line=dict(dash='dash')), row=2, col=1)

# OverallLoss
fig.add_trace(go.Scatter(x=list(range(num_epochs)), y=train_loss, mode='lines+markers', name='Train Loss'), row=2, col=2)
fig.add_trace(go.Scatter(x=list(range(num_epochs)), y=test_loss, mode='lines+markers', name='Test Loss', line=dict(dash='dash')), row=2, col=2)
# Update layout
fig.update_layout(title_text="VICReg Loss Components over Training Epochs", height=600, width=1800)
fig.update_xaxes(title_text="Epochs")
fig.update_yaxes(title_text="Loss")

# Show plot
fig.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

# Step 1: Load the Test Images and Encode them
model.eval()
representations = []
labels = []
with torch.no_grad():
    for imgs, lbls in test_loader:
        imgs = imgs.to(model.device)
        reps = model.encoder.encode(imgs)
        representations.append(reps.cpu().numpy())
        labels.extend(lbls.numpy())

representations = np.concatenate(representations, axis=0)

# Step 2: Apply PCA and T-SNE
pca = PCA(n_components=2)
pca_result = pca.fit_transform(representations)

tsne = TSNE(n_components=2, perplexity=30, n_iter=300)
tsne_result = tsne.fit_transform(representations)

# Step 3: Plot the Results
def plot_reduction(result, labels, title):
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(result[:, 0], result[:, 1], c=labels, cmap='viridis', alpha=0.7)
    plt.colorbar(scatter, label='Classes')
    plt.title(title)
    plt.xlabel('Component 1')
    plt.ylabel('Component 2')
    plt.show()

plot_reduction(pca_result, labels, 'PCA 2D Representation')
plot_reduction(tsne_result, labels, 'T-SNE 2D Representation')


In [None]:
import torch.nn.functional as F

# Step 1: Freeze the Encoder
for param in model.encoder.parameters():
    param.requires_grad = False

# Step 2: Create the Classifier
class LinearProbingClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(LinearProbingClassifier, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.fc(x)

classifier = LinearProbingClassifier(input_dim=128, num_classes=10)
classifier.to(device)

# Step 3: Train the Classifier
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=1e-3)

num_epochs = 10
for epoch in range(num_epochs):
    classifier.train()
    total_loss = 0.0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)

        with torch.no_grad():
            representations = model.encoder.encode(imgs)

        optimizer.zero_grad()
        outputs = classifier(representations)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')

# Step 4: Evaluate the Classifier
classifier.eval()
correct = 0
total = 0
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        representations = model.encoder.encode(imgs)
        outputs = classifier(representations)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Accuracy of the classifier on the test set: {accuracy:.2f}%')
