In [None]:
!pip install torch-fidelity

In [None]:
!pip install torchmetrics[image]

In [8]:
import os
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
import matplotlib.cm as cm

# For FID calculation
from torchmetrics.image.fid import FrechetInceptionDistance

In [9]:
# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## 1. Custom Dataset with Data Augmentation

The dataset class loads .npy files representing gravitational lensing images.
I have added a data augmentation pipeline that applies random horizontal/vertical flips and rotations.

In [10]:
class AugmentedLensDataset(Dataset):
    def __init__(self, root_dir, transform=None, augmentations=None):
        self.root_dir = root_dir
        self.transform = transform
        self.augmentations = augmentations
        self.file_list = [f for f in os.listdir(root_dir) if f.endswith('.npy')]
    
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, idx):
        file_path = os.path.join(self.root_dir, self.file_list[idx])
        data = np.load(file_path)
        data = torch.from_numpy(data).float()
        if self.transform:
            data = self.transform(data)
        if self.augmentations:
            data = self.augmentations(data)
        return data

## 2. Helper Functions: Time Embedding and Cosine Noise Schedule

I have used a sinusoidal embedding for time steps and a cosine noise schedule.

In [11]:
def get_timestep_embedding(timesteps, embedding_dim):
    half_dim = embedding_dim // 2
    emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
    emb = timesteps.float().unsqueeze(1) * emb.unsqueeze(0)
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:
        emb = F.pad(emb, (0, 1))
    return emb

In [12]:
def cosine_beta_schedule(timesteps, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * (torch.pi / 2))**2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    betas = torch.clamp(betas, 0, 0.999)
    return betas

## 3. U-Net Architecture with Residual Blocks

U-Net model with residual blocks and time conditioning has been implemented.

In [13]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.norm1 = nn.GroupNorm(8, out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.norm2 = nn.GroupNorm(8, out_channels)
        self.time_emb = nn.Linear(time_emb_dim, out_channels)
        self.res_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()
    
    def forward(self, x, t_emb):
        h = self.conv1(x)
        h = self.norm1(h)
        h = F.silu(h)
        t_proj = self.time_emb(t_emb)[:, :, None, None]
        h = h + t_proj
        h = self.conv2(h)
        h = self.norm2(h)
        return F.silu(h + self.res_conv(x))

In [14]:
class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim):
        super().__init__()
        self.res_block = ResidualBlock(in_channels, out_channels, time_emb_dim)
        self.downsample = nn.Conv2d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)
    
    def forward(self, x, t_emb):
        x = self.res_block(x, t_emb)
        skip = x
        x_down = self.downsample(x)
        return skip, x_down

In [15]:
class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, skip_channels, time_emb_dim):
        super().__init__()
        self.upsample = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        )
        self.skip_proj = nn.Conv2d(skip_channels, out_channels, kernel_size=1) if skip_channels != out_channels else nn.Identity()
        self.res_block = ResidualBlock(out_channels * 2, out_channels, time_emb_dim)
    
    def forward(self, x, skip, t_emb):
        x = self.upsample(x)
        skip = self.skip_proj(skip)
        x = torch.cat([x, skip], dim=1)
        x = self.res_block(x, t_emb)
        return x

