In [1]:
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

cudnn.benchmark = True


Using device: cuda


In [2]:
base_transform = transforms.ToTensor()

ssl_transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
    transforms.ToTensor()
])

to_pil = transforms.ToPILImage()


In [3]:
dataset = datasets.CIFAR10(
    root="/kaggle/working/data",
    train=True,
    download=True,
    transform=base_transform
)


In [4]:
loader = DataLoader(
    dataset,
    batch_size=256,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True
)

print("Batches per epoch:", len(loader))


Batches per epoch: 195


In [5]:
import torchvision.models as models

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()

        # Load ResNet18
        backbone = models.resnet18(weights=None)

        # Remove final classification layer
        self.encoder = nn.Sequential(*list(backbone.children())[:-1])

    def forward(self, x):
        x = self.encoder(x)
        return x.view(x.size(0), -1)  # Output: (B, 512)


In [6]:
class Predictor(nn.Module):
    def __init__(self, dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, dim)
        )

    def forward(self, x):
        return self.net(x)

In [7]:
class Projector(nn.Module):
    def __init__(self, in_dim=512, hidden_dim=512, out_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x):
        return self.net(x)

In [8]:
def ssl_loss(p, z):
    p = F.normalize(p, dim=1)
    z = F.normalize(z, dim=1)
    return F.mse_loss(p, z)


In [18]:
encoder = Encoder().to(device)
projector = Projector().to(device)
predictor = Predictor().to(device)

target_encoder = Encoder().to(device)
target_encoder.load_state_dict(encoder.state_dict())

for p in target_encoder.parameters():
    p.requires_grad = False

optimizer = torch.optim.Adam(
    list(encoder.parameters()) +
    list(projector.parameters()) +
    list(predictor.parameters()),
    lr=1e-3
)

ema_tau = 0.996


In [19]:
@torch.no_grad()
def feature_variance(z):
    return z.var(dim=0).mean().item()

@torch.no_grad()
def cosine_similarity_mean(z1, z2):
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    return (z1 * z2).sum(dim=1).mean().item()

@torch.no_grad()
def update_target_encoder(online, target, tau):
    for op, tp in zip(online.parameters(), target.parameters()):
        tp.data = tau * tp.data + (1 - tau) * op.data


In [20]:
epochs = 100
os.makedirs("/kaggle/working/checkpoints", exist_ok=True)

for epoch in range(epochs):
    start = time.time()

    total_loss, total_var, total_cos = 0.0, 0.0, 0.0

    for images, _ in loader:
        images = images.to(device)

        view1 = torch.stack([
            ssl_transform(to_pil(img.cpu())) for img in images
        ]).to(device)

        view2 = torch.stack([
            ssl_transform(to_pil(img.cpu())) for img in images
        ]).to(device)

        z1 = encoder(view1)
        z2 = encoder(view2)

        h1 = projector(z1)
        h2 = projector(z2)

        p1 = predictor(h1)
        p2 = predictor(h2)

        with torch.no_grad():
            t1 = projector(target_encoder(view1))
            t2 = projector(target_encoder(view2))

        loss = ssl_loss(p1, t2.detach()) + ssl_loss(p2, t1.detach())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        update_target_encoder(encoder, target_encoder, ema_tau)

        with torch.no_grad():
            total_var += feature_variance(z1)
            total_cos += cosine_similarity_mean(z1, z2)

        total_loss += loss.item()

    print(
        f"Epoch {epoch+1}/{epochs} | "
        f"Loss {total_loss/len(loader):.4f} | "
        f"Var {total_var/len(loader):.4f} | "
        f"Cos {total_cos/len(loader):.4f} | "
        f"Time {(time.time()-start)/60:.2f} min"
    )

    torch.save(
        encoder.state_dict(),
        f"/kaggle/working/checkpoints/encoder_epoch_{epoch+1}.pt"
    )


