In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SinusoidalPositionEmbeddings(nn.Module):
    """
    Helper module to generate sinusoidal timestep embeddings.
    """
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class ResidualBlock(nn.Module):
    """
    A standard residual block for the U-Net, which also incorporates
    time and class embeddings.
    """
    def __init__(self, in_channels, out_channels, time_emb_dim, num_classes=None, dropout=0.1):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.GroupNorm(32, in_channels), # This GroupNorm applies to the input of conv1
            nn.SiLU(),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        )

        self.conv2 = nn.Sequential(
            nn.GroupNorm(32, out_channels), # This GroupNorm applies to the output of conv1 and input of conv2
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        )

        # Linear layers for time and class embeddings
        self.time_mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, out_channels)
        )

        self.class_mlp = None
        if num_classes is not None:
            self.class_mlp = nn.Sequential(
                nn.SiLU(),
                nn.Linear(time_emb_dim, out_channels) # Use time_emb_dim for class emb as well
            )

        # Shortcut connection
        self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()

    def forward(self, x, t_emb, c_emb=None):
        h = self.conv1(x)

        # Add time embedding
        time_cond = self.time_mlp(t_emb)
        h = h + time_cond.unsqueeze(-1).unsqueeze(-1) # Reshape to (B, C, 1, 1)

        # Add class embedding
        if c_emb is not None and self.class_mlp is not None:
            class_cond = self.class_mlp(c_emb)
            h = h + class_cond.unsqueeze(-1).unsqueeze(-1)

        h = self.conv2(h)
        return h + self.shortcut(x)

class AttentionBlock(nn.Module):
    """
    A simple multi-head self-attention block.
    """
    def __init__(self, channels, num_heads=4):
        super().__init__()
        self.channels = channels
        self.num_heads = num_heads

        self.norm = nn.GroupNorm(32, channels)
        self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1, bias=False)
        self.attention = nn.MultiheadAttention(embed_dim=channels, num_heads=num_heads, batch_first=True)
        self.proj_out = nn.Conv2d(channels, channels, kernel_size=1)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.norm(x)

        qkv = self.qkv(h)
        qkv = qkv.reshape(B, 3 * C, H * W).permute(0, 2, 1) # (B, H*W, 3*C)
        q, k, v = qkv.chunk(3, dim=-1) # (B, H*W, C) each

        # Use torch.nn.MultiheadAttention
        # Note: MHA expects (L, N, E) or (N, L, E) if batch_first=True
        # Here N=B, L=H*W, E=C
        attn_output, _ = self.attention(q, k, v)

        attn_output = attn_output.permute(0, 2, 1).reshape(B, C, H, W) # (B, C, H, W)
        return x + self.proj_out(attn_output)

class DownBlock(nn.Module):
    """
    A downsampling block in the U-Net.
    (ResidualBlock -> ResidualBlock -> Attention -> Downsample)
    """
    def __init__(self, in_channels, out_channels, time_emb_dim, num_classes, has_attn=False, dropout=0.1):
        super().__init__()
        self.res1 = ResidualBlock(in_channels, out_channels, time_emb_dim, num_classes, dropout)
        self.res2 = ResidualBlock(out_channels, out_channels, time_emb_dim, num_classes, dropout)
        self.attn = AttentionBlock(out_channels) if has_attn else nn.Identity()
        self.downsample = nn.Conv2d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)

    def forward(self, x, t_emb, c_emb):
        x = self.res1(x, t_emb, c_emb)
        x = self.res2(x, t_emb, c_emb)
        x = self.attn(x)
        skip = x # Save for skip connection
        out = self.downsample(x)
        return out, skip

class UpBlock(nn.Module):
    """
    An upsampling block in the U-Net.
    (Upsample -> Concat Skip -> ResidualBlock -> ResidualBlock -> Attention)
    """
    def __init__(self, in_channels, out_channels, time_emb_dim, num_classes,
                 has_attn=False, dropout=0.1, up_in_channels=None):
        super().__init__()

        # 'in_channels' is the channel count for res1 (after skip-cat)
        # 'out_channels' is the final output channel count for this block
        # 'up_in_channels' is the input channel count for the upsample layer (from layer below)

        if up_in_channels is None:
            up_in_channels = in_channels // 2 # Fallback, but we will provide it

        self.res1 = ResidualBlock(in_channels, out_channels, time_emb_dim, num_classes, dropout)
        self.res2 = ResidualBlock(out_channels, out_channels, time_emb_dim, num_classes, dropout)
        self.attn = AttentionBlock(out_channels) if has_attn else nn.Identity()

        # *** THIS IS THE FIX ***
        # The upsample layer takes 'up_in_channels' and upsamples to 'out_channels'.
        self.upsample = nn.ConvTranspose2d(up_in_channels, out_channels, kernel_size=4, stride=2, padding=1)

    def forward(self, x, skip, t_emb, c_emb):
        x = self.upsample(x)
        x = torch.cat([skip, x], dim=1) # Concatenate skip connection
        x = self.res1(x, t_emb, c_emb)
        x = self.res2(x, t_emb, c_emb)
        x = self.attn(x)
        return x

