In [10]:
import numpy as np 
import matplotlib.pyplot as plt 
import torch 
from PIL import Image
import os
from model import Generator, Discriminator

In [11]:
gen = Generator(in_channels=4, out_channels=4, num_upsample_blocks=2, num_residual_blocks=16)
dis = Discriminator(input_channels=4)

In [12]:
img_path = "/Users/archismanchakraborti/Desktop/python_files/Super-Resolution-Tool/data/low_res_data/3_lr.png"
img = Image.open(img_path)


# Convert img to torch tensor 
img = np.array(img)
img_torch = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).float()

In [13]:
print(f"Input image shape = {img_torch.shape}")
out_img = gen(img_torch)
print(f"Generator output shape = {out_img.shape}")
dis_out = dis(out_img)
print(f"Discriminator output shape = {dis_out.shape}, Discriminator output = {dis_out.item()}")

Input image shape = torch.Size([1, 4, 32, 32])
Generator output shape = torch.Size([1, 4, 96, 96])
Discriminator output shape = torch.Size([1, 1]), Discriminator output = 0.4223564863204956


In [14]:
# count the number of parameters in the generator and discriminator
gen_params = sum(p.numel() for p in gen.parameters())
dis_params = sum(p.numel() for p in dis.parameters())

print(f"Generator has {gen_params} parameters")
print(f"Discriminator has {dis_params} parameters")
print(f"Total number of parameters = {gen_params + dis_params: ,d}")

Generator has 1561028 parameters
Discriminator has 23566081 parameters
Total number of parameters =  25,127,109


In [15]:
import config 
from imageloader import create_dataloaders

In [16]:
train_dl, val_dl, test_dl = create_dataloaders(low_res_dir=config.LOW_RES_FOLDER, 
                                               high_res_dir=config.HIGH_RES_FOLDER, 
                                      batch_size=config.BATCH_SIZE, num_workers=config.NUM_WORKERS)

In [17]:
next(iter(train_dl))

[tensor([[[[0.0627, 0.0627, 0.0627,  ..., 0.7451, 0.5529, 0.4039],
           [0.0627, 0.0627, 0.0627,  ..., 0.5059, 0.4392, 0.4549],
           [0.0627, 0.0627, 0.0627,  ..., 0.4353, 0.5098, 0.5373],
           ...,
           [0.7804, 0.7765, 0.7765,  ..., 0.0627, 0.0627, 0.0627],
           [0.7765, 0.7725, 0.7686,  ..., 0.0627, 0.0627, 0.0627],
           [0.7725, 0.7725, 0.7686,  ..., 0.0627, 0.0627, 0.0627]],
 
          [[0.0627, 0.0627, 0.0627,  ..., 0.7451, 0.5569, 0.4078],
           [0.0627, 0.0627, 0.0627,  ..., 0.5059, 0.4588, 0.4588],
           [0.0627, 0.0627, 0.0627,  ..., 0.4275, 0.5020, 0.5294],
           ...,
           [0.7882, 0.7922, 0.7922,  ..., 0.0627, 0.0627, 0.0627],
           [0.7922, 0.7922, 0.7961,  ..., 0.0627, 0.0627, 0.0627],
           [0.7961, 0.7961, 0.7961,  ..., 0.0627, 0.0627, 0.0627]],
 
          [[0.0627, 0.0627, 0.0627,  ..., 0.1765, 0.1451, 0.1373],
           [0.0627, 0.0627, 0.0627,  ..., 0.1294, 0.1294, 0.1686],
           [0.0627, 0.06