In [1]:
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

cudnn.benchmark = True


Using device: cuda


In [2]:
base_transform = transforms.ToTensor()

ssl_transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
    transforms.ToTensor()
])

to_pil = transforms.ToPILImage()


In [3]:
dataset = datasets.CIFAR10(
    root="/kaggle/working/data",
    train=True,
    download=True,
    transform=base_transform
)


In [4]:
loader = DataLoader(
    dataset,
    batch_size=256,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True
)

print("Batches per epoch:", len(loader))


Batches per epoch: 195


In [5]:
import torchvision.models as models

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()

        # Load ResNet18
        backbone = models.resnet18(weights=None)

        # Remove final classification layer
        self.encoder = nn.Sequential(*list(backbone.children())[:-1])

    def forward(self, x):
        x = self.encoder(x)
        return x.view(x.size(0), -1)  # Output: (B, 512)


In [6]:
class Predictor(nn.Module):
    def __init__(self, dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, dim)
        )

    def forward(self, x):
        return self.net(x)

In [7]:
class Projector(nn.Module):
    def __init__(self, in_dim=512, hidden_dim=512, out_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x):
        return self.net(x)

In [8]:
def ssl_loss(p, z):
    p = F.normalize(p, dim=1)
    z = F.normalize(z, dim=1)
    return F.mse_loss(p, z)


In [9]:
encoder = Encoder().to(device)
projector = Projector().to(device)
predictor = Predictor().to(device)

target_encoder = Encoder().to(device)
target_encoder.load_state_dict(encoder.state_dict())

for p in target_encoder.parameters():
    p.requires_grad = False

optimizer = torch.optim.Adam(
    list(encoder.parameters()) +
    list(projector.parameters()) +
    list(predictor.parameters()),
    lr=1e-3
)

ema_tau = 0.99


In [10]:
@torch.no_grad()
def feature_variance(z):
    return z.var(dim=0).mean().item()

@torch.no_grad()
def cosine_similarity_mean(z1, z2):
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    return (z1 * z2).sum(dim=1).mean().item()

@torch.no_grad()
def update_target_encoder(online, target, tau):
    for op, tp in zip(online.parameters(), target.parameters()):
        tp.data = tau * tp.data + (1 - tau) * op.data


In [11]:
epochs = 50
os.makedirs("/kaggle/working/checkpoints", exist_ok=True)

for epoch in range(epochs):
    start = time.time()

    total_loss, total_var, total_cos = 0.0, 0.0, 0.0

    for images, _ in loader:
        images = images.to(device)

        view1 = torch.stack([
            ssl_transform(to_pil(img.cpu())) for img in images
        ]).to(device)

        view2 = torch.stack([
            ssl_transform(to_pil(img.cpu())) for img in images
        ]).to(device)

        z1 = encoder(view1)
        z2 = encoder(view2)

        h1 = projector(z1)
        h2 = projector(z2)

        p1 = predictor(h1)
        p2 = predictor(h2)

        with torch.no_grad():
            t1 = projector(target_encoder(view1))
            t2 = projector(target_encoder(view2))

        loss = ssl_loss(p1, t2.detach()) + ssl_loss(p2, t1.detach())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        update_target_encoder(encoder, target_encoder, ema_tau)

        with torch.no_grad():
            total_var += feature_variance(z1)
            total_cos += cosine_similarity_mean(z1, z2)

        total_loss += loss.item()

    print(
        f"Epoch {epoch+1}/{epochs} | "
        f"Loss {total_loss/len(loader):.4f} | "
        f"Var {total_var/len(loader):.4f} | "
        f"Cos {total_cos/len(loader):.4f} | "
        f"Time {(time.time()-start)/60:.2f} min"
    )

    torch.save(
        encoder.state_dict(),
        f"/kaggle/working/checkpoints/encoder_epoch_{epoch+1}.pt"
    )


Epoch 1/50 | Loss 0.0009 | Var 1.9817 | Cos 0.6455 | Time 2.27 min
Epoch 2/50 | Loss 0.0008 | Var 1.8870 | Cos 0.8097 | Time 2.19 min
Epoch 3/50 | Loss 0.0008 | Var 1.8653 | Cos 0.8648 | Time 2.20 min
Epoch 4/50 | Loss 0.0008 | Var 1.8646 | Cos 0.8825 | Time 2.19 min
Epoch 5/50 | Loss 0.0009 | Var 1.8538 | Cos 0.8886 | Time 2.19 min
Epoch 6/50 | Loss 0.0010 | Var 1.8269 | Cos 0.8935 | Time 2.19 min
Epoch 7/50 | Loss 0.0011 | Var 1.7906 | Cos 0.8974 | Time 2.19 min
Epoch 8/50 | Loss 0.0011 | Var 1.7658 | Cos 0.9025 | Time 2.18 min
Epoch 9/50 | Loss 0.0011 | Var 1.7558 | Cos 0.9034 | Time 2.18 min
Epoch 10/50 | Loss 0.0012 | Var 1.7675 | Cos 0.9023 | Time 2.18 min
Epoch 11/50 | Loss 0.0013 | Var 1.7556 | Cos 0.9057 | Time 2.18 min
Epoch 12/50 | Loss 0.0012 | Var 1.7463 | Cos 0.9086 | Time 2.19 min
Epoch 13/50 | Loss 0.0012 | Var 1.7615 | Cos 0.9090 | Time 2.20 min
Epoch 14/50 | Loss 0.0012 | Var 1.7461 | Cos 0.9098 | Time 2.18 min
Epoch 15/50 | Loss 0.0012 | Var 1.7317 | Cos 0.9111 | Tim

In [12]:
torch.save(encoder.state_dict(), "/kaggle/working/encoder_final.pt")


In [13]:
from torchvision import datasets, transforms

# No augmentation for probe
probe_transform = transforms.ToTensor()

train_dataset = datasets.CIFAR10(
    root="/kaggle/working/data",
    train=True,
    download=False,
    transform=probe_transform
)

test_dataset = datasets.CIFAR10(
    root="/kaggle/working/data",
    train=False,
    download=False,
    transform=probe_transform
)

train_loader = DataLoader(
    train_dataset,
    batch_size=256,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=256,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

print("Train batches:", len(train_loader))
print("Test batches:", len(test_loader))


Train batches: 196
Test batches: 40


In [14]:
# Load final SSL weights
encoder = Encoder().to(device)
encoder.load_state_dict(torch.load("/kaggle/working/encoder_final.pt"))
encoder.eval()

# Freeze encoder
for p in encoder.parameters():
    p.requires_grad = False

print("Encoder loaded and frozen.")


Encoder loaded and frozen.


In [15]:

class LinearProbe(nn.Module):
    def __init__(self, in_dim=512, num_classes=10):
        super().__init__()
        self.fc = nn.Linear(in_dim, num_classes)

    def forward(self, x):
        return self.fc(x)


probe = LinearProbe(in_dim=512).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(probe.parameters(), lr=1e-3)


In [16]:
probe_epochs = 50

for epoch in range(probe_epochs):
    probe.train()
    total_loss = 0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            features = encoder(images)

        outputs = probe(features)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Probe Epoch {epoch+1}/{probe_epochs} | Loss {total_loss/len(train_loader):.4f}")


Probe Epoch 1/50 | Loss 1.6067
Probe Epoch 2/50 | Loss 1.5217
Probe Epoch 3/50 | Loss 1.4968
Probe Epoch 4/50 | Loss 1.4840
Probe Epoch 5/50 | Loss 1.4736
Probe Epoch 6/50 | Loss 1.4646
Probe Epoch 7/50 | Loss 1.4613
Probe Epoch 8/50 | Loss 1.4564
Probe Epoch 9/50 | Loss 1.4494
Probe Epoch 10/50 | Loss 1.4431
Probe Epoch 11/50 | Loss 1.4438
Probe Epoch 12/50 | Loss 1.4388
Probe Epoch 13/50 | Loss 1.4361
Probe Epoch 14/50 | Loss 1.4349
Probe Epoch 15/50 | Loss 1.4296
Probe Epoch 16/50 | Loss 1.4252
Probe Epoch 17/50 | Loss 1.4241
Probe Epoch 18/50 | Loss 1.4214
Probe Epoch 19/50 | Loss 1.4202
Probe Epoch 20/50 | Loss 1.4172
Probe Epoch 21/50 | Loss 1.4141
Probe Epoch 22/50 | Loss 1.4138
Probe Epoch 23/50 | Loss 1.4122
Probe Epoch 24/50 | Loss 1.4087
Probe Epoch 25/50 | Loss 1.4111
Probe Epoch 26/50 | Loss 1.4039
Probe Epoch 27/50 | Loss 1.4040
Probe Epoch 28/50 | Loss 1.4004
Probe Epoch 29/50 | Loss 1.4051
Probe Epoch 30/50 | Loss 1.4013
Probe Epoch 31/50 | Loss 1.3996
Probe Epoch 32/50

In [17]:
probe.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        features = encoder(images)
        outputs = probe(features)

        _, predicted = torch.max(outputs, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"\nðŸ”¹ Linear Probe Test Accuracy: {accuracy:.2f}%")



ðŸ”¹ Linear Probe Test Accuracy: 48.25%