class Unet(nn.Module):
    """
    The full U-Net model.
    """
    def __init__(
        self,
        image_channels=3,
        init_channels=64,
        dim_mults=(1, 2, 4), # Controls depth and channel count
        num_classes=2,       # Set to 2 for "cat" and "dog"
        time_emb_dim=256,    # Should be 4 * init_channels
        dropout=0.1
    ):
        super().__init__()

        self.num_classes = num_classes

        # == 1. Time and Class Embeddings ==
        self.time_embed = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim)
        )

        if num_classes is not None:
            self.class_embed = nn.Embedding(num_classes, time_emb_dim)

        # == 2. Initial Convolution ==
        self.init_conv = nn.Conv2d(image_channels, init_channels, kernel_size=3, padding=1)

        # == 3. Downsampling Path (Encoder) ==
        self.down_blocks = nn.ModuleList()
        channels = [init_channels]
        now_channels = init_channels

        for i, mult in enumerate(dim_mults):
            out_channels = init_channels * mult
            self.down_blocks.append(
                DownBlock(
                    now_channels,
                    out_channels,
                    time_emb_dim,
                    num_classes,
                    has_attn=(i >= (len(dim_mults) - 2)), # Add attention at 8x8
                    dropout=dropout
                )
            )
            now_channels = out_channels
            channels.append(now_channels)

        # == 4. Bottleneck ==
        self.mid_block1 = ResidualBlock(now_channels, now_channels, time_emb_dim, num_classes, dropout)
        self.mid_attn = AttentionBlock(now_channels)
        self.mid_block2 = ResidualBlock(now_channels, now_channels, time_emb_dim, num_classes, dropout)
# == 5. Upsampling Path (Decoder) ==
        self.up_blocks = nn.ModuleList()

        # 'now_channels' is 256 (from bottleneck)
        # 'channels' list is [64, 64, 128, 256] (from down path)

        for i, mult in reversed(list(enumerate(dim_mults))):
            out_channels = init_channels * mult  # e.g., 256, then 128, then 64

            # Get skip channels from corresponding down block
            # i=2 -> channels[3] = 256
            # i=1 -> channels[2] = 128
            # i=0 -> channels[1] = 64
            skip_channels = channels[i+1]

            # *** THIS IS THE FIX ***
            # The input to the first ResBlock is the concatenated tensor,
            # which has 'out_channels' (from upsampling) + 'skip_channels'.
            # The old code incorrectly used 'now_channels'.
            in_channels_res1 = out_channels + skip_channels

            self.up_blocks.append(
                UpBlock(
                    in_channels = in_channels_res1,    # Channels for ResBlock (e.g., 128+128=256)
                    out_channels = out_channels,       # Channels for output of this block (e.g., 128)
                    time_emb_dim = time_emb_dim,
                    num_classes = num_classes,
                    has_attn = (i >= (len(dim_mults) - 2)),
                    dropout = dropout,
                    up_in_channels = now_channels    # Channels from layer below (e.g., 256)
                )
            )
            now_channels = out_channels # Update for next loop (e.g., 128, then 64)

        # == 6. Final Convolution ==
        self.final_conv = nn.Sequential(
            nn.GroupNorm(32, init_channels),
            nn.SiLU(),
            nn.Conv2d(init_channels, image_channels, kernel_size=3, padding=1)
        )

    def forward(self, x, time, y=None):
        """
        x: Noisy image (B, C, H, W)
        time: Timestep (B,)
        y: Class label (B,)
        """

        # 1. Get embeddings
        t_emb = self.time_embed(time)

        c_emb = None
        if y is not None and self.num_classes is not None:
            c_emb = self.class_embed(y)
            # You can combine embeddings, e.g., by adding
            t_emb = t_emb + c_emb

        # 2. Initial conv
        x = self.init_conv(x) # (B, 64, 32, 32)

        # 3. Down path
        skips = [x]
        for block in self.down_blocks:
            x, skip = block(x, t_emb, c_emb)
            skips.append(skip)

        # 4. Bottleneck
        x = self.mid_block1(x, t_emb, c_emb)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t_emb, c_emb)

        # 5. Up path
        for block in self.up_blocks:
            x = block(x, skips.pop(), t_emb, c_emb)

        # 6. Final output
        return self.final_conv(x)