Epoch 1/100 | Loss 0.0009 | Var 2.2349 | Cos 0.6894 | Time 1.96 min
Epoch 2/100 | Loss 0.0007 | Var 1.9018 | Cos 0.7820 | Time 1.96 min
Epoch 3/100 | Loss 0.0008 | Var 1.8067 | Cos 0.8282 | Time 1.96 min
Epoch 4/100 | Loss 0.0010 | Var 1.7264 | Cos 0.8503 | Time 1.96 min
Epoch 5/100 | Loss 0.0011 | Var 1.6807 | Cos 0.8667 | Time 1.96 min
Epoch 6/100 | Loss 0.0011 | Var 1.6395 | Cos 0.8764 | Time 1.99 min
Epoch 7/100 | Loss 0.0012 | Var 1.6545 | Cos 0.8830 | Time 1.96 min
Epoch 8/100 | Loss 0.0012 | Var 1.6519 | Cos 0.8912 | Time 1.96 min
Epoch 9/100 | Loss 0.0012 | Var 1.6551 | Cos 0.8978 | Time 1.96 min
Epoch 10/100 | Loss 0.0012 | Var 1.6402 | Cos 0.9016 | Time 1.96 min
Epoch 11/100 | Loss 0.0012 | Var 1.6293 | Cos 0.9053 | Time 1.96 min
Epoch 12/100 | Loss 0.0013 | Var 1.6189 | Cos 0.9061 | Time 1.96 min
Epoch 13/100 | Loss 0.0013 | Var 1.6003 | Cos 0.9071 | Time 1.96 min
Epoch 14/100 | Loss 0.0013 | Var 1.5803 | Cos 0.9075 | Time 1.96 min
Epoch 15/100 | Loss 0.0013 | Var 1.5606 | C

In [21]:
torch.save(encoder.state_dict(), "/kaggle/working/encoder_final.pt")


In [22]:
from torchvision import datasets, transforms

# No augmentation for probe
probe_transform = transforms.ToTensor()

train_dataset = datasets.CIFAR10(
    root="/kaggle/working/data",
    train=True,
    download=False,
    transform=probe_transform
)

test_dataset = datasets.CIFAR10(
    root="/kaggle/working/data",
    train=False,
    download=False,
    transform=probe_transform
)

train_loader = DataLoader(
    train_dataset,
    batch_size=256,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=256,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

print("Train batches:", len(train_loader))
print("Test batches:", len(test_loader))


Train batches: 196
Test batches: 40


In [23]:
# Load final SSL weights
encoder = Encoder().to(device)
encoder.load_state_dict(torch.load("/kaggle/working/encoder_final.pt"))
encoder.eval()

# Freeze encoder
for p in encoder.parameters():
    p.requires_grad = False

print("Encoder loaded and frozen.")


Encoder loaded and frozen.


In [24]:

class LinearProbe(nn.Module):
    def __init__(self, in_dim=512, num_classes=10):
        super().__init__()
        self.fc = nn.Linear(in_dim, num_classes)

    def forward(self, x):
        return self.fc(x)


probe = LinearProbe(in_dim=512).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(probe.parameters(), lr=1e-3)


In [25]:
probe_epochs = 100

for epoch in range(probe_epochs):
    probe.train()
    total_loss = 0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            features = encoder(images)

        outputs = probe(features)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Probe Epoch {epoch+1}/{probe_epochs} | Loss {total_loss/len(train_loader):.4f}")


Probe Epoch 1/100 | Loss 1.5714
Probe Epoch 2/100 | Loss 1.4619
Probe Epoch 3/100 | Loss 1.4339
Probe Epoch 4/100 | Loss 1.4191
Probe Epoch 5/100 | Loss 1.4099
Probe Epoch 6/100 | Loss 1.3977
Probe Epoch 7/100 | Loss 1.3935
Probe Epoch 8/100 | Loss 1.3847
Probe Epoch 9/100 | Loss 1.3816
Probe Epoch 10/100 | Loss 1.3752
Probe Epoch 11/100 | Loss 1.3715
Probe Epoch 12/100 | Loss 1.3679
Probe Epoch 13/100 | Loss 1.3628
Probe Epoch 14/100 | Loss 1.3600
Probe Epoch 15/100 | Loss 1.3562
Probe Epoch 16/100 | Loss 1.3569
Probe Epoch 17/100 | Loss 1.3535
Probe Epoch 18/100 | Loss 1.3488
Probe Epoch 19/100 | Loss 1.3493
Probe Epoch 20/100 | Loss 1.3486
Probe Epoch 21/100 | Loss 1.3432
Probe Epoch 22/100 | Loss 1.3433
Probe Epoch 23/100 | Loss 1.3425
Probe Epoch 24/100 | Loss 1.3389
Probe Epoch 25/100 | Loss 1.3351
Probe Epoch 26/100 | Loss 1.3367
Probe Epoch 27/100 | Loss 1.3349
Probe Epoch 28/100 | Loss 1.3338
Probe Epoch 29/100 | Loss 1.3321
Probe Epoch 30/100 | Loss 1.3321
Probe Epoch 31/100 

In [27]:
probe.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        features = encoder(images)
        outputs = probe(features)

        _, predicted = torch.max(outputs, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"\nðŸ”¹ Linear Probe Test Accuracy: {accuracy:.2f}%")



ðŸ”¹ Linear Probe Test Accuracy: 51.54%
