#### Please check the previously saved versions of this notebook on - https://github.com/AvtnshM/SSL


In [17]:
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

cudnn.benchmark = True


Using device: cuda


In [18]:
# CIFAR10 normalization values
cifar_mean = (0.4914, 0.4822, 0.4465)
cifar_std = (0.2023, 0.1994, 0.2010)

ssl_transform = transforms.Compose([
    transforms.RandomResizedCrop(32, scale=(0.2, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
    ], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.GaussianBlur(3),
    transforms.ToTensor(),
    transforms.Normalize(cifar_mean, cifar_std)
])

In [19]:
class BYOLDataset(torch.utils.data.Dataset):
    def __init__(self, root, train, transform):
        self.dataset = datasets.CIFAR10(
            root=root,
            train=train,
            download=True,
            transform=None  # IMPORTANT: no transform here
        )
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, _ = self.dataset[idx]   # img is PIL image
        view1 = self.transform(img)
        view2 = self.transform(img)
        return view1, view2


ssl_dataset = BYOLDataset(
    root="/kaggle/working/data",
    train=True,
    transform=ssl_transform
)


In [20]:
loader = DataLoader(
    ssl_dataset,
    batch_size=256,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True
)

print("Batches per epoch:", len(loader))

Batches per epoch: 195


In [21]:
import torchvision.models as models

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()

        # Load ResNet18
        backbone = models.resnet18(weights=None)

        # Remove final classification layer
        self.encoder = nn.Sequential(*list(backbone.children())[:-1])

    def forward(self, x):
        x = self.encoder(x)
        return x.view(x.size(0), -1)  # Output: (B, 512)


In [22]:
class Predictor(nn.Module):
    def __init__(self, dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, dim)
        )

    def forward(self, x):
        return self.net(x)

In [23]:
class Projector(nn.Module):
    def __init__(self, in_dim=512, hidden_dim=512, out_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x):
        return self.net(x)

In [24]:
def ssl_loss(p, z):
    p = F.normalize(p, dim=1)
    z = F.normalize(z, dim=1)
    return F.mse_loss(p, z)


In [25]:
encoder = Encoder().to(device)
projector = Projector().to(device)
predictor = Predictor().to(device)

target_encoder = Encoder().to(device)
target_encoder.load_state_dict(encoder.state_dict())

for p in target_encoder.parameters():
    p.requires_grad = False

optimizer = torch.optim.Adam(
    list(encoder.parameters()) +
    list(projector.parameters()) +
    list(predictor.parameters()),
    lr=1e-3
)

ema_tau = 0.996


In [26]:
@torch.no_grad()
def feature_variance(z):
    return z.var(dim=0).mean().item()

@torch.no_grad()
def cosine_similarity_mean(z1, z2):
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    return (z1 * z2).sum(dim=1).mean().item()

@torch.no_grad()
def update_target_encoder(online, target, tau):
    for op, tp in zip(online.parameters(), target.parameters()):
        tp.data = tau * tp.data + (1 - tau) * op.data


In [27]:
epochs = 100
os.makedirs("/kaggle/working/checkpoints", exist_ok=True)

total_start_time = time.time()

for epoch in range(epochs):

    epoch_start = time.time()

    total_loss, total_var, total_cos = 0.0, 0.0, 0.0

    for view1, view2 in loader:

        view1 = view1.to(device)
        view2 = view2.to(device)

        z1 = encoder(view1)
        z2 = encoder(view2)

        h1 = projector(z1)
        h2 = projector(z2)

        p1 = predictor(h1)
        p2 = predictor(h2)

        with torch.no_grad():
            t1 = projector(target_encoder(view1))
            t2 = projector(target_encoder(view2))

        loss = ssl_loss(p1, t2.detach()) + ssl_loss(p2, t1.detach())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        update_target_encoder(encoder, target_encoder, ema_tau)

        with torch.no_grad():
            total_var += feature_variance(z1)
            total_cos += cosine_similarity_mean(z1, z2)

        total_loss += loss.item()

    epoch_time = (time.time() - epoch_start) / 60
    total_time = (time.time() - total_start_time) / 60

    print(
        f"Epoch {epoch+1}/{epochs} | "
        f"Loss {total_loss/len(loader):.4f} | "
        f"Var {total_var/len(loader):.4f} | "
        f"Cos {total_cos/len(loader):.4f} | "
        f"Epoch Time {epoch_time:.2f} min | "
        f"Total Time {total_time:.2f} min"
    )

# Save only final model (avoid 5GB explosion)
torch.save(encoder.state_dict(), "/kaggle/working/encoder_final.pt")

Epoch 1/100 | Loss 0.0010 | Var 2.4694 | Cos 0.6874 | Epoch Time 1.12 min | Total Time 1.12 min
Epoch 2/100 | Loss 0.0007 | Var 2.4105 | Cos 0.7370 | Epoch Time 1.12 min | Total Time 2.24 min
Epoch 3/100 | Loss 0.0009 | Var 2.2976 | Cos 0.7549 | Epoch Time 1.12 min | Total Time 3.36 min
Epoch 4/100 | Loss 0.0011 | Var 2.1429 | Cos 0.7708 | Epoch Time 1.12 min | Total Time 4.49 min
Epoch 5/100 | Loss 0.0012 | Var 2.0685 | Cos 0.7911 | Epoch Time 1.12 min | Total Time 5.61 min
Epoch 6/100 | Loss 0.0013 | Var 2.0318 | Cos 0.8032 | Epoch Time 1.13 min | Total Time 6.73 min
Epoch 7/100 | Loss 0.0014 | Var 2.0265 | Cos 0.8161 | Epoch Time 1.14 min | Total Time 7.87 min
Epoch 8/100 | Loss 0.0015 | Var 2.0014 | Cos 0.8202 | Epoch Time 1.13 min | Total Time 9.00 min
Epoch 9/100 | Loss 0.0016 | Var 1.9696 | Cos 0.8273 | Epoch Time 1.13 min | Total Time 10.13 min
Epoch 10/100 | Loss 0.0016 | Var 1.9688 | Cos 0.8347 | Epoch Time 1.13 min | Total Time 11.27 min
Epoch 11/100 | Loss 0.0016 | Var 1.94

In [28]:
from torchvision import datasets, transforms

# Probe transform (NO augmentation, only normalization)
probe_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar_mean, cifar_std)
])

train_dataset = datasets.CIFAR10(
    root="/kaggle/working/data",
    train=True,
    download=False,
    transform=probe_transform
)

test_dataset = datasets.CIFAR10(
    root="/kaggle/working/data",
    train=False,
    download=False,
    transform=probe_transform
)

train_loader = DataLoader(
    train_dataset,
    batch_size=256,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=256,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

print("Train batches:", len(train_loader))
print("Test batches:", len(test_loader))

Train batches: 196
Test batches: 40


In [29]:
# Load final SSL weights
encoder = Encoder().to(device)
encoder.load_state_dict(torch.load("/kaggle/working/encoder_final.pt"))
encoder.eval()

# Freeze encoder
for p in encoder.parameters():
    p.requires_grad = False

print("Encoder loaded and frozen.")


Encoder loaded and frozen.


In [30]:

class LinearProbe(nn.Module):
    def __init__(self, in_dim=512, num_classes=10):
        super().__init__()
        self.fc = nn.Linear(in_dim, num_classes)

    def forward(self, x):
        return self.fc(x)


probe = LinearProbe(in_dim=512).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(probe.parameters(), lr=1e-3)


In [31]:
probe_epochs = 100

for epoch in range(probe_epochs):
    probe.train()
    total_loss = 0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            features = encoder(images)

        outputs = probe(features)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Probe Epoch {epoch+1}/{probe_epochs} | Loss {total_loss/len(train_loader):.4f}")


Probe Epoch 1/100 | Loss 1.3938
Probe Epoch 2/100 | Loss 1.2947
Probe Epoch 3/100 | Loss 1.2727
Probe Epoch 4/100 | Loss 1.2622
Probe Epoch 5/100 | Loss 1.2535
Probe Epoch 6/100 | Loss 1.2488
Probe Epoch 7/100 | Loss 1.2394
Probe Epoch 8/100 | Loss 1.2366
Probe Epoch 9/100 | Loss 1.2320
Probe Epoch 10/100 | Loss 1.2262
Probe Epoch 11/100 | Loss 1.2262
Probe Epoch 12/100 | Loss 1.2260
Probe Epoch 13/100 | Loss 1.2180
Probe Epoch 14/100 | Loss 1.2171
Probe Epoch 15/100 | Loss 1.2148
Probe Epoch 16/100 | Loss 1.2114
Probe Epoch 17/100 | Loss 1.2084
Probe Epoch 18/100 | Loss 1.2088
Probe Epoch 19/100 | Loss 1.2090
Probe Epoch 20/100 | Loss 1.2048
Probe Epoch 21/100 | Loss 1.2020
Probe Epoch 22/100 | Loss 1.2022
Probe Epoch 23/100 | Loss 1.1989
Probe Epoch 24/100 | Loss 1.1989
Probe Epoch 25/100 | Loss 1.1958
Probe Epoch 26/100 | Loss 1.1993
Probe Epoch 27/100 | Loss 1.1943
Probe Epoch 28/100 | Loss 1.1922
Probe Epoch 29/100 | Loss 1.1942
Probe Epoch 30/100 | Loss 1.1906
Probe Epoch 31/100 

In [32]:
probe.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        features = encoder(images)
        outputs = probe(features)

        _, predicted = torch.max(outputs, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"\nðŸ”¹ Linear Probe Test Accuracy: {accuracy:.2f}%")



ðŸ”¹ Linear Probe Test Accuracy: 56.01%