In [8]:
if __name__ == "__main__":
    # --- Parameters for CIFAR-10 (32x32) ---
    BATCH_SIZE = 8
    IMG_SIZE = 32
    IMG_CHANNELS = 3
    NUM_CLASSES = 2 # As requested: "cat" and "dog"

    # 1. Create a dummy batch of data
    # Noisy images
    x = torch.randn(BATCH_SIZE, IMG_CHANNELS, IMG_SIZE, IMG_SIZE)

    # Random timesteps (from 0 to, e.g., 1000)
    t = torch.randint(0, 1000, (BATCH_SIZE,))

    # Random class labels (0 or 1)
    y = torch.randint(0, NUM_CLASSES, (BATCH_SIZE,))

    # 2. Instantiate the U-Net
    # We use small channel dimensions for a 32x32 image
    # dim_mults=(1, 2, 4) -> 32x32 -> 16x16 -> 8x8 -> 4x4 (bottleneck)
    model = Unet(
        image_channels=IMG_CHANNELS,
        init_channels=64,
        dim_mults=(1, 2, 4),
        num_classes=NUM_CLASSES,
        time_emb_dim=256 # 64 * 4
    )

    print(f"Model parameter count: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")

    # 3. Run the forward pass
    # The model predicts the noise (or v, depending on your objective)
    predicted_noise = model(x, t, y)

    print(f"Input shape:   {x.shape}")
    print(f"Output shape:  {predicted_noise.shape}")

    # Check if output shape matches input shape
    assert x.shape == predicted_noise.shape

    print("\nSuccess! Model forward pass is working correctly.")

    # --- Example without class conditioning ---
    model_unconditional = Unet(num_classes=None)
    predicted_noise_uncond = model_unconditional(x, t, y=None) # Pass y=None
    print(f"\nUnconditional output shape: {predicted_noise_uncond.shape}")

Model parameter count: 15.43M
Input shape:   torch.Size([8, 3, 32, 32])
Output shape:  torch.Size([8, 3, 32, 32])

Success! Model forward pass is working correctly.

Unconditional output shape: torch.Size([8, 3, 32, 32])


In [13]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from torchvision.datasets import CIFAR10
from tqdm import tqdm  # For a nice progress bar
import numpy as np

# We assume your Unet model is in a file named unet.py
# (No import needed if running in the same notebook)


# --- 1. Diffusion Scheduler ---
# This helper class manages the noise schedule (betas, alphas)
# and provides functions for the v-prediction objective.

class LinearNoiseScheduler:
    """
    A linear noise scheduler as described in the DDPM paper,
    with added support for v-prediction.
    """
    def __init__(self, num_timesteps=1000, beta_start=0.0001, beta_end=0.02):
        self.num_timesteps = num_timesteps
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps)

        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

        # Pre-calculate values for diffusion step
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)


    def add_noise(self, x_0, t, noise):
        """
        Adds noise to the original image x_0 to get x_t.
        x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * noise
        """
        shape = x_0.shape

        # --- THIS IS THE FIX ---
        # Call the 3-argument _get_tensor_values for BOTH schedules
        sqrt_alpha_bar_t = self._get_tensor_values(
            t, shape, self.sqrt_alphas_cumprod
        )
        sqrt_one_minus_alpha_bar_t = self._get_tensor_values(
            t, shape, self.sqrt_one_minus_alphas_cumprod
        )
        # --- END OF FIX ---

        noisy_image = sqrt_alpha_bar_t * x_0 + sqrt_one_minus_alpha_bar_t * noise
        return noisy_image

    def get_velocity(self, x_0, t, noise):
        """
        Calculates the target 'v' for v-prediction.
        v = sqrt(alpha_bar_t) * noise - sqrt(1 - alpha_bar_t) * x_0
        """
        shape = x_0.shape

        # Note: Fixed a bug here from the original paste.
        # get_velocity also needs to call the 3-argument version correctly.
        sqrt_alpha_bar_t = self._get_tensor_values(
            t, shape, self.sqrt_alphas_cumprod
        )
        sqrt_one_minus_alpha_bar_t = self._get_tensor_values(
            t, shape, self.sqrt_one_minus_alphas_cumprod
        )

        velocity = sqrt_alpha_bar_t * noise - sqrt_one_minus_alpha_bar_t * x_0
        return velocity

    # We need to override this one method for v-prediction
    def _get_tensor_values(self, t, shape, schedule_tensor):
        batch_size = t.shape[0]
        out = schedule_tensor.to(t.device).gather(-1, t)
        return out.reshape(batch_size, *((1,) * (len(shape) - 1)))


# --- 2. Dataset Loader ---

