# Single Image Super-Resolution
This is the primary training notebook for training models for all the loss functions

## Importing required libraries

In [1]:
import os
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import Dataset, DataLoader, Subset, random_split
from torchvision import io, transforms
import torchvision.transforms.functional as F
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from collections import OrderedDict
import random
from preprocess import *

In [2]:
!pip install torchsummary
!pip install pytorch_msssim
!pip install torchmetrics

Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m
Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m
Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amaz

In [3]:
import sys
sys.path.append('/home/ec2-user/.local/lib/python3.7/site-packages')

In [4]:
# from pytorch_msssim import ssim
from torchsummary import summary
from torchmetrics.functional import structural_similarity_index_measure as ssim

In [5]:
print(torch.__version__)
print(torchvision.__version__)
print(torch.cuda.get_device_name(0))
print(torch.cuda.get_device_properties(0))

1.13.0+cu117
0.14.0+cu117
Tesla T4
_CudaDeviceProperties(name='Tesla T4', major=7, minor=5, total_memory=15109MB, multi_processor_count=40)


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

## Data Pre-processing and Loading

In [7]:
# Paired data transformation for training including 0 degrees
paired_transform = PairedTransform([
    RandomHorizontalFlip(),
    RandomVerticalFlip(),
    RandomRotationSpecific([0, 90, 180, 270])
])

tensor_transform = transforms.Compose([
    transforms.ToTensor(),  # Convert the image to a tensor with values between [0, 1]
    # transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize to [-1, 1]
])

# Initialize dataset
batch_size = 16
hr_dir = "HR_patches"
# lr_x2_dir = "LR_x2_patches"
lr_x4_dir = "LR_x4_patches"

# Initialize the dataset
dataset = SRDataset(hr_dir, 
                    lr_x4_dir,
                    paired_transform=paired_transform, 
                    tensor_transform=tensor_transform,
                    paired_transform_prob=0.5)

# Split the dataset into a smaller subset
subset_size = int(0.35 * len(dataset))
discard_size = len(dataset) - subset_size
subset_dataset, _ = random_split(dataset, [subset_size, discard_size])

# Second split: 90% for training, 10% for validation
train_size = int(0.9 * len(subset_dataset))
val_size = len(subset_dataset) - train_size
train_dataset, val_dataset = random_split(subset_dataset, [train_size, val_size])

# Create subsets for training and validation
train_loader = DataLoader(train_dataset, 
                          batch_size=batch_size, 
                          shuffle=True, 
                          num_workers=4,
                          pin_memory=True,
                          collate_fn=custom_collate_fn)

val_loader = DataLoader(val_dataset, 
                        batch_size=batch_size, 
                        shuffle=False, 
                        num_workers=4,
                        pin_memory=True,
                        collate_fn=custom_collate_fn)

In [8]:
len(train_dataset), len(val_dataset)

(18900, 2100)

## Defining Model Architecture

In [9]:
# Defining Inception Block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, batch_norm=False):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, bias=False)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels) if batch_norm else None

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        if self.bn1:
            out = self.bn1(out)
        out = self.relu(out)
        out = out + residual
        out = self.relu(out)
        return out

# Defining Residual-in-Residual Block
class ResidualInResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, batch_norm=False):
        super(ResidualInResidualBlock, self).__init__()
        self.res1 = ResidualBlock(in_channels, out_channels, batch_norm)
        self.res2 = ResidualBlock(out_channels, out_channels, batch_norm)
        # self.res3 = ResidualBlock(out_channels, out_channels, batch_norm)

    def forward(self, x):
        out = self.res1(x)
        out = self.res2(out)
        # out = self.res3(out)
        return out + x

# Defining the U-Net architecture
class UNetSRx4RiR(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, num_features=64, dropout_rate=0.0):
        super(UNetSRx4RiR, self).__init__()

        # Encoder
        self.conv1 = nn.Conv2d(in_channels, num_features, kernel_size=3, padding=1)
        self.encoder1 = ResidualInResidualBlock(num_features, num_features, batch_norm=True)
        self.downsample1 = self.downsample(num_features, num_features * 2, dropout_rate)

        self.encoder2 = ResidualInResidualBlock(num_features * 2, num_features * 2, batch_norm=True)
        self.downsample2 = self.downsample(num_features * 2, num_features * 4, dropout_rate)

        self.encoder3 = ResidualInResidualBlock(num_features * 4, num_features * 4, batch_norm=True)
        self.downsample3 = self.downsample(num_features * 4, num_features * 8, dropout_rate)

        self.encoder4 = ResidualInResidualBlock(num_features * 8, num_features * 8, batch_norm=True)
        self.downsample4 = self.downsample(num_features * 8, num_features * 16, dropout_rate)

        # Bottleneck
        self.bottleneck = ResidualInResidualBlock(num_features * 16, num_features * 16, batch_norm=True)

        # Decoder
        
        self.upconv4 = nn.ConvTranspose2d(num_features * 16, num_features * 8, kernel_size=2, stride=2)
        self.decoder4 = ResidualInResidualBlock(num_features * 16, num_features * 16, batch_norm=True)
        self.conv3 = nn.Conv2d(num_features * 16, num_features * 8, kernel_size=3, padding=1)
        
        self.upconv3 = nn.ConvTranspose2d(num_features * 8, num_features * 4, kernel_size=2, stride=2)
        self.decoder3 = ResidualInResidualBlock(num_features * 8, num_features * 8, batch_norm=True)
        self.conv4 = nn.Conv2d(num_features * 8, num_features * 4, kernel_size=3, padding=1)

        self.upconv2 = nn.ConvTranspose2d(num_features * 4, num_features * 2, kernel_size=2, stride=2)
        self.decoder2 = ResidualInResidualBlock(num_features * 4, num_features * 4, batch_norm=True)
        self.conv5 = nn.Conv2d(num_features * 4, num_features * 2, kernel_size=3, padding=1)

        self.upconv1 = nn.ConvTranspose2d(num_features * 2, num_features, kernel_size=2, stride=2)
        self.decoder1 = ResidualInResidualBlock(num_features * 2, num_features * 2, batch_norm=True)

        # Final output layer
        self.final_conv = nn.Sequential(
            nn.Conv2d(num_features * 2, num_features, kernel_size=3, padding=1),
            nn.Conv2d(num_features, out_channels, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

    def downsample(self, in_channels, out_channels, dropout_rate, batch_norm=True):
        layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=2)]
        if batch_norm:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.ReLU())
        if dropout_rate > 0:
            layers.append(nn.Dropout(dropout_rate))
        return nn.Sequential(*layers)

    def forward(self, x):
        # First upscale x4 using bicubic interpolation
        x = torch.nn.functional.interpolate(x, scale_factor=4, mode='bicubic', align_corners=False)

        # Encoder
        enc1 = self.conv1(x) 
        enc1 = self.encoder1(enc1) 

        enc2 = self.downsample1(enc1) 
        enc2 = self.encoder2(enc2)

        enc3 = self.downsample2(enc2)
        enc3 = self.encoder3(enc3) 

        enc4 = self.downsample3(enc3)
        enc4 = self.encoder4(enc4)

        # Bottleneck
        bottleneck = self.downsample4(enc4)
        bottleneck = self.bottleneck(bottleneck)

        # Decoder
        dec4 = self.upconv4(bottleneck) 
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec4 = self.conv3(dec4)
        
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec3 = self.conv4(dec3)

        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec2 = self.conv5(dec2)

        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        sr = self.final_conv(dec1)

        return sr

In [10]:
model_1 = UNetSRx4RiR(3, 3, 64, 0.0)
model_1.to(device)
summary(model_1, input_size=(3, 64, 64))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]           1,792
            Conv2d-2         [-1, 64, 256, 256]           4,096
              ReLU-3         [-1, 64, 256, 256]               0
            Conv2d-4         [-1, 64, 256, 256]          36,864
       BatchNorm2d-5         [-1, 64, 256, 256]             128
              ReLU-6         [-1, 64, 256, 256]               0
              ReLU-7         [-1, 64, 256, 256]               0
     ResidualBlock-8         [-1, 64, 256, 256]               0
            Conv2d-9         [-1, 64, 256, 256]           4,096
             ReLU-10         [-1, 64, 256, 256]               0
           Conv2d-11         [-1, 64, 256, 256]          36,864
      BatchNorm2d-12         [-1, 64, 256, 256]             128
             ReLU-13         [-1, 64, 256, 256]               0
             ReLU-14         [-1, 64, 2

## Defining Loss Functions and Evaluation Metrics

### Defining PSNR metric

In [11]:
def calculate_psnr(img1, img2, max_pixel_value=1.0):
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return 100
    psnr = 10 * torch.log10((max_pixel_value ** 2) / mse)
    return psnr

### Defining Tukey Loss

In [11]:
class TukeyLoss(nn.Module):
    def __init__(self, c=0.3):  # Adjusted c value for normalized data
        super(TukeyLoss, self).__init__()
        self.c = c

    def forward(self, input, target):
        error = input - target
        abs_error = torch.abs(error)

        # Tukey loss calculation
        mask = abs_error <= self.c
        tukey_loss = torch.where(
            mask,
            (self.c ** 2 / 6) * (1 - (1 - (error / self.c) ** 2) ** 3),
            (self.c ** 2 / 6) * torch.ones_like(error)
        )

        return tukey_loss.mean()

### Defining Charbonnier loss

In [12]:
class CharbonnierLoss(nn.Module):
    def __init__(self, epsilon=1e-3):
        super(CharbonnierLoss, self).__init__()
        self.epsilon = epsilon

    def forward(self, x, y):
        diff = x - y
        loss = torch.mean(torch.sqrt(diff * diff + self.epsilon ** 2))
        return loss

### Defining Total Variation (TV) loss and combined TV + MAE loss

In [13]:
class TVLoss(nn.Module):
    def __init__(self, tv_weight=1e-4):
        super(TVLoss, self).__init__()
        self.tv_weight = tv_weight

    def forward(self, x):
        batch_size = x.size(0)
        h_x = x.size(2)
        w_x = x.size(3)
        count_h = self._tensor_size(x[:, :, 1:, :])
        count_w = self._tensor_size(x[:, :, :, 1:])
        h_tv = torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2).sum()
        w_tv = torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2).sum()
        return self.tv_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

    @staticmethod
    def _tensor_size(t):
        return t.size(1) * t.size(2) * t.size(3)
    
def combined_tv_mae_loss(output, target):
    """Calculate combined MAE and TV loss."""
    mae_loss = l1_loss(output, target)
    tv_out = tv_loss(output)
    return tv_out + mae_loss

### Defining combined SSIM loss + MAE loss

In [12]:
def combined_ssim_l1_loss(output, target, alpha=0.7):
    ssim_out = 1 - ssim(output, target, data_range=1.0)
    mae_loss = torch.nn.functional.l1_loss(output, target)
    return alpha * mae_loss + (1 - alpha) * ssim_out + 1e-8

### Defining Early Stopping class

In [12]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')
        self.delta = delta

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), 'checkpoint.pt')
        self.val_loss_min = val_loss

## Model Training

### Training the model for MAE loss

In [None]:
from torch.cuda.amp import GradScaler, autocast

torch.autograd.set_detect_anomaly(True)

