In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# dataset

import cv2
import os
import torch
import random
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


def create_dataloader(split, SR_rate, augment, batch_size=1, shuffle=False, num_workers=1, pin_memory=True):
    dataset = dataread(split, SR_rate, augment)
    dataloader = DataLoader(dataset, batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory)
    # print('check 0', dataloader) #me
    return dataloader

def random_crop(LR_img, HR_img, crop_size, SR_rate):
    # check the shape
    # print('LR_img.shape', LR_img.shape) #me
    # print('HR_img.shape', HR_img.shape) #me
    LR_h, LR_w = LR_img.shape[:2]
    HR_h, HR_w = HR_img.shape[:2]
    # print('np.round_(LR_h * SR_rate)=', np.round_(LR_h * SR_rate), ', HR_h=', HR_h)
    # assert np.round_(LR_h * SR_rate) == HR_h and np.round_(LR_w * SR_rate) == HR_w, 'SR_rate is not correct for LR and HR image'
    # check the crop size
    new_LR_h, new_LR_w = crop_size
    assert new_LR_h <= LR_h and new_LR_w <= LR_w, 'crop_size is too large'

    y1 = random.randint(0, LR_h - new_LR_h)
    x1 = random.randint(0, LR_w - new_LR_w)

    LR_crop = LR_img[y1:y1 + new_LR_h, x1:x1 + new_LR_w, :]
    HR_crop = HR_img[SR_rate * y1:SR_rate * (y1 + new_LR_h), SR_rate * x1:SR_rate * (x1 + new_LR_w), :]

    return LR_crop, HR_crop

