In [22]:
import numpy as np 
import matplotlib.pyplot as plt 
import torch 
from PIL import Image
from model import Generator, Discriminator

In [23]:
gen = Generator(in_channels=4, out_channels=4, num_upsample_blocks=2, num_residual_blocks=16)
dis = Discriminator(input_channels=4)

In [24]:
img_path = "/Users/archismanchakraborti/Desktop/python_files/Super-Resolution-Tool/data/low_res_data/3_lr.png"
img = Image.open(img_path)


# Convert img to torch tensor 
img = np.array(img)
img_torch = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).float()

In [25]:
print(f"Input image shape = {img_torch.shape}")
out_img = gen(img_torch)
print(f"Generator output shape = {out_img.shape}")
dis_out = dis(out_img)
print(f"Discriminator output shape = {dis_out.shape}, Discriminator output = {dis_out.item()}")

Input image shape = torch.Size([1, 4, 32, 32])
Generator output shape = torch.Size([1, 4, 96, 96])
Discriminator output shape = torch.Size([1, 1]), Discriminator output = 0.4969562888145447


In [26]:
# count the number of parameters in the generator and discriminator
gen_params = sum(p.numel() for p in gen.parameters())
dis_params = sum(p.numel() for p in dis.parameters())

print(f"Generator has {gen_params} parameters")
print(f"Discriminator has {dis_params} parameters")
print(f"Total number of parameters = {gen_params + dis_params: ,d}")

Generator has 1561028 parameters
Discriminator has 23566081 parameters
Total number of parameters =  25,127,109


In [27]:
from torchvision.models import vgg19

In [28]:
vgg = vgg19(pretrained=True)
img_torch = img_torch[:, :3, :, :]

In [33]:
vgg.features[:36].eval()

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo