In [65]:
from typing import (
    Tuple,
    Optional,
    Literal
)

import torch
import torch.nn.functional as F
torch.manual_seed(0)

from schedule import linear_beta_schedule
from unet import Unet

timesteps = 300

# define beta schedule
betas = linear_beta_schedule(timesteps=timesteps)

# define alphas 
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

def extract(a: torch.Tensor, t: torch.Tensor, x_shape: Tuple):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def q_sample(x_start: torch.Tensor, t: torch.Tensor, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)

    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x_start.shape
    )

    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

def p_losses(
        denoise_model: Unet,
        x_start,
        t,
        noise=None,
        loss_type="l1",
        self_condition: Optional[torch.Tensor] = None
    ):
    if noise is None:
        noise = torch.randn_like(x_start)

    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)

    if denoise_model.self_condition:
        if self_condition is None:
            raise RuntimeError("The self-conditioning is not provided. ")
        
        predicted_noise = denoise_model.forward(
            x=x_noisy, 
            time=t,
            x_self_cond=self_condition
        )
        
    else:
        predicted_noise = denoise_model.forward(
            x=x_noisy, 
            time=t
        )

    if loss_type == 'l1':
        loss = F.l1_loss(noise, predicted_noise)
    elif loss_type == 'l2':
        loss = F.mse_loss(noise, predicted_noise)
    elif loss_type == "huber":
        loss = F.smooth_l1_loss(noise, predicted_noise)
    else:
        raise NotImplementedError()

    return loss

In [66]:
from torch.utils.data import Dataset
from torchvision.transforms import (
    Lambda,
    Compose,
    Resize
)
from torchvision.transforms import functional as TV_F

from torch.utils.data import DataLoader

import numpy as np
import random
import os

