In [7]:
import os , glob
import numpy as np
from PIL import Image
from tqdm import tqdm
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.models import vgg16
import torchvision.transforms as transforms
from torchvision.utils import make_grid

In [8]:


class Interpolate(nn.Module):
    def __init__(self, scale_factor, mode):
        super(Interpolate, self).__init__()
        self.interp = nn.functional.interpolate
        self.scale_factor = scale_factor
        self.mode = mode
        
    def forward(self, x):
        x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=True)

        return x
    
    
def gridImage(imgs):

    img_grid = make_grid(imgs[0], nrow=1, normalize=False)

    for i in range(1,len(imgs)):
        img_grid = torch.cat((img_grid, make_grid(imgs[i], nrow=1, normalize=False)), -1)

    return img_grid

In [9]:
"""  
arima losss type
loss_edge_corr = criterion_edge_corr(reduce(lambda x,y:x*y,GCLoss(*net(Input)), 0)
loss_edge_sum = criterion_edge_sum(gradB + gradR, imgs_grad_label)
loss_edge = (epoch - 1) * (loss_edge_sum + loss_edge_corr)   
"""


class GCLoss(nn.Module):
    def __init__(self):
        super(GCLoss, self).__init__()

        #sobelfilters , lib not workign
        sobel_x = torch.Tensor([[1,0,-1],[2,0,-2],[1,0,-1]]).view((1,1,3,3)).repeat(1,3,1,1)
        sobel_y = torch.Tensor([[1,2,1],[0,0,0],[-1,-2,-1]]).view((1,1,3,3)).repeat(1,3,1,1)

        self.gxb = nn.Conv2d(3,1,kernel_size=3,stride=1,padding=0,bias=False)
        self.gxb.weight = nn.Parameter(sobel_x)
        for param in self.gxb.parameters():
            param.requires_grad = False

        self.gyb = nn.Conv2d(3,1,kernel_size=3,stride=1,padding=0,bias=False)
        self.gyb.weight = nn.Parameter(sobel_y)
        for param in self.gyb.parameters():
            param.requires_grad = False


        self.gxr = nn.Conv2d(3,1,kernel_size=3,stride=1,padding=0,bias=False)
        self.gxr.weight = nn.Parameter(sobel_x)
        for param in self.gxr.parameters():
            param.requires_grad = False
        
        self.gyr = nn.Conv2d(3,1,kernel_size=3,stride=1,padding=0,bias=False)
        self.gyr.weight = nn.Parameter(sobel_y)
        for param in self.gyr.parameters():
            param.requires_grad = False

        self.af_B = nn.Tanhshrink()
        self.af_R = nn.Tanhshrink()
        

    def forward(self, B, R):

        gradout_B = self.af_B(self.gyb(B) + self.gxb(B))
        gradout_R = self.af_R(self.gyr(R) + self.gxr(R))
        return gradout_B, gradout_R

In [10]:


#values in stateoftheart
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
padsize = 150

class testImageDataset(Dataset):
    def __init__(self, root):
        
        self.tensor_setup = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        )

        self.files = sorted(glob.glob(root + "/input/*.*"))

    def __getitem__(self, index):
        filePath = self.files[index % len(self.files)]
        R = np.array(Image.open(filePath),'f') / 255.
        R = np.pad(R,[(padsize,padsize),(padsize,padsize),(0,0)],'symmetric')

        return {"R": self.tensor_setup(R[:,:,:3]), "Name": os.path.basename(filePath).split(".")[0]}

    def __len__(self):
        return len(self.files)



