In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms 
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image

# Constants
IMAGE_SIZE = 64
CHANNELS = 3
BATCH_SIZE = 512
Z_DIM = 128
LEARNING_RATE = 0.0002
ADAM_BETA_1 = 0.5
ADAM_BETA_2 = 0.999
EPOCHS = 200
CRITIC_STEPS = 3
GP_WEIGHT = 10.0
LOAD_MODEL = False

# Data Loading and Preprocessing
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

class CustomDataset(Dataset):
    def __init__(self, root_dir, transform):
        self.root_dir = root_dir
        self.transform = transform
        self.images = [os.path.join(root_dir, img) for img in os.listdir(root_dir) if img.endswith(".png")]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        img = Image.open(img_path).convert('RGB')
        img = self.transform(img)
        return img

train_data = CustomDataset(root_dir="/kaggle/input/gan-dataset/waqar_pics", transform=transform)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

# Generator Model
class Generator(nn.Module):
    def __init__(self, z_dim):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        self.main = nn.Sequential(
            nn.ConvTranspose2d(z_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, CHANNELS, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        x = x.view(-1, self.z_dim, 1, 1)
        return self.main(x)

# Critic Model
class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(CHANNELS, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False)
        )

    def forward(self, x):
        return self.main(x).view(-1, 1)

# Wasserstein GAN with Gradient Penalty (WGAN-GP)
class WGANGP(nn.Module):
    def __init__(self, critic, generator, latent_dim, critic_steps, gp_weight):
        super(WGANGP, self).__init__()
        self.critic = critic
        self.generator = generator
        self.latent_dim = latent_dim
        self.critic_steps = critic_steps
        self.gp_weight = gp_weight
        self.c_optimizer = optim.Adam(critic.parameters(), lr=LEARNING_RATE * 0.5, betas=(ADAM_BETA_1, ADAM_BETA_2))
        self.g_optimizer = optim.Adam(generator.parameters(), lr=LEARNING_RATE * 0.5, betas=(ADAM_BETA_1, ADAM_BETA_2))

    def gradient_penalty(self, real_images, fake_images):
        batch_size = real_images.size(0)
        alpha = torch.rand(batch_size, 1, 1, 1, device=real_images.device)
        interpolated = (alpha * real_images + (1 - alpha) * fake_images).requires_grad_(True)
        pred = self.critic(interpolated)
        gradients = torch.autograd.grad(outputs=pred, inputs=interpolated,
                                        grad_outputs=torch.ones_like(pred),
                                        create_graph=True, retain_graph=True)[0]
        gradients = gradients.view(batch_size, -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty

    def forward(self, real_images):
        batch_size = real_images.size(0)
        for _ in range(self.critic_steps):
            z = torch.randn(batch_size, self.latent_dim, 1, 1, device=real_images.device)
            fake_images = self.generator(z)
            c_loss = -(torch.mean(self.critic(real_images)) - torch.mean(self.critic(fake_images)))
            gp = self.gradient_penalty(real_images.data, fake_images.data)
            c_loss += gp * self.gp_weight

            self.c_optimizer.zero_grad()
            c_loss.backward()
            self.c_optimizer.step()

        z = torch.randn(batch_size, self.latent_dim, 1, 1, device=real_images.device)
        fake_images = self.generator(z)
        g_loss = -torch.mean(self.critic(fake_images))

        self.g_optimizer.zero_grad()
        g_loss.backward()
        self.g_optimizer.step()

        return c_loss.item(), g_loss.item()

# Initialize models and move to appropriate device
generator = Generator(Z_DIM).cuda()
critic = Critic().cuda()
wgangp = WGANGP(critic, generator, Z_DIM, CRITIC_STEPS, GP_WEIGHT)

# Training loop
for epoch in range(EPOCHS):
    for i, real_images in enumerate(train_loader):
        real_images = real_images.cuda()

        c_loss, g_loss = wgangp(real_images)

        if i % 100 == 0:
            print(f"Epoch [{epoch}/{EPOCHS}] Step [{i}/{len(train_loader)}]: "
                  f"Critic Loss: {c_loss:.4f}, Generator Loss: {g_loss:.4f}")

# Save models
torch.save(generator.state_dict(), "./models/generator.pth")
torch.save(critic.state_dict(), "./models/critic.pth")


Epoch [0/200] Step [0/18]: Critic Loss: 9.1676, Generator Loss: 0.1974
Epoch [1/200] Step [0/18]: Critic Loss: -97.2872, Generator Loss: 50.6284
Epoch [2/200] Step [0/18]: Critic Loss: -74.9557, Generator Loss: 25.4559
Epoch [3/200] Step [0/18]: Critic Loss: -60.1714, Generator Loss: 9.3193
Epoch [4/200] Step [0/18]: Critic Loss: -45.3899, Generator Loss: 1.3533
Epoch [5/200] Step [0/18]: Critic Loss: -32.1492, Generator Loss: -2.7314
Epoch [6/200] Step [0/18]: Critic Loss: -23.7207, Generator Loss: -0.1039
Epoch [7/200] Step [0/18]: Critic Loss: -20.5641, Generator Loss: 0.0613
Epoch [8/200] Step [0/18]: Critic Loss: -13.5584, Generator Loss: 18.5304
Epoch [9/200] Step [0/18]: Critic Loss: -15.8327, Generator Loss: 3.0187
Epoch [10/200] Step [0/18]: Critic Loss: -12.0708, Generator Loss: 11.4940
Epoch [11/200] Step [0/18]: Critic Loss: -12.8722, Generator Loss: 2.7178
Epoch [12/200] Step [0/18]: Critic Loss: -13.9859, Generator Loss: 4.4557
Epoch [13/200] Step [0/18]: Critic Loss: -10

RuntimeError: Parent directory ./models does not exist.