In [16]:
class DiffLensUNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, time_emb_dim=128):
        super().__init__()
        self.time_emb_dim = time_emb_dim
        
        self.time_mlp = nn.Sequential(
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim)
        )
        
        self.init_conv = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)
        
        # Encoder
        self.down1 = DownBlock(64, 128, time_emb_dim)
        self.down2 = DownBlock(128, 256, time_emb_dim)
        self.down3 = DownBlock(256, 256, time_emb_dim)
        
        self.bottleneck = ResidualBlock(256, 256, time_emb_dim)
        
        # Decoder
        self.up1 = UpBlock(in_channels=256, out_channels=256, skip_channels=256, time_emb_dim=time_emb_dim)
        self.up2 = UpBlock(in_channels=256, out_channels=128, skip_channels=256, time_emb_dim=time_emb_dim)
        self.up3 = UpBlock(in_channels=128, out_channels=64,  skip_channels=128, time_emb_dim=time_emb_dim)
        
        self.out_conv = nn.Conv2d(64, out_channels, kernel_size=1)
    
    def forward(self, x, t):
        t_emb = get_timestep_embedding(t, self.time_emb_dim)
        t_emb = self.time_mlp(t_emb)
        
        x1 = F.silu(self.init_conv(x))
        skip1, x2 = self.down1(x1, t_emb)
        skip2, x3 = self.down2(x2, t_emb)
        skip3, x4 = self.down3(x3, t_emb)
        
        x4 = self.bottleneck(x4, t_emb)
        
        x = self.up1(x4, skip3, t_emb)
        x = self.up2(x, skip2, t_emb)
        x = self.up3(x, skip1, t_emb)
        return self.out_conv(x)

## 4. Diffusion Process with Cosine Noise Schedule
Defined the diffusion process with an option for a cosine noise schedule.

In [17]:
class Diffusion:
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=128, device=device, noise_schedule='cosine'):
        self.noise_steps = noise_steps
        self.img_size = img_size
        self.device = device
        
        if noise_schedule == 'linear':
            self.beta = torch.linspace(beta_start, beta_end, noise_steps).to(device)
        elif noise_schedule == 'cosine':
            self.beta = cosine_beta_schedule(noise_steps).to(device)
        else:
            raise ValueError("noise_schedule must be either 'linear' or 'cosine'")
        
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)
    
    def add_noise(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        noise = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * noise, noise
    
    def sample_timesteps(self, batch_size):
        return torch.randint(0, self.noise_steps, (batch_size,), device=self.device)
    
    def sample_images(self, model, n):
        model.eval()
        with torch.no_grad():
            x = torch.randn((n, 1, self.img_size, self.img_size), device=self.device)
            pbar = tqdm(reversed(range(1, self.noise_steps)), desc="Sampling", bar_format='{l_bar}{bar} {postfix}')
            for i in pbar:
                t = torch.full((n,), i, device=self.device, dtype=torch.long)
                predicted_noise = model(x, t)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                noise = torch.randn_like(x) if i > 1 else torch.zeros_like(x)
                x = (1 / torch.sqrt(alpha)) * (x - ((1 - alpha) / torch.sqrt(1 - alpha_hat)) * predicted_noise) + torch.sqrt(beta) * noise
        model.train()
        return x
    
    def convert_to_rgb(self, images):
        rgb_images = []
        for image in images:
            image_np = image.squeeze().cpu().numpy()
            rgb = cm.viridis(image_np)[..., :3]
            rgb_tensor = torch.from_numpy(rgb.astype(np.float32)).permute(2, 0, 1)
            rgb_images.append(rgb_tensor)
        return torch.stack(rgb_images, dim=0).to(self.device)
    
    def calculate_fid(self, model, real_dataloader, num_samples=100):
        fid_metric = FrechetInceptionDistance(feature=2048, normalize=True).to(self.device)
        model.eval()
        fake_images = self.sample_images(model, num_samples)
        fake_images_rgb = self.convert_to_rgb(fake_images)
        
        real_images_list = []
        count = 0
        for batch in real_dataloader:
            real_images_list.append(batch.to(self.device))
            count += batch.size(0)
            if count >= num_samples:
                break
        real_images = torch.cat(real_images_list, dim=0)[:num_samples]
        real_images_rgb = self.convert_to_rgb(real_images)
        
        fid_metric.update(real_images_rgb, real=True)
        fid_metric.update(fake_images_rgb, real=False)
        fid_score = fid_metric.compute()
        model.train()
        return fid_score

## 5. Exponential Moving Average (EMA)

EMA keeps a running average of model parameters to stabilize training.

In [18]:
class EMA:
    def __init__(self, model, decay=0.999):
        self.ema_model = copy.deepcopy(model)
        self.decay = decay
        for param in self.ema_model.parameters():
            param.requires_grad = False
    
    def update(self, model):
        with torch.no_grad():
            msd = model.state_dict()
            for key, param in self.ema_model.state_dict().items():
                param.copy_(param * self.decay + msd[key] * (1. - self.decay))

## 6. Training Setup and Loop

Seting hyper-parameters, create dataloaders (with data augmentation), model, optimizer, EMA, and train.

In [19]:
lr = 3e-4
epochs = 100
img_size = 128
batch_size = 24
plot_freq = 25
data_dir = '/kaggle/input/deeplensetask4/Samples'

In [20]:
# Base transformation: ensure images have the right size
base_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
])

