In [None]:
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import torchvision


# --- 1. DEFINE THE SAMPLING FUNCTION FOR A SINGLE TIMESTEP ---

@torch.no_grad() # Crucial: we are not training, so we don't need to calculate gradients
def sample_timestep(x, t, model):
    """
    Calls the model to predict the noise and uses it to denoise the image for one timestep.
    This is the core of Algorithm 2 from the paper.
    """
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
    
    # This is the formula from Algorithm 2, Step 4
    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
    
    if t.all() == 0:
        # As stated in the paper, at the last step we don't add noise
        return model_mean
    else:
        # Add random noise
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise

# --- 2. DEFINE THE FULL SAMPLING LOOP ---

@torch.no_grad()
def sample_plot_images(model, num_images=16):
    print("Generating new images...")
    # Start with pure random noise (our blank canvas)
    img = torch.randn((num_images, 3, IMG_SIZE, IMG_SIZE), device=DEVICE)
    
    # Use tqdm for a nice progress bar
    for i in tqdm(reversed(range(0, TIMESTEPS)), desc='Sampling loop', total=TIMESTEPS):
        # Create a tensor of the current timestep for all images in the batch
        t = torch.full((num_images,), i, device=DEVICE, dtype=torch.long)
        # Perform one denoising step
        img = sample_timestep(img, t, model)

    # Display the final generated images
    show_images(img, "Generated Images")


# --- 3. LOAD YOUR TRAINED MODEL AND RUN SAMPLING ---

# Create a new instance of the model (make sure architecture is the same)
model = SimpleUnet().to(DEVICE)

# Load the saved weights
model_path = "ddpm_cifar10_10_epochs.pth"
model.load_state_dict(torch.load(model_path, map_location=DEVICE))

# Set the model to evaluation mode
model.eval()

# Run the sampling
sample_plot_images(model)