class dataread(Dataset):  # for training/testing
    def __init__(self, split, SR_rate, augment=False):

        self.split = split
        self.SR_rate = SR_rate
        self.augment = augment
        self.intensity_list = [1.0, 0.7, 0.5]
        self.crop_size = [32, 32]

        # data split
        if split == 'train':
           #self.LR_dir = os.path.join('/content/drive/MyDrive/ColabNotebooks/Research_works/dataset/DIV2K/DIV2K_train_LR_bicubic/', 'X'+str(SR_rate))
            self.LR_dir_l = os.path.join('/content/drive/MyDrive/phd/StereoSR/datasets/Flickr_modifd/LR/X2/separated/Train/left')
            self.LR_dir_r = os.path.join('/content/drive/MyDrive/phd/StereoSR/datasets/Flickr_modifd/LR/X2/separated/Train/right')
            self.HR_dir_l = '/content/drive/MyDrive/phd/StereoSR/datasets/Flickr_modifd/HR/size_corrected/hr_x2/train/left'
            self.HR_dir_r = '/content/drive/MyDrive/phd/StereoSR/datasets/Flickr_modifd/HR/size_corrected/hr_x2/train/right'
            self.img_names_l = sorted(os.listdir(self.LR_dir_l))
            self.img_names_r = sorted(os.listdir(self.LR_dir_r))
            # print('Sorted HR_l images upto 5: ', self.img_names_l)
            # print('Sorted HR_r images upto 5: ', self.img_names_r)
            # print('LR_dir_l: ', self.LR_dir_l)
            # print('LR_dir_r: ', self.LR_dir_r)
            # print('HR_dir_l: ',self.HR_dir_l)
            # print('HR_dir_r: ',self.HR_dir_r)

        elif split == 'valid':
            #self.LR_dir = os.path.join('/content/drive/MyDrive/ColabNotebooks/Research_works/dataset/DIV2K/DIV2K_train_LR_bicubic/', 'X'+str(SR_rate))
            self.LR_dir_l = os.path.join('/content/drive/MyDrive/phd/StereoSR/datasets/Flickr_modifd/LR/X2/separated/val/left')
            self.LR_dir_r = os.path.join('/content/drive/MyDrive/phd/StereoSR/datasets/Flickr_modifd/LR/X2/separated/val/right')
            self.HR_dir_l = '/content/drive/MyDrive/phd/StereoSR/datasets/Flickr_modifd/HR/size_corrected/hr_x2/val/left'
            self.HR_dir_r = '/content/drive/MyDrive/phd/StereoSR/datasets/Flickr_modifd/HR/size_corrected/hr_x2/val/right'
            self.img_names_l = sorted(os.listdir(self.HR_dir_l))
            self.img_names_r = sorted(os.listdir(self.HR_dir_r))

        elif split == 'test':
            # self.LR_dir_l = os.path.join('/content/drive/MyDrive/phd/StereoSR/datasets/Flickr_modifd/LR/X2/separated/test/left')
            # self.LR_dir_r = os.path.join('/content/drive/MyDrive/phd/StereoSR/datasets/Flickr_modifd/LR/X2/separated/test/right')
            # self.HR_dir_l = '/content/drive/MyDrive/phd/StereoSR/datasets/Flickr_modifd/HR/size_corrected/hr_x2/test/left'
            # self.HR_dir_r = '/content/drive/MyDrive/phd/StereoSR/datasets/Flickr_modifd/HR/size_corrected/hr_x2/test/right'

            self.LR_dir_l = os.path.join('/content/drive/MyDrive/phd/StereoSR/datasets/kitti_mix/LR/x2/left')
            self.LR_dir_r = os.path.join('/content/drive/MyDrive/phd/StereoSR/datasets/kitti_mix/LR/x2/right')
            self.HR_dir_l = '/content/drive/MyDrive/phd/StereoSR/datasets/kitti_mix/HR_size_crctd/x2/left'
            self.HR_dir_r = '/content/drive/MyDrive/phd/StereoSR/datasets/kitti_mix/HR_size_crctd/x2/right'

            # self.LR_dir_l = os.path.join('/content/drive/MyDrive/phd/StereoSR/datasets/Middleburry_mix/LR/x2/left')
            # self.LR_dir_r = os.path.join('/content/drive/MyDrive/phd/StereoSR/datasets/Middleburry_mix/LR/x2/right')
            # self.HR_dir_l = '/content/drive/MyDrive/phd/StereoSR/datasets/Middleburry_mix/HR_size_corct/x2/left'
            # self.HR_dir_r = '/content/drive/MyDrive/phd/StereoSR/datasets/Middleburry_mix/HR_size_corct/x2/right'

            self.img_names_l = sorted(os.listdir(self.HR_dir_l))[:19]
            self.img_names_r = sorted(os.listdir(self.HR_dir_r))[:19]
        else:
            raise NameError('data split must be "train", "valid" or "test". ')



    def __len__(self):


        if self.split == 'train':
           return len(self.img_names_l)
        else:
          return len(self.img_names_r)



    def __getitem__(self, index):

        if self.split == 'train':
            #LR_img = cv2.imread(os.path.join(self.LR_dir, self.img_names[index][:-4]+'x'+str(self.SR_rate)+'.png')) / 255.  #me

            #load  LR stereo
            LR_path_l = os.path.join(self.LR_dir_l, self.img_names_l[index])
            LR_path_r = os.path.join(self.LR_dir_r, self.img_names_r[index])
            # print('LR_path_l: ',LR_path_l)
            # print('LR_path_r: ',LR_path_r)
            # print('type_l: ', type(LR_path_l))
            # print('type_r: ', type(LR_path_r))
            LR_img1_l = cv2.imread(LR_path_l)  #me
            LR_img1_r = cv2.imread(LR_path_r)  #me
            # print(type(LR_img1_l))
            LR_img_l = LR_img1_l/255
            LR_img_r = LR_img1_r/255

            #load HR stereo
            HR_img_l = cv2.imread(os.path.join(self.HR_dir_l, self.img_names_l[index])) / 255.
            HR_img_r = cv2.imread(os.path.join(self.HR_dir_r, self.img_names_r[index])) / 255

            if self.augment:
                # random crop
                LR_img_l, HR_img_l = random_crop(LR_img_l, HR_img_l, self.crop_size, self.SR_rate)
                LR_img_r, HR_img_r = random_crop(LR_img_r, HR_img_r, self.crop_size, self.SR_rate)

                # geometric transformations
                if random.random() < 0.5: # hflip
                    LR_img_l, LR_img_r = LR_img_l[:, ::-1, :], LR_img_r[:, ::-1, :]
                    HR_img_l, HR_img_r = HR_img_l[:, ::-1, :], HR_img_r[:, ::-1, :]
                if random.random() < 0.5: # vflip
                    LR_img_l, LR_img_r = LR_img_l[::-1, :, :], LR_img_r[::-1, :, :]
                    HR_img_l, HR_img_r = HR_img_l[::-1, :, :], HR_img_r[::-1, :, :]
                if random.random() < 0.5: # rot90
                    LR_img_l, LR_img_r = LR_img_l.transpose(1, 0, 2), LR_img_r.transpose(1, 0, 2)
                    HR_img_l, HR_img_r = HR_img_l.transpose(1, 0, 2), HR_img_r.transpose(1, 0, 2)

                # intensity scale
                intensity_scale = random.choice(self.intensity_list)
                LR_img_l *= intensity_scale
                LR_img_r *= intensity_scale
                HR_img_l *= intensity_scale
                HR_img_r *= intensity_scale


        else:
            #LR_img = cv2.imread(os.path.join(self.LR_dir, self.img_names[index][:-4]+'x'+str(self.SR_rate)+'.png')) / 255. #me
            LR_img_l = cv2.imread(os.path.join(self.LR_dir_l, self.img_names_l[index])) / 255.  #me
            LR_img_r = cv2.imread(os.path.join(self.LR_dir_r, self.img_names_r[index])) / 255.
            HR_img_l = cv2.imread(os.path.join(self.HR_dir_l, self.img_names_l[index])) / 255.
            HR_img_r = cv2.imread(os.path.join(self.HR_dir_r, self.img_names_r[index])) / 255.

        # Convert
        LR_img_l = np.ascontiguousarray(LR_img_l.transpose(2, 0, 1)) # HWC => CHW
        LR_img_r = np.ascontiguousarray(LR_img_r.transpose(2, 0, 1))
        HR_img_l = np.ascontiguousarray(HR_img_l.transpose(2, 0, 1))
        HR_img_r = np.ascontiguousarray(HR_img_r.transpose(2, 0, 1))

        #return torch.from_numpy(LR_img_l), torch.from_numpy(LR_img_r), torch.from_numpy(HR_img_l), torch.from_numpy(HR_img_r), self.img_names_l[index],self.img_names_r[index]
        return torch.from_numpy(LR_img_l), torch.from_numpy(LR_img_r), torch.from_numpy(HR_img_l), torch.from_numpy(HR_img_r), self.img_names_l[index],self.img_names_r[index]

if __name__ == '__main__':
    os.makedirs('/content/drive/MyDrive/phd/wk1/phase1_baseline/output_jnl/x2/test_dataloader', exist_ok=True)
    train_dataloader = create_dataloader('train',2, False, batch_size=1, shuffle=False, num_workers=1)
    print(f"len(train_dataloader): {len(train_dataloader)}")
    #LR_img, HR_img, img_names = next(iter(train_dataloader))
    iterator = iter(train_dataloader) #me
    LR_img_l, LR_img_r, HR_img_l, HR_img_r, img_names_l, img_names_r = next(iterator)
    # #LR_img = next(iterator)
    # print(f"LR_img shape: {LR_img_l.size()}")
    # print(f"LR_img shape: {LR_img_r.size()}")
    # print(f"HR_img shape: {HR_img_l.size()}")
    # print(f"HR_img shape: {HR_img_r.size()}")
    # print('left image names in iterator: ', img_names_l, len(img_names_l))
    # print('right image names in iterator: ', img_names_r, len(img_names_r))
    LR_img_l = LR_img_l[0].numpy().transpose(1, 2, 0)
    LR_img_r = LR_img_r[0].numpy().transpose(1, 2, 0)
    HR_img_l = HR_img_l[0].numpy().transpose(1, 2, 0)
    HR_img_r = HR_img_r[0].numpy().transpose(1, 2, 0)
    cv2.imwrite('./test_dataloader/LR_img_l.png', np.uint8(LR_img_l*255))
    cv2.imwrite('./test_dataloader/LR_img_r.png', np.uint8(LR_img_r*255))
    cv2.imwrite('./test_dataloader/HR_img_l.png', np.uint8(HR_img_l*255))
    cv2.imwrite('./test_dataloader/HR_img_r.png', np.uint8(HR_img_r*255))

len(train_dataloader): 800


In [None]:
# visualization

import os
import cv2

import torch
import numpy as np
import matplotlib.pyplot as plt

def save_res(pred_HR_l, pred_HR_r, img_names_l, img_names_r, folder):
    '''

    Parameters
    ----------
    preds : List
        each pred has a shape of 1x3xHxW. BGR
    img_names : List

    Returns
    -------
    None.

    '''
    # for pred_HR_l, pred_HR_r, img_names_l, img_names_r in zip(pred_HR_l, pred_HR_r, img_names_l, img_names_r):
    #     pred_img_l = pred[0].cpu().numpy().transpose(1,2,0)
    #     pred_img_r = pred[0].cpu().numpy().transpose(1,2,0)
    #     cv2.imwrite(os.path.join(save_dir, img_names_l), np.uint8(pred_img_l*255))
    #     cv2.imwrite(os.path.join(save_dir, img_names_r), np.uint8(pred_img_r*255))
    # return

    for pred_HR_l, pred_HR_r, img_name_l, img_name_r in zip(pred_HR_l, pred_HR_r, img_names_l, img_names_r):
        pred_img_l = pred_HR_l[0].detach().cpu().numpy().transpose(1, 2, 0)
        pred_img_r = pred_HR_r[0].detach().cpu().numpy().transpose(1, 2, 0)
        cv2.imwrite(os.path.join(folder, img_name_l), np.uint8(pred_img_l * 255))
        cv2.imwrite(os.path.join(folder, img_name_r), np.uint8(pred_img_r * 255))
    return


def visualize_training(save_dir):
    txt_res = os.path.join(save_dir, 'results.txt')
    with open(txt_res, 'r') as f:
        info = f.readlines()
    epoch, lr, train_loss, valid_loss, psnr = [], [], [], [], []

    for line in info[:-1]:
        line = line.strip().split('|')
        epoch.append(int(line[0].split(':')[1]))
        lr.append(float(line[1].split(':')[1]))
        train_loss.append(float(line[2].split(':')[1]))
        valid_loss.append(float(line[3].split(':')[1]))
        psnr.append(float(line[4].split(':')[1]))

    fig = plt.figure(figsize=(16, 8), dpi=400)
    ax1 = fig.add_subplot(221)
    ax2 = fig.add_subplot(222)
    ax3 = fig.add_subplot(223)
    ax4 = fig.add_subplot(224)
    ax1.title.set_text('Training loss')
    ax2.title.set_text('Validation loss')
    ax3.title.set_text('PSNR (validation)')
    ax4.title.set_text('learning rate')
    ax1.plot(epoch, train_loss)
    ax2.plot(epoch, valid_loss)
    ax3.plot(epoch, psnr)
    ax4.plot(epoch, lr)
    plt.savefig(os.path.join(save_dir, 'results.png'), dpi=200)
    return

In [None]:
# metric


import torch

def cal_psnr(x, y):
    '''
    Parameters
    ----------
    x, y are two tensors has the same shape (1, C, H, W)

    Returns
    -------
    score : PSNR.
    '''

    mse = torch.mean((x - y) ** 2, dim=[1, 2, 3])
    score = - 10 * torch.log10(mse)
    return score

In [None]:
# loss

import torch
from torch import nn

class L1Loss(object):
    def __call__(self, input, target):
        return torch.abs(input - target).mean()

class CharbonnierLoss(nn.Module):

    def __init__(self, eps=0.01):
        super(CharbonnierLoss, self).__init__()
        self.eps = eps

    def forward(self, pred, gt):
        # print(len(pred),'&', len(gt))
        loss = torch.sqrt((pred - gt)**2 + self.eps).mean()

        return loss.mean()

In [15]:
########################## new model #############################
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import math
from torch.quantization import QuantStub, DeQuantStub
from torchsummary import summary
from skimage import morphology

class ClippedReLU(nn.Module):
    def __init__(self):
        super(ClippedReLU, self).__init__()

    def forward(self, x):
        return x.clamp(min=0., max=1.)

class ConvRelu(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, bias=True):
        super(ConvRelu, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # print("input tensor shape @ clipped relu =", x.shape)
        x = self.conv(x)
        x = self.relu(x)
        return x

class Gblock(nn.Module):
    def __init__(self, in_channels, out_channels, groups):
        super(Gblock, self).__init__()
        self.conv0 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, groups=groups)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(out_channels, out_channels, kernel_size=1, padding=0)

    def forward(self, x):
        x = self.conv0(x)
        x = self.relu(x)
        x = self.conv1(x)
        return x

class ResB(nn.Module):
    def __init__(self, channels):
        super(ResB, self).__init__()
        self.body = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1, bias=False),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(channels, channels, 3, 1, 1, bias=False),
        )
    def __call__(self,x):
        out = self.body(x)
        return out + x

