In [2]:
#In this code block implement our kaggle.json api key to use the dataset
from google.colab import files
files.upload()


Saving kaggle.json to kaggle.json


{'kaggle.json': b'{"username":"cakmak61","key":"55d16274829171b5e05a5792176e1508"}'}

In [3]:
!pip install -q kaggle

# Here we move our api key to the correct folder
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

# Then Download Stanford Cars Dataset
!kaggle datasets download -d jutrera/stanford-car-dataset-by-classes-folder

# Unzip the dataset
!unzip -q stanford-car-dataset-by-classes-folder.zip -d data/


Dataset URL: https://www.kaggle.com/datasets/jutrera/stanford-car-dataset-by-classes-folder
License(s): other
Downloading stanford-car-dataset-by-classes-folder.zip to /content
100% 1.82G/1.83G [00:08<00:00, 320MB/s]
100% 1.83G/1.83G [00:08<00:00, 244MB/s]


In [None]:
# Here we import the related libraries
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm

# WE set device settings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Our Hyperparameters
batch_size = 64
image_size = 64
channels = 3
latent_dim = 128
n_epochs = 200
lambda_gp = 10
n_critic = 5
lr = 1e-4
beta1 = 0.0
beta2 = 0.9

# Data transformations
transform = transforms.Compose([
    transforms.Resize(64),  # veya 128
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3),
])

# Uploading the dataset
data_root = './data/stanford_cars'
train_dataset = datasets.ImageFolder(root='data/car_data/car_data/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)

# Generator Definition
class Generator(nn.Module):
    def __init__(self, latent_dim, img_channels):
        super(Generator, self).__init__()
        self.init_size = image_size // 16
        self.l1 = nn.Sequential(nn.Linear(latent_dim, 512 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(512),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, img_channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 512, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

# Definition of Discriminator (Critic)
class Discriminator(nn.Module):
    def __init__(self, img_channels):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters))  # Instead of layer norm, we have utilized BatchNorm
            block.append(nn.LeakyReLU(0.2, inplace=True))
            return block

        self.model = nn.Sequential(
            *discriminator_block(img_channels, 64, bn=False),  # 64x64 -> 32x32
            *discriminator_block(64, 128),                      # 32x32 -> 16x16
            *discriminator_block(128, 256),                     # 16x16 -> 8x8
            *discriminator_block(256, 512),                     # 8x8 -> 4x4
        )

        ds_size = image_size // 2 ** 4  # 64 -> 4
        self.adv_layer = nn.Linear(512 * ds_size ** 2, 1)

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        return validity


# Calculating Gradient penalty
def compute_gradient_penalty(D, real_samples, fake_samples):
    alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = torch.ones(real_samples.size(0), 1, device=device, requires_grad=False)
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

# Model and Optimization
generator = Generator(latent_dim, channels).to(device)
discriminator = Discriminator(channels).to(device)

optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))

# Training Loop
os.makedirs("wgan_gp_outputs", exist_ok=True)

for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(tqdm(train_loader)):
        real_imgs = imgs.to(device)

        #  Training Discriminator
        for _ in range(n_critic):
            z = torch.randn(imgs.size(0), latent_dim, device=device)
            fake_imgs = generator(z).detach()
            real_validity = discriminator(real_imgs)
            fake_validity = discriminator(fake_imgs)
            gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data)
            d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty

            optimizer_D.zero_grad()
            d_loss.backward()
            optimizer_D.step()

        # Training Generator
        z = torch.randn(imgs.size(0), latent_dim, device=device)
        gen_imgs = generator(z)
        g_loss = -torch.mean(discriminator(gen_imgs))

        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

    print(f"[Epoch {epoch+1}/{n_epochs}] D loss: {d_loss.item():.4f} | G loss: {g_loss.item():.4f}")

    # Recording the sample images
    if (epoch + 1) % 10 == 0 or epoch == 0:
        generator.eval()
        with torch.no_grad():
            sample_z = torch.randn(4, latent_dim, device=device)
            samples = generator(sample_z)
            grid = make_grid(samples.cpu(), nrow=2, normalize=True)
            save_image(grid, f"wgan_gp_outputs/sample_epoch_{epoch+1:03d}.png")
        generator.train()

