In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load Unlabeled CIFAR-10
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip()
])
cifar10_unlabeled = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
unlabeled_loader = DataLoader(cifar10_unlabeled, batch_size=64, shuffle=True)

# Patch projection layer
patch_dim = 4 * 4 * 3  # = 48
embed_dim = 64
patch_proj = nn.Linear(patch_dim, embed_dim).to(device)

# Decoder
class SimpleDecoder(nn.Module):
    def __init__(self, input_dim, patch_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, patch_dim),
            nn.ReLU(),
            nn.Linear(patch_dim, patch_dim)
        )
    def forward(self, x):
        return self.fc(x)

decoder = SimpleDecoder(embed_dim, patch_dim).to(device)

# Optimizer
optimizer = torch.optim.AdamW(list(patch_proj.parameters()) + list(decoder.parameters()), lr=1e-3)

# Patch extraction function
def extract_patches(x, patch_size=4):
    return F.unfold(x, kernel_size=patch_size, stride=patch_size).transpose(1, 2)  # (B, N, P^2*C)

# Self-supervised pretraining loop
for epoch in range(5):  # you can increase to 100
    for images, _ in unlabeled_loader:
        images = images.to(device)
        patches = extract_patches(images, patch_size=4)  # (B, N, 48)

        B, N, _ = patches.shape
        n_keep = int(N * 0.25)
        idx = torch.rand(B, N).argsort(dim=1)[:, :n_keep]

        visible = torch.stack([
            patches[i, idx[i]] for i in range(B)
        ])  # (B, n_keep, 48)

        # Project to latent space
        encoded = patch_proj(visible)  # (B, n_keep, 64)

        # Reconstruct
        decoded = decoder(encoded)  # (B, n_keep, 48)

        # Loss: MSE between reconstructed and original visible patches
        loss = ((decoded - visible.detach()) ** 2).mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}: loss = {loss.item():.4f}")


Epoch 1: loss = 0.0010
Epoch 2: loss = 0.0004
Epoch 3: loss = 0.0003
Epoch 4: loss = 0.0003
Epoch 5: loss = 0.0002
