# GAN and DCGAN task from Lecture
### Transfered to a this directory to keep sanity between the lecture notes and Kaleb's implementation
Task description: "**Task**: Try generating more complex color images with DCGAN - for example, take one class from CIFAR-10 dataset."

In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt
from torchvision import transforms
from torch import nn
from torch import optim
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchinfo import summary
torch.manual_seed(42)
np.random.seed(42)

In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

# Shared Requirements

In [None]:
batch_size = 16
epochs = 100
plot_every = 10

In [None]:
def weights_init(model):
    classname = model.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(model.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(model.weight.data, 1.0, 0.02)
        nn.init.constant_(model.bias.data, 0)

In [None]:
# Define the transformations for the training and testing sets
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4, padding_mode='reflect'),  # Randomly crop the images
    transforms.RandomHorizontalFlip(),     # Randomly flip the images horizontally
    transforms.ToTensor(),                 # Convert the images to PyTorch tensors
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),  # Normalize the images
])

transform_test = transforms.Compose([
    transforms.ToTensor(),                 # Convert the images to PyTorch tensors
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),  # Normalize the images
])

class SpecificClassDataset(Dataset):
    def __init__(self, dataset, target_class):
        self.dataset = dataset
        self.target_class = target_class
        self.indices = [i for i, (_, label) in enumerate(self.dataset) if label == self.target_class]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        original_idx = self.indices[idx]
        image, label = self.dataset[original_idx]
        return image, label

# Load the CIFAR-10 training dataset
train_dataset_full = torchvision.datasets.CIFAR10(
    root='./data',       # Directory to store the dataset
    train=True,          # Specify that this is the training set
    transform=transform_train,  # Apply the training transformations
    download=True        # Download the dataset if it's not already available
)

# Load the CIFAR-10 test dataset
test_dataset_full = torchvision.datasets.CIFAR10(
    root='./data',       # Directory to store the dataset
    train=False,         # Specify that this is the test set
    transform=transform_test,  # Apply the test transformations
    download=True        # Download the dataset if it's not already available
)

# Specify the class you want to filter
target_class = 6 # frogs

# Create specific class datasets
train_dataset = SpecificClassDataset(train_dataset_full, target_class)
test_dataset = SpecificClassDataset(test_dataset_full, target_class)

# Create data loaders for the training and testing sets
train_loader = DataLoader(
    dataset=train_dataset,  # The training dataset
    batch_size=batch_size,          # Number of samples per batch
    shuffle=True,           # Shuffle the data at every epoch
    num_workers=2           # Number of subprocesses to use for data loading
)

test_loader = DataLoader(
    dataset=test_dataset,   # The test dataset
    batch_size=batch_size,          # Number of samples per batch
    shuffle=False,          # Do not shuffle the data
    num_workers=2           # Number of subprocesses to use for data loading
)

dataloaders = (train_loader, )

In [None]:
# Function to unnormalize and show an image
def unnormalize(img, mean, std):
    img = img.clone()  # Clone to avoid modifying the original tensor
    for t, m, s in zip(img, mean, std):
        t.mul_(s).add_(m)  # Unnormalize
    return img
    
def imshow(img, title):
    img = unnormalize(img, mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))  # Unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.title(title)
    plt.show()

# Get some random training images
dataiter = iter(train_loader)
images, labels = next(dataiter)

# Print 5 images from the dataset
for i in range(5):
    imshow(images[i], f'Label: {labels[i].item()}')

# GAN Implementation

In [None]:
train_size = 1.0
lr = 1e-4
weight_decay = 8e-9
beta1 = 0.5
beta2 = 0.999

In [None]:
def plotn(n, generator, device):
    generator.eval()
    noise = torch.FloatTensor(np.random.normal(0, 1, (n, 100))).to(device)
    imgs = generator(noise).detach().cpu()
    
    # Rescale from [-1, 1] to [0, 1]
    imgs = (imgs + 1) / 2
    # imgs = torch.clamp(imgs, 0, 1)
    fig, ax = plt.subplots(1, n, figsize=(n * 3, 3))
    for i, im in enumerate(imgs):
        # print(im.shape)
        # print(im[0])
        ax[i].imshow(np.transpose(im.numpy(), (1, 2, 0)))  # Convert from CHW to HWC format
        ax[i].axis('off')
    plt.show()

