In [None]:
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import opendatasets as od
from tqdm import tqdm

In [None]:
od.download('https://www.kaggle.com/datasets/yaswanthgali/dog-images/')

In [None]:
def get_image_urls(root_folder="dog-images/images/images"):
    image_urls = []
    image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp')

    for folder_path, _, files in os.walk(root_folder):
        for file in files:
            if file.lower().endswith(image_extensions):
                relative_path = os.path.join(folder_path, file)
                url = f"/{relative_path.replace(os.sep, '/')}"
                image_urls.append(url)

    return image_urls

In [None]:
def pad_to_square(img):
    width, height = img.size
    if width == height:
        return img
    elif width > height:
        result = Image.new(img.mode, (width, width), (0, 0, 0))
        result.paste(img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(img.mode, (height, height), (0, 0, 0))
        result.paste(img, ((height - width) // 2, 0))
        return result

class GANDataset(Dataset):
    def __init__(self, image_urls, transform=None):
        self.image_urls = image_urls
        self.transform = transform or transforms.Compose([
            transforms.Lambda(pad_to_square),
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.image_urls)

    def __getitem__(self, idx):
        img_path = self.image_urls[idx].lstrip('/')  # Remove leading slash
        img = Image.open(img_path).convert('RGB')
        img = self.transform(img)
        return img
    


In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 16, 4, 2, 1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)
    

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2, 2, 1),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2, 2, 1),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2, 2, 1),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2, 2, 1),
            nn.Conv2d(512, 1, 4, 1, 1, bias=False),
        )

    def forward(self, img):
        out = self.model(img)
        return out.view(-1, 1).squeeze(1)
    



In [None]:
class GAN(pl.LightningModule):
    def __init__(self, latent_dim=100, lr=0.0002, b1=0.5, b2=0.999):
        super().__init__()
        self.save_hyperparameters()
        self.automatic_optimization = False

        # networks
        self.generator = Generator(latent_dim=self.hparams.latent_dim)
        self.discriminator = Discriminator()

        self.validation_z = torch.randn(4, self.hparams.latent_dim, 1, 1)

    def forward(self, z):
        return self.generator(z)

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy_with_logits(y_hat, y.squeeze())

    def training_step(self, batch, batch_idx):
        imgs = batch

        optimizer_g, optimizer_d = self.optimizers()

        # sample noise
        z = torch.randn(imgs.shape[0], self.hparams.latent_dim, 1, 1)
        z = z.type_as(imgs)

        # train generator
        self.toggle_optimizer(optimizer_g)
        self.generated_imgs = self(z)

        # ground truth result (ie: all fake)
        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs)

        # adversarial loss is binary cross-entropy
        g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
        self.log("g_loss", g_loss, prog_bar=True)
        self.manual_backward(g_loss)
        optimizer_g.step()
        optimizer_g.zero_grad()
        self.untoggle_optimizer(optimizer_g)

        # train discriminator
        self.toggle_optimizer(optimizer_d)

        # how well can it label as real?
        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs)

        real_loss = self.adversarial_loss(self.discriminator(imgs), valid)

        # how well can it label as fake?
        fake = torch.zeros(imgs.size(0), 1)
        fake = fake.type_as(imgs)

        fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake)

        # discriminator loss is the average of these
        d_loss = (real_loss + fake_loss) / 2
        self.log("d_loss", d_loss, prog_bar=True)
        self.manual_backward(d_loss)
        optimizer_d.step()
        optimizer_d.zero_grad()
        self.untoggle_optimizer(optimizer_d)

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d]

    def on_train_epoch_end(self):
        z = self.validation_z.type_as(self.generator.model[0].weight)

        # generate images
        sample_imgs = self(z)
        
        # Convert images from tensor to numpy array
        sample_imgs = sample_imgs.detach().cpu().numpy()
        
        # Rescale images from [-1, 1] to [0, 1]
        sample_imgs = (sample_imgs + 1) / 2.0
        
        # Create a 2x2 grid of images
        fig, axs = plt.subplots(2, 2, figsize=(8, 8))
        for i in range(2):
            for j in range(2):
                img = sample_imgs[i*2 + j].transpose(1, 2, 0)  # Change from (C, H, W) to (H, W, C)
                axs[i, j].imshow(img)
                axs[i, j].axis('off')
        
        plt.suptitle(f'Generated Images at Epoch {self.current_epoch}')
        plt.tight_layout()
        plt.show()

In [None]:
image_urls = get_image_urls()
dataset = GANDataset(image_urls)
dataloader = DataLoader(dataset, batch_size=512, shuffle=True, num_workers=4)


model = GAN()
trainer = pl.Trainer(max_epochs=200, accelerator="auto", devices="auto", strategy="auto", precision="16")
trainer.fit(model, dataloader)

torch.save(model.state_dict(), 'gan_model.pth')