In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# dataset

import cv2
import os
import torch
import random
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


def create_dataloader(split, SR_rate, augment, batch_size=1, shuffle=False, num_workers=1, pin_memory=True):
    dataset = dataread(split, SR_rate, augment)
    dataloader = DataLoader(dataset, batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory)
    # print('check 0', dataloader) #me
    return dataloader

def random_crop(LR_img, HR_img, crop_size, SR_rate):
    # check the shape
    # print('LR_img.shape', LR_img.shape) #me
    # print('HR_img.shape', HR_img.shape) #me
    LR_h, LR_w = LR_img.shape[:2]
    HR_h, HR_w = HR_img.shape[:2]
    # print('np.round_(LR_h * SR_rate)=', np.round_(LR_h * SR_rate), ', HR_h=', HR_h)
    # assert np.round_(LR_h * SR_rate) == HR_h and np.round_(LR_w * SR_rate) == HR_w, 'SR_rate is not correct for LR and HR image'
    # check the crop size
    new_LR_h, new_LR_w = crop_size
    assert new_LR_h <= LR_h and new_LR_w <= LR_w, 'crop_size is too large'

    y1 = random.randint(0, LR_h - new_LR_h)
    x1 = random.randint(0, LR_w - new_LR_w)

    LR_crop = LR_img[y1:y1 + new_LR_h, x1:x1 + new_LR_w, :]
    HR_crop = HR_img[SR_rate * y1:SR_rate * (y1 + new_LR_h), SR_rate * x1:SR_rate * (x1 + new_LR_w), :]

    return LR_crop, HR_crop

class dataread(Dataset):  # for training/testing
    def __init__(self, split, SR_rate, augment=False):

        self.split = split
        self.SR_rate = SR_rate
        self.augment = augment
        self.intensity_list = [1.0, 0.7, 0.5]
        self.crop_size = [32, 32]

        # data split
        if split == 'train':
            #self.LR_dir = os.path.join('/content/drive/MyDrive/ColabNotebooks/Research_works/dataset/DIV2K/DIV2K_train_LR_bicubic/', 'X'+str(SR_rate))
            self.LR_dir_l = os.path.join('/content/drive/MyDrive/phd/StereoSR/datasets/Flickr_modifd/LR/X2/separated/Train/left')
            self.LR_dir_r = os.path.join('/content/drive/MyDrive/phd/StereoSR/datasets/Flickr_modifd/LR/X2/separated/Train/right')
            self.HR_dir_l = '/content/drive/MyDrive/phd/StereoSR/datasets/Flickr_modifd/HR/size_corrected/hr_x2/train/left'
            self.HR_dir_r = '/content/drive/MyDrive/phd/StereoSR/datasets/Flickr_modifd/HR/size_corrected/hr_x2/train/right'
            self.img_names_l = sorted(os.listdir(self.LR_dir_l))
            self.img_names_r = sorted(os.listdir(self.LR_dir_r))
            # print('Sorted HR_l images upto 5: ', self.img_names_l)
            # print('Sorted HR_r images upto 5: ', self.img_names_r)
            # print('LR_dir_l: ', self.LR_dir_l)
            # print('LR_dir_r: ', self.LR_dir_r)
            # print('HR_dir_l: ',self.HR_dir_l)
            # print('HR_dir_r: ',self.HR_dir_r)

        elif split == 'valid':
            #self.LR_dir = os.path.join('/content/drive/MyDrive/ColabNotebooks/Research_works/dataset/DIV2K/DIV2K_train_LR_bicubic/', 'X'+str(SR_rate))
            self.LR_dir_l = os.path.join('/content/drive/MyDrive/phd/StereoSR/datasets/Flickr_modifd/LR/X2/separated/val/left')
            self.LR_dir_r = os.path.join('/content/drive/MyDrive/phd/StereoSR/datasets/Flickr_modifd/LR/X2/separated/val/right')
            self.HR_dir_l = '/content/drive/MyDrive/phd/StereoSR/datasets/Flickr_modifd/HR/size_corrected/hr_x2/val/left'
            self.HR_dir_r = '/content/drive/MyDrive/phd/StereoSR/datasets/Flickr_modifd/HR/size_corrected/hr_x2/val/right'
            self.img_names_l = sorted(os.listdir(self.HR_dir_l))
            self.img_names_r = sorted(os.listdir(self.HR_dir_r))
        elif split == 'test':

            # self.LR_dir_l = os.path.join('/content/drive/MyDrive/phd/StereoSR/datasets/kitti_mix/LR/x2/left')
            # self.LR_dir_r = os.path.join('/content/drive/MyDrive/phd/StereoSR/datasets/kitti_mix/LR/x2/right')
            # self.HR_dir_l = '/content/drive/MyDrive/phd/StereoSR/datasets/kitti_mix/HR_size_crctd/x2/left'
            # self.HR_dir_r = '/content/drive/MyDrive/phd/StereoSR/datasets/kitti_mix/HR_size_crctd/x2/right'

            self.LR_dir_l = os.path.join('/content/drive/MyDrive/phd/StereoSR/datasets/x2_versions/x2_flickr/LR/left')
            self.LR_dir_r = os.path.join('/content/drive/MyDrive/phd/StereoSR/datasets/x2_versions/x2_flickr/LR/right')
            self.HR_dir_l = '/content/drive/MyDrive/phd/StereoSR/datasets/x2_versions/x2_flickr/HR_size_crctd/left'
            self.HR_dir_r = '/content/drive/MyDrive/phd/StereoSR/datasets/x2_versions/x2_flickr/HR_size_crctd/right'

            # self.LR_dir_l = os.path.join('/content/drive/MyDrive/phd/StereoSR/datasets/kitti_mix/LR/x2/left')
            # self.LR_dir_r = os.path.join('/content/drive/MyDrive/phd/StereoSR/datasets/kitti_mix/LR/x2/right')
            # self.HR_dir_l = '/content/drive/MyDrive/phd/StereoSR/datasets/kitti_mix/HR_size_crctd/x2/left'
            # self.HR_dir_r = '/content/drive/MyDrive/phd/StereoSR/datasets/kitti_mix/HR_size_crctd/x2/right'

            self.img_names_l = sorted(os.listdir(self.HR_dir_l))
            self.img_names_r = sorted(os.listdir(self.HR_dir_r))
        else:
            raise NameError('data split must be "train", "valid" or "test". ')



    def __len__(self):


        if self.split == 'train':
           return len(self.img_names_l)
        else:
          return len(self.img_names_r)



    def __getitem__(self, index):

        if self.split == 'train':
            #LR_img = cv2.imread(os.path.join(self.LR_dir, self.img_names[index][:-4]+'x'+str(self.SR_rate)+'.png')) / 255.  #me

            #load  LR stereo
            LR_path_l = os.path.join(self.LR_dir_l, self.img_names_l[index])
            LR_path_r = os.path.join(self.LR_dir_r, self.img_names_r[index])
            # print('LR_path_l: ',LR_path_l)
            # print('LR_path_r: ',LR_path_r)
            # print('type_l: ', type(LR_path_l))
            # print('type_r: ', type(LR_path_r))
            LR_img1_l = cv2.imread(LR_path_l)  #me
            LR_img1_r = cv2.imread(LR_path_r)  #me
            # print(type(LR_img1_l))
            LR_img_l = LR_img1_l/255
            LR_img_r = LR_img1_r/255

            #load HR stereo
            HR_img_l = cv2.imread(os.path.join(self.HR_dir_l, self.img_names_l[index])) / 255.
            HR_img_r = cv2.imread(os.path.join(self.HR_dir_r, self.img_names_r[index])) / 255

            if self.augment:
                # random crop
                LR_img_l, HR_img_l = random_crop(LR_img_l, HR_img_l, self.crop_size, self.SR_rate)
                LR_img_r, HR_img_r = random_crop(LR_img_r, HR_img_r, self.crop_size, self.SR_rate)

                # geometric transformations
                if random.random() < 0.5: # hflip
                    LR_img_l, LR_img_r = LR_img_l[:, ::-1, :], LR_img_r[:, ::-1, :]
                    HR_img_l, HR_img_r = HR_img_l[:, ::-1, :], HR_img_r[:, ::-1, :]
                if random.random() < 0.5: # vflip
                    LR_img_l, LR_img_r = LR_img_l[::-1, :, :], LR_img_r[::-1, :, :]
                    HR_img_l, HR_img_r = HR_img_l[::-1, :, :], HR_img_r[::-1, :, :]
                if random.random() < 0.5: # rot90
                    LR_img_l, LR_img_r = LR_img_l.transpose(1, 0, 2), LR_img_r.transpose(1, 0, 2)
                    HR_img_l, HR_img_r = HR_img_l.transpose(1, 0, 2), HR_img_r.transpose(1, 0, 2)

                # intensity scale
                intensity_scale = random.choice(self.intensity_list)
                LR_img_l *= intensity_scale
                LR_img_r *= intensity_scale
                HR_img_l *= intensity_scale
                HR_img_r *= intensity_scale


        else:
            #LR_img = cv2.imread(os.path.join(self.LR_dir, self.img_names[index][:-4]+'x'+str(self.SR_rate)+'.png')) / 255. #me
            LR_img_l = cv2.imread(os.path.join(self.LR_dir_l, self.img_names_l[index])) / 255.  #me
            LR_img_r = cv2.imread(os.path.join(self.LR_dir_r, self.img_names_r[index])) / 255.
            HR_img_l = cv2.imread(os.path.join(self.HR_dir_l, self.img_names_l[index])) / 255.
            HR_img_r = cv2.imread(os.path.join(self.HR_dir_r, self.img_names_r[index])) / 255.

        # Convert
        LR_img_l = np.ascontiguousarray(LR_img_l.transpose(2, 0, 1)) # HWC => CHW
        LR_img_r = np.ascontiguousarray(LR_img_r.transpose(2, 0, 1))
        HR_img_l = np.ascontiguousarray(HR_img_l.transpose(2, 0, 1))
        HR_img_r = np.ascontiguousarray(HR_img_r.transpose(2, 0, 1))

        #return torch.from_numpy(LR_img_l), torch.from_numpy(LR_img_r), torch.from_numpy(HR_img_l), torch.from_numpy(HR_img_r), self.img_names_l[index],self.img_names_r[index]
        return torch.from_numpy(LR_img_l), torch.from_numpy(LR_img_r), torch.from_numpy(HR_img_l), torch.from_numpy(HR_img_r), self.img_names_l[index],self.img_names_r[index]

