# Globals

In [46]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [47]:
global_var = {
    # Resolutions
    'RGB_img_res': (3, 16, 64),

    # Parameters
    'batch_size': 8,
    'n_workers': 2,
    'seed': 10000,
    'lr': 1e-3,
    'lr_patience': 15,
    'epochs': 20,
    'n_workers': 2,
    'e_stop_epochs': 10,

    # Operations
    'do_print_model': True
}

augmentation_parameters = {
    # TODO
}

In [48]:
dataset_root = '/content/drive/MyDrive/NN_project/SSID_dataset/'
save_model_root = '/content/drive/MyDrive/NN_project/'
model_name = "Uformer"

# Imports

In [49]:
!pip install einops torchsummaryX --quiet

In [50]:
import math
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as TT

from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from torchsummaryX import summary
from tqdm import tqdm

# Utils

In [51]:
def compute_accuracy(y_pred, y_true, thr=0.05):
  valid_mask = y_true > 0.0
  valid_pred = y_pred[valid_mask]
  valid_true = y_true[valid_mask]
  correct = torch.max((valid_true / valid_pred), (valid_pred / valid_true)) < (1 + thr)
  return 100 * torch.mean(correct.float())

def count_parameters(model):
  return sum(p.numel() for p in model.parameters() if p.requires_grad)

def get_lr(optimizer):
  for param_group in optimizer.param_groups:
      return param_group['lr']

def hardware_check():
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
  print("Actual device: ", device)
  return device

def load_pretrained_model(model, device):
  print("Loading checkpoint...\n")
  model_dict = torch.load(save_model_root+"/Uformer_best_acc", map_location=torch.device(device))
  model.load_state_dict(model_dict)
  print("Checkpoint loaded!\n")
  return model

def plot_graph(f, g, f_label, g_label, title, path):
  epochs = range(0, len(f))
  plt.plot(epochs, f, 'b', label=f_label)
  plt.plot(epochs, g, 'orange', label=g_label)
  plt.title(title)
  plt.xlabel('Epochs')
  plt.legend()
  plt.grid('on', color='#cfcfcf')
  plt.tight_layout()
  plt.savefig(path + title + '.pdf')
  plt.close()

def plot_history(history):
  plot_graph(history['train_loss'], history['val_loss'], 'Train Loss', 'Val. Loss', 'TrainVal_loss', save_model_root)
  plot_graph(history['train_acc'], history['val_acc'], 'Train Acc.', 'Val. Acc.', 'TrainVal_acc', save_model_root)

def plot_loss(history,title):
  l_train_list = history['train_loss']
  l_test_list = history['val_loss']
  epochs = range(0, len(l_train_list))

  plt.plot(epochs, l_train_list, 'r', label='Train loss')
  plt.plot(epochs, l_test_list, 'g', label='Test loss')
  plt.title(title)
  plt.xlabel('Epochs')
  plt.grid('on', color='#cfcfcf')
  plt.legend()
  plt.tight_layout()
  plt.savefig(save_model_root + "/" + title + '.pdf')
  plt.close()

def print_model(model, device, input_shape):
  info = summary(model, torch.ones((global_var['batch_size'], input_shape[0], input_shape[1], input_shape[2])).to(device))
  info.to_csv(save_model_root + 'model_summary.csv')

def save_checkpoint(model, name):
  torch.save(model.state_dict(), save_model_root + name)

def save_csv_history(model_name):
  objects = []
  with (open(save_model_root + model_name + '_history.pkl', "rb")) as openfile:
      while True:
          try:
              objects.append(pickle.load(openfile))
          except EOFError:
              break
  df = pd.DataFrame(objects)
  df.to_csv(save_model_root + model_name + '_history.csv', header=False, index=False, sep=" ")

def save_history(history, filepath):
  tmp_file = open(filepath + '.pkl', "wb")
  pickle.dump(history, tmp_file)
  tmp_file.close()

# Data

## Data augmentation

In [52]:
# TODO

## Dataset

