In [3]:
import os
import yaml
import pytorch_lightning as pl
import torch
from torchvision import transforms
from torchvision.utils import save_image

from data_loaders_l import SynthesisDataModule
from model_architectures import DDPM

# Load configuration
with open('config_l.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Extract parameters from config
batch_size = config.get('batch_size', 1)  # use a small batch size for inference
latent_dim = config.get('latent_dim', 100)
label_dim = config.get('label_dim', 4)
experiment_name = config.get('experiment_name', 'default_experiment')
model_dir = config.get('model_dir', 'saved_models')

# Specify the checkpoint path to load the trained model from.
# checkpoint_path = os.path.join(model_dir, 'ddpm-epoch=30-train_loss=0.1432.ckpt')
checkpoint_path = os.path.join(model_dir, 'ddpm-epoch=30-train_loss=0.1432.ckpt')

# Set up the same transform as used during training
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: (x - x.min()) / (x.max() - x.min() + 1e-8)), 
    transforms.Normalize((0.5,), (0.5,)),
    transforms.Lambda(lambda x: x.to(torch.float32))
])


In [4]:
# Set up data module for inference
data_module = SynthesisDataModule(batch_size=batch_size, transform=transform)
data_module.setup()

# Load the trained model from checkpoint
model = DDPM.load_from_checkpoint(
    checkpoint_path,
    label_dim=label_dim,
    learning_rate=0  # learning_rate is not used during inference
)
model.eval()  # Set model to evaluation mode

# Move model to the right device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model.eval()  # Set model to evaluation mode

# Create a directory to save inference results
inference_dir = os.path.join('inference_results', experiment_name)
os.makedirs(inference_dir, exist_ok=True)

# Choose dataloader
if hasattr(data_module, 'test_dataloader'):
    dataloader = data_module.test_dataloader()
else:
    dataloader = data_module.val_dataloader()

In [5]:
import matplotlib.pyplot as plt
import torch

# Determine the device the model is on
device = next(model.parameters()).device

num_images_to_generate = 10  # How many images you want to generate
num_timesteps = model.timesteps  # Number of denoising steps

# Function to generate images using DDPM
def sample_ddpm(model, num_samples, num_timesteps, label=None):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    img_size = (num_samples, 1, 256, 256)  # Assuming grayscale 256x256 images
    noisy_img = torch.randn(img_size, device=device)

    if label is not None:
        label_tensor = torch.full((num_samples,), label, dtype=torch.long, device=device)
    else:
        label_tensor = None

    for t in reversed(range(num_timesteps)):
        timestep_tensor = torch.tensor([t] * num_samples, dtype=torch.long, device=device)
        predicted_noise = model(noisy_img, timestep_tensor, label_tensor)

        alpha_t = model.alpha_hat[t].to(device)
        sqrt_alpha_t = torch.sqrt(alpha_t)
        sqrt_one_minus_alpha_t = torch.sqrt(1 - alpha_t)

        # Fix: Add correct noise scaling
        if t > 0:
            sigma_t = torch.sqrt((1 - alpha_t) / (1 - model.alpha_hat[t-1]))
            noise = torch.randn_like(noisy_img, device=device) * sigma_t
            noisy_img = (1 / sqrt_alpha_t) * (noisy_img - sqrt_one_minus_alpha_t * predicted_noise) + noise
        else:
            noisy_img = (1 / sqrt_alpha_t) * (noisy_img - sqrt_one_minus_alpha_t * predicted_noise)

    return torch.clamp(noisy_img, -1, 1)  # Ensure final values are in range




In [None]:
# Ensure inline plotting works in JupyterLab
%matplotlib inline  

# Generate images
generated_images = sample_ddpm(model, num_images_to_generate, num_timesteps)

# Convert generated images from [-1,1] to [0,1] for visualization
generated_images = (generated_images + 1) / 2
print("Generated Images Tensor Shape:", generated_images.shape)
print("Min:", generated_images.min().item(), "Max:", generated_images.max().item())

# Create figure
fig, axes = plt.subplots(1, num_images_to_generate, figsize=(15, 3))

for i in range(num_images_to_generate):
    img = generated_images[i].cpu().squeeze().numpy()
    axes[i].imshow(img, cmap='gray')
    axes[i].axis('off')

plt.show()
print("Generation complete!")