class PAM(nn.Module):
    def __init__(self, channels):
        super(PAM, self).__init__()
        self.b1 = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
        self.b2 = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
        self.b3 = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
        self.softmax = nn.Softmax(-1)
        self.rb = ResB(64)
        self.fusion = nn.Conv2d(channels * 2 + 1, channels, 1, 1, 0, bias=True)
    def __call__(self, x_left, x_right, is_training):
        b, c, h, w = x_left.shape
        buffer_left = self.rb(x_left)
        buffer_right = self.rb(x_right)

        ### M_{right_to_left}
        Q = self.b1(buffer_left).permute(0, 2, 3, 1)                                                # B * H * W * C
        S = self.b2(buffer_right).permute(0, 2, 1, 3)                                               # B * H * C * W
        # print("Q shape:", Q.shape)
        # print("S shape:", S.shape)
        score = torch.bmm(Q.contiguous().view(-1, w, c),
                          S.contiguous().view(-1, c, w))                                            # (B*H) * W * W
        M_right_to_left = self.softmax(score)

        ### M_{left_to_right}
        Q = self.b1(buffer_right).permute(0, 2, 3, 1)                                               # B * H * W * C
        S = self.b2(buffer_left).permute(0, 2, 1, 3)                                                # B * H * C * W
        '''the input tensors Q and S may represent the hidden states of two
        different encoders. '''
        score = torch.bmm(Q.contiguous().view(-1, w, c),
                          S.contiguous().view(-1, c, w))                                            # (B*H) * W * W
        '''what Score do? The torch.bmm() operation can be used to calculate the
        similarity between the two hidden states (Q,S), which can be used to predict
        the relationship between the two input sequences.'''
        M_left_to_right = self.softmax(score)
        print('attention map:',M_left_to_right.shape)

        ### valid masks
        V_left_to_right = torch.sum(M_left_to_right.detach(), 1) > 0.1
        V_left_to_right = V_left_to_right.view(b, 1, h, w)                                          #  B * 1 * H * W
        V_left_to_right = morphologic_process(V_left_to_right)
        print(V_left_to_right.shape)
        if is_training==1:
            V_right_to_left = torch.sum(M_right_to_left.detach(), 1) > 0.1
            V_right_to_left = V_right_to_left.view(b, 1, h, w)                                      #  B * 1 * H * W
            V_right_to_left = morphologic_process(V_right_to_left)

            M_left_right_left = torch.bmm(M_right_to_left, M_left_to_right)
            M_right_left_right = torch.bmm(M_left_to_right, M_right_to_left)

        ### fusion
        buffer = self.b3(x_right).permute(0,2,3,1).contiguous().view(-1, w, c)                      # (B*H) * W * C
        buffer = torch.bmm(M_right_to_left, buffer).contiguous().view(b, h, w, c).permute(0,3,1,2)  #  B * C * H * W
        out = self.fusion(torch.cat((buffer, x_left, V_left_to_right), 1))
        '''This Conv2d layer in 'out' is  used to fuse the original input tensor
         with its reflection and a channel of 1s. This helps the network learn
         to distinguish between foreground and background pixels in the image.'''

        ## output
        if is_training == 1:
            return out, \
               (M_right_to_left.contiguous().view(b, h, w, w), M_left_to_right.contiguous().view(b, h, w, w)), \
               (M_left_right_left.view(b,h,w,w), M_right_left_right.view(b,h,w,w)), \
               (V_left_to_right, V_right_to_left)
        if is_training == 0:
            return out

def morphologic_process(mask):
    device = mask.device
    b,_,_,_ = mask.shape
    # mask = 1-mask
    mask = ~mask
    mask_np = mask.cpu().numpy().astype(bool)
    mask_np = morphology.remove_small_objects(mask_np, 20, 2)
    mask_np = morphology.remove_small_holes(mask_np, 10, 2)
    for idx in range(b):
        buffer = np.pad(mask_np[idx,0,:,:],((3,3),(3,3)),'constant')
        buffer = morphology.binary_closing(buffer, morphology.disk(3))
        mask_np[idx,0,:,:] = buffer[3:-3,3:-3]
    mask_np = 1-mask_np
    mask_np = mask_np.astype(float)

    return torch.from_numpy(mask_np).float().to(device)

class XLSR_stereo_AM(nn.Module):
    def __init__(self, SR_rate):
        super(XLSR_stereo_AM, self).__init__()
        ### feature extraction
        # self.init_feature  = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3,padding=0),
        #                                     nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=0),
        #                                     ConvRelu(in_channels=16, out_channels=32, kernel_size=1))
        self.init_feature  = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3,padding=1),
                                            nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=1),
                                           nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
                                            ConvRelu(in_channels=32, out_channels=64, kernel_size=1))

        ### paralax attention
        self.pam = PAM(64)

        ### Gblock
        self.Gblocks = nn.Sequential(Gblock(64, 64, 4),Gblock(64, 64, 4),
                                     Gblock(64, 64, 4))

        ### paralax attention
        # self.conv1 = ConvRelu(in_channels=3, out_channels=32, kernel_size=3)
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
                                   nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
                                   nn.LeakyReLU(0.1, inplace=True))

        self.pam = PAM(64)

        #### upscaling

        # self.upscale = nn.Sequential(
        #     nn.Conv2d(in_channels=32, out_channels=3*SR_rate**2, kernel_size=3, padding=1),
        #     nn.PixelShuffle(SR_rate))
        self.upscale = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64*SR_rate**2, kernel_size=3, padding=1),
            nn.PixelShuffle(SR_rate),
            nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, padding=1),  # No need for bias=False here
            nn.LeakyReLU(0.1, inplace=True),  # Add LeakyReLU activation after the last convolution
            nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1),  # Add one more convolution with padding
            )


        self.clippedReLU = ClippedReLU()

        # weights initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='relu')
                _, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(m.weight.data)
                std = math.sqrt(2/fan_out*0.1)
                torch.nn.init.normal_(m.weight.data, mean=0, std=std)
                if m.bias is not None:
                    nn.init.constant_(m.bias.data, 0.01)


    def forward(self, x_left, x_right, is_training=1):
        ### feature extraction
        buffer_left = self.init_feature(x_left)
        buffer_right = self.init_feature(x_right)

        # print("buffer_left shape:", buffer_left.shape)

        if is_training == 1:
            ### parallax attention
            buffer, (M_right_to_left, M_left_to_right), (M_left_right_left, M_right_left_right), \
            (V_left_to_right, V_right_to_left) = self.pam(buffer_left, buffer_right, is_training)

            # print("PAM 1st PAM shape:", buffer.shape)

            ##### G-block
            res1 = self.Gblocks(buffer)

            # print("o/p of Gblock shape:", res1.shape)

            res2_l = self.conv1(x_left)
            res2_r = self.conv1(x_right)

            # print("res2_l shape:", res2_l.shape)
            # print("res2_r shape:", res2_r.shape)

            ### parallax attention_left
            buffer_l, (M_right_to_left, M_left_to_right), (M_left_right_left, M_right_left_right), \
            (V_left_to_right, V_right_to_left) = self.pam(res1, res2_l, is_training)

            # print("buffer_l shape after 2nd PAM left:", buffer_l.shape)
             ### parallax attention_right
            buffer_r, (M_right_to_left, M_left_to_right), (M_left_right_left, M_right_left_right), \
            (V_left_to_right, V_right_to_left) = self.pam(res1, res2_r, is_training)
            ### upscaling_left
            buffer_l = self.clippedReLU(buffer_l)
            # print("buffer_l shape after 2nd PAM_right and clippedRelu:", buffer_l.shape)
            out_l = self.upscale(buffer_l)
            # print("out_l shape:", out_l.shape)
            ### upscaling_right
            buffer_r = self.clippedReLU(buffer_r)
            out_r = self.upscale(buffer_r)

            return out_l, out_r, (M_right_to_left, M_left_to_right), (M_left_right_left, M_right_left_right), \
                   (V_left_to_right, V_right_to_left)
        if is_training == 0:
            ### parallax attention
            buffer = self.pam(buffer_left, buffer_right, is_training)
            #####G-block
            res1 = self.Gblocks(buffer)
            res2_l = self.conv1(x_left)
            res2_r = self.conv1(x_right)
            ### parallax attention_left
            buffer_l = self.pam(res1, res2_l, is_training)
            ### parallax attention_right
            buffer_r = self.pam(res1, res2_r, is_training)
            ### upscaling_left
            buffer_l = self.clippedReLU(buffer_l)
            out_l = self.upscale(buffer_l)
            ### upscaling_right
            buffer_r = self.clippedReLU(buffer_r)
            out_r = self.upscale(buffer_r)

            return out_l, out_r

