In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
import numpy as np
import os
from tqdm import tqdm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Adding hyperparameters

In [None]:
# Image settings
IMAGE_SIZE = 64  # DCGAN works best with 64x64 images
IMAGE_CHANNELS = 3  # RGB images

# Training settings
BATCH_SIZE = 64  # Number of images per batch
NUM_EPOCHS = 50  # How many times to go through the dataset
LEARNING_RATE = 0.0002  # Learning rate for both networks

# Network settings
LATENT_DIM = 100  # Size of random noise vector (input to Generator)
FEATURE_DIM = 64  # Base number of filters in convolution layers

# Directories
DATA_DIR = "data/wikiart"
OUTPUT_DIR = "output"
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),  # Resize to 64x64
    transforms.CenterCrop(IMAGE_SIZE),  # Crop center 64x64
    transforms.ToTensor(),  # Convert to tensor [0, 1]
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])

# Load the WikiArt dataset
dataset = datasets.ImageFolder(root=DATA_DIR, transform=transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

print(f"Total images in dataset: {len(dataset)}")
print(f"Number of classes: {len(dataset.classes)}")
print(f"Number of batches per epoch: {len(dataloader)}")

In [None]:
class Generator(nn.Module):
    """
    Generator: Takes random noise (latent vector) and creates fake images
    
    Architecture:
    - Input: Random noise vector of size LATENT_DIM (100)
    - Output: RGB image of size 64x64x3
    - Uses TransposeConv to upsample from noise to image
    """
    
    def __init__(self):
        super(Generator, self).__init__()
        
        # This network progressively upsamples: 100 -> 4x4 -> 8x8 -> 16x16 -> 32x32 -> 64x64
        self.network = nn.Sequential(
            # Input: Latent vector (100x1x1)
            nn.ConvTranspose2d(LATENT_DIM, FEATURE_DIM * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(FEATURE_DIM * 8),
            nn.ReLU(True),
            # Output: 512x4x4
            
            nn.ConvTranspose2d(FEATURE_DIM * 8, FEATURE_DIM * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(FEATURE_DIM * 4),
            nn.ReLU(True),
            # Output: 256x8x8
            
            nn.ConvTranspose2d(FEATURE_DIM * 4, FEATURE_DIM * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(FEATURE_DIM * 2),
            nn.ReLU(True),
            # Output: 128x16x16
            
            nn.ConvTranspose2d(FEATURE_DIM * 2, FEATURE_DIM, 4, 2, 1, bias=False),
            nn.BatchNorm2d(FEATURE_DIM),
            nn.ReLU(True),
            # Output: 64x32x32
            
            nn.ConvTranspose2d(FEATURE_DIM, IMAGE_CHANNELS, 4, 2, 1, bias=False),
            nn.Tanh()  # Output range: [-1, 1]
            # Output: 3x64x64
        )
    
    def forward(self, x):
        return self.network(x)

In [None]:
class Discriminator(nn.Module):
    """
    Discriminator: Classifies images as real (from dataset) or fake (from generator)
    
    Architecture:
    - Input: RGB image of size 64x64x3
    - Output: Single probability value (real vs fake)
    - Uses Conv layers to downsample image and classify
    """
    
    def __init__(self):
        super(Discriminator, self).__init__()
        
        # This network progressively downsamples: 64x64 -> 32x32 -> 16x16 -> 8x8 -> 4x4 -> 1
        self.network = nn.Sequential(
            # Input: 3x64x64
            nn.Conv2d(IMAGE_CHANNELS, FEATURE_DIM, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # Output: 64x32x32
            
            nn.Conv2d(FEATURE_DIM, FEATURE_DIM * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(FEATURE_DIM * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # Output: 128x16x16
            
            nn.Conv2d(FEATURE_DIM * 2, FEATURE_DIM * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(FEATURE_DIM * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # Output: 256x8x8
            
            nn.Conv2d(FEATURE_DIM * 4, FEATURE_DIM * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(FEATURE_DIM * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # Output: 512x4x4
            
            nn.Conv2d(FEATURE_DIM * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()  # Output range: [0, 1]
            # Output: 1x1x1
        )
    
    def forward(self, x):
        return self.network(x).view(-1, 1).squeeze(1)

In [None]:
def weights_init(m):
    """
    Initialize network weights (important for stable GAN training)
    """
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Apply weight initialization
generator.apply(weights_init)
discriminator.apply(weights_init)

# Define loss function (Binary Cross Entropy)
criterion = nn.BCELoss()

# Create optimizers (Adam optimizer works well for GANs)
optimizer_G = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

print("Models created successfully!")
print(f"Generator parameters: {sum(p.numel() for p in generator.parameters()):,}")
print(f"Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}")

# Fixed noise for visualization (to see progress during training)
fixed_noise = torch.randn(64, LATENT_DIM, 1, 1, device=device)

In [None]:
# Lists to store losses for plotting
G_losses = []
D_losses = []

print("\nStarting Training...")

for epoch in range(NUM_EPOCHS):
    for i, (real_images, _) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")):
        batch_size = real_images.size(0)
        real_images = real_images.to(device)
        
        # Create labels for real and fake images
        real_labels = torch.ones(batch_size, device=device)
        fake_labels = torch.zeros(batch_size, device=device)
        
        # ================================================================
        # (1) Update Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
        # ================================================================
        
        optimizer_D.zero_grad()
        
        # Train with real images
        output_real = discriminator(real_images)
        loss_D_real = criterion(output_real, real_labels)
        
        # Train with fake images
        noise = torch.randn(batch_size, LATENT_DIM, 1, 1, device=device)
        fake_images = generator(noise)
        output_fake = discriminator(fake_images.detach())
        loss_D_fake = criterion(output_fake, fake_labels)
        
        # Total discriminator loss
        loss_D = loss_D_real + loss_D_fake
        loss_D.backward()
        optimizer_D.step()
        
        # ================================================================
        # (2) Update Generator: maximize log(D(G(z)))
        # ================================================================
        
        optimizer_G.zero_grad()
        
        # Generate fake images and try to fool discriminator
        output = discriminator(fake_images)
        loss_G = criterion(output, real_labels)  # We want discriminator to think these are real
        
        loss_G.backward()
        optimizer_G.step()
        
        # Save losses for plotting
        if i % 50 == 0:
            G_losses.append(loss_G.item())
            D_losses.append(loss_D.item())
    
    # Print epoch statistics
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] - Loss_D: {loss_D.item():.4f}, Loss_G: {loss_G.item():.4f}")
    
    # Save generated images every 5 epochs
    if (epoch + 1) % 5 == 0:
        with torch.no_grad():
            fake = generator(fixed_noise).detach().cpu()
        img_grid = make_grid(fake, padding=2, normalize=True)
        save_image(img_grid, f"{OUTPUT_DIR}/epoch_{epoch+1}.png")
        print(f"Saved generated images for epoch {epoch+1}")

print("\nTraining Complete!")

Plotting graphs

In [None]:
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="Generator Loss")
plt.plot(D_losses, label="Discriminator Loss")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig(f"{OUTPUT_DIR}/training_loss.png")
plt.show()

Generate new image

In [None]:
def generate_images(num_images=16):
    """Generate new art images using trained generator"""
    generator.eval()
    with torch.no_grad():
        noise = torch.randn(num_images, LATENT_DIM, 1, 1, device=device)
        generated = generator(noise).cpu()
    
    # Display generated images
    fig = plt.figure(figsize=(8, 8))
    for i in range(num_images):
        plt.subplot(4, 4, i+1)
        plt.imshow(np.transpose(generated[i], (1, 2, 0)) * 0.5 + 0.5)  # Denormalize
        plt.axis('off')
    plt.tight_layout()
    plt.savefig(f"{OUTPUT_DIR}/generated_samples.png")
    plt.show()

# Generate sample images
generate_images(16)