print("Training is completed.")


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
100%|██████████| 128/128 [00:27<00:00,  4.59it/s]


[Epoch 1/200] D loss: -16.1672 | G loss: 26.8570


100%|██████████| 128/128 [00:26<00:00,  4.77it/s]


[Epoch 2/200] D loss: -32.3738 | G loss: 30.1733


100%|██████████| 128/128 [00:26<00:00,  4.75it/s]


[Epoch 3/200] D loss: -94.3277 | G loss: 43.4314


100%|██████████| 128/128 [00:26<00:00,  4.85it/s]


[Epoch 4/200] D loss: -19.4803 | G loss: 37.3274


100%|██████████| 128/128 [00:27<00:00,  4.58it/s]


[Epoch 5/200] D loss: -78.0898 | G loss: 24.8689


100%|██████████| 128/128 [00:26<00:00,  4.80it/s]


[Epoch 6/200] D loss: -36.9418 | G loss: 68.2932


100%|██████████| 128/128 [00:26<00:00,  4.77it/s]


[Epoch 7/200] D loss: -49.2677 | G loss: 83.5457


100%|██████████| 128/128 [00:27<00:00,  4.73it/s]


[Epoch 8/200] D loss: 59.3851 | G loss: -51.7170


100%|██████████| 128/128 [00:26<00:00,  4.78it/s]


[Epoch 9/200] D loss: -387.9231 | G loss: 293.1530


100%|██████████| 128/128 [00:26<00:00,  4.78it/s]


[Epoch 10/200] D loss: -478.5959 | G loss: 13.4427


100%|██████████| 128/128 [00:26<00:00,  4.81it/s]


[Epoch 11/200] D loss: -335.9915 | G loss: 17.4013


100%|██████████| 128/128 [00:27<00:00,  4.63it/s]


[Epoch 12/200] D loss: -140.2918 | G loss: -42.3727


100%|██████████| 128/128 [00:27<00:00,  4.58it/s]


[Epoch 13/200] D loss: -1038.6927 | G loss: 12.1145


100%|██████████| 128/128 [00:26<00:00,  4.81it/s]


[Epoch 14/200] D loss: -186.1796 | G loss: 198.8419


100%|██████████| 128/128 [00:26<00:00,  4.75it/s]


[Epoch 15/200] D loss: -1747.4446 | G loss: 417.6819


100%|██████████| 128/128 [00:26<00:00,  4.84it/s]


[Epoch 16/200] D loss: -1411.3182 | G loss: 758.9375


100%|██████████| 128/128 [00:26<00:00,  4.79it/s]


[Epoch 17/200] D loss: -2546.7358 | G loss: 1071.6855


100%|██████████| 128/128 [00:26<00:00,  4.79it/s]


[Epoch 18/200] D loss: 160.4355 | G loss: 1184.8698


100%|██████████| 128/128 [00:26<00:00,  4.89it/s]


[Epoch 19/200] D loss: -2401.4375 | G loss: 1662.4565


100%|██████████| 128/128 [00:26<00:00,  4.84it/s]


[Epoch 20/200] D loss: -4355.4697 | G loss: 1692.8518


100%|██████████| 128/128 [00:26<00:00,  4.80it/s]


[Epoch 21/200] D loss: -5771.8184 | G loss: 2234.8459


100%|██████████| 128/128 [00:26<00:00,  4.81it/s]


[Epoch 22/200] D loss: -4442.5308 | G loss: 2987.8047


100%|██████████| 128/128 [00:26<00:00,  4.88it/s]


[Epoch 23/200] D loss: -249.8095 | G loss: 3117.7871


100%|██████████| 128/128 [00:26<00:00,  4.74it/s]


[Epoch 24/200] D loss: -7343.3687 | G loss: 3453.4233


100%|██████████| 128/128 [00:26<00:00,  4.75it/s]


[Epoch 25/200] D loss: -10014.4795 | G loss: 4930.6362


100%|██████████| 128/128 [00:26<00:00,  4.81it/s]


[Epoch 26/200] D loss: -6420.6240 | G loss: 4796.1885


100%|██████████| 128/128 [00:26<00:00,  4.75it/s]


[Epoch 27/200] D loss: -3123.8765 | G loss: 4674.9580


100%|██████████| 128/128 [00:26<00:00,  4.82it/s]


[Epoch 28/200] D loss: -9653.2090 | G loss: 5431.8389


100%|██████████| 128/128 [00:26<00:00,  4.79it/s]


[Epoch 29/200] D loss: 730.7344 | G loss: 5718.0288


100%|██████████| 128/128 [00:26<00:00,  4.87it/s]


[Epoch 30/200] D loss: -12749.6611 | G loss: 6705.9805


100%|██████████| 128/128 [00:26<00:00,  4.83it/s]


[Epoch 31/200] D loss: -13733.9541 | G loss: 6996.3906


100%|██████████| 128/128 [00:26<00:00,  4.85it/s]


[Epoch 32/200] D loss: -16267.8340 | G loss: 7523.6753


100%|██████████| 128/128 [00:26<00:00,  4.85it/s]


[Epoch 33/200] D loss: -15956.3193 | G loss: 7965.6992


100%|██████████| 128/128 [00:26<00:00,  4.77it/s]


[Epoch 34/200] D loss: -18343.8281 | G loss: 8619.1426


100%|██████████| 128/128 [00:26<00:00,  4.81it/s]


[Epoch 35/200] D loss: -21510.3008 | G loss: 10065.7227


100%|██████████| 128/128 [00:26<00:00,  4.85it/s]


[Epoch 36/200] D loss: -37.3766 | G loss: 10550.2441


100%|██████████| 128/128 [00:26<00:00,  4.85it/s]


[Epoch 37/200] D loss: -21850.9902 | G loss: 6363.9604


100%|██████████| 128/128 [00:27<00:00,  4.73it/s]


[Epoch 38/200] D loss: 1813.4484 | G loss: 11436.2734


100%|██████████| 128/128 [00:26<00:00,  4.74it/s]


[Epoch 39/200] D loss: -9821.9053 | G loss: 8487.6367


100%|██████████| 128/128 [00:25<00:00,  4.92it/s]


[Epoch 40/200] D loss: -26569.4551 | G loss: 10242.6719


100%|██████████| 128/128 [00:25<00:00,  4.94it/s]


[Epoch 41/200] D loss: -4943.7295 | G loss: 2190.6440


100%|██████████| 128/128 [00:26<00:00,  4.83it/s]


[Epoch 42/200] D loss: 2054.7026 | G loss: 10880.3486


100%|██████████| 128/128 [00:26<00:00,  4.81it/s]


[Epoch 43/200] D loss: -14693.9912 | G loss: 12947.6670


100%|██████████| 128/128 [00:26<00:00,  4.85it/s]


[Epoch 44/200] D loss: -25422.9160 | G loss: 14282.6758


100%|██████████| 128/128 [00:26<00:00,  4.82it/s]


[Epoch 45/200] D loss: -18729.1602 | G loss: 15598.6621


100%|██████████| 128/128 [00:26<00:00,  4.75it/s]


[Epoch 46/200] D loss: 12233.1836 | G loss: 15571.3418


100%|██████████| 128/128 [00:27<00:00,  4.67it/s]


[Epoch 47/200] D loss: -15774.7754 | G loss: 13757.8125


100%|██████████| 128/128 [00:26<00:00,  4.83it/s]


[Epoch 48/200] D loss: -24013.6484 | G loss: 18066.6250


100%|██████████| 128/128 [00:25<00:00,  4.94it/s]


[Epoch 49/200] D loss: -39867.1289 | G loss: 18782.6406


100%|██████████| 128/128 [00:26<00:00,  4.79it/s]