def print_model_summary(model):
    print("--------------- Model Summary ---------------")
    total_params = 0
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Module):
            print(f"{name:30} -> {str(module):}")
            total_params += sum(p.numel() for p in module.parameters())
    print(f"Total Trainable Parameters: {total_params}")
    print("----------------------------------------------")


if __name__ == '__main__':
    device = 'cpu'  # Change to 'cpu' if you don't have a GPU
    model = XLSR_stereo_AM(3).to(device)
    model.eval()

    # Create random left and right low-resolution stereo images (batch size = 1)
    left_image = torch.randn(1, 3, 32, 32).to(device)
    right_image = torch.randn(1, 3, 32, 32).to(device)

    # Forward pass through the network
    pred = model(left_image, right_image, is_training=1)

    print_model_summary(model)

    # Unpack the elements of the prediction tuple
    if len(pred) == 5:  # Assuming the tuple contains 5 elements as per your model definition
        HR_pred_l, HR_pred_r, (M_right_to_left, M_left_to_right), (M_left_right_left, M_right_left_right), \
         (V_left_to_right, V_right_to_left) = pred
    elif len(pred) == 2:
        HR_pred_l, HR_pred_r = pred



pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total parameters =", pytorch_total_params)

# Print the model summary
# summary(model, input_size=[(3, 32, 32), (3, 32, 32)])
summary(model, input_size=[(3, 32, 32), (3, 32, 32)])


