In [1]:
import torch
from torch import nn, optim
from GAN import Generator, Discriminator
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from PIL import Image
from customDataset import CustomDataset


In [2]:
#################### Define Generator and Discriminator

# Create a loss function
criterion = nn.BCELoss()

# Create the generator and the discriminator
netG = Generator()
netD = Discriminator()

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [3]:
###################### Define Dataloader ######################

# Define a transformation operation to normalize the images
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # normalize to [-1, 1]
])

# Set the paths for your image folders
path_low_res = "HR_images"
path_high_res = "LR_images/nn"

# Create the dataset
dataset = CustomDataset(path_low_res, path_high_res, transform)

# Create the dataloader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [None]:
num_epochs = 5
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
real_label = 1

for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        # Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        # Train with all-real batch
        netD.zero_grad()
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device)
        output = netD(real_cpu).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()