<a href="https://colab.research.google.com/github/alixmacdonald10/TRACKER/blob/main/TRACKER_PreProcessing_Denoising.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#PREREQUISITS

##Mount drive


In [1]:
from google.colab import drive
 
# mount drive to access file
drive.mount('/content/gdrive')
root_path = 'gdrive/My Drive/your_project_folder/'  #change dir to your project folder

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


##Install requirements

In [2]:
# install weights and bias for logging
!pip install wandb -qqq
import wandb

wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mamackerel[0m (use `wandb login --relogin` to force relogin)


True

In [3]:
# NN imports
!pip3 install torch==1.9.0+cu102 torchvision==0.10.0+cu102 torchaudio===0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
!pip install tensorboard

Looking in links: https://download.pytorch.org/whl/torch_stable.html


In [4]:
# database specific imports
!pip3 install h5py



##Imports

In [5]:
%load_ext tensorboard
import torch
from torch.utils import data
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
from torch import optim
from torch.optim import lr_scheduler

import PIL
import h5py
import numpy as np
from pathlib import Path
import random
import time
import copy
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
%matplotlib inline
from google.colab.patches import cv2_imshow

In [6]:
# debugging
from IPython.core.debugger import set_trace

##Check Cuda

In [7]:
# Check devices
num_devices = torch.cuda.device_count()
print(f'Number of cuda devices: {num_devices}')
for device in range(0, num_devices):
  device_name = torch.cuda.get_device_name(device)
  print(f'Cuda device name: {device_name}')

Number of cuda devices: 1
Cuda device name: Tesla P100-PCIE-16GB


#MODEL INFORMATION

##Set hyper parameters

In [8]:
# define hyper parameters
project_name = 'Tracker-PreProcessing-Denoising'
verbose = 1  # print out helpers
transformed_when = 'after patches'
transform1 = 'hflip'
transform2 = 'rotate'
patch_size = 256
patch_stride = (12 * patch_size)
plot_itterations = 5
num_epoch = 10
max_itterations = 4e5
initial_learning_rate = 2e-4
min_learning_rate = 1e-7
loss_type = "PSNR"
scheduler_type = "cosine_annealing"
optimizer_type = "Adam"
if optimizer_type == "SGD":
  batch = 1
else:
  batch = 8  #  max batch size for cuda memory
mini_batch_size = 4
fpath = '/content/gdrive/MyDrive/Programming/datasets/noisyDataset.hdf5'

In [9]:
# Track hyperparameters and run metadata
hyperparameters = {
  "device_name": device_name,
  "transformed_when": transformed_when,
  "transform1": transform1,
  "transform2": transform2,
  "verbose": verbose,
  "patch_size": patch_size,
  "patch_stride": patch_stride,
  "plot_itterations": plot_itterations,
  "batch": batch,
  "mini_batch_size": mini_batch_size,
  "num_epoch": num_epoch,
  "max_itterations": max_itterations,
  "initial_learning_rate": initial_learning_rate,
  "min_learning_rate": min_learning_rate,
  "scheduler_type": scheduler_type,
  "loss_type": "PSNR",
  "optimizer_type": optimizer_type,
  "architecture": "HINet",
  "dataset": "Smartphone Image Denoising Dataset (SIDD)",
  "project_name": project_name
}

##Define model

###Note for transfer learning and importing pre-trained models

In [10]:
''' 
Models can be imported from saved files and transfer learned by:
 
'''
#from torchvision import models
 
#model = models.resnet101(pretrained=True)
 
# TO ONLY TRAIN THE LAST LAYER (TRANSFER LEARNING) DO THE FOLLOWING
#for param in model.parameters():
#   param.requires_grad = False  # freezes all layers in beginning
 
# return number of features in output (to transfer learn)
#num_ftrs = model.fc.in_features
 
# set output fully connected using
#model.fc = nn.Linear(num_ftrs, 2)   # (input layers, output layers)
#model.to(device)

' \nModels can be imported from saved files and transfer learned by:\n \n'

###Model Architecture

In [11]:
'''
HINet: Half Instance Normalization Network for Image Restoration
@inproceedings{chen2021hinet,
  title={HINet: Half Instance Normalization Network for Image Restoration},
  author={Liangyu Chen and Xin Lu and Jie Zhang and Xiaojie Chu and Chengpeng Chen},
  booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops},
  year={2021}
}
'''


def conv3x3(in_chn, out_chn, bias=True):
    layer = nn.Conv2d(in_chn, out_chn, kernel_size=3, stride=1, padding=1, bias=bias)
    return layer

def conv_down(in_chn, out_chn, bias=False):
    layer = nn.Conv2d(in_chn, out_chn, kernel_size=4, stride=2, padding=1, bias=bias)
    return layer

def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size//2), bias=bias, stride = stride)

## Supervised Attention Module
class SAM(nn.Module):
    def __init__(self, n_feat, kernel_size=3, bias=True):
        super(SAM, self).__init__()
        self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias)
        self.conv2 = conv(n_feat, 3, kernel_size, bias=bias)
        self.conv3 = conv(3, n_feat, kernel_size, bias=bias)

    def forward(self, x, x_img):
        x1 = self.conv1(x)
        img = self.conv2(x) + x_img
        x2 = torch.sigmoid(self.conv3(img))
        x1 = x1*x2
        x1 = x1+x
        return x1, img