if __name__ == '__main__':
    os.makedirs('/content/drive/MyDrive/phd/wk1/phase1_baseline/Output_UP_VM_CR_b4_loss_anchor/x2/test_dataloader', exist_ok=True)
    train_dataloader = create_dataloader('test',2, False, batch_size=1, shuffle=False, num_workers=1)
    print(f"len(train_dataloader): {len(train_dataloader)}")
    #LR_img, HR_img, img_names = next(iter(train_dataloader))
    # iterator = iter(train_dataloader) #me
    # LR_img_l, LR_img_r, HR_img_l, HR_img_r, img_names_l, img_names_r = next(iterator)
    # # #LR_img = next(iterator)
    # # print(f"LR_img shape: {LR_img_l.size()}")
    # # print(f"LR_img shape: {LR_img_r.size()}")
    # # print(f"HR_img shape: {HR_img_l.size()}")
    # # print(f"HR_img shape: {HR_img_r.size()}")
    # # print('left image names in iterator: ', img_names_l, len(img_names_l))
    # # print('right image names in iterator: ', img_names_r, len(img_names_r))
    # LR_img_l = LR_img_l[0].numpy().transpose(1, 2, 0)
    # LR_img_r = LR_img_r[0].numpy().transpose(1, 2, 0)
    # HR_img_l = HR_img_l[0].numpy().transpose(1, 2, 0)
    # HR_img_r = HR_img_r[0].numpy().transpose(1, 2, 0)
    # cv2.imwrite('./test_dataloader/LR_img_l.png', np.uint8(LR_img_l*255))
    # cv2.imwrite('./test_dataloader/LR_img_r.png', np.uint8(LR_img_r*255))
    # cv2.imwrite('./test_dataloader/HR_img_l.png', np.uint8(HR_img_l*255))
    # cv2.imwrite('./test_dataloader/HR_img_r.png', np.uint8(HR_img_r*255))

len(train_dataloader): 112


In [3]:
# visualization

import os
import cv2

import torch
import numpy as np
import matplotlib.pyplot as plt

def save_res(pred_HR_l, pred_HR_r, img_names_l, img_names_r, folder):
    '''

    Parameters
    ----------
    preds : List
        each pred has a shape of 1x3xHxW. BGR
    img_names : List

    Returns
    -------
    None.

    '''
    # for pred_HR_l, pred_HR_r, img_names_l, img_names_r in zip(pred_HR_l, pred_HR_r, img_names_l, img_names_r):
    #     pred_img_l = pred[0].cpu().numpy().transpose(1,2,0)
    #     pred_img_r = pred[0].cpu().numpy().transpose(1,2,0)
    #     cv2.imwrite(os.path.join(save_dir, img_names_l), np.uint8(pred_img_l*255))
    #     cv2.imwrite(os.path.join(save_dir, img_names_r), np.uint8(pred_img_r*255))
    # return

    for pred_HR_l, pred_HR_r, img_name_l, img_name_r in zip(pred_HR_l, pred_HR_r, img_names_l, img_names_r):
        pred_img_l = pred_HR_l[0].detach().cpu().numpy().transpose(1, 2, 0)
        pred_img_r = pred_HR_r[0].detach().cpu().numpy().transpose(1, 2, 0)
        cv2.imwrite(os.path.join(folder, img_name_l), np.uint8(pred_img_l * 255))
        cv2.imwrite(os.path.join(folder, img_name_r), np.uint8(pred_img_r * 255))
    return