class gtTestImageDataset(Dataset):
    def __init__(self, root):
        
        self.tensor_setup = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        )
  
        self.files_base = sorted(glob.glob(root + "/gt/*.*"))
        self.files_input = sorted(glob.glob(root + "/input/*.*"))

    def __getitem__(self, index):
        filePath = self.files_base[index % len(self.files_base)]
        B = np.array(Image.open(filePath),'f') /255.
        R = np.array(Image.open(self.files_input[index % len(self.files_base)]),'f') / 255. 
        R = np.pad(R,[(padsize,padsize),(padsize,padsize),(0,0)],'symmetric')

        return {"R": self.tensor_setup(R[:,:,:3]), 
                
                "B": B[:,:,:3], 
                "Name": os.path.basename(filePath).split(".")[0]}

    def __len__(self):
        return len(self.files_base)

In [11]:


class ConvBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels, activation=nn.LeakyReLU(negative_slope=0.01,inplace=True)):
        super(ConvBlock, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, middle_channels, 3, padding=1),
            nn.BatchNorm2d(middle_channels),
            activation,
            nn.Conv2d(middle_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            activation)
    def forward(self, x): return self.model(x)


class GCNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, n_residual_blocks=16):
        super(GCNet, self).__init__()


        filt = [32, 64, 128, 256, 512]   #  filtersssssss

        self.pool = nn.MaxPool2d(2, 2)
        #self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.up = Interpolate(scale_factor=2, mode='bilinear')


        # self.conv0_0 = ConvBlock(in_channels, filt[0], filt[0])
        # self.conv1_0 = ConvBlocks(filt[0], filt[1], filt[1])
        # self.conv2_0 = ConvBlock(filt[1], filt[2], filt[2])
        # self.conv3_0 = ConvBlock(filt[2], filt[3], filt[3])
        # self.conv4_0 = ConvBlock(filt[3], filt[4], filt[4])
        
        
        self.conv0_0 = ConvBlock(in_channels, filt[0], filt[0])
        self.conv1_0= ConvBlock(filt[0]  , filt[1], filt[1])
        self.conv2_0= ConvBlock(filt[1],  filt[2], filt[2])
        self.conv3_0 = ConvBlock(filt[2] , filt[3], filt[3])
        self.conv4_0= ConvBlock( filt[3] , filt[4], filt[4])

        self.conv0_1 = ConvBlock(filt[0]+filt[1], filt[0], filt[0])
        self.conv1_1 = ConvBlock(filt[1]+filt[2], filt[1], filt[1])
        # self.conv1_1 = ConvBlock(filt[1]+filt[2], filt[3], filt[3])
        self.conv2_1 = ConvBlock(filt[2]+filt[3], filt[2], filt[2])
        # self.conv2_1 = ConvBlock(filt[2]+filt[3], filt[3], filt[3])
        self.conv3_1 = ConvBlock(filt[3]+filt[4], filt[3], filt[3])

        self.conv0_2 = ConvBlock(filt[0]*2+filt[1], filt[0], filt[0])
        self.conv1_2 = ConvBlock(filt[1]*2+filt[2], filt[1], filt[1])
        self.conv2_2 = ConvBlock(filt[2]*2+filt[3], filt[2], filt[2])
        # self.conv2_2 = ConvBlock(filt[2]*2+filt[3], filt[3], filt[3])


        self.conv0_3 = ConvBlock(filt[0]*3+filt[1], filt[0], filt[0])
        self.conv1_3 = ConvBlock(filt[1]*3+filt[2], filt[1], filt[1])

        self.conv0_4 = ConvBlock(filt[0]*4+filt[1], filt[0], filt[0])

        self.final1 = nn.Sequential(nn.Conv2d(filt[0], out_channels, kernel_size=3, padding=1))
        self.final2 = nn.Sequential(nn.Conv2d(filt[0], out_channels, kernel_size=3, padding=1))
        self.final3 = nn.Sequential(nn.Conv2d(filt[0], out_channels, kernel_size=3, padding=1))
        self.final4 = nn.Sequential(nn.Conv2d(filt[0], filt[0], 5, padding=2),
            nn.BatchNorm2d(filt[0]),
            nn.LeakyReLU(negative_slope=0.01,inplace=True),
            nn.Conv2d(filt[0], out_channels, kernel_size=3, padding=1))

        self.G_x_D = nn.Conv2d(3,1,kernel_size=3,stride=1,padding=0,bias=False)
        self.G_y_D = nn.Conv2d(3,1,kernel_size=3,stride=1,padding=0,bias=False)
        self.G_x_G = nn.Conv2d(3,1,kernel_size=3,stride=1,padding=0,bias=False)
        self.G_y_G = nn.Conv2d(3,1,kernel_size=3,stride=1,padding=0,bias=False)

    def forward(self, x):
        x0_0 = self.conv0_0(x)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))

        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))

        output4 = self.final4(x0_4)

        return output4


