In [1]:
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

cudnn.benchmark = True


Using device: cuda


In [2]:
base_transform = transforms.ToTensor()

ssl_transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
    transforms.ToTensor()
])

to_pil = transforms.ToPILImage()


In [3]:
dataset = datasets.CIFAR10(
    root="/kaggle/working/data",
    train=True,
    download=True,
    transform=base_transform
)


In [4]:
loader = DataLoader(
    dataset,
    batch_size=256,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True
)

print("Batches per epoch:", len(loader))


Batches per epoch: 195


In [5]:
import torchvision.models as models

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()

        # Load ResNet18
        backbone = models.resnet18(weights=None)

        # Remove final classification layer
        self.encoder = nn.Sequential(*list(backbone.children())[:-1])

    def forward(self, x):
        x = self.encoder(x)
        return x.view(x.size(0), -1)  # Output: (B, 512)


In [6]:
class Predictor(nn.Module):
    def __init__(self, dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, dim)
        )

    def forward(self, x):
        return self.net(x)

In [7]:
class Projector(nn.Module):
    def __init__(self, in_dim=512, hidden_dim=512, out_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x):
        return self.net(x)

In [8]:
def ssl_loss(p, z):
    p = F.normalize(p, dim=1)
    z = F.normalize(z, dim=1)
    return F.mse_loss(p, z)


In [9]:
encoder = Encoder().to(device)
projector = Projector().to(device)
predictor = Predictor().to(device)

target_encoder = Encoder().to(device)
target_encoder.load_state_dict(encoder.state_dict())

for p in target_encoder.parameters():
    p.requires_grad = False

optimizer = torch.optim.Adam(
    list(encoder.parameters()) +
    list(projector.parameters()) +
    list(predictor.parameters()),
    lr=1e-3
)

ema_tau = 0.99


In [10]:
@torch.no_grad()
def feature_variance(z):
    return z.var(dim=0).mean().item()

@torch.no_grad()
def cosine_similarity_mean(z1, z2):
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    return (z1 * z2).sum(dim=1).mean().item()

@torch.no_grad()
def update_target_encoder(online, target, tau):
    for op, tp in zip(online.parameters(), target.parameters()):
        tp.data = tau * tp.data + (1 - tau) * op.data


In [11]:
epochs = 100
os.makedirs("/kaggle/working/checkpoints", exist_ok=True)

for epoch in range(epochs):
    start = time.time()

    total_loss, total_var, total_cos = 0.0, 0.0, 0.0

    for images, _ in loader:
        images = images.to(device)

        view1 = torch.stack([
            ssl_transform(to_pil(img.cpu())) for img in images
        ]).to(device)

        view2 = torch.stack([
            ssl_transform(to_pil(img.cpu())) for img in images
        ]).to(device)

        z1 = encoder(view1)
        z2 = encoder(view2)

        h1 = projector(z1)
        h2 = projector(z2)

        p1 = predictor(h1)
        p2 = predictor(h2)

        with torch.no_grad():
            t1 = projector(target_encoder(view1))
            t2 = projector(target_encoder(view2))

        loss = ssl_loss(p1, t2.detach()) + ssl_loss(p2, t1.detach())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        update_target_encoder(encoder, target_encoder, ema_tau)

        with torch.no_grad():
            total_var += feature_variance(z1)
            total_cos += cosine_similarity_mean(z1, z2)

        total_loss += loss.item()

    print(
        f"Epoch {epoch+1}/{epochs} | "
        f"Loss {total_loss/len(loader):.4f} | "
        f"Var {total_var/len(loader):.4f} | "
        f"Cos {total_cos/len(loader):.4f} | "
        f"Time {(time.time()-start)/60:.2f} min"
    )

    torch.save(
        encoder.state_dict(),
        f"/kaggle/working/checkpoints/encoder_epoch_{epoch+1}.pt"
    )


