In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils
from PIL import Image
from GAN import Generator, Discriminator
from customDataset import CustomDataset

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

nz = 100
ngf = 64

G = Generator(ngf=ngf, nz=nz).to(device)
D = Discriminator().to(device)

# Initialize BCELoss function
criterion = nn.BCEWithLogitsLoss()

# Create batch of latent vectors that we will use to visualize
# the progression of the generator
fixed_noise = torch.randn(64, 100, 1, 1, device=device)

# Setup Adam optimizers for G and D
optimizerG = torch.optim.Adam(G.parameters(), lr=0.0002)
optimizerD = torch.optim.Adam(D.parameters(), lr=0.0002)

In [3]:
###################### Define Dataloader ######################

# Define a transformation operation to normalize the images
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # normalize to [-1, 1]
])

# Set the paths for your image folders
path_low_res = "HR_images"
path_high_res = "LR_images/nn"

# Create the dataset
dataset = CustomDataset(path_low_res, path_high_res, transform)

# Create the dataloader
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, drop_last=True)

In [4]:
# Training Loop
num_epochs = 50 # this number can be adjusted

for epoch in range(num_epochs):

    for i, data in enumerate(dataloader, 0):

        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        D.zero_grad()
        real_cpu = data[0].to(device)       #gpu 
        b_size = real_cpu.size(0)           # batch size
        # print("real_cpu size: ", real_cpu.size(0))
        label = torch.full((b_size,), 1, device=device)
        output = D(real_cpu).view(-1)
        label = label.float()
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()
        print("updated Discriminator with real data")

        noise = torch.randn(b_size, nz, device=device)
        # print("noise: ", noise)
        fake = G(noise)
        label.fill_(0)
        output = D(fake.detach()).view(-1)
        label = label.float()
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        print("updated Discriminator with fake data")

        errD = errD_real + errD_fake
        optimizerD.step()

        # (2) Update G network: maximize log(D(G(z)))
        G.zero_grad()
        label.fill_(1)
        output = D(fake).view(-1)
        label = label.float()
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()
        print("updated Generator")

updated Discriminator with real data


RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [1, 100]