In [1]:
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.optim as optim
import os
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
import math
import itertools
import torchvision.utils as vutils




device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [2]:
# parameters:

batch_size = 32  # Adjust as needed
source_data_path = 'data/source_data'
img_size = 128
img_channels = 3
real_label = 1
fake_label = 0

# net params
latent_dim_values = [1,2,4, 8, 16, 32, 64, 128]
d_hidden_values = [64]
g_hidden_values = [64]


# Define the optimizers
lr_gen_values = [0.001, 0.0005, 0.0002, 0.0001, 0.00005]
beta1_gen_values = [0.5, 0.9]
beta2_gen_values = [0.999]
lr_dis_values = [0.001, 0.0005, 0.0002, 0.0001, 0.00005]
beta1_dis_values = [0.5, 0.9]
beta2_dis_values = [0.999]

num_epochs = 5 # Adjust as needed

lambda_gp = 0.5




# Data preperation

In [3]:
data_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 
])

dataset = ImageFolder(root=source_data_path, transform=data_transform)

data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)



In [4]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [58]:
# Generator Network
class Generator(nn.Module):
    def __init__(self, latent_dim, image_channels, img_size, g_hidden):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input layer
            nn.ConvTranspose2d(latent_dim, g_hidden * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(g_hidden * 8),
            nn.ReLU(True),
            # 1st hidden layer
            nn.ConvTranspose2d(g_hidden * 8, g_hidden * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(g_hidden * 4),
            nn.ReLU(True),
            # 2nd hidden layer
            nn.ConvTranspose2d(g_hidden * 4, g_hidden * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(g_hidden * 2),
            nn.ReLU(True),
            # 3rd hidden layer
            nn.ConvTranspose2d(g_hidden * 2, g_hidden, 4, 2, 1, bias=False),
            nn.BatchNorm2d(g_hidden),
            nn.ReLU(True),
            # output layer
            nn.ConvTranspose2d(g_hidden, img_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)
# Discriminator Network
class Discriminator(nn.Module):
    def __init__(self, image_channels, img_size, d_hidden):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # 1st layer
            nn.Conv2d(img_channels, d_hidden, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 2nd layer
            nn.Conv2d(d_hidden, d_hidden * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(d_hidden * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 3rd layer
            nn.Conv2d(d_hidden * 2, d_hidden * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(d_hidden * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 4th layer
            nn.Conv2d(d_hidden * 4, d_hidden * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(d_hidden * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # output layer
            nn.Conv2d(d_hidden * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input).view(-1, 1).squeeze(1)

# Generator Network
class Generator(nn.Module):
    def __init__(self, latent_dim, img_channels, img_size, g_hidden):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, img_channels * img_size * img_size),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)

# Discriminator Network
class Discriminator(nn.Module):
    def __init__(self, img_channels, img_size, d_hidden):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_channels * img_size * img_size, 1024),
            nn.LeakyReLU(0.01),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.01),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [59]:
def train_gan(n_epochs, latent_dim, lr_gen, beta1_gen, beta2_gen, lr_dis, beta1_dis, beta2_dis, g_hidden, d_hidden,img_channels, img_size):
    
    # Create GAN
    generator = Generator(latent_dim, img_channels, img_size, g_hidden)
    generator.apply(weights_init)
    discriminator = Discriminator(img_channels, img_size, d_hidden)
    discriminator.apply(weights_init)
    adversarial_loss = nn.BCELoss()

    # noise
    viz_noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)

    # Move to GPU
    generator = generator.to(device)
    discriminator = discriminator.to(device)
    adversarial_loss = adversarial_loss.to(device)

    # Create optimizers for the generator and discriminator
    generator_optimizer = optim.Adam(generator.parameters(), lr=lr_gen, betas=(beta1_gen, beta2_gen))
    discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=lr_dis, betas=(beta1_dis, beta2_dis))

    generator_losses = []
    discriminator_losses = []

    print_interval = 10

    for epoch in range(n_epochs):
        for batch_idx, (real_images, _) in enumerate(data_loader):
            real_images = real_images.to(device)

            # Training the discriminator
            discriminator_optimizer.zero_grad()

            z = torch.randn(real_images.size(0), latent_dim, device=device)
            fake_images = generator(z).view(real_images.size())

            real_images_flat = real_images.view(real_images.size(0), -1)
            fake_images_flat = fake_images.view(fake_images.size(0), -1)

            real_labels = torch.ones(real_images.size(0), 1, device=device) * 0.9  # Label smoothing
            fake_labels = torch.zeros(real_images.size(0), 1, device=device)
            real_loss = adversarial_loss(discriminator(real_images_flat), real_labels)
            fake_loss = adversarial_loss(discriminator(fake_images_flat.detach()), fake_labels)
            discriminator_loss = real_loss + fake_loss
            


            discriminator_loss.backward()
            discriminator_optimizer.step()



            # Training the generator
            generator_optimizer.zero_grad()

            z = torch.randn(real_images.size(0), latent_dim, device=device)
            fake_images = generator(z).view(real_images.size())

            fake_images_flat = fake_images.view(fake_images.size(0), -1)

            generator_loss = adversarial_loss(discriminator(fake_images_flat), real_labels)

            generator_loss.backward()
            generator_optimizer.step()

            discriminator_losses.append(discriminator_loss.item())
            generator_losses.append(generator_loss.item())

            if batch_idx % print_interval == 0:
                print(f"Epoch [{epoch}/{num_epochs}], Batch [{batch_idx}/{len(data_loader)}], "
                      f"Discriminator Loss: {discriminator_loss.item():.4f}, "
                      f"Generator Loss: {generator_loss.item():.4f}")
    return discriminator_losses, generator_losses ,generator, discriminator


In [60]:
# Plot loss curves
def plot_train_losses(generator_losses, discriminator_losses):
    plt.figure(figsize=(10, 5))
    plt.plot(discriminator_losses, label='Discriminator Loss')
    plt.plot(generator_losses, label='Generator Loss')
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('GAN Training Loss')
    plt.show()


In [61]:
def generate_n_image(num_images_to_generate, generator, latent_dim):
    generator.eval()
    # Generate random noise
    z = torch.randn(num_images_to_generate, latent_dim, device=device)
    # Generate images using the generator
    generated_images = generator(z).view(num_images_to_generate, img_channels, img_size, img_size)

    fig, axes = plt.subplots(1, num_images_to_generate, figsize=(15, 3))

    for i in range(num_images_to_generate):
        img_np = F.to_pil_image(generated_images[i].detach())
        axes[i].imshow(img_np)
        axes[i].axis('off')

    plt.show()

    generator.train()
#generate_n_image(5,generator, latent_dim_values[2])


In [62]:
#latent_dim:  4 lr_gen:  0.001 beta1_gen:  0.5 beta2_gen:  0.999 lr_dis:  0.0002 beta1_dis:  0.5 beta2_dis:  0.999
generator_losses, discriminator_losses, generator, discriminator = train_gan(
    num_epochs,
    latent_dim_values[2],
    lr_gen_values[1],
    beta1_gen_values[0],
    beta2_gen_values[0],
    lr_dis_values[2],
    beta1_dis_values[0],
    beta2_dis_values[0],
    g_hidden_values[0],
    d_hidden_values[0],
    img_channels,
    img_size
)
plot_train_losses(generator_losses, discriminator_losses)
generate_n_image(5,generator, latent_dim_values[2])


Epoch [0/5], Batch [0/38], Discriminator Loss: 1.3976, Generator Loss: 1.7852
Epoch [0/5], Batch [10/38], Discriminator Loss: 0.7890, Generator Loss: 3.5531
Epoch [0/5], Batch [20/38], Discriminator Loss: 0.9192, Generator Loss: 3.4375
Epoch [0/5], Batch [30/38], Discriminator Loss: 0.8320, Generator Loss: 2.9094
Epoch [1/5], Batch [0/38], Discriminator Loss: 0.8017, Generator Loss: 2.4348
Epoch [1/5], Batch [10/38], Discriminator Loss: 1.0712, Generator Loss: 1.4518
Epoch [1/5], Batch [20/38], Discriminator Loss: 0.8958, Generator Loss: 1.2222
Epoch [1/5], Batch [30/38], Discriminator Loss: 0.9125, Generator Loss: 1.3649


In [None]:
def hparams_tuning(num_epochs):
    res = []
    iters = 0
    for (latent_dim, lr_gen, beta1_gen, beta2_gen, lr_dis, beta1_dis, beta2_dis, g_hidden, d_hidden) in itertools.product(
        latent_dim_values, lr_gen_values, beta1_gen_values, beta2_gen_values,
        lr_dis_values, beta1_dis_values, beta2_dis_values, g_hidden_values, d_hidden_values):
        print('--------------- iteration: ', iters, '-----------------')
        print('latent_dim: ', latent_dim, 'lr_gen: ', lr_gen, 'beta1_gen: ', beta1_gen, 'beta2_gen: ', beta2_gen, 'lr_dis: ', lr_dis, 'beta1_dis: ', beta1_dis, 'beta2_dis: ', beta2_dis)
        print('------------------------------------------------------')
        generator_losses, discriminator_losses, generator, discriminator = train_gan(num_epochs, latent_dim, lr_gen, beta1_gen, beta2_gen, lr_dis, beta1_dis, beta2_dis, g_hidden, d_hidden, img_channels, img_size)
        generate_n_image(5,generator, latent_dim)
        plot_train_losses(generator_losses, discriminator_losses)
        res.append((generator_losses, discriminator_losses, generator, discriminator))
        iters = iters +1

    return res

In [None]:

a = hparams_tuning(num_epochs)