In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

# Define the generator network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

# Define the discriminator network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# Define the loss function (binary cross entropy)
criterion = nn.BCELoss()

# Initialize the generator and discriminator networks
generator = Generator()
discriminator = Discriminator()

# Define the optimizers for the generator and discriminator networks
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Train the generator and discriminator networks
# Define the number of epochs and batch size for training
num_epochs = 50
batch_size = 64

# Load the dataset
dataset = datasets.ImageFolder(root='PS2_sample_images', transform=transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))

# Create a dataloader for the dataset
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define the real and fake labels for the discriminator loss
real_label = 1
fake_label = 0

# Train the generator and discriminator networks
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(dataloader):
        # Train the discriminator with real images
        discriminator.zero_grad()
        output = discriminator(images)
        label = torch.full((batch_size,), real_label, device=device)
        loss_d_real = criterion(output, label)
        loss_d_real.backward()
        real_d_mean = output.mean().item()

        # Train the discriminator with fake images
        z = torch.randn(batch_size, 100, 1, 1, device=device)
        fake_images = generator(z)
        output = discriminator(fake_images.detach())
        label.fill_(fake_label)
        loss_d_fake = criterion(output, label)
        loss_d_fake.backward()
        fake_d_mean = output.mean().item()

        # Update the discriminator weights
        loss_d = loss_d_real + loss_d_fake
        optimizer_d.step()

        # Train the generator
        generator.zero_grad()
        label.fill_(real_label)
        output = discriminator(fake_images)
        loss_g = criterion(output, label)
        loss_g.backward()

        # Update the generator weights
        optimizer_g.step()

        # Print the training progress
        if i % 100 == 0:
            print('[Epoch %d/%d][Batch %d/%d] Loss_D: %.4f (real: %.4f, fake: %.4f) Loss_G: %.4f'
                  % (epoch+1, num_epochs, i, len(dataloader), loss_d.item(), real_d_mean, fake_d_mean, loss_g.item()))

# Save the generator model
torch.save(generator.state_dict(), 'generator.pth')


FileNotFoundError: Couldn't find any class folder in PS2_sample_images.