## Transweather model class

In [1]:
from transweather_model import Transweather
from torchinfo import summary
from functools import partial
from torch import nn


model = Transweather()

summary(model,
        input_size=[1,3,256,256], 
        verbose=0,
        col_names=["input_size", "output_size", "num_params"],
        col_width=20,
        row_settings=["var_names"],
        device = "cpu")

# see how dataset and dataloaders are formed for training and evaluation
# how different image size is handled in training and validation
# Note that we find the best model based on validating with raindrop data. 

# outs is a list where, outputs from encoder stage are appended x1 outputs
# how conv tail is working

Layer (type (var_name))                            Input Shape          Output Shape         Param #
Transweather (Transweather)                        [1, 3, 256, 256]     [1, 3, 256, 256]     --
├─Tenc (Tenc)                                      [1, 3, 256, 256]     [1, 64, 64, 64]      296,448
│    └─OverlapPatchEmbed (patch_embed1)            [1, 3, 256, 256]     [1, 4096, 64]        --
│    │    └─Conv2d (proj)                          [1, 3, 256, 256]     [1, 64, 64, 64]      9,472
│    │    └─LayerNorm (norm)                       [1, 4096, 64]        [1, 4096, 64]        128
│    └─OverlapPatchEmbed (mini_patch_embed1)       [1, 64, 64, 64]      [1, 1024, 128]       --
│    │    └─Conv2d (proj)                          [1, 64, 64, 64]      [1, 128, 32, 32]     73,856
│    │    └─LayerNorm (norm)                       [1, 1024, 128]       [1, 1024, 128]       256
│    └─ModuleList (block1)                         --                   --                   --
│    │    └─Block (0)

## EncoderTransformer class

In [2]:
from transweather_model import EncoderTransformer

model = EncoderTransformer(patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 4, 4], mlp_ratios=[2, 2, 2, 2],
            qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[4, 2, 2, 1],
            drop_rate=0.0, drop_path_rate=0.1)

summary(model,
        input_size=[1,3,256,256], 
        verbose=0,
        col_names=["input_size", "output_size", "num_params"],
        col_width=20,
        row_settings=["var_names"],
        device = "cpu")