In [53]:
class SSID_Dataset(Dataset):
    def __init__(self, data_root):
        self.dataset_path = data_root
        self.dir_list = data_root + "Scene_Instances.txt"
        self.data_dir = data_root + "Data/"
        self.data_directiories = []
        self.img_paths = []
        self.target_paths = []
        self.post_processing = TT.Compose([
            TT.ToTensor(),
            TT.Resize((global_var['RGB_img_res'][1], global_var['RGB_img_res'][2]),antialias=None),
        ])

        data_dir_file = open(dataset_root+"Scene_Instances.txt", 'r')
        self.data_directories = [elem.strip() for elem in data_dir_file.readlines()]
        data_dir_file.close()

        for elem in self.data_directories:
          data_path = self.data_dir + elem
          content = sorted(os.listdir(data_path))
          self.target_paths.append(content[0])
          self.img_paths.append(content[1])

    def __getitem__(self, index):
        img_path = self.data_dir + self.data_directories[index] + "/" + self.img_paths[index]
        img = self.post_processing(Image.open(img_path))

        target_path = self.data_dir + self.data_directories[index] + "/" + self.target_paths[index]
        target = self.post_processing(Image.open(target_path))

        return img.float(), target.float()

    def __len__(self):
        return len(self.img_paths)

## Dataloader

In [54]:
dataset = SSID_Dataset(dataset_root)
train_dataset, test_dataset = random_split(dataset, [112, 48])

train_loader = DataLoader(dataset=train_dataset,
                          batch_size = global_var['batch_size'],
                          num_workers = global_var['n_workers'],
                          shuffle = True)

test_loader = DataLoader(dataset=test_dataset,
                         batch_size = global_var['batch_size'],
                         num_workers = global_var['n_workers'],
                         shuffle = True)

print("Train data percentage: ", len(train_dataset)/(len(train_dataset)+len(test_dataset)))
print("Test data percentage: ", len(test_dataset)/(len(train_dataset)+len(test_dataset)))

Train data percentage:  0.7
Test data percentage:  0.3


# Loss

In [55]:
class Cha_loss(nn.Module):
  def __init__(self, epsilon=1e-3):
    super(Cha_loss,self).__init__()
    self.epsilon = epsilon

  def forward(self,pred,truth):
    return torch.mean(torch.sqrt((pred-truth)**2 + self.epsilon**2))

# class Cha_loss(nn.Module):
#     """Charbonnier Loss (L1)"""
#     def __init__(self, eps=1e-6):
#         super(Cha_loss, self).__init__()
#         self.eps = eps

#     def forward(self, x, y):
#         b, c, h, w = y.size()
#         loss = torch.sum(torch.sqrt((x - y).pow(2) + self.eps**2))
#         return loss/(c*b*h*w)

# Evaluation metrics

In [56]:
# ATTENTION: PYTORCH HAS PIXEL RANGE BETWEEN 0.0 AND 1.0, NOT BETWEEN 0 AND 255
# It works, compared with torchmetrics.image import PeakSignalNoiseRatio
def psnr(original_img, compressed_img, max_pix_val=1.0):
  mse = torch.mean((original_img-compressed_img)**2)
  return 20 * torch.log10(max_pix_val/torch.sqrt(mse))