In [21]:
# Data augmentation: random horizontal/vertical flips and rotation (customize as needed)
augmentation_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
])

In [22]:
# Create dataset and dataloader with augmentation
dataset = AugmentedLensDataset(root_dir=data_dir, transform=base_transform, augmentations=augmentation_transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

In [23]:
model = DiffLensUNet(in_channels=1, out_channels=1, time_emb_dim=128).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, eps=1e-5)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=len(dataloader), epochs=epochs)
mse_loss = nn.MSELoss()
diffusion = Diffusion(noise_steps=300, img_size=img_size, device=device, noise_schedule='cosine')

ema = EMA(model, decay=0.999)

## 7. Utility Functions for Saving and Plotting

In [24]:
def save_sample_images(images, path, nrow=6):
    grid = torchvision.utils.make_grid(images, nrow=nrow, normalize=True)
    ndarr = grid.permute(1, 2, 0).cpu().numpy()
    plt.figure(figsize=(6, 6))
    plt.axis('off')
    plt.imshow(ndarr)
    plt.savefig(path, bbox_inches='tight')
    plt.close()

In [25]:
def plot_sample_images(images, nrow=6):
    grid = torchvision.utils.make_grid(images, nrow=nrow, normalize=True)
    ndarr = grid.permute(1, 2, 0).cpu().numpy()
    plt.figure(figsize=(6, 6))
    plt.axis('off')
    plt.imshow(ndarr)
    plt.show()

In [26]:
os.makedirs("Results", exist_ok=True)
os.makedirs("Models", exist_ok=True)

## 8. Training Loop with EMA and Data Augmentation

The training loop updates EMA after each optimizer step.

In [None]:
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}", bar_format='{l_bar}{bar} {postfix}')
    for images in pbar:
        images = images.to(device)
        t = diffusion.sample_timesteps(images.shape[0])
        x_noisy, noise = diffusion.add_noise(images, t)
        pred_noise = model(x_noisy, t)
        loss = mse_loss(pred_noise, noise)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        ema.update(model)
        
        running_loss += loss.item()
        pbar.set_postfix({'loss': f"{running_loss / (pbar.n or 1):.4f}"})
    
    if (epoch + 1) % plot_freq == 0:
        model.eval()
        with torch.no_grad():
            sampled_images = diffusion.sample_images(ema.ema_model, n=images.shape[0])
        save_sample_images(sampled_images, os.path.join("Results", f"epoch_{epoch+1}.png"))
        torch.save(model.state_dict(), os.path.join("Models", f"ckpt_epoch_{epoch+1}.pt"))
        fid_score = diffusion.calculate_fid(ema.ema_model, dataloader, num_samples=100)
        print(f"Epoch {epoch+1}: FID score: {fid_score.item():.2f}")