def visualize_training(save_dir,epoch_start):
    txt_res = os.path.join(save_dir, f"result_epoch_{epoch_start}.txt")
    with open(txt_res, 'r') as f:
        info = f.readlines()
    epoch, lr, train_loss, valid_loss, psnr = [], [], [], [], []

    for line in info[:-1]:
        line = line.strip().split('|')
        epoch.append(int(line[0].split(':')[1]))
        lr.append(float(line[1].split(':')[1]))
        train_loss.append(float(line[2].split(':')[1]))
        valid_loss.append(float(line[3].split(':')[1]))
        psnr.append(float(line[4].split(':')[1]))

    fig = plt.figure(figsize=(16, 8), dpi=400)
    ax1 = fig.add_subplot(221)
    ax2 = fig.add_subplot(222)
    ax3 = fig.add_subplot(223)
    ax4 = fig.add_subplot(224)
    ax1.title.set_text('Training loss')
    ax2.title.set_text('Validation loss')
    ax3.title.set_text('PSNR (validation)')
    ax4.title.set_text('learning rate')
    ax1.plot(epoch, train_loss)
    ax2.plot(epoch, valid_loss)
    ax3.plot(epoch, psnr)
    ax4.plot(epoch, lr)
    plt.savefig(os.path.join(save_dir, 'results.png'), dpi=200)
    return

In [4]:
# metric


import torch

def cal_psnr(x, y):
    '''
    Parameters
    ----------
    x, y are two tensors has the same shape (1, C, H, W)

    Returns
    -------
    score : PSNR.
    '''

    mse = torch.mean((x - y) ** 2, dim=[1, 2, 3])
    score = - 10 * torch.log10(mse)
    return score

In [5]:
# loss

import torch
from torch import nn

class L1Loss(object):
    def __call__(self, input, target):
        return torch.abs(input - target).mean()

class CharbonnierLoss(nn.Module):

    def __init__(self, eps=0.01):
        super(CharbonnierLoss, self).__init__()
        self.eps = eps

    def forward(self, pred, gt):
        # print(len(pred),'&', len(gt))
        loss = torch.sqrt((pred - gt)**2 + self.eps).mean()

        return loss.mean()

In [6]:
########################## new model #############################
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import math
from torch.quantization import QuantStub, DeQuantStub
from torchsummary import summary
from skimage import morphology
import torch.nn.functional as F

class ClippedReLU(nn.Module):
    def __init__(self):
        super(ClippedReLU, self).__init__()

    def forward(self, x):
        return x.clamp(min=0., max=1.)

class ConvRelu(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, bias=True):
        super(ConvRelu, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # print("input tensor shape @ clipped relu =", x.shape)
        x = self.conv(x)
        x = self.relu(x)
        return x

class Dblock(nn.Module):
    def __init__(self, in_channels, out_channels, groups):
        super(Dblock, self).__init__()
        self.conv0 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, groups=groups)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(out_channels, out_channels, kernel_size=1, padding=0)

    def forward(self, x):
        x = self.conv0(x)
        x = self.relu(x)
        x = self.conv1(x)
        return x

class ResB(nn.Module):
    def __init__(self, channels):
        super(ResB, self).__init__()
        self.body = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1, bias=False),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(channels, channels, 3, 1, 1, bias=False),
        )
    def __call__(self,x):
        out = self.body(x)
        return out + x

class SaveValues:
    def __init__(self):
        self.weights = None
        self.activations = None
        self.device = 'cuda'
        # self.device = 'cpu'
    def hook_weights(self, module, input, output):
        self.weights = module.weight.data.to(self.device).view(-1)
    def hook_activations(self, module, input, output):
        self.activations = output.data.to(self.device)
    # def hook_fn_grad(self, module, grad_input, grad_output):
    #     self.gradients = grad_output[0]
    def clear(self):
        self.weights = None
        self.activations = None

class Linear_Activation(nn.Module):
    def __init__(self):
        super(Linear_Activation, self).__init__()

        self.linear = None  # Define linear layer here
        self.save_values = SaveValues()

    def forward(self, x):

        # Subtract mean along the last dimension
        device = 'cuda'
        # device = 'cpu'
        channels, height, width = x.size()
        # print('channels, height, width of i/p linear activation=', x.size)
        x_flatten = x
        x_mean = torch.mean(x_flatten, dim=-1, keepdim=True)
        x_subtracted = x_flatten - x_mean
        # print('x_subtracted', x_subtracted.shape)

        x_subtracted_cpu = x_subtracted.to(device)

        # Create the linear layer dynamically based on the input width
        self.linear = nn.Linear(width, height)
        # print(self.linear.weight.shape)
        self.linear.to(device)

         # Register hooks for the b1 layers
        hook_handle_weights = self.linear.register_forward_hook(self.save_values.hook_weights)

        # Apply linear transformation
        x_transformed = self.linear(x_subtracted_cpu)

        # Access the saved weights and activations
        weights_linear = self.save_values.weights
        # activations_linear = self.save_values.activations
        # Remove the hooks
        hook_handle_weights.remove()
        # hook_handle_activations.remove()

        # print('weights_linear o/p linear activation=', weights_linear.size)

        # return weights_linear, activations_linear
        return weights_linear

class Get_weights_activations(nn.Module):
    def __init__(self, input_dim):
        super(Get_weights_activations, self).__init__()
        self.conv2 = None
        self.save_values = SaveValues()

    def forward(self, x):

        device = 'cuda'
        # device = 'cpu'
        # Assuming x is your input tensor on the GPU
        batch_size, channels, height, width = x.size()
        # print('input size of getActivation = ',batch_size, channels, height, width )
        x_cpu = x.to(device)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, padding=1)
        self.conv2.to(device)
        hook_handle_activations = self.conv2.register_forward_hook(self.save_values.hook_activations)
        # convolutional layer
        x_out = self.conv2(x_cpu)
        # Access the saved weights and activations
        # weights_x_out = self.save_values.weights
        activations_x_out = self.save_values.activations
        # print('activations_x_out = ', activations_x_out.shape)
        # Remove the hooks
        # hook_handle_weights.remove()
        hook_handle_activations.remove()

        # return weights_x_out, activations_x_out
        return activations_x_out

class Valid_Mask_CAM(object):
    """ Class Activation Mapping """

    def __init__(self, channels):

        self.linear_activation = Linear_Activation()
        self.get_attention = Get_weights_activations(channels)

    def forward(self, A, B):

        b2,c2,h2,w2 = A.shape
        B = B.permute(1,0,2)
        # print('Input to linear actvtn = ', B.shape)
        values_A = self.get_attention(A)
        values_B = self.linear_activation(B)

        # b, c, h, w = values_A[1].shape
        b, c, h, w = values_A.shape
        linear_layer_output = values_B
        # print('output of linear actvtn weights.shape = ', linear_layer_output.shape)
        linear_layer_output = linear_layer_output.unsqueeze(0).unsqueeze(0)
        linear_layer_output = linear_layer_output.view(b,c,h,w)
        # print('output of linear actvtn weights.shape = ', linear_layer_output.shape)

        valid_mask_cam = torch.mul(values_A, linear_layer_output)

        # Set values greater than 0.1 to 0
        valid_mask_cam[valid_mask_cam > 0.1] = 0

        # Set negative values to 0
        valid_mask_cam[valid_mask_cam < 0] = 0

        return valid_mask_cam.data

    def __call__(self, A, B):
        return self.forward(A, B)

