# Evaluation on Generated Images

Purpose: Quantitatively validate semantic consistency of VAE outputs.

Includes:
- Accuracy on real images
- Accuracy on generated images


In [1]:
import sys
from pathlib import Path
import torch
import torch.nn as nn
from torchvision import models
import numpy as np

current = Path().resolve()
while not (current / "src").exists():
    current = current.parent

sys.path.append(str(current))
print("Project root:", current)


  warn(


Project root: /workspace


In [2]:
from src.models.vae import ConvVAE
from src.datasets.grayscale_datasets import get_grayscale_loader


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

latent_dim = 32
num_samples = 1000


Using device: cpu


In [13]:
DATASET = "fashion"     # change to "fashion" later
NUM_CLASSES = 10


In [14]:
import torch.nn as nn
from torchvision import models

cnn = models.resnet18(weights=None)
cnn.conv1 = nn.Conv2d(
    1, 64,
    kernel_size=7,
    stride=2,
    padding=3,
    bias=False
)
cnn.fc = nn.Linear(cnn.fc.in_features, NUM_CLASSES)

cnn_ckpt = current / "checkpoints" / "grayscale" / f"resnet18_{DATASET}.pt"
cnn.load_state_dict(torch.load(cnn_ckpt, map_location=device))
cnn = cnn.to(device)
cnn.eval()

print("Loaded CNN:", cnn_ckpt.name)


Loaded CNN: resnet18_fashion.pt


In [15]:
from src.models.vae import ConvVAE

latent_dim = 32
vae = ConvVAE(latent_dim=latent_dim).to(device)

vae_ckpt = current / "checkpoints" / "grayscale" / f"vae_{DATASET}_sharp_64.pt"
vae.load_state_dict(torch.load(vae_ckpt, map_location=device))
vae.eval()

print("Loaded SHARP VAE:", vae_ckpt.name)


Loaded SHARP VAE: vae_fashion_sharp_64.pt


In [16]:
def decode_from_latent(model, z):
    with torch.no_grad():
        h = model.decoder.fc(z)
        h = h.view(z.size(0), 128, 7, 7)
        out = model.decoder.deconv(h)
    return out


In [17]:
num_samples = 1000

with torch.no_grad():
    z = torch.randn(num_samples, latent_dim).to(device)
    images = decode_from_latent(vae, z)
    images = (images + 1) / 2   # normalize to [0,1]

print("Generated images:", images.shape)


Generated images: torch.Size([1000, 1, 28, 28])


In [18]:
import torch.nn.functional as F

images_64 = F.interpolate(
    images,
    size=(64, 64),
    mode="bilinear",
    align_corners=False
)

print("Upscaled images:", images_64.shape)


Upscaled images: torch.Size([1000, 1, 64, 64])


In [19]:
with torch.no_grad():
    logits = cnn(images_64)
    preds = logits.argmax(dim=1)

preds = preds.cpu().numpy()


In [20]:
import numpy as np

unique, counts = np.unique(preds, return_counts=True)

print(f"\nPrediction distribution on SHARP VAE ({DATASET.upper()}):")
for u, c in zip(unique, counts):
    print(f"Class {u}: {c}")



Prediction distribution on SHARP VAE (FASHION):
Class 3: 988
Class 6: 12