if __name__ == "__main__":

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = UNetSRx4RiR(in_channels=3, out_channels=3, num_features=64, dropout_rate=0.0).to(device)
    criterion = nn.L1Loss()
    optimizer = optim.Adam(model.parameters(), lr=0.00005)
    scaler = GradScaler()
    
    # Initialize TensorBoard writer
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter('runs/UNetSRx4_L1_fin')

    # Initialize early stopping
    early_stopping = EarlyStopping(patience=5, verbose=True)
    
    # Initialize learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=2, verbose=True)

    num_epochs = 25
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_psnr = 0.0
        running_ssim = 0.0

        # Wrap the train_loader with tqdm for progress bar
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
        
        for i, (hr_patches, lr_patches) in enumerate(progress_bar):
            hr_patches, lr_patches = hr_patches.to(device), lr_patches.to(device)
            
            optimizer.zero_grad()
            with autocast():
                outputs = model(lr_patches)
                loss = criterion(outputs, hr_patches)
                if torch.isnan(loss).any():
                    print("Loss is NaN")
                    break
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            running_loss += loss.item()
            running_psnr += calculate_psnr(outputs, hr_patches).item()
            outputs = outputs.float()
            hr_patches = hr_patches.float()
            running_ssim += ssim(outputs, hr_patches, data_range=1.0).item()

            # Update the progress bar with the running loss
            progress_bar.set_postfix(loss=loss.item())
        
        avg_train_loss = running_loss / len(train_loader)
        avg_train_psnr = running_psnr / len(train_loader)
        avg_train_ssim = running_ssim / len(train_loader)
        writer.add_scalar('Loss/train', avg_train_loss, epoch)
        writer.add_scalar('PSNR/train', avg_train_psnr, epoch)
        writer.add_scalar('SSIM/train', avg_train_ssim, epoch)
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Train PSNR: {avg_train_psnr:.4f}, Train SSIM: {avg_train_ssim:.4f}')
    
        # Validation loop
        model.eval()
        val_loss = 0.0
        val_psnr = 0.0
        val_ssim = 0.0     
        with torch.no_grad():
            for hr_patches, lr_patches in val_loader:
                hr_patches, lr_patches = hr_patches.to(device), lr_patches.to(device)
                with autocast(enabled=True):
                    outputs = model(lr_patches)
                    loss = criterion(outputs, hr_patches)
                    if torch.isnan(loss).any():
                        print("Loss is NaN")
                        break
                val_loss += loss.item()
                val_psnr += calculate_psnr(outputs, hr_patches).item()
                # Convert to float32 before calculating SSIM
                outputs = outputs.float()
                hr_patches = hr_patches.float()
                val_ssim += ssim(outputs, hr_patches, data_range=1.0).item()
    
        avg_val_loss = val_loss / len(val_loader)
        avg_val_psnr = val_psnr / len(val_loader)
        avg_val_ssim = val_ssim / len(val_loader)
        writer.add_scalar('Loss/val', avg_val_loss, epoch)
        writer.add_scalar('PSNR/val', avg_val_psnr, epoch)
        writer.add_scalar('SSIM/val', avg_val_ssim, epoch)
        print(f'Validation Loss: {avg_val_loss:.4f}, Validation PSNR: {avg_val_psnr:.4f}, Validation SSIM: {avg_val_ssim:.4f}')

        # Check early stopping condition
        early_stopping(avg_val_loss, model)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break
        
        # Step the scheduler
        scheduler.step(avg_val_loss)
    
    print('Training complete.')

2024-09-07 04:55:39.783390: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-09-07 04:55:39.938763: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-09-07 04:55:40.670297: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2024-09-07 04:55:40.670367: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] 

Epoch [1/25], Train Loss: 0.0471, Train PSNR: 23.4607, Train SSIM: 0.7081





Validation Loss: 0.0338, Validation PSNR: 25.0414, Validation SSIM: 0.7598
Validation loss decreased (inf --> 0.033751).  Saving model ...


Epoch 2/25: 100%|██████████| 1182/1182 [23:23<00:00,  1.19s/batch, loss=0.0268]

Epoch [2/25], Train Loss: 0.0350, Train PSNR: 24.9670, Train SSIM: 0.7618





Validation Loss: 0.0310, Validation PSNR: 25.4650, Validation SSIM: 0.7786
Validation loss decreased (0.033751 --> 0.030954).  Saving model ...


Epoch 3/25: 100%|██████████| 1182/1182 [23:24<00:00,  1.19s/batch, loss=0.0218]

Epoch [3/25], Train Loss: 0.0324, Train PSNR: 25.3283, Train SSIM: 0.7749





Validation Loss: 0.0314, Validation PSNR: 25.5211, Validation SSIM: 0.7861
EarlyStopping counter: 1 out of 5


Epoch 4/25: 100%|██████████| 1182/1182 [23:25<00:00,  1.19s/batch, loss=0.0442]

Epoch [4/25], Train Loss: 0.0310, Train PSNR: 25.5430, Train SSIM: 0.7819





Validation Loss: 0.0296, Validation PSNR: 25.7119, Validation SSIM: 0.7874
Validation loss decreased (0.030954 --> 0.029613).  Saving model ...


Epoch 5/25: 100%|██████████| 1182/1182 [23:24<00:00,  1.19s/batch, loss=0.026] 

Epoch [5/25], Train Loss: 0.0302, Train PSNR: 25.6532, Train SSIM: 0.7865





Validation Loss: 0.0292, Validation PSNR: 25.7905, Validation SSIM: 0.7906
Validation loss decreased (0.029613 --> 0.029197).  Saving model ...


Epoch 6/25: 100%|██████████| 1182/1182 [23:25<00:00,  1.19s/batch, loss=0.0337]

Epoch [6/25], Train Loss: 0.0296, Train PSNR: 25.7612, Train SSIM: 0.7899





Validation Loss: 0.0282, Validation PSNR: 25.8957, Validation SSIM: 0.7941
Validation loss decreased (0.029197 --> 0.028229).  Saving model ...


Epoch 7/25: 100%|██████████| 1182/1182 [23:25<00:00,  1.19s/batch, loss=0.0225]

Epoch [7/25], Train Loss: 0.0291, Train PSNR: 25.8263, Train SSIM: 0.7927





Validation Loss: 0.0275, Validation PSNR: 25.9694, Validation SSIM: 0.7973
Validation loss decreased (0.028229 --> 0.027493).  Saving model ...


Epoch 8/25: 100%|██████████| 1182/1182 [23:24<00:00,  1.19s/batch, loss=0.031] 

Epoch [8/25], Train Loss: 0.0288, Train PSNR: 25.9014, Train SSIM: 0.7948





Validation Loss: 0.0277, Validation PSNR: 25.9596, Validation SSIM: 0.7988
EarlyStopping counter: 1 out of 5


Epoch 9/25: 100%|██████████| 1182/1182 [23:24<00:00,  1.19s/batch, loss=0.0444]

Epoch [9/25], Train Loss: 0.0283, Train PSNR: 25.9511, Train SSIM: 0.7969





Validation Loss: 0.0282, Validation PSNR: 25.8762, Validation SSIM: 0.7927
EarlyStopping counter: 2 out of 5


Epoch 10/25: 100%|██████████| 1182/1182 [23:23<00:00,  1.19s/batch, loss=0.0347]

Epoch [10/25], Train Loss: 0.0281, Train PSNR: 25.9692, Train SSIM: 0.7981





Validation Loss: 0.0286, Validation PSNR: 25.9306, Validation SSIM: 0.8002
EarlyStopping counter: 3 out of 5
Epoch 00010: reducing learning rate of group 0 to 2.5000e-05.


Epoch 11/25: 100%|██████████| 1182/1182 [23:23<00:00,  1.19s/batch, loss=0.0261]

Epoch [11/25], Train Loss: 0.0271, Train PSNR: 26.1101, Train SSIM: 0.8014





Validation Loss: 0.0266, Validation PSNR: 26.1135, Validation SSIM: 0.8026
Validation loss decreased (0.027493 --> 0.026587).  Saving model ...


Epoch 12/25: 100%|██████████| 1182/1182 [23:23<00:00,  1.19s/batch, loss=0.0118]

Epoch [12/25], Train Loss: 0.0270, Train PSNR: 26.1436, Train SSIM: 0.8026





Validation Loss: 0.0271, Validation PSNR: 26.1051, Validation SSIM: 0.8037
EarlyStopping counter: 1 out of 5


Epoch 13/25: 100%|██████████| 1182/1182 [23:22<00:00,  1.19s/batch, loss=0.0349]

Epoch [13/25], Train Loss: 0.0268, Train PSNR: 26.1325, Train SSIM: 0.8033





Validation Loss: 0.0269, Validation PSNR: 26.1289, Validation SSIM: 0.8032
EarlyStopping counter: 2 out of 5


Epoch 14/25: 100%|██████████| 1182/1182 [23:22<00:00,  1.19s/batch, loss=0.0241]

Epoch [14/25], Train Loss: 0.0267, Train PSNR: 26.1934, Train SSIM: 0.8041





Validation Loss: 0.0265, Validation PSNR: 26.1533, Validation SSIM: 0.8073
Validation loss decreased (0.026587 --> 0.026497).  Saving model ...


Epoch 15/25: 100%|██████████| 1182/1182 [23:22<00:00,  1.19s/batch, loss=0.0282]

Epoch [15/25], Train Loss: 0.0267, Train PSNR: 26.1998, Train SSIM: 0.8047





Validation Loss: 0.0267, Validation PSNR: 26.1728, Validation SSIM: 0.8047
EarlyStopping counter: 1 out of 5


Epoch 16/25: 100%|██████████| 1182/1182 [23:22<00:00,  1.19s/batch, loss=0.0231]

Epoch [16/25], Train Loss: 0.0265, Train PSNR: 26.2153, Train SSIM: 0.8054





Validation Loss: 0.0275, Validation PSNR: 26.1121, Validation SSIM: 0.8050
EarlyStopping counter: 2 out of 5


Epoch 17/25: 100%|██████████| 1182/1182 [23:23<00:00,  1.19s/batch, loss=0.0343]

Epoch [17/25], Train Loss: 0.0265, Train PSNR: 26.2195, Train SSIM: 0.8058





Validation Loss: 0.0266, Validation PSNR: 26.1921, Validation SSIM: 0.8075
EarlyStopping counter: 3 out of 5
Epoch 00017: reducing learning rate of group 0 to 1.2500e-05.


Epoch 18/25: 100%|██████████| 1182/1182 [23:22<00:00,  1.19s/batch, loss=0.044] 

Epoch [18/25], Train Loss: 0.0260, Train PSNR: 26.2633, Train SSIM: 0.8075





Validation Loss: 0.0266, Validation PSNR: 26.1954, Validation SSIM: 0.8071
EarlyStopping counter: 4 out of 5


Epoch 19/25: 100%|██████████| 1182/1182 [23:21<00:00,  1.19s/batch, loss=0.0241]

Epoch [19/25], Train Loss: 0.0260, Train PSNR: 26.3188, Train SSIM: 0.8081





Validation Loss: 0.0262, Validation PSNR: 26.2290, Validation SSIM: 0.8081
Validation loss decreased (0.026497 --> 0.026174).  Saving model ...


Epoch 20/25: 100%|██████████| 1182/1182 [23:21<00:00,  1.19s/batch, loss=0.0228]

Epoch [20/25], Train Loss: 0.0259, Train PSNR: 26.3016, Train SSIM: 0.8085





Validation Loss: 0.0260, Validation PSNR: 26.2425, Validation SSIM: 0.8079
Validation loss decreased (0.026174 --> 0.026030).  Saving model ...


Epoch 21/25: 100%|██████████| 1182/1182 [23:22<00:00,  1.19s/batch, loss=0.0417]

Epoch [21/25], Train Loss: 0.0259, Train PSNR: 26.3158, Train SSIM: 0.8089





Validation Loss: 0.0260, Validation PSNR: 26.2444, Validation SSIM: 0.8082
Validation loss decreased (0.026030 --> 0.026021).  Saving model ...


Epoch 22/25: 100%|██████████| 1182/1182 [23:23<00:00,  1.19s/batch, loss=0.0413]

Epoch [22/25], Train Loss: 0.0258, Train PSNR: 26.3482, Train SSIM: 0.8092





Validation Loss: 0.0260, Validation PSNR: 26.2419, Validation SSIM: 0.8091
Validation loss decreased (0.026021 --> 0.025996).  Saving model ...


Epoch 23/25: 100%|██████████| 1182/1182 [23:22<00:00,  1.19s/batch, loss=0.0308]

Epoch [23/25], Train Loss: 0.0258, Train PSNR: 26.3186, Train SSIM: 0.8096





Validation Loss: 0.0259, Validation PSNR: 26.2640, Validation SSIM: 0.8089
Validation loss decreased (0.025996 --> 0.025868).  Saving model ...


Epoch 24/25: 100%|██████████| 1182/1182 [23:21<00:00,  1.19s/batch, loss=0.0217]

Epoch [24/25], Train Loss: 0.0258, Train PSNR: 26.3511, Train SSIM: 0.8098





Validation Loss: 0.0258, Validation PSNR: 26.2700, Validation SSIM: 0.8102
Validation loss decreased (0.025868 --> 0.025831).  Saving model ...


Epoch 25/25: 100%|█████████▉| 1180/1182 [23:21<00:02,  1.19s/batch, loss=0.0265]

In [None]:
torch.save(model, "UNet_L1_final.pth")

### Training the model for MSE loss

In [14]:
from torch.cuda.amp import GradScaler, autocast

torch.autograd.set_detect_anomaly(True)

if __name__ == "__main__":

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = UNetSRx4RiR(in_channels=3, out_channels=3, num_features=64, dropout_rate=0.0).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    scaler = GradScaler()
    
    # Initialize TensorBoard writer
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter('runs/UNetSRx4_L2')

    # Initialize early stopping
    early_stopping = EarlyStopping(patience=5, verbose=True)
    
    # Initialize learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=2, verbose=True)

    num_epochs = 25
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_psnr = 0.0
        running_ssim = 0.0

        # Wrap the train_loader with tqdm for progress bar
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
        
        for i, (hr_patches, lr_patches) in enumerate(progress_bar):
            hr_patches, lr_patches = hr_patches.to(device), lr_patches.to(device)
            
            optimizer.zero_grad()
            with autocast():
                outputs = model(lr_patches)
                loss = criterion(outputs, hr_patches)
                if torch.isnan(loss).any():
                    print("Loss is NaN")
                    break
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            running_loss += loss.item()
            running_psnr += calculate_psnr(outputs, hr_patches).item()
            outputs = outputs.float()
            hr_patches = hr_patches.float()
            running_ssim += ssim(outputs, hr_patches, data_range=1.0).item()

            # Update the progress bar with the running loss
            progress_bar.set_postfix(loss=loss.item())
        
        avg_train_loss = running_loss / len(train_loader)
        avg_train_psnr = running_psnr / len(train_loader)
        avg_train_ssim = running_ssim / len(train_loader)
        writer.add_scalar('Loss/train', avg_train_loss, epoch)
        writer.add_scalar('PSNR/train', avg_train_psnr, epoch)
        writer.add_scalar('SSIM/train', avg_train_ssim, epoch)
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Train PSNR: {avg_train_psnr:.4f}, Train SSIM: {avg_train_ssim:.4f}')
    
        # Validation loop
        model.eval()
        val_loss = 0.0
        val_psnr = 0.0
        val_ssim = 0.0     
        with torch.no_grad():
            for hr_patches, lr_patches in val_loader:
                hr_patches, lr_patches = hr_patches.to(device), lr_patches.to(device)
                with autocast(enabled=True):
                    outputs = model(lr_patches)
                    loss = criterion(outputs, hr_patches)
                    if torch.isnan(loss).any():
                        print("Loss is NaN")
                        break
                val_loss += loss.item()
                val_psnr += calculate_psnr(outputs, hr_patches).item()
                # Convert to float32 before calculating SSIM
                outputs = outputs.float()
                hr_patches = hr_patches.float()
                val_ssim += ssim(outputs, hr_patches, data_range=1.0).item()
    
        avg_val_loss = val_loss / len(val_loader)
        avg_val_psnr = val_psnr / len(val_loader)
        avg_val_ssim = val_ssim / len(val_loader)
        writer.add_scalar('Loss/val', avg_val_loss, epoch)
        writer.add_scalar('PSNR/val', avg_val_psnr, epoch)
        writer.add_scalar('SSIM/val', avg_val_ssim, epoch)
        print(f'Validation Loss: {avg_val_loss:.4f}, Validation PSNR: {avg_val_psnr:.4f}, Validation SSIM: {avg_val_ssim:.4f}')

        # Check early stopping condition
        early_stopping(avg_val_loss, model)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break
        
        # Step the scheduler
        scheduler.step(avg_val_loss)
    
    print('Training complete.')

Epoch 1/25: 100%|██████████| 1182/1182 [25:41<00:00,  1.30s/batch, loss=0.00643]

Epoch [1/25], Train Loss: 0.0048, Train PSNR: 23.6951, Train SSIM: 0.7243





Validation Loss: 0.0032, Validation PSNR: 25.1840, Validation SSIM: 0.7579
Validation loss decreased (inf --> 0.003196).  Saving model ...


Epoch 2/25: 100%|██████████| 1182/1182 [23:18<00:00,  1.18s/batch, loss=0.008]  

Epoch [2/25], Train Loss: 0.0032, Train PSNR: 25.1770, Train SSIM: 0.7694





Validation Loss: 0.0032, Validation PSNR: 25.2142, Validation SSIM: 0.7714
Validation loss decreased (0.003196 --> 0.003168).  Saving model ...


Epoch 3/25: 100%|██████████| 1182/1182 [23:18<00:00,  1.18s/batch, loss=0.0029]  

Epoch [3/25], Train Loss: 0.0030, Train PSNR: 25.4919, Train SSIM: 0.7778





Validation Loss: 0.0028, Validation PSNR: 25.7216, Validation SSIM: 0.7849
Validation loss decreased (0.003168 --> 0.002848).  Saving model ...


Epoch 4/25: 100%|██████████| 1182/1182 [23:18<00:00,  1.18s/batch, loss=0.00875]

Epoch [4/25], Train Loss: 0.0029, Train PSNR: 25.7035, Train SSIM: 0.7853





Validation Loss: 0.0028, Validation PSNR: 25.7427, Validation SSIM: 0.7836
Validation loss decreased (0.002848 --> 0.002832).  Saving model ...


Epoch 5/25: 100%|██████████| 1182/1182 [23:18<00:00,  1.18s/batch, loss=0.000963]

Epoch [5/25], Train Loss: 0.0028, Train PSNR: 25.7884, Train SSIM: 0.7889





Validation Loss: 0.0032, Validation PSNR: 25.2679, Validation SSIM: 0.7767
EarlyStopping counter: 1 out of 5


Epoch 6/25: 100%|██████████| 1182/1182 [23:21<00:00,  1.19s/batch, loss=0.00404]

Epoch [6/25], Train Loss: 0.0028, Train PSNR: 25.8172, Train SSIM: 0.7893





Validation Loss: 0.0028, Validation PSNR: 25.7982, Validation SSIM: 0.7872
Validation loss decreased (0.002832 --> 0.002809).  Saving model ...


Epoch 7/25: 100%|██████████| 1182/1182 [23:22<00:00,  1.19s/batch, loss=0.00152]

Epoch [7/25], Train Loss: 0.0027, Train PSNR: 25.9402, Train SSIM: 0.7932





Validation Loss: 0.0028, Validation PSNR: 25.8125, Validation SSIM: 0.7930
Validation loss decreased (0.002809 --> 0.002786).  Saving model ...


Epoch 8/25: 100%|██████████| 1182/1182 [23:25<00:00,  1.19s/batch, loss=0.00232] 

Epoch [8/25], Train Loss: 0.0027, Train PSNR: 25.9980, Train SSIM: 0.7953





Validation Loss: 0.0028, Validation PSNR: 25.7723, Validation SSIM: 0.7816
EarlyStopping counter: 1 out of 5


Epoch 9/25: 100%|██████████| 1182/1182 [23:25<00:00,  1.19s/batch, loss=0.0144]  

Epoch [9/25], Train Loss: 0.0027, Train PSNR: 26.0056, Train SSIM: 0.7959





Validation Loss: 0.0026, Validation PSNR: 26.0768, Validation SSIM: 0.7960
Validation loss decreased (0.002786 --> 0.002638).  Saving model ...


Epoch 10/25: 100%|██████████| 1182/1182 [23:26<00:00,  1.19s/batch, loss=0.002]  

Epoch [10/25], Train Loss: 0.0026, Train PSNR: 26.0539, Train SSIM: 0.7983





Validation Loss: 0.0034, Validation PSNR: 25.0126, Validation SSIM: 0.7710
EarlyStopping counter: 1 out of 5


Epoch 11/25: 100%|██████████| 1182/1182 [23:27<00:00,  1.19s/batch, loss=0.00118] 

Epoch [11/25], Train Loss: 0.0026, Train PSNR: 26.1054, Train SSIM: 0.7991





Validation Loss: 0.0029, Validation PSNR: 25.5884, Validation SSIM: 0.7880
EarlyStopping counter: 2 out of 5


Epoch 12/25: 100%|██████████| 1182/1182 [23:29<00:00,  1.19s/batch, loss=0.00679] 

Epoch [12/25], Train Loss: 0.0026, Train PSNR: 26.1235, Train SSIM: 0.8004





Validation Loss: 0.0026, Validation PSNR: 26.1958, Validation SSIM: 0.8009
Validation loss decreased (0.002638 --> 0.002572).  Saving model ...


Epoch 13/25: 100%|██████████| 1182/1182 [23:29<00:00,  1.19s/batch, loss=0.0025]  

Epoch [13/25], Train Loss: 0.0026, Train PSNR: 26.1774, Train SSIM: 0.8018





Validation Loss: 0.0027, Validation PSNR: 25.9753, Validation SSIM: 0.7966
EarlyStopping counter: 1 out of 5


Epoch 14/25: 100%|██████████| 1182/1182 [23:30<00:00,  1.19s/batch, loss=0.00187] 

Epoch [14/25], Train Loss: 0.0026, Train PSNR: 26.1956, Train SSIM: 0.8010





Validation Loss: 0.0026, Validation PSNR: 26.1362, Validation SSIM: 0.8000
EarlyStopping counter: 2 out of 5


Epoch 15/25: 100%|██████████| 1182/1182 [23:32<00:00,  1.19s/batch, loss=0.00073] 

Epoch [15/25], Train Loss: 0.0026, Train PSNR: 26.2152, Train SSIM: 0.8034





Validation Loss: 0.0027, Validation PSNR: 25.9322, Validation SSIM: 0.7971
EarlyStopping counter: 3 out of 5
Epoch 00015: reducing learning rate of group 0 to 5.0000e-05.


Epoch 16/25: 100%|██████████| 1182/1182 [23:30<00:00,  1.19s/batch, loss=0.00362] 

Epoch [16/25], Train Loss: 0.0025, Train PSNR: 26.3441, Train SSIM: 0.8066





Validation Loss: 0.0026, Validation PSNR: 26.2225, Validation SSIM: 0.8023
Validation loss decreased (0.002572 --> 0.002557).  Saving model ...


Epoch 17/25: 100%|██████████| 1182/1182 [23:29<00:00,  1.19s/batch, loss=0.00132] 

Epoch [17/25], Train Loss: 0.0025, Train PSNR: 26.3697, Train SSIM: 0.8073





Validation Loss: 0.0026, Validation PSNR: 26.1714, Validation SSIM: 0.8032
EarlyStopping counter: 1 out of 5


Epoch 18/25: 100%|██████████| 1182/1182 [23:30<00:00,  1.19s/batch, loss=0.00247] 

Epoch [18/25], Train Loss: 0.0025, Train PSNR: 26.3721, Train SSIM: 0.8077





Validation Loss: 0.0025, Validation PSNR: 26.2374, Validation SSIM: 0.8025
Validation loss decreased (0.002557 --> 0.002545).  Saving model ...


Epoch 19/25: 100%|██████████| 1182/1182 [23:32<00:00,  1.19s/batch, loss=0.00299] 

Epoch [19/25], Train Loss: 0.0025, Train PSNR: 26.3594, Train SSIM: 0.8080





Validation Loss: 0.0026, Validation PSNR: 26.1002, Validation SSIM: 0.8024
EarlyStopping counter: 1 out of 5


Epoch 20/25: 100%|██████████| 1182/1182 [23:31<00:00,  1.19s/batch, loss=0.0018]  

Epoch [20/25], Train Loss: 0.0025, Train PSNR: 26.4157, Train SSIM: 0.8087





Validation Loss: 0.0027, Validation PSNR: 26.0875, Validation SSIM: 0.8029
EarlyStopping counter: 2 out of 5


Epoch 21/25: 100%|██████████| 1182/1182 [23:33<00:00,  1.20s/batch, loss=0.00167] 

Epoch [21/25], Train Loss: 0.0025, Train PSNR: 26.4230, Train SSIM: 0.8091





Validation Loss: 0.0025, Validation PSNR: 26.3188, Validation SSIM: 0.8043
Validation loss decreased (0.002545 --> 0.002501).  Saving model ...


Epoch 22/25: 100%|██████████| 1182/1182 [23:32<00:00,  1.19s/batch, loss=0.000775]

Epoch [22/25], Train Loss: 0.0024, Train PSNR: 26.4401, Train SSIM: 0.8097





Validation Loss: 0.0039, Validation PSNR: 25.1265, Validation SSIM: 0.7962
EarlyStopping counter: 1 out of 5


Epoch 23/25: 100%|██████████| 1182/1182 [23:33<00:00,  1.20s/batch, loss=0.00112] 

Epoch [23/25], Train Loss: 0.0024, Train PSNR: 26.4512, Train SSIM: 0.8099





Validation Loss: 0.0025, Validation PSNR: 26.2909, Validation SSIM: 0.8045
EarlyStopping counter: 2 out of 5


Epoch 24/25: 100%|██████████| 1182/1182 [23:32<00:00,  1.19s/batch, loss=0.00234] 

Epoch [24/25], Train Loss: 0.0024, Train PSNR: 26.4427, Train SSIM: 0.8101





Validation Loss: 0.0027, Validation PSNR: 26.0393, Validation SSIM: 0.8028
EarlyStopping counter: 3 out of 5
Epoch 00024: reducing learning rate of group 0 to 2.5000e-05.


Epoch 25/25: 100%|██████████| 1182/1182 [23:32<00:00,  1.20s/batch, loss=0.00137] 

Epoch [25/25], Train Loss: 0.0024, Train PSNR: 26.5032, Train SSIM: 0.8119





Validation Loss: 0.0025, Validation PSNR: 26.3680, Validation SSIM: 0.8077
Validation loss decreased (0.002501 --> 0.002475).  Saving model ...
Training complete.


In [17]:
torch.save(model, "UNet_L2_final.pth")

### Training the model for Tukey loss (C = 0.3)

In [13]:
from torch.cuda.amp import GradScaler, autocast

torch.autograd.set_detect_anomaly(True)

