## gans architecture
Paper:[Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/abs/1511.06434)

![imge](screenshot_7.png)

- this is the Generator architecture
- the discriminator architecture is the opposite of the generator architecture

## Design of the network
- Replace all pooling layers with strided convolutions.
- Remove all fully connected layers.
- Use transposed convolutions for upsampling.
- Use Batch Normalization after every layer except after the output layer of the generator and the input layer of the discriminator.
- Use ReLU non-linearity for each layer in the generator except for output layer  use tanh.
- Use Leaky-ReLU non-linearity for each layer of the disciminator excpet for output layer use sigmoid.

## Hyperparameters

Hyperparameters are chosen as given in the paper.

- mini-batch size: 128
- learning rate: 0.0002
- momentum term beta1: 0.5
- slope of leak of LeakyReLU: 0.2
- For the optimizer Adam (with beta2 = 0.999) has been used instead of SGD as described in the paper.

## Implementation

In [1]:
import torch
from torch import nn 
%load_ext autoreload
%autoreload 2


### Discriminator


In [2]:
class Discriminator(nn.Module):
    def __init__(self, channel_num, filters_num) -> None:  # * channel_num: 3, filters_num is the number of filters in the first layer
        super(Discriminator, self).__init__()

        # * Input: (N, channel_num, 64, 64)
        self.disc = nn.Sequential(
            nn.Conv2d(
                channel_num,
                filters_num,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=True,
            ),
            nn.LeakyReLU(0.2),
            # * Input: (N, filters_num, 32, 32)
            self._block(filters_num, filters_num * 2, 4, 2, 1),
            # * Input: (N, filters_num*2, 16, 16)
            self._block(filters_num*2, filters_num * 4, 4, 2, 1),
            # * Input: (N, filters_num*8, 8, 8)
            self._block(filters_num*4, filters_num * 8, 4, 2, 1),
            # * Input: (N, filters_num*8, 4, 4)
            nn.Conv2d(filters_num * 8, 1, kernel_size=4, stride=2, padding=0, bias=False),
            # * Output: (N, 1, 1, 1)
            #nn.Sigmoid(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                bias=False,  # * No bias in batch norm
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)

In [3]:
disc = Discriminator(3, 64) #* 3 channels, starting with 64 filters


In [4]:
def _output_shape_cal(input_shape, kernel_size, stride, padding):
    return (input_shape - kernel_size + 2 * padding) / stride + 1

_output_shape_cal(64, 4, 2, 1)

32.0

In [5]:
#from torchinfo import summary

#summary(disc, (1, 3, 64, 64), device="cuda", verbose=0)
# 1, 3 , 64, 64 -> 1, 64, 32, 32 

### Generator

In [6]:
class Generator(nn.Module):
    def __init__(self, z_dim , channel_num, filters_num) -> None:
        super(Generator, self).__init__()
        
        self.net = nn.Sequential(
            #* Input: (N, z_dim, 1, 1)
            self._block(z_dim, filters_num*16, 4, 1, 0),
            
            #* Input: (N, filters_num*16, 4, 4)
            self._block(filters_num*16, filters_num*8, 4, 2, 1), 
            
            #* Input: (N, filters_num*8, 8, 8)
            self._block(filters_num*8, filters_num*4, 4, 2, 1),
            
            #* Input: (N, filters_num*4, 16, 16)
            self._block(filters_num*4, filters_num*2, 4, 2, 1),
            
            
            #self._block(filters_num*2, filters_num, 4, 2, 1),
            
            #* Input: (N, filters_num*2, 32, 32)
            nn.ConvTranspose2d(
                in_channels = filters_num*2,
                out_channels= channel_num,
                kernel_size=4,
                stride=2,
                padding=1),
            #* Output: (N, channel_num, 64, 64)
            nn.Tanh(), #* [-1, 1]
        )
        
    
    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                bias=False,  # * No bias in batch norm
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )
    
    def forward(self, x):
        return self.net(x)

In [7]:
#gen = Generator(100, 3, 64)
#gen

## initialize weights

In [8]:
def initializing_weights(model):
    for m in model.modules():
        if isinstance(m , (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            #* Initializing weights
            nn.init.normal_(m.weight.data, 0.0, 0.02) #* Normal distribution with mean 0 and std 0.02

In [9]:
N, in_channels, H, W = 8, 3, 64, 64
z_dim = 100
x = torch.randn((N, in_channels, H, W))
disc = Discriminator(in_channels, 8)
initializing_weights(disc)
print(disc(x).shape) #*-> get down to 1, 1, 1
gen = Generator(z_dim, in_channels, 8)
z = torch.randn((N, z_dim, 1, 1))
initializing_weights(gen)
print(gen(z).shape) #* -> upsacle to 1, 3, 64, 64

torch.Size([8, 1, 1, 1])
torch.Size([8, 3, 64, 64])


## Training

1. Imports

In [11]:
import torch
import torchvision
from torch import nn, optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from utils.model import *


2. Hyperparameters

In [12]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = 'cpu'
LEARNING_RATE = 2e-4
BATCH_SIZE = 32
IMAG_SIZE = 64
CHANNELS_IMG = 1 
Z_DIM = 100
NUM_EPOCH = 5

FEATURES_DISC = 64 #* NUMBER OF FILETER IN DISCRIMINATOR
FEATURES_GEN = 64 #* NUMBER OF FILETER IN GENERATOR

3. transforms

In [13]:
transformation = transforms.Compose([
    transforms.Resize(IMAG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
)]) #* mean and std = [0.5] [0.5] for each channel
    
                                    


## Downloading the dataset

In [14]:
dataset =datasets.MNIST(
    root='data/',
    download=True,
    train=True,
    transform=transformation,
    target_transform=None,
)

loader = DataLoader(
    dataset= dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

### Cearing Generator and Discriminator

In [15]:
gen = Generator(
    z_dim= Z_DIM,
    channel_num= CHANNELS_IMG,
    filters_num= FEATURES_GEN,
).to(device)

disc = Discriminator(
    channel_num= CHANNELS_IMG,
    filters_num=FEATURES_DISC
).to(device)


#* initialize weights with normal distribution of mean 0 and std 0.02
initializing_weights(gen)
initializing_weights(disc)

#### LOSS and Optimizer

In [16]:
opt_gen = optim.Adam(
    params=gen.parameters(),
    lr=LEARNING_RATE,
    betas=(0.5, 0.999)  # * beta1 and beta2 ? default = (0.9, 0.999)
    # * beta1: exponential decay rate for the first moment estimates
    # * beta2: exponential decay rate for the second moment estimates
    # * beta1 and beta2 are used to calculate the first and second moment estimates
    #*  coefficients used for computing running averages of gradient and its square
)

opt_disc = optim.Adam(params=disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
criterion = nn.BCELoss() 
#loss_no_logits = nn.BCELoss()


## Fews more steps + tensorboard

In [17]:
fixed_noise = torch.randn(32, Z_DIM,1,1)
fixed_noise.shape

writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(F'logs/fake')
step = 0


In [18]:
real, _ = next(iter(loader))
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def get_Disc_loss(gen, disc, real, batch_size, z_dim, criterion, device):
    real = real.to(device)
    noise = torch.randn((batch_size, z_dim, 1, 1)).to(device)
    gen_images = gen(noise)

    disc_fake = disc(gen_images).detach().reshape(-1)
    disc_real = disc(real).reshape(-1)

    Disc_fake_loss = criterion(disc_fake, torch.zeros_like(disc_fake))
    Disc_real_loss = criterion(disc_real, torch.ones_like(disc_real))
    Disc_loss = (Disc_fake_loss + Disc_real_loss) / 2
    return Disc_loss

# opt_disc.zero_grad()
# Disc_loss = get_Disc_loss(gen, disc , real, BATCH_SIZE, Z_DIM, criterion, opt_disc , device)
# Disc_loss.backward(retain_graph=True)
# opt_disc.step()

# Disc_loss.item()
# get_Disc_loss(gen, disc , real, BATCH_SIZE, Z_DIM, criterion, opt_disc , device)

In [19]:
def get_gen_loss(gen, disc, batch_size, z_dim, criterion, device):
    noise = torch.randn((batch_size, z_dim, 1, 1)).to(device)
    generated_img = gen(noise).detach()
    disc_gen = disc(generated_img)
    Gen_loss = criterion(disc_gen, torch.ones_like(disc_gen))
    return Gen_loss

# opt_gen.zero_grad()
# Gen_loss = get_gen_loss(gen, disc, BATCH_SIZE, Z_DIM, criterion)
# Gen_loss.backward()
# opt_gen.step()


In [20]:
from tqdm.autonotebook import tqdm

In [21]:
device = 'cuda'
epochs = 1
# BATCH_SIZE = 16
gen.train()
disc.train()
gen_loss_hist = []
disc_loss_hist = []
for _ in range(epochs):
    total_disc_loss = 0
    total_gen_loss = 0
    for batch_idx , (real, _) in enumerate(tqdm(loader)):
        # real, _ = next(iter(loader))
        real = real.to(device)
        opt_disc.zero_grad()
        Disc_loss = get_Disc_loss(disc=disc,
                                gen=gen,
                                real=real,
                                batch_size=BATCH_SIZE,
                                z_dim=Z_DIM,
                                criterion=criterion,
                                device='cuda'
                                )
        Disc_loss.backward()
        opt_disc.step()
        
        total_disc_loss += Disc_loss.item()
        
        opt_gen.zero_grad()
        Gen_loss = get_gen_loss(
            disc=disc,
            gen=gen,
            batch_size=BATCH_SIZE,
            z_dim=Z_DIM,
            criterion=criterion,
            device=device
        )
        opt_gen.step()
        total_gen_loss += Gen_loss.item()
        
    disc_loss_hist.append(total_disc_loss//len(loader))
    gen_loss_hist.append(total_disc_loss//len(loader))

disc_loss_hist, gen_loss_hist

  0%|          | 0/1875 [00:00<?, ?it/s]

([3.0], [3.0])