def get_dataloader(batch_size, num_classes=2):
    """
    Loads the CIFAR-10 dataset, filtered for two classes.
    """
    # CIFAR-10 classes: ['airplane', 'automobile', 'bird', 'cat', 'deer',
    #                    'dog', 'frog', 'horse', 'ship', 'truck']
    CAT_CLASS = 3
    DOG_CLASS = 5

    # Standard normalization for diffusion models
    transform = transforms.Compose([
        transforms.ToTensor(),                # To [0, 1]
        transforms.Normalize((0.5,), (0.5,))  # To [-1, 1]
    ])

    dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)

    # Filter for "cat" and "dog"
    indices = [i for i, (_, label) in enumerate(dataset) if label in [CAT_CLASS, DOG_CLASS]]

    # Create a new label mapping: 0 for cat, 1 for dog
    # We must do this *after* filtering
    for i in range(len(dataset.targets)):
        if dataset.targets[i] == CAT_CLASS:
            dataset.targets[i] = 0
        elif dataset.targets[i] == DOG_CLASS:
            dataset.targets[i] = 1

    filtered_dataset = Subset(dataset, indices)

    dataloader = DataLoader(
        filtered_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    return dataloader

# --- 3. Training Script ---

def train():
    # --- Hyperparameters (from your project) ---
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    LEARNING_RATE = 1e-4
    BATCH_SIZE = 256  # As specified
    NUM_EPOCHS = 100  # You'll need to run this for many epochs
    NUM_TIMESTEPS = 250 # As specified (use this for scheduler)
    NUM_CLASSES = 2     # "cat" and "dog"

    print(f"Using device: {DEVICE}")

    # --- Setup ---
    # 1. Model
    model = Unet(
        image_channels=3,
        init_channels=64,
        dim_mults=(1, 2, 4),
        num_classes=NUM_CLASSES,
        time_emb_dim=256
    ).to(DEVICE)

    # 2. Dataloader
    dataloader = get_dataloader(BATCH_SIZE, NUM_CLASSES)

    # 3. Scheduler (with v-prediction)
    scheduler = LinearNoiseScheduler(num_timesteps=NUM_TIMESTEPS)

    # 4. Optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

    # 5. Loss Function
    loss_fn = F.mse_loss

    # --- Training Loop ---
    for epoch in range(NUM_EPOCHS):
        model.train()
        total_loss = 0.0

        # Use tqdm for a progress bar
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", leave=True)

        for batch in progress_bar:
            images, labels = batch
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            # 1. Sample random noise
            noise = torch.randn_like(images).to(DEVICE)

            # 2. Sample random timesteps
            t = torch.randint(0, scheduler.num_timesteps, (images.shape[0],)).to(DEVICE)

            # 3. Create noisy images (x_t)
            noisy_images = scheduler.add_noise(images, t, noise)

            # 4. Get model prediction (predict v)
            # We set labels to None 10% of the time for classifier-free guidance
            if np.random.rand() < 0.1:
                labels = None

            predicted_v = model(noisy_images, t, labels)

            # 5. Get target velocity (target v)
            target_v = scheduler.get_velocity(images, t, noise)

            # 6. Calculate loss
            loss = loss_fn(predicted_v, target_v)

            # 7. Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update progress bar
            total_loss += loss.item()
            progress_bar.set_postfix(Loss=loss.item())

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1} finished. Average Loss: {avg_loss:.4f}")

        # --- Save a checkpoint (optional but recommended) ---
        if (epoch + 1) % 10 == 0:
            torch.save(model.state_dict(), f"unet_cifar10_epoch_{epoch+1}.pth")
            print(f"Saved model checkpoint at epoch {epoch+1}")

    print("Training complete.")
    torch.save(model.state_dict(), "unet_cifar10_final.pth")

# This check won't work in a notebook cell,
# you should just call train() directly.
# if __name__ == "__main__":
#     train()

# Call train() directly in your notebook cell
train()

Using device: cuda


Epoch 1/100: 100%|██████████| 40/40 [00:26<00:00,  1.51it/s, Loss=0.354]


Epoch 1 finished. Average Loss: 0.4810


Epoch 2/100: 100%|██████████| 40/40 [00:25<00:00,  1.56it/s, Loss=0.262]


Epoch 2 finished. Average Loss: 0.2963


Epoch 3/100: 100%|██████████| 40/40 [00:25<00:00,  1.54it/s, Loss=0.218]


Epoch 3 finished. Average Loss: 0.2335


Epoch 4/100: 100%|██████████| 40/40 [00:25<00:00,  1.58it/s, Loss=0.152]


Epoch 4 finished. Average Loss: 0.2002


Epoch 5/100: 100%|██████████| 40/40 [00:25<00:00,  1.58it/s, Loss=0.191]


Epoch 5 finished. Average Loss: 0.1962


Epoch 6/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.154]


Epoch 6 finished. Average Loss: 0.1825


Epoch 7/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.156]


Epoch 7 finished. Average Loss: 0.1723


