<a href="https://colab.research.google.com/github/avrymi-asraf/Garden-of-GAN/blob/main/2-More-Networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
from IPython.display import clear_output
import time

In [None]:
from math import ceil
def train_model(dis, gen, epoches,batch_size, dataset,device,loss_f,optim_dis,optim_gen):
    '''Trains a generative adversarial network (GAN) model.

    Args:
        dis (torch.nn.Module): The discriminator model.
        gen (torch.nn.Module): The generator model.
        epoches (int): Number of training epochs.
        batch_size (int): Size of the batches used during training.
        dataset (torch.utils.data.Dataset): The dataset for training.
        device (torch.device): The device to run the models on (e.g., 'cpu' or 'cuda').
        loss_f (callable): The loss function used for training.
        optim_dis (torch.optim.Optimizer): Optimizer for the discriminator.
        optim_gen (torch.optim.Optimizer): Optimizer for the generator.

    Returns:
        pd.DataFrame: A DataFrame containing the training progress, with columns:
                      "epoch", "batch", "loss_gen", and "loss_dis".
    '''
    run_ind = 0
    run_data = pd.DataFrame(columns=["epoch","batch","loss_gen","loss_dis"],index=range(ceil(len(dataset)/batch_size)*epoches))
    data_loader = DataLoader(dataset,batch_size=batch_size,shuffle=True)
    for i in range(epoches):
        start_time = time.time()
        for j,(x,y) in enumerate(data_loader):
            x = x.to(device)
            y_hat = dis(x).flatten()
            loss_dis_1 = loss_f(y_hat,torch.ones(len(x),device=device))

            noise = torch.randn(len(x),100).to(device)
            y_hat = dis(gen(noise)).flatten()
            loss_dis_2 = loss_f(y_hat,torch.zeros(len(x),device=device))
            loss_dis = loss_dis_1 + loss_dis_2/2

            optim_dis.zero_grad()
            loss_dis.backward()
            optim_dis.step()

            noise = torch.randn(len(x),100).to(device)
            y_hat = dis(gen(noise)).flatten()
            loss_gen = loss_f(y_hat,torch.ones(len(x),device=device))
            optim_gen.zero_grad()
            loss_gen.backward()
            optim_gen.step()

            run_data.loc[run_ind] = [i,j,loss_gen.item(),loss_dis.item()]
            run_ind+=1
        clear_output(wait=True)
        px.line(run_data,x=run_data.index,y=["loss_gen","loss_dis"]).show()
        exapmle = gen(torch.randn(5,100).to(device)).detach().cpu().permute(0,2,3,1)
        px.imshow(exapmle,facet_col=0,facet_col_wrap=5).show()
        print(f'run time: {time.time()-start_time:03f}')
    return run_data




In [None]:
clifar10 = datasets.CIFAR10(root="data", train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0,1)]))

In [None]:
class Discrimnator_32(nn.Module):
    def __init__(self,im_dim=(3,32,32)):
        super().__init__()
        self.im_dim = im_dim
        self.conv1 = nn.Conv2d(in_channels=self.im_dim[0],out_channels=64,kernel_size=4,stride=2,padding=1)
        self.relu = nn.LeakyReLU(0.2)
        self.conv2 = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=4,stride=2,padding=1)
        self.conv3 = nn.Conv2d(in_channels=128,out_channels=256,kernel_size=4,stride=2,padding=1)
        self.lin = nn.Linear(in_features=256*4*4,out_features=1)
        self.sigmoid = nn.Sigmoid()
    def forward(self,x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = x.view(-1,256*4*4)
        x = self.lin(x)
        x = self.sigmoid(x)
        return x
class Generator_32(nn.Module):
    def __init__(self,letant_dim=100):
        super().__init__()
        self.lin = nn.Linear(in_features=letant_dim,out_features=256*4*4)
        self.relu = nn.LeakyReLU(0.2)
        self.conv1 = nn.ConvTranspose2d(in_channels=256,out_channels=128,kernel_size=4,stride=2,padding=1,output_padding=0)
        self.conv2 = nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=4,stride=2,padding=1,output_padding=0)
        self.conv3 = nn.ConvTranspose2d(in_channels=64,out_channels=3,kernel_size=4,stride=2,padding=1,output_padding=0)
        self.tanh = nn.Tanh()
    def forward(self,x):
        x = self.lin(x)
        x = self.relu(x)
        x = x.view(-1,256,4,4)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.tanh(x)
        return x