Layer (type (var_name))                       Input Shape          Output Shape         Param #
EncoderTransformer (EncoderTransformer)       [1, 3, 256, 256]     [1, 64, 64, 64]      296,448
├─OverlapPatchEmbed (patch_embed1)            [1, 3, 256, 256]     [1, 4096, 64]        --
│    └─Conv2d (proj)                          [1, 3, 256, 256]     [1, 64, 64, 64]      9,472
│    └─LayerNorm (norm)                       [1, 4096, 64]        [1, 4096, 64]        128
├─OverlapPatchEmbed (mini_patch_embed1)       [1, 64, 64, 64]      [1, 1024, 128]       --
│    └─Conv2d (proj)                          [1, 64, 64, 64]      [1, 128, 32, 32]     73,856
│    └─LayerNorm (norm)                       [1, 1024, 128]       [1, 1024, 128]       256
├─ModuleList (block1)                         --                   --                   --
│    └─Block (0)                              [1, 4096, 64]        [1, 4096, 64]        --
│    │    └─LayerNorm (norm1)                 [1, 4096, 64]        [1, 

##  OverlapPatchEmbed class

In [3]:
from transweather_model import OverlapPatchEmbed

model = OverlapPatchEmbed(img_size=256, patch_size=7, stride=4, in_chans=3, embed_dim=64)
summary(model,
        input_size=[1,3,256,256], 
        verbose=0,
        col_names=["input_size", "output_size", "num_params"],
        col_width=20,
        row_settings=["var_names"],
        device = "cpu")

Layer (type (var_name))                  Input Shape          Output Shape         Param #
OverlapPatchEmbed                        --                   --                   --
├─Conv2d (proj)                          [1, 3, 256, 256]     [1, 64, 64, 64]      9,472
├─LayerNorm (norm)                       [1, 4096, 64]        [1, 4096, 64]        128
Total params: 9,600
Trainable params: 9,600
Non-trainable params: 0
Total mult-adds (M): 38.80
Input size (MB): 0.79
Forward/backward pass size (MB): 4.19
Params size (MB): 0.04
Estimated Total Size (MB): 5.02

## Block class

In [4]:
from transweather_model import Block

model = Block(dim=64, num_heads=1, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1) # dim = embedding dimension
summary(model,
        input_size=[(1,4096, 64), (64, 1), (64,1)], 
        verbose=0,
        col_names=["input_size", "output_size", "num_params"],
        col_width=20,
        row_settings=["var_names"],
        device = "cpu")


RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [LayerNorm: 1, Attention: 1, Linear: 2, Linear: 2, Dropout: 2, Linear: 2, Dropout: 2, Identity: 1, LayerNorm: 1, Linear: 2]

## DWConv class

In [26]:
from transweather_model import DWConv

model = DWConv()
summary(model,
        input_size=[(1,2048, 64), (64,), (64,)], 
        verbose=0,
        col_names=["input_size", "output_size", "num_params"],
        col_width=20,
        row_settings=["var_names"],
        device = "cpu")

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: []

In [9]:
import cv2

path = "/home/ananth/TransWeather/dataset/allweather/input/0_rain.png"
img = cv2.imread(path)
img.shape

(480, 720, 3)

In [7]:
import torchvision

image = torchvision.io.read_image(path)
image.shape

torch.Size([3, 480, 720])

In [5]:
from train_data_functions import TrainData

train_dataset = TrainData(crop_size = [256, 256], train_data_dir = './allweather/',train_filename = 'allweather.txt')
len(train_dataset)

18069

In [12]:
from val_data_functions import ValData

val_dataset = ValData(val_data_dir = './dataset/test_a/', val_filename = 'raindroptesta.txt')
len(val_dataset)

58

torch.Size([3, 256, 256])

'city_read_14216'

In [15]:
import os
dataset_path = "./dataset/test_b/input/"
files = os.listdir(dataset_path)

# files


with open('./dataset/test_b/raindroptestb.txt', 'w') as file:
    # Write each filename followed by a newline character
    for filename in files:
        file.write(dataset_path + filename + '\n')

In [4]:
import time
import torch
import argparse
import torch.nn as nn
from torch.utils.data import DataLoader
from val_data_functions import ValData
from utils2 import validation, validation_val
from transweather_model import Transweather
import sys

net = Transweather()
exp_name = "Transweather_scratch_val_on_testset_a/"

try:
    ckp_path = "./{}best.pth".format(exp_name)
    ckp = torch.load(ckp_path)
    net.load_state_dict(ckp)
    print("Model loaded successfully")
except:
    print("Unsuccessful in loading model")
    sys.exit(1)


Model loaded successfully


## Testing Transweather on raindrop test set

In [2]:
import time
import torch
import argparse
import torch.nn as nn
from torch.utils.data import DataLoader
from val_data_functions import ValData
from utils2 import validation, validation_val
import os
import numpy as np
import random
from transweather_model import Transweather
import sys

# import os
# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"


val_batch_size = 1
exp_name = "Transweather_scratch_conv_tail_depthwise_conv/"
val_data_dir = './dataset/test_a/'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
val_filename = 'raindroptesta.txt'

val_data_loader = DataLoader(ValData(val_data_dir,val_filename), batch_size=val_batch_size, shuffle=False)

net = Transweather()


# try:
#     net.load_state_dict(torch.load('./{}/best.pth'.format(exp_name)))
#     print("Model loaded successfully")
# except:
#     print("Unsuccessful in loading model")
#     sys.exit(1)

try:
    model_state_dict = torch.load('./{}/best.pth'.format(exp_name))
    new_state_dict = {}

    for key, value in model_state_dict.items():
        new_key = key.replace("module.",'')
        new_state_dict[new_key] = value
        
    net.load_state_dict(new_state_dict)
    print("Model loaded successfully")
except:
    print("Unsuccessful in loading model")
    sys.exit(1)

    
net.to(device)
net.eval()
# if os.path.exists('./{}_results/{}/'.format(category,exp_name))==False:
#     os.mkdir('./{}_results/{}/'.format(category,exp_name))	
#     os.mkdir('./{}_results/{}/rain/'.format(category,exp_name))
print('--- Testing starts! ---')
start_time = time.time()
val_psnr, val_ssim = validation(net, val_data_loader, device, exp_name, save_tag=False)
end_time = time.time() - start_time
print('val_psnr: {0:.2f}, val_ssim: {1:.4f}'.format(val_psnr, val_ssim))
print('validation time is {0:.4f}'.format(end_time))

Model loaded successfully
--- Testing starts! ---
val_psnr: 33.54, val_ssim: 0.9055
validation time is 4.2633


# Raindrop testB results
## with Transweather_scratch_val_on_testset_a model
* val_psnr: 29.48
* val_ssim: 0.8606
  

## with Transweather_scratch_val_on_testset_b model
* val_psnr: 29.89
* val_ssim: 0.8603



## Testing Transweather on raindrop test set

In [1]:
import time
import torch
import argparse
import torch.nn as nn
from torch.utils.data import DataLoader
from utils2 import validation
import os
import numpy as np
import random
from transweather_model import Transweather
import sys

### Validation dataset

In [2]:
import torch.utils.data as data
from PIL import Image
from torchvision.transforms import Compose, ToTensor, Normalize
import numpy as np

# --- Validation/test dataset --- #
class ValData(data.Dataset):
    def __init__(self, val_input_filename ,val_gt_filename):
        super().__init__()

        with open(val_input_filename) as f:
            contents = f.readlines()
            input_names = [i.strip() for i in contents]
        
        with open(val_gt_filename) as f:
            contents = f.readlines()
            gt_names = [i.strip() for i in contents]
            
            
        self.input_names = input_names
        self.gt_names = gt_names

    def get_images(self, index):
        
        input_name = self.input_names[index]
        gt_name = self.gt_names[index]

        input_img = Image.open(input_name)
        gt_img = Image.open(gt_name)

        # Resizing image in the multiple of 16"
        wd_new,ht_new = input_img.size
        if ht_new>wd_new and ht_new>1024:
            wd_new = int(np.ceil(wd_new*1024/ht_new))
            ht_new = 1024
        elif ht_new<=wd_new and wd_new>1024:
            ht_new = int(np.ceil(ht_new*1024/wd_new))
            wd_new = 1024
        wd_new = int(16*np.ceil(wd_new/16.0))
        ht_new = int(16*np.ceil(ht_new/16.0))
        input_img = input_img.resize((wd_new,ht_new), Image.ANTIALIAS)
        gt_img = gt_img.resize((wd_new, ht_new), Image.ANTIALIAS)

        # --- Transform to tensor --- #
        transform_input = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        transform_gt = Compose([ToTensor()])
        input_im = transform_input(input_img)
        gt = transform_gt(gt_img)

        return input_im, gt, input_name

    def __getitem__(self, index):
        res = self.get_images(index)
        return res

    def __len__(self):
        return len(self.input_names)


### Validating for AIWD6 dataset

In [6]:
val_batch_size = 32
exp_name = "Transweather_scratch_val_on_testset_a"


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

meta_path = "/home/ananth/TransWeather/dataset/AIWD6/meta/"
meta_files = os.listdir(meta_path) 

for file in meta_files:
    
    val_input_filename = f'./dataset/AIWD6/meta/{file}/input.txt'
    val_gt_filename = f'./dataset/AIWD6/meta/{file}/gt.txt'
    
    
    dataset = ValData(val_input_filename,val_gt_filename)
    val_data_loader = DataLoader(dataset, batch_size=val_batch_size, 
                                 shuffle=False, num_workers = 8, pin_memory = True)
    
    net = Transweather()
    net.to(device)
    
    # try:
    #     net.load_state_dict(torch.load('./{}/latest.pth'.format(exp_name)))
    #     print("Model loaded successfully")
    # except:
    #     print("Unsuccessful in loading model")
    #     sys.exit(1)
    try:
        model_state_dict = torch.load('./{}/best.pth'.format(exp_name))
        new_state_dict = {}
    
        for key, value in model_state_dict.items():
            new_key = key.replace("module.",'')
            new_state_dict[new_key] = value
            
        net.load_state_dict(new_state_dict)
        print("Model loaded successfully")
    except:
        print("Unsuccessful in loading model")
        sys.exit(1)

    
    net.eval()
    # if os.path.exists('./{}_results/{}/'.format(category,exp_name))==False:
    #     os.mkdir('./{}_results/{}/'.format(category,exp_name))	
    #     os.mkdir('./{}_results/{}/rain/'.format(category,exp_name))
    print('--- Testing starts! ---')
    start_time = time.time()
    val_psnr, val_ssim = validation(net, val_data_loader, device, exp_name, save_tag=False)
    end_time = time.time() - start_time
    
    print(f"{file}: Images extracted :", len(dataset))
    
    print('val_psnr: {0:.2f}, val_ssim: {1:.4f}'.format(val_psnr, val_ssim))
    print('validation time is {0:.4f}'.format(end_time))


Unsuccessful in loading model


AttributeError: 'tuple' object has no attribute 'tb_frame'

### Validating for Snow100k-L dataset

In [3]:
val_batch_size = 1
exp_name = "Transweather_scratch_enc_depth_3_3_2_2"


device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")


    
val_input_filename = './dataset/Snow100K-testset/media/jdway/GameSSD/overlapping/test/Snow100K-L/synthetic.txt'
val_gt_filename = './dataset/Snow100K-testset/media/jdway/GameSSD/overlapping/test/Snow100K-L/gt.txt'


dataset = ValData(val_input_filename,val_gt_filename)
val_data_loader = DataLoader(dataset, batch_size=val_batch_size, 
                             shuffle=False, num_workers = 8, pin_memory = True)

net = Transweather()
net.to(device)

# try:
#     net.load_state_dict(torch.load('./{}/best.pth'.format(exp_name)))
#     print("Model loaded successfully")
# except:
#     print("Unsuccessful in loading model")
#     sys.exit(1)
try:
    model_state_dict = torch.load('./{}/best.pth'.format(exp_name))
    new_state_dict = {}
    
    for key, value in model_state_dict.items():
        new_key = key.replace("module.",'')
        new_state_dict[new_key] = value
        
    net.load_state_dict(new_state_dict)
    print("Model loaded successfully")
except:
    print("Unsuccessful in loading model")
    sys.exit(1)


net.eval()

print('--- Testing starts! ---')
start_time = time.time()
val_psnr, val_ssim = validation(net, val_data_loader, device, exp_name, save_tag=False)
end_time = time.time() - start_time

print('val_psnr: {0:.2f}, val_ssim: {1:.4f}'.format(val_psnr, val_ssim))
print('validation time is {0:.4f}'.format(end_time))


Model loaded successfully
--- Testing starts! ---


  input_img = input_img.resize((wd_new,ht_new), Image.ANTIALIAS)
  input_img = input_img.resize((wd_new,ht_new), Image.ANTIALIAS)
  input_img = input_img.resize((wd_new,ht_new), Image.ANTIALIAS)
  input_img = input_img.resize((wd_new,ht_new), Image.ANTIALIAS)
  input_img = input_img.resize((wd_new,ht_new), Image.ANTIALIAS)
  input_img = input_img.resize((wd_new,ht_new), Image.ANTIALIAS)
  input_img = input_img.resize((wd_new,ht_new), Image.ANTIALIAS)
  gt_img = gt_img.resize((wd_new, ht_new), Image.ANTIALIAS)
  gt_img = gt_img.resize((wd_new, ht_new), Image.ANTIALIAS)
  gt_img = gt_img.resize((wd_new, ht_new), Image.ANTIALIAS)
  gt_img = gt_img.resize((wd_new, ht_new), Image.ANTIALIAS)
  gt_img = gt_img.resize((wd_new, ht_new), Image.ANTIALIAS)
  gt_img = gt_img.resize((wd_new, ht_new), Image.ANTIALIAS)
  gt_img = gt_img.resize((wd_new, ht_new), Image.ANTIALIAS)
  input_img = input_img.resize((wd_new,ht_new), Image.ANTIALIAS)
  gt_img = gt_img.resize((wd_new, ht_new), Image.ANTIALIAS)


val_psnr: 33.39, val_ssim: 0.8897
validation time is 883.5207


# Results on Snowtest 100k-L
## Scratch_a
* Val PSNR = 33.40
* Val SSIM = 0.8901
## Scratch_b
* Val PSNR = 33.61
* Val SSIM = 0.8889

In [38]:

import torchvision
img_name = "./dataset/AIWD6/Rainy_to_Cloudy/0 to 293/Image1.png"
img = torchvision.io.read_image(img_name)

img.shape

torch.Size([3, 256, 512])

In [39]:

normalized_image = img.float() / 255.0
# img_normalized = normalize_image_tensor(img)
torchvision.utils.save_image(normalized_image, "./demo_image.png")

In [2]:
# Generate input.txt and gt.txt files for each image from img0 to img9 for each subfolder
import os

main_folder = "./dataset/AIWD6/Sunny_to_Rainy/"

subfolders = os.listdir(main_folder)
subfolders.sort()

for i in range(0, 10):  # For each image from img0 to img9
    input_file = open(f'./dataset/AIWD6/meta/Sunny_to_Rainy/input{i}.txt', 'w')
    gt_file = open('./dataset/AIWD6/meta/Sunny_to_Rainy/gt.txt', 'w')

    for subfolder in subfolders:
        if os.path.exists(os.path.join(main_folder, subfolder, f'{i}.png')):
            img_path = os.path.join(main_folder, subfolder, f'{i}.png')
            gt_path = os.path.join(main_folder, subfolder, 'Image1.png')

        input_file.write(img_path + '\n')
        gt_file.write(gt_path + '\n')

    input_file.close()  
    gt_file.close()

In [47]:
# prepare meta for AIWD6 data

import os


dataset_path = "./dataset/AIWD6/Rainy_to_Sunny/"
folders = os.listdir(dataset_path)
folders.sort()

for folder in folders:
    img_list = os.listdir(dataset_path+folder)
    img_list.sort()  # to remove interpolated image from list 
    
    if len(img_list)!=13: # some folder do not have 13 images
        continue
    img_list = img_list[:-1]
    
    # store gt image and Image1 or Image2 dependong on folder
    gt_img = img_list[-1]
    # temp_img = img_list[-1]
    
    # remove gt and add that Image1 or Image2
    img_list = img_list[:-1]
    # img_list.append(temp_img)
    
    with open("./dataset/AIWD6/meta/Rainy_to_Sunny/input.txt", 'a') as file:
        for img_name in img_list:
            file.write(dataset_path + folder + '/' + img_name + '\n')
            
    with open("./dataset/AIWD6/meta/Rainy_to_Sunny/gt.txt", 'a') as file:
        for _ in range(len(img_list)):
            file.write(dataset_path + folder + '/' + gt_img + '\n')

    
    

In [2]:
# prepare meta for raindrop SnowTest100k-L data

import os
dataset_path = "./dataset/Snow100K-testset/media/jdway/GameSSD/overlapping/test/Snow100K-L/synthetic/"
files = os.listdir(dataset_path)
files.sort()

input_filename = "synthetic.txt"
with open('./dataset/Snow100K-testset/media/jdway/GameSSD/overlapping/test/Snow100K-L/' + input_filename, 'w') as file:
    # Write each filename followed by a newline character
    for filename in files:
        file.write(dataset_path + filename + '\n')



dataset_path = "./dataset/Snow100K-testset/media/jdway/GameSSD/overlapping/test/Snow100K-L/gt/"
files = os.listdir(dataset_path)
files.sort()
gt_filename = "gt.txt"
with open('./dataset/Snow100K-testset/media/jdway/GameSSD/overlapping/test/Snow100K-L/' + gt_filename, 'w') as file:
    # Write each filename followed by a newline character
    for filename in files:
        file.write(dataset_path + filename + '\n')

In [24]:
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
import cv2

img_1_path = "./dataset/AIWD6/Rainy_to_Cloudy/0 to 293/Image1.png"
im1 = torchvision.io.read_image(img_1_path)

img_2_path = "./demo_image.png"
im2 = torchvision.io.read_image(img_2_path)

im1 = im1.reshape(im1.shape[1],im1.shape[2],3)
im2 = im2.reshape(im2.shape[1],im2.shape[2],3)

im1 = im1.numpy()
im2 = im2.numpy()

im1_y = cv2.cvtColor(im1, cv2.COLOR_BGR2YCR_CB)[:, :, 0]
im2_y = cv2.cvtColor(im2, cv2.COLOR_BGR2YCR_CB)[:, :, 0]


ans = structural_similarity(im1_y, im2_y, data_range=1, multichannel=True)

In [60]:
from PIL import Image
img_1_path = "./dataset/AIWD6/Rainy_to_Cloudy/0 to 293/0.png"
im1 = Image.open(img_1_path)
# im1 = im1.unsqueeze(dim=0)

transform_input = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
im1 = transform_input(im1).unsqueeze(dim=0).to(device)
output = net(im1)


In [6]:
# cleaning snow100K-L dataset
import torchvision
import os

datapath = "./dataset/Snow100K-testset/media/jdway/GameSSD/overlapping/test/Snow100K-L/synthetic/"
files = os.listdir(datapath)


input_folder = "./dataset/Snow100K-testset/media/jdway/GameSSD/overlapping/test/Snow100K-L/synthetic/"
gt_folder = "./dataset/Snow100K-testset/media/jdway/GameSSD/overlapping/test/Snow100K-L/gt/"

bad_imgs = []

for file in files:
    img_path = input_folder + file
    gt_path = gt_folder + file

    img = torchvision.io.read_image(img_path)
    gt = torchvision.io.read_image(gt_path)

    if(img.shape != gt.shape):
        bad_imgs.append(file)
    

    
    

    

## Function take input and gt and compute output image and PSNR and SSIM

In [1]:
import torchvision
from PIL import Image
from torchvision.transforms import Normalize, ToTensor, Compose
import time
import torch
import argparse
import torch.nn as nn
from torch.utils.data import DataLoader
from val_data_functions import ValData
from utils2 import validation, validation_val
import os
import numpy as np
import random
from transweather_model import Transweather
import sys
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
import cv2



device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
inp_image_path = "./dataset/realistic/sidewalk winter -grayscale -gray_00440.jpg"
gt_image_path = "./dataset/AIWD6/Rainy_to_Cloudy/27 to 320/Image2.png"


input_img = Image.open(inp_image_path)
gt_img = Image.open(gt_image_path)

# Resizing image in the multiple of 16"
wd_new,ht_new = input_img.size
if ht_new>wd_new and ht_new>1024:
    wd_new = int(np.ceil(wd_new*1024/ht_new))
    ht_new = 1024
elif ht_new<=wd_new and wd_new>1024:
    ht_new = int(np.ceil(ht_new*1024/wd_new))
    wd_new = 1024
wd_new = int(16*np.ceil(wd_new/16.0))
ht_new = int(16*np.ceil(ht_new/16.0))
input_img = input_img.resize((wd_new,ht_new), Image.ANTIALIAS)
gt_img = gt_img.resize((wd_new, ht_new), Image.ANTIALIAS)

# --- Transform to tensor --- #
transform_input = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
transform_gt = Compose([ToTensor()])
input_im = transform_input(input_img).unsqueeze(dim=0).to(device)
gt = transform_gt(gt_img)



exp_name = "Transweather_scratch_conv_tail_depthwise_conv/"
net = Transweather()


# try:
#     net.load_state_dict(torch.load('./{}/best.pth'.format(exp_name)))
#     print("Model loaded successfully")
# except:
#     print("Unsuccessful in loading model")
#     sys.exit(1)

try:
    model_state_dict = torch.load('./{}/best.pth'.format(exp_name))
    new_state_dict = {}

    for key, value in model_state_dict.items():
        new_key = key.replace("module.",'')
        new_state_dict[new_key] = value
        
    net.load_state_dict(new_state_dict)
    print("Model loaded successfully")
except:
    print("Unsuccessful in loading model")
    sys.exit(1)

    
net.to(device)
net.eval()

print('--- Evaluation starts! ---')
pred_image = net(input_im)


input_im = input_im.squeeze()
pred_image = pred_image.squeeze()

print("Input image", input_im.shape)
print("predicted Image", pred_image.shape)
print("Ground Truth Image", gt.shape)

save_image = pred_image.detach().cpu().reshape(3, pred_image.shape[1], pred_image.shape[2])
# gt = gt.cpu().numpy().reshape(gt.shape[1], gt.shape[2], 3)
# pred_image = pred_image.detach().cpu().numpy().reshape(pred_image.shape[1], pred_image.shape[2], 3)

# im1_y = cv2.cvtColor(gt, cv2.COLOR_BGR2YCR_CB)[:, :, 0] 
# im2_y = cv2.cvtColor(pred_image, cv2.COLOR_BGR2YCR_CB)[:, :, 0] 


# ssim = structural_similarity(im1_y, im2_y, data_range=1, multichannel=True)
# psnr = peak_signal_noise_ratio(im1_y, im2_y)

# print("\nSSIM", ssim)
# print("PSNR", psnr)

print(save_image.shape)
torchvision.utils.save_image(save_image, "./output_image_conv.png")

  input_img = input_img.resize((wd_new,ht_new), Image.ANTIALIAS)
  gt_img = gt_img.resize((wd_new, ht_new), Image.ANTIALIAS)


Model loaded successfully
--- Evaluation starts! ---
Input image torch.Size([3, 432, 640])
predicted Image torch.Size([3, 432, 640])
Ground Truth Image torch.Size([3, 432, 640])
torch.Size([3, 432, 640])
