### Contributors
AmirMohammad Bandari (401110278) & Pouria Mahmoudkhan (401110289)

### Including Libraries

In [None]:
import os
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from collections import Counter

# reproducibility
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

# Phase 1: Preparing Data & EDA

downloading data in `./data` folder and splitting it into `train`, `validation`, and `test`:

In [None]:
transform = transforms.Compose([
    transforms.ToTensor()
])

data_root = './data'
train_full = datasets.FashionMNIST(root=data_root, train=True, download=True, transform=transform)
test_ds = datasets.FashionMNIST(root=data_root, train=False, download=True, transform=transform)

n_train = int(0.9 * len(train_full))
n_val = len(train_full) - n_train
train_ds, val_ds = random_split(train_full, [n_train, n_val], generator=torch.Generator().manual_seed(seed))

batch_size = 128
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=torch.cuda.is_available())
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=torch.cuda.is_available())
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=torch.cuda.is_available())

len(train_ds), len(val_ds), len(test_ds), train_ds[0][0].shape

sanity: show a batch 25 images with labels

In [None]:
classes = train_full.classes
x, y = next(iter(train_loader))
x = x[:25]
y = y[:25]

plt.figure(figsize=(7,7))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.imshow(x[i].squeeze(0), cmap='gray')
    plt.title(classes[int(y[i])], fontsize=8)
    plt.axis('off')
plt.tight_layout()
plt.show()

EDA: class distribution in training set

In [None]:
train_labels = [int(train_full.targets[i]) for i in train_ds.indices]
cnt = Counter(train_labels)
xs = np.arange(10)
vals = np.array([cnt[i] for i in xs])

plt.figure(figsize=(7,3))
plt.bar(xs, vals)
plt.xticks(xs, classes, rotation=45, ha='right')
plt.ylabel('count')
plt.title('Train class distribution')
plt.tight_layout()
plt.show()

In [None]:
# pixel intensity histogram (sampled for speed)
def sample_pixels(loader, max_batches=80):
    px = []
    for i, (x, _) in enumerate(loader):
        px.append(x.view(-1).cpu().numpy())
        if i+1 >= max_batches:
            break
    px = np.concatenate(px, axis=0)
    return px

px = sample_pixels(train_loader, max_batches=60)
plt.figure(figsize=(6,3))
plt.hist(px, bins=50)
plt.title('Pixel intensity histogram (train, sampled)')
plt.xlabel('intensity')
plt.ylabel('freq')
plt.tight_layout()
plt.show()

In [None]:
# image-level histogram: mean pixel intensity per image (sampled)
means = []
for i, (x, _) in enumerate(train_loader):
    means.append(x.view(x.size(0), -1).mean(dim=1).cpu().numpy())
    if i >= 80:
        break
means = np.concatenate(means, axis=0)

plt.figure(figsize=(6,3))
plt.hist(means, bins=50)
plt.title('Per-image mean intensity (train, sampled)')
plt.xlabel('mean intensity')
plt.ylabel('count')
plt.tight_layout()
plt.show()

## Helpers
here we are defining some helper functions:
- `to_img_grid` gets a tensor in input and view it as a grid of images, it helps in demonstrations
- `evaluate_vae` abstracts away the of evaluation process of our models (they all use a same function)
- `train_vae` abstracts away the training process of our models (they all use a same function)

In [None]:
def to_img_grid(x, nrow, ncol, figsize=(10,4), title=None):
    x = x.detach().cpu()
    plt.figure(figsize=figsize)
    for i in range(nrow*ncol):
        plt.subplot(nrow, ncol, i+1)
        plt.imshow(x[i].squeeze(0), cmap='gray')
        plt.axis('off')
    if title is not None:
        plt.suptitle(title)
    plt.tight_layout()
    plt.show()

@torch.no_grad()
def evaluate_vae(model, loader, beta=1.0):
    model.eval()
    rec_sum, kld_sum, total_sum, n = 0.0, 0.0, 0.0, 0
    for x, _ in loader:
        x = x.to(device)
        out = model(x)
        x_hat, mu, logvar = out["x_hat"], out["mu"], out["logvar"]
        rec = F.binary_cross_entropy(x_hat, x, reduction='sum')
        kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        total = rec + beta * kld
        rec_sum += rec.item()
        kld_sum += kld.item()
        total_sum += total.item()
        n += x.size(0)
    return {
        "recon_per_img": rec_sum / n,
        "kld_per_img": kld_sum / n,
        "total_per_img": total_sum / n
    }

