# Super Resolution

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import math
import copy
import cv2
import os
import json
import time

from PIL import Image
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from model import UNet, ExponentialMovingAverage
from dsr import DSRDataset
from pathlib import Path

In [None]:
root = Path('./dsr/')

with open(root / 'train_valid_test_split.json', 'r') as f:
    split = json.load(f)
    
train_dataset = DSRDataset(root, split['train'])
val_dataset = DSRDataset(root, split['test'])

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=16,
    num_workers=2,
    shuffle=True,
    prefetch_factor=2
)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=16,
    num_workers=2,
    shuffle=True,
    prefetch_factor=2
)

In [None]:
start_time = time.time()

images = next(iter(train_loader))
lowres_images = images[0].numpy()
highres_images = images[1].numpy()

elapsed_time = time.time() - start_time
print(f"Time taken to get next batch of images: {elapsed_time} seconds")

print(lowres_images.shape)
print(highres_images.shape)

print('mean:', highres_images.mean())
print('variance:', highres_images.var())
print('min:', highres_images.min())
print('max:', highres_images.max())

fig, axes = plt.subplots(figsize=(10, 20), nrows=4, ncols=2)
for i in range(4):
    ax = axes[i]

    lowres_image = (lowres_images[i] + 1) / 2
    highres_image = (highres_images[i] + 1) / 2

    ax[0].imshow(lowres_image.transpose(1, 2, 0))
    ax[0].set_title('Upscaled Image')

    ax[1].imshow(highres_image.transpose(1, 2, 0))
    ax[1].set_title('Original Image')

plt.show()

## U-Net

![Example Image](unet.png)

In [None]:
from simple_diffusion.model import UNet
from simple_diffusion.scheduler import DDIMScheduler

device = 'cuda' if torch.cuda.is_available() else 'cpu'
unet = UNet(in_channels=6).to(device)
ema = ExponentialMovingAverage(copy.deepcopy(unet).requires_grad_(False))
optimizer = torch.optim.AdamW(unet.parameters(),lr=1e-3,betas=(0.9, 0.99),weight_decay=0.0)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
criterion = nn.MSELoss()
epochs = 50
T = 2000
diffusion_scheduler = DDIMScheduler(beta_schedule="cosine") #DiffusionScheduler(T, schedule_type='linear')
tensorboard_path="./runs/diffusion-2M-test"

num_params = sum(p.numel() for p in unet.parameters() if p.requires_grad)
print("Number of parameters:", num_params, device)

In [None]:
from train import train

train(unet, ema, diffusion_scheduler, train_loader, val_loader, epochs, device, optimizer, scheduler, criterion, tensorboard_path)

In [None]:
torch.save(unet.state_dict(), 'dsr_sr_cos2.pth')
#unet.load_state_dict(torch.load('dsr_sr_cos.pth'))

In [None]:
X, y = next(iter(val_loader))
X = X[4].unsqueeze(0).to(device)
y = y[4].to(device)

samples = diffusion_scheduler.generate(unet, X)

X = (X + 1) / 2
y = (y + 1) / 2

fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].imshow(X[0].permute(1, 2, 0).cpu().numpy())
ax[0].set_title("Condition image")

ax[1].imshow(samples[-1][0].permute(1, 2, 0).cpu().numpy())
ax[1].set_title("Sample image")

ax[2].imshow(y.permute(1, 2, 0).cpu().numpy())
ax[2].set_title("Target image")

In [None]:
import matplotlib.animation as animation
from IPython.display import HTML

fig, ax = plt.subplots(figsize=(10, 5))
ax.axis('off')

def update(i):
    if i < len(samples):
        image = np.clip(samples[i][0].detach().numpy(), 0, 1)
    else:
        last_frame_index = len(samples) - 1
        image = np.clip(samples[last_frame_index][0].detach().numpy(), 0, 1)
        
    image = np.concatenate((X[0].cpu().numpy(), image), axis=2)
    ax.imshow(np.transpose(image, (1, 2, 0)))
    
ani = animation.FuncAnimation(fig, update, frames=len(samples) + 60, interval=60)
HTML(ani.to_jshtml())

In [None]:
ani.save('denoising_sr.gif', writer='imagemagick', fps=30)