def gaussian(window_size, sigma):
    gauss = torch.Tensor([math.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()

def create_window(window_size, channel=1):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
    return window

# ATTENTION: PYTORCH HAS PIXEL RANGE BETWEEN 0.0 AND 1.0, NOT BETWEEN 0 AND 255
# ATTENTION: 4D tensors needed
# It works, compared with StructuralSimilarityIndexMeasure from torchmetrics.image
def ssim(original_img, restored_img, max_pix_val=1.0, window_size=11, window=None, size_average=True, full=False):
    (_, channel, height, width) = original_img.size()
    real_size = min(window_size, height, width)
    window = create_window(real_size, channel=channel).to(original_img.device)

    mu1 = F.conv2d(original_img, window, padding=0, groups=channel)
    mu2 = F.conv2d(restored_img, window, padding=0, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(original_img ** 2, window, padding=0, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(restored_img ** 2, window, padding=0, groups=channel) - mu2_sq
    sigma12 = F.conv2d(original_img * restored_img, window, padding=0, groups=channel) - mu1_mu2

    C1 = (0.01 * max_pix_val) ** 2
    C2 = (0.03 * max_pix_val) ** 2

    v1 = 2.0 * sigma12 + C2
    v2 = sigma1_sq + sigma2_sq + C2

    return (((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)).mean()


def compute_evaluation(test_dataloader, model, device='cpu'):
  model.eval()
  psnr_values = []
  ssim_values = []

  for _, (inputs, targets) in enumerate(test_dataloader):
      inputs, targets = inputs.to(device), targets.to(device)

      with torch.no_grad():
          predictions = model(inputs)

      psnr_values.append(psnr(targets,predictions))
      ssim_values.append(ssim(targets,predictions))

  return torch.mean(torch.Tensor(psnr_values)).item(),torch.mean(torch.Tensor(ssim_values)).item()

# Architecture

In [69]:
# Attention components
# Attention module
class W_MSA(nn.Module):
  def __init__(self, dim=32, num_heads=8, qkv_bias=False):
    super(W_MSA, self).__init__()
    self.num_heads = num_heads
    self.head_dim = dim // num_heads

    # nn.Linear(in_features, out_features): the input of the layer has to have the last dimension equal to in_features (namely (*, in_features)). The output of the layer has the
    # same dimension of the input except for the last one which is equal to out_features (namely (*, out_features))

    # self.qkv = nn.Linear(dim, num_heads, self.head_dim, bias=qkv_bias) # this layer returns the queries, keys and values
    self.qkv = nn.Linear(dim, self.head_dim*self.num_heads*3, bias=qkv_bias)

    # these are default layers for the attention module
    self.proj = nn.Linear(dim, dim)
    self.proj_drop = nn.Dropout(0.)
    self.attn_drop = nn.Dropout(0.)

  def forward(self, x):
    B, N, C = x.shape
    # print("********* x shape: ",x.shape)
    # print("********* self qkv shape:",self.qkv(x).shape)

    qkv_temp = self.qkv(x)
    mult = qkv_temp.shape[0] * qkv_temp.shape[1] * qkv_temp.shape[2]
    mult = mult//3
    mult = mult//(self.num_heads*4)
    mult = mult//global_var['batch_size']

    qkv = qkv_temp.reshape(B, mult, self.num_heads*4, 3)
    # print("********* qkv shape:",qkv.shape)
    q, k, v = qkv.unbind(dim=-1) # this returnes a tuple of tensors whose each element is portion of the original tensor (qkv) (ref: https://pytorch.org/docs/stable/generated/torch.unbind.html)

    # this is the implementation of the attention formula described on the paper
    scale = (C // self.head_dim) ** (0.5)
    attn = ((q @ k.transpose(-2, -1)) // scale) + B # from the github: the final B can be also removed
    attn = attn.softmax(dim=1)
    attn = self.attn_drop(attn)
    # print("********** x no reshape:",(attn @ v).transpose(1, 2).shape)
    x = (attn @ v).transpose(1, 2).reshape(B, N, C)

    # these are default for the attention module
    x = self.proj(x)
    x = self.proj_drop(x)

    return x

# this is the simple implementation described in the paper
class LeFF(nn.Module):
  def __init__(self, dim=32, hidden_dim=128):
    super(LeFF, self).__init__()
    self.dim = dim
    self.hidden_dim = hidden_dim

    self.layer1 = nn.Sequential(nn.Linear(dim, hidden_dim), nn.GELU())
    self.layer2 = nn.Sequential(nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1), nn.GELU())
    self.layer3 = nn.Sequential(nn.Linear(hidden_dim, dim))

  def forward(self, x):
    # print("Before 1st layer x: ", x.shape)
    x = self.layer1(x)
    x = self.layer2(x.permute(2,1,0))
    x = self.layer3(x.permute(2,1,0))

    return x

  # def forward(self, x):
  #   # bs x hw x c
  #   bs, hw, c = x.size()
  #   hh = int(math.sqrt(hw))

  #   x = self.layer1(x)

  #   # spatial restore
  #   x = rearrange(x, ' b (h w) (c) -> b c h w ', h = hh, w = hh)

  #   x = self.layer2(x)

  #   # flatten
  #   x = rearrange(x, ' b c h w -> b (h w) c', h = hh, w = hh)

  #   x = self.layer3(x)

  #   return x

# NN BLOCKS
# LeWin Transformer Block (from the paper, it is made up of a sequence of: NormLayer, W_MSA, NormLayer, LeFF)
class TransformerBlock(nn.Module):
  def __init__(self, dim):
    super().__init__()
    self.norm1 = nn.LayerNorm(dim)
    self.w_msa = W_MSA(dim=dim)
    self.norm2 = nn.LayerNorm(dim)
    self.leff = LeFF(dim=dim)
    self.dropout = nn.Dropout(0.)

  def forward(self, x):
    x = self.dropout(self.norm1(x))
    x = self.w_msa(x)
    x = self.dropout(self.norm2(x))
    x = self.leff(x)

    return x

# Down-sampling Block (reduces the size of the feature map)
# reshape the flattened features into 2D spatial feature maps, and then down-sample the maps, double the channels using 4 × 4 convolution with stride 2
class DownsampleBlock(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(DownsampleBlock, self).__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.conv = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size=4, stride=2, padding=1),
        )

    def forward(self, x): # remember that x is a tensor!!
        B, L, C = x.shape
        W = int(math.sqrt(L))*2
        # W = int(math.sqrt(L))
        H = (((W//4)**2)//(W//2))*2
        # print("*************** H: ", H)
        # print("*************** W: ", W)
        # print("*************** x: ",x.shape)
        # print("*************** x: ", x.transpose(1, 2).contiguous().shape)
        x = x.transpose(1, 2).contiguous().view(B, C, H, W) # this transposes the 1st and 2nd dimension of x, then the size of x is reshaped with view (the new size is (B, C, H, W))
                                                            # (.contiguous() is required to make view workable, since view works only on contiguous data)

        out = self.conv(x).flatten(2).transpose(1, 2).contiguous() # this pass the input x to the downsample layer, then the 2nd dimension of the output is flattened with the 3rd
                                                                   # and finally its 1st and 2nd dimensions are transposed

                                                                   # (B, C, H*W) is the size of the out after flatten(2)
                                                                   # (B H*W C) is the final size of the out after transpose(1, 2)
        # print("************** out: ", out.shape)
        return out


# Up-sampling Block (reduces half of the channels and doubles the size of the feature map)
# 2 × 2 transposed convolution with stride 2
class UpsampleBlock(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, stride, padding):
      super(UpsampleBlock, self).__init__()
      self.in_channel = in_channel
      self.out_channel = out_channel
      self.deconv = nn.Sequential(
        nn.ConvTranspose2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride,padding=padding),
      )

    def forward(self, x):
      B, L, C = x.shape
      # H = int(math.sqrt(L))
      # W = int(math.sqrt(L))
      W = int(math.sqrt(L))*2
      H = (((W//4)**2)//(W//2))*2
      x = x.transpose(1, 2).contiguous().view(B, C, H, W)
      out = self.deconv(x).flatten(2).transpose(1, 2).contiguous() # B H*W C

      return out

# Input Projection Block (extracts the low-level features)
# 3 x 3 convolutional layer with LeakyReLu
class InputProjBlock(nn.Module):
    def __init__(self, in_channel=3, out_channel=32, kernel_size=3):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=kernel_size//2),
            nn.LeakyReLU(0.1)
        )

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2).contiguous()  # B H*W C

        return x

# Output Projection Block (returns the residual)
# 3 x 3 convolutional layer
class OutputProjBlock(nn.Module):
    def __init__(self, in_channel=64, out_channel=3, kernel_size=3):
        super().__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.proj = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=kernel_size//2),
            nn.LeakyReLU(0.1)
        )

    def forward(self, x):
        B, L, C = x.shape
        H = int(math.sqrt(L))
        W = int(math.sqrt(L))
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.proj(x)

        return x # the output of this block si called residual

# complete uformer class
class Uformer(nn.Module):
  def __init__(self, embed_dim=32):
    super(Uformer, self).__init__()

    # encoder
    self.input_proj = InputProjBlock()

    self.transformerblock_0 = TransformerBlock(embed_dim)
    self.downsample_0 = DownsampleBlock(embed_dim, embed_dim*2)

    self.transformerblock_1 = TransformerBlock(embed_dim*2)
    self.downsample_1 = DownsampleBlock(embed_dim*2, embed_dim*4)

    self.transformerblock_2 = TransformerBlock(embed_dim*4)
    self.downsample_2 = DownsampleBlock(embed_dim*4, embed_dim*8)

    self.transformerblock_3 = TransformerBlock(embed_dim*8)
    self.downsample_3 = DownsampleBlock(embed_dim*8, embed_dim*16)


    # bottleneck
    self.transformerblock_4 = TransformerBlock(embed_dim*16)


    # decoder
    self.upsample_0 = UpsampleBlock(embed_dim*16, embed_dim*8,kernel_size=2,stride=2,padding=0)
    self.transformerblock_5 = TransformerBlock(embed_dim*16)

    self.upsample_1 = UpsampleBlock(embed_dim*16, embed_dim*4,kernel_size=2,stride=2,padding=0)
    self.transformerblock_6 = TransformerBlock(embed_dim*8)

    self.upsample_2 = UpsampleBlock(embed_dim*8, embed_dim*2,kernel_size=2,stride=2,padding=0)
    self.transformerblock_7 = TransformerBlock(embed_dim*4)

    self.upsample_3 = UpsampleBlock(embed_dim*4, embed_dim,kernel_size=2,stride=2,padding=0)
    self.transformerblock_8 = TransformerBlock(embed_dim*2)

    self.output_proj = OutputProjBlock()


  def forward(self, x):
    degraded_image = x # x is the degraded image


    # encoder
    y = self.input_proj(x)
    t0 = self.transformerblock_0(y)
    d0 = self.downsample_0(t0)
    t1 = self.transformerblock_1(d0)
    d1 = self.downsample_1(t1)
    t2 = self.transformerblock_2(d1)
    d2 = self.downsample_2(t2)
    t3 = self.transformerblock_3(d2)
    d3 = self.downsample_3(t3)


    # bottleneck
    t4 = self.transformerblock_4(d3)


    # decoder
    u0 = self.upsample_0(t4)
    # print("1) Upsampled in: ",u0.shape)
    # print("1) Skipped in: ",t3.shape)
    skippedconn_0 = torch.cat([u0, t3], -1) # this creates a skipped connection between t3 and t6 (u0 would have to be the input of t5)
    t5 = self.transformerblock_5(skippedconn_0)

    u1 = self.upsample_1(t5)
    # print("2) Upsampled in: ",u1.shape)
    # print("2) Skipped in: ",t2.shape)
    skippedconn_1 = torch.cat([u1, t2], -1)
    t6 = self.transformerblock_6(skippedconn_1)

    u2 = self.upsample_2(t6)
    # print("3) Upsampled in: ",u2.shape)
    # print("3) Skipped in: ",t1.shape)
    skippedconn_2 = torch.cat([u2, t1], -1)
    t7 = self.transformerblock_7(skippedconn_2)

    u3 = self.upsample_3(t7)
    # print("4) Upsampled in: ",u3.shape)
    # print("4) Skipped in: ",t0.shape)
    skippedconn_3 = torch.cat([u3, t0], -1)
    t8 = self.transformerblock_8(skippedconn_3)

    residual = self.output_proj(t8)


    # final residual summation
    # print("Degraded: ",degraded_image.shape)
    # print("Residual: ",residual.shape)
    restored_image = degraded_image + residual.reshape([residual.shape[0],residual.shape[1],degraded_image.shape[2],degraded_image.shape[3]])

    return restored_image

# Train

In [70]:
def train(device,train_dataloader,test_dataloader):
  # Set-seed
  os.environ["PYTHONHASHSEED"] = str(global_var['seed'])
  np.random.seed(global_var['seed'])
  torch.cuda.manual_seed(global_var['seed'])
  torch.cuda.empty_cache()

  # Globals
  history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': [], 'lrs': []}
  min_acc = 0
  train_acc_list = []
  test_acc_list = []
  train_loss_list = []
  test_loss_list = []

  # Loss
  criterion = Cha_loss()

  # Model
  model = Uformer()
  model.to(device=device)

  if global_var['do_print_model']:
    print_model(model, device, input_shape=global_var['RGB_img_res'])
    print('The {} model has: {} trainable parameters'.format(model_name, count_parameters(model)))

  # # Optimizer
  # optimizer = torch.optim.AdamW(
  #   model.parameters(), lr=global_var['lr'], betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False
  # )

  # # Scheduler
  # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
  #   optimizer, mode='min', factor=0.1, patience=global_var['lr_patience'], threshold=1e-4, threshold_mode='rel',
  #   cooldown=0, min_lr=1e-8, eps=1e-08, verbose=False
  # )

  # # Early stopping
  # trigger_times, early_stopping_epochs = 0, global_var['e_stop_epochs']

  # print("--- Start training: {} ---\n".format(model_name))
  # # Train

  # for epoch in range(global_var['epochs']):
  #   iter = 1
  #   model.train(mode=True)
  #   running_loss, accuracy = 0, 0

  #   with tqdm(train_dataloader, unit="step", position=0, leave=True) as tepoch:
  #     for batch in tepoch:
  #       tepoch.set_description(f"Epoch {epoch + 1}/{global_var['epochs']} - Training")

  #       # Load data
  #       inputs, targets = batch[0].to(device=device), batch[1].to(device=device)

  #       # Forward
  #       optimizer.zero_grad()
  #       outputs = model(inputs)

  #       # Compute loss
  #       loss = criterion(outputs, targets)

  #       # Backward
  #       loss = torch.clone(loss).detach().requires_grad_(True)
  #       loss.backward()
  #       optimizer.step()

  #       # Evaluation and Stats
  #       running_loss += loss.item()
  #       train_loss_list.append(loss.item())

  #       accuracy += compute_accuracy(outputs, targets)

  #       tepoch.set_postfix({'Loss': running_loss / iter,
  #                           'Acc': accuracy.item() / iter,
  #                           'Lr': global_var['lr'] if not history['lrs'] else history['lrs'][-1]})
  #       iter += 1

  #   # Validation
  #   iter = 1
  #   model.eval()
  #   test_loss, test_accuracy = 0, 0
  #   with tqdm(test_dataloader, unit="step", position=0, leave=True) as tepoch:
  #     for batch in tepoch:
  #       tepoch.set_description(f"Epoch {epoch + 1}/{global_var['epochs']} - Validation")
  #       inputs, targets = batch[0].to(device=device), batch[1].to(device=device)

  #       # Validation loop
  #       with torch.no_grad():
  #         outputs = model(inputs)

  #         # Evaluation metrics
  #         test_accuracy += compute_accuracy(outputs, targets)

  #         # Loss
  #         loss = criterion(outputs, targets)
  #         test_loss += loss.item()
  #         test_loss_list.append(loss.item())

  #         tepoch.set_postfix({'Loss': test_loss / iter, 'Acc': test_accuracy.item() / iter})
  #         iter += 1

  #       # Update history infos
  #       history['lrs'].append(get_lr(optimizer))
  #       history['train_loss'].append(running_loss / len(train_dataloader))
  #       history['val_loss'].append(test_loss / len(test_dataloader))
  #       history['train_acc'].append(accuracy.item() / len(train_dataloader))
  #       history['val_acc'].append(test_accuracy.item() / len(test_dataloader))

  #       # Save model by best ACCURACY
  #       if min_acc <= (test_accuracy / len(test_dataloader)):
  #         min_acc = test_accuracy / len(test_dataloader)
  #         save_checkpoint(model, model_name + '_best_acc')
  #         print('New best ACCURACY: {:.3f} at epoch {}'.format(min_acc, epoch + 1))

  #         if trigger_times > 4:
  #           trigger_times = trigger_times - 2
  #           print(f"EarlyStopping increased due to Accuracy, stop in {early_stopping_epochs - trigger_times} epochs")


  #       save_history(history, save_model_root + model_name + '_history')
  #       # Empty CUDA cache
  #       torch.cuda.empty_cache()

  #       if trigger_times == early_stopping_epochs:
  #           print('Val Loss did not imporved for {} epochs, training stopped'.format(early_stopping_epochs + 1))
  #           break

  #       # Save loss for graphs
  #       np.save(save_model_root + 'train.npy', np.array(train_loss_list))
  #       np.save(save_model_root + 'test.npy', np.array(test_loss_list))

  # print('--- Finished Training ---')
  # save_csv_history(model_name=model_name)
  # plot_history(history)
  # plot_loss(history, title='Loss Trend')

  # return history, min_acc, train_acc_list, test_acc_list, train_loss_list, test_loss_list

In [71]:
device = hardware_check()
stats = train(device,train_loader,test_loader)

Actual device:  cuda:0


RuntimeError: ignored

# Test

In [None]:
def test(device,test_dataloader):
  model = Uformer()
  model = load_pretrained_model(model,device)
  model.to(device)

  if global_var['do_print_model']:
    print_model(model, device, input_shape=global_var['RGB_img_res'])
    print('The {} model has: {} trainable parameters'.format(model_name, count_parameters(model)))

  # Evaluate
  print(" --- Begin evaluation --- ")
  mean_psnr, mean_ssim = compute_evaluation(test_dataloader,model,device)
  print(" --- End evaluation --- ")
  print("Mean PSNR: ",mean_psnr)
  print("Mean SSIM: ",mean_ssim)

In [None]:
test(device,test_loader)

Loading checkpoint...

Checkpoint loaded!

 --- Begin evaluation --- 
 --- End evaluation --- 
Mean PSNR:  24.364336013793945
Mean SSIM:  0.8439557552337646