In [None]:
def train_gan(dataloaders, models, optimizers, loss_fn, epochs, plot_every, device):
    tqdm_iter = tqdm(range(epochs))
    train_dataloader = dataloaders[0]
    
    gen, disc = models[0], models[1]
    optim_gen, optim_disc = optimizers[0], optimizers[1]

    for epoch in tqdm_iter:
        gen.train()
        disc.train()

        train_gen_loss = 0.0
        train_disc_loss = 0.0
        
        test_gen_loss = 0.0
        test_disc_loss = 0.0

        for batch in train_dataloader:
            imgs, _ = batch
            imgs = imgs.to(device)

            disc.eval()
            gen.zero_grad()

            noise = torch.FloatTensor(np.random.normal(0.0, 1.0, (imgs.shape[0], 100))).to(device)
            real_labels = torch.ones((imgs.shape[0], 1)).to(device)
            fake_labels = torch.zeros((imgs.shape[0], 1)).to(device)
            
            generated = gen(noise)
            disc_preds = disc(generated)

            g_loss = loss_fn(disc_preds, real_labels)
            g_loss.backward()
            optim_gen.step()

            disc.train()
            disc.zero_grad()

            disc_real = disc(imgs)
            disc_real_loss = loss_fn(disc_real, real_labels)

            disc_fake = disc(generated.detach())
            disc_fake_loss = loss_fn(disc_fake, fake_labels)

            d_loss = (disc_real_loss + disc_fake_loss) / 2.0
            d_loss.backward()
            optim_disc.step()

            train_gen_loss += g_loss.item()
            train_disc_loss += d_loss.item()

        train_gen_loss /= len(train_dataloader)
        train_disc_loss /= len(train_dataloader)

        if epoch % plot_every == 0 or epoch == epochs - 1:
            plotn(5, gen, device)

        tqdm_dct = {'generator loss:': train_gen_loss, 'discriminator loss:': train_disc_loss}
        tqdm_iter.set_postfix(tqdm_dct, refresh=True)
        tqdm_iter.refresh()

In [None]:
class KeiGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(100, 256)
        self.bn1 = nn.BatchNorm1d(256, momentum=0.2)
        self.linear2 = nn.Linear(256, 512)
        self.bn2 = nn.BatchNorm1d(512, momentum=0.2)
        self.linear3 = nn.Linear(512, 1024)
        self.bn3 = nn.BatchNorm1d(1024, momentum=0.2)
        self.linear4 = nn.Linear(1024, 3072)
        self.tanh = nn.Tanh()
        self.leaky_relu = nn.LeakyReLU(0.2)

    def forward(self, input):
        hidden1 = self.leaky_relu(self.bn1(self.linear1(input)))
        hidden2 = self.leaky_relu(self.bn2(self.linear2(hidden1)))
        hidden3 = self.leaky_relu(self.bn3(self.linear3(hidden2)))
        generated = self.tanh(self.linear4(hidden3)).view(input.shape[0], 3, 32, 32)
        return generated

In [None]:
class KeiDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(3072, 1024)
        self.linear2 = nn.Linear(1024, 512)
        self.linear3 = nn.Linear(512, 256)
        self.linear4 = nn.Linear(256, 1)
        self.leaky_relu = nn.LeakyReLU(0.2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input):
        input = input.view(input.shape[0], -1)
        hidden1 = self.leaky_relu(self.linear1(input))
        hidden2 = self.leaky_relu(self.linear2(hidden1))
        hidden3 = self.leaky_relu(self.linear3(hidden2))
        classififed = self.sigmoid(self.linear4(hidden3))
        return classififed

In [None]:
generator = KeiGenerator().to(device)
discriminator = KeiDiscriminator().to(device)
optimizer_generator = optim.Adam(generator.parameters(), lr=lr, weight_decay=weight_decay, betas=(beta1, beta2))
optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=lr, weight_decay=weight_decay, betas=(beta1, beta2))
loss_fn = nn.BCELoss()

models = (generator, discriminator)
optimizers = (optimizer_generator, optimizer_discriminator)

In [None]:
print(summary(generator,input_size=(1,100)))
print(summary(discriminator,input_size=(1,3,32,32)))

In [None]:
train_gan(dataloaders, models, optimizers, loss_fn, epochs, plot_every, device)

In [None]:
generator.eval()
plotn(5, generator, device)

# DC GAN Implementation

In [None]:
train_size = 1.0
lr = 1e-4
weight_decay = 8e-9
beta1 = 0.5
beta2 = 0.999