Epoch 1: 100%|██████████ , loss=0.0879
Epoch 2: 100%|██████████ , loss=0.0172
Epoch 3: 100%|██████████ , loss=0.0105
Epoch 4: 100%|██████████ , loss=0.0078
Epoch 5: 100%|██████████ , loss=0.0071
Epoch 6: 100%|██████████ , loss=0.0055
Epoch 7: 100%|██████████ , loss=0.0049
Epoch 8: 100%|██████████ , loss=0.0049
Epoch 9: 100%|██████████ , loss=0.0047
Epoch 10: 100%|██████████ , loss=0.0041
Epoch 11: 100%|██████████ , loss=0.0038
Epoch 12: 100%|██████████ , loss=0.0039
Epoch 13: 100%|██████████ , loss=0.0035
Epoch 14: 100%|██████████ , loss=0.0038
Epoch 15: 100%|██████████ , loss=0.0032
Epoch 16: 100%|██████████ , loss=0.0033
Epoch 17: 100%|██████████ , loss=0.0033
Epoch 18: 100%|██████████ , loss=0.0031
Epoch 19: 100%|██████████ , loss=0.0032
Epoch 20: 100%|██████████ , loss=0.0032
Epoch 21: 100%|██████████ , loss=0.0029
Epoch 22: 100%|██████████ , loss=0.0029
Epoch 23: 100%|██████████ , loss=0.0031
Epoch 24: 100%|██████████ , loss=0.0029
Epoch 25: 100%|██████████ , loss=0.0028
Sampling:

Epoch 25: FID score: 270.77


Epoch 26: 100%|██████████ , loss=0.0028
Epoch 27: 100%|██████████ , loss=0.0027
Epoch 28: 100%|██████████ , loss=0.0027
Epoch 29: 100%|██████████ , loss=0.0027
Epoch 30: 100%|██████████ , loss=0.0026
Epoch 31: 100%|██████████ , loss=0.0026
Epoch 32: 100%|██████████ , loss=0.0027
Epoch 33: 100%|██████████ , loss=0.0025
Epoch 34: 100%|██████████ , loss=0.0025
Epoch 35: 100%|██████████ , loss=0.0026
Epoch 36: 100%|██████████ , loss=0.0027
Epoch 37: 100%|██████████ , loss=0.0025
Epoch 38: 100%|██████████ , loss=0.0025
Epoch 39: 100%|██████████ , loss=0.0026
Epoch 40: 100%|██████████ , loss=0.0023
Epoch 41: 100%|██████████ , loss=0.0024
Epoch 42: 100%|██████████ , loss=0.0024
Epoch 43: 100%|██████████ , loss=0.0022
Epoch 44: 100%|██████████ , loss=0.0025
Epoch 45: 100%|██████████ , loss=0.0025
Epoch 46: 100%|██████████ , loss=0.0023
Epoch 47: 100%|██████████ , loss=0.0023
Epoch 48: 100%|██████████ , loss=0.0023
Epoch 49: 100%|██████████ , loss=0.0025
Epoch 50: 100%|██████████ , loss=0.0023


Epoch 50: FID score: 30.65


Epoch 51: 100%|██████████ , loss=0.0022
Epoch 52: 100%|██████████ , loss=0.0024
Epoch 53: 100%|██████████ , loss=0.0022
Epoch 54: 100%|██████████ , loss=0.0023
Epoch 55: 100%|██████████ , loss=0.0021
Epoch 56: 100%|██████████ , loss=0.0021
Epoch 57: 100%|██████████ , loss=0.0023
Epoch 58: 100%|██████████ , loss=0.0022
Epoch 59: 100%|██████████ , loss=0.0022
Epoch 60: 100%|██████████ , loss=0.0023
Epoch 61: 100%|██████████ , loss=0.0021
Epoch 62: 100%|██████████ , loss=0.0023
Epoch 63: 100%|██████████ , loss=0.0023
Epoch 64: 100%|██████████ , loss=0.0022
Epoch 65: 100%|██████████ , loss=0.0023
Epoch 66: 100%|██████████ , loss=0.0022
Epoch 67: 100%|██████████ , loss=0.0024
Epoch 68: 100%|██████████ , loss=0.0022
Epoch 69: 100%|██████████ , loss=0.0021
Epoch 70: 100%|██████████ , loss=0.0024
Epoch 71: 100%|██████████ , loss=0.0021
Epoch 72: 100%|██████████ , loss=0.0021
Epoch 73: 100%|██████████ , loss=0.0023
Epoch 74: 100%|██████████ , loss=0.0021
Epoch 75: 100%|██████████ , loss=0.0021