Epoch 8/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.174]


Epoch 8 finished. Average Loss: 0.1661


Epoch 9/100: 100%|██████████| 40/40 [00:25<00:00,  1.56it/s, Loss=0.152]


Epoch 9 finished. Average Loss: 0.1636


Epoch 10/100: 100%|██████████| 40/40 [00:25<00:00,  1.55it/s, Loss=0.12]


Epoch 10 finished. Average Loss: 0.1585
Saved model checkpoint at epoch 10


Epoch 11/100: 100%|██████████| 40/40 [00:25<00:00,  1.56it/s, Loss=0.149]


Epoch 11 finished. Average Loss: 0.1569


Epoch 12/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.181]


Epoch 12 finished. Average Loss: 0.1527


Epoch 13/100: 100%|██████████| 40/40 [00:25<00:00,  1.58it/s, Loss=0.153]


Epoch 13 finished. Average Loss: 0.1502


Epoch 14/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.15]


Epoch 14 finished. Average Loss: 0.1504


Epoch 15/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.21]


Epoch 15 finished. Average Loss: 0.1469


Epoch 16/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.164]


Epoch 16 finished. Average Loss: 0.1453


Epoch 17/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.146]


Epoch 17 finished. Average Loss: 0.1415


Epoch 18/100: 100%|██████████| 40/40 [00:25<00:00,  1.56it/s, Loss=0.115]


Epoch 18 finished. Average Loss: 0.1400


Epoch 19/100: 100%|██████████| 40/40 [00:25<00:00,  1.56it/s, Loss=0.17]


Epoch 19 finished. Average Loss: 0.1415


Epoch 20/100: 100%|██████████| 40/40 [00:25<00:00,  1.55it/s, Loss=0.155]


Epoch 20 finished. Average Loss: 0.1423
Saved model checkpoint at epoch 20


Epoch 21/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.166]


Epoch 21 finished. Average Loss: 0.1364


Epoch 22/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.12]


Epoch 22 finished. Average Loss: 0.1348


Epoch 23/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.119]


Epoch 23 finished. Average Loss: 0.1319


Epoch 24/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.108]


Epoch 24 finished. Average Loss: 0.1310


Epoch 25/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.12]


Epoch 25 finished. Average Loss: 0.1319


Epoch 26/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.115]


Epoch 26 finished. Average Loss: 0.1278


Epoch 27/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.16]


Epoch 27 finished. Average Loss: 0.1295


Epoch 28/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.142]


Epoch 28 finished. Average Loss: 0.1312


Epoch 29/100: 100%|██████████| 40/40 [00:25<00:00,  1.56it/s, Loss=0.1]


Epoch 29 finished. Average Loss: 0.1247


Epoch 30/100: 100%|██████████| 40/40 [00:25<00:00,  1.56it/s, Loss=0.107]


Epoch 30 finished. Average Loss: 0.1267
Saved model checkpoint at epoch 30


Epoch 31/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.115]


Epoch 31 finished. Average Loss: 0.1252


Epoch 32/100: 100%|██████████| 40/40 [00:25<00:00,  1.56it/s, Loss=0.0947]


Epoch 32 finished. Average Loss: 0.1235


Epoch 33/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.0946]


Epoch 33 finished. Average Loss: 0.1246


Epoch 34/100: 100%|██████████| 40/40 [00:25<00:00,  1.58it/s, Loss=0.125]


Epoch 34 finished. Average Loss: 0.1258


Epoch 35/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.212]


Epoch 35 finished. Average Loss: 0.1258


Epoch 36/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.108]


Epoch 36 finished. Average Loss: 0.1228


Epoch 37/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.132]


Epoch 37 finished. Average Loss: 0.1212


Epoch 38/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.122]


Epoch 38 finished. Average Loss: 0.1225


Epoch 39/100: 100%|██████████| 40/40 [00:25<00:00,  1.56it/s, Loss=0.145]


Epoch 39 finished. Average Loss: 0.1233


Epoch 40/100: 100%|██████████| 40/40 [00:25<00:00,  1.56it/s, Loss=0.102]


Epoch 40 finished. Average Loss: 0.1208
Saved model checkpoint at epoch 40


Epoch 41/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.146]


Epoch 41 finished. Average Loss: 0.1209


Epoch 42/100: 100%|██████████| 40/40 [00:25<00:00,  1.56it/s, Loss=0.131]


Epoch 42 finished. Average Loss: 0.1206


Epoch 43/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.12]


Epoch 43 finished. Average Loss: 0.1216


Epoch 44/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.114]


Epoch 44 finished. Average Loss: 0.1213


Epoch 45/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.109]


Epoch 45 finished. Average Loss: 0.1174


Epoch 46/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.119]


Epoch 46 finished. Average Loss: 0.1177