Epoch 1/100 | Loss 0.0008 | Var 2.0138 | Cos 0.6514 | Time 2.05 min
Epoch 2/100 | Loss 0.0006 | Var 1.9284 | Cos 0.7996 | Time 1.97 min
Epoch 3/100 | Loss 0.0007 | Var 1.8682 | Cos 0.8518 | Time 1.96 min
Epoch 4/100 | Loss 0.0007 | Var 1.8330 | Cos 0.8723 | Time 1.96 min
Epoch 5/100 | Loss 0.0008 | Var 1.7911 | Cos 0.8809 | Time 1.96 min
Epoch 6/100 | Loss 0.0008 | Var 1.7755 | Cos 0.8899 | Time 1.96 min
Epoch 7/100 | Loss 0.0009 | Var 1.7093 | Cos 0.8940 | Time 1.96 min
Epoch 8/100 | Loss 0.0010 | Var 1.6825 | Cos 0.8982 | Time 1.96 min
Epoch 9/100 | Loss 0.0012 | Var 1.6856 | Cos 0.9014 | Time 1.96 min
Epoch 10/100 | Loss 0.0012 | Var 1.6639 | Cos 0.9029 | Time 1.96 min
Epoch 11/100 | Loss 0.0013 | Var 1.6374 | Cos 0.9048 | Time 1.96 min
Epoch 12/100 | Loss 0.0014 | Var 1.6390 | Cos 0.9069 | Time 1.96 min
Epoch 13/100 | Loss 0.0014 | Var 1.6107 | Cos 0.9112 | Time 1.96 min
Epoch 14/100 | Loss 0.0014 | Var 1.5989 | Cos 0.9123 | Time 1.96 min
Epoch 15/100 | Loss 0.0014 | Var 1.5739 | C

In [12]:
torch.save(encoder.state_dict(), "/kaggle/working/encoder_final.pt")


In [13]:
from torchvision import datasets, transforms

# No augmentation for probe
probe_transform = transforms.ToTensor()

train_dataset = datasets.CIFAR10(
    root="/kaggle/working/data",
    train=True,
    download=False,
    transform=probe_transform
)

test_dataset = datasets.CIFAR10(
    root="/kaggle/working/data",
    train=False,
    download=False,
    transform=probe_transform
)

train_loader = DataLoader(
    train_dataset,
    batch_size=256,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=256,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

print("Train batches:", len(train_loader))
print("Test batches:", len(test_loader))


Train batches: 196
Test batches: 40


In [14]:
# Load final SSL weights
encoder = Encoder().to(device)
encoder.load_state_dict(torch.load("/kaggle/working/encoder_final.pt"))
encoder.eval()

# Freeze encoder
for p in encoder.parameters():
    p.requires_grad = False

print("Encoder loaded and frozen.")


Encoder loaded and frozen.


In [15]:

class LinearProbe(nn.Module):
    def __init__(self, in_dim=512, num_classes=10):
        super().__init__()
        self.fc = nn.Linear(in_dim, num_classes)

    def forward(self, x):
        return self.fc(x)


probe = LinearProbe(in_dim=512).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(probe.parameters(), lr=1e-3)


In [16]:
probe_epochs = 100

for epoch in range(probe_epochs):
    probe.train()
    total_loss = 0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            features = encoder(images)

        outputs = probe(features)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Probe Epoch {epoch+1}/{probe_epochs} | Loss {total_loss/len(train_loader):.4f}")


Probe Epoch 1/100 | Loss 1.4681
Probe Epoch 2/100 | Loss 1.3430
Probe Epoch 3/100 | Loss 1.3191
Probe Epoch 4/100 | Loss 1.3006
Probe Epoch 5/100 | Loss 1.2859
Probe Epoch 6/100 | Loss 1.2775
Probe Epoch 7/100 | Loss 1.2715
Probe Epoch 8/100 | Loss 1.2646
Probe Epoch 9/100 | Loss 1.2573
Probe Epoch 10/100 | Loss 1.2533
Probe Epoch 11/100 | Loss 1.2490
Probe Epoch 12/100 | Loss 1.2439
Probe Epoch 13/100 | Loss 1.2435
Probe Epoch 14/100 | Loss 1.2384
Probe Epoch 15/100 | Loss 1.2388
Probe Epoch 16/100 | Loss 1.2348
Probe Epoch 17/100 | Loss 1.2319
Probe Epoch 18/100 | Loss 1.2273
Probe Epoch 19/100 | Loss 1.2267
Probe Epoch 20/100 | Loss 1.2274
Probe Epoch 21/100 | Loss 1.2226
Probe Epoch 22/100 | Loss 1.2234
Probe Epoch 23/100 | Loss 1.2202
Probe Epoch 24/100 | Loss 1.2201
Probe Epoch 25/100 | Loss 1.2201
Probe Epoch 26/100 | Loss 1.2137
Probe Epoch 27/100 | Loss 1.2153
Probe Epoch 28/100 | Loss 1.2136
Probe Epoch 29/100 | Loss 1.2135
Probe Epoch 30/100 | Loss 1.2122
Probe Epoch 31/100 

In [17]:
probe.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        features = encoder(images)
        outputs = probe(features)

        _, predicted = torch.max(outputs, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"\nðŸ”¹ Linear Probe Test Accuracy: {accuracy:.2f}%")



ðŸ”¹ Linear Probe Test Accuracy: 55.16%
