# Creating WGAN from Sctrach

## importing Libraries

In [1]:

import torch
import torchvision
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

## Model

### Creating CRITI

In [2]:
class Criti(nn.Module):
    def __init__(self, channel_img, features_d):
        super(Criti, self).__init__()
        self.critic = nn.Sequential(
            nn.Conv2d(in_channels=channel_img,
                               out_channels=features_d,
                               kernel_size=4,
                               stride=2,
                               padding=1
                               ), # 64x64
            nn.LeakyReLU(0.2),
            self._block(in_chanels=features_d,
                        out_channels=features_d*2,
                        kernel_size=4,
                        stride=2,
                        padding=1
                        ),#16x16
            self._block(in_chanels=features_d*2,
                        out_channels=features_d*4,
                        kernel_size=4,
                        stride=2,
                        padding=1
                        ),#8x8
            self._block(in_chanels=features_d*4,
                        out_channels=features_d*8,
                        kernel_size=4,
                        stride=2,
                        padding=1
                        ),#4x4
            nn.Conv2d(
                in_channels=features_d*8,
                out_channels=1,
                kernel_size=4,
                stride=2,
                padding=0
            ),#1x1
        )

    def _block(self,
               in_chanels,
               out_channels,
               kernel_size,
               stride,
               padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels=in_chanels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                bias=False
            ),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.LeakyReLU(0.2),
        )
    def forward(self,x):
        return self.critic(x)

### Creating Generator

In [3]:
   
class Generator(nn.Module):
    def __init__(self, z_din,channel_img,features_g):
        super(Generator,self).__init__()
        self.net = nn.Sequential(
            self._block(
                in_chanels=z_din,
                out_channels=features_g*16,
                kernel_size=4,
                stride=1,
                padding=0
            ),#N x f_g*16 x 4 x 4
            self._block(
                in_chanels=features_g*16,
                out_channels=features_g*8,
                kernel_size=4,
                stride=2,
                padding=1
            ),#8x8
            self._block(
                in_chanels=features_g*8,
                out_channels=features_g*4,
                kernel_size=4,
                stride=2,
                padding=1
            ),#16x16
            self._block(
                in_chanels=features_g*4,
                out_channels=features_g*2,
                kernel_size=4,
                stride=2,
                padding=1
            ),#32x32
            nn.ConvTranspose2d(
                in_channels=features_g*2,
                out_channels=channel_img,
                kernel_size=4,
                stride=2,
                padding=1,
            ),
            nn.Tanh(),
        )
        
    def _block(self,
               in_chanels,
               out_channels,
               kernel_size,
               stride,
               padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=in_chanels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(0.2),
        )
    def forward(self,x):
        return self.net(x)

### Creatring InItial Weights

In [4]:
        
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m,
                      (
                          nn.Conv2d,
                          nn.ConvTranspose2d,
                          nn.BatchNorm2d
                        )):
            nn.init.normal_(m.weight.data,0.0,0.02)

### Testing the model

In [5]:
def test():
    N,in_channels,H,W = 8,3,64,64
    z_dim = 100
    x = torch.randn((N,in_channels,H,W))
    
    critic = Criti(in_channels,8)
    initialize_weights(critic)
    assert critic(x).shape == (N,1,1,1)
    
    gen = Generator(z_dim, in_channels, 8)
    initialize_weights(gen)
    z = torch.randn((N, z_dim, 1, 1))
    assert gen(z).shape == (N, in_channels, H, W)
    
    print("Success")
    
test()

Success


## Hapreparameter

In [6]:
device = torch.device("cuda" if torch.cuda.is_available else 'cpu')
LEARNING_RATE = 2e-4
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNEL_IMG = 1
NOICE_DIM = 100
NUM_EPOCHS=5
FEATURES_critic =64
FEATHURE_DEN = 64
CRITIC_ITERATIONS = 5
WEIGHT_CLIP = 0.01

In [7]:
transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
            [
                0.5 for _ in range(CHANNEL_IMG)
            ],
            [
                0.5 for _ in range(CHANNEL_IMG)
            ]
        ),
    ]
)

## Dataset

In [8]:
datasets = datasets.MNIST(
    'dataset/',
    train=True,
    transform=transforms,
    download=True
)

## Initilizing and optimizing

### Initilizing dataset

In [9]:
loader =  DataLoader(datasets, batch_size=BATCH_SIZE, shuffle=True)

### Inilizing model and weights

In [10]:
gen = Generator(
                NOICE_DIM, 
                CHANNEL_IMG,
                FEATHURE_DEN
    ).to(device)
critic = Criti(
                    CHANNEL_IMG,
                    FEATURES_critic
    ).to(device)
initialize_weights(gen)
initialize_weights(critic)

### optimizing learning

In [11]:
opt_gen = optim.RMSprop(gen.parameters(), 
                        lr=LEARNING_RATE)
opt_critic = optim.RMSprop(critic.parameters(),
                           lr=LEARNING_RATE)

### Fixing Noice

In [12]:
fixed_noise = torch.randn(32,NOICE_DIM,1,1).to(device)