[Epoch 50/200] D loss: 307.9802 | G loss: 19362.1523


100%|██████████| 128/128 [00:26<00:00,  4.80it/s]


[Epoch 51/200] D loss: 1203.2461 | G loss: 19455.0820


100%|██████████| 128/128 [00:26<00:00,  4.78it/s]


[Epoch 52/200] D loss: -11698.6250 | G loss: -14885.4434


100%|██████████| 128/128 [00:26<00:00,  4.82it/s]


[Epoch 53/200] D loss: -42380.7734 | G loss: 20805.7832


100%|██████████| 128/128 [00:26<00:00,  4.85it/s]


[Epoch 54/200] D loss: 27437.8008 | G loss: 20344.6641


100%|██████████| 128/128 [00:26<00:00,  4.80it/s]


[Epoch 55/200] D loss: 22953.5781 | G loss: 20675.4375


100%|██████████| 128/128 [00:26<00:00,  4.78it/s]


[Epoch 56/200] D loss: -42361.1289 | G loss: 22263.3398


100%|██████████| 128/128 [00:26<00:00,  4.82it/s]


[Epoch 57/200] D loss: 762.5079 | G loss: 25008.6680


100%|██████████| 128/128 [00:26<00:00,  4.77it/s]


[Epoch 58/200] D loss: 315.5729 | G loss: 25423.1406


100%|██████████| 128/128 [00:26<00:00,  4.84it/s]


[Epoch 59/200] D loss: -56373.1758 | G loss: 26100.9023


100%|██████████| 128/128 [00:26<00:00,  4.83it/s]


[Epoch 60/200] D loss: -22525.7461 | G loss: 2978.9893


100%|██████████| 128/128 [00:26<00:00,  4.89it/s]


[Epoch 61/200] D loss: -8752.5244 | G loss: 26089.3789


100%|██████████| 128/128 [00:26<00:00,  4.87it/s]


[Epoch 62/200] D loss: -58981.7539 | G loss: 28937.8633


100%|██████████| 128/128 [00:26<00:00,  4.87it/s]


[Epoch 63/200] D loss: -51605.2500 | G loss: 27256.7227


100%|██████████| 128/128 [00:26<00:00,  4.77it/s]


[Epoch 64/200] D loss: 1142.8862 | G loss: 29956.9023


100%|██████████| 128/128 [00:26<00:00,  4.86it/s]


[Epoch 65/200] D loss: -60108.1914 | G loss: 25817.0039


100%|██████████| 128/128 [00:26<00:00,  4.77it/s]


[Epoch 66/200] D loss: 683.4724 | G loss: 31661.3301


100%|██████████| 128/128 [00:27<00:00,  4.66it/s]


[Epoch 67/200] D loss: -5376.6362 | G loss: 30309.9688


100%|██████████| 128/128 [00:26<00:00,  4.88it/s]


[Epoch 68/200] D loss: -63989.5312 | G loss: 32968.0078


100%|██████████| 128/128 [00:26<00:00,  4.78it/s]


[Epoch 69/200] D loss: -54865.7383 | G loss: 29313.2539


100%|██████████| 128/128 [00:26<00:00,  4.83it/s]


[Epoch 70/200] D loss: 18415.8887 | G loss: 32201.6758


100%|██████████| 128/128 [00:26<00:00,  4.76it/s]


[Epoch 71/200] D loss: -73836.0312 | G loss: 35492.8242


100%|██████████| 128/128 [00:26<00:00,  4.80it/s]


[Epoch 72/200] D loss: -56797.9453 | G loss: 36010.2656


100%|██████████| 128/128 [00:27<00:00,  4.64it/s]


[Epoch 73/200] D loss: 13609.1035 | G loss: 34217.0273


100%|██████████| 128/128 [00:27<00:00,  4.69it/s]


[Epoch 74/200] D loss: 31634.8457 | G loss: 32979.0938


 79%|███████▉  | 101/128 [00:21<00:04,  5.64it/s]