# 單元 1：安裝環境

In [1]:
# Cell 1: Setup Environment
# We need pytorch-fid for evaluation 
!pip install pytorch-fid

Collecting pytorch-fid
  Downloading pytorch_fid-0.3.0-py3-none-any.whl.metadata (5.3 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.0.1->pytorch-fid)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.0.1->pytorch-fid)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.0.1->pytorch-fid)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.0.1->pytorch-fid)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.0.1->pytorch-fid)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torc

# 單元 2：導入必要的庫

In [2]:
# Cell 2: Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
from tqdm.auto import tqdm
import numpy as np
import os
import zipfile
from PIL import Image

# For FID calculation
from pytorch_fid.fid_score import calculate_fid_given_paths

# 單元 3：設定超參數與配置

In [31]:
# Cell 3: Hyperparameters & Configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# Dataset
IMG_SIZE = 28 # MNIST original size 
CHANNELS = 3  # As required by slides, convert to RGB 
IMG_SHAPE = (CHANNELS, IMG_SIZE, IMG_SIZE)

# Diffusion
TIMESTEPS = 1000 # Example: 1000 steps
BETA_START = 1e-4
BETA_END = 0.02

# Training 
EPOCHS = 30 # Adjust based on your training time
BATCH_SIZE = 128
LR = 2e-4

# Output directories
OUTPUT_DIR = "output_images"
REPORT_DIR = "report_visuals"
MODEL_PATH = "ddpm_mnist.pth"
FID_STATS_PATH = "mnist.npz" # Precalculated stats 

# Create output directories
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(REPORT_DIR, exist_ok=True)

Using device: cuda


# 單元 4：準備 MNIST 數據集

In [14]:
# Cell 4: Dataset & DataLoader

# We need to convert grayscale MNIST to RGB 
# We also normalize images to [-1, 1] as is common for diffusion models
transform = transforms.Compose([
    transforms.Resize(IMG_SIZE), # Ensure 28x28
    transforms.Grayscale(num_output_channels=3), # Convert to RGB 
    transforms.ToTensor(), # Scales to [0, 1]
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # Scales to [-1, 1]
])

# Download training data [cite: 37]
train_dataset = datasets.MNIST(
    root="./data",
    train=True,
    download=True,
    transform=transform
)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    drop_last=True
)

print(f"Loaded {len(train_dataset)} training images.")

Loaded 60000 training images.


# 單元 5：下載 FID 預計算統計數據

In [28]:
# Cell 5: Download pre-calculated FID stats
# This file is provided in the slides 
# We download it from a common public host for that file.
!wget https://github.com/bioinf-jku/TTUR/releases/download/v1.0/fid_stats_mnist.npz -O {FID_STATS_PATH}
print(f"Downloaded {FID_STATS_PATH} for FID calculation.")

--2025-11-18 10:47:51--  https://github.com/bioinf-jku/TTUR/releases/download/v1.0/fid_stats_mnist.npz
Resolving github.com (github.com)... 4.237.22.38
Connecting to github.com (github.com)|4.237.22.38|:443... connected.
HTTP request sent, awaiting response... 404 Not Found
2025-11-18 10:47:52 ERROR 404: Not Found.

Downloaded mnist.npz for FID calculation.


# 單元 6：定義擴散模型 (U-Net)

In [16]:
# Cell 6: Diffusion Model (U-Net) [cite: 125]

# --- Positional Embedding for Timesteps ---
# We need to encode the timestep 't' so the model knows which noise level it's handling
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        emb = np.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

# --- U-Net Building Block ---
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim, dropout=0.1):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        
        self.shortcut = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()

    def forward(self, x, t_emb):
        h = self.relu(self.bn1(self.conv1(x)))
        
        # Inject time embedding
        t_emb_proj = self.relu(self.time_mlp(t_emb))
        h = h + t_emb_proj.unsqueeze(-1).unsqueeze(-1) # (B, C, 1, 1)
        
        h = self.dropout(h)
        h = self.relu(self.bn2(self.conv2(h)))
        
        return h + self.shortcut(x)