In [None]:
class KeiDCGenerator(nn.Module):
    def __init__(self):
        super().__init__()      
        self.conv1 = nn.ConvTranspose2d(100, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        self.bn1 = nn.BatchNorm2d(1024)
        self.conv2 = nn.ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        self.bn2 = nn.BatchNorm2d(512)
        self.conv3 = nn.ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = nn.ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        self.bn4 = nn.BatchNorm2d(128)
        self.conv5 = nn.ConvTranspose2d(128, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        self.tanh = nn.Tanh()
        self.relu = nn.ReLU()

    def forward(self, input):
        # [1, 100, 1, 1]
        hidden1 = self.relu(self.bn1(self.conv1(input)))
        # [1, 1024, 2, 2]
        hidden2 = self.relu(self.bn2(self.conv2(hidden1)))
        # [1, 512, 4, 4]
        hidden3 = self.relu(self.bn3(self.conv3(hidden2)))
        # [1, 256, 8, 8]
        hidden4 = self.relu(self.bn4(self.conv4(hidden3)))
        # [1, 128, 16, 16]
        generated = self.tanh(self.conv5(hidden4))
        # [1, 3, 32, 32]
        return generated

In [None]:
class KeiDCDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        self.bn4 = nn.BatchNorm2d(512)
        self.conv5 = nn.Conv2d(512, 1, kernel_size=(2, 2), stride=(1, 1), padding=(0, 0), bias=False)
        self.leaky_relu = nn.LeakyReLU(0.2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input):
        # [1, 3, 32, 32]
        hidden1 = self.leaky_relu(self.conv1(input))
        # [1, 64, 16, 16]
        hidden2 = self.leaky_relu(self.bn2(self.conv2(hidden1)))
        # [1, 128, 8, 8]
        # [1, 128, 8, 8]
        hidden3 = self.leaky_relu(self.bn3(self.conv3(hidden2)))
        # [1, 256, 4, 4]
        # [1, 256, 4, 4]
        hidden4 = self.leaky_relu(self.bn4(self.conv4(hidden3)))
        # [1, 512, 2, 2]
        # [1, 512, 2, 2]      
        classified = self.sigmoid(self.conv5(hidden4)).view(input.shape[0], -1)
        # [1, 1, 1, 1]
        # [1, 1]
        return classified

In [None]:
adv_generator = KeiDCGenerator().to(device)
adv_generator.apply(weights_init)
adv_discriminator = KeiDCDiscriminator().to(device)
adv_discriminator.apply(weights_init)
print(summary(adv_generator,input_size=(1,100,1,1)))
print(summary(adv_discriminator,input_size=(64,3,32,32)))

In [None]:
def dcplotn(n, generator, device):
    generator.eval()
    noise = torch.FloatTensor(np.random.normal(0, 1, (n, 100, 1, 1))).to(device)
    imgs = generator(noise).detach().cpu()
    
    # Rescale from [-1, 1] to [0, 1]
    imgs = (imgs + 1) / 2
    fig, ax = plt.subplots(1, n, figsize=(n * 3, 3))
    for i, im in enumerate(imgs):
        # print(im.shape)
        # print(im[0])
        ax[i].imshow(np.transpose(im.numpy(), (1, 2, 0)))  # Convert from CHW to HWC format
        ax[i].axis('off')
    plt.show()

In [None]:
def train_dcgan(dataloaders, models, optimizers, loss_fn, epochs, plot_every, device):
    tqdm_iter = tqdm(range(epochs))
    train_dataloader = dataloaders[0]
    
    gen, disc = models[0], models[1]
    optim_gen, optim_disc = optimizers[0], optimizers[1]
    
    gen.train()
    disc.train()

    for epoch in tqdm_iter:
        train_gen_loss = 0.0
        train_disc_loss = 0.0
        
        test_gen_loss = 0.0
        test_disc_loss = 0.0

        for batch in train_dataloader:
            imgs, _ = batch
            imgs = imgs.to(device)
            imgs = 2.0 * imgs - 1.0

            gen.zero_grad()

            noise = torch.FloatTensor(np.random.normal(0.0, 1.0, (imgs.shape[0], 100, 1, 1))).to(device)
            real_labels = torch.ones((imgs.shape[0], 1)).to(device)
            fake_labels = torch.zeros((imgs.shape[0], 1)).to(device)
            # print(imgs.shape[1])
            generated = gen(noise)
            disc_preds = disc(generated)

            # print(f"Disc_preds: {disc_preds}\nReal: {real_labels}")
            g_loss = loss_fn(disc_preds, real_labels)
            g_loss.backward()
            optim_gen.step()

            disc.zero_grad()
            disc_real = disc(imgs)
            disc_real_loss = loss_fn(disc_real, real_labels)

            disc_fake = disc(generated.detach())
            disc_fake_loss = loss_fn(disc_fake, fake_labels)

            d_loss = (disc_real_loss + disc_fake_loss) / 2.0
            d_loss.backward()
            optim_disc.step()

            train_gen_loss += g_loss.item()
            train_disc_loss += d_loss.item()

        train_gen_loss /= len(train_dataloader)
        train_disc_loss /= len(train_dataloader)

        if epoch % plot_every == 0 or epoch == epochs - 1:
            dcplotn(5, gen, device)

        tqdm_dct = {'generator loss:': train_gen_loss, 'discriminator loss:': train_disc_loss}
        tqdm_iter.set_postfix(tqdm_dct, refresh=True)
        tqdm_iter.refresh()

In [None]:
optimizer_generator = optim.Adam(adv_generator.parameters(), lr=lr, weight_decay=weight_decay, betas=(beta1, beta2))
optimizer_discriminator = optim.Adam(adv_discriminator.parameters(), lr=lr, weight_decay=weight_decay, betas=(beta1, beta2))
loss_fn = nn.BCELoss()
adv_models = (adv_generator, adv_discriminator)
adv_optimizers = (optimizer_generator, optimizer_discriminator)

In [None]:
train_dcgan(dataloaders, adv_models, adv_optimizers, loss_fn, epochs, plot_every, device)

In [None]:
adv_generator.eval()
dcplotn(5, adv_generator, device)