class WeatherFieldsDataset(Dataset):
    def __init__(self, root_dir, path_to_folder, transform=None):
        """
        Arguments:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        lr_data_folder = os.path.join(
            root_dir, 
            path_to_folder,
            "train_2017_lr",
        )
        hr_data_folder = os.path.join(
            root_dir, 
            path_to_folder,
            "train_2017_hr",
        )
        
        date_idx_to_hr_file_names = {}

        for hr_file_name in os.listdir(hr_data_folder):
            hr_file_name_copy = hr_file_name
            hr_file_name = hr_file_name.replace("_high_res_", "_")
            hr_file_name = hr_file_name.replace(".npy", "")
            _, date, number = hr_file_name.split("_")
            index = int("".join(date.split("-")) + number)
            date_idx_to_hr_file_names[index] = hr_file_name_copy

        date_idx_to_file_pathes = {}

        for lr_file_name in os.listdir(lr_data_folder):
            lr_file_name_copy = lr_file_name
            lr_file_name = lr_file_name.replace("_low_res_", "_")
            lr_file_name = lr_file_name.replace(".npy", "")
            _, date, number = lr_file_name.split("_")
            index = int("".join(date.split("-")) + number)
            hr_file_name = date_idx_to_hr_file_names.get(index)
            if hr_file_name is not None:
                date_idx_to_file_pathes[index] = (
                    os.path.join(
                        lr_data_folder,
                        lr_file_name_copy,
                    ),
                    os.path.join(
                        hr_data_folder,
                        hr_file_name,
                    )
                )

        self.transform = transform
        self.date_idx_to_file_pathes = date_idx_to_file_pathes 
        self.sorted_date_idx = list(date_idx_to_file_pathes.keys())
        self.sorted_date_idx.sort()
        self.lr_transform = Compose([
            Lambda(lambda t: (t / 255.)),
            Lambda(lambda t: (t*2) - 1),
            Lambda(lambda t: t.permute(2, 0, 1)),
            Resize(size=(128, 128), antialias=True),
        ])

        self.hr_transform = Compose([
            Lambda(lambda t: (t / 255.)),
            Lambda(lambda t: (t*2) - 1),
            Lambda(lambda t: t.permute(2, 0, 1))
        ])
  
    def __len__(self):
        return len(self.sorted_date_idx)

    def __getitem__(self, index) -> torch.TensorType:
        date_idx = self.sorted_date_idx[index]
        lr_file_path, hr_file_path = self.date_idx_to_file_pathes[date_idx]

        with torch.no_grad():
            lr_image = torch.from_numpy(np.load(lr_file_path))
            hr_image = torch.from_numpy(np.load(hr_file_path))

            lr_image = self.lr_transform(lr_image)
            hr_image = self.hr_transform(hr_image)

            if random.random() < 0.5:
                lr_image = TV_F.hflip(lr_image)
                hr_image = TV_F.hflip(hr_image)
                
            if random.random() < 0.5:
                lr_image = TV_F.vflip(lr_image)
                hr_image = TV_F.vflip(hr_image)

        return lr_image, hr_image

In [67]:
from eval_generation import (
    sample
)
from torch.optim import Adam
from unet import Unet

dataset = WeatherFieldsDataset(
    root_dir=os.path.abspath(".."),
    path_to_folder=os.path.join(
        "data",
        "wrf_data",
    )
)

batch_size = 1
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
image_size = 128
channels = 3

hr_image, lr_image = dataset[0]
C, H, W = hr_image.shape

model = Unet(
    dim=H,
    channels=C,
    dim_mults=(1, 2, 4,),
    self_condition=True,
)
model.to(device)

optimizer = Adam(model.parameters(), lr=1e-3)

In [68]:
@torch.no_grad()
def spectral_noise_generator(shape: Tuple) -> torch.Tensor:
    noise = torch.randn(shape)
    return noise, torch.fft.rfft2(noise)

def complex_mse_loss(  
        input,
        target,
    ):
    difference = input - target
    return ((difference.real**2 + difference.imag**2) / 2).mean()

def p_spectral_losses(
        denoise_model: Unet,
        x_start: torch.Tensor,
        t: torch.Tensor,
        noise: Optional[torch.Tensor] = None,
        self_condition: Optional[torch.Tensor] = None
    ):
    if noise is None:
        noise = torch.randn_like(x_start)

    with torch.no_grad():
        domain_fourier_noise = torch.randn_like(x_start)
        fourier_noise = torch.fft.rfft2(domain_fourier_noise)
        
    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    x_noisy_transformed = q_sample(x_start=x_start, t=t, noise=domain_fourier_noise)

    if denoise_model.self_condition:
        if self_condition is None:
            raise RuntimeError("The self-conditioning is not provided. ")
        
        predicted_noise = denoise_model.forward(
            x=x_noisy, 
            time=t,
            x_self_cond=self_condition
        )
        
        predicted_domain_fourier_noise = denoise_model.forward(
            x=x_noisy_transformed, 
            time=t,
            x_self_cond=self_condition
        )
        
    else:
        predicted_noise = denoise_model.forward(
            x=x_noisy, 
            time=t
        )
        
        predicted_domain_fourier_noise = denoise_model.forward(
            x=x_noisy_transformed, 
            time=t,
            x_self_cond=self_condition
        )

    loss = (
        F.mse_loss(noise, predicted_noise) +
        complex_mse_loss(fourier_noise, torch.fft.rfft2(predicted_domain_fourier_noise))
    )

    return loss

In [56]:
epochs = 6
state = {
   "loss_train":[]
}
for epoch in range(epochs):
    for step, (lr_batch, hr_batch) in enumerate(dataloader):
      optimizer.zero_grad()

      batch_size, _, _, _ = lr_batch.shape
      lr_batch = lr_batch.to(device)
      hr_batch = hr_batch.to(device)

      t = torch.randint(0, timesteps, (batch_size,), device=device).long()

      loss = p_spectral_losses(
         denoise_model=model, 
         x_start=hr_batch, 
         t=t,
         self_condition=lr_batch,
      )

      if step % 100 == 0:
        print("Loss:", loss.item())

      loss.backward()
      state['loss_train'].append(float(loss.detach().cpu()))
      optimizer.step()
      break

Loss: 18686.611328125
Loss: 135340.046875
Loss: 19188.12109375
Loss: 19802.421875
Loss: 84844.78125
Loss: 46212.484375


In [63]:
from datetime import datetime
now = datetime.now()
now = now.strftime('%m_%d_%M_%S')
file_name = now + "_checkpoint.pkl"

pwd_path = os.path.abspath("..")
folder_path = os.path.join(
    pwd_path,
    "checkpoints"
)

file_path = os.path.join(
    folder_path,
    file_name
)

state["model_state_dict"] = model.state_dict()
state["optimizer_state_dict"] = optimizer.state_dict()
state["model_kwargs"] = {
    "dim":H,
    "channels":C,
    "dim_mults":(1, 2, 4,),
    "self_condition":True,
}
state["other"] = {
    "epochs":epochs,
    "batch_size":batch_size,
}
torch.save(state, file_path)