class HINet(nn.Module):

    def __init__(self, in_chn=3, wf=64, depth=5, relu_slope=0.2, hin_position_left=0, hin_position_right=4):
        super(HINet, self).__init__()
        self.depth = depth
        self.down_path_1 = nn.ModuleList()
        self.down_path_2 = nn.ModuleList()
        self.conv_01 = nn.Conv2d(in_chn, wf, 3, 1, 1)
        self.conv_02 = nn.Conv2d(in_chn, wf, 3, 1, 1)

        prev_channels = self.get_input_chn(wf)
        for i in range(depth): #0,1,2,3,4
            use_HIN = True if hin_position_left <= i and i <= hin_position_right else False
            downsample = True if (i+1) < depth else False
            self.down_path_1.append(UNetConvBlock(prev_channels, (2**i) * wf, downsample, relu_slope, use_HIN=use_HIN))
            self.down_path_2.append(UNetConvBlock(prev_channels, (2**i) * wf, downsample, relu_slope, use_csff=downsample, use_HIN=use_HIN))
            prev_channels = (2**i) * wf

        self.up_path_1 = nn.ModuleList()
        self.up_path_2 = nn.ModuleList()
        self.skip_conv_1 = nn.ModuleList()
        self.skip_conv_2 = nn.ModuleList()
        for i in reversed(range(depth - 1)):
            self.up_path_1.append(UNetUpBlock(prev_channels, (2**i)*wf, relu_slope))
            self.up_path_2.append(UNetUpBlock(prev_channels, (2**i)*wf, relu_slope))
            self.skip_conv_1.append(nn.Conv2d((2**i)*wf, (2**i)*wf, 3, 1, 1))
            self.skip_conv_2.append(nn.Conv2d((2**i)*wf, (2**i)*wf, 3, 1, 1))
            prev_channels = (2**i)*wf
        self.sam12 = SAM(prev_channels)
        self.cat12 = nn.Conv2d(prev_channels*2, prev_channels, 1, 1, 0)

        self.last = conv3x3(prev_channels, in_chn, bias=True)

    def forward(self, x):
        image = x
        #stage 1
        x1 = self.conv_01(image)
        encs = []
        decs = []
        for i, down in enumerate(self.down_path_1):
            if (i+1) < self.depth:
                x1, x1_up = down(x1)
                encs.append(x1_up)
            else:
                x1 = down(x1)

        for i, up in enumerate(self.up_path_1):
            x1 = up(x1, self.skip_conv_1[i](encs[-i-1]))
            decs.append(x1)

        sam_feature, out_1 = self.sam12(x1, image)
        #stage 2
        x2 = self.conv_02(image)
        x2 = self.cat12(torch.cat([x2, sam_feature], dim=1))
        blocks = []
        for i, down in enumerate(self.down_path_2):
            if (i+1) < self.depth:
                x2, x2_up = down(x2, encs[i], decs[-i-1])
                blocks.append(x2_up)
            else:
                x2 = down(x2)

        for i, up in enumerate(self.up_path_2):
            x2 = up(x2, self.skip_conv_2[i](blocks[-i-1]))

        out_2 = self.last(x2)
        out_2 = out_2 + image
        return [out_1, out_2]

    def get_input_chn(self, in_chn):
        return in_chn

    def _initialize(self):
        gain = nn.init.calculate_gain('leaky_relu', 0.20)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.orthogonal_(m.weight, gain=gain)
                if not m.bias is None:
                    nn.init.constant_(m.bias, 0)


class UNetConvBlock(nn.Module):
    def __init__(self, in_size, out_size, downsample, relu_slope, use_csff=False, use_HIN=False):
        super(UNetConvBlock, self).__init__()
        self.downsample = downsample
        self.identity = nn.Conv2d(in_size, out_size, 1, 1, 0)
        self.use_csff = use_csff

        self.conv_1 = nn.Conv2d(in_size, out_size, kernel_size=3, padding=1, bias=True)
        self.relu_1 = nn.LeakyReLU(relu_slope, inplace=False)
        self.conv_2 = nn.Conv2d(out_size, out_size, kernel_size=3, padding=1, bias=True)
        self.relu_2 = nn.LeakyReLU(relu_slope, inplace=False)

        if downsample and use_csff:
            self.csff_enc = nn.Conv2d(out_size, out_size, 3, 1, 1)
            self.csff_dec = nn.Conv2d(out_size, out_size, 3, 1, 1)

        if use_HIN:
            self.norm = nn.InstanceNorm2d(out_size//2, affine=True)
        self.use_HIN = use_HIN

        if downsample:
            self.downsample = conv_down(out_size, out_size, bias=False)

    def forward(self, x, enc=None, dec=None):
        out = self.conv_1(x)

        if self.use_HIN:
            out_1, out_2 = torch.chunk(out, 2, dim=1)
            out = torch.cat([self.norm(out_1), out_2], dim=1)
        out = self.relu_1(out)
        out = self.relu_2(self.conv_2(out))

        out += self.identity(x)
        if enc is not None and dec is not None:
            assert self.use_csff
            out = out + self.csff_enc(enc) + self.csff_dec(dec)
        if self.downsample:
            out_down = self.downsample(out)
            return out_down, out
        else:
            return out


class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, relu_slope):
        super(UNetUpBlock, self).__init__()
        self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, bias=True)
        self.conv_block = UNetConvBlock(in_size, out_size, False, relu_slope)

    def forward(self, x, bridge):
        up = self.up(x)
        out = torch.cat([up, bridge], 1)
        out = self.conv_block(out)
        return out

class Subspace(nn.Module):

    def __init__(self, in_size, out_size):
        super(Subspace, self).__init__()
        self.blocks = nn.ModuleList()
        self.blocks.append(UNetConvBlock(in_size, out_size, False, 0.2))
        self.shortcut = nn.Conv2d(in_size, out_size, kernel_size=1, bias=True)

    def forward(self, x):
        sc = self.shortcut(x)
        for i in range(len(self.blocks)):
            x = self.blocks[i](x)
        return x + sc