if __name__ == "__main__":

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = UNetSRx4RiR(in_channels=3, out_channels=3, num_features=64, dropout_rate=0.0).to(device)
    criterion = TukeyLoss(c=0.3)
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    scaler = GradScaler()
    
    # Initialize TensorBoard writer
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter('runs/UNetSRx4_Tukey')

    # Initialize early stopping
    early_stopping = EarlyStopping(patience=5, verbose=True)
    
    # Initialize learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=2, verbose=True)

    num_epochs = 25
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_psnr = 0.0
        running_ssim = 0.0

        # Wrap the train_loader with tqdm for progress bar
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
        
        for i, (hr_patches, lr_patches) in enumerate(progress_bar):
            hr_patches, lr_patches = hr_patches.to(device), lr_patches.to(device)
            
            optimizer.zero_grad()
            with autocast():
                outputs = model(lr_patches)
                loss = criterion(outputs, hr_patches)
                if torch.isnan(loss).any():
                    print("Loss is NaN")
                    break
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            running_loss += loss.item()
            running_psnr += calculate_psnr(outputs, hr_patches).item()
            outputs = outputs.float()
            hr_patches = hr_patches.float()
            running_ssim += ssim(outputs, hr_patches, data_range=1.0).item()

            # Update the progress bar with the running loss
            progress_bar.set_postfix(loss=loss.item())
        
        avg_train_loss = running_loss / len(train_loader)
        avg_train_psnr = running_psnr / len(train_loader)
        avg_train_ssim = running_ssim / len(train_loader)
        writer.add_scalar('Loss/train', avg_train_loss, epoch)
        writer.add_scalar('PSNR/train', avg_train_psnr, epoch)
        writer.add_scalar('SSIM/train', avg_train_ssim, epoch)
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Train PSNR: {avg_train_psnr:.4f}, Train SSIM: {avg_train_ssim:.4f}')
    
        # Validation loop
        model.eval()
        val_loss = 0.0
        val_psnr = 0.0
        val_ssim = 0.0     
        with torch.no_grad():
            for hr_patches, lr_patches in val_loader:
                hr_patches, lr_patches = hr_patches.to(device), lr_patches.to(device)
                with autocast(enabled=True):
                    outputs = model(lr_patches)
                    loss = criterion(outputs, hr_patches)
                    if torch.isnan(loss).any():
                        print("Loss is NaN")
                        break
                val_loss += loss.item()
                val_psnr += calculate_psnr(outputs, hr_patches).item()
                # Convert to float32 before calculating SSIM
                outputs = outputs.float()
                hr_patches = hr_patches.float()
                val_ssim += ssim(outputs, hr_patches, data_range=1.0).item()
    
        avg_val_loss = val_loss / len(val_loader)
        avg_val_psnr = val_psnr / len(val_loader)
        avg_val_ssim = val_ssim / len(val_loader)
        writer.add_scalar('Loss/val', avg_val_loss, epoch)
        writer.add_scalar('PSNR/val', avg_val_psnr, epoch)
        writer.add_scalar('SSIM/val', avg_val_ssim, epoch)
        print(f'Validation Loss: {avg_val_loss:.4f}, Validation PSNR: {avg_val_psnr:.4f}, Validation SSIM: {avg_val_ssim:.4f}')

        # Check early stopping condition
        early_stopping(avg_val_loss, model)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break
        
        # Step the scheduler
        scheduler.step(avg_val_loss)
    
    print('Training complete.')

2024-09-01 17:53:37.455563: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-09-01 17:53:37.595472: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-09-01 17:53:38.294487: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2024-09-01 17:53:38.294561: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] 

Epoch [1/25], Train Loss: 0.0016, Train PSNR: 23.7686, Train SSIM: 0.7276





Validation Loss: 0.0010, Validation PSNR: 25.4096, Validation SSIM: 0.7745
Validation loss decreased (inf --> 0.001032).  Saving model ...


Epoch 2/25: 100%|██████████| 1182/1182 [23:06<00:00,  1.17s/batch, loss=0.000624]

Epoch [2/25], Train Loss: 0.0011, Train PSNR: 25.2389, Train SSIM: 0.7747





Validation Loss: 0.0010, Validation PSNR: 25.7575, Validation SSIM: 0.7839
Validation loss decreased (0.001032 --> 0.000957).  Saving model ...


Epoch 3/25: 100%|██████████| 1182/1182 [23:07<00:00,  1.17s/batch, loss=0.00142] 

Epoch [3/25], Train Loss: 0.0010, Train PSNR: 25.5453, Train SSIM: 0.7841





Validation Loss: 0.0010, Validation PSNR: 25.4694, Validation SSIM: 0.7880
EarlyStopping counter: 1 out of 5


Epoch 4/25: 100%|██████████| 1182/1182 [23:08<00:00,  1.17s/batch, loss=0.000908]

Epoch [4/25], Train Loss: 0.0010, Train PSNR: 25.6855, Train SSIM: 0.7895





Validation Loss: 0.0009, Validation PSNR: 25.9637, Validation SSIM: 0.7889
Validation loss decreased (0.000957 --> 0.000913).  Saving model ...


Epoch 5/25: 100%|██████████| 1182/1182 [23:09<00:00,  1.18s/batch, loss=0.000859]

Epoch [5/25], Train Loss: 0.0009, Train PSNR: 25.8234, Train SSIM: 0.7930





Validation Loss: 0.0009, Validation PSNR: 26.0153, Validation SSIM: 0.7978
Validation loss decreased (0.000913 --> 0.000900).  Saving model ...


Epoch 6/25: 100%|██████████| 1182/1182 [23:10<00:00,  1.18s/batch, loss=0.00119] 

Epoch [6/25], Train Loss: 0.0009, Train PSNR: 25.8832, Train SSIM: 0.7952





Validation Loss: 0.0009, Validation PSNR: 25.9024, Validation SSIM: 0.7983
EarlyStopping counter: 1 out of 5


Epoch 7/25: 100%|██████████| 1182/1182 [23:11<00:00,  1.18s/batch, loss=0.000317]

Epoch [7/25], Train Loss: 0.0009, Train PSNR: 25.9626, Train SSIM: 0.7981





Validation Loss: 0.0009, Validation PSNR: 26.0827, Validation SSIM: 0.8015
Validation loss decreased (0.000900 --> 0.000890).  Saving model ...


Epoch 8/25: 100%|██████████| 1182/1182 [23:11<00:00,  1.18s/batch, loss=0.00153] 

Epoch [8/25], Train Loss: 0.0009, Train PSNR: 25.9923, Train SSIM: 0.7996





Validation Loss: 0.0009, Validation PSNR: 26.1851, Validation SSIM: 0.8041
Validation loss decreased (0.000890 --> 0.000871).  Saving model ...


Epoch 9/25: 100%|██████████| 1182/1182 [23:12<00:00,  1.18s/batch, loss=0.000549]

Epoch [9/25], Train Loss: 0.0009, Train PSNR: 26.0511, Train SSIM: 0.8008





Validation Loss: 0.0009, Validation PSNR: 26.1436, Validation SSIM: 0.7966
EarlyStopping counter: 1 out of 5


Epoch 10/25: 100%|██████████| 1182/1182 [23:12<00:00,  1.18s/batch, loss=0.000881]

Epoch [10/25], Train Loss: 0.0009, Train PSNR: 26.1100, Train SSIM: 0.8028





Validation Loss: 0.0009, Validation PSNR: 26.1679, Validation SSIM: 0.8051
EarlyStopping counter: 2 out of 5


Epoch 11/25: 100%|██████████| 1182/1182 [23:15<00:00,  1.18s/batch, loss=0.000349]

Epoch [11/25], Train Loss: 0.0009, Train PSNR: 26.1385, Train SSIM: 0.8039





Validation Loss: 0.0009, Validation PSNR: 26.1200, Validation SSIM: 0.8037
EarlyStopping counter: 3 out of 5
Epoch 00011: reducing learning rate of group 0 to 5.0000e-05.


Epoch 12/25: 100%|██████████| 1182/1182 [23:15<00:00,  1.18s/batch, loss=0.000546]

Epoch [12/25], Train Loss: 0.0008, Train PSNR: 26.2438, Train SSIM: 0.8072





Validation Loss: 0.0008, Validation PSNR: 26.2961, Validation SSIM: 0.8063
Validation loss decreased (0.000871 --> 0.000844).  Saving model ...


Epoch 13/25: 100%|██████████| 1182/1182 [23:17<00:00,  1.18s/batch, loss=0.000903]

Epoch [13/25], Train Loss: 0.0008, Train PSNR: 26.2781, Train SSIM: 0.8078





Validation Loss: 0.0009, Validation PSNR: 26.1882, Validation SSIM: 0.8082
EarlyStopping counter: 1 out of 5


Epoch 14/25: 100%|██████████| 1182/1182 [23:18<00:00,  1.18s/batch, loss=0.000656]

Epoch [14/25], Train Loss: 0.0008, Train PSNR: 26.2705, Train SSIM: 0.8083





Validation Loss: 0.0009, Validation PSNR: 26.2512, Validation SSIM: 0.8034
EarlyStopping counter: 2 out of 5


Epoch 15/25: 100%|██████████| 1182/1182 [23:19<00:00,  1.18s/batch, loss=0.000488]

Epoch [15/25], Train Loss: 0.0008, Train PSNR: 26.2863, Train SSIM: 0.8090





Validation Loss: 0.0008, Validation PSNR: 26.3992, Validation SSIM: 0.8104
Validation loss decreased (0.000844 --> 0.000830).  Saving model ...


Epoch 16/25: 100%|██████████| 1182/1182 [23:21<00:00,  1.19s/batch, loss=0.00164] 

Epoch [16/25], Train Loss: 0.0008, Train PSNR: 26.3229, Train SSIM: 0.8096





Validation Loss: 0.0008, Validation PSNR: 26.4150, Validation SSIM: 0.8111
Validation loss decreased (0.000830 --> 0.000826).  Saving model ...


Epoch 17/25: 100%|██████████| 1182/1182 [23:21<00:00,  1.19s/batch, loss=0.00069] 

Epoch [17/25], Train Loss: 0.0008, Train PSNR: 26.3318, Train SSIM: 0.8098





Validation Loss: 0.0008, Validation PSNR: 26.4138, Validation SSIM: 0.8097
EarlyStopping counter: 1 out of 5


Epoch 18/25: 100%|██████████| 1182/1182 [23:20<00:00,  1.18s/batch, loss=0.000941]

Epoch [18/25], Train Loss: 0.0008, Train PSNR: 26.3245, Train SSIM: 0.8105





Validation Loss: 0.0008, Validation PSNR: 26.3849, Validation SSIM: 0.8130
Validation loss decreased (0.000826 --> 0.000824).  Saving model ...


Epoch 19/25: 100%|██████████| 1182/1182 [23:20<00:00,  1.18s/batch, loss=0.000869]

Epoch [19/25], Train Loss: 0.0008, Train PSNR: 26.3521, Train SSIM: 0.8109





Validation Loss: 0.0008, Validation PSNR: 26.1917, Validation SSIM: 0.8103
EarlyStopping counter: 1 out of 5


Epoch 20/25: 100%|██████████| 1182/1182 [23:20<00:00,  1.19s/batch, loss=0.00226] 

Epoch [20/25], Train Loss: 0.0008, Train PSNR: 26.3403, Train SSIM: 0.8113





Validation Loss: 0.0008, Validation PSNR: 26.0074, Validation SSIM: 0.8089
EarlyStopping counter: 2 out of 5


Epoch 21/25: 100%|██████████| 1182/1182 [23:21<00:00,  1.19s/batch, loss=0.00123] 

Epoch [21/25], Train Loss: 0.0008, Train PSNR: 26.3725, Train SSIM: 0.8115





Validation Loss: 0.0008, Validation PSNR: 26.4476, Validation SSIM: 0.8121
Validation loss decreased (0.000824 --> 0.000822).  Saving model ...


Epoch 22/25: 100%|██████████| 1182/1182 [23:19<00:00,  1.18s/batch, loss=0.000695]

Epoch [22/25], Train Loss: 0.0008, Train PSNR: 26.3990, Train SSIM: 0.8122





Validation Loss: 0.0008, Validation PSNR: 26.3637, Validation SSIM: 0.8101
EarlyStopping counter: 1 out of 5


Epoch 23/25: 100%|██████████| 1182/1182 [23:20<00:00,  1.18s/batch, loss=0.00109] 

Epoch [23/25], Train Loss: 0.0008, Train PSNR: 26.4172, Train SSIM: 0.8124





Validation Loss: 0.0008, Validation PSNR: 26.4527, Validation SSIM: 0.8124
Validation loss decreased (0.000822 --> 0.000821).  Saving model ...


Epoch 24/25: 100%|██████████| 1182/1182 [23:21<00:00,  1.19s/batch, loss=0.000339]