class PAM(nn.Module):
    def __init__(self, channels):
        super(PAM, self).__init__()
        self.b1 = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
        self.b2 = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
        self.b3 = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
        self.b4 = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
        self.softmax = nn.Softmax(-1)
        self.rb = ResB(64)
        self.fusion = nn.Conv2d(channels * 2 + 1, channels, 1, 1, 0, bias=True)
        self.valid_mask = Valid_Mask_CAM(channels)

    def __call__(self, x_left, x_right, is_training):
        b, c, h, w = x_left.shape
        buffer_left = self.rb(x_left)
        buffer_right = self.rb(x_right)

        ### M_{right_to_left}
        Q = self.b1(buffer_left).permute(0, 2, 3, 1)                                                # B * H * W * C
        S = self.b2(buffer_right).permute(0, 2, 1, 3)                                               # B * H * C * W
        # print("Q shape:", Q.shape)
        # print("S shape:", S.shape)
        score_l = torch.bmm(Q.contiguous().view(-1, w, c),
                          S.contiguous().view(-1, c, w))                                            # (B*H) * W * W
        M_right_to_left = self.softmax(score_l)

        ### M_{left_to_right}
        Q = self.b1(buffer_right).permute(0, 2, 3, 1)                                               # B * H * W * C
        S = self.b2(buffer_left).permute(0, 2, 1, 3)                                                # B * H * C * W
        '''the input tensors Q and S may represent the hidden states of two
        different encoders. '''
        score_r = torch.bmm(Q.contiguous().view(-1, w, c),
                          S.contiguous().view(-1, c, w))                                            # (B*H) * W * W
        # '''what Score do? The torch.bmm() operation can be used to calculate the
        # similarity between the two hidden states (Q,S), which can be used to predict
        # the relationship between the two input sequences.'''
        M_left_to_right = self.softmax(score_r)
        # M_left_to_right = self.softmax(score_l.permute(0, 2, 1))

        Vin_L = self.b3(x_left)
        Vin_R = self.b3(x_right)
        # print('Vmask_l which is A in valid_CAM =', Vin_L.shape)
        # print('M_right_to_left which is B in valid_CAM =', M_right_to_left.shape)
        V_right_left = self.valid_mask(Vin_L, M_right_to_left)
        V_left_right = self.valid_mask(Vin_R, M_left_to_right)
        # V_right_left = V_right_left.view(b, 1, h, w)
        # V_left_right = V_left_right.view(b, 1, h, w)
        # print('V_right_left =', V_right_left.shape)

        if is_training==1:
          # V_right_left = self.valid_mask(Vmask_L, M_right_to_left)
          # V_left_right = self.valid_mask(Vmask_R, M_left_to_right)
          M_left_right_left = torch.bmm(M_right_to_left, M_left_to_right)
          M_right_left_right = torch.bmm(M_left_to_right, M_right_to_left)

        ### fusion
        buffer_l = self.b4(x_right).permute(0,2,3,1).contiguous().view(-1, w, c)                      # (B*H) * W * C
        buffer_l = torch.bmm(M_right_to_left, buffer_l).contiguous().view(b, h, w, c).permute(0,3,1,2)  #  B * C * H * W
        # print('buffer_l', buffer_l.shape)
        # print('x_left.shape',x_left.shape)
        out_l = self.fusion(torch.cat((buffer_l, V_left_right, x_left), 1))

        buffer_r = self.b4(x_left).permute(0,2,3,1).contiguous().view(-1, w, c)                      # (B*H) * W * C
        buffer_r = torch.bmm(M_left_to_right, buffer_r).contiguous().view(b, h, w, c).permute(0,3,1,2)  #  B * C * H * W
        out_r = self.fusion(torch.cat((buffer_r, V_right_left, x_right), 1))
        '''This Conv2d layer in 'out' is  used to fuse the original input tensor
         with its reflection and a channel of 1s. This helps the network learn
         to distinguish between foreground and background pixels in the image.'''

        ## output
        if is_training == 1:
            return out_l, out_r, \
               (M_right_to_left.contiguous().view(b, h, w, w), M_left_to_right.contiguous().view(b, h, w, w)), \
               (M_left_right_left.view(b,h,w,w), M_right_left_right.view(b,h,w,w)), \
               (V_left_right, V_right_left)
        if is_training == 0:
            return out_l, out_r