def train_vae(model, train_loader, val_loader, epochs=10, lr=2e-3, beta=1.0, warmup_epochs=0):
    model.to(device)
    opt = Adam(model.parameters(), lr=lr)
    hist = {"train_total":[], "val_total":[], "train_rec":[], "train_kld":[], "val_rec":[], "val_kld":[]}
    for ep in range(1, epochs+1):
        model.train()
        rec_sum, kld_sum, tot_sum, n = 0.0, 0.0, 0.0, 0
        if warmup_epochs > 0:
            beta_ep = beta * min(1.0, ep / warmup_epochs)
        else:
            beta_ep = beta
        for x, _ in train_loader:
            x = x.to(device)
            out = model(x)
            x_hat, mu, logvar = out["x_hat"], out["mu"], out["logvar"]
            rec = F.binary_cross_entropy(x_hat, x, reduction='sum')
            kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            loss = rec + beta_ep * kld
            opt.zero_grad()
            loss.backward()
            opt.step()
            rec_sum += rec.item()
            kld_sum += kld.item()
            tot_sum += loss.item()
            n += x.size(0)

        tr = {"recon_per_img": rec_sum/n, "kld_per_img": kld_sum/n, "total_per_img": tot_sum/n}
        va = evaluate_vae(model, val_loader, beta=beta)
        hist["train_total"].append(tr["total_per_img"])
        hist["val_total"].append(va["total_per_img"])
        hist["train_rec"].append(tr["recon_per_img"])
        hist["train_kld"].append(tr["kld_per_img"])
        hist["val_rec"].append(va["recon_per_img"])
        hist["val_kld"].append(va["kld_per_img"])
        print(f"ep {ep:02d} | beta={beta_ep:.3f} | train total {tr['total_per_img']:.2f} rec {tr['recon_per_img']:.2f} kld {tr['kld_per_img']:.2f} | val total {va['total_per_img']:.2f}")
    return hist

# Phase 2: VAE implementation & quality improvement

### Define VAE Model

We define our base model which has a fully connected architecture, a Multi-Layer Perceptron.
We suppose our latent dimension to be 20.

