In [3]:
from model import Generator, Discriminator
from loss import ContentLoss,  GeneratoradversarialLoss, DiscriminatorLoss
import torch 
import torch.nn as nn
torch.set_num_threads(1)
import config
from imageloader import create_dataloaders
from torchvision.models import vgg19, VGG19_Weights

In [4]:
device = config.DEVICE

# Create the models
discriminator = Discriminator(input_channels=config.DISCRIMINATOR_INPUT_CHANNELS).to(device)
generator = Generator(in_channels=3, out_channels=3, 
                      num_residual_blocks=config.NUM_RESIDUAL_BLOCKS, 
                      num_upsample_blocks=config.NUM_UPSAMPLE_BLOCKS, 
                      upsample_factor=2).to(device)

# Create the optimizers
generator_optim = torch.optim.Adam(generator.parameters(), lr=1e-4)
discriminator_optim = torch.optim.Adam(discriminator.parameters(), lr=1e-4)

# Create the loss functions
generator_adversarial_loss = GeneratoradversarialLoss().to(device)
discriminator_loss = DiscriminatorLoss().to(device)
content_loss = ContentLoss().to(device)

# Create the dataloaders
train_loader, validation_loader, test_loader = create_dataloaders(low_res_dir=config.LOW_RES_FOLDER, 
                                                                high_res_dir=config.HIGH_RES_FOLDER, 
                                                                batch_size=config.BATCH_SIZE, 
                                                                num_workers=config.NUM_WORKERS)

In [5]:
lr, hr = next(iter(train_loader))
# print(f"Shape of low res image: {lr.shape}")
# print(f"Shape of high res image: {hr.shape}")

# real_outputs = discriminator(hr)
# print(f"Shape of real outputs: {real_outputs.shape}")

# fake_images = generator(lr)
# fake_outputs = discriminator(fake_images)
# print(f"Shape of fake outputs: {fake_outputs.shape}")


# loss_D = discriminator_loss(real_outputs, fake_outputs)
# print(f"Discriminator loss: {loss_D}")

# # Generator 
# fake_outputs = discriminator(fake_images)
# adversarial_loss = generator_adversarial_loss(fake_outputs)
# print(f"Adversarial loss: {adversarial_loss}")

In [8]:
fake_images = generator(lr)
content_loss_value = content_loss(fake_images, hr)

In [9]:
content_loss_value

tensor(5.5002, grad_fn=<MseLossBackward0>)