# --- U-Net Architecture ---
# This U-Net predicts the noise added to the image
class UNet(nn.Module):
    def __init__(self, img_channels=3, base_channels=64, time_emb_dim=256):
        super().__init__()
        
        # Time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU()
        )
        
        # Encoder (Downsampling)
        self.init_conv = nn.Conv2d(img_channels, base_channels, kernel_size=3, padding=1)
        self.down1 = ResBlock(base_channels, base_channels, time_emb_dim)
        self.down_pool1 = nn.MaxPool2d(2) # 28x28 -> 14x14
        self.down2 = ResBlock(base_channels, base_channels * 2, time_emb_dim)
        self.down_pool2 = nn.MaxPool2d(2) # 14x14 -> 7x7

        # Bottleneck
        self.bot1 = ResBlock(base_channels * 2, base_channels * 2, time_emb_dim)

        # Decoder (Upsampling)
        self.up_pool1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) # 7x7 -> 14x14
        self.up1 = ResBlock(base_channels * 4, base_channels, time_emb_dim) # Skip connection
        self.up_pool2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) # 14x14 -> 28x28
        self.up2 = ResBlock(base_channels * 2, base_channels, time_emb_dim) # Skip connection
        
        # Final output convolution
        self.out_conv = nn.Conv2d(base_channels, img_channels, kernel_size=1)

    def forward(self, x, t):
        # x is (B, C, H, W)
        # t is (B,)
        
        t_emb = self.time_mlp(t)
        
        # Encoder
        x1 = self.init_conv(x)
        x2 = self.down1(x1, t_emb)
        x3 = self.down_pool1(x2)
        x4 = self.down2(x3, t_emb)
        x5 = self.down_pool2(x4)
        
        # Bottleneck
        x_bot = self.bot1(x5, t_emb)
        
        # Decoder
        x_up1 = self.up_pool1(x_bot)
        x_up1 = torch.cat([x_up1, x4], dim=1) # Skip connection from x4
        x_up1 = self.up1(x_up1, t_emb)

        x_up2 = self.up_pool2(x_up1)
        x_up2 = torch.cat([x_up2, x2], dim=1) # Skip connection from x2
        x_up2 = self.up2(x_up2, t_emb)
        
        output = self.out_conv(x_up2)
        return output

# 單元 7：定義擴散輔助函數

In [17]:
# Cell 7: Diffusion Helper Functions

# --- Pre-calculate diffusion constants ---
# This is the "variance schedule"
def get_betas(timesteps, beta_start, beta_end):
    return torch.linspace(beta_start, beta_end, timesteps, device=DEVICE)

BETAS = get_betas(TIMESTEPS, BETA_START, BETA_END)
ALPHAS = 1. - BETAS
ALPHAS_CUMPROD = torch.cumprod(ALPHAS, dim=0)
ALPHAS_CUMPROD_PREV = F.pad(ALPHAS_CUMPROD[:-1], (1, 0), value=1.0)
SQRT_RECIP_ALPHAS = torch.sqrt(1.0 / ALPHAS)

# Calculations for forward diffusion q(x_t | x_0)
SQRT_ALPHAS_CUMPROD = torch.sqrt(ALPHAS_CUMPROD)
SQRT_ONE_MINUS_ALPHAS_CUMPROD = torch.sqrt(1. - ALPHAS_CUMPROD)

# Calculations for reverse diffusion q(x_{t-1} | x_t, x_0)
POSTERIOR_VARIANCE = BETAS * (1. - ALPHAS_CUMPROD_PREV) / (1. - ALPHAS_CUMPROD)

# --- Helper function to extract correct values for a batch ---
def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t)
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