In [None]:
class MLPVAE(nn.Module):
    def __init__(self, z_dim=20, h_dim=400):
        super().__init__()
        self.z_dim = z_dim
        self.enc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, h_dim),
            nn.ReLU()
        )
        self.mu = nn.Linear(h_dim, z_dim)
        self.logvar = nn.Linear(h_dim, z_dim)
        self.dec = nn.Sequential(
            nn.Linear(z_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, 28*28),
            nn.Sigmoid()
        )

    def reparam(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        h = self.enc(x)
        mu = self.mu(h)
        logvar = self.logvar(h)
        z = self.reparam(mu, logvar)
        x_hat = self.dec(z).view(-1, 1, 28, 28)
        return {"x_hat": x_hat, "mu": mu, "logvar": logvar, "z": z}

@torch.no_grad()
def recon_samples(model, loader, n=20):
    model.eval()
    x, _ = next(iter(loader))
    x = x[:n].to(device)
    out = model(x)
    return x, out["x_hat"]

@torch.no_grad()
def sample_prior(model, n=50):
    model.eval()
    z = torch.randn(n, model.z_dim).to(device)
    x_hat = model.dec(z).view(-1, 1, 28, 28)
    return x_hat

In [None]:
vae_base = MLPVAE(z_dim=20, h_dim=400).to(device)
hist_base = train_vae(vae_base, train_loader, val_loader, epochs=10, lr=2e-3, beta=1.0)

base_test = evaluate_vae(vae_base, test_loader, beta=1.0)
base_test

In [None]:
# reconstructions (20)
x_in, x_out = recon_samples(vae_base, test_loader, n=20)
to_img_grid(x_in, 2, 10, figsize=(12,3), title="Baseline: inputs")
to_img_grid(x_out, 2, 10, figsize=(12,3), title="Baseline: reconstructions")

In [None]:
# sampling (50)
samp = sample_prior(vae_base, n=50)
to_img_grid(samp, 5, 10, figsize=(12,6), title="Baseline: prior samples")

Now we define our improved architecture.
This new version has a convolution based structure.
- First a layer of 32-channeled 4x4 convolution with stride 2 and padding 1.
- Then a layer of 64-channeled 4x4 convolution with stride 2 and padding 1.
- Then finally a 128-channeled 3x3 convolution with stride and padding 1.

Leaving us with 128 channels with size 7x7.
This enters a linear perceptron layer.

The decoder is built to be the inverse of this structure.

Because this structure uses our inductive bias, it will probably give a better result.

In [None]:
class ConvVAE(nn.Module):
    def __init__(self, z_dim=32):
        super().__init__()
        self.z_dim = z_dim

        self.enc_conv = nn.Sequential(
            nn.Conv2d(1, 32, 4, 2, 1),  # 14x14
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1), # 7x7
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, 1, 1),# 7x7
            nn.ReLU()
        )
        self.enc_fc = nn.Linear(128*7*7, 256)
        self.mu = nn.Linear(256, z_dim)
        self.logvar = nn.Linear(256, z_dim)

        self.dec_fc = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128*7*7),
            nn.ReLU()
        )
        self.dec_deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1), # 14x14
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),  # 28x28
            nn.ReLU(),
            nn.Conv2d(32, 1, 3, 1, 1),
            nn.Sigmoid()
        )

    def reparam(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        h = self.enc_conv(x).view(x.size(0), -1)
        h = F.relu(self.enc_fc(h))
        mu = self.mu(h)
        logvar = self.logvar(h)
        z = self.reparam(mu, logvar)
        g = self.dec_fc(z).view(-1, 128, 7, 7)
        x_hat = self.dec_deconv(g)
        return {"x_hat": x_hat, "mu": mu, "logvar": logvar, "z": z}

@torch.no_grad()
def sample_prior_conv(model, n=50):
    model.eval()
    z = torch.randn(n, model.z_dim).to(device)
    g = model.dec_fc(z).view(-1, 128, 7, 7)
    x_hat = model.dec_deconv(g)
    return x_hat

In [None]:
vae_imp = ConvVAE(z_dim=32).to(device)
hist_imp = train_vae(vae_imp, train_loader, val_loader, epochs=12, lr=2e-3, beta=1.0, warmup_epochs=6)

imp_test = evaluate_vae(vae_imp, test_loader, beta=1.0)
imp_test

In [None]:
# before vs after (reconstruction)
x_in, x_out_base = recon_samples(vae_base, test_loader, n=20)
_, x_out_imp = recon_samples(vae_imp, test_loader, n=20)

to_img_grid(x_out_base, 2, 10, figsize=(12,3), title="Baseline reconstructions")
to_img_grid(x_out_imp, 2, 10, figsize=(12,3), title="Improved reconstructions")

In [None]:
# before vs after (sampling)
samp_base = sample_prior(vae_base, n=50)
samp_imp = sample_prior_conv(vae_imp, n=50)
to_img_grid(samp_base, 5, 10, figsize=(12,6), title="Baseline samples")
to_img_grid(samp_imp, 5, 10, figsize=(12,6), title="Improved samples")

We now examine our test metrics

In [None]:
def print_table(rows, headers):
    w = [max(len(h), max(len(str(r[i])) for r in rows)) for i, h in enumerate(headers)]
    line = " | ".join(h.ljust(w[i]) for i, h in enumerate(headers))
    sep = "-+-".join("-"*w[i] for i in range(len(headers)))
    print(line)
    print(sep)
    for r in rows:
        print(" | ".join(str(r[i]).ljust(w[i]) for i in range(len(headers))))

rows = [
    ["MLP VAE (base)", f"{base_test['recon_per_img']:.2f}", f"{base_test['kld_per_img']:.2f}", f"{base_test['total_per_img']:.2f}"],
    ["Conv VAE (imp)", f"{imp_test['recon_per_img']:.2f}", f"{imp_test['kld_per_img']:.2f}", f"{imp_test['total_per_img']:.2f}"],
]
print_table(rows, ["model", "recon/img", "kld/img", "total/img"])

# Phase 3: Controlling latent space

Here, by changing the regularization parameter beta, we search for the optimal hyperparameter.
First we define a function for our latent traversal.

In [None]:
@torch.no_grad()
def latent_traversal(model, x, dims, steps=7, span=3.0):
    model.eval()
    x = x.to(device)
    out = model(x)
    mu = out["mu"][0]
    z0 = mu.clone()

    vals = torch.linspace(-span, span, steps, device=device)
    grids = []
    for d in dims:
        zs = []
        for v in vals:
            z = z0.clone()
            z[d] = v
            zs.append(z.unsqueeze(0))
        zs = torch.cat(zs, dim=0)
        if isinstance(model, ConvVAE):
            g = model.dec_fc(zs).view(-1, 128, 7, 7)
            x_hat = model.dec_deconv(g)
        else:
            x_hat = model.dec(zs).view(-1, 1, 28, 28)
        grids.append(x_hat)
    return grids

Now we train our three models.

In [None]:
betas = [0.5, 1.0, 4.0]
beta_models = {}
beta_tests = {}

for b in betas:
    m = ConvVAE(z_dim=32).to(device)
    _ = train_vae(m, train_loader, val_loader, epochs=8, lr=2e-3, beta=b, warmup_epochs=4)
    beta_models[b] = m
    beta_tests[b] = evaluate_vae(m, test_loader, beta=b)

rows = [[f"beta={b}", f"{beta_tests[b]['recon_per_img']:.2f}", f"{beta_tests[b]['kld_per_img']:.2f}", f"{beta_tests[b]['total_per_img']:.2f}"] for b in betas]
print_table(rows, ["setting", "recon/img", "kld/img", "total/img"])

In [None]:
x0, _ = next(iter(test_loader))
x0 = x0[:1]
dims = [0, 3, 7, 12, 20]

for b in betas:
    grids = latent_traversal(beta_models[b], x0, dims=dims, steps=7, span=3.0)
    merged = torch.cat(grids, dim=0)  # (5*7, 1, 28, 28)
    to_img_grid(merged, 5, 7, figsize=(10,7), title=f"beta={b} | latent traversal (dims={dims}, span=-3..+3)")

# Phase 4: Image generation with labels

here we define our Conditional VAE.

`one_hot` is a simple function that makes a vector with size `num_classes` which is zero everywhere but in `y`th element which is one.

our model is a MLP VAE which concats vector in latent space with one-hot encoding of its label.

In [None]:
def one_hot(y, num_classes=10):
    y = y.long()
    oh = torch.zeros(y.size(0), num_classes, device=y.device)
    oh.scatter_(1, y.view(-1,1), 1.0)
    return oh

class MLPcVAE(nn.Module):
    def __init__(self, z_dim=20, h_dim=512, num_classes=10):
        super().__init__()
        self.z_dim = z_dim
        self.num_classes = num_classes

        self.enc = nn.Sequential(
            nn.Linear(28*28 + num_classes, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU()
        )
        self.mu = nn.Linear(h_dim, z_dim)
        self.logvar = nn.Linear(h_dim, z_dim)

        self.dec = nn.Sequential(
            nn.Linear(z_dim + num_classes, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, 28*28),
            nn.Sigmoid()
        )

    def reparam(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x, y):
        x = x.view(x.size(0), -1)
        yoh = one_hot(y, self.num_classes)
        h = self.enc(torch.cat([x, yoh], dim=1))
        mu = self.mu(h)
        logvar = self.logvar(h)
        z = self.reparam(mu, logvar)
        x_hat = self.dec(torch.cat([z, yoh], dim=1)).view(-1, 1, 28, 28)
        return {"x_hat": x_hat, "mu": mu, "logvar": logvar, "z": z}

@torch.no_grad()
def evaluate_cvae(model, loader, beta=1.0):
    model.eval()
    rec_sum, kld_sum, total_sum, n = 0.0, 0.0, 0.0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        out = model(x, y)
        x_hat, mu, logvar = out["x_hat"], out["mu"], out["logvar"]
        rec = F.binary_cross_entropy(x_hat, x, reduction='sum')
        kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        total = rec + beta * kld
        rec_sum += rec.item()
        kld_sum += kld.item()
        total_sum += total.item()
        n += x.size(0)
    return {"recon_per_img": rec_sum/n, "kld_per_img": kld_sum/n, "total_per_img": total_sum/n}

def train_cvae(model, train_loader, val_loader, epochs=10, lr=2e-3, beta=1.0, warmup_epochs=0):
    model.to(device)
    opt = Adam(model.parameters(), lr=lr)
    for ep in range(1, epochs+1):
        model.train()
        rec_sum, kld_sum, tot_sum, n = 0.0, 0.0, 0.0, 0
        beta_ep = beta * min(1.0, ep / warmup_epochs) if warmup_epochs > 0 else beta
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            out = model(x, y)
            x_hat, mu, logvar = out["x_hat"], out["mu"], out["logvar"]
            rec = F.binary_cross_entropy(x_hat, x, reduction='sum')
            kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            loss = rec + beta_ep * kld
            opt.zero_grad()
            loss.backward()
            opt.step()
            rec_sum += rec.item()
            kld_sum += kld.item()
            tot_sum += loss.item()
            n += x.size(0)
        va = evaluate_cvae(model, val_loader, beta=beta)
        print(f"ep {ep:02d} | beta={beta_ep:.3f} | train total {tot_sum/n:.2f} rec {rec_sum/n:.2f} kld {kld_sum/n:.2f} | val total {va['total_per_img']:.2f}")
    return model

In [None]:
cvae = MLPcVAE(z_dim=20, h_dim=512).to(device)
_ = train_cvae(cvae, train_loader, val_loader, epochs=12, lr=2e-3, beta=1.0, warmup_epochs=6)
cvae_test = evaluate_cvae(cvae, test_loader, beta=1.0)
cvae_test

here we are making 20 images with random latent space for each valid label.

In [None]:
@torch.no_grad()
def cvae_generate(model, labels, n_per_label=20):
    model.eval()
    ys = torch.tensor(labels, device=device).repeat_interleave(n_per_label)
    z = torch.randn(len(labels)*n_per_label, model.z_dim, device=device)
    yoh = one_hot(ys, 10)
    x_hat = model.dec(torch.cat([z, yoh], dim=1)).view(-1, 1, 28, 28)
    return x_hat, ys

# generate 20 per class (grid per class)
imgs, ys = cvae_generate(cvae, labels=list(range(10)), n_per_label=20)
plt.figure(figsize=(20,10))
idx = 0
for r in range(10):
    for c in range(20):
        plt.subplot(10,20,idx+1)
        plt.imshow(imgs[idx].squeeze(0).cpu(), cmap='gray')
        plt.axis('off')
        idx += 1
plt.suptitle('CVAE: 20 samples per class (rows=classes 0..9)')
plt.tight_layout()
plt.show()

### Classifier for controllability and Features

In [None]:
from torchvision.models import resnet18

class FashionResNet18(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        m = resnet18(weights=None)

        m.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
        m.maxpool = nn.Identity()

        m.fc = nn.Identity()
        self.backbone = m
        self.head = nn.Linear(512, num_classes)

    def forward(self, x: torch.Tensor):
        feats = self.backbone(x)
        logits = self.head(feats)
        return logits, feats

clf = FashionResNet18(num_classes=10).to(device)

ckpt = torch.load("./classifier/fashion_resnet18_classifier.pt", map_location=device)
state = ckpt["model_state_dict"] if isinstance(ckpt, dict) and "model_state_dict" in ckpt else ckpt
mean, std = float(ckpt["mean"][0]), float(ckpt["std"][0])

clf.load_state_dict(state, strict=True)

clf.eval()
for p in clf.parameters():
    p.requires_grad = False

In [None]:
@torch.no_grad()
def clf_acc_on_loader(clf, loader):
    clf.eval()
    corr, n = 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits, _ = clf(x)
        corr += (logits.argmax(1) == y).sum().item()
        n += y.numel()
    return corr / n

clf_acc_on_loader(clf, test_loader)

evaluating with the given classifier, for each class

In [None]:

def norm_for_clf(x):
    return (x - mean) / std

@torch.no_grad()
def cvae_accuracy_on_generated(model_cvae, clf, n_per_class=200):
    model_cvae.eval(); clf.eval()
    accs = []
    for k in range(10):
        imgs, ys = cvae_generate(model_cvae, labels=[k], n_per_label=n_per_class)
        x = norm_for_clf(imgs.to(device))
        logits, _ = clf(x)
        pred = logits.argmax(dim=1)
        accs.append((pred == ys.to(device)).float().mean().item())
    return accs, float(np.mean(accs))

accs, acc_mean = cvae_accuracy_on_generated(cvae, clf, n_per_class=200)
accs, acc_mean

In [None]:
# visualization
plt.figure(figsize=(7,3))
plt.bar(np.arange(10), accs)
plt.xticks(np.arange(10), classes, rotation=45, ha='right')
plt.ylim(0, 1.0)
plt.ylabel('accuracy')
plt.title(f'Classifier accuracy on CVAE generated images (mean={acc_mean:.3f})')
plt.tight_layout()
plt.show()

### FID using classifier features

In [None]:
def cov_np(x):
    x = x - x.mean(axis=0, keepdims=True)
    return (x.T @ x) / (x.shape[0] - 1)

def sqrtm_psd(A, eps=1e-9):
    A = 0.5 * (A + A.T)
    w, V = np.linalg.eigh(A)
    w = np.clip(w, eps, None)
    return (V * np.sqrt(w)) @ V.T

def fid_from_features(f1, f2):
    m1, m2 = f1.mean(axis=0), f2.mean(axis=0)
    C1, C2 = cov_np(f1), cov_np(f2)
    diff = m1 - m2
    prod = C1 @ C2
    covmean = sqrtm_psd(prod)
    return float(diff @ diff + np.trace(C1 + C2 - 2 * covmean))

@torch.no_grad()
def extract_features(loader, clf, max_items=None):
    clf.eval()
    feats = []
    seen = 0
    for x, _ in loader:
        x = x.to(device)
        _, f = clf(x)
        feats.append(f.cpu().numpy())
        seen += x.size(0)
        if max_items is not None and seen >= max_items:
            break
    feats = np.concatenate(feats, axis=0)
    if max_items is not None:
        feats = feats[:max_items]
    return feats

@torch.no_grad()
def extract_features_from_images(imgs, clf, batch=256):
    clf.eval()
    feats = []
    for i in range(0, imgs.size(0), batch):
        x = imgs[i:i+batch].to(device)
        _, f = clf(x)
        feats.append(f.cpu().numpy())
    return np.concatenate(feats, axis=0)

In [None]:
# generate 10,000 images from improved VAE
@torch.no_grad()
def generate_uncond_10000(model, n=10000, batch=256):
    model.eval()
    outs = []
    for i in range(0, n, batch):
        m = min(batch, n-i)
        if isinstance(model, ConvVAE):
            x_hat = sample_prior_conv(model, n=m)
        else:
            x_hat = sample_prior(model, n=m)
        outs.append(x_hat.cpu())
    return torch.cat(outs, dim=0)

gen_10k = generate_uncond_10000(vae_imp, n=10000, batch=256)
gen_10k.shape

In [None]:
f_real = extract_features(test_loader, clf, max_items=10000)
f_fake = extract_features_from_images(gen_10k, clf, batch=256)
fid_proxy = fid_from_features(f_real, f_fake)
fid_proxy

# Phase 5: Report

> you can read `README.md` for more explaination of the process in this notebook.

In [None]:
final_rows = [
    ["MLP VAE (base)", f"{base_test['recon_per_img']:.2f}", f"{base_test['kld_per_img']:.2f}", f"{base_test['total_per_img']:.2f}", "-"],
    ["Conv VAE (imp)", f"{imp_test['recon_per_img']:.2f}", f"{imp_test['kld_per_img']:.2f}", f"{imp_test['total_per_img']:.2f}", f"{fid_proxy:.2f}"],
    ["CVAE (label)", f"{cvae_test['recon_per_img']:.2f}", f"{cvae_test['kld_per_img']:.2f}", f"{cvae_test['total_per_img']:.2f}", "-"],
]
print_table(final_rows, ["model", "recon/img", "kld/img", "total/img", "FID proxy"])