# Project 1 · BASIC: CIFAR-10 / CIFAR-100

In [54]:
import os, json, math, numpy as np, torch, torchvision
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as T
from sklearn.model_selection import train_test_split
from lightly.transforms import MAETransform
from lightly.models import utils
from timm.models.vision_transformer import vit_base_patch32_224
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import umap
import seaborn as sns
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from lightly.models.modules.masked_vision_transformer_timm import MaskedVisionTransformerTIMM
from lightly.models.modules.masked_autoencoder_timm import MAEDecoderTIMM
import torch.multiprocessing as mp

In [55]:
mp.set_sharing_strategy('file_system')

In [56]:
pl.seed_everything(42)

Seed set to 42


42

In [57]:
torch.set_float32_matmul_precision("high")
device = "cuda" if torch.cuda.is_available() else "cpu"

In [58]:
def compute_cifar_stats(name="cifar100", root="data", batch_size=5000):
    ds_class = getattr(torchvision.datasets, name.upper())
    ds = ds_class(root, train=True, download=True, transform=T.ToTensor())
    loader = DataLoader(ds, batch_size=batch_size, num_workers=2, shuffle=False)
    ch_sum = torch.zeros(3)
    ch_sum_sq = torch.zeros(3)
    n_pixels = 0
    for imgs, _ in loader:
        b, c, h, w = imgs.shape
        n_pixels += b * h * w
        ch_sum    += imgs.sum(dim=[0,2,3])
        ch_sum_sq += (imgs**2).sum(dim=[0,2,3])
    mean = ch_sum / n_pixels
    std  = torch.sqrt(ch_sum_sq / n_pixels - mean**2)
    return mean.tolist(), std.tolist()

In [59]:
def make_datasets(root="data", val_size=5_000, img_size=224, seed=42):
    mean, std = compute_cifar_stats(name="cifar100", root=root)
    ssl_tfm  = MAETransform(input_size=img_size, min_scale=0.2, normalize={"mean": mean, "std": std})
    eval_tfm = torchvision.transforms.Compose([
        torchvision.transforms.Resize(img_size + 10),
        torchvision.transforms.CenterCrop(img_size),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean, std),
    ])
    base = torchvision.datasets.CIFAR100(root, train=True, download=True, transform=eval_tfm)
    targets = np.array(base.targets)
    idx = np.arange(len(base))
    tr_idx, val_idx = train_test_split(idx, test_size=val_size, stratify=targets, random_state=seed)

    train_ssl = Subset(torchvision.datasets.CIFAR100(root, train=True, transform=ssl_tfm, download=False), tr_idx)
    val_set   = Subset(torchvision.datasets.CIFAR100(root, train=True, transform=eval_tfm, download=False), val_idx)
    test_set  = torchvision.datasets.CIFAR100(root, train=False, transform=eval_tfm, download=True)
    return train_ssl, val_set, test_set

In [60]:
class MAE(pl.LightningModule):
    def __init__(self, lr=1.5e-4):
        super().__init__()
        decoder_dim = 512
        vit         = vit_base_patch32_224()
        self.mask_ratio   = 0.75
        self.patch_size   = vit.patch_embed.patch_size[0]
        self.backbone     = MaskedVisionTransformerTIMM(vit=vit)
        self.sequence_len = self.backbone.sequence_length
        self.decoder      = MAEDecoderTIMM(
            num_patches=vit.patch_embed.num_patches,
            patch_size=self.patch_size,
            embed_dim=vit.embed_dim,
            decoder_embed_dim=decoder_dim,
            decoder_depth=1,
            decoder_num_heads=16,
            mlp_ratio=4.0,
        )
        self.criterion = nn.MSELoss()
        self.lr = lr

    def forward_encoder(self, imgs, idx_keep=None):
        return self.backbone.encode(images=imgs, idx_keep=idx_keep)

    def forward_decoder(self, x_enc, idx_keep, idx_mask):
        b = x_enc.size(0)
        x_dec = self.decoder.embed(x_enc)
        x_masked = utils.repeat_token(self.decoder.mask_token, (b, self.sequence_len))
        x_masked = utils.set_at_index(x_masked, idx_keep, x_dec.type_as(x_masked))
        x_decoded = self.decoder.decode(x_masked)
        x_pred = utils.get_at_index(x_decoded, idx_mask)
        return self.decoder.predict(x_pred)

    def training_step(self, batch, batch_idx):
        views, _ = batch
        imgs = views[0]
        bsz = imgs.size(0)
        idx_keep, idx_mask = utils.random_token_mask(
            size=(bsz, self.sequence_len),
            mask_ratio=self.mask_ratio,
            device=imgs.device,
        )
        x_enc = self.forward_encoder(imgs, idx_keep)
        x_pred = self.forward_decoder(x_enc, idx_keep, idx_mask)

        patches = utils.patchify(imgs, self.patch_size)
        target  = utils.get_at_index(patches, idx_mask - 1)
        loss = self.criterion(x_pred, target)
        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.lr)