Epoch 47/100: 100%|██████████| 40/40 [00:25<00:00,  1.56it/s, Loss=0.147]


Epoch 47 finished. Average Loss: 0.1171


Epoch 48/100: 100%|██████████| 40/40 [00:25<00:00,  1.56it/s, Loss=0.0992]


Epoch 48 finished. Average Loss: 0.1190


Epoch 49/100: 100%|██████████| 40/40 [00:25<00:00,  1.56it/s, Loss=0.121]


Epoch 49 finished. Average Loss: 0.1195


Epoch 50/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.104]


Epoch 50 finished. Average Loss: 0.1179
Saved model checkpoint at epoch 50


Epoch 51/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.117]


Epoch 51 finished. Average Loss: 0.1167


Epoch 52/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.174]


Epoch 52 finished. Average Loss: 0.1168


Epoch 53/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.112]


Epoch 53 finished. Average Loss: 0.1167


Epoch 54/100: 100%|██████████| 40/40 [00:25<00:00,  1.56it/s, Loss=0.0776]


Epoch 54 finished. Average Loss: 0.1149


Epoch 55/100: 100%|██████████| 40/40 [00:25<00:00,  1.56it/s, Loss=0.0959]


Epoch 55 finished. Average Loss: 0.1154


Epoch 56/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.115]


Epoch 56 finished. Average Loss: 0.1156


Epoch 57/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.103]


Epoch 57 finished. Average Loss: 0.1139


Epoch 58/100: 100%|██████████| 40/40 [00:25<00:00,  1.56it/s, Loss=0.0963]


Epoch 58 finished. Average Loss: 0.1139


Epoch 59/100: 100%|██████████| 40/40 [00:25<00:00,  1.56it/s, Loss=0.0955]


Epoch 59 finished. Average Loss: 0.1153


Epoch 60/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.114]


Epoch 60 finished. Average Loss: 0.1146
Saved model checkpoint at epoch 60


Epoch 61/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.131]


Epoch 61 finished. Average Loss: 0.1157


Epoch 62/100: 100%|██████████| 40/40 [00:25<00:00,  1.58it/s, Loss=0.155]


Epoch 62 finished. Average Loss: 0.1150


Epoch 63/100: 100%|██████████| 40/40 [00:25<00:00,  1.58it/s, Loss=0.104]


Epoch 63 finished. Average Loss: 0.1137


Epoch 64/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.101]


Epoch 64 finished. Average Loss: 0.1129


Epoch 65/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.138]


Epoch 65 finished. Average Loss: 0.1129


Epoch 66/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.104]


Epoch 66 finished. Average Loss: 0.1118


Epoch 67/100: 100%|██████████| 40/40 [00:25<00:00,  1.56it/s, Loss=0.0954]


Epoch 67 finished. Average Loss: 0.1129


Epoch 68/100: 100%|██████████| 40/40 [00:25<00:00,  1.56it/s, Loss=0.113]


Epoch 68 finished. Average Loss: 0.1119


Epoch 69/100: 100%|██████████| 40/40 [00:25<00:00,  1.56it/s, Loss=0.101]


Epoch 69 finished. Average Loss: 0.1107


Epoch 70/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.13]


Epoch 70 finished. Average Loss: 0.1127
Saved model checkpoint at epoch 70


Epoch 71/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.118]


Epoch 71 finished. Average Loss: 0.1121


Epoch 72/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.11]


Epoch 72 finished. Average Loss: 0.1112


Epoch 73/100: 100%|██████████| 40/40 [00:25<00:00,  1.58it/s, Loss=0.095]


Epoch 73 finished. Average Loss: 0.1089


Epoch 74/100: 100%|██████████| 40/40 [00:25<00:00,  1.58it/s, Loss=0.0964]


Epoch 74 finished. Average Loss: 0.1119


Epoch 75/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.115]


Epoch 75 finished. Average Loss: 0.1108


Epoch 76/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.123]


Epoch 76 finished. Average Loss: 0.1113


Epoch 77/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.106]


Epoch 77 finished. Average Loss: 0.1112


Epoch 78/100: 100%|██████████| 40/40 [00:25<00:00,  1.56it/s, Loss=0.11]


Epoch 78 finished. Average Loss: 0.1094


Epoch 79/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.0917]


Epoch 79 finished. Average Loss: 0.1093


Epoch 80/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.0847]


Epoch 80 finished. Average Loss: 0.1091
Saved model checkpoint at epoch 80


Epoch 81/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.137]


Epoch 81 finished. Average Loss: 0.1105


Epoch 82/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.102]


Epoch 82 finished. Average Loss: 0.1097


Epoch 83/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.0666]


Epoch 83 finished. Average Loss: 0.1092


Epoch 84/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.0847]


