In [1]:
from networks.source import Unet_resize_conv
from PIL import Image

import torchvision.transforms as transforms
import os
import torch

In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

device

device(type='cuda')

In [3]:
nth_training = 48
root = r"C:/Users/ILLEGEAR/OneDrive - Universiti Malaya/FYP/checkpoints/enlightenGAN"
path = f"{root}/training_{nth_training}/iter_trained"
demo = f"{root}/demo"

In [4]:
use_baseline = True

In [5]:
baseline = torch.nn.DataParallel(Unet_resize_conv(custom_attention=False), [0])

In [6]:
model = baseline if use_baseline else Unet_resize_conv(custom_attention=False)
model_pth = os.path.join(demo if use_baseline else path, 'net_G.pth')
model.load_state_dict(torch.load(model_pth, device))

model.to(device)
model.eval()

print("Done Loading Model")
print(model_pth)

Done Loading Model
C:/Users/ILLEGEAR/OneDrive - Universiti Malaya/FYP/checkpoints/enlightenGAN/demo\net_G.pth


In [7]:
preprocess = transforms.Compose([
    transforms.Resize(size=286, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(size=256),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

def getAttentionLayer(image):
    r, g, b = image[0] + 1, image[1] + 1, image[2] + 1
    grayscale = 1. - (0.299*r + 0.587*g + 0.114*b) / 2.
    return torch.unsqueeze(grayscale, dim=0)

def transformImage(image: Image.Image, mode='grayscale-1D'):
    assert mode in ['grayscale-1D', 'grayscale-3D', 'RGB']
    
    img = preprocess(image)
    att = getAttentionLayer(img)
    
    if mode == 'grayscale-1D':
        r, g, b = img[0], img[1], img[2]
        grayscale = 0.299 * r + 0.587 * g + 0.114 * b
        img = torch.unsqueeze(grayscale, dim=0)
    
    if mode == 'grayscale-3D':
        r, g, b = img[0], img[1], img[2]
        grayscale = 0.299 * r + 0.587 * g + 0.114 * b
        img = torch.unsqueeze(grayscale, dim=0)
        img = img.repeat(3, 1, 1)
    
    # Adding batch dimension
    img = img.unsqueeze(dim=0).to(device)
    att = att.unsqueeze(dim=0).to(device)
    
    return img, att
    

In [8]:
LOL = r"C:\Users\ILLEGEAR\Desktop\cheelam\FYP\Repositories\FYP_low_light_image_enhancement\datasets\light_enhancement\testA"
EXDARK = r"C:\Users\ILLEGEAR\Desktop\cheelam\FYP\Repositories\FYP_low_light_image_enhancement\datasets\test\testA"
SID = r"C:\Users\ILLEGEAR\Desktop\cheelam\FYP\Repositories\FYP_low_light_image_enhancement\datasets\test\SID-RGB"

data_root = SID

In [9]:
n_sample = 120
running_times = []
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)

# GPU Warmup
dummy_img = torch.randn(1, 3, 256, 256, dtype=torch.float).to(device)
dummy_att = torch.randn(1, 1, 256, 256, dtype=torch.float).to(device)

for _ in range(10):
    model(dummy_img, dummy_att)

with torch.no_grad():
    for i, img_path in enumerate(os.listdir(data_root)):
        image = Image.open(os.path.join(data_root, img_path)).convert('RGB')
        
        img, att = transformImage(image, mode='RGB')
        
        starter.record()
        model(img, att)
        ender.record()
        
        # WAIT FOR GPU SYNC
        torch.cuda.synchronize()
        
        curr_time = starter.elapsed_time(ender)
        running_times.append(curr_time)
            
        if (i + 1) >= n_sample:
            break

In [10]:
total_time = sum(running_times) / 1000
avg_time  = total_time / len(running_times)
avg_fps = 1 / avg_time

print(f"Total Running Time (s) for {len(running_times)} images:", total_time)
print("Average Running Time (s):", avg_time)
print("Average FPS:", avg_fps)

Total Running Time (s) for 120 images: 8.236103989601135
Average Running Time (s): 0.0686341999133428
Average FPS: 14.569995734817264


In [11]:
# SID_raw = r"C:\Users\ILLEGEAR\Desktop\cheelam\FYP\Repositories\FYP_low_light_image_enhancement\datasets\test\SID"
# SID_save = r"C:\Users\ILLEGEAR\Desktop\cheelam\FYP\Repositories\FYP_low_light_image_enhancement\datasets\test\SID-RGB"

# import rawpy

In [12]:
# image_filenames = os.listdir(SID_raw)

In [13]:
# for i, filename in enumerate(image_filenames):
#     rawimg = rawpy.imread(os.path.join(SID_raw, filename))
    
#     img = Image.fromarray(rawimg.raw_image)
#     img.convert('RGB')
#     img.save(os.path.join(SID_save, "%04d" % (i+1) + '.png'))