class skip_blocks(nn.Module):

    def __init__(self, in_size, out_size, repeat_num=1):
        super(skip_blocks, self).__init__()
        self.blocks = nn.ModuleList()
        self.re_num = repeat_num
        mid_c = 128
        self.blocks.append(UNetConvBlock(in_size, mid_c, False, 0.2))
        for i in range(self.re_num - 2):
            self.blocks.append(UNetConvBlock(mid_c, mid_c, False, 0.2))
        self.blocks.append(UNetConvBlock(mid_c, out_size, False, 0.2))
        self.shortcut = nn.Conv2d(in_size, out_size, kernel_size=1, bias=True)

    def forward(self, x):
        sc = self.shortcut(x)
        for m in self.blocks:
            x = m(x)
        return x + sc

###Dataset

**NOTE:** The dataset section can vary between analysis!

####Dataset class and transforms

In [12]:
# define dataset class

class HDF5Dataset(data.Dataset):
  """Represents a HDF5 dataset. with X = 'data' and y = 'labels'
  
  Input params:
      file_path: Path to the HDF5 file
      transform: PyTorch transform to apply to every data instance (default=None).
  """
  def __init__(self, config, file_loc=None, transform=True):
    super().__init__()
    self.transform = transform
    self.config = config

    P = Path(file_loc)
    # Search for all h5 files in path
    self.file = h5py.File(P, 'r')
    # return number of samples
    self.n_samples = len(self.file.get('data'))
    
    # create patches of patch_size
    self.patch_size = config.patch_size
    self.stride = config.patch_stride  # for even patches
       

  def __getitem__(self, index):
    '''
    returns items X and y. If config
    '''
    X, y = self.load_file(index)  # -> [w, h, c]

    if self.config.transformed_when == 'before':
      if self.transform:
        X, y = transform_func(X, y, self.config)
    else:    
      X = torch.from_numpy(np.float32(X)).permute(2, 0, 1)  # -> [c, w, h]
      if self.config.patch_size > 0:
        X = X.unfold(1, self.patch_size, self.stride).unfold(2, self.patch_size, self.stride).unfold(3, self.patch_size, self.stride) # -> [c, vert_patches, horiz_patches, img, height, width]
        X = torch.flatten(X, start_dim=1, end_dim=3).permute(1, 0, 2, 3) / 255 # -> [patches, c, height, width]  0 to 1 

      y = torch.from_numpy(np.float32(y)).permute(2, 0, 1)  # -> [c, w, h]
      if self.config.patch_size > 0:
        y = y.unfold(1, self.patch_size, self.stride).unfold(2, self.patch_size, self.stride).unfold(3, self.patch_size, self.stride) # -> [c, vert_patches, horiz_patches, img, height, width]
        y = torch.flatten(y, start_dim=1, end_dim=3).permute(1, 0, 2, 3) / 255 # -> [patches, c, height, width]  0 to 1

      num_patches = int(X.shape[0])

      if self.transform:
        for img in range(0, num_patches):
          X_temp, y_temp = transform_func(X[img], y[img], self.config)
          X_temp = X_temp[None, :, :]
          y_temp = y_temp[None, :, :]
          torch.cat((X, X_temp), 0)
          torch.cat((y, y_temp), 0)
    return (X, y)


  def __len__(self):
    # allows length of dataset to be returned
    return self.n_samples


  def load_file(self, index):
    ''' load index of database '''
    X = self.file.get('data')[index]
    y = self.file.get('labels')[index]

    return (X, y)

In [13]:
class MiniBatchDataset(HDF5Dataset):
  """Represents a HDF5 dataset. with X = 'data' and y = 'labels'
  
  Input params:
      file_path: Path to the HDF5 file
      transform: PyTorch transform to apply to every data instance (default=None).
  """
  def __init__(self, X, y, config):
    super().__init__(config, fpath)
    assert X.shape[0] == y.shape[0] # assuming shape[0] = dataset size
    self.X = X
    self.y = y

  def __len__(self):
    return self.y.shape[0]

  def __getitem__(self, index):
    return self.X[index], self.y[index]


In [14]:
  #transform functions
  
  def transform_func(image, label, config):

      '''
      Apply transforms to image and label
      '''
      
      # convert to PIL image
      image = transforms.functional.to_pil_image(image)
      label = transforms.functional.to_pil_image(label)

      if config.transform1 == "hflip":
        # Random horizontal flipping
        if random.random() > 0.5:
            image = transforms.functional.hflip(image)
            label = transforms.functional.hflip(label)

      if config.transform1 == "hflip":
        # Random 90 deg rotation
        if random.random() > 0.5:
            image = transforms.functional.rotate(image, 90)
            label = transforms.functional.rotate(label, 90)

      # Transform to tensor
      image = transforms.functional.to_tensor(image)
      label = transforms.functional.to_tensor(label)

      return image, label

####Split dataset into train, val, test split

In [15]:
def train_val_test_split(dataset):

  # perform random splits on dataset to return train, val, test sets (manual seed fixes output for repeatable results , remove device = cuda for non. GPU )
  dataset_length = len(dataset)

  # set sizes
  train_set_size = 0.8
  val_set_size = 0.1
  test_set_size = 0.1

  # return lengths
  train_set_length = round(train_set_size * dataset_length, 0)
  val_set_length = round(val_set_size * dataset_length, 0)
  test_set_length = dataset_length - train_set_length - val_set_length

  # check
  total = train_set_length + val_set_length + test_set_length
  print(f'total length: {total} / train set length: {train_set_length} / validation set length: {val_set_length} / test set length: {test_set_length}')
  assert total == dataset_length

  # create datasets
  train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
      dataset,
      [int(train_set_length), int(val_set_length), int(test_set_length)],
      generator=torch.Generator().manual_seed(42)
  )

  return train_dataset, val_dataset, test_dataset


####Test plot functions

In [16]:
def test_plot_Xy_batch(dataset):
 
  # plot images
  num_cols = 4
  num_patches = int(num_cols / 2)
  num_rows = 2
  fig, axs = plt.subplots(num_rows, num_cols, figsize=(16, 16), sharey=True)
  fig.suptitle("Train sample example images")

  row = 0
  for image in range(0, num_rows):  # plot two images in batch
    X, y = dataset.__getitem__(image)
    col = 0
    for patch in range(0, num_patches):  #plot patches
      X_img_patches = X.permute(0, 2, 3, 1)
      X_img = X_img_patches[patch]

      y_img_patches = y.permute(0, 2, 3, 1)
      y_img = y_img_patches[patch]

      axs[row, col].imshow(X_img)
      axs[row, col].title.set_text(f'Batch {image} X train example patch {patch}')
      col += 1
      axs[row, col].imshow(y_img)
      axs[row, col].title.set_text(f'Batch {image} y train example patch {patch}')
      col += 1
    row += 1

In [17]:
def test_plot_Xy_train_val_test_dataset(train_dataset, val_dataset, test_dataset):
  # Get dataset items
  test_index = 1
  X_train, y_train = train_dataset.__getitem__(test_index)
  X_val, y_val = val_dataset.__getitem__(test_index)
  X_test, y_test = test_dataset.__getitem__(test_index)

  patch = 0
  
  # plot images
  fig, axs = plt.subplots(3, 2, figsize=(16, 16), sharey=True)
  fig.suptitle("Train, validation, test dataset sample example images")
  axs[0, 0].imshow(X_train.permute(0, 2, 3, 1)[patch])
  axs[0, 0].title.set_text('X train example image')
  axs[0, 1].imshow(y_train.permute(0, 2, 3, 1)[patch])
  axs[0, 1].title.set_text('y train example image')
  axs[1, 0].imshow(X_val.permute(0, 2, 3, 1)[patch])
  axs[1, 0].title.set_text('X validation example image')
  axs[1, 1].imshow(y_val.permute(0, 2, 3, 1)[patch])
  axs[1, 1].title.set_text('y validation example image')
  axs[2, 0].imshow(X_test.permute(0, 2, 3, 1)[patch])
  axs[2, 0].title.set_text('X test example image')
  axs[2, 1].imshow(y_test.permute(0, 2, 3, 1)[patch])
  axs[2, 1].title.set_text('y test example image')

###Custom Loss Function 

In [18]:
# create custom PSNRLoss class

class PSNRLoss(nn.Module):

  def __init__(self):
    super(PSNRLoss, self).__init__()

  def forward(self, R, X, y, data_range=1.0, reduction='mean', convert_to_greyscale=False):
      r"""Compute Peak Signal-to-Noise Ratio for a batch of images.
      Supports both greyscale and color images with RGB channel order.
      Args:
          R: An input tensor. Shape :math:`(2, N, C, H, W)`.  where position 0 is [stage1 output, stage 2 output] -> outputs from model
          x: An input tensor. Shape :math:`(N, C, H, W)`.
          y: A target tensor. Shape :math:`(N, C, H, W)`.
          data_range: Maximum value range of images (usually 1.0 or 255).
          reduction: Specifies the reduction type:
              ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``
          convert_to_greyscale: Convert RGB image to YCbCr format and computes PSNR
              only on luminance channel if `True`. Compute on all 3 channels otherwise.
      Returns:
          PSNR Index of similarity betwen two images.
      References:
          https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
      """
      # Constant for numerical stability
      EPS = 1e-8

      X = X / float(data_range)
      y = y / float(data_range)
      R[0] = (R[0] / float(data_range)).to(device)
      R[1] = (R[1] / float(data_range)).to(device)
      

      if (X.size(1) == 3) and convert_to_greyscale:
          # Convert RGB image to YCbCr and take luminance: Y = 0.299 R + 0.587 G + 0.114 B
          rgb_to_grey = torch.tensor([0.299, 0.587, 0.114]).view(1, -1, 1, 1).to(X)
          R[0] = torch.sum(R[0] * rgb_to_grey, dim=1, keepdim=True)
          R[1] = torch.sum(R[1] * rgb_to_grey, dim=1, keepdim=True)
          X = torch.sum(X * rgb_to_grey, dim=1, keepdim=True)
          y = torch.sum(y * rgb_to_grey, dim=1, keepdim=True)

      score = []
      for i in range(0, len(R)):
        mse = torch.mean(((R[i].add(X)) - y) ** 2, dim=[1, 2, 3])
        max_value = 1. if X[0].max() <= 1 else 255. # max pixel value
        score.append(20. * torch.log10(max_value / torch.sqrt(mse)))

      summed_loss = score[0].add(score[1])
      summed_loss = -1. * summed_loss

      return _reduced(summed_loss, reduction) # reduced to single value

    
def _reduced(loss, reduction_type):
  r"""Reduce input in batch dimension if needed.
  Args:
      x: Tensor with shape (N, *).
      reduction: Specifies the reduction type:
          ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``
  """
  if reduction_type == 'none':
      return loss
  elif reduction_type == 'mean':
      return loss.mean(dim=0)
  elif reduction_type == 'sum':
      return loss.sum(dim=0)
  else:
      raise ValueError("Uknown reduction. Expected one of {'none', 'mean', 'sum'}")



###Make model

In [19]:
def make(config, device):
    
    verbose = config.verbose

    # create dataset
    file_loc = fpath
    dataset = HDF5Dataset(config, file_loc=fpath, transform=True)
    train_dataset, val_dataset, test_dataset = train_val_test_split(dataset)
    
    if verbose==1:
      # plot dataset info
      test_plot_Xy_batch(train_dataset)
      test_plot_Xy_train_val_test_dataset(train_dataset, val_dataset, test_dataset)

    # create dataloader
    train_dataloader = make_loader(config, train_dataset)
    val_dataloader = make_loader(config, val_dataset)
    test_dataloader = make_loader(config, test_dataset)
    # set up a dictionary of dataloaders for train and val
    dataloader = {'train': train_dataloader, 'val': val_dataloader, 'test': test_dataloader}
    dataset_sizes = {'train': len(train_dataloader), 'val': len(val_dataloader), 'test': len(test_dataloader)} 
    if verbose==1:
      print(f'Dataset sizes: {dataset_sizes}')

    # Make the model
    model = HINet().to(device)

    # set loss type
    if config.loss_type == "PSNR":
      criterion = PSNRLoss().to(device)  # custom loss
    elif config.loss_type == "MSE":
      criterion = MSELoss().to(device)

    # set optimizer type
    if config.optimizer_type == "Adam":
      optimizer = torch.optim.Adam(model.parameters(), lr=config.initial_learning_rate)
    elif config.optimizer_type == "SGD":
      optimizer = torch.optim.SGD(model.parameters(), lr=config.initial_learning_rate)

    # set learning rate scheduler 
    if config.scheduler_type == "cosine_annealing":
      model_lr_scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.max_itterations, eta_min=config.min_learning_rate, last_epoch=-1, verbose=False)


    return model, dataloader, dataset_sizes, criterion, optimizer, model_lr_scheduler

####Dataloader

In [20]:
def make_loader(config, dataset):
    loader = data.DataLoader(
        dataset=dataset,
        batch_size=config.batch, 
        shuffle=False,
        pin_memory=False, 
        num_workers=1
    )
    return loader

###Training Loop

In [21]:
# training loop

def train_model(model, config, device, dataloader, dataset_sizes, criterion, optimizer, scheduler):
  verbose = config.verbose

  plot_itter = config.plot_itterations
  num_epochs = config.num_epoch
  mini_batch_size = config.mini_batch_size
  # tell wandb to watch
  wandb.watch(model, criterion, log="all", log_freq=plot_itter)
  # start timer
  since = time.time()
  if verbose == 1:
    print('Total progress...')
  #initialise variables
  output_dict = {
      'num_training_examples': 0
  }
  num_itters = 0
  num_training_examples = 0
  num_validation_examples = 0
  temp_dict = {}
  # load existing best weights and reset accuracy
  best_model_wts = copy.deepcopy(model.state_dict())

  # begin analysis
  for epoch in tqdm(range(num_epochs)):
    since_epoch = time.time()
    if num_itters > max_itterations:
      break
    else:
      print(f'\n\nEpoch {epoch + 1} / {num_epochs}')
      print('-' * 10)

      for phase in ['train', 'val']:
        since_phase = time.time()
        if verbose == 1:
          print(f'\nPhase -> {phase}')
          print(f'Batches...')
          print('-' * 10)
        # loop train and val phase for each dataset
        best_loss = 1e5  # large so always saves best loss
        epoch_running_loss = 0
        epoch_loss = 0
        total_imgs_in_batches = 0
        for batch_idx, (images, labels) in enumerate(tqdm(dataloader[phase])):
          # get images and labels
          images_batch = torch.flatten(images, start_dim=0, end_dim=1)  # -> [imgs (batch*patches), channels, height, width]
          labels_batch = torch.flatten(labels, start_dim=0, end_dim=1)       
          imgs_in_batch = images_batch.shape[0]
          total_imgs_in_batches += imgs_in_batch

          # create dataset for mini-batches
          mini_batch_data = MiniBatchDataset(images_batch, labels_batch, config) 
          mini_batch_dataloader = make_loader(config, mini_batch_data) 

          for mini_batch_idx, (mini_batch_images, mini_batch_labels) in enumerate(mini_batch_dataloader):
            # get images labels and indices
            mini_batch_images = mini_batch_images.to(device)  # -> [mb, c, h, w]
            mini_batch_labels = mini_batch_labels.to(device)
            imgs_in_mini_batch = mini_batch_images.shape[0]

            if phase == 'train':  # TRAINING
              print(f'Mini-batches -> Total training examples = {num_training_examples}')
              model.train().to(device)  # set to training mode
              # FORWARD PASS
              with torch.set_grad_enabled(True):
                outputs = model(mini_batch_images) # -> [mini_batch, channels, height, width] -> outputs is [out_1 (STAGE 1 output image), out_2 (STAGE 2 output image)]
                loss = criterion(outputs, mini_batch_images, mini_batch_labels, data_range=1.0, reduction='mean').to(device)
                num_training_examples += mini_batch_images.shape[0]
                epoch_running_loss += (loss.item() * imgs_in_mini_batch)  # mult by batch to allow avg (epoch loss) to be calcd.
                # BACKWARD PASS
                optimizer.zero_grad()  # emptys cache
                loss.backward()
                optimizer.step()
                scheduler.step()
                if (epoch==0 and batch_idx==0 and mini_batch_idx==0):
                  print('Logging learning rate!')
                  wandb.log({'Learning rate': scheduler.get_last_lr()},
                    commit=False
                  )
                elif scheduler.get_last_lr() != scheduler.get_lr():
                  print('Learning rate changed - logging learning rate!')
                  wandb.log({'Learning rate': scheduler.get_last_lr()},
                    commit=False
                  )
                num_itters += 1

            else:  # VALIDATION
              model.eval()  # set to training mode
              # FORWARD PASS
              with torch.set_grad_enabled(False):
                outputs = model(mini_batch_images) # -> [mini_batch, channels, height, width] -> outputs is [out_1 (STAGE 1 output image), out_2 (STAGE 2 output image)]
                loss = criterion(outputs, mini_batch_images, mini_batch_labels, data_range=1.0, reduction='mean').to(device)
                num_validation_examples += mini_batch_images.shape[0]
                epoch_running_loss += (loss.item() * imgs_in_mini_batch)
                num_itters += 1
                
                if (epoch==0 and
                    batch_idx==0 and
                    mini_batch_idx==0):  # first epoch
                  print('Logging initial data to wandb!')
                  temp_dict['number of training examples'] = num_training_examples
                  temp_dict_append = temp_log_data(mini_batch_images[0], mini_batch_labels[0], outputs[0][0], outputs[1][0], epoch)
                  temp_dict.update(temp_dict_append)
                  temp_dict_append = 0
                
                elif (epoch+1 % plot_itter==0 and
                      batch_idx==0 and
                      mini_batch_idx==0):  # every scheduled index
                  print('Logging data to wandb!')
                  temp_dict['number of training examples'] = num_training_examples
                  temp_dict_append = temp_log_data(mini_batch_images[0], mini_batch_labels[0], outputs[0][0], outputs[1][0], epoch)
                  temp_dict.update(temp_dict_append)
                  temp_dict_append = 0

        # print phase time
        phase_time = time.time() - since_phase 
        print(f'Epoch {epoch+1} {phase} time: {phase_time // 60:.0f}m {phase_time % 60:.0f}s')
        if phase == 'train':
          total_train_time += phase_time
        else:
          total_val_time += phase_time
        
        # log epoch loss
        epoch_loss = epoch_running_loss / total_imgs_in_batches # average of epoch losses
        print(f'Epoch {epoch+1} {phase} loss: {epoch_loss}')
        temp_dict[f"epoch {phase} loss"] = epoch_loss  # log train and val loss
        
        if epoch== 0 and phase=='val':  # first epoch
          wandb.log(temp_dict,
                    commit=True
                  )
          temp_dict = {}  #zero to reduce memory overhead

        if epoch + 1 % plot_itter == 0 and phase=='val':
          wandb.log(temp_dict,
                    commit=True
                  )
          temp_dict = {}  #zero to reduce memory overhead

        # deep copy the model if best loss
        if (phase == 'val' and
            epoch_loss < best_loss):
          best_loss = epoch_loss
          best_loss_epoch = epoch + 1 
          best_model_wts = copy.deepcopy(model.state_dict())
          wandb.run.summary['Best val loss'] = best_loss
          wandb.run.summary['Best val loss epoch'] = best_loss_epoch
        del loss, outputs

    # print time for epoch
    epoch_time = time.time() - since_epoch
    print(f'Epoch {epoch+1} time: {epoch_time // 60:.0f}m {epoch_time % 60:.0f}s')
    print(f'Current best validation loss: {best_loss:.4f} @ epoch {best_loss_epoch}')
    total_time += epoch_time
    print(f'Total training time: {total_train_time // 60:.0f}m {total_train_time % 60:.0f}s / Total validation time: {total_val_time // 60:.0f}m {total_val_time % 60:.0f}s / Total time: {total_time // 60:.0f}m {total_time % 60:.0f}s')

  # print total training time over all epochs
  total_time = time.time() - since
  print(f'Total time for training {num_epochs} epochs: {total_time // 60:.0f}m {total_time % 60:.0f}s')
  output_dict['num_training_examples'] = num_training_examples

  return output_dict

####Image logging helper function

In [22]:
# plot function 

def temp_log_data(image, label, output_1, output_2, epoch):

  temp_dict_append = {
    "epoch": epoch + 1,
    "input image": [wandb.Image(image.permute(1, 2, 0).to('cpu').numpy(), caption="Input image")],
    "ground truth image": [wandb.Image(label.permute(1, 2, 0).to('cpu').numpy(), caption="Ground truth image")],
    "Output image stage 1": [wandb.Image(output_1.permute(1, 2, 0).to('cpu').detach().numpy(), caption="Output image stage 1")],
    "Output image stage 2": [wandb.Image(output_2.permute(1, 2, 0).to('cpu').detach().numpy(), caption="Output image stage 2")]
  }
  
  return temp_dict_append

###Test Function

In [23]:
def test(model, config, device, criterion, dataloader, dataset_sizes, num_training_examples):
  if config.verbose == 1:
    print(f'\nPhase -> testing')
    print('-' * 10)
    
  with torch.no_grad(): # stops computation of grads
    running_loss = 0
    total_imgs = 0
    num_test_examples = 0
    for batch_idx, (images, labels) in enumerate(tqdm(dataloader['test'])):
      # get images and labels
      images_batch = torch.flatten(images, start_dim=0, end_dim=1)  # -> [imgs (batch*patches), channels, height, width]
      labels_batch = torch.flatten(labels, start_dim=0, end_dim=1)       
      imgs_in_batch = images_batch.shape[0]
      total_imgs += imgs_in_batch

      # create dataset for mini-batches
      mini_batch_data = MiniBatchDataset(images_batch, labels_batch, config) 
      mini_batch_dataloader = make_loader(config, mini_batch_data) 
      print(f'Mini-batches -> Testing after {num_training_examples} -> Test examples = {num_test_examples}')
      for mini_batch_idx, (mini_batch_images, mini_batch_labels) in enumerate(mini_batch_dataloader):
        # get images labels and indices
        mini_batch_images = mini_batch_images.to(device)  # -> [mb, c, h, w]
        mini_batch_labels = mini_batch_labels.to(device)
        mini_batch_size = mini_batch_images.shape[0]

        outputs = model(mini_batch_images) # -> [mini_batch, channels, height, width] -> outputs is [out_1 (STAGE 1 output image), out_2 (STAGE 2 output image)]
        output_stage1 = outputs[0].to(device)
        output_stage2 = outputs[1].to(device)
        # determine losses    
        loss = criterion(outputs, mini_batch_images, mini_batch_labels, data_range=1.0, reduction='mean').to(device)
        num_test_examples += mini_batch_images.shape[0]
        running_loss += loss.item() * mini_batch_size  # mult by batch to allow avg (epoch loss) to be calcd.
      
    # log test loss
    total_loss = running_loss / total_imgs
    wandb.run.summary['Test loss'] = total_loss
    wandb.run.summary['Number of training examples'] = num_training_examples
    print(f'\nTest loss of the network: {loss}')

###Model pipeline

In [24]:
#Model Pipeline - train, val, test

def model_pipeline(hyperparameters):
    
    verbose = hyperparameters['verbose']

    # run on device
    if torch.cuda.is_available():
      device = 'cuda'
    else:
      device = 'cpu'

    # tell wandb to get started
    with wandb.init(project=project_name, config=hyperparameters):
      # access all HPs through wandb.config, so logging matches execution!
      config = wandb.config
      
      # make the model, data, and optimization problem
      model, dataloader, dataset_sizes, criterion, optimizer, model_lr_scheduler = make(config, device)
      print(model)

      # perform training and evaluating
      output_dict = train_model(model, config, device, dataloader, dataset_sizes, criterion, optimizer, model_lr_scheduler)

      # Test final performance
      test(model, config, device, criterion, dataloader, dataset_sizes, output_dict['num_training_examples'])

    return model

#RUN

##Train, Validate, test!

In [None]:
model = model_pipeline(hyperparameters)

total length: 160.0 / train set length: 128.0 / validation set length: 16.0 / test set length: 16.0
Dataset sizes: {'train': 16, 'val': 2, 'test': 2}
HINet(
  (down_path_1): ModuleList(
    (0): UNetConvBlock(
      (identity): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
      (conv_1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu_1): LeakyReLU(negative_slope=0.2)
      (conv_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu_2): LeakyReLU(negative_slope=0.2)
      (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (downsample): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    )
    (1): UNetConvBlock(
      (identity): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
      (conv_1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu_1): LeakyReLU(negative_slope=0.2)
      (conv_2): Conv2d(128, 128, kernel_si

  0%|          | 0/10 [00:00<?, ?it/s]



Epoch 1 / 10
----------

Phase -> train
Batches...
----------


  0%|          | 0/16 [00:00<?, ?it/s]

Mini-batches -> Total training examples = 0
Mini-batches -> Total training examples = 8




Mini-batches -> Total training examples = 16
Mini-batches -> Total training examples = 24
Mini-batches -> Total training examples = 32
Mini-batches -> Total training examples = 40
Mini-batches -> Total training examples = 48
Mini-batches -> Total training examples = 56
Mini-batches -> Total training examples = 64
Mini-batches -> Total training examples = 72
Mini-batches -> Total training examples = 80
Mini-batches -> Total training examples = 88
Mini-batches -> Total training examples = 96
Mini-batches -> Total training examples = 104
Mini-batches -> Total training examples = 112
Mini-batches -> Total training examples = 120
Mini-batches -> Total training examples = 128
Mini-batches -> Total training examples = 136
Mini-batches -> Total training examples = 144
Mini-batches -> Total training examples = 152
Mini-batches -> Total training examples = 160
Mini-batches -> Total training examples = 168
Mini-batches -> Total training examples = 176
Mini-batches -> Total training examples = 184

  0%|          | 0/2 [00:00<?, ?it/s]

Epoch 1 val time: 0m 32s
Epoch 1 val loss: -49.829071044921875
Epoch 1 time: 4m 22s
Current best validation loss: -49.8291 @ epoch 1


Epoch 2 / 10
----------

Phase -> train
Batches...
----------


  0%|          | 0/16 [00:00<?, ?it/s]

Mini-batches -> Total training examples = 256
Mini-batches -> Total training examples = 264
Mini-batches -> Total training examples = 272
Mini-batches -> Total training examples = 280
Mini-batches -> Total training examples = 288
Mini-batches -> Total training examples = 296
Mini-batches -> Total training examples = 304
Mini-batches -> Total training examples = 312
Mini-batches -> Total training examples = 320
Mini-batches -> Total training examples = 328
Mini-batches -> Total training examples = 336
Mini-batches -> Total training examples = 344
Mini-batches -> Total training examples = 352
Mini-batches -> Total training examples = 360
Mini-batches -> Total training examples = 368
Mini-batches -> Total training examples = 376
Mini-batches -> Total training examples = 384
Mini-batches -> Total training examples = 392
Mini-batches -> Total training examples = 400
Mini-batches -> Total training examples = 408
Mini-batches -> Total training examples = 416
Mini-batches -> Total training exa

  0%|          | 0/2 [00:00<?, ?it/s]

Epoch 2 val time: 0m 30s
Epoch 2 val loss: -55.08488082885742
Epoch 2 time: 4m 22s
Current best validation loss: -55.0849 @ epoch 2


Epoch 3 / 10
----------

Phase -> train
Batches...
----------


  0%|          | 0/16 [00:00<?, ?it/s]

Mini-batches -> Total training examples = 512
Mini-batches -> Total training examples = 520
Mini-batches -> Total training examples = 528
Mini-batches -> Total training examples = 536
Mini-batches -> Total training examples = 544
Mini-batches -> Total training examples = 552
Mini-batches -> Total training examples = 560
Mini-batches -> Total training examples = 568
Mini-batches -> Total training examples = 576
Mini-batches -> Total training examples = 584
Mini-batches -> Total training examples = 592
Mini-batches -> Total training examples = 600
Mini-batches -> Total training examples = 608
Mini-batches -> Total training examples = 616
Mini-batches -> Total training examples = 624
Mini-batches -> Total training examples = 632
Mini-batches -> Total training examples = 640
Mini-batches -> Total training examples = 648
Mini-batches -> Total training examples = 656
Mini-batches -> Total training examples = 664
Mini-batches -> Total training examples = 672
Mini-batches -> Total training exa

  0%|          | 0/2 [00:00<?, ?it/s]

Epoch 3 val time: 0m 30s
Epoch 3 val loss: -55.78398036956787
Epoch 3 time: 4m 22s
Current best validation loss: -55.7840 @ epoch 3


Epoch 4 / 10
----------

Phase -> train
Batches...
----------


  0%|          | 0/16 [00:00<?, ?it/s]

Mini-batches -> Total training examples = 768
Mini-batches -> Total training examples = 776
Mini-batches -> Total training examples = 784
Mini-batches -> Total training examples = 792
Mini-batches -> Total training examples = 800
Mini-batches -> Total training examples = 808
Mini-batches -> Total training examples = 816
Mini-batches -> Total training examples = 824
Mini-batches -> Total training examples = 832
Mini-batches -> Total training examples = 840
Mini-batches -> Total training examples = 848
Mini-batches -> Total training examples = 856
Mini-batches -> Total training examples = 864
Mini-batches -> Total training examples = 872
Mini-batches -> Total training examples = 880
Mini-batches -> Total training examples = 888
Mini-batches -> Total training examples = 896
Mini-batches -> Total training examples = 904
Mini-batches -> Total training examples = 912
Mini-batches -> Total training examples = 920
Mini-batches -> Total training examples = 928
Mini-batches -> Total training exa

  0%|          | 0/2 [00:00<?, ?it/s]

Epoch 4 val time: 0m 30s
Epoch 4 val loss: -59.44982147216797
Epoch 4 time: 4m 25s
Current best validation loss: -59.4498 @ epoch 4


Epoch 5 / 10
----------

Phase -> train
Batches...
----------


  0%|          | 0/16 [00:00<?, ?it/s]

Mini-batches -> Total training examples = 1024
Mini-batches -> Total training examples = 1032
Mini-batches -> Total training examples = 1040
Mini-batches -> Total training examples = 1048
Mini-batches -> Total training examples = 1056
Mini-batches -> Total training examples = 1064
Mini-batches -> Total training examples = 1072
Mini-batches -> Total training examples = 1080
Mini-batches -> Total training examples = 1088
Mini-batches -> Total training examples = 1096
Mini-batches -> Total training examples = 1104
Mini-batches -> Total training examples = 1112
Mini-batches -> Total training examples = 1120
Mini-batches -> Total training examples = 1128
Mini-batches -> Total training examples = 1136
Mini-batches -> Total training examples = 1144
Mini-batches -> Total training examples = 1152
Mini-batches -> Total training examples = 1160
Mini-batches -> Total training examples = 1168
Mini-batches -> Total training examples = 1176
Mini-batches -> Total training examples = 1184
Mini-batches 

  0%|          | 0/2 [00:00<?, ?it/s]

Epoch 5 val time: 0m 33s
Epoch 5 val loss: -61.32639122009277
Epoch 5 time: 4m 28s
Current best validation loss: -61.3264 @ epoch 5


Epoch 6 / 10
----------

Phase -> train
Batches...
----------


  0%|          | 0/16 [00:00<?, ?it/s]

Mini-batches -> Total training examples = 1280
Mini-batches -> Total training examples = 1288
Mini-batches -> Total training examples = 1296
Mini-batches -> Total training examples = 1304
Mini-batches -> Total training examples = 1312
Mini-batches -> Total training examples = 1320
Mini-batches -> Total training examples = 1328
Mini-batches -> Total training examples = 1336
Mini-batches -> Total training examples = 1344
Mini-batches -> Total training examples = 1352
Mini-batches -> Total training examples = 1360
Mini-batches -> Total training examples = 1368
Mini-batches -> Total training examples = 1376
Mini-batches -> Total training examples = 1384
Mini-batches -> Total training examples = 1392
Mini-batches -> Total training examples = 1400
Mini-batches -> Total training examples = 1408
Mini-batches -> Total training examples = 1416
Mini-batches -> Total training examples = 1424
Mini-batches -> Total training examples = 1432
Mini-batches -> Total training examples = 1440
Mini-batches 

  0%|          | 0/2 [00:00<?, ?it/s]

##Save model weights and bias

In [None]:
# save model state dict
PATH = f'/content/gdrive/MyDrive/Programming/models/{project_name}.pth'
torch.save(model.state_dict(), PATH)

#INFERENCE

Run inference using already trained weights


In [None]:
# load model weights
model.load_state_dict(torch.load(PATH))
model.eval()