Epoch 84 finished. Average Loss: 0.1095


Epoch 85/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.101]


Epoch 85 finished. Average Loss: 0.1080


Epoch 86/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.109]


Epoch 86 finished. Average Loss: 0.1085


Epoch 87/100: 100%|██████████| 40/40 [00:25<00:00,  1.56it/s, Loss=0.0969]


Epoch 87 finished. Average Loss: 0.1086


Epoch 88/100: 100%|██████████| 40/40 [00:25<00:00,  1.56it/s, Loss=0.0792]


Epoch 88 finished. Average Loss: 0.1084


Epoch 89/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.121]


Epoch 89 finished. Average Loss: 0.1071


Epoch 90/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.131]


Epoch 90 finished. Average Loss: 0.1099
Saved model checkpoint at epoch 90


Epoch 91/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.089]


Epoch 91 finished. Average Loss: 0.1073


Epoch 92/100: 100%|██████████| 40/40 [00:25<00:00,  1.58it/s, Loss=0.0783]


Epoch 92 finished. Average Loss: 0.1091


Epoch 93/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.102]


Epoch 93 finished. Average Loss: 0.1071


Epoch 94/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.102]


Epoch 94 finished. Average Loss: 0.1080


Epoch 95/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.0796]


Epoch 95 finished. Average Loss: 0.1067


Epoch 96/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.0896]


Epoch 96 finished. Average Loss: 0.1064


Epoch 97/100: 100%|██████████| 40/40 [00:25<00:00,  1.55it/s, Loss=0.149]


Epoch 97 finished. Average Loss: 0.1094


Epoch 98/100: 100%|██████████| 40/40 [00:25<00:00,  1.56it/s, Loss=0.12]


Epoch 98 finished. Average Loss: 0.1075


Epoch 99/100: 100%|██████████| 40/40 [00:25<00:00,  1.58it/s, Loss=0.118]


Epoch 99 finished. Average Loss: 0.1089


Epoch 100/100: 100%|██████████| 40/40 [00:25<00:00,  1.57it/s, Loss=0.102]


Epoch 100 finished. Average Loss: 0.1067
Saved model checkpoint at epoch 100
Training complete.


In [16]:
import torch
from torchvision.utils import save_image
from tqdm import tqdm
import torch.nn.functional as F

# --- 1. INFERENCE FUNCTION (SAMPLER) ---

