In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# cd drive/My \Drive/.....

In [None]:
'''
Notebook to train DDPM for SR
'''

# Import required libraries
import os
import numpy as np
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from tqdm import tqdm
from torch import optim
import torchvision
from PIL import Image
from modules_conditional import UNet, Diffusion
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

def save_images(generated, conditional, path, **kwargs):
    # Convert to numpy arrays
    generated = generated.to('cpu').numpy()
    conditional = conditional.to('cpu').numpy()

    # Create a figure and axes
    fig, axes = plt.subplots(nrows=5, ncols=4, figsize=(20, 20))

    # Adjust space between images
    plt.subplots_adjust(wspace=0.5, hspace=0.5)

    for i, (gen_img, cond_img) in enumerate(zip(generated, conditional)):
        row = i // 2
        col = (i % 2) * 2

        # Conditional images
        axes[row, col].imshow(cond_img.squeeze(), cmap='gray', **kwargs)
        axes[row, col].axis('off')
        axes[row, col].set_title(f"Conditional {i+1}", fontsize=10, y=1.05)

        # Generated images
        axes[row, col + 1].imshow(gen_img.squeeze(), cmap='gray', **kwargs)
        axes[row, col + 1].axis('off')
        axes[row, col + 1].set_title(f"Generated {i+1}", fontsize=10, y=1.05)

    # Save the figure
    fig.savefig(path, bbox_inches='tight')
    plt.close(fig)

# Load training data
# High-Resolution lensing data
x_trainHR = np.load('./Data/.../train_HR.npy').astype(np.float32).reshape(-1,1,64,64)
# Low-Resolution lensing data
x_trainLR = np.load('./Data/.../train_LR.npy').astype(np.float32).reshape(-1,1,64,64)
x_trainHR = torch.Tensor(x_trainHR)
x_trainLR = torch.Tensor(x_trainLR)
# Print data dimensions
print(x_trainHR.shape)
print(x_trainLR.shape)

# Create dataset and dataloader for efficient data loading and batching
dataset = TensorDataset(x_trainHR,x_trainLR)
dataloader = DataLoader(dataset, batch_size=10)

device = "cuda"
model = UNet().to(device)
#model.load_state_dict(torch.load('./Weights/Diff_ckpt_1.pt'))
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
mse = nn.MSELoss()
diffusion = Diffusion(img_size=64, device=device)
l = len(dataloader)
epochs = 100

for epoch in range(1, epochs):
    print(f"Starting epoch {epoch}:")
    pbar = tqdm(dataloader)
    avg_mse = 0
    for i, (images, conditions) in enumerate(pbar):
        images = images.to(device)
        conditions = conditions.to(device)
        t = diffusion.sample_timesteps(images.shape[0]).to(device)
        x_t, noise = diffusion.noise_images(images, t)
        predicted_noise = model(x_t, t, conditions)
        loss = mse(noise, predicted_noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_postfix(MSE=loss.item())
        avg_mse += loss.item()

    print(f'Average MSE: {avg_mse/1000:.5f}\n')
    # sampled_images = diffusion.sample(model, n=images.shape[0], lat=conditions)
    # save_images(sampled_images, conditions, os.path.join("Results/Diff/Diff", f"{epoch}.png"))
    torch.save(model.state_dict(), os.path.join("Weights", f"Diff_ckpt_1.pt"))