Epoch [24/25], Train Loss: 0.0008, Train PSNR: 26.4036, Train SSIM: 0.8130





Validation Loss: 0.0008, Validation PSNR: 26.1608, Validation SSIM: 0.8085
EarlyStopping counter: 1 out of 5


Epoch 25/25: 100%|██████████| 1182/1182 [23:21<00:00,  1.19s/batch, loss=0.00037] 

Epoch [25/25], Train Loss: 0.0008, Train PSNR: 26.4178, Train SSIM: 0.8130





Validation Loss: 0.0008, Validation PSNR: 26.4577, Validation SSIM: 0.8140
Validation loss decreased (0.000821 --> 0.000818).  Saving model ...
Training complete.


In [14]:
torch.save(model, "UNet_Tukey_c-0.3_final.pth")

### Training the model for SSIM + MAE loss

In [None]:
from torch.cuda.amp import GradScaler, autocast

torch.autograd.set_detect_anomaly(True)

if __name__ == "__main__":

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = UNetSRx4RiR(in_channels=3, out_channels=3, num_features=64, dropout_rate=0.0).to(device)
    criterion = combined_ssim_l1_loss
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    scaler = GradScaler()
    
    # Initialize TensorBoard writer
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter('runs/UNetSRx4_ssim')

    # Initialize early stopping
    early_stopping = EarlyStopping(patience=5, verbose=True)
    
    # Initialize learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=2, verbose=True)

    num_epochs = 25
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_psnr = 0.0
        running_ssim = 0.0

        # Wrap the train_loader with tqdm for progress bar
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
        
        for i, (hr_patches, lr_patches) in enumerate(progress_bar):
            hr_patches, lr_patches = hr_patches.to(device), lr_patches.to(device)
            
            optimizer.zero_grad()
            with autocast(enabled=False):
                outputs = model(lr_patches)
                loss = criterion(outputs, hr_patches)
                if torch.isnan(loss).any():
                    print("Loss is NaN")
                    break
            
            scaler.scale(loss).backward()
            # Apply gradient clipping
            # nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            
            running_loss += loss.item()
            running_psnr += calculate_psnr(outputs, hr_patches).item()
            outputs = outputs.float()
            hr_patches = hr_patches.float()
            running_ssim += ssim(outputs, hr_patches, data_range=1.0).item()

            # Update the progress bar with the running loss
            progress_bar.set_postfix(loss=loss.item())
        
        avg_train_loss = running_loss / len(train_loader)
        avg_train_psnr = running_psnr / len(train_loader)
        avg_train_ssim = running_ssim / len(train_loader)
        writer.add_scalar('Loss/train', avg_train_loss, epoch)
        writer.add_scalar('PSNR/train', avg_train_psnr, epoch)
        writer.add_scalar('SSIM/train', avg_train_ssim, epoch)
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Train PSNR: {avg_train_psnr:.4f}, Train SSIM: {avg_train_ssim:.4f}')
    
        # Validation loop
        model.eval()
        val_loss = 0.0
        val_psnr = 0.0
        val_ssim = 0.0     
        with torch.no_grad():
            for hr_patches, lr_patches in val_loader:
                hr_patches, lr_patches = hr_patches.to(device), lr_patches.to(device)
                with autocast(enabled=False):
                    outputs = model(lr_patches)
                    loss = criterion(outputs, hr_patches)
                    if torch.isnan(loss).any():
                        print("Loss is NaN")
                        break
                val_loss += loss.item()
                val_psnr += calculate_psnr(outputs, hr_patches).item()
                # Convert to float32 before calculating SSIM
                outputs = outputs.float()
                hr_patches = hr_patches.float()
                val_ssim += ssim(outputs, hr_patches, data_range=1.0).item()
    
        avg_val_loss = val_loss / len(val_loader)
        avg_val_psnr = val_psnr / len(val_loader)
        avg_val_ssim = val_ssim / len(val_loader)
        writer.add_scalar('Loss/val', avg_val_loss, epoch)
        writer.add_scalar('PSNR/val', avg_val_psnr, epoch)
        writer.add_scalar('SSIM/val', avg_val_ssim, epoch)
        print(f'Validation Loss: {avg_val_loss:.4f}, Validation PSNR: {avg_val_psnr:.4f}, Validation SSIM: {avg_val_ssim:.4f}')

        # Check early stopping condition
        early_stopping(avg_val_loss, model)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break
        
        # Step the scheduler
        scheduler.step(avg_val_loss)
    
    print('Training complete.')

2024-09-02 07:38:50.728703: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-09-02 07:38:50.874078: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-09-02 07:38:51.596594: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2024-09-02 07:38:51.596666: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] 

Epoch [1/25], Train Loss: 0.0985, Train PSNR: 24.3744, Train SSIM: 0.7674





Validation Loss: 0.0856, Validation PSNR: 25.3044, Validation SSIM: 0.7966
Validation loss decreased (inf --> 0.085587).  Saving model ...


Epoch 2/25: 100%|██████████| 2363/2363 [1:07:03<00:00,  1.70s/batch, loss=0.0927]

Epoch [2/25], Train Loss: 0.0837, Train PSNR: 25.5299, Train SSIM: 0.7971





Validation Loss: 0.0807, Validation PSNR: 25.7845, Validation SSIM: 0.8023
Validation loss decreased (0.085587 --> 0.080728).  Saving model ...


Epoch 3/25: 100%|██████████| 2363/2363 [1:07:02<00:00,  1.70s/batch, loss=0.105] 

Epoch [3/25], Train Loss: 0.0805, Train PSNR: 25.8157, Train SSIM: 0.8034





Validation Loss: 0.0794, Validation PSNR: 26.0624, Validation SSIM: 0.8055
Validation loss decreased (0.080728 --> 0.079366).  Saving model ...


Epoch 4/25: 100%|██████████| 2363/2363 [1:07:03<00:00,  1.70s/batch, loss=0.0824]

Epoch [4/25], Train Loss: 0.0792, Train PSNR: 25.9106, Train SSIM: 0.8062





Validation Loss: 0.0766, Validation PSNR: 26.2448, Validation SSIM: 0.8104
Validation loss decreased (0.079366 --> 0.076637).  Saving model ...


Epoch 5/25: 100%|██████████| 2363/2363 [1:07:02<00:00,  1.70s/batch, loss=0.0441]

Epoch [5/25], Train Loss: 0.0775, Train PSNR: 26.0461, Train SSIM: 0.8098





Validation Loss: 0.0772, Validation PSNR: 26.2642, Validation SSIM: 0.8098
EarlyStopping counter: 1 out of 5


Epoch 6/25: 100%|██████████| 2363/2363 [1:07:02<00:00,  1.70s/batch, loss=0.0362]

Epoch [6/25], Train Loss: 0.0767, Train PSNR: 26.1087, Train SSIM: 0.8114





Validation Loss: 0.0758, Validation PSNR: 26.2083, Validation SSIM: 0.8130
Validation loss decreased (0.076637 --> 0.075763).  Saving model ...


Epoch 7/25: 100%|██████████| 2363/2363 [1:07:02<00:00,  1.70s/batch, loss=0.0495]

Epoch [7/25], Train Loss: 0.0760, Train PSNR: 26.1805, Train SSIM: 0.8129





Validation Loss: 0.0755, Validation PSNR: 26.3510, Validation SSIM: 0.8129
Validation loss decreased (0.075763 --> 0.075503).  Saving model ...


Epoch 8/25: 100%|██████████| 2363/2363 [1:07:02<00:00,  1.70s/batch, loss=0.0613]

Epoch [8/25], Train Loss: 0.0755, Train PSNR: 26.2333, Train SSIM: 0.8141





Validation Loss: 0.0751, Validation PSNR: 26.3607, Validation SSIM: 0.8145
Validation loss decreased (0.075503 --> 0.075054).  Saving model ...


Epoch 9/25: 100%|██████████| 2363/2363 [1:07:02<00:00,  1.70s/batch, loss=0.2]   

Epoch [9/25], Train Loss: 0.0751, Train PSNR: 26.2753, Train SSIM: 0.8149





Validation Loss: 0.0757, Validation PSNR: 26.2728, Validation SSIM: 0.8156
EarlyStopping counter: 1 out of 5


Epoch 10/25: 100%|██████████| 2363/2363 [1:07:02<00:00,  1.70s/batch, loss=0.129] 

Epoch [10/25], Train Loss: 0.0746, Train PSNR: 26.3246, Train SSIM: 0.8158





Validation Loss: 0.0739, Validation PSNR: 26.3606, Validation SSIM: 0.8173
Validation loss decreased (0.075054 --> 0.073905).  Saving model ...


Epoch 11/25: 100%|██████████| 2363/2363 [1:07:03<00:00,  1.70s/batch, loss=0.101] 

Epoch [11/25], Train Loss: 0.0743, Train PSNR: 26.3315, Train SSIM: 0.8167





Validation Loss: 0.0739, Validation PSNR: 26.3166, Validation SSIM: 0.8180
Validation loss decreased (0.073905 --> 0.073875).  Saving model ...


Epoch 12/25: 100%|██████████| 2363/2363 [1:07:03<00:00,  1.70s/batch, loss=0.0677]

Epoch [12/25], Train Loss: 0.0739, Train PSNR: 26.3506, Train SSIM: 0.8175





Validation Loss: 0.0733, Validation PSNR: 26.4991, Validation SSIM: 0.8176
Validation loss decreased (0.073875 --> 0.073291).  Saving model ...


Epoch 13/25: 100%|██████████| 2363/2363 [1:07:03<00:00,  1.70s/batch, loss=0.0588]

Epoch [13/25], Train Loss: 0.0736, Train PSNR: 26.3789, Train SSIM: 0.8181





Validation Loss: 0.0741, Validation PSNR: 26.3952, Validation SSIM: 0.8176
EarlyStopping counter: 1 out of 5


Epoch 14/25: 100%|██████████| 2363/2363 [1:07:03<00:00,  1.70s/batch, loss=0.0969]

Epoch [14/25], Train Loss: 0.0733, Train PSNR: 26.3943, Train SSIM: 0.8189





Validation Loss: 0.0731, Validation PSNR: 26.5981, Validation SSIM: 0.8180
Validation loss decreased (0.073291 --> 0.073068).  Saving model ...


Epoch 15/25: 100%|██████████| 2363/2363 [1:07:02<00:00,  1.70s/batch, loss=0.101] 

Epoch [15/25], Train Loss: 0.0730, Train PSNR: 26.4522, Train SSIM: 0.8195





Validation Loss: 0.0737, Validation PSNR: 26.5020, Validation SSIM: 0.8168
EarlyStopping counter: 1 out of 5


Epoch 16/25: 100%|██████████| 2363/2363 [1:07:02<00:00,  1.70s/batch, loss=0.0324]

Epoch [16/25], Train Loss: 0.0728, Train PSNR: 26.4072, Train SSIM: 0.8200





Validation Loss: 0.0736, Validation PSNR: 26.4453, Validation SSIM: 0.8184
EarlyStopping counter: 2 out of 5


Epoch 17/25: 100%|██████████| 2363/2363 [1:07:02<00:00,  1.70s/batch, loss=0.0512]

Epoch [17/25], Train Loss: 0.0725, Train PSNR: 26.4743, Train SSIM: 0.8207





Validation Loss: 0.0724, Validation PSNR: 26.5374, Validation SSIM: 0.8208
Validation loss decreased (0.073068 --> 0.072357).  Saving model ...


Epoch 18/25: 100%|██████████| 2363/2363 [1:07:02<00:00,  1.70s/batch, loss=0.0826]

Epoch [18/25], Train Loss: 0.0723, Train PSNR: 26.4904, Train SSIM: 0.8211





Validation Loss: 0.0741, Validation PSNR: 26.3242, Validation SSIM: 0.8197
EarlyStopping counter: 1 out of 5


Epoch 19/25: 100%|██████████| 2363/2363 [1:07:02<00:00,  1.70s/batch, loss=0.0702]

Epoch [19/25], Train Loss: 0.0721, Train PSNR: 26.5160, Train SSIM: 0.8216





Validation Loss: 0.0722, Validation PSNR: 26.5913, Validation SSIM: 0.8206
Validation loss decreased (0.072357 --> 0.072178).  Saving model ...


Epoch 20/25:  76%|███████▌  | 1791/2363 [50:49<16:13,  1.70s/batch, loss=0.0776] 

Continuing the training after epoch 19 as the session abrutly ended

In [13]:
# Continuing the training after epoch 19 as the session abrutly ended
from torch.cuda.amp import GradScaler, autocast

torch.autograd.set_detect_anomaly(True)