class XLSR_stereo_AM(nn.Module):
    def __init__(self, SR_rate):
        super(XLSR_stereo_AM, self).__init__()
        self.upscale_factor = SR_rate

        ### feature extraction
        self.init_feature  = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3,padding=1),
                                            nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=1),
                                           nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
                                            ConvRelu(in_channels=32, out_channels=64, kernel_size=1))

        ### paralax attention
        self.pam = PAM(64)

        ### Dblocks
        self.Dblocks = nn.Sequential(Dblock(64, 64, 4),Dblock(64, 64, 4),
                                     Dblock(64, 64, 4))

        self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
                                   nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
                                   nn.LeakyReLU(0.1, inplace=True))


        ########### Anchors making
        self.anch_func = lambda x_list: torch.cat(x_list, dim=1)
        self.conv_anch = nn.Conv2d(in_channels=30, out_channels=64, kernel_size=3, padding=1)

        #### upscale module
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64*SR_rate**2, kernel_size=3, padding=1)
        self.clippedReLU = ClippedReLU()
        self.upscale = nn.Sequential(nn.PixelShuffle(SR_rate),
            nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, padding=1),  # No need for bias=False here
            nn.LeakyReLU(0.1, inplace=True),  # Add LeakyReLU activation after the last convolution
            nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1),  # Add one more convolution with padding
            )

        # weights initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='relu')
                _, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(m.weight.data)
                std = math.sqrt(2/fan_out*0.1)
                torch.nn.init.normal_(m.weight.data, mean=0, std=std)
                if m.bias is not None:
                    nn.init.constant_(m.bias.data, 0.01)


    def forward(self, x_left, x_right, is_training):
        ### feature extraction
        buffer_left = self.init_feature(x_left)
        buffer_right = self.init_feature(x_right)

        # print("x_left shape:", x_left.shape)

        if is_training == 1:
            ### parallax attention
            buffer_lt,buffer_rt, (M_right_to_left, M_left_to_right), (M_left_right_left, M_right_left_right), \
            (V_left_to_right, V_right_to_left) = self.pam(buffer_left, buffer_right, is_training)

            # print("PAM 1st PAM shape:", buffer.shape)

            ##### G-block
            res1_l = self.Dblocks(buffer_lt)
            res1_r = self.Dblocks(buffer_rt)

            # print("o/p of Dblock shape:", res1.shape)


            res2_l = self.conv1(x_left)
            res2_r = self.conv1(x_right)

            # print("res2_l shape:", res2_l.shape)
            # print("res2_r shape:", res2_r.shape)

            ### parallax attention_left
            buffer_l, buffer_lr, (M_right_to_left, M_left_to_right), (M_left_right_left, M_right_left_right), \
            (V_left_to_right, V_right_to_left) = self.pam(res1_l, res2_l, is_training)

            # print("buffer_l shape after 2nd PAM left:", buffer_l.shape)
             ### parallax attention_right
            buffer_rl, buffer_r, (M_right_to_left, M_left_to_right), (M_left_right_left, M_right_left_right), \
            (V_left_to_right, V_right_to_left) = self.pam(res1_r, res2_r, is_training)

            buffer_l = self.Dblocks(buffer_l)
            buffer_r = self.Dblocks(buffer_r)

            ### creating anchors

            x_left_anch = self.anch_func([x_left] * 10)
            x_right_anch = self.anch_func([x_right] * 10)
            x_left_anch = self.conv_anch(x_left_anch)
            x_right_anch = self.conv_anch(x_right_anch)

            ### upscaling

            buffer_l = self.conv2(buffer_l)
            buffer_r = self.conv2(buffer_r)

            x_left_anch = self.conv2(x_left_anch)
            x_right_anch = self.conv2(x_right_anch)

            buffer_l = self.clippedReLU(buffer_l)
            buffer_r = self.clippedReLU(buffer_r)

            x_left_anch = self.clippedReLU(x_left_anch)
            x_right_anch = self.clippedReLU(x_right_anch)

            # print("buffer_l shape:", buffer_l.shape)
            # print("x_left_anch shape:", x_left_anch.shape)

            out_l = buffer_l + x_left_anch
            out_l = self.upscale(out_l)

            out_r = buffer_r + x_right_anch
            out_r = self.upscale(out_r)

            x_left_up = F.interpolate(x_left, scale_factor=self.upscale_factor, mode='bicubic', align_corners=False)
            x_right_up = F.interpolate(x_right, scale_factor=self.upscale_factor, mode='bicubic', align_corners=False)

            out_l = x_left_up + out_l
            out_r = x_right_up + out_r

            out_l = torch.clamp(out_l, 0., 255.)
            out_r = torch.clamp(out_r, 0., 255.)

            return out_l, out_r, (M_right_to_left, M_left_to_right), (M_left_right_left, M_right_left_right), \
                   (V_left_to_right, V_right_to_left)

        if is_training == 0:
            ### parallax attention_1
            buffer_lt, buffer_rt = self.pam(buffer_left, buffer_right, is_training)

            # print("PAM 1st PAM shape:", buffer.shape)

            ##### G-block
            res1_l = self.Dblocks(buffer_lt)
            res1_r = self.Dblocks(buffer_rt)

            # print("o/p of Dblock shape:", res1.shape)

            res2_l = self.conv1(x_left)
            res2_r = self.conv1(x_right)

            # print("res2_l shape:", res2_l.shape)
            # print("res2_r shape:", res2_r.shape)

            ### parallax attention_2
            buffer_l, buffer_lr = self.pam(res1_l, res2_l, is_training)
            buffer_rl, buffer_r = self.pam(res1_r, res2_r, is_training)

            buffer_l = self.Dblocks(buffer_l)
            buffer_r = self.Dblocks(buffer_r)

             ### creating anchors

            x_left_anch = self.anch_func([x_left] * 10)
            x_right_anch = self.anch_func([x_right] * 10)
            x_left_anch = self.conv_anch(x_left_anch)
            x_right_anch = self.conv_anch(x_right_anch)

            ### upscaling

            buffer_l = self.conv2(buffer_l)
            buffer_r = self.conv2(buffer_r)

            x_left_anch = self.conv2(x_left_anch)
            x_right_anch = self.conv2(x_right_anch)

            buffer_l = self.clippedReLU(buffer_l)
            buffer_r = self.clippedReLU(buffer_r)

            x_left_anch = self.clippedReLU(x_left_anch)
            x_right_anch = self.clippedReLU(x_right_anch)

            # print("buffer_l shape:", buffer_l.shape)
            # print("x_left_anch shape:", x_left_anch.shape)

            out_l = buffer_l + x_left_anch
            out_l = self.upscale(out_l)

            out_r = buffer_r + x_right_anch
            out_r = self.upscale(out_r)

            x_left_up = F.interpolate(x_left, scale_factor=self.upscale_factor, mode='bicubic', align_corners=False)
            x_right_up = F.interpolate(x_right, scale_factor=self.upscale_factor, mode='bicubic', align_corners=False)

            out_l = x_left_up + out_l
            out_r = x_right_up + out_r

            out_l = torch.clamp(out_l, 0., 255.)
            out_r = torch.clamp(out_r, 0., 255.)

            return out_l, out_r


if __name__ == '__main__':
    device = 'cuda'  # Change to 'cpu' if you don't have a GPU
    # device = 'cpu'
    model = XLSR_stereo_AM(4).to(device)
    model.eval()

    # # Create random left and right low-resolution stereo images (batch size = 1)
    # left_image = torch.randn(1, 3, 32, 32).to(device)
    # right_image = torch.randn(1, 3, 32, 32).to(device)

    # # Forward pass through the network
    # pred = model(left_image, right_image, is_training=1)


    # # Unpack the elements of the prediction tuple
    # if len(pred) == 5:  # Assuming the tuple contains 5 elements as per your model definition
    #     HR_pred_l, HR_pred_r, (M_right_to_left, M_left_to_right), (M_left_right_left, M_right_left_right), \
    #      (V_left_to_right, V_right_to_left) = pred
    # elif len(pred) == 2:
    #     HR_pred_l, HR_pred_r = pred



pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total parameters =", pytorch_total_params)

# Print the model summary
# summary(model, input_size=[(3, 32, 32), (3, 32, 32)])


Total parameters = 776551


In [None]:
# # train.py

# import os
# import yaml
# import argparse
# import torch
# from torch.autograd import Variable
# from torch.utils.data import DataLoader
# import torch.backends.cudnn as cudnn
# import time
# import torch.optim as optim
# import torch.optim.lr_scheduler as lr_scheduler
# import pandas as pd


# # flag = True # for storing checkpoints
# # CP = True   # for storing checkpoints



# def train(model, dataloader, criteria, device, optimizer, scheduler):
#     loss_epoch = 0.
#     criterion_L1 = L1Loss()
#     for LR_img_l, LR_img_r,HR_img_l, HR_img_r, _ , _ in dataloader:

#         optimizer.zero_grad()
#         # LR_img_l, LR_img_r = LR_img_l.to(device).float(), LR_img_r.to(device).float()
#         # HR_img_l, HR_img_r = HR_img_l.to(device).float(), HR_img_r.to(device).float()
#         LR_left, LR_right = LR_img_l.to(device).float(), LR_img_r.to(device).float()
#         HR_left, HR_right = HR_img_l.to(device).float(), HR_img_r.to(device).float()
#         # HR_pred_l, HR_pred_r, (M_right_to_left, M_left_to_right), (M_left_right_left, M_right_left_right), \
#         #  (V_left_to_right, V_right_to_left) = model(LR_img_l, LR_img_r, is_training = 1)

#         SR_left, SR_right, (M_right_to_left, M_left_to_right), (M_left_right_left, M_right_left_right), \
#          (V_left, V_right)= model(LR_left, LR_right, is_training = 1)
#         b, c, h, w = LR_left.shape

#         # loss_l = criteria(HR_pred_l, HR_img_l)
#         # loss_r = criteria(HR_pred_r, HR_img_r)

#         loss_SR = criteria(SR_left, HR_left) + criteria(SR_right, HR_right)

#         # ### loss_smoothness
#         # loss_h = criterion_L1(M_right_to_left[:, :-1, :, :], M_right_to_left[:, 1:, :, :]) + \
#         #            criterion_L1(M_left_to_right[:, :-1, :, :], M_left_to_right[:, 1:, :, :])
#         # loss_w = criterion_L1(M_right_to_left[:, :, :-1, :-1], M_right_to_left[:, :, 1:, 1:]) + \
#         #            criterion_L1(M_left_to_right[:, :, :-1, :-1], M_left_to_right[:, :, 1:, 1:])
#         # loss_smooth = loss_w + loss_h

#         # ### loss_cycle
#         # Identity = Variable(torch.eye(w, w).repeat(b, h, 1, 1), requires_grad=False).to(device)
#         # loss_cycle = criterion_L1(M_left_right_left * V_left_to_right.permute(0, 2, 1, 3), Identity * V_left_to_right.permute(0, 2, 1, 3)) + \
#         #                  criterion_L1(M_right_left_right * V_right_to_left.permute(0, 2, 1, 3), Identity * V_right_to_left.permute(0, 2, 1, 3))

#         # ### loss_photometric
#         # LR_right_warped = torch.bmm(M_right_to_left.contiguous().view(b*h,w,w), LR_img_r.permute(0,2,3,1).contiguous().view(b*h, w, c))
#         # LR_right_warped = LR_right_warped.view(b, h, w, c).contiguous().permute(0, 3, 1, 2)
#         # LR_left_warped = torch.bmm(M_left_to_right.contiguous().view(b * h, w, w), LR_img_l.permute(0, 2, 3, 1).contiguous().view(b * h, w, c))
#         # LR_left_warped = LR_left_warped.view(b, h, w, c).contiguous().permute(0, 3, 1, 2)

#         # loss_photo = criterion_L1(LR_img_l * V_left_to_right, LR_right_warped * V_left_to_right) + \
#         #                   criterion_L1(LR_img_r * V_right_to_left, LR_left_warped * V_right_to_left)

#         # ### losses
#         # loss = loss_SR + 0.005 * (loss_photo + loss_smooth + loss_cycle)

#         ''' Photometric Loss '''
#         Res_left = torch.abs(HR_left - F.interpolate(LR_left, scale_factor=scale, mode='bicubic', align_corners=False))
#         Res_left = F.interpolate(Res_left, scale_factor=1 / scale, mode='bicubic', align_corners=False)
#         Res_right = torch.abs(HR_right - F.interpolate(LR_right, scale_factor=scale, mode='bicubic', align_corners=False))
#         Res_right = F.interpolate(Res_right, scale_factor=1 / scale, mode='bicubic', align_corners=False)
#         Res_leftT = torch.bmm(M_right_to_left.contiguous().view(b * h, w, w), Res_right.permute(0, 2, 3, 1).contiguous().view(b * h, w, c)
#                                   ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2)
#         Res_rightT = torch.bmm(M_left_to_right.contiguous().view(b * h, w, w), Res_left.permute(0, 2, 3, 1).contiguous().view(b * h, w, c)
#                                    ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2)
#         loss_photo = criterion_L1(Res_left * V_left.repeat(1, 3, 1, 1), Res_leftT * V_left.repeat(1, 3, 1, 1)) + \
#                          criterion_L1(Res_right * V_right.repeat(1, 3, 1, 1), Res_rightT * V_right.repeat(1, 3, 1, 1))

#         ''' Smoothness Loss '''
#         loss_h = criterion_L1(M_right_to_left[:, :-1, :, :], M_right_to_left[:, 1:, :, :]) + \
#                      criterion_L1(M_left_to_right[:, :-1, :, :], M_left_to_right[:, 1:, :, :])
#         loss_w = criterion_L1(M_right_to_left[:, :, :-1, :-1], M_right_to_left[:, :, 1:, 1:]) + \
#                      criterion_L1(M_left_to_right[:, :, :-1, :-1], M_left_to_right[:, :, 1:, 1:])
#         loss_smooth = loss_w + loss_h

#         ''' Cycle Loss '''
#         Res_left_cycle = torch.bmm(M_right_to_left.contiguous().view(b * h, w, w), Res_rightT.permute(0, 2, 3, 1).contiguous().view(b * h, w, c)
#                                        ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2)
#         Res_right_cycle = torch.bmm(M_left_to_right.contiguous().view(b * h, w, w), Res_leftT.permute(0, 2, 3, 1).contiguous().view(b * h, w, c)
#                                         ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2)
#         loss_cycle = criterion_L1(Res_left * V_left.repeat(1, 3, 1, 1), Res_left_cycle * V_left.repeat(1, 3, 1, 1)) + \
#                          criterion_L1(Res_right * V_right.repeat(1, 3, 1, 1), Res_right_cycle * V_right.repeat(1, 3, 1, 1))

#         ''' Consistency Loss '''
#         SR_left_res = F.interpolate(torch.abs(HR_left - SR_left), scale_factor=1 / scale, mode='bicubic', align_corners=False)
#         SR_right_res = F.interpolate(torch.abs(HR_right - SR_right), scale_factor=1 / scale, mode='bicubic', align_corners=False)
#         SR_left_resT = torch.bmm(M_right_to_left.detach().contiguous().view(b * h, w, w), SR_right_res.permute(0, 2, 3, 1).contiguous().view(b * h, w, c)
#                                      ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2)
#         SR_right_resT = torch.bmm(M_left_to_right.detach().contiguous().view(b * h, w, w), SR_left_res.permute(0, 2, 3, 1).contiguous().view(b * h, w, c)
#                                       ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2)
#         loss_cons = criterion_L1(SR_left_res * V_left.repeat(1, 3, 1, 1), SR_left_resT * V_left.repeat(1, 3, 1, 1)) + \
#                        criterion_L1(SR_right_res * V_right.repeat(1, 3, 1, 1), SR_right_resT * V_right.repeat(1, 3, 1, 1))

#         ''' Total Loss '''
#         loss = loss_SR + 0.1 * loss_cons + 0.1 * (loss_photo + loss_smooth + loss_cycle)


#         # Backpropagation
#         loss.backward()
#         optimizer.step()
#         scheduler.step()
#         # print("start_step:", scheduler.last_epoch)   # me
#         # print("end_step:", scheduler.end_step)       #me
#         loss_epoch += loss.item()
#     loss_epoch /= len(dataloader)
#     lr_epoch = scheduler.get_last_lr()[0]
#     scheduler.step() # me
#     return loss_epoch, lr_epoch

# def validation(model, dataloader, criteria, device):
#     loss_epoch = 0.
#     psnr_epoch = 0.
#     pred_list_l = []
#     pred_list_r = []
#     name_list_l = []
#     name_list_r = []
#     with torch.no_grad():
#         for LR_img_l, LR_img_r, HR_img_l, HR_img_r, img_name_l, img_name_r in dataloader:
#             LR_img_l, LR_img_r = LR_img_l.to(device).float(), LR_img_r.to(device).float()
#             HR_img_l , HR_img_r =  HR_img_l.to(device).float(), HR_img_r.to(device).float()

#             HR_pred_l, HR_pred_r = model(LR_img_l, LR_img_r, is_training=0)
#             SR_left = torch.clamp(HR_pred_l, 0, 1)

