In [None]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np

In [None]:

# Setup train and testing paths
train_dir = "data/Food-3/train"
test_dir = "data/Food-3/test"

data_transform = transforms.Compose(
    [transforms.Resize(size=(64, 64)), transforms.ToTensor()]  # Resize the images to 64x64*
)


train_data = datasets.ImageFolder(root=train_dir,  # target folder of images
                                  transform=data_transform,  # transforms to perform on data (images)
                                  target_transform=None)  # transforms to perform on labels (if necessary)

test_data = datasets.ImageFolder(root=test_dir,
                                 transform=data_transform)
dataset = torch.utils.data.ConcatDataset([train_data, test_data])


In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(128, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512, 1024)
        self.fc4 = nn.Linear(1024, 3072)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = torch.tanh(self.fc4(x))
        return x.view(-1, 3, 64, 64)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.fc1 = nn.Linear(128 * 16 * 16, 1)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = x.view(-1, 128 * 16 * 16)
        x = torch.sigmoid(self.fc1(x))
        return x


# Define the generator and discriminator
generator = Generator()
discriminator = Discriminator()

# Define the loss function and optimizers
loss_fn = nn.BCELoss()
gen_optimizer = torch.optim.Adam(generator.parameters())
dis_optimizer = torch.optim.Adam(discriminator.parameters())

dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [None]:
# Train the GAN
for epoch in range(100):
    for i, (real_images, _) in enumerate(dataloader):
        # Generate fake images
        noise = torch.randn(real_images.shape[0], 128)
        fake_images = generator(noise)
        
        # Train the discriminator
        if i%2 == 0:
            dis_optimizer.zero_grad()
            real_output = discriminator(real_images)
            fake_output = discriminator(fake_images.detach())
            real_loss = loss_fn(real_output, torch.ones_like(real_output))
            fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))
            dis_loss = real_loss + fake_loss
            dis_loss.backward()
            dis_optimizer.step()
            
        # Train the generator
        gen_optimizer.zero_grad()
        fake_output = discriminator(fake_images)
        gen_loss = loss_fn(fake_output, torch.ones_like(fake_output))
        gen_loss.backward()
        gen_optimizer.step()
    print(f"Epoch {epoch+1} of {100}, Generator Loss: {gen_loss.item():.4f}, Discriminator Loss: {dis_loss.item():.4f}")


In [None]:
# generate new images
noise = torch.randn(64, 128)
gen_images = generator(noise)

# plot the images
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(10, 10))
for i, ax in enumerate(axes.flatten()):
    gen_images = gen_images.clip(0, 1)
    ax.imshow(gen_images[i,:,:,:].permute(1, 2, 0).detach().numpy())
    ax.axis("off")