Epoch 75: FID score: 27.96


Epoch 76: 100%|██████████ , loss=0.0021
Epoch 77: 100%|██████████ , loss=0.0021
Epoch 78: 100%|██████████ , loss=0.0023
Epoch 79: 100%|██████████ , loss=0.0021
Epoch 80: 100%|██████████ , loss=0.0021
Epoch 81: 100%|██████████ , loss=0.0020
Epoch 82: 100%|██████████ , loss=0.0021
Epoch 83: 100%|██████████ , loss=0.0021
Epoch 84: 100%|██████████ , loss=0.0019
Epoch 85: 100%|██████████ , loss=0.0020
Epoch 86: 100%|██████████ , loss=0.0020
Epoch 87: 100%|██████████ , loss=0.0020
Epoch 88: 100%|██████████ , loss=0.0020
Epoch 89: 100%|██████████ , loss=0.0022
Epoch 90: 100%|██████████ , loss=0.0019
Epoch 91: 100%|██████████ , loss=0.0020
Epoch 92: 100%|██████████ , loss=0.0020
Epoch 93: 100%|██████████ , loss=0.0021
Epoch 94: 100%|██████████ , loss=0.0020
Epoch 95: 100%|██████████ , loss=0.0020
Epoch 96: 100%|██████████ , loss=0.0021
Epoch 97: 100%|██████████ , loss=0.0020
Epoch 98: 100%|██████████ , loss=0.0021
Epoch 99: 100%|██████████ , loss=0.0022
Epoch 100: 100%|██████████ , loss=0.0018

## 9. Final Evaluation

We evaluate using the EMA model and report the final FID score

In [29]:
final_fid = diffusion.calculate_fid(ema.ema_model, dataloader, num_samples=100)
print("Final FID Score (EMA Model with Augmentation):", final_fid.item())

Sampling: |           


Final FID Score (EMA Model with Augmentation): 26.49726104736328


# Conclusion
- **Model Performance:**  
  The diffusion model enhanced with EMA and data augmentation achieved a final FID score of approximately 26.5 with 300 noise steps. This indicates that the model is generating high-quality, realistic strong gravitational lensing images.

- **Impact of Techniques:**  
  - **Exponential Moving Average (EMA):**  
    EMA helped stabilize training and produced smoother generated samples.  
  - **Data Augmentation:**  
    Augmentation increased the diversity and robustness of the training data, contributing to improved performance.  
  - **Noise Schedule & Steps:**  
    Using a cosine noise schedule and reducing noise steps to 300 provided a balance between computational efficiency and sample quality.

# Discussion
- **Quality vs. Efficiency Trade-off:**  
  Reducing the number of noise steps speeds up sampling but may compromise the gradual denoising quality. The experiments suggest that 300 steps are sufficient for competitive results, whereas further increasing the number of steps might yield diminishing returns relative to the extra computational cost.

- **Computational Considerations:**  
  Diffusion models require many iterative sampling steps, resulting in high-quality outputs at the cost of longer inference times. Accelerated sampling techniques (e.g., DDIM) could be explored to improve efficiency without sacrificing quality.

# Future Work
- **Accelerated Sampling:**  
  Investigate methods such as DDIM to reduce inference time while maintaining or enhancing sample quality.

- **Hyperparameter Optimization:**  
  Conduct a systematic search over hyperparameters (e.g., noise steps, learning rate, model depth) to further optimize performance and potentially lower the FID score.

- **Architecture Enhancements:**  
  Consider incorporating additional layers, such as self-attention mechanisms, or more advanced conditioning techniques to capture finer details in the generated images.