if __name__ == "__main__":

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = UNetSRx4RiR(in_channels=3, out_channels=3, num_features=64, dropout_rate=0.0).to(device)
    criterion = combined_ssim_l1_loss
    
    # Load the model's state dict from the checkpoint
    checkpoint = torch.load('checkpoint.pt')
    model.load_state_dict(checkpoint)
    
    optimizer = optim.Adam(model.parameters(), lr=0.00005) # Reducing the loss function slightly
    scaler = GradScaler()
    start_epoch = 19  # Start from the next epoch
    
    # Initialize TensorBoard writer
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter('runs/UNetSRx4_ssim')

    # Initialize early stopping
    early_stopping = EarlyStopping(patience=5, verbose=True)
    
    # Initialize learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=2, verbose=True)

    num_epochs = 25
    for epoch in range(start_epoch, num_epochs): # Start from start_rpoch
        model.train()
        running_loss = 0.0
        running_psnr = 0.0
        running_ssim = 0.0

        # Wrap the train_loader with tqdm for progress bar
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
        
        for i, (hr_patches, lr_patches) in enumerate(progress_bar):
            hr_patches, lr_patches = hr_patches.to(device), lr_patches.to(device)
            
            optimizer.zero_grad()
            with autocast(enabled=False):
                outputs = model(lr_patches)
                loss = criterion(outputs, hr_patches)
                if torch.isnan(loss).any():
                    print("Loss is NaN")
                    break
            
            scaler.scale(loss).backward()
            # Apply gradient clipping
            # nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            
            running_loss += loss.item()
            running_psnr += calculate_psnr(outputs, hr_patches).item()
            outputs = outputs.float()
            hr_patches = hr_patches.float()
            running_ssim += ssim(outputs, hr_patches, data_range=1.0).item()

            # Update the progress bar with the running loss
            progress_bar.set_postfix(loss=loss.item())
        
        avg_train_loss = running_loss / len(train_loader)
        avg_train_psnr = running_psnr / len(train_loader)
        avg_train_ssim = running_ssim / len(train_loader)
        writer.add_scalar('Loss/train', avg_train_loss, epoch)
        writer.add_scalar('PSNR/train', avg_train_psnr, epoch)
        writer.add_scalar('SSIM/train', avg_train_ssim, epoch)
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Train PSNR: {avg_train_psnr:.4f}, Train SSIM: {avg_train_ssim:.4f}')
    
        # Validation loop
        model.eval()
        val_loss = 0.0
        val_psnr = 0.0
        val_ssim = 0.0     
        with torch.no_grad():
            for hr_patches, lr_patches in val_loader:
                hr_patches, lr_patches = hr_patches.to(device), lr_patches.to(device)
                with autocast(enabled=False):
                    outputs = model(lr_patches)
                    loss = criterion(outputs, hr_patches)
                    if torch.isnan(loss).any():
                        print("Loss is NaN")
                        break
                val_loss += loss.item()
                val_psnr += calculate_psnr(outputs, hr_patches).item()
                # Convert to float32 before calculating SSIM
                outputs = outputs.float()
                hr_patches = hr_patches.float()
                val_ssim += ssim(outputs, hr_patches, data_range=1.0).item()
    
        avg_val_loss = val_loss / len(val_loader)
        avg_val_psnr = val_psnr / len(val_loader)
        avg_val_ssim = val_ssim / len(val_loader)
        writer.add_scalar('Loss/val', avg_val_loss, epoch)
        writer.add_scalar('PSNR/val', avg_val_psnr, epoch)
        writer.add_scalar('SSIM/val', avg_val_ssim, epoch)
        print(f'Validation Loss: {avg_val_loss:.4f}, Validation PSNR: {avg_val_psnr:.4f}, Validation SSIM: {avg_val_ssim:.4f}')

        # Check early stopping condition
        early_stopping(avg_val_loss, model)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break
        
        # Step the scheduler
        scheduler.step(avg_val_loss)
    
    print('Training complete.')

2024-09-03 07:40:48.212839: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-09-03 07:40:48.351979: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-09-03 07:40:49.042439: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2024-09-03 07:40:49.042508: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] 

Epoch [20/25], Train Loss: 0.0713, Train PSNR: 26.5379, Train SSIM: 0.8230





Validation Loss: 0.0721, Validation PSNR: 26.4905, Validation SSIM: 0.8205
Validation loss decreased (inf --> 0.072086).  Saving model ...


Epoch 21/25: 100%|██████████| 2363/2363 [1:06:03<00:00,  1.68s/batch, loss=0.079] 

Epoch [21/25], Train Loss: 0.0711, Train PSNR: 26.5909, Train SSIM: 0.8236





Validation Loss: 0.0723, Validation PSNR: 26.4309, Validation SSIM: 0.8206
EarlyStopping counter: 1 out of 5


Epoch 22/25: 100%|██████████| 2363/2363 [1:06:07<00:00,  1.68s/batch, loss=0.0449]

Epoch [22/25], Train Loss: 0.0708, Train PSNR: 26.6048, Train SSIM: 0.8242





Validation Loss: 0.0722, Validation PSNR: 26.5221, Validation SSIM: 0.8199
EarlyStopping counter: 2 out of 5


Epoch 23/25: 100%|██████████| 2363/2363 [1:06:08<00:00,  1.68s/batch, loss=0.0326]

Epoch [23/25], Train Loss: 0.0708, Train PSNR: 26.5959, Train SSIM: 0.8243





Validation Loss: 0.0723, Validation PSNR: 26.4147, Validation SSIM: 0.8202
EarlyStopping counter: 3 out of 5
Epoch 00004: reducing learning rate of group 0 to 2.5000e-05.


Epoch 24/25: 100%|██████████| 2363/2363 [1:06:12<00:00,  1.68s/batch, loss=0.109] 

Epoch [24/25], Train Loss: 0.0702, Train PSNR: 26.7097, Train SSIM: 0.8256





Validation Loss: 0.0717, Validation PSNR: 26.5632, Validation SSIM: 0.8209
Validation loss decreased (0.072086 --> 0.071669).  Saving model ...


Epoch 25/25: 100%|██████████| 2363/2363 [1:06:12<00:00,  1.68s/batch, loss=0.0878]

Epoch [25/25], Train Loss: 0.0700, Train PSNR: 26.6955, Train SSIM: 0.8260





Validation Loss: 0.0725, Validation PSNR: 26.2958, Validation SSIM: 0.8188
EarlyStopping counter: 1 out of 5
Training complete.


In [15]:
torch.save(model, "UNet_l1_ssim_final.pth")

### Training the model for Charbonnier loss

In [None]:
from torch.cuda.amp import GradScaler, autocast

torch.autograd.set_detect_anomaly(True)

if __name__ == "__main__":

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = UNetSRx4RiR(in_channels=3, out_channels=3, num_features=64, dropout_rate=0.0).to(device)
    criterion = CharbonnierLoss(epsilon=1e-4)
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    scaler = GradScaler()
    
    # Initialize TensorBoard writer
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter('runs/UNetSRx4_charbonnier')

    # Initialize early stopping
    early_stopping = EarlyStopping(patience=5, verbose=True)
    
    # Initialize learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=2, verbose=True)

    num_epochs = 25
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_psnr = 0.0
        running_ssim = 0.0

        # Wrap the train_loader with tqdm for progress bar
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
        
        for i, (hr_patches, lr_patches) in enumerate(progress_bar):
            hr_patches, lr_patches = hr_patches.to(device), lr_patches.to(device)
            
            optimizer.zero_grad()
            with autocast(enabled=True):
                outputs = model(lr_patches)
                loss = criterion(outputs, hr_patches)
                if torch.isnan(loss).any():
                    print("Loss is NaN")
                    break
            
            scaler.scale(loss).backward()
            # Apply gradient clipping
            # nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            
            running_loss += loss.item()
            running_psnr += calculate_psnr(outputs, hr_patches).item()
            outputs = outputs.float()
            hr_patches = hr_patches.float()
            running_ssim += ssim(outputs, hr_patches, data_range=1.0).item()

            # Update the progress bar with the running loss
            progress_bar.set_postfix(loss=loss.item())
        
        avg_train_loss = running_loss / len(train_loader)
        avg_train_psnr = running_psnr / len(train_loader)
        avg_train_ssim = running_ssim / len(train_loader)
        writer.add_scalar('Loss/train', avg_train_loss, epoch)
        writer.add_scalar('PSNR/train', avg_train_psnr, epoch)
        writer.add_scalar('SSIM/train', avg_train_ssim, epoch)
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Train PSNR: {avg_train_psnr:.4f}, Train SSIM: {avg_train_ssim:.4f}')
    
        # Validation loop
        model.eval()
        val_loss = 0.0
        val_psnr = 0.0
        val_ssim = 0.0     
        with torch.no_grad():
            for hr_patches, lr_patches in val_loader:
                hr_patches, lr_patches = hr_patches.to(device), lr_patches.to(device)
                with autocast(enabled=True):
                    outputs = model(lr_patches)
                    loss = criterion(outputs, hr_patches)
                    if torch.isnan(loss).any():
                        print("Loss is NaN")
                        break
                val_loss += loss.item()
                val_psnr += calculate_psnr(outputs, hr_patches).item()
                # Convert to float32 before calculating SSIM
                outputs = outputs.float()
                hr_patches = hr_patches.float()
                val_ssim += ssim(outputs, hr_patches, data_range=1.0).item()
    
        avg_val_loss = val_loss / len(val_loader)
        avg_val_psnr = val_psnr / len(val_loader)
        avg_val_ssim = val_ssim / len(val_loader)
        writer.add_scalar('Loss/val', avg_val_loss, epoch)
        writer.add_scalar('PSNR/val', avg_val_psnr, epoch)
        writer.add_scalar('SSIM/val', avg_val_ssim, epoch)
        print(f'Validation Loss: {avg_val_loss:.4f}, Validation PSNR: {avg_val_psnr:.4f}, Validation SSIM: {avg_val_ssim:.4f}')

        # Check early stopping condition
        early_stopping(avg_val_loss, model)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break
        
        # Step the scheduler
        scheduler.step(avg_val_loss)
    
    print('Training complete.')

2024-09-04 04:24:43.462607: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-09-04 04:24:43.603426: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-09-04 04:24:44.318968: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2024-09-04 04:24:44.319057: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] 

Epoch [1/25], Train Loss: 0.0455, Train PSNR: 23.6736, Train SSIM: 0.7262





Validation Loss: 0.0327, Validation PSNR: 25.4157, Validation SSIM: 0.7709
Validation loss decreased (inf --> 0.032664).  Saving model ...


Epoch 2/25: 100%|██████████| 1182/1182 [23:26<00:00,  1.19s/batch, loss=0.0362]

Epoch [2/25], Train Loss: 0.0338, Train PSNR: 25.2284, Train SSIM: 0.7751





Validation Loss: 0.0334, Validation PSNR: 25.4563, Validation SSIM: 0.7848
EarlyStopping counter: 1 out of 5


Epoch 3/25: 100%|██████████| 1182/1182 [23:27<00:00,  1.19s/batch, loss=0.0424]

Epoch [3/25], Train Loss: 0.0316, Train PSNR: 25.5467, Train SSIM: 0.7848





Validation Loss: 0.0320, Validation PSNR: 25.6961, Validation SSIM: 0.7841
Validation loss decreased (0.032664 --> 0.032025).  Saving model ...


Epoch 4/25: 100%|██████████| 1182/1182 [23:27<00:00,  1.19s/batch, loss=0.0171]

Epoch [4/25], Train Loss: 0.0305, Train PSNR: 25.6753, Train SSIM: 0.7899





Validation Loss: 0.0281, Validation PSNR: 26.1552, Validation SSIM: 0.7945
Validation loss decreased (0.032025 --> 0.028102).  Saving model ...


Epoch 5/25: 100%|██████████| 1182/1182 [23:27<00:00,  1.19s/batch, loss=0.0243]

Epoch [5/25], Train Loss: 0.0298, Train PSNR: 25.8089, Train SSIM: 0.7933





Validation Loss: 0.0279, Validation PSNR: 26.1927, Validation SSIM: 0.7957
Validation loss decreased (0.028102 --> 0.027871).  Saving model ...


Epoch 6/25: 100%|██████████| 1182/1182 [23:29<00:00,  1.19s/batch, loss=0.0267]

Epoch [6/25], Train Loss: 0.0292, Train PSNR: 25.8740, Train SSIM: 0.7957





Validation Loss: 0.0297, Validation PSNR: 26.0723, Validation SSIM: 0.7992
EarlyStopping counter: 1 out of 5


Epoch 7/25: 100%|██████████| 1182/1182 [23:29<00:00,  1.19s/batch, loss=0.0663]

Epoch [7/25], Train Loss: 0.0288, Train PSNR: 25.9514, Train SSIM: 0.7981





Validation Loss: 0.0277, Validation PSNR: 26.2551, Validation SSIM: 0.7996
Validation loss decreased (0.027871 --> 0.027697).  Saving model ...


Epoch 8/25: 100%|██████████| 1182/1182 [23:27<00:00,  1.19s/batch, loss=0.0256]