#             # HR_pred_l, HR_pred_r = HR_pred[:, :3, :, :], HR_pred[:, 3:, :, :]
#             loss_l = criteria(HR_pred_l, HR_img_l)
#             # loss_r = criteria(HR_pred_r, HR_img_r)
#             loss = loss_l
#             loss_epoch += loss.item()
#             psnr_epoch_l = cal_psnr(HR_pred_l, HR_img_l)
#             # psnr_epoch_r = cal_psnr(HR_pred_r, HR_img_r)
#             psnr_batch = psnr_epoch_l
#             psnr_epoch += psnr_batch.item()
#             pred_list_l.append(HR_pred_l)
#             pred_list_r.append(HR_pred_r)
#             name_list_l += img_name_l
#             name_list_r += img_name_r
#     loss_epoch /= len(dataloader)
#     psnr_epoch /= len(dataloader)
#     return loss_epoch, psnr_epoch, pred_list_l, pred_list_r, name_list_l, name_list_r

# save_dir = "/content/drive/MyDrive/phd/wk1/phase1_baseline/Output_UP_VM_CR_b4_loss_anchor/x2/output1"
# SR_rate = 2
# pretrained_model = "/content/drive/MyDrive/phd/wk1/phase1_baseline/Output_UP_VM_CR_b4_loss_anchor/x2/output1/best.pt"
# # epochs = 5000
# epochs = 500
# batch_size = 16
# lr_max = 25e-04
# pct_epoch = 30
# augment = 'store_true'
# workers = 2
# device = 0
# div_factor = 50.0
# final_div_factor = 0.5

# scale = SR_rate
# # if os.path.exists(save_dir):
# #         print(f"Warning: {save_dir} exists, please delete it manually if it is useless.")


# # txt file to record training process in EXCEL
# # results_df = pd.DataFrame(columns=['Epoch', 'Learning Rate', 'Training Loss',
# #                                    'Validation Loss', 'PSNR', 'Time'])
# # txt_path = os.path.join(save_dir, 'results.txt')
# # if os.path.exists(txt_path):
# #         os.remove(txt_path)

# # folder to save the predicted HR image in the validation
# valid_folder = os.path.join(save_dir, 'valid_res')
# os.makedirs(valid_folder, exist_ok=True)

# device = 'cuda'
# # device = 'cpu'

# model = XLSR_stereo_AM(SR_rate)
# #model = PTSQ_quantized_model(opt[1])

# # load pretrained model        ########################## uncomment if traing needed a restart
# if pretrained_model.endswith('.pt') and os.path.exists(pretrained_model):
#         # filter conv4 weights which have different conv channels
#         model_dict = model.state_dict()
#         pretrained_dict = torch.load(pretrained_model)
#         filtered_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.shape == model_dict[k].shape}
#         model_dict.update(filtered_dict)
#         model.load_state_dict(model_dict)
#         print(f"Loaded pretrained model {pretrained_model}" )


# model.to(device)

# train_dataloader = create_dataloader('train', SR_rate, augment, batch_size,
#                                      shuffle=True, num_workers=workers)
# batch = next(iter(train_dataloader))


# valid_dataloader = create_dataloader('valid', SR_rate, False, 1,
#                                      shuffle=True, num_workers=1)

# criteria = CharbonnierLoss()
# optimizer = optim.Adam(model.parameters(), lr=lr_max/div_factor, betas=(0.9, 0.999), eps=1e-08)


# scheduler = lr_scheduler.OneCycleLR(optimizer, lr_max, epochs=epochs, steps_per_epoch=len(train_dataloader), pct_start=pct_epoch/epochs, anneal_strategy='cos', \
#                                         cycle_momentum=False, div_factor=div_factor, final_div_factor=final_div_factor)


# epoch_start = 487
# txt_path = os.path.join(save_dir, f"result_epoch_{epoch_start}.txt")
# best_psnr = 0.
# for idx in range(epoch_start, epochs+1):
#         t0 = time.time()
#         train_loss_epoch, lr_epoch = train(model, train_dataloader, criteria, device, optimizer, scheduler)
#         t1 = time.time()
#         valid_loss_epoch, psnr_epoch, pred_HR_l, pred_HR_r, img_names_l, img_names_r = validation(model, valid_dataloader, criteria, device)
#         t2 = time.time()
#         print(f"Epoch: {idx} | lr: {lr_epoch:.5f} | training loss: {train_loss_epoch:.5f} | validation loss: {valid_loss_epoch:.5f} | PSNR: {psnr_epoch:.3f} | Time: {t2-t0:.1f}")
#         with open(txt_path, 'a') as f:
#             f.write(f"Epoch: {idx} | lr: {lr_epoch:.5f} | training loss: {train_loss_epoch:.5f} | validation loss: {valid_loss_epoch:.5f} | PSNR: {psnr_epoch:.3f} | Time: {t2-t0:.1f}" +'\n')

#         if psnr_epoch > best_psnr:
#             best_psnr = psnr_epoch

#             torch.save(model.state_dict(), os.path.join(save_dir, 'best.pt'))
#             # save predicted HR image on validation set
#             save_res(pred_HR_l,pred_HR_r, img_names_l,img_names_r, valid_folder)
#         del pred_HR_l, pred_HR_r


# #  # Save the DataFrame to an Excel file after the loop
# # excel_file_path = os.path.join(save_dir, 'results.xlsx')
# # results_df.to_excel(excel_file_path, index=False)

# #  # visualize the training process
# # visualize_training(save_dir)
# # print(f"Training is finished, the best PSNR is {best_psnr:.3f}")
# # with open(excel_file_path, 'a') as f:
# #             f.write(f"Training is finished, the best PSNR is {best_psnr:.3f}")

#  # visualize the training process
# visualize_training(save_dir, epoch_start)
# print(f"Training is finished, the best PSNR is {best_psnr:.3f}")
# with open(txt_path, 'a') as f:
#             f.write(f"Training is finished, the best PSNR is {best_psnr:.3f}")



Loaded pretrained model /content/drive/MyDrive/phd/wk1/phase1_baseline/Output_UP_VM_CR_b4_loss_anchor/x2/output1/best.pt


KeyboardInterrupt: 

In [7]:
#computing the Structural SImilarity Index Metrix

#luminance computation
def l_x_y(x,y,C1):
    mean_x = torch.mean(x)
    mean_y = torch.mean(y)
    numerator = (2*mean_x*mean_y) + (C1)
    denominator = (mean_x**2)+(mean_y**2)+C1
    return (numerator/denominator)
#contrast computation
def c_x_y(x,y,C2):
    std_x = torch.std(x,unbiased=True)
    std_y = torch.std(y,unbiased=True)
    numerator = (2*std_x*std_y) + C2
    denominator = (std_x**2) + (std_y**2) + C2
    return (numerator/denominator)
#structure of the image computation
def s_x_y(x,y,C3):
    mean_x = torch.mean(x)
    mean_y = torch.mean(y)
    std_x = torch.std(x,unbiased=True)
    std_y = torch.std(y,unbiased=True)
    x_ = x - mean_x
    y_ = x - mean_y
    # sigma_x_y = torch.sum(x_ * y_)/((np.shape(x)[0]*np.shape(x)[1])-1)
    sigma_x_y = torch.sum(x_ * y_) / (torch.numel(x) - 1)  # Corrected calculation
    numerator = sigma_x_y + C3
    denominator = (std_x*std_y) + C3
    return (numerator/denominator)

#computing SSIM
def compute_SSIM(image_1,image_2):
    l_exp = 1
    c_exp = 1
    s_exp = 1
    l_ = l_x_y(image_1,image_2,0.00001)
    c_ = c_x_y(image_1,image_2,0.00001)
    s_ = s_x_y(image_1,image_2,0.00001)
    l = l_**l_exp
    c = c_**c_exp
    s = s_**s_exp
    return (l*c*s)