# --- Forward process q(x_t | x_0) ---
# This adds noise to an image x_0 to get x_t
def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)

    sqrt_alphas_cumprod_t = extract(SQRT_ALPHAS_CUMPROD, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(SQRT_ONE_MINUS_ALPHAS_CUMPROD, t, x_start.shape)

    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

# --- Loss Function ---
# We use L1 loss, as it's often more stable than L2 (MSE) 
loss_fn = nn.L1Loss()

def p_losses(denoise_model, x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)

    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    predicted_noise = denoise_model(x_noisy, t)

    return loss_fn(noise, predicted_noise)

# 單元 8：模型初始化與優化器

In [18]:
# Cell 8: Model Initialization and Optimizer

model = UNet(
    img_channels=CHANNELS,
    base_channels=64,
    time_emb_dim=256
).to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LR)

print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")

Model parameters: 1103363


# 單元 9：訓練循環

In [32]:
# Cell 9: Training Loop 

print("Starting training...")
for epoch in range(EPOCHS):
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=True)
    for step, (images, _) in enumerate(progress_bar):
        optimizer.zero_grad()
        
        batch_size = images.shape[0]
        images = images.to(DEVICE)
        
        # Sample random timesteps t for each image in the batch
        t = torch.randint(0, TIMESTEPS, (batch_size,), device=DEVICE).long()
        
        # Calculate loss
        loss = p_losses(model, images, t)
        
        # Backprop
        loss.backward()
        optimizer.step()
        
        progress_bar.set_postfix(loss=loss.item())

    print(f"Epoch {epoch+1} | Loss: {loss.item():.4f}")

# Save the trained model
torch.save(model.state_dict(), MODEL_PATH)
print(f"Model saved to {MODEL_PATH}")

Starting training...


Epoch 1/30:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 1 | Loss: 0.0486


Epoch 2/30:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 2 | Loss: 0.0410


Epoch 3/30:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 3 | Loss: 0.0373


Epoch 4/30:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 4 | Loss: 0.0450


Epoch 5/30:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 5 | Loss: 0.0572


Epoch 6/30:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 6 | Loss: 0.0442


Epoch 7/30:   0%|          | 0/468 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b90e0208860>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^Exception ignored in: ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7b90e0208860>^

  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
Traceback (most recent call last):
    assert self._parent_pid == os.getpid(), 'can only test a child process'  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__

    self._shutdown_workers() 
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
       if w.is_alive(): 
         ^ ^ ^ ^^^^^^^^^^^^^^^^^^^^^^

Epoch 7 | Loss: 0.0446


Epoch 8/30:   0%|          | 0/468 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b90e0208860><function _MultiProcessingDataLoaderIter.__del__ at 0x7b90e0208860>

Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
        self._shutdown_workers()
self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():    
if w.is_alive(): 
            ^ ^^^^^^^^^^^^^^^^^^^^
^  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
^    ^
assert self._parent_pid == os.getpid(), 'can only test a child process'
  File "/usr/lib/python

Epoch 8 | Loss: 0.0482


Epoch 9/30:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 9 | Loss: 0.0422


Epoch 10/30:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 10 | Loss: 0.0346


Epoch 11/30:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 11 | Loss: 0.0436


Epoch 12/30:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 12 | Loss: 0.0400


Epoch 13/30:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 13 | Loss: 0.0447


Epoch 14/30:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 14 | Loss: 0.0379


Epoch 15/30:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 15 | Loss: 0.0383


Epoch 16/30:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 16 | Loss: 0.0383


Epoch 17/30:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 17 | Loss: 0.0402


Epoch 18/30:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 18 | Loss: 0.0400


Epoch 19/30:   0%|          | 0/468 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b90e0208860>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b90e0208860>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch 19 | Loss: 0.0388


Epoch 20/30:   0%|          | 0/468 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b90e0208860>
Traceback (most recent call last):
Exception ignored in:   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x7b90e0208860>    
self._shutdown_workers()Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__

      File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
self._shutdown_workers()
      File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
