# Super Resolution Diffusion Model Training Code
## Group 6 Super Resolution Project

Written following the guide at:

https://huggingface.co/docs/diffusers/en/tutorials/basic_training

and with reference to

https://arxiv.org/pdf/2104.07636

https://arxiv.org/pdf/2006.11239

## Training Code

### Imports
Make sure to install our package beforehand.

In [None]:

import matplotlib.pyplot as plt
import math
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn.functional import interpolate
import py7zr as py7zr
import diffusers
import accelerate

from super_resolution.src.sen2venus_dataset import S2VSite, S2VSites, create_train_test_split
from super_resolution.src.visualization import plot_gallery


### Defining Training Configuration

All the training parameters are set here for convenience.

In [None]:
# Training Configuration
from dataclasses import dataclass
@dataclass
class TrainingConfig:
    image_size = 256  # the generated image resolution
    train_batch_size = 1 # how many images to sample during training
    num_epochs = 1
    train_sites = set(["SO2"])
    data_dir = "../Data"
    output_dir = '../models/Test xx'
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    mixed_precision = 'fp16' if device == 'cuda' else 'no'# `no` for float32, `fp16` for automatic mixed precision
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook
    seed = 0

config = TrainingConfig()


### Loading Data

In [None]:
def clamp_transform(x, y):
    x = x[:3, :, :]
    y = y[:3, :, :]

    x = torch.clamp(x, 0, 1)
    y = torch.clamp(y, 0, 1)

    return x, y
    
train_data, test_data = create_train_test_split(
    data_dir = config.data_dir,
    seed = -1,
    sites = config.train_sites,
    device = config.device,
)

train_data.set_transform(clamp_transform)
test_data.set_transform(clamp_transform)

train_dataloader = DataLoader(train_data, batch_size=config.train_batch_size)
test_dataloader = DataLoader(test_data, batch_size=1)

### Defining Transforms

In [None]:
import torchvision.transforms.v2 as transforms

upscale = lambda x: interpolate(x, size=(config.image_size, config.image_size), mode="bicubic")



### Defining the U-Net to be used

In [None]:
from diffusers import UNet2DModel


model = UNet2DModel(
    sample_size=config.image_size,  # the target image resolution
    in_channels=6,  # 6 channels since the input is a concat of noise + upscaled low res
    out_channels=3,  # 3 RGB out channels
    layers_per_block=3,  # how many ResNet layers to use per UNet block
     # the number of output channels for each UNet block
    block_out_channels=(128, 128, 256, 256, 512, 512),
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D"
      ),
)


### Noise Pipelines and Optimizer

In [None]:
from diffusers import DDPMScheduler
from diffusers.optimization import get_cosine_schedule_with_warmup
from diffusers import DDPMPipeline

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

optimizer = torch.optim.Adam(model.parameters(), lr = config.learning_rate)

lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(len(train_dataloader) * config.num_epochs))


### Training Loop

Code adapted from a Hugging Face tutorial: https://huggingface.co/docs/diffusers/en/tutorials/basic_training

In [None]:
import torch.nn.functional as F
from accelerate import Accelerator
from tqdm.auto import tqdm
from pathlib import Path
import os

def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler, upscaler):
    # Initialize accelerator and tensorboard logging
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=os.path.join(config.output_dir, "logs"),
    )
    if accelerator.is_main_process:
        if config.output_dir is not None:
            os.makedirs(config.output_dir, exist_ok=True)
        accelerator.init_trackers("train_example")

    # Prepare the model and optimizer
    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )
    upscale.to(accelerator.device)
    global_step = 0

    # Now you train the model
    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_dataloader):
            low, high = batch
            low, high = process(low), process(high) # pre processing (normalise)
            upscaled = upscaler(low)

            # Sample noise to add to the images
            noise = torch.randn(high.shape, device=high.device)
            bs = high.shape[0]

            # Sample a random timestep for each image
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (bs,), device=high.device,
                dtype=torch.int64
            )

            # Add noise to the clean images according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_images = noise_scheduler.add_noise(high, noise, timesteps)

            with accelerator.accumulate(model):
                # Predict the noise residual
                noisy_images = torch.concat([noisy_images, upscaled], dim=1)
                noise_pred = model(noisy_images, timesteps, return_dict=False)[0]

                loss = F.mse_loss(noise_pred, noise)
                accelerator.backward(loss)

                accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1

        # After each epoch you optionally sample some demo images with evaluate() and save the model
        if accelerator.is_main_process:
            pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
            pipeline.save_pretrained(config.output_dir)


### Launch Training

In [None]:
from accelerate import notebook_launcher

args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler, upscale, process)

notebook_launcher(train_loop, args, num_processes=1)


### Resume Training

Set `pretrained_location` to the folder containing the pipeline checkpoints

In [None]:
pretrained_location = '../models/Test 11'

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr = config.learning_rate)

lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(len(train_dataloader) * config.num_epochs))

model = UNet2DModel.from_pretrained(pretrained_location + "/unet")
model.train()
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_location + "/scheduler")
# noise scheduler doesn't need to be set to train

In [None]:
from accelerate import notebook_launcher

args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler, upscale, process)

notebook_launcher(train_loop, args, num_processes=1)

## Testing Code

###  Load Model

Will need to be configured to the location of the model files

In [None]:
model_location = '../models/Test 11'
loadModel = DDPMPipeline.from_pretrained(model_location)
loadModel.to(config.device)

### Testing Code Loop
Generates a High resolution sample from a low resolution input

In [None]:
@torch.no_grad
def test_diffuse(pipeline, upscaler, low, generator = None):

    noise = torch.randn(
            (1, 3, pipeline.unet.sample_size, pipeline.unet.sample_size),
            generator=generator,
        )
    noise = noise.to(pipeline.device)

    upscaled = upscaler(low)
    upscaled = upscaled.to(pipeline.device)
    # set step values
    pipeline.scheduler.set_timesteps(1000)

    for t in pipeline.progress_bar(pipeline.scheduler.timesteps):
            # 1. predict noise model_output
            image = torch.concat([noise, upscaled], dim=1)
            model_output = pipeline.unet(image, t).sample

            # 2. compute previous image: x_t -> t_t-1
            noise = pipeline.scheduler.step(model_output, t, noise, generator=generator).prev_sample

    noise = (noise / 2 + 0.5).clamp(0, 1)
    noise = noise.cpu().squeeze(0).permute(2,1,0)
    return noise

### Generating Test Images
To generate a single test image, use this function

In [None]:
diffuseSR = lambda x: test_diffuse(loadModel, upscale, x)

#### Plotting

This loop will save a bunch of images to plot

In [None]:
NUM_IMGS = 10

test_imgs = []
titles = []
for (i, img) in enumerate(test_dataloader):
  if i <= (NUM_IMGS - 1):
      titles.append(f"Low Res {i}")
      titles.append(f"Naive Upscale {i}")
      titles.append(f"Upscaled {i}")
      titles.append(f"High Res {i}")
      low, high = img
      upscaled = test_diffuse(loadModel, upscale, low)
      test_imgs.append(low.squeeze(0).permute(2,1,0))
      test_imgs.append(upscale(low).squeeze(0).permute(2,1,0))
      test_imgs.append(upscaled)
      test_imgs.append(high.squeeze(0).permute(2,1,0))
test_imgs = [t.cpu() for t in test_imgs]


In [None]:
plot_gallery(test_imgs, titles, nrow = NUM_IMGS, xscale = 3, yscale = 3)