In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
import numpy as np
import os
from tqdm import tqdm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Adding hyperparameters

In [None]:
IMAGE_SIZE = 64  
IMAGE_CHANNELS = 3

BATCH_SIZE = 64
NUM_EPOCHS = 50
LEARNING_RATE = 0.0005

LATENT_DIM = 100 # The random noise that we are going to tgive to the Generator
FEATURE_DIM = 64

DATA_DIR = "/kaggle/input/wikiart"
OUTPUT_DIR = "output"
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

full_dataset = datasets.ImageFolder(root=DATA_DIR, transform=transform)

ALL_CLASSES = ['Abstract_Expressionism', 'Impressionism', 'Expressionism', 
                    'Cubism', 'Post_Impressionism']
SELECTED_CLASSES = [ALL_CLASSES[3]]

class_to_idx = full_dataset.class_to_idx
selected_indices = []

for class_name in SELECTED_CLASSES:
    possible_names = [class_name, class_name.lower(), class_name.replace('_', '-')]
    
    for name in possible_names:
        if name in class_to_idx:
            class_idx = class_to_idx[name]
            indices = [i for i, (_, label) in enumerate(full_dataset.samples) if label == class_idx]
            selected_indices.extend(indices)
            print(f"Found class '{name}' with {len(indices)} images")
            break

dataset = torch.utils.data.Subset(full_dataset, selected_indices)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

print(f"\nFiltered dataset:")
print(f"Total images in selected classes: {len(dataset)}")
print(f"Number of batches per epoch: {len(dataloader)}")

In [None]:
class Generator(nn.Module):
    
    def __init__(self):
        super(Generator, self).__init__()
        
        # Upsampling -> 100 -> 4x4 -> 8x8 -> 16x16 -> 32x32 -> 64x64
        self.network = nn.Sequential(
            # Input: Latent vector (100x1x1)
            nn.ConvTranspose2d(LATENT_DIM, FEATURE_DIM * 8, 4, 1, 0, bias=False), # Kernel =4, Stride=1, Padding=0
            nn.BatchNorm2d(FEATURE_DIM * 8),
            nn.ReLU(True),
            # Output: 512x4x4
            
            nn.ConvTranspose2d(FEATURE_DIM * 8, FEATURE_DIM * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(FEATURE_DIM * 4),
            nn.ReLU(True),
            # Output: 256x8x8
            
            nn.ConvTranspose2d(FEATURE_DIM * 4, FEATURE_DIM * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(FEATURE_DIM * 2),
            nn.ReLU(True),
            # Output: 128x16x16
            
            nn.ConvTranspose2d(FEATURE_DIM * 2, FEATURE_DIM, 4, 2, 1, bias=False),
            nn.BatchNorm2d(FEATURE_DIM),
            nn.ReLU(True),
            # Output: 64x32x32
            
            nn.ConvTranspose2d(FEATURE_DIM, IMAGE_CHANNELS, 4, 2, 1, bias=False),
            nn.Tanh()  # Output range: [-1, 1]
            # Output: 3x64x64
        )
    
    def forward(self, x):
        return self.network(x)

In [None]:
class Discriminator(nn.Module):
    
    def __init__(self):
        super(Discriminator, self).__init__()
        
        # Downsampling -> 64x64 -> 32x32 -> 16x16 -> 8x8 -> 4x4 -> 1
        self.network = nn.Sequential(
            # Input: 3x64x64
            nn.Conv2d(IMAGE_CHANNELS, FEATURE_DIM, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # Output: 64x32x32
            
            nn.Conv2d(FEATURE_DIM, FEATURE_DIM * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(FEATURE_DIM * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # Output: 128x16x16
            
            nn.Conv2d(FEATURE_DIM * 2, FEATURE_DIM * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(FEATURE_DIM * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # Output: 256x8x8
            
            nn.Conv2d(FEATURE_DIM * 4, FEATURE_DIM * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(FEATURE_DIM * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # Output: 512x4x4
            
            nn.Conv2d(FEATURE_DIM * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.network(x).view(-1, 1).squeeze(1)

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02) # mean=0, std=0.02
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02) # normal distribution near to 1.
        nn.init.constant_(m.bias.data, 0) # bias =0

In [None]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)

generator.apply(weights_init)
discriminator.apply(weights_init)

criterion = nn.BCELoss()

optimizer_G = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

print("Models created successfully!")
print(f"Generator parameters: {sum(p.numel() for p in generator.parameters()):,}")
print(f"Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}")

fixed_noise = torch.randn(64, LATENT_DIM, 1, 1, device=device)

In [None]:
G_losses = []
D_losses = []

print("\nStarting Training...")

for epoch in range(NUM_EPOCHS):
    for i, (real_images, _) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")):
        batch_size = real_images.size(0)
        real_images = real_images.to(device)
        
        real_labels = torch.ones(batch_size, device=device)
        fake_labels = torch.zeros(batch_size, device=device)
        
        optimizer_D.zero_grad()
        
        output_real = discriminator(real_images)
        loss_D_real = criterion(output_real, real_labels)
        
        noise = torch.randn(batch_size, LATENT_DIM, 1, 1, device=device)
        fake_images = generator(noise)
        output_fake = discriminator(fake_images.detach())
        loss_D_fake = criterion(output_fake, fake_labels)
        
        loss_D = loss_D_real + loss_D_fake
        loss_D.backward()
        optimizer_D.step()
        
        optimizer_G.zero_grad()
        
        output = discriminator(fake_images)
        loss_G = criterion(output, real_labels
        
        loss_G.backward()
        optimizer_G.step()
        
        if i % 50 == 0:
            G_losses.append(loss_G.item())
            D_losses.append(loss_D.item())
    
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] - Loss_D: {loss_D.item():.4f}, Loss_G: {loss_G.item():.4f}")
    
    if (epoch + 1) % 5 == 0:
        with torch.no_grad():
            fake = generator(fixed_noise).detach().cpu()
        img_grid = make_grid(fake, padding=2, normalize=True)
        save_image(img_grid, f"{OUTPUT_DIR}/epoch_{epoch+1}.png")
        print(f"Saved generated images for epoch {epoch+1}")

print("\nTraining Complete!")

Plotting graphs

In [None]:
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="Generator Loss")
plt.plot(D_losses, label="Discriminator Loss")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig(f"{OUTPUT_DIR}/training_loss.png")
plt.show()

Generate new image

In [None]:
def generate_images(num_images=16):
    generator.eval()
    with torch.no_grad():
        noise = torch.randn(num_images, LATENT_DIM, 1, 1, device=device)
        generated = generator(noise).cpu()
    
    fig = plt.figure(figsize=(8, 8))
    for i in range(num_images):
        plt.subplot(4, 4, i+1)
        plt.imshow(np.transpose(generated[i], (1, 2, 0)) * 0.5 + 0.5) 
        plt.axis('off')
    plt.tight_layout()
    plt.show()

generate_images(16)

print("\n All done!")