def generate_images(
    model,
    scheduler,
    num_images=16,
    class_label=0,  # 0 for "cat", 1 for "dog"
    guidance_scale=5.0,
    device="cuda"
):
    """
    Generates images using the trained U-Net and a DDPM-style sampler.
    This sampler is specifically adapted for a v-prediction model.
    """

    # --- 1. Setup ---
    model.eval()  # Set the model to evaluation mode
    num_timesteps = scheduler.num_timesteps

    # Start with random noise x_T
    # Shape: (batch_size, channels, height, width)
    img = torch.randn(num_images, 3, 32, 32, device=device)

    # Create the label tensor
    labels = torch.full((num_images,), class_label, dtype=torch.long, device=device)

    # --- 2. Pre-calculate sampler constants ---
    # We need these for the DDPM sampling formula
    betas = scheduler.betas.to(device)
    alphas = scheduler.alphas.to(device)
    alphas_cumprod = scheduler.alphas_cumprod.to(device)

    # Pre-calculated terms for x_t-1
    # 1 / sqrt(alpha_t)
    sqrt_recip_alphas_t = (1.0 / torch.sqrt(alphas)).to(device)

    # (1 - alpha_t) / sqrt(1 - alpha_bar_t)
    beta_over_sqrt_one_minus_alpha_bar_t = ((1. - alphas) / torch.sqrt(1. - alphas_cumprod)).to(device)

    # Posterior variance: (1 - alpha_bar_t-1) / (1 - alpha_bar_t) * beta_t
    alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
    posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
    # We use the log for stability and clipping
    posterior_log_variance = torch.log(posterior_variance.clamp(min=1e-20))

    # --- 3. The Sampling Loop ---
    print(f"Generating {num_images} images of class '{'cat' if class_label==0 else 'dog'}'...")

    # Loop from T (num_timesteps - 1) down to 0
    for t in tqdm(reversed(range(num_timesteps)), desc="Sampling"):
        # Create a tensor of the current timestep, duplicated for each image in batch
        t_tensor = torch.full((num_images,), t, dtype=torch.long, device=device)

        with torch.no_grad():
            # --- a. Classifier-Free Guidance (CFG) ---
            # First, predict v_uncond (with labels=None)
            v_uncond = model(img, t_tensor, y=None)

            # Second, predict v_cond (with our target labels)
            v_cond = model(img, t_tensor, y=labels)

            # Combine them: v = v_uncond + scale * (v_cond - v_uncond)
            pred_v = v_uncond + guidance_scale * (v_cond - v_uncond)

            # --- b. Convert v-prediction to epsilon-prediction ---
            # We need epsilon for the DDPM sampling formula
            # pred_epsilon = sqrt(1 - alpha_bar_t) * x_t + sqrt(alpha_bar_t) * pred_v
            sqrt_alpha_bar_t = scheduler._get_tensor_values(
                t_tensor, img.shape, scheduler.sqrt_alphas_cumprod
            )
            sqrt_one_minus_alpha_bar_t = scheduler._get_tensor_values(
                t_tensor, img.shape, scheduler.sqrt_one_minus_alphas_cumprod
            )

            pred_epsilon = sqrt_one_minus_alpha_bar_t * img + sqrt_alpha_bar_t * pred_v

        # --- c. DDPM Sampling Step ---
        # Get the pre-calculated coefficients for this timestep t
        mean_scale_t = sqrt_recip_alphas_t[t].reshape(-1, 1, 1, 1)
        noise_coeff_t = beta_over_sqrt_one_minus_alpha_bar_t[t].reshape(-1, 1, 1, 1)
        log_variance_t = posterior_log_variance[t].reshape(-1, 1, 1, 1)

        # 1. Calculate the mean of x_{t-1}
        # x_{t-1}_mean = (1/sqrt(alpha_t)) * (x_t - (beta_t / sqrt(1 - alpha_bar_t)) * epsilon_t)
        pred_x_t_minus_1_mean = mean_scale_t * (img - noise_coeff_t * pred_epsilon)

        # 2. Add noise
        if t > 0:
            noise = torch.randn_like(img)
            # variance = exp(log_variance) * noise
            # (we take 0.5 * log_variance because it's log(sigma^2) and we want sigma)
            pred_x_t_minus_1 = pred_x_t_minus_1_mean + (0.5 * log_variance_t).exp() * noise
        else:
            # At t=0, there is no noise
            pred_x_t_minus_1 = pred_x_t_minus_1_mean

        # Update our image for the next loop iteration
        img = pred_x_t_minus_1

    # --- 4. Post-process and Return ---
    # Undo the normalization from [-1, 1] back to [0, 1]
    img = (img.clamp(-1, 1) + 1) / 2
    # Convert to [0, 255] for saving as an image file
    img = (img * 255).type(torch.uint8)

    return img

# --- 2. SCRIPT TO RUN THE INFERENCE ---

# --- Setup Model and Scheduler ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_TIMESTEPS = 250
NUM_CLASSES = 2
CHECKPOINT_PATH = "unet_cifar10_epoch_100.pth" # Or "unet_cifar10_final.pth"

# 1. Load your trained model
# (This assumes the 'Unet' class is already defined in a previous cell)
model = Unet(
    image_channels=3,
    init_channels=64,
    dim_mults=(1, 2, 4),
    num_classes=NUM_CLASSES,
    time_emb_dim=256
).to(DEVICE)

# Load the checkpoint file from training
model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=DEVICE))

# 2. Create the scheduler
# (This assumes the 'LinearNoiseScheduler' class is already defined)
scheduler = LinearNoiseScheduler(num_timesteps=NUM_TIMESTEPS)


# --- Generate "Cat" images (class_label=0) ---
print("--- Generating Cats ---")
generated_cats = generate_images(
    model,
    scheduler,
    num_images=16,
    class_label=0,  # 0 for cat
    guidance_scale=5.0,
    device=DEVICE
)

# Save the cat images
save_image(
    generated_cats.float() / 255.0,  # Convert from [0, 255] to [0, 1]
    "generated_cats_epoch_100.png",
    nrow=4  # Create a 4x4 grid
)
print("Saved cat images to 'generated_cats_epoch_100.png'")


# --- Generate "Dog" images (class_label=1) ---
print("\n--- Generating Dogs ---")
generated_dogs = generate_images(
    model,
    scheduler,
    num_images=16,
    class_label=1,  # 1 for dog
    guidance_scale=5.0,
    device=DEVICE
)

# Save the dog images
save_image(
    generated_dogs.float() / 255.0,  # Convert from [0, 255] to [0, 1]
    "generated_dogs_epoch_100.png",
    nrow=4  # Create a 4x4 grid
)
print("Saved dog images to 'generated_dogs_epoch_100.png'")

--- Generating Cats ---
Generating 16 images of class 'cat'...


Sampling: 250it [00:08, 28.91it/s]


Saved cat images to 'generated_cats_epoch_100.png'

--- Generating Dogs ---
Generating 16 images of class 'dog'...


Sampling: 250it [00:08, 28.67it/s]

Saved dog images to 'generated_dogs_epoch_100.png'