In [8]:
# test

import os
import argparse
import torch
import time

#from model import XLSR
#from dataset import create_dataloader
#from metric import cal_psnr
#from visualization import save_res


def test(model, dataloader, device, txt_path):
    pred_list_l = []
    pred_list_r = []
    name_list_l = []
    name_list_r = []
    avg_psnr = 0.
    avg_time = 0.
    avg_ssim = 0.
    with torch.no_grad():
        # print("warm up ...")
        random_input_l = torch.randn(1, 3, 640, 360).to(device)
        random_input_r = torch.randn(1, 3, 640, 360).to(device)

        print("Start the inference ...")
        for LR_img_l, LR_img_r, HR_img_l, HR_img_r, img_nam_l , img_nam_r in dataloader:
            LR_img_l, LR_img_r = LR_img_l.to(device).float(), LR_img_r.to(device).float()
            HR_img_l, HR_img_r = HR_img_l.to(device).float(), HR_img_r.to(device).float()
            if device != 'cpu':
                torch.cuda.synchronize()
            t0 = time.perf_counter()
            # print('LR shape = ', LR_img_l.shape,LR_img_r.shape )
            HR_pred_l, HR_pred_r = model(LR_img_l, LR_img_r, is_training=0)
            # HR_pred_l, HR_pred_r = HR_pred[:, :3, :, :], HR_pred[:, 3:, :, :]
            # print('HR_Predictions shape = ', HR_pred_l.shape,HR_pred_r.shape )
            if device != 'cpu':
                torch.cuda.synchronize()
            t1 = time.perf_counter()
            psnr_l = cal_psnr(HR_pred_l, HR_img_l).item()
            # psnr_r = cal_psnr(HR_pred_r, HR_img_r).item()
            psnr = psnr_l
            inference_time = t1 - t0
            ssim_l = compute_SSIM(HR_pred_l, HR_img_l).item()
            print(f"PSRN on {img_nam_l} and {img_nam_r} : {psnr:.3f}, inference time: {1000*inference_time:.2f}ms, SSIM = {ssim_l:.1f}")
            with open(txt_path, 'a') as f:
                f.write(f"PSRN on {img_nam_l} and {img_nam_r} : {psnr:.3f}, inference time: {1000*inference_time:.2f}ms, SSIM = {ssim_l:.1f}" + '\n')
            avg_psnr += psnr
            avg_time += inference_time
            avg_ssim += ssim_l
            pred_list_l.append(HR_pred_l)
            pred_list_r.append(HR_pred_r)
            name_list_l += img_nam_l
            name_list_r += img_nam_r
    avg_psnr /= len(test_dataloader)
    avg_time /= len(test_dataloader)
    avg_ssim /= len(test_dataloader)
    print(f"Average PSRN: {avg_psnr:.3f}, average inference time: {1000*avg_time:.2f}ms, average SSIM : {avg_ssim:.1f}")
    with open(txt_path, 'a') as f:
        f.write(f"Average PSRN: {avg_psnr:.3f}, average inference time: {1000*avg_time:.2f}ms, average SSIM : {avg_ssim:.1f}")
    return pred_list_l,pred_list_r, name_list_l, name_list_r

'''
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--save-dir', type=str, default='exp/OneCyclicLR_exp0', help='hyperparameters path')
    parser.add_argument('--SR-rate', type=int, default=3, help='the scale rate for SR')
    parser.add_argument('--model', type=str, default='', help='the path to the saved model')
    parser.add_argument('--device', type=str, default='cpu', help='gpu id or "cpu"')
    opt = parser.parse_args()
'''
save_dir = '/content/drive/MyDrive/phd/wk1/phase1_baseline/Output_UP_VM_CR_b4_loss_anchor/x2/flkr_500'
SR_rate = 2
# model_path = '/content/drive/MyDrive/phd/wk1_in_smriti_07_23_gmail/try2/output5_2/best.pt'
model_path = '/content/drive/MyDrive/phd/wk1/phase1_baseline/Output_UP_VM_CR_b4_loss_anchor/x2/output1/best.pt'
device1 = 'cuda'
# device1 = 'cpu'
img_name = 1
opt = [save_dir, SR_rate, model, device1]

os.makedirs(save_dir, exist_ok=True)

# cuDnn configurations


if device1 != 'cpu':
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = True

# txt file to record training process
txt_path = os.path.join(save_dir, f"test_res_{img_name}.txt")
if os.path.exists(txt_path):
        os.remove(txt_path)

    # folder to save the predicted HR image in the validation
test_folder = os.path.join(save_dir, f"test_res{img_name}")
os.makedirs(test_folder, exist_ok=True)

# device = 'cuda:' + str(device1) if device1 != 'cpu' else 'cpu'
device = device1
model = XLSR_stereo_AM(SR_rate)

# load pretrained model
if model_path.endswith('.pt') and os.path.exists(model_path):
        # model.load_state_dict(torch.load(model, map_location=device))
        model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
else:
        model.load_state_dict(torch.load(os.path.join(save_dir, 'best.pt'), map_location=device))
model.to(device)
model.eval()

test_dataloader = create_dataloader('test', SR_rate, False, batch_size=1, shuffle=False, num_workers=1)
print(len(test_dataloader))
 # evaluate
# pred_HR_l,pred_HR_r, img_names_l,img_names_r, valid_folder = test(model, test_dataloader, device, txt_path)
pred_HR_l,pred_HR_r, img_names_l,img_names_r = test(model, test_dataloader, device, txt_path)

print("Saving the predicted HR images")
save_res(pred_HR_l,pred_HR_r, img_names_l,img_names_r, test_folder)
print(f"Testing is done!, predicted HR images are saved in {test_folder}")

112
Start the inference ...
PSRN on ['0001_L.png'] and ['0001_R.png'] : 28.921, inference time: 2806.44ms, SSIM = 1.0
PSRN on ['0002_L.png'] and ['0002_R.png'] : 33.248, inference time: 508.30ms, SSIM = 1.0
PSRN on ['0003_L.png'] and ['0003_R.png'] : 25.835, inference time: 410.85ms, SSIM = 0.9
PSRN on ['0004_L.png'] and ['0004_R.png'] : 25.495, inference time: 1536.78ms, SSIM = 1.0
PSRN on ['0005_L.png'] and ['0005_R.png'] : 33.307, inference time: 448.05ms, SSIM = 1.0
PSRN on ['0006_L.png'] and ['0006_R.png'] : 31.225, inference time: 307.54ms, SSIM = 1.0
PSRN on ['0007_L.png'] and ['0007_R.png'] : 34.040, inference time: 63.39ms, SSIM = 1.0
PSRN on ['0008_L.png'] and ['0008_R.png'] : 33.418, inference time: 84.64ms, SSIM = 1.0
PSRN on ['0009_L.png'] and ['0009_R.png'] : 22.566, inference time: 356.15ms, SSIM = 0.9
PSRN on ['0010_L.png'] and ['0010_R.png'] : 23.040, inference time: 500.86ms, SSIM = 1.0
PSRN on ['0011_L.png'] and ['0011_R.png'] : 23.579, inference time: 96.24ms, SSIM 