## Setting to Tensorboard

In [13]:
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

gen.train()
critic.train()

Criti(
  (critic): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (5): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2))
  )
)

## Training

In [14]:

for epoch in range(NUM_EPOCHS):
    # Target labels not needed! <3 unsupervised
    for batch_idx, (data, _) in enumerate(tqdm(loader)):
        data = data.to(device)
        cur_batch_size = data.shape[0]

        # Train Critic: max E[critic(real)] - E[critic(fake)]
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, NOICE_DIM, 1, 1).to(device)
            fake = gen(noise)
            critic_real = critic(data).reshape(-1)
            critic_fake = critic(fake).reshape(-1)
            loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake))
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

            # clip critic weights between -0.01, 0.01
            for p in critic.parameters():
                p.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP)

        # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        gen_fake = critic(fake).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0 and batch_idx > 0:
            gen.eval()
            critic.eval()
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
                  Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(noise)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(
                    data[:32], normalize=True
                )
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:32], normalize=True
                )

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1
            gen.train()
            critic.train()

 22%|██▏       | 101/469 [01:04<04:29,  1.36it/s]

Epoch [0/5] Batch 100/469                   Loss D: -1.4717, loss G: 0.7077


 43%|████▎     | 201/469 [02:08<03:13,  1.39it/s]

Epoch [0/5] Batch 200/469                   Loss D: -1.3025, loss G: 0.6658


 64%|██████▍   | 301/469 [03:13<02:01,  1.38it/s]

Epoch [0/5] Batch 300/469                   Loss D: -0.9044, loss G: 0.2499


 86%|████████▌ | 401/469 [04:18<00:50,  1.35it/s]

Epoch [0/5] Batch 400/469                   Loss D: -0.7128, loss G: 0.6106


100%|██████████| 469/469 [05:03<00:00,  1.55it/s]
 22%|██▏       | 101/469 [01:06<04:28,  1.37it/s]

Epoch [1/5] Batch 100/469                   Loss D: -1.0335, loss G: 0.4740


 43%|████▎     | 201/469 [02:12<03:15,  1.37it/s]

Epoch [1/5] Batch 200/469                   Loss D: -1.1112, loss G: 0.4883


 64%|██████▍   | 301/469 [03:16<02:00,  1.40it/s]

Epoch [1/5] Batch 300/469                   Loss D: -1.0003, loss G: 0.3582


 86%|████████▌ | 401/469 [04:21<00:48,  1.39it/s]

Epoch [1/5] Batch 400/469                   Loss D: -1.0068, loss G: 0.3867


100%|██████████| 469/469 [05:04<00:00,  1.54it/s]
 22%|██▏       | 101/469 [01:05<04:25,  1.38it/s]

Epoch [2/5] Batch 100/469                   Loss D: -0.9951, loss G: 0.5945


 43%|████▎     | 201/469 [02:10<03:13,  1.38it/s]

Epoch [2/5] Batch 200/469                   Loss D: -0.8270, loss G: 0.6007


 64%|██████▍   | 301/469 [03:16<02:03,  1.36it/s]

Epoch [2/5] Batch 300/469                   Loss D: -0.9082, loss G: 0.5998


 86%|████████▌ | 401/469 [04:22<00:50,  1.36it/s]

Epoch [2/5] Batch 400/469                   Loss D: -1.0253, loss G: 0.4621


100%|██████████| 469/469 [05:07<00:00,  1.53it/s]
 22%|██▏       | 101/469 [01:07<04:31,  1.36it/s]

Epoch [3/5] Batch 100/469                   Loss D: -0.8019, loss G: 0.6018


 43%|████▎     | 201/469 [02:13<03:14,  1.38it/s]

Epoch [3/5] Batch 200/469                   Loss D: -0.5835, loss G: 0.5531


 64%|██████▍   | 301/469 [03:18<02:01,  1.38it/s]

Epoch [3/5] Batch 300/469                   Loss D: -0.7138, loss G: 0.1851


 86%|████████▌ | 401/469 [04:23<00:49,  1.38it/s]

Epoch [3/5] Batch 400/469                   Loss D: -0.7046, loss G: 0.1939


100%|██████████| 469/469 [05:07<00:00,  1.53it/s]
 22%|██▏       | 101/469 [01:06<04:28,  1.37it/s]

Epoch [4/5] Batch 100/469                   Loss D: -0.6071, loss G: 0.1125


 43%|████▎     | 201/469 [02:11<03:13,  1.39it/s]

Epoch [4/5] Batch 200/469                   Loss D: -0.6142, loss G: 0.5552


 64%|██████▍   | 301/469 [03:16<02:01,  1.39it/s]

Epoch [4/5] Batch 300/469                   Loss D: -0.5056, loss G: 0.5821


 86%|████████▌ | 401/469 [04:21<00:49,  1.39it/s]

Epoch [4/5] Batch 400/469                   Loss D: -0.6112, loss G: 0.1013


100%|██████████| 469/469 [05:05<00:00,  1.54it/s]