Epoch [8/25], Train Loss: 0.0283, Train PSNR: 26.0037, Train SSIM: 0.8003





Validation Loss: 0.0271, Validation PSNR: 26.3387, Validation SSIM: 0.7997
Validation loss decreased (0.027697 --> 0.027074).  Saving model ...


Epoch 9/25: 100%|██████████| 1182/1182 [23:27<00:00,  1.19s/batch, loss=0.0397]

Epoch [9/25], Train Loss: 0.0281, Train PSNR: 26.0428, Train SSIM: 0.8017





Validation Loss: 0.0282, Validation PSNR: 26.1520, Validation SSIM: 0.7997
EarlyStopping counter: 1 out of 5


Epoch 10/25: 100%|██████████| 1182/1182 [23:27<00:00,  1.19s/batch, loss=0.0478]

Epoch [10/25], Train Loss: 0.0278, Train PSNR: 26.1011, Train SSIM: 0.8033





Validation Loss: 0.0269, Validation PSNR: 26.2393, Validation SSIM: 0.8031
Validation loss decreased (0.027074 --> 0.026872).  Saving model ...


Epoch 11/25: 100%|██████████| 1182/1182 [23:28<00:00,  1.19s/batch, loss=0.0316]

Epoch [11/25], Train Loss: 0.0276, Train PSNR: 26.1277, Train SSIM: 0.8041





Validation Loss: 0.0265, Validation PSNR: 26.4285, Validation SSIM: 0.8055
Validation loss decreased (0.026872 --> 0.026546).  Saving model ...


Epoch 12/25: 100%|██████████| 1182/1182 [23:28<00:00,  1.19s/batch, loss=0.0177]

Epoch [12/25], Train Loss: 0.0272, Train PSNR: 26.1828, Train SSIM: 0.8057





Validation Loss: 0.0267, Validation PSNR: 26.3562, Validation SSIM: 0.8048
EarlyStopping counter: 1 out of 5


Epoch 13/25: 100%|██████████| 1182/1182 [23:26<00:00,  1.19s/batch, loss=0.0361]

Epoch [13/25], Train Loss: 0.0271, Train PSNR: 26.2104, Train SSIM: 0.8064





Validation Loss: 0.0272, Validation PSNR: 26.0432, Validation SSIM: 0.8039
EarlyStopping counter: 2 out of 5


Epoch 14/25: 100%|██████████| 1182/1182 [23:25<00:00,  1.19s/batch, loss=0.0289]

Epoch [14/25], Train Loss: 0.0268, Train PSNR: 26.2478, Train SSIM: 0.8073





Validation Loss: 0.0278, Validation PSNR: 26.2759, Validation SSIM: 0.8047
EarlyStopping counter: 3 out of 5
Epoch 00014: reducing learning rate of group 0 to 5.0000e-05.


Epoch 15/25: 100%|██████████| 1182/1182 [23:28<00:00,  1.19s/batch, loss=0.0304]

Epoch [15/25], Train Loss: 0.0261, Train PSNR: 26.3472, Train SSIM: 0.8100





Validation Loss: 0.0255, Validation PSNR: 26.5442, Validation SSIM: 0.8080
Validation loss decreased (0.026546 --> 0.025503).  Saving model ...


Epoch 16/25: 100%|██████████| 1182/1182 [23:31<00:00,  1.19s/batch, loss=0.0365]

Epoch [16/25], Train Loss: 0.0260, Train PSNR: 26.3464, Train SSIM: 0.8107





Validation Loss: 0.0254, Validation PSNR: 26.3714, Validation SSIM: 0.8083
Validation loss decreased (0.025503 --> 0.025384).  Saving model ...


Epoch 17/25:  57%|█████▋    | 678/1182 [13:29<10:02,  1.19s/batch, loss=0.0228]

Continuing the training after epoch 16 as the session abrutly ended

In [13]:
# Continuing the training after epoch 16 as the session abrutly ended
from torch.cuda.amp import GradScaler, autocast

torch.autograd.set_detect_anomaly(True)

if __name__ == "__main__":

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = UNetSRx4RiR(in_channels=3, out_channels=3, num_features=64, dropout_rate=0.0).to(device)
    criterion = CharbonnierLoss(epsilon=1e-4)
    
    # Load the model's state dict from the checkpoint
    checkpoint = torch.load('checkpoint.pt')
    model.load_state_dict(checkpoint)
    
    optimizer = optim.Adam(model.parameters(), lr=0.00005) # Reducing the loss function slightly
    scaler = GradScaler()
    start_epoch = 16  # Start from the next epoch
    
    # Initialize TensorBoard writer
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter('runs/UNetSRx4_charbonnier')

    # Initialize early stopping
    early_stopping = EarlyStopping(patience=5, verbose=True)
    
    # Initialize learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=2, verbose=True)

    num_epochs = 25
    for epoch in range(start_epoch, num_epochs): # Start from start_rpoch
        model.train()
        running_loss = 0.0
        running_psnr = 0.0
        running_ssim = 0.0

        # Wrap the train_loader with tqdm for progress bar
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
        
        for i, (hr_patches, lr_patches) in enumerate(progress_bar):
            hr_patches, lr_patches = hr_patches.to(device), lr_patches.to(device)
            
            optimizer.zero_grad()
            with autocast(enabled=True):
                outputs = model(lr_patches)
                loss = criterion(outputs, hr_patches)
                if torch.isnan(loss).any():
                    print("Loss is NaN")
                    break
            
            scaler.scale(loss).backward()
            # Apply gradient clipping
            # nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            
            running_loss += loss.item()
            running_psnr += calculate_psnr(outputs, hr_patches).item()
            outputs = outputs.float()
            hr_patches = hr_patches.float()
            running_ssim += ssim(outputs, hr_patches, data_range=1.0).item()

            # Update the progress bar with the running loss
            progress_bar.set_postfix(loss=loss.item())
        
        avg_train_loss = running_loss / len(train_loader)
        avg_train_psnr = running_psnr / len(train_loader)
        avg_train_ssim = running_ssim / len(train_loader)
        writer.add_scalar('Loss/train', avg_train_loss, epoch)
        writer.add_scalar('PSNR/train', avg_train_psnr, epoch)
        writer.add_scalar('SSIM/train', avg_train_ssim, epoch)
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Train PSNR: {avg_train_psnr:.4f}, Train SSIM: {avg_train_ssim:.4f}')
    
        # Validation loop
        model.eval()
        val_loss = 0.0
        val_psnr = 0.0
        val_ssim = 0.0     
        with torch.no_grad():
            for hr_patches, lr_patches in val_loader:
                hr_patches, lr_patches = hr_patches.to(device), lr_patches.to(device)
                with autocast(enabled=True):
                    outputs = model(lr_patches)
                    loss = criterion(outputs, hr_patches)
                    if torch.isnan(loss).any():
                        print("Loss is NaN")
                        break
                val_loss += loss.item()
                val_psnr += calculate_psnr(outputs, hr_patches).item()
                # Convert to float32 before calculating SSIM
                outputs = outputs.float()
                hr_patches = hr_patches.float()
                val_ssim += ssim(outputs, hr_patches, data_range=1.0).item()
    
        avg_val_loss = val_loss / len(val_loader)
        avg_val_psnr = val_psnr / len(val_loader)
        avg_val_ssim = val_ssim / len(val_loader)
        writer.add_scalar('Loss/val', avg_val_loss, epoch)
        writer.add_scalar('PSNR/val', avg_val_psnr, epoch)
        writer.add_scalar('SSIM/val', avg_val_ssim, epoch)
        print(f'Validation Loss: {avg_val_loss:.4f}, Validation PSNR: {avg_val_psnr:.4f}, Validation SSIM: {avg_val_ssim:.4f}')

        # Check early stopping condition
        early_stopping(avg_val_loss, model)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break
        
        # Step the scheduler
        scheduler.step(avg_val_loss)
    
    print('Training complete.')

2024-09-04 15:36:39.383257: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-09-04 15:36:49.704962: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-09-04 15:37:36.680281: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2024-09-04 15:37:36.680679: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] 

Epoch [17/25], Train Loss: 0.0256, Train PSNR: 26.4188, Train SSIM: 0.8120





Validation Loss: 0.0254, Validation PSNR: 26.5091, Validation SSIM: 0.8177
Validation loss decreased (inf --> 0.025369).  Saving model ...


Epoch 18/25: 100%|██████████| 1182/1182 [23:10<00:00,  1.18s/batch, loss=0.0415]

Epoch [18/25], Train Loss: 0.0256, Train PSNR: 26.4330, Train SSIM: 0.8123





Validation Loss: 0.0253, Validation PSNR: 26.4657, Validation SSIM: 0.8169
Validation loss decreased (0.025369 --> 0.025327).  Saving model ...


Epoch 19/25: 100%|██████████| 1182/1182 [23:11<00:00,  1.18s/batch, loss=0.0335]

Epoch [19/25], Train Loss: 0.0255, Train PSNR: 26.4339, Train SSIM: 0.8128





Validation Loss: 0.0251, Validation PSNR: 26.4881, Validation SSIM: 0.8174
Validation loss decreased (0.025327 --> 0.025119).  Saving model ...


Epoch 20/25: 100%|██████████| 1182/1182 [23:11<00:00,  1.18s/batch, loss=0.0279]

Epoch [20/25], Train Loss: 0.0255, Train PSNR: 26.4537, Train SSIM: 0.8123





Validation Loss: 0.0257, Validation PSNR: 26.2949, Validation SSIM: 0.8145
EarlyStopping counter: 1 out of 5


Epoch 21/25: 100%|██████████| 1182/1182 [23:11<00:00,  1.18s/batch, loss=0.0277]

Epoch [21/25], Train Loss: 0.0255, Train PSNR: 26.4522, Train SSIM: 0.8134





Validation Loss: 0.0250, Validation PSNR: 26.5333, Validation SSIM: 0.8184
Validation loss decreased (0.025119 --> 0.025020).  Saving model ...


Epoch 22/25: 100%|██████████| 1182/1182 [23:12<00:00,  1.18s/batch, loss=0.02]  

Epoch [22/25], Train Loss: 0.0253, Train PSNR: 26.4974, Train SSIM: 0.8141





Validation Loss: 0.0254, Validation PSNR: 26.4749, Validation SSIM: 0.8172
EarlyStopping counter: 1 out of 5


Epoch 23/25: 100%|██████████| 1182/1182 [23:11<00:00,  1.18s/batch, loss=0.0215]

Epoch [23/25], Train Loss: 0.0253, Train PSNR: 26.4909, Train SSIM: 0.8143





Validation Loss: 0.0255, Validation PSNR: 26.3423, Validation SSIM: 0.8159
EarlyStopping counter: 2 out of 5


Epoch 24/25: 100%|██████████| 1182/1182 [23:11<00:00,  1.18s/batch, loss=0.0328]

Epoch [24/25], Train Loss: 0.0252, Train PSNR: 26.4916, Train SSIM: 0.8146





Validation Loss: 0.0249, Validation PSNR: 26.5465, Validation SSIM: 0.8187
Validation loss decreased (0.025020 --> 0.024882).  Saving model ...


Epoch 25/25: 100%|██████████| 1182/1182 [23:11<00:00,  1.18s/batch, loss=0.0185]

Epoch [25/25], Train Loss: 0.0252, Train PSNR: 26.5254, Train SSIM: 0.8150





Validation Loss: 0.0256, Validation PSNR: 26.3645, Validation SSIM: 0.8166
EarlyStopping counter: 1 out of 5
Training complete.


In [14]:
torch.save(model, "UNet_charbonnier_eps1e-4_final.pth")

### Training the model for TV + MAE loss

In [14]:
from torch.cuda.amp import GradScaler, autocast

torch.autograd.set_detect_anomaly(True)