In [None]:
# dis = Discrimnator()
# dis(clifar10[0][0])

In [None]:
# px.imshow(gen(torch.randn(1,100)).detach().squeeze().permute(1,2,0))

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator  = Generator().to(device)
discriminator = Discrimnator().to(device)
optim_gen = torch.optim.Adam(generator.parameters(),lr=0.0002)
optim_dis = torch.optim.Adam(discriminator.parameters(),lr=0.0002)
loss_f = nn.BCELoss()

In [None]:
epocs = 20
batch_size = 128

In [None]:
mnist = datasets.MNIST(root="data", train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0,1),transforms.Resize((32,32)),transforms.Lambda(lambda x: x.repeat(3,1,1))]))
# px.imshow(mnist[0][0].permute(1,2,0))
clear_output()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator  = Generator().to(device)
discriminator = Discrimnator().to(device)
optim_gen = torch.optim.Adam(generator.parameters(),lr=0.0002)
optim_dis = torch.optim.Adam(discriminator.parameters(),lr=0.0002)
loss_f = nn.BCELoss()

In [None]:
epocs = 20
batch_size = 128

In [None]:
train_model(discriminator,generator,epocs,batch_size,mnist,device,loss_f,optim_dis,optim_gen)

In [None]:
!wget https://github.com/avrymi-asraf/Garden-of-GAN/raw/main/mini_celebA.zip -O mini_celebA.zip
!unzip mini_celebA.zip
clear_output()

In [None]:
!mkdir data
!mkdir data/mini_celebA
!mkdir data/mini_celebA/class_0
!mv mini_selebA* data/mini_celebA/class_0

In [None]:
# Define the data transform
transform = transforms.Compose([
    transforms.ToTensor(),
     transforms.CenterCrop((178,178)),
    transforms.Resize((128,128)),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])



In [None]:
class Discrimnator_128(nn.Module):
    def __init__(self,im_dim=(3,128,128)):
        super().__init__()
        self.im_dim = im_dim
        self.conv1 = nn.Conv2d(in_channels=self.im_dim[0],out_channels=64,kernel_size=4,stride=2,padding=1)
        self.relu = nn.LeakyReLU(0.2)
        self.conv2 = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=4,stride=2,padding=1)
        self.conv3 = nn.Conv2d(in_channels=128,out_channels=256,kernel_size=4,stride=2,padding=1)
        self.conv4 = nn.Conv2d(in_channels=256,out_channels=256,kernel_size=4,stride=4,padding=1)
        self.lin = nn.Linear(in_features=256*4*4,out_features=1)
        self.sigmoid = nn.Sigmoid()
    def forward(self,x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = self.conv4(x)
        x = self.relu(x)
        x = x.view(-1,256*4*4)
        x = self.lin(x)
        x = self.sigmoid(x)
        return x

class Generator_128(nn.Module):
    def __init__(self,letant_dim=100):
        super().__init__()
        self.lin = nn.Linear(in_features=letant_dim,out_features=256*4*4)
        self.relu = nn.LeakyReLU(0.2)
        self.conv1 = nn.ConvTranspose2d(in_channels=256,out_channels=128,kernel_size=4,stride=2,padding=1,output_padding=0)
        self.conv2 = nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=4,stride=4,padding=1,output_padding=2)
        self.conv3 = nn.ConvTranspose2d(in_channels=64,out_channels=3,kernel_size=4,stride=4,padding=1,output_padding=2)
        self.tanh = nn.Tanh()
    def forward(self,x):
        x = self.lin(x)
        x = self.relu(x)
        x = x.view(-1,256,4,4)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.tanh(x)
        return x

In [None]:
# print(Discrimnator()(dataset[0][0]).item())
# px.imshow(Generator()(torch.randn(1,100)).detach().squeeze().permute(1,2,0))


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = datasets.ImageFolder(root='data/mini_celebA', transform=transform)
gen = Generator().to(device)
dis = Discrimnator().to(device)

In [None]:
epoches = 100
batch_size = 64

optim_gen = torch.optim.Adam(gen.parameters(),lr=0.0002)
optim_dis = torch.optim.Adam(dis.parameters(),lr=0.0002)
loss_f = nn.BCELoss()

In [None]:
record_data = train_model(dis,gen,epoches,batch_size,dataset,device,loss_f,optim_dis,optim_gen)