In [None]:
train_ssl, val_set, test_set = make_datasets()
train_loader = DataLoader(
    train_ssl,
    batch_size=256,
    shuffle=True,
    drop_last=True,
    num_workers=2,
    pin_memory=True,
    persistent_workers=False,
    prefetch_factor=2,
)

model = MAE()

trainer = pl.Trainer(
    max_epochs=80,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1,
    precision="16-mixed",
    enable_checkpointing=False,
    log_every_n_steps=50,
)
trainer.fit(model, train_loader)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                        | Params | Mode 
------------------------------------------------------------------
0 | backbone  | MaskedVisionTransformerTIMM | 88.2 M | train
1 | decoder   | MAEDecoderTIMM              | 5.1 M  | train
2 | criterion | MSELoss                     | 0      | train
------------------------------------------------------------------
93.3 M    Trainable params
64.0 K    Non-trainable params
93.4 M    Total params
373.497   Total estimated model params size (MB)
292       Modules in train mode
0         Modules in eval mode
d:\Repos\Warsztaty_Badawcze\env\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:420: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Epoch 2:  39%|███▉      | 68/175 [03:39<05:44,  0.31it/s, v_num=7, train_loss_step=0.280, train_loss_epoch=0.350] 

In [None]:
def extract_embs(mod, dataset):
    loader = DataLoader(
        dataset, batch_size=256, shuffle=False,
        num_workers=os.cpu_count() // 2, pin_memory=True,
        persistent_workers=True, prefetch_factor=4,
    )
    embs, lbls = [], []
    mod.eval()
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(mod.device, non_blocking=True)
            feat, _ = mod.backbone.encode(images=imgs)
            embs.append(feat.cpu())
            lbls.append(labels)
    return torch.vstack(embs).numpy(), torch.hstack(lbls).numpy()

train_embs, train_lbls = extract_embs(model, val_set)
test_embs,  test_lbls  = extract_embs(model, test_set)

lin_clf = LogisticRegression(max_iter=500).fit(train_embs, train_lbls)
knn_clf = KNeighborsClassifier(n_neighbors=20).fit(train_embs, train_lbls)
print(f"Linear probe accuracy: {lin_clf.score(test_embs, test_lbls):.4f}")
print(f"k-NN            acc.: {knn_clf.score(test_embs, test_lbls):.4f}")

def plot_2d(emb, lbl, title):
    plt.figure(figsize=(8, 6))
    sns.scatterplot(x=emb[:, 0], y=emb[:, 1], hue=lbl, palette="tab10", s=10, linewidth=0)
    plt.title(title)
    plt.legend(loc="best", bbox_to_anchor=(1, 1), ncol=2, fontsize="x-small")
    plt.tight_layout()
    plt.show()

pca_emb = PCA(n_components=2).fit_transform(test_embs)
plot_2d(pca_emb, test_lbls, "PCA of MAE embeddings")

tsne_emb = TSNE(n_components=2, random_state=42).fit_transform(test_embs)
plot_2d(tsne_emb, test_lbls, "t-SNE of MAE embeddings")

umap_emb = umap.UMAP(random_state=42).fit_transform(test_embs)
plot_2d(umap_emb, test_lbls, "UMAP of MAE embeddings")