if __name__ == "__main__":

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = UNetSRx4RiR(in_channels=3, out_channels=3, num_features=64, dropout_rate=0.0).to(device)
    l1_loss = nn.L1Loss()
    tv_loss = TVLoss(tv_weight=1e-4)
    optimizer = optim.Adam(model.parameters(), lr=0.00005)
    scaler = GradScaler()
    
    # Initialize TensorBoard writer
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter('runs/UNetSRx4_tv-1e-4_mae_final')

    # Initialize early stopping
    early_stopping = EarlyStopping(patience=5, verbose=True)
    
    # Initialize learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=2, verbose=True)

    num_epochs = 25
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_psnr = 0.0
        running_ssim = 0.0

        # Wrap the train_loader with tqdm for progress bar
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
        
        for i, (hr_patches, lr_patches) in enumerate(progress_bar):
            hr_patches, lr_patches = hr_patches.to(device), lr_patches.to(device)
            
            optimizer.zero_grad()
            with autocast(enabled=True):
                outputs = model(lr_patches)
                loss = combined_tv_mae_loss(outputs, hr_patches)
                if torch.isnan(loss).any():
                    print("Loss is NaN")
                    continue
            
            scaler.scale(loss).backward()
            # Apply gradient clipping
            # nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            
            running_loss += loss.item()
            running_psnr += calculate_psnr(outputs, hr_patches).item()
            outputs = outputs.float()
            hr_patches = hr_patches.float()
            running_ssim += ssim(outputs, hr_patches, data_range=1.0).item()

            # Update the progress bar with the running loss
            progress_bar.set_postfix(loss=loss.item())
        
        avg_train_loss = running_loss / len(train_loader)
        avg_train_psnr = running_psnr / len(train_loader)
        avg_train_ssim = running_ssim / len(train_loader)
        writer.add_scalar('Loss/train', avg_train_loss, epoch)
        writer.add_scalar('PSNR/train', avg_train_psnr, epoch)
        writer.add_scalar('SSIM/train', avg_train_ssim, epoch)
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Train PSNR: {avg_train_psnr:.4f}, Train SSIM: {avg_train_ssim:.4f}')
    
        # Validation loop
        model.eval()
        val_loss = 0.0
        val_psnr = 0.0
        val_ssim = 0.0     
        with torch.no_grad():
            for hr_patches, lr_patches in val_loader:
                hr_patches, lr_patches = hr_patches.to(device), lr_patches.to(device)
                with autocast(enabled=True):
                    outputs = model(lr_patches)
                    loss = combined_tv_mae_loss(outputs, hr_patches)
                    if torch.isnan(loss).any():
                        print("Loss is NaN")
                        continue
                val_loss += loss.item()
                val_psnr += calculate_psnr(outputs, hr_patches).item()
                # Convert to float32 before calculating SSIM
                outputs = outputs.float()
                hr_patches = hr_patches.float()
                val_ssim += ssim(outputs, hr_patches, data_range=1.0).item()
    
        avg_val_loss = val_loss / len(val_loader)
        avg_val_psnr = val_psnr / len(val_loader)
        avg_val_ssim = val_ssim / len(val_loader)
        writer.add_scalar('Loss/val', avg_val_loss, epoch)
        writer.add_scalar('PSNR/val', avg_val_psnr, epoch)
        writer.add_scalar('SSIM/val', avg_val_ssim, epoch)
        print(f'Validation Loss: {avg_val_loss:.4f}, Validation PSNR: {avg_val_psnr:.4f}, Validation SSIM: {avg_val_ssim:.4f}')

        # Check early stopping condition
        early_stopping(avg_val_loss, model)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break
        
        # Step the scheduler
        scheduler.step(avg_val_loss)
    
    print('Training complete.')

2024-09-06 12:18:04.038916: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-09-06 12:18:17.010591: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-09-06 12:19:00.370499: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2024-09-06 12:19:00.370912: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] 

Epoch [1/25], Train Loss: 0.0469, Train PSNR: 23.4509, Train SSIM: 0.7100





Validation Loss: 0.0376, Validation PSNR: 24.7016, Validation SSIM: 0.7602
Validation loss decreased (inf --> 0.037649).  Saving model ...


Epoch 2/25: 100%|██████████| 1182/1182 [23:48<00:00,  1.21s/batch, loss=0.0429]

Epoch [2/25], Train Loss: 0.0353, Train PSNR: 24.9652, Train SSIM: 0.7632





Validation Loss: 0.0322, Validation PSNR: 25.4593, Validation SSIM: 0.7760
Validation loss decreased (0.037649 --> 0.032246).  Saving model ...


Epoch 3/25: 100%|██████████| 1182/1182 [23:51<00:00,  1.21s/batch, loss=0.0399]

Epoch [3/25], Train Loss: 0.0325, Train PSNR: 25.3697, Train SSIM: 0.7758





Validation Loss: 0.0303, Validation PSNR: 25.7446, Validation SSIM: 0.7854
Validation loss decreased (0.032246 --> 0.030276).  Saving model ...


Epoch 4/25: 100%|██████████| 1182/1182 [23:50<00:00,  1.21s/batch, loss=0.0421]

Epoch [4/25], Train Loss: 0.0311, Train PSNR: 25.5657, Train SSIM: 0.7826





Validation Loss: 0.0297, Validation PSNR: 25.8333, Validation SSIM: 0.7908
Validation loss decreased (0.030276 --> 0.029730).  Saving model ...


Epoch 5/25: 100%|██████████| 1182/1182 [23:48<00:00,  1.21s/batch, loss=0.0197]

Epoch [5/25], Train Loss: 0.0303, Train PSNR: 25.6908, Train SSIM: 0.7878





Validation Loss: 0.0283, Validation PSNR: 26.0205, Validation SSIM: 0.7936
Validation loss decreased (0.029730 --> 0.028310).  Saving model ...


Epoch 6/25: 100%|██████████| 1182/1182 [23:49<00:00,  1.21s/batch, loss=0.0501]

Epoch [6/25], Train Loss: 0.0295, Train PSNR: 25.7955, Train SSIM: 0.7912





Validation Loss: 0.0282, Validation PSNR: 26.0504, Validation SSIM: 0.7972
Validation loss decreased (0.028310 --> 0.028214).  Saving model ...


Epoch 7/25: 100%|██████████| 1182/1182 [23:48<00:00,  1.21s/batch, loss=0.0365]

Epoch [7/25], Train Loss: 0.0289, Train PSNR: 25.8855, Train SSIM: 0.7942





Validation Loss: 0.0280, Validation PSNR: 26.1129, Validation SSIM: 0.7988
Validation loss decreased (0.028214 --> 0.028016).  Saving model ...


Epoch 8/25: 100%|██████████| 1182/1182 [23:47<00:00,  1.21s/batch, loss=0.0224]

Epoch [8/25], Train Loss: 0.0287, Train PSNR: 25.9381, Train SSIM: 0.7959





Validation Loss: 0.0274, Validation PSNR: 26.1736, Validation SSIM: 0.7982
Validation loss decreased (0.028016 --> 0.027355).  Saving model ...


Epoch 9/25: 100%|██████████| 1182/1182 [23:46<00:00,  1.21s/batch, loss=0.0295]

Epoch [9/25], Train Loss: 0.0282, Train PSNR: 25.9896, Train SSIM: 0.7980





Validation Loss: 0.0278, Validation PSNR: 26.1704, Validation SSIM: 0.7992
EarlyStopping counter: 1 out of 5


Epoch 10/25: 100%|██████████| 1182/1182 [23:46<00:00,  1.21s/batch, loss=0.0185]

Epoch [10/25], Train Loss: 0.0280, Train PSNR: 26.0230, Train SSIM: 0.7996





Validation Loss: 0.0274, Validation PSNR: 26.1925, Validation SSIM: 0.7993
EarlyStopping counter: 2 out of 5


Epoch 11/25: 100%|██████████| 1182/1182 [23:45<00:00,  1.21s/batch, loss=0.0383]

Epoch [11/25], Train Loss: 0.0278, Train PSNR: 26.0902, Train SSIM: 0.8006





Validation Loss: 0.0266, Validation PSNR: 26.2838, Validation SSIM: 0.8039
Validation loss decreased (0.027355 --> 0.026636).  Saving model ...


Epoch 12/25: 100%|██████████| 1182/1182 [23:44<00:00,  1.21s/batch, loss=0.0418]

Epoch [12/25], Train Loss: 0.0275, Train PSNR: 26.1054, Train SSIM: 0.8021





Validation Loss: 0.0270, Validation PSNR: 26.2725, Validation SSIM: 0.8045
EarlyStopping counter: 1 out of 5


Epoch 13/25: 100%|██████████| 1182/1182 [23:44<00:00,  1.20s/batch, loss=0.0435]

Epoch [13/25], Train Loss: 0.0273, Train PSNR: 26.1504, Train SSIM: 0.8031





Validation Loss: 0.0318, Validation PSNR: 25.2761, Validation SSIM: 0.7928
EarlyStopping counter: 2 out of 5


Epoch 14/25: 100%|██████████| 1182/1182 [23:43<00:00,  1.20s/batch, loss=0.0249]

Epoch [14/25], Train Loss: 0.0271, Train PSNR: 26.1726, Train SSIM: 0.8039





Validation Loss: 0.0271, Validation PSNR: 26.2755, Validation SSIM: 0.8033
EarlyStopping counter: 3 out of 5
Epoch 00014: reducing learning rate of group 0 to 2.5000e-05.


Epoch 15/25: 100%|██████████| 1182/1182 [23:45<00:00,  1.21s/batch, loss=0.0434]

Epoch [15/25], Train Loss: 0.0263, Train PSNR: 26.2761, Train SSIM: 0.8069





Validation Loss: 0.0258, Validation PSNR: 26.3978, Validation SSIM: 0.8074
Validation loss decreased (0.026636 --> 0.025815).  Saving model ...


Epoch 16/25: 100%|██████████| 1182/1182 [23:44<00:00,  1.21s/batch, loss=0.022] 

Epoch [16/25], Train Loss: 0.0261, Train PSNR: 26.3118, Train SSIM: 0.8080





Validation Loss: 0.0255, Validation PSNR: 26.4300, Validation SSIM: 0.8087
Validation loss decreased (0.025815 --> 0.025523).  Saving model ...


Epoch 17/25: 100%|██████████| 1182/1182 [23:43<00:00,  1.20s/batch, loss=0.0345]

Epoch [17/25], Train Loss: 0.0261, Train PSNR: 26.3207, Train SSIM: 0.8084





Validation Loss: 0.0259, Validation PSNR: 26.4123, Validation SSIM: 0.8066
EarlyStopping counter: 1 out of 5


Epoch 18/25: 100%|██████████| 1182/1182 [23:43<00:00,  1.20s/batch, loss=0.03]  

Epoch [18/25], Train Loss: 0.0260, Train PSNR: 26.3400, Train SSIM: 0.8089





Validation Loss: 0.0254, Validation PSNR: 26.4487, Validation SSIM: 0.8091
Validation loss decreased (0.025523 --> 0.025404).  Saving model ...


Epoch 19/25: 100%|██████████| 1182/1182 [23:43<00:00,  1.20s/batch, loss=0.0257]

Epoch [19/25], Train Loss: 0.0259, Train PSNR: 26.3579, Train SSIM: 0.8096





Validation Loss: 0.0257, Validation PSNR: 26.4356, Validation SSIM: 0.8096
EarlyStopping counter: 1 out of 5


Epoch 20/25: 100%|██████████| 1182/1182 [23:44<00:00,  1.21s/batch, loss=0.0343]

Epoch [20/25], Train Loss: 0.0259, Train PSNR: 26.3747, Train SSIM: 0.8099





Validation Loss: 0.0255, Validation PSNR: 26.4686, Validation SSIM: 0.8098
EarlyStopping counter: 2 out of 5


Epoch 21/25: 100%|██████████| 1182/1182 [23:42<00:00,  1.20s/batch, loss=0.0254]

Epoch [21/25], Train Loss: 0.0258, Train PSNR: 26.4058, Train SSIM: 0.8105





Validation Loss: 0.0255, Validation PSNR: 26.4624, Validation SSIM: 0.8109
EarlyStopping counter: 3 out of 5
Epoch 00021: reducing learning rate of group 0 to 1.2500e-05.


Epoch 22/25: 100%|██████████| 1182/1182 [23:43<00:00,  1.20s/batch, loss=0.0292]

Epoch [22/25], Train Loss: 0.0254, Train PSNR: 26.4531, Train SSIM: 0.8118





Validation Loss: 0.0254, Validation PSNR: 26.4769, Validation SSIM: 0.8112
Validation loss decreased (0.025404 --> 0.025362).  Saving model ...


Epoch 23/25: 100%|██████████| 1182/1182 [23:43<00:00,  1.20s/batch, loss=0.0201]

Epoch [23/25], Train Loss: 0.0254, Train PSNR: 26.4558, Train SSIM: 0.8122





Validation Loss: 0.0252, Validation PSNR: 26.4951, Validation SSIM: 0.8110
Validation loss decreased (0.025362 --> 0.025232).  Saving model ...


Epoch 24/25: 100%|██████████| 1182/1182 [23:44<00:00,  1.20s/batch, loss=0.0438]

Epoch [24/25], Train Loss: 0.0253, Train PSNR: 26.4742, Train SSIM: 0.8126





Validation Loss: 0.0254, Validation PSNR: 26.4898, Validation SSIM: 0.8113
EarlyStopping counter: 1 out of 5


Epoch 25/25: 100%|██████████| 1182/1182 [23:44<00:00,  1.21s/batch, loss=0.0176]

Epoch [25/25], Train Loss: 0.0253, Train PSNR: 26.5016, Train SSIM: 0.8130





Validation Loss: 0.0252, Validation PSNR: 26.5008, Validation SSIM: 0.8117
Validation loss decreased (0.025232 --> 0.025196).  Saving model ...
Training complete.


In [15]:
torch.save(model, "UNet_1e-4tv_mae_final.pth")