In [14]:



def convert_to_numpy(input,H,W):
    image = input[:,:,padsize:H-padsize,padsize:W-padsize].clone()
    input_numpy = image[:,:,:H,:W].clone().cpu().numpy().reshape(3,H-padsize*2,W-padsize*2).transpose(1,2,0)
    for i in range(3):
        input_numpy[:,:,i] = input_numpy[:,:,i] * std[i] + mean[i]

    return  input_numpy

dataset = "test_dataset"

# make output directory
os.makedirs("../.dataset/" + dataset + "/output", exist_ok=True)

if torch.cuda.is_available(): device , torch.backends.cudnn.benchmark = 'cuda' , True
else:device = 'cpu'

# Initialize generator
Generator = GCNet().to(device)
Generator.eval()
Generator.load_state_dict(torch.load("../checkpoint/model_gc.pth", map_location=device,weights_only=True))

# read image
gtAvailable = False
if os.path.exists("../.dataset/" + dataset + "/gt"):
    if len(os.listdir("../.dataset/" + dataset + "/input")) == len(os.listdir("../.dataset/" + dataset + "/gt")):
        gtAvailable = True

if gtAvailable:
    image_dataset = gtTestImageDataset("../.dataset/" + dataset)
else:
    image_dataset = testImageDataset("../.dataset/" + dataset)

# run
all_psnr = 0.0
all_ssim = 0.0
print("[Dataset name: %s] --> %d images" % (dataset, len(image_dataset)))
for image_num in tqdm(range(len(image_dataset)),ncols=100):

    data = image_dataset[image_num]
    R = data["R"].to(device)

    _,first_h,first_w = R.size()
    R = torch.nn.functional.pad(R,(0,(R.size(2)//16)*16+16-R.size(2),0,(R.size(1)//16)*16+16-R.size(1)),"constant")
    R = R.view(1,3,R.size(1),R.size(2))
    with torch.no_grad():output  = Generator(R) 

    #output image
    output_np = np.clip(convert_to_numpy(output,first_h,first_w) + 0.015,0,1)
    R_np = convert_to_numpy(R,first_h,first_w)
    final_output = np.fmin(output_np, R_np)

    # save output
    Image.fromarray(np.uint8(final_output * 255)).save("../.dataset/" + dataset + "/output/" + data["Name"] + ".png")

    # Calculate PSNR/SSIM if available
    if gtAvailable:
        B = data["B"].astype(np.float32)
        # print(B.shape,final_output.shape)
        thisPSNR = psnr(B, final_output.astype(np.float32))
        thisSSIM = ssim(B, final_output.astype(np.float32), multichannel=True, channel_axis=2,data_range=255.0)
        all_psnr += thisPSNR
        all_ssim += thisSSIM
        print("[%s] PSNR:%4.2f SSIM:%4.3f" % (data["Name"], thisPSNR, thisSSIM), end="\r")

if gtAvailable:
    all_psnr = all_psnr/len(image_dataset)
    all_ssim = all_ssim/len(image_dataset)
    print("hogya : [[%s]]" % (dataset))
    print("PSNR: %4.2f / SSIM: %4.3f" % (all_psnr, all_ssim))
else:
    print("Complete.")

[Dataset name: test_dataset] --> 2 images


100%|█████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00,  1.73s/it]

Complete.



