In [10]:
from torch import nn
from diffusers import DDPMScheduler, UNet2DModel
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from torchvision import datasets, transforms
CONFIG = {
    "image_size": 32,
    "batch_size_train": 32,
    "batch_size_eval": 32,
    "num_epochs": 5,
    "lr": 1e-4,
    "timesteps": 1000,
    "cfg_scale": 1,
    "device": "cuda" if torch.cuda.is_available() else "cpu"
}
transform = transforms.Compose([
    transforms.Resize((CONFIG["image_size"], CONFIG["image_size"])),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]) 
])
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=CONFIG["batch_size_train"], shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=CONFIG["batch_size_eval"], shuffle=False)
noise_scheduler = DDPMScheduler(
    num_train_timesteps=CONFIG["timesteps"],
    beta_start=0.0001,
    beta_end=0.02,
    beta_schedule="linear"
)
class ClassConditionedUNet(nn.Module):
    def __init__(self, num_classes=10, image_size=32):
        super().__init__()
        self.label_dropout_prob = 0.1
        self.null_class_id = num_classes
        self.unet = UNet2DModel(
            sample_size=image_size,
            in_channels=1,
            out_channels=1,
            layers_per_block=2,
            block_out_channels=(128, 256, 256),
            down_block_types=(
                "DownBlock2D", 
                "AttnDownBlock2D",
                "DownBlock2D"
            ),
            up_block_types=(
                "UpBlock2D", 
                "AttnUpBlock2D",
                "UpBlock2D"
            )
            ,
            class_embed_type="timestep",
            num_class_embeds=num_classes + 1,
        )
    
    def forward(self, x, t, class_labels=None):
        if class_labels is None:
            raise ValueError("class_labels must be provided.")
        
        if self.training:
            drop_mask = torch.rand(class_labels.shape[0], device=x.device) < self.label_dropout_prob
            class_labels = class_labels.clone()
            class_labels[drop_mask] = self.null_class_id
        
        return self.unet(x, t, class_labels=class_labels).sample


Evaluate

In [11]:

model = ClassConditionedUNet(num_classes=10, image_size=CONFIG["image_size"])
model.load_state_dict(torch.load('mnist_diffusion_model.pth'))
model=model.to(CONFIG["device"])

def to_rgb_grayscale(tensor):
    return tensor.repeat(1, 3, 1, 1)

@torch.no_grad()
def generate_samples(model, noise_scheduler, num_samples=1000, guidance_scale=7.5, 
                     num_classes=10, image_size=32, device="cuda"):
    model.eval()
    class_labels = torch.randint(0, num_classes, (num_samples,), device=device)
    generator = torch.Generator(device=device).manual_seed(42)
    latents = torch.randn(
        num_samples, 1, image_size, image_size,
        generator=generator, device=device
    )
    for t in tqdm(noise_scheduler.timesteps, desc="Sampling"):
        latent_model_input = torch.cat([latents] * 2)
        latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t)
        uncond_latents = latent_model_input[:num_samples]
        cond_latents = latent_model_input[num_samples:]
        null_labels = torch.full_like(class_labels, model.null_class_id)
        noise_pred_uncond = model(uncond_latents, t, class_labels=null_labels)
        noise_pred_cond = model(cond_latents, t, class_labels=class_labels)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
        latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
    samples = (latents.clamp(-1, 1) + 1) / 2
    return samples, class_labels

def evaluate(model, noise_scheduler, test_loader, device, num_samples=1):
    print("Starting Evaluation...")
    print(f"Generating {num_samples} samples with CFG scale = {CONFIG['cfg_scale']}...")
    gen_samples, gen_labels = generate_samples(
        model=model,
        noise_scheduler=noise_scheduler,
        num_samples=num_samples,
        guidance_scale=CONFIG["cfg_scale"],
        num_classes=10,
        image_size=CONFIG["image_size"],
        device=device
    )
    print("Loading real test images...")
    real_images = []
    for images, _ in test_loader:
        images = images.to(device)
        images = (images + 1) / 2
        real_images.append(images)
        if len(real_images) * CONFIG["batch_size_eval"] >= num_samples:
            break
    real_images = torch.cat(real_images, dim=0)[:num_samples]
    save_sample_grid(gen_samples[:64], "generated_samples.png",nrow=1)
    gen_rgb = to_rgb_grayscale(gen_samples)
    real_rgb = to_rgb_grayscale(real_images)
    gen_rgb = gen_rgb.float()
    real_rgb = real_rgb.float()
    print("Computing Inception Score...")
    is_metric = InceptionScore(normalize=True).to(device)
    is_metric.update(gen_rgb)
    is_mean, is_std = is_metric.compute()
    print("Computing FID...")
    fid_metric = FrechetInceptionDistance(feature=2048, normalize=True).to(device)
    fid_metric.update(real_rgb, real=True)
    fid_metric.update(gen_rgb, real=False)
    fid_score = fid_metric.compute()
    metrics = {
        "inception_score_mean": is_mean.item(),
        "inception_score_std": is_std.item(),
        "fid_score": fid_score.item()
    }
    
    print("Evaluation Results:")
    print(f"   Inception Score: {is_mean:.4f} ± {is_std:.4f}")
    print(f"   FID: {fid_score:.4f}")
    save_sample_grid(gen_samples[:64], "generated_samples.png")
    
    return metrics

def save_sample_grid(samples, filename, nrow=8):
    from torchvision.utils import make_grid
    grid = make_grid(samples[:64], nrow=nrow, padding=2, normalize=True, value_range=(0, 1))
    grid_img = transforms.ToPILImage()(grid.cpu())
    grid_img.save(filename)
print("\n" + "="*50)
print("Running Final Evaluation")
print("="*50)
model.eval()
eval_metrics = evaluate(
    model=model,
    noise_scheduler=noise_scheduler,
    test_loader=test_loader,
    device=CONFIG["device"],
    num_samples=1000
)


Running Final Evaluation
Starting Evaluation...
Generating 1000 samples with CFG scale = 1...


Sampling: 100%|██████████| 1000/1000 [17:01<00:00,  1.02s/it]


Loading real test images...
Computing Inception Score...
Computing FID...
Evaluation Results:
   Inception Score: 1.9070 ± 0.0536
   FID: 23.3124