attention map: torch.Size([32, 32, 32])
torch.Size([1, 1, 32, 32])
attention map: torch.Size([32, 32, 32])
torch.Size([1, 1, 32, 32])
attention map: torch.Size([32, 32, 32])
torch.Size([1, 1, 32, 32])
--------------- Model Summary ---------------
                               -> XLSR_stereo_AM(
  (init_feature): Sequential(
    (0): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ConvRelu(
      (conv): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
      (relu): ReLU(inplace=True)
    )
  )
  (pam): PAM(
    (b1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
    (b2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
    (b3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
    (softmax): Softmax(dim=-1)
    (rb): ResB(
      (body): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1,

In [None]:
# train.py

import os
import yaml
import argparse
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
import time
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import pandas as pd


# flag = True # for storing checkpoints
# CP = True   # for storing checkpoints



def train(model, dataloader, criteria, device, optimizer, scheduler):
    loss_epoch = 0.
    criterion_L1 = L1Loss()
    for LR_img_l, LR_img_r,HR_img_l, HR_img_r, _ , _ in dataloader:

        optimizer.zero_grad()
        LR_img_l, LR_img_r = LR_img_l.to(device).float(), LR_img_r.to(device).float()
        HR_img_l, HR_img_r = HR_img_l.to(device).float(), HR_img_r.to(device).float()
        # print(len(LR_img_l), len(LR_img_r), len(HR_img_l), len(HR_img_r))  ##### check
        # HR_pred_l, HR_pred_r = model(LR_img_l, LR_img_r)
        HR_pred_l, HR_pred_r, (M_right_to_left, M_left_to_right), (M_left_right_left, M_right_left_right), \
         (V_left_to_right, V_right_to_left) = model(LR_img_l, LR_img_r, is_training = 1)

        b, c, h, w = LR_img_l.shape

        # print('predct_r & l=',HR_pred_r.shape, '&',HR_pred_l.shape)   ##### check/
        # print('gt_r & l=',HR_img_r.shape, '&',HR_img_l.shape)   ##### check
        loss_l = criteria(HR_pred_l, HR_img_l)
        # loss_r = criteria(HR_pred_r, HR_img_r)
        loss_SR = loss_l

        ### loss_smoothness
        loss_h = criterion_L1(M_right_to_left[:, :-1, :, :], M_right_to_left[:, 1:, :, :]) + \
                   criterion_L1(M_left_to_right[:, :-1, :, :], M_left_to_right[:, 1:, :, :])
        loss_w = criterion_L1(M_right_to_left[:, :, :-1, :-1], M_right_to_left[:, :, 1:, 1:]) + \
                   criterion_L1(M_left_to_right[:, :, :-1, :-1], M_left_to_right[:, :, 1:, 1:])
        loss_smooth = loss_w + loss_h

        ### loss_cycle
        Identity = Variable(torch.eye(w, w).repeat(b, h, 1, 1), requires_grad=False).to(device)
        loss_cycle = criterion_L1(M_left_right_left * V_left_to_right.permute(0, 2, 1, 3), Identity * V_left_to_right.permute(0, 2, 1, 3)) + \
                         criterion_L1(M_right_left_right * V_right_to_left.permute(0, 2, 1, 3), Identity * V_right_to_left.permute(0, 2, 1, 3))

        ### loss_photometric
        LR_right_warped = torch.bmm(M_right_to_left.contiguous().view(b*h,w,w), LR_img_r.permute(0,2,3,1).contiguous().view(b*h, w, c))
        LR_right_warped = LR_right_warped.view(b, h, w, c).contiguous().permute(0, 3, 1, 2)
        LR_left_warped = torch.bmm(M_left_to_right.contiguous().view(b * h, w, w), LR_img_l.permute(0, 2, 3, 1).contiguous().view(b * h, w, c))
        LR_left_warped = LR_left_warped.view(b, h, w, c).contiguous().permute(0, 3, 1, 2)

        loss_photo = criterion_L1(LR_img_l * V_left_to_right, LR_right_warped * V_left_to_right) + \
                          criterion_L1(LR_img_r * V_right_to_left, LR_left_warped * V_right_to_left)

        ### losses
        loss = loss_SR + 0.005 * (loss_photo + loss_smooth + loss_cycle)


        # Backpropagation
        loss.backward()
        optimizer.step()
        scheduler.step()
        # print("start_step:", scheduler.last_epoch)   # me
        # print("end_step:", scheduler.end_step)       #me
        loss_epoch += loss.item()
    loss_epoch /= len(dataloader)
    lr_epoch = scheduler.get_last_lr()[0]
    scheduler.step() # me
    return loss_epoch, lr_epoch

def validation(model, dataloader, criteria, device):
    loss_epoch = 0.
    psnr_epoch = 0.
    pred_list_l = []
    pred_list_r = []
    name_list_l = []
    name_list_r = []
    with torch.no_grad():
        for LR_img_l, LR_img_r, HR_img_l, HR_img_r, img_name_l, img_name_r in dataloader:
            LR_img_l, LR_img_r = LR_img_l.to(device).float(), LR_img_r.to(device).float()
            HR_img_l , HR_img_r =  HR_img_l.to(device).float(), HR_img_r.to(device).float()

            HR_pred_l, HR_pred_r = model(LR_img_l, LR_img_r, is_training=0)
            SR_left = torch.clamp(HR_pred_l, 0, 1)

            # HR_pred_l, HR_pred_r = HR_pred[:, :3, :, :], HR_pred[:, 3:, :, :]
            loss_l = criteria(HR_pred_l, HR_img_l)
            # loss_r = criteria(HR_pred_r, HR_img_r)
            loss = loss_l
            loss_epoch += loss.item()
            psnr_epoch_l = cal_psnr(HR_pred_l, HR_img_l)
            # psnr_epoch_r = cal_psnr(HR_pred_r, HR_img_r)
            psnr_batch = psnr_epoch_l
            psnr_epoch += psnr_batch.item()
            pred_list_l.append(HR_pred_l)
            pred_list_r.append(HR_pred_r)
            name_list_l += img_name_l
            name_list_r += img_name_r
    loss_epoch /= len(dataloader)
    psnr_epoch /= len(dataloader)
    return loss_epoch, psnr_epoch, pred_list_l, pred_list_r, name_list_l, name_list_r

save_dir = "/content/drive/MyDrive/phd/wk1/phase1_baseline/output_jnl/x2/output4"
SR_rate = 2
pretrained_model = "/content/drive/MyDrive/phd/wk1/phase1_baseline/output_jnl/x2/output3/best.pt"
# epochs = 5000
epochs = 100
batch_size = 16
lr_max = 25e-04
pct_epoch = 30
augment = 'store_true'
workers = 2
device = 0
div_factor = 50.0
final_div_factor = 0.5

# if os.path.exists(save_dir):
#         print(f"Warning: {save_dir} exists, please delete it manually if it is useless.")


# txt file to record training process in EXCEL
# results_df = pd.DataFrame(columns=['Epoch', 'Learning Rate', 'Training Loss',
#                                    'Validation Loss', 'PSNR', 'Time'])
txt_path = os.path.join(save_dir, 'results.txt')
if os.path.exists(txt_path):
        os.remove(txt_path)

# folder to save the predicted HR image in the validation
valid_folder = os.path.join(save_dir, 'valid_res')
os.makedirs(valid_folder, exist_ok=True)

device = 'cuda'

model = XLSR_stereo_AM(SR_rate)
#model = PTSQ_quantized_model(opt[1])

# load pretrained model        ########################## uncomment if traing needed a restart
if pretrained_model.endswith('.pt') and os.path.exists(pretrained_model):
        # filter conv4 weights which have different conv channels
        model_dict = model.state_dict()
        pretrained_dict = torch.load(pretrained_model)
        filtered_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.shape == model_dict[k].shape}
        model_dict.update(filtered_dict)
        model.load_state_dict(model_dict)
        print(f"Loaded pretrained model {pretrained_model}" )


model.to(device)

train_dataloader = create_dataloader('train', SR_rate, augment, batch_size,
                                     shuffle=True, num_workers=workers)
batch = next(iter(train_dataloader))


valid_dataloader = create_dataloader('valid', SR_rate, False, 1,
                                     shuffle=True, num_workers=1)

criteria = CharbonnierLoss()
optimizer = optim.Adam(model.parameters(), lr=lr_max/div_factor, betas=(0.9, 0.999), eps=1e-08)


scheduler = lr_scheduler.OneCycleLR(optimizer, lr_max, epochs=epochs, steps_per_epoch=len(train_dataloader), pct_start=pct_epoch/epochs, anneal_strategy='cos', \
                                        cycle_momentum=False, div_factor=div_factor, final_div_factor=final_div_factor)


epoch_start = 1
best_psnr = 0.
for idx in range(epoch_start, epochs+1):
        t0 = time.time()
        train_loss_epoch, lr_epoch = train(model, train_dataloader, criteria, device, optimizer, scheduler)
        t1 = time.time()
        valid_loss_epoch, psnr_epoch, pred_HR_l, pred_HR_r, img_names_l, img_names_r = validation(model, valid_dataloader, criteria, device)
        t2 = time.time()

        # result_row = {
        # 'Epoch': idx,
        # 'Learning Rate': lr_epoch,
        # 'Training Loss': train_loss_epoch,
        # 'Validation Loss': valid_loss_epoch,
        # 'PSNR': psnr_epoch,
        # 'Time': t2 - t0 }
        # results_df = results_df.concat(result_row, ignore_index=True)

        print(f"Epoch: {idx} | lr: {lr_epoch:.5f} | training loss: {train_loss_epoch:.5f} | validation loss: {valid_loss_epoch:.5f} | PSNR: {psnr_epoch:.3f} | Time: {t2-t0:.1f}")
        with open(txt_path, 'a') as f:
            f.write(f"Epoch: {idx} | lr: {lr_epoch:.5f} | training loss: {train_loss_epoch:.5f} | validation loss: {valid_loss_epoch:.5f} | PSNR: {psnr_epoch:.3f} | Time: {t2-t0:.1f}" +'\n')

        if psnr_epoch > best_psnr:
            best_psnr = psnr_epoch

            torch.save(model.state_dict(), os.path.join(save_dir, 'best.pt'))
            # save predicted HR image on validation set
            save_res(pred_HR_l,pred_HR_r, img_names_l,img_names_r, valid_folder)
        del pred_HR_l, pred_HR_r


#  # Save the DataFrame to an Excel file after the loop
# excel_file_path = os.path.join(save_dir, 'results.xlsx')
# results_df.to_excel(excel_file_path, index=False)

#  # visualize the training process
# visualize_training(save_dir)
# print(f"Training is finished, the best PSNR is {best_psnr:.3f}")
# with open(excel_file_path, 'a') as f:
#             f.write(f"Training is finished, the best PSNR is {best_psnr:.3f}")

 # visualize the training process
visualize_training(save_dir)
print(f"Training is finished, the best PSNR is {best_psnr:.3f}")
with open(txt_path, 'a') as f:
            f.write(f"Training is finished, the best PSNR is {best_psnr:.3f}")



Loaded pretrained model /content/drive/MyDrive/phd/wk1/phase1_baseline/output_jnl/x2/output3/best.pt
Epoch: 1 | lr: 0.00006 | training loss: 0.10734 | validation loss: 0.11235 | PSNR: 25.403 | Time: 1869.1
Epoch: 2 | lr: 0.00008 | training loss: 0.10732 | validation loss: 0.11229 | PSNR: 25.445 | Time: 114.7
Epoch: 3 | lr: 0.00011 | training loss: 0.10735 | validation loss: 0.11254 | PSNR: 25.335 | Time: 112.2
Epoch: 4 | lr: 0.00016 | training loss: 0.10732 | validation loss: 0.11295 | PSNR: 25.124 | Time: 109.0
Epoch: 5 | lr: 0.00022 | training loss: 0.10672 | validation loss: 0.11496 | PSNR: 24.381 | Time: 108.9
Epoch: 6 | lr: 0.00029 | training loss: 0.10775 | validation loss: 0.11346 | PSNR: 24.898 | Time: 109.6
Epoch: 7 | lr: 0.00038 | training loss: 0.10768 | validation loss: 0.11290 | PSNR: 25.134 | Time: 106.6
Epoch: 8 | lr: 0.00047 | training loss: 0.10810 | validation loss: 0.11277 | PSNR: 25.213 | Time: 109.9
Epoch: 9 | lr: 0.00057 | training loss: 0.10750 | validation loss:

In [None]:
# test

import os
import argparse
import torch
import time

#from model import XLSR
#from dataset import create_dataloader
#from metric import cal_psnr
#from visualization import save_res


def test(model, dataloader, device, txt_path):
    pred_list_l = []
    pred_list_r = []
    name_list_l = []
    name_list_r = []
    avg_psnr = 0.
    avg_time = 0.
    with torch.no_grad():
        # print("warm up ...")
        random_input_l = torch.randn(1, 3, 640, 360).to(device)
        random_input_r = torch.randn(1, 3, 640, 360).to(device)
        # # warm up
        # for _ in range(10):
        #     model(random_input_l, random_input_r, is_training=0 )

        # with torch.autograd.profiler.profile() as prof:
        #     model(random_input_l, random_input_r, is_training=0)
        # print(prof.key_averages().table(sort_by="self_cpu_time_total"))

        # print("Start testing the model speed on 640*360 input ...")
        # test_t = 0.
        # for idx in range(100):
        #     if device != 'cpu':
        #         torch.cuda.synchronize()
        #     t0 = time.perf_counter()
        #     model(random_input_l, random_input_r, is_training=0)
        #     if device != 'cpu':
        #         torch.cuda.synchronize()
        #     t1 = time.perf_counter()
        #     print(f"Inference #{idx}, inference time: {1000*(t1-t0):.2f}ms")
        #     test_t += t1 - t0
        # print(f"Average inference time on 640*360 input: {1000*test_t/100:.2f}ms")
        # with open(txt_path, 'a') as f:
        #     f.write(f"Average inference time on 640*360 input: {1000*test_t/100:.2f}ms" + '\n')

        print("Start the inference ...")
        for LR_img_l, LR_img_r, HR_img_l, HR_img_r, img_nam_l , img_nam_r in dataloader:
            LR_img_l, LR_img_r = LR_img_l.to(device).float(), LR_img_r.to(device).float()
            HR_img_l, HR_img_r = HR_img_l.to(device).float(), HR_img_r.to(device).float()
            if device != 'cpu':
                torch.cuda.synchronize()
            t0 = time.perf_counter()
            # print('LR shape = ', LR_img_l.shape,LR_img_r.shape )
            HR_pred_l, HR_pred_r = model(LR_img_l, LR_img_r, is_training=0)
            # HR_pred_l, HR_pred_r = HR_pred[:, :3, :, :], HR_pred[:, 3:, :, :]
            # print('HR_Predictions shape = ', HR_pred_l.shape,HR_pred_r.shape )
            if device != 'cpu':
                torch.cuda.synchronize()
            t1 = time.perf_counter()
            psnr_l = cal_psnr(HR_pred_l, HR_img_l).item()
            # psnr_r = cal_psnr(HR_pred_r, HR_img_r).item()
            psnr = psnr_l
            inference_time = t1 - t0
            print(f"PSRN on {img_nam_l} and {img_nam_r} : {psnr:.3f}, inference time: {1000*inference_time:.2f}ms")
            with open(txt_path, 'a') as f:
                f.write(f"PSRN on {img_nam_l} and {img_nam_r} : {psnr:.3f}, inference time: {1000*inference_time:.2f}ms" + '\n')
            avg_psnr += psnr
            avg_time += inference_time
            pred_list_l.append(HR_pred_l)
            pred_list_r.append(HR_pred_r)
            name_list_l += img_nam_l
            name_list_r += img_nam_r
    avg_psnr /= len(test_dataloader)
    avg_time /= len(test_dataloader)
    print(f"Average PSRN: {avg_psnr:.3f}, average inference time: {1000*avg_time:.2f}ms")
    with open(txt_path, 'a') as f:
        f.write(f"Average PSRN: {avg_psnr:.3f}, average inference time: {1000*avg_time:.2f}ms")
    return pred_list_l,pred_list_r, name_list_l, name_list_r

'''
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--save-dir', type=str, default='exp/OneCyclicLR_exp0', help='hyperparameters path')
    parser.add_argument('--SR-rate', type=int, default=3, help='the scale rate for SR')
    parser.add_argument('--model', type=str, default='', help='the path to the saved model')
    parser.add_argument('--device', type=str, default='cpu', help='gpu id or "cpu"')
    opt = parser.parse_args()
'''
save_dir = '/content/drive/MyDrive/phd/wk1/phase1_baseline/output_jnl/x2/kitti_final'
SR_rate = 2
# model_path = '/content/drive/MyDrive/phd/wk1_in_smriti_07_23_gmail/try2/output5_2/best.pt'
model_path = '/content/drive/MyDrive/phd/wk1/phase1_baseline/output_jnl/x2/output4/best.pt'
device1 = 'cuda'
opt = [save_dir, SR_rate, model, device1]

os.makedirs(save_dir, exist_ok=True)

# cuDnn configurations

if device1 != 'cpu':
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = True

# txt file to record training process
txt_path = os.path.join(save_dir, 'test_res.txt')
if os.path.exists(txt_path):
        os.remove(txt_path)

    # folder to save the predicted HR image in the validation
test_folder = os.path.join(save_dir, 'test_res')
os.makedirs(test_folder, exist_ok=True)

# device = 'cuda:' + str(device1) if device1 != 'cpu' else 'cpu'
device = device1
model = XLSR_stereo_AM(SR_rate)

# load pretrained model
if model_path.endswith('.pt') and os.path.exists(model_path):
        # model.load_state_dict(torch.load(model, map_location=device))
        model.load_state_dict(torch.load(model_path))
else:
        model.load_state_dict(torch.load(os.path.join(save_dir, 'best.pt'), map_location=device))
model.to(device)
model.eval()

test_dataloader = create_dataloader('test', SR_rate, False, batch_size=1, shuffle=False, num_workers=1)
print(len(test_dataloader))
 # evaluate
# pred_HR_l,pred_HR_r, img_names_l,img_names_r, valid_folder = test(model, test_dataloader, device, txt_path)
pred_HR_l,pred_HR_r, img_names_l,img_names_r = test(model, test_dataloader, device, txt_path)

print("Saving the predicted HR images")
save_res(pred_HR_l,pred_HR_r, img_names_l,img_names_r, test_folder)
print(f"Testing is done!, predicted HR images are saved in {test_folder}")

19
Start the inference ...
PSRN on ['000_L.png'] and ['000_R.png'] : 26.731, inference time: 817.64ms
PSRN on ['001_L.png'] and ['001_R.png'] : 29.520, inference time: 196.35ms
PSRN on ['002_L.png'] and ['002_R.png'] : 28.446, inference time: 738.49ms
PSRN on ['003_L.png'] and ['003_R.png'] : 30.178, inference time: 219.79ms
PSRN on ['004_L.png'] and ['004_R.png'] : 30.421, inference time: 728.56ms
PSRN on ['005_L.png'] and ['005_R.png'] : 29.493, inference time: 212.54ms
PSRN on ['006_L.png'] and ['006_R.png'] : 29.098, inference time: 215.13ms
PSRN on ['007_L.png'] and ['007_R.png'] : 26.588, inference time: 201.05ms
PSRN on ['008_L.png'] and ['008_R.png'] : 28.440, inference time: 234.97ms
PSRN on ['009_L.png'] and ['009_R.png'] : 29.021, inference time: 745.72ms
PSRN on ['010_L.png'] and ['010_R.png'] : 28.384, inference time: 197.23ms
PSRN on ['011_L.png'] and ['011_R.png'] : 27.690, inference time: 197.08ms
PSRN on ['012_L.png'] and ['012_R.png'] : 31.332, inference time: 209.69m