In [1]:
import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

# Generator
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()

        self.model = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)

# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(3, 128, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img)

# Load the dataset
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=2)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the generator and discriminator
latent_dim = 100

generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)

# Set up the optimizers and loss function
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

loss_function = nn.BCELoss()

from time import time 

# Training function
def train_gan(epochs):
    for epoch in range(epochs):
        for i, (imgs, _) in enumerate(dataloader):
            batch_size = imgs.size(0)

            # Generate the labels for real and fake images
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            # Train the discriminator
            optimizer_D.zero_grad()

            real_imgs = imgs.to(device)
            real_validity = discriminator(real_imgs).view(-1, 1)
            real_loss = loss_function(real_validity, real_labels)

            z = torch.randn(batch_size, latent_dim, 1, 1).to(device)
            fake_imgs = generator(z)
            fake_validity = discriminator(fake_imgs.detach()).view(-1, 1)
            fake_loss = loss_function(fake_validity, fake_labels)

            d_loss = real_loss + fake_loss
            d_loss.backward()
            optimizer_D.step()

            # Train the generator
            optimizer_G.zero_grad()

            validity = discriminator(fake_imgs).view(-1, 1)
            g_loss = loss_function(validity, real_labels)
            g_loss.backward()
            optimizer_G.step()

            if i % 100 == 0:
                print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")

# Train the GAN
start = time()
train_gan(epochs=50)
end = time()
print(f"Time taken to train model : {end - start} seconds")




Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 29518385.91it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
[Epoch 0/50] [Batch 0/782] [D loss: 1.4623669385910034] [G loss: 2.99706768989563]
[Epoch 0/50] [Batch 100/782] [D loss: 0.04832854121923447] [G loss: 10.274909019470215]
[Epoch 0/50] [Batch 200/782] [D loss: 0.270407497882843] [G loss: 5.213973045349121]
[Epoch 0/50] [Batch 300/782] [D loss: 0.6626717448234558] [G loss: 5.145402908325195]
[Epoch 0/50] [Batch 400/782] [D loss: 0.47421324253082275] [G loss: 2.8240904808044434]
[Epoch 0/50] [Batch 500/782] [D loss: 0.9983822703361511] [G loss: 2.337174654006958]
[Epoch 0/50] [Batch 600/782] [D loss: 0.2577119469642639] [G loss: 3.5086512565612793]
[Epoch 0/50] [Batch 700/782] [D loss: 2.638639450073242] [G loss: 3.0727434158325195]
[Epoch 1/50] [Batch 0/782] [D loss: 1.1577092409133911] [G loss: 0.919029951095581]
[Epoch 1/50] [Batch 100/782] [D loss: 0.5643131136894226] [G loss: 4.040186405181885]
[Epoch 1/50] [Batch 200/782] [D loss: 0.2333410382270813] [G loss: 3.3695459365844727]
[Ep

In [5]:
import torchvision.utils as vutils
import numpy as np

In [10]:
# Generate and save an image
def generate_image(generator, latent_dim, device):
    generator.eval()  # Set the generator to evaluation mode

    # Sample a random point in the latent space
    z = torch.randn(1, latent_dim, 1, 1).to(device)

    # Generate an image using the generator
    with torch.no_grad():
        generated_img = generator(z)

    # Denormalize the image and convert it to a NumPy array
    generated_img = (generated_img.detach().cpu().numpy() + 1) / 2.0
    generated_img = np.transpose(generated_img, (0, 2, 3, 1))

    return generated_img[0]

generated_img = generate_image(generator, latent_dim, device)
vutils.save_image(torch.from_numpy(generated_img).permute(2, 0, 1), "generated_image.png")