In [2]:
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

cudnn.benchmark = True


Using device: cuda


In [3]:
base_transform = transforms.ToTensor()

ssl_transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
    transforms.ToTensor()
])

to_pil = transforms.ToPILImage()


In [4]:
dataset = datasets.CIFAR10(
    root="/kaggle/working/data",
    train=True,
    download=True,
    transform=base_transform
)


100%|██████████| 170M/170M [00:08<00:00, 20.6MB/s] 


In [5]:
loader = DataLoader(
    dataset,
    batch_size=256,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True
)

print("Batches per epoch:", len(loader))


Batches per epoch: 195


In [6]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)
        )

    def forward(self, x):
        x = self.net(x)
        return x.view(x.size(0), -1)


In [7]:
class Predictor(nn.Module):
    def __init__(self, dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, dim)
        )

    def forward(self, x):
        return self.net(x)


In [8]:
class Projector(nn.Module):
    def __init__(self, in_dim=128, hidden_dim=256, out_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x):
        return self.net(x)


In [9]:
def ssl_loss(p, z):
    p = F.normalize(p, dim=1)
    z = F.normalize(z, dim=1)
    return F.mse_loss(p, z)


In [10]:
encoder = Encoder().to(device)
projector = Projector().to(device)
predictor = Predictor().to(device)

target_encoder = Encoder().to(device)
target_encoder.load_state_dict(encoder.state_dict())

for p in target_encoder.parameters():
    p.requires_grad = False

optimizer = torch.optim.Adam(
    list(encoder.parameters()) +
    list(projector.parameters()) +
    list(predictor.parameters()),
    lr=1e-3
)

ema_tau = 0.99


In [11]:
@torch.no_grad()
def feature_variance(z):
    return z.var(dim=0).mean().item()

@torch.no_grad()
def cosine_similarity_mean(z1, z2):
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    return (z1 * z2).sum(dim=1).mean().item()

@torch.no_grad()
def update_target_encoder(online, target, tau):
    for op, tp in zip(online.parameters(), target.parameters()):
        tp.data = tau * tp.data + (1 - tau) * op.data


In [12]:
epochs = 50
os.makedirs("/kaggle/working/checkpoints", exist_ok=True)

for epoch in range(epochs):
    start = time.time()

    total_loss, total_var, total_cos = 0.0, 0.0, 0.0

    for images, _ in loader:
        images = images.to(device)

        view1 = torch.stack([
            ssl_transform(to_pil(img.cpu())) for img in images
        ]).to(device)

        view2 = torch.stack([
            ssl_transform(to_pil(img.cpu())) for img in images
        ]).to(device)

        z1 = encoder(view1)
        z2 = encoder(view2)

        h1 = projector(z1)
        h2 = projector(z2)

        p1 = predictor(h1)
        p2 = predictor(h2)

        with torch.no_grad():
            t1 = projector(target_encoder(view1))
            t2 = projector(target_encoder(view2))

        loss = ssl_loss(p1, t2.detach()) + ssl_loss(p2, t1.detach())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        update_target_encoder(encoder, target_encoder, ema_tau)

        with torch.no_grad():
            total_var += feature_variance(z1)
            total_cos += cosine_similarity_mean(z1, z2)

        total_loss += loss.item()

    print(
        f"Epoch {epoch+1}/{epochs} | "
        f"Loss {total_loss/len(loader):.4f} | "
        f"Var {total_var/len(loader):.4f} | "
        f"Cos {total_cos/len(loader):.4f} | "
        f"Time {(time.time()-start)/60:.2f} min"
    )

    torch.save(
        encoder.state_dict(),
        f"/kaggle/working/checkpoints/encoder_epoch_{epoch+1}.pt"
    )


Epoch 1/50 | Loss 0.0034 | Var 0.0039 | Cos 0.9853 | Time 1.57 min
Epoch 2/50 | Loss 0.0023 | Var 0.0021 | Cos 0.9751 | Time 1.56 min
Epoch 3/50 | Loss 0.0022 | Var 0.0016 | Cos 0.9677 | Time 1.56 min
Epoch 4/50 | Loss 0.0022 | Var 0.0013 | Cos 0.9613 | Time 1.53 min
Epoch 5/50 | Loss 0.0023 | Var 0.0011 | Cos 0.9538 | Time 1.53 min
Epoch 6/50 | Loss 0.0024 | Var 0.0010 | Cos 0.9479 | Time 1.53 min
Epoch 7/50 | Loss 0.0024 | Var 0.0009 | Cos 0.9417 | Time 1.52 min
Epoch 8/50 | Loss 0.0024 | Var 0.0008 | Cos 0.9348 | Time 1.52 min
Epoch 9/50 | Loss 0.0026 | Var 0.0007 | Cos 0.9315 | Time 1.53 min
Epoch 10/50 | Loss 0.0026 | Var 0.0006 | Cos 0.9266 | Time 1.57 min
Epoch 11/50 | Loss 0.0027 | Var 0.0006 | Cos 0.9225 | Time 1.56 min
Epoch 12/50 | Loss 0.0028 | Var 0.0005 | Cos 0.9168 | Time 1.58 min
Epoch 13/50 | Loss 0.0028 | Var 0.0005 | Cos 0.9140 | Time 1.54 min
Epoch 14/50 | Loss 0.0028 | Var 0.0004 | Cos 0.9119 | Time 1.60 min
Epoch 15/50 | Loss 0.0028 | Var 0.0004 | Cos 0.9108 | Tim

In [13]:
torch.save(encoder.state_dict(), "/kaggle/working/encoder_final.pt")
