<a href="https://colab.research.google.com/github/avrymi-asraf/Garden-of-GAN/blob/main/2-More-Networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
from IPython.display import clear_output
import time

In [2]:
clifar10 = datasets.CIFAR10(root="data", train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0,1)]))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 81408099.70it/s]


Extracting data/cifar-10-python.tar.gz to data


In [3]:
class Discrimnator(nn.Module):
    def __init__(self,im_dim=(3,32,32)):
        super().__init__()
        self.im_dim = im_dim
        self.conv1 = nn.Conv2d(in_channels=self.im_dim[0],out_channels=64,kernel_size=4,stride=2,padding=1)
        self.relu = nn.LeakyReLU(0.2)
        self.conv2 = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=4,stride=2,padding=1)
        self.conv3 = nn.Conv2d(in_channels=128,out_channels=256,kernel_size=4,stride=2,padding=1)
        self.lin = nn.Linear(in_features=256*4*4,out_features=1)
        self.sigmoid = nn.Sigmoid()
    def forward(self,x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = x.view(-1,256*4*4)
        x = self.lin(x)
        x = self.sigmoid(x)
        return x

In [4]:
dis = Discrimnator()
dis(clifar10[0][0])


tensor([[0.5052]], grad_fn=<SigmoidBackward0>)

In [5]:
class Generator(nn.Module):
    def __init__(self,letant_dim=100):
        super().__init__()
        self.lin = nn.Linear(in_features=letant_dim,out_features=256*4*4)
        self.relu = nn.LeakyReLU(0.2)
        self.conv1 = nn.ConvTranspose2d(in_channels=256,out_channels=128,kernel_size=4,stride=2,padding=1,output_padding=0)
        self.conv2 = nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=4,stride=2,padding=1,output_padding=0)
        self.conv3 = nn.ConvTranspose2d(in_channels=64,out_channels=3,kernel_size=4,stride=2,padding=1,output_padding=0)
        self.tanh = nn.Tanh()
    def forward(self,x):
        x = self.lin(x)
        x = self.relu(x)
        x = x.view(-1,256,4,4)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.tanh(x)
        return x


In [None]:
# px.imshow(gen(torch.randn(1,100)).detach().squeeze().permute(1,2,0))

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator  = Generator().to(device)
discriminator = Discrimnator().to(device)
optim_gen = torch.optim.Adam(generator.parameters(),lr=0.0002)
optim_dis = torch.optim.Adam(discriminator.parameters(),lr=0.0002)
loss_f = nn.BCELoss()

In [14]:
epocs = 20
batch_size = 128
from math import ceil
run_data = pd.DataFrame(columns=["epoch","batch","loss_gen","loss_dis"],index=range(ceil(len(clifar10)/batch_size)*epocs))

In [15]:
run_ind = 0
for i in range(epocs):
    for j,(x,y) in enumerate(DataLoader(clifar10,batch_size=batch_size,shuffle=True)):
        x = x.to(device)
        y_hat = discriminator(x).flatten()
        loss_dis_1 = loss_f(y_hat,torch.ones(len(x),device=device))

        noise = torch.randn(len(x),100).to(device)
        y_hat = discriminator(generator(noise)).flatten()
        loss_dis_2 = loss_f(y_hat,torch.zeros(len(x),device=device))
        loss_dis = loss_dis_1 + loss_dis_2/2

        optim_dis.zero_grad()
        loss_dis.backward()
        optim_dis.step()

        noise = torch.randn(len(x),100).to(device)
        y_hat = discriminator(generator(noise)).flatten()
        loss_gen = loss_f(y_hat,torch.ones(len(x),device=device))
        optim_gen.zero_grad()
        loss_gen.backward()
        optim_gen.step()

        run_data.loc[run_ind] = [i,j,loss_gen.item(),loss_dis.item()]
        run_ind+=1
    clear_output(wait=True)
    px.line(run_data,x=run_data.index,y=["loss_gen","loss_dis"]).show()
    exapmle = generator(torch.randn(5,100).to(device)).detach().cpu().permute(0,2,3,1)
    px.imshow(exapmle,facet_col=0,facet_col_wrap=5).show()




In [None]:
mnist = datasets.MNIST(root="data", train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0,1),transforms.Resize((32,32)),transforms.Lambda(lambda x: x.repeat(3,1,1))]))
# px.imshow(mnist[0][0].permute(1,2,0))

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 227177264.87it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 38140961.53it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 248478115.78it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 18549687.21it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw






In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator  = Generator().to(device)
discriminator = Discrimnator().to(device)
optim_gen = torch.optim.Adam(generator.parameters(),lr=0.0002)
optim_dis = torch.optim.Adam(discriminator.parameters(),lr=0.0002)
loss_f = nn.BCELoss()

In [None]:
epocs = 20
batch_size = 128
from math import ceil
run_data = pd.DataFrame(columns=["epoch","batch","loss_gen","loss_dis"],index=range(ceil(len(mnist)/batch_size)*epocs))

In [None]:
run_ind = 0
for i in range(epocs):
    start_time = time.time()
    for j,(x,y) in enumerate(DataLoader(mnist,batch_size=batch_size,shuffle=True)):
        x = x.to(device)
        y_hat = discriminator(x).flatten()
        loss_dis_1 = loss_f(y_hat,torch.ones(len(x),device=device))

        noise = torch.randn(len(x),100).to(device)
        y_hat = discriminator(generator(noise)).flatten()
        loss_dis_2 = loss_f(y_hat,torch.zeros(len(x),device=device))
        loss_dis = loss_dis_1 + loss_dis_2/2

        optim_dis.zero_grad()
        loss_dis.backward()
        optim_dis.step()

        noise = torch.randn(len(x),100).to(device)
        y_hat = discriminator(generator(noise)).flatten()
        loss_gen = loss_f(y_hat,torch.ones(len(x),device=device))
        optim_gen.zero_grad()
        loss_gen.backward()
        optim_gen.step()

        run_data.loc[run_ind] = [i,j,loss_gen.item(),loss_dis.item()]
        run_ind+=1
    clear_output(wait=True)
    px.line(run_data,x=run_data.index,y=["loss_gen","loss_dis"]).show()
    exapmle = generator(torch.randn(5,100).to(device)).detach().cpu().permute(0,2,3,1)
    px.imshow(exapmle,facet_col=0,facet_col_wrap=5).show()
    print(f'{time.time()-start_time:03f}')




26.012212


In [None]:
!wget https://github.com/avrymi-asraf/Garden-of-GAN/raw/main/mini_celebA.zip -O mini_celebA.zip
!unzip mini_celebA.zip
clear_output()

In [None]:
!mkdir data
!mkdir data/mini_celebA
!mkdir data/mini_celebA/class_0
!mv mini_selebA* data/mini_celebA/class_0

In [None]:
# Define the data transform
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

dataset = datasets.ImageFolder(root='data/mini_celebA', transform=transform)

In [None]:
dataset[0][0].shape

torch.Size([3, 218, 178])

In [None]:
class Discrimnator(nn.Module):
    def __init__(self,im_dim=(3,32,32)):
        super().__init__()
        self.im_dim = im_dim
        self.conv1 = nn.Conv2d(in_channels=self.im_dim[0],out_channels=64,kernel_size=4,stride=2,padding=1)
        self.relu = nn.LeakyReLU(0.2)
        self.conv2 = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=4,stride=2,padding=1)
        self.conv3 = nn.Conv2d(in_channels=128,out_channels=256,kernel_size=4,stride=2,padding=1)
        self.lin = nn.Linear(in_features=256*4*4,out_features=1)
        self.sigmoid = nn.Sigmoid()
    def forward(self,x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = x.view(-1,256*4*4)
        x = self.lin(x)
        x = self.sigmoid(x)
        return x

class Generator(nn.Module):
    def __init__(self,letant_dim=100):
        super().__init__()
        self.lin = nn.Linear(in_features=letant_dim,out_features=256*4*4)
        self.relu = nn.LeakyReLU(0.2)
        self.conv1 = nn.ConvTranspose2d(in_channels=256,out_channels=128,kernel_size=4,stride=2,padding=1,output_padding=0)
        self.conv2 = nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=4,stride=2,padding=1,output_padding=0)
        self.conv3 = nn.ConvTranspose2d(in_channels=64,out_channels=3,kernel_size=4,stride=2,padding=1,output_padding=0)
        self.tanh = nn.Tanh()


torch.Size([3, 218, 178])

$$
h_{out} = (h_{in}-1)\times \text{stride} -2\times \text{padding} + \text{output pandding}+1
$$
$$
h_{out} = (h_{in}-1)\times 2 -2\times \text{padding} + \text{output pandding}+1
$$

In [None]:
im = transforms.CenterCrop((200,156))(x)
px.imshow(im.permute(1,2,0))