if w.is_alive():    
 if w.is_alive(): 
          ^ ^ ^^^^^^^^^^^^^^^^^^^^^
^  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive

      File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
assert self._par

Epoch 20 | Loss: 0.0319


Epoch 21/30:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 21 | Loss: 0.0390


Epoch 22/30:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 22 | Loss: 0.0346


Epoch 23/30:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 23 | Loss: 0.0328


Epoch 24/30:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 24 | Loss: 0.0335


Epoch 25/30:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 25 | Loss: 0.0354


Epoch 26/30:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 26 | Loss: 0.0368


Epoch 27/30:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 27 | Loss: 0.0366


Epoch 28/30:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 28 | Loss: 0.0401


Epoch 29/30:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 29 | Loss: 0.0374


Epoch 30/30:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 30 | Loss: 0.0399
Model saved to ddpm_mnist.pth


# 單元 10：採樣 (Sampling) 函數

In [38]:
# Cell 10: Sampling Functions (Reverse Process)

# --- p_sample ---
# This function samples x_{t-1} given the model's prediction for x_t
@torch.no_grad()
def p_sample(model, x, t, t_index):
    betas_t = extract(BETAS, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(SQRT_ONE_MINUS_ALPHAS_CUMPROD, t, x.shape)
    sqrt_recip_alphas_t = extract(SQRT_RECIP_ALPHAS, t, x.shape)
    
    # Model predicts noise
    predicted_noise = model(x, t)
    
    # Calculate mean
    model_mean = sqrt_recip_alphas_t * (x - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t)
    
    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(POSTERIOR_VARIANCE, t, x.shape)
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise

# --- p_sample_loop (The full reverse process) ---
# This starts from pure noise (x_T) and iterates down to x_0 [cite: 17, 23]
@torch.no_grad()
def p_sample_loop(model, shape, return_process=False):
    device = next(model.parameters()).device
    
    b = shape[0]
    # Start from pure noise x_T [cite: 17]
    img = torch.randn(shape, device=device)
    
    imgs = []
    
    for i in tqdm(reversed(range(0, TIMESTEPS)), desc="Sampling", total=TIMESTEPS, leave=False):
        t = torch.full((b,), i, device=device, dtype=torch.long)
        img = p_sample(model, img, t, i)
        
        if return_process and i % (TIMESTEPS // 7) == 0:
            imgs.append(img.cpu())
            
    if return_process:
        imgs.append(img.cpu()) # Add final image
        return imgs
    
    return img

# --- Helper to denormalize images from [-1, 1] to [0, 255] for saving ---
def denormalize_image(img):
    img = (img.clamp(-1, 1) + 1) / 2 # from [-1, 1] to [0, 1]
    img = (img * 255).type(torch.uint8)
    return img

# 單元 11：生成 10,000 張圖像
>使用訓練好的模型生成 10,000 張圖像 ，並將它們保存為 00001.png 到 10000.png 。

In [39]:
# Cell 11: Generate 10,000 Images for Submission

print(f"Generating 10,000 images in {OUTPUT_DIR}...")

# Load the trained model if not already in memory
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

generated_count = 0
generation_batch_size = 64 # Adjust based on GPU memory

while generated_count < 10000:
    # Determine batch size for this iteration
    current_batch_size = min(generation_batch_size, 10000 - generated_count)
    if current_batch_size == 0:
        break
        
    shape = (current_batch_size, CHANNELS, IMG_SIZE, IMG_SIZE)
    generated_images = p_sample_loop(model, shape)
    
    # Denormalize
    generated_images = denormalize_image(generated_images.cpu())
    
    # Save images
    for i in range(current_batch_size):
        img_idx = generated_count + i + 1
        if img_idx > 10000:
            break
            
        img_filename = f"{img_idx:05d}.png" # Format: 00001.png 
        img_path = os.path.join(OUTPUT_DIR, img_filename)
        
        # Convert to PIL Image (C, H, W) -> (H, W, C)
        pil_img = Image.fromarray(generated_images[i].permute(1, 2, 0).numpy())
        pil_img.save(img_path)
    
    generated_count += current_batch_size
    print(f"Generated {generated_count} / 10000 images...")

print("Image generation complete.")

Generating 10,000 images in output_images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 64 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 128 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 192 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 256 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 320 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 384 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 448 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 512 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 576 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 640 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 704 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 768 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 832 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 896 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 960 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 1024 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 1088 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 1152 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 1216 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 1280 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 1344 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 1408 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 1472 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 1536 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 1600 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 1664 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 1728 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 1792 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 1856 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 1920 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 1984 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 2048 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 2112 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 2176 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 2240 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 2304 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 2368 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 2432 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 2496 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 2560 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 2624 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 2688 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 2752 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 2816 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 2880 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 2944 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 3008 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 3072 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 3136 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 3200 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 3264 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 3328 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 3392 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 3456 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 3520 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 3584 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 3648 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 3712 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 3776 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 3840 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 3904 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 3968 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 4032 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 4096 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 4160 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 4224 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 4288 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 4352 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 4416 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 4480 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 4544 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 4608 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 4672 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 4736 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 4800 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 4864 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 4928 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 4992 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 5056 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 5120 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 5184 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 5248 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 5312 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 5376 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 5440 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 5504 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 5568 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 5632 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 5696 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 5760 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 5824 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 5888 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 5952 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 6016 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 6080 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 6144 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 6208 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 6272 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 6336 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 6400 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 6464 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 6528 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 6592 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 6656 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 6720 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 6784 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 6848 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 6912 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 6976 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 7040 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 7104 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 7168 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 7232 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 7296 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 7360 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 7424 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 7488 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 7552 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 7616 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 7680 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 7744 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 7808 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 7872 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 7936 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 8000 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 8064 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 8128 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 8192 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 8256 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 8320 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 8384 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 8448 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 8512 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 8576 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 8640 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 8704 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 8768 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 8832 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 8896 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 8960 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 9024 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 9088 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 9152 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 9216 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 9280 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 9344 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 9408 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 9472 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 9536 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 9600 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 9664 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 9728 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 9792 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 9856 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 9920 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 9984 / 10000 images...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Generated 10000 / 10000 images...
Image generation complete.


# 單元 12：生成報告可視化圖像
> 此單元生成報告所需的擴散過程可視化網格 。

In [40]:
# Cell 12: Generate Report Visualization 

print("Generating visualization grid for report...")
model.eval()

# 1. Generate 8 different samples [cite: 138]
num_report_samples = 8
shape = (num_report_samples, CHANNELS, IMG_SIZE, IMG_SIZE)

# 2. Record 7 intermediate steps + noise + final image (total 9)
# We ask for the process back
# The slides ask for 7 equal parts[cite: 139], which means 8 images (start, 7 steps, end)
# Our p_sample_loop modification saves 7 steps + final. We'll add noise at the start.
process_imgs = p_sample_loop(model, shape, return_process=True)

# 3. Arrange into the grid [cite: 140]
print(f"Process recorded {len(process_imgs)} steps.")

# We want 8 rows (samples) and ~8 columns (steps)
# Let's re-organize: process_imgs is a list of [B, C, H, W] tensors
# We want [Step, Sample, C, H, W]
all_steps = torch.stack(process_imgs) # (NumSteps, B, C, H, W)
all_steps = all_steps.permute(1, 0, 2, 3, 4) # (B, NumSteps, C, H, W)

# We only need 8 samples, and 8 steps (e.g., T=1000, 875, 750... 0)
# Our list `process_imgs` has 8 tensors (7 steps + 1 final)
# Let's take 8 samples from the batch
final_grid_imgs = []
for i in range(num_report_samples): # Iterate over samples
    for j in range(len(process_imgs)): # Iterate over steps (reversed)
        final_grid_imgs.append(all_steps[i, j])

# Denormalize all
final_grid_imgs = [denormalize_image(img.unsqueeze(0)).squeeze(0) for img in final_grid_imgs]

# Create the 8x8 grid [cite: 140] (8 samples, 8 steps)
grid = make_grid(final_grid_imgs, nrow=len(process_imgs), padding=1)

# Save the grid
grid_pil = Image.fromarray(grid.permute(1, 2, 0).numpy())
grid_path = os.path.join(REPORT_DIR, "diffusion_process_grid.png")
grid_pil.save(grid_path)

print(f"Saved report visualization to {grid_path}")

Generating visualization grid for report...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Process recorded 9 steps.
Saved report visualization to report_visuals/diffusion_process_grid.png


# 單元 13：壓縮提交檔案

In [41]:
# Cell 13: Zip Files for Submission

STUDENT_ID = "R12345678" # !!! <--- 請在這裡填入您的學號
ZIP_FILENAME = f"img_{STUDENT_ID}.zip" # 

print(f"Zipping images into {ZIP_FILENAME}...")

with zipfile.ZipFile(ZIP_FILENAME, 'w', zipfile.ZIP_DEFLATED) as zipf:
    # Ensure no subdirectory is created 
    for img_file in sorted(os.listdir(OUTPUT_DIR)):
        if img_file.endswith(".png"):
            # arcname=img_file ensures it's at the root of the zip
            zipf.write(os.path.join(OUTPUT_DIR, img_file), arcname=img_file)

print(f"Created {ZIP_FILENAME} successfully.")

Zipping images into img_R12345678.zip...
Created img_R12345678.zip successfully.


# 單元 14：計算 FID 分數

In [None]:
# Cell A: Save Training Images for FID Reference
# Alternative method per Slide 7[cite: 63]: Use training dataset directly

REF_DIR = "mnist_train_ref"
os.makedirs(REF_DIR, exist_ok=True)

print(f"Saving training images to {REF_DIR} for FID calculation...")

# We use the train_dataset loaded in Cell 4
# We need to reverse the normalization to save as PNG
for idx, (img_tensor, _) in enumerate(tqdm(train_dataset, desc="Saving Ref Images")):
    # Denormalize: [-1, 1] -> [0, 1]
    img_tensor = (img_tensor * 0.5) + 0.5
    # [0, 1] -> [0, 255]
    img_tensor = (img_tensor * 255).clamp(0, 255).to(torch.uint8)
    
    # Save as PNG
    pil_img = Image.fromarray(img_tensor.permute(1, 2, 0).numpy())
    pil_img.save(os.path.join(REF_DIR, f"{idx:05d}.png"))

print("Reference images saved.")

In [42]:
# Cell 14: Calculate FID Score

print("Calculating FID Score... This may take a few minutes.")

# We will calculate FID against the pre-calculated stats [cite: 66]
fid_value = calculate_fid_given_paths(
    paths=[OUTPUT_DIR, REF_DIR],
    batch_size=50,
    device=DEVICE,
    dims=2048,
    num_workers=2
)

print(f"\n" + "="*30)
print(f"      FINAL FID SCORE: {fid_value:.4f}")
print("="*30)

# Print grading policy [cite: 75]
if fid_value < 30:
    print("Points: 90")
elif 30 <= fid_value <= 70:
    points = 60 + (90 - 60) * (70 - fid_value) / (70 - 30)
    print(f"Points (Linear 60-90): {points:.2f}")
elif 70 < fid_value <= 100:
    points = 0 + (60 - 0) * (100 - fid_value) / (100 - 70)
    print(f"Points (Linear 0-60): {points:.2f}")
else:
    print("Points: 0")

Calculating FID Score... This may take a few minutes.


100%|██████████| 200/200 [00:21<00:00,  9.42it/s]
100%|██████████| 1200/1200 [02:06<00:00,  9.49it/s]



      FINAL FID SCORE: 19.9555
Points: 90
