## Imports

In [1]:
import time
import torch
from torch.utils.data import DataLoader, Dataset
from utils import validation
from transweather_model import Transweather
import sys
import os
from PIL import Image
from torchvision.transforms import Compose, ToTensor, Normalize
import numpy as np

import warnings
warnings.filterwarnings('ignore')

## Generate input.txt and gt.txt files for each image from img0 to img9 for each subfolder


In [None]:
import os

main_folder = "./dataset/AIWD6/Sunny_to_Rainy/"

subfolders = os.listdir(main_folder)
subfolders.sort()

for i in range(0, 10):  # For each image from img0 to img9
    input_file = open(f'./dataset/AIWD6/meta/Sunny_to_Rainy/input{i}.txt', 'w')
    gt_file = open('./dataset/AIWD6/meta/Sunny_to_Rainy/gt.txt', 'w')

    for subfolder in subfolders:
        if os.path.exists(os.path.join(main_folder, subfolder, f'{i}.png')):
            img_path = os.path.join(main_folder, subfolder, f'{i}.png')
            gt_path = os.path.join(main_folder, subfolder, 'Image1.png')

        input_file.write(img_path + '\n')
        gt_file.write(gt_path + '\n')

    input_file.close()  
    gt_file.close()

## Dataloader

In [2]:
class ValData(Dataset):
    def __init__(self, val_input_filename ,val_gt_filename):
        super().__init__()

        with open(val_input_filename) as f:
            contents = f.readlines()
            input_names = [i.strip() for i in contents]
        
        with open(val_gt_filename) as f:
            contents = f.readlines()
            gt_names = [i.strip() for i in contents]
            
            
        self.input_names = input_names
        self.gt_names = gt_names

    def get_images(self, index):
        
        input_name = self.input_names[index]
        gt_name = self.gt_names[index]

        input_img = Image.open(input_name)
        gt_img = Image.open(gt_name)

        # Resizing image in the multiple of 16"
        wd_new,ht_new = input_img.size
        if ht_new>wd_new and ht_new>1024:
            wd_new = int(np.ceil(wd_new*1024/ht_new))
            ht_new = 1024
        elif ht_new<=wd_new and wd_new>1024:
            ht_new = int(np.ceil(ht_new*1024/wd_new))
            wd_new = 1024
        wd_new = int(16*np.ceil(wd_new/16.0))
        ht_new = int(16*np.ceil(ht_new/16.0))
        input_img = input_img.resize((wd_new,ht_new), Image.ANTIALIAS)
        gt_img = gt_img.resize((wd_new, ht_new), Image.ANTIALIAS)

        # --- Transform to tensor --- #
        transform_input = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        transform_gt = Compose([ToTensor()])
        input_im = transform_input(input_img)
        gt = transform_gt(gt_img)

        return input_im, gt, input_name

    def __getitem__(self, index):
        res = self.get_images(index)
        return res

    def __len__(self):
        return len(self.input_names)


## Loading model

In [3]:
net = Transweather()
exp_name = "Transweather_scratch_val_on_testset_a/"

try:
    ckp_path = "./{}best.pth".format(exp_name)
    ckp = torch.load(ckp_path)
    net.load_state_dict(ckp)
    print("Model loaded successfully")
except:
    print("Unsuccessful in loading model")
    sys.exit(1)

Model loaded successfully


## Evaluation

In [5]:
val_batch_size = 32
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)
net.eval()



meta_folder = "./dataset/AIWD6/meta/"
subfolders = sorted(os.listdir(meta_folder))

print('--- Evaluation starts! ---')

for subfolder in subfolders:
    input_files = sorted(os.listdir(meta_folder+subfolder+"/"))[1:]
    
    for input_file in input_files:
        val_input_filename = meta_folder+subfolder+"/"+input_file
        val_gt_filename = meta_folder+subfolder+"/gt.txt"

        dataset = ValData(val_input_filename,val_gt_filename)
        val_data_loader = DataLoader(dataset, batch_size=val_batch_size, 
                                 shuffle=False, num_workers = 4, pin_memory = True)

        
        start_time = time.time()
        val_psnr, val_ssim = validation(net, val_data_loader, device, exp_name, save_tag=False)
        end_time = time.time() - start_time
        

        print("Folder: ", subfolder, "\nImage: ", input_file)
        print(f"Extracted {len(dataset)} images")
        print('val_psnr: {0:.2f}, val_ssim: {1:.4f}'.format(val_psnr, val_ssim))
        print('validation time is {0:.4f}'.format(end_time))
        print("*"*100)

print('--- Evaluation finish! ---')
    



--- Evaluation starts! ---


Folder:  Cloudy_to_Rainy 
Image:  input0.txt
Extracted 293 images
val_psnr: 28.42, val_ssim: 0.9128
validation time is 4.3073
****************************************************************************************************
Folder:  Cloudy_to_Rainy 
Image:  input1.txt
Extracted 293 images
val_psnr: 30.16, val_ssim: 0.9242
validation time is 3.6657
****************************************************************************************************
Folder:  Cloudy_to_Rainy 
Image:  input2.txt
Extracted 293 images
val_psnr: 29.44, val_ssim: 0.9210
validation time is 3.7255
****************************************************************************************************
Folder:  Cloudy_to_Rainy 
Image:  input3.txt
Extracted 293 images
val_psnr: 26.82, val_ssim: 0.9028
validation time is 4.1622
****************************************************************************************************
Folder:  Cloudy_to_Rainy 
Image:  input4.txt
Extracted 293 images
val_psnr: 23.91, val_ssim: