# Paper Information


Paper: TransGAN: Two Transformers Can Make One Strong GAN, CVPR 2021, https://arxiv.org/abs/2102.07074

Authors: Yifan Jiang, Shiyu Chang, Zhangyang Wang

Code Author: Ahmet Sarıgün & Dursun Bekci

## Importing Libraries

In [22]:
from __future__ import division
from __future__ import print_function

import time
import argparse
import numpy as np

import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision.utils import make_grid, save_image

from tensorboardX import SummaryWriter
from tqdm import tqdm
from copy import deepcopy


from models import *
from utils import * 


## Training & Saving a Model for CIFAR-10

Although in the paper, this implementation claimed as TransGAN-S which is small relatively, it needs more computational power and 1 epoch takes nearly 7 minutes in GPU for CIFAR-10. Therefore, it is recommended to train the model in GPU for CIFAR-10 dataset. When training model on GPU, it may take one day to train. That's why, in this notebook we will mainly focus on MNIST dataset. If you want to know more details of implementation of the model in CIFAR-10, you can look at ```./cifar``` folder and there is a brief explanation about the re-implementation at README.md under the ```./cifar```. 

## Hyperparameters for MNIST

This is the hyperparameters that has been used when training on MNIST dataset.

In [4]:
# training hyperparameters given by code author

lr_gen = 0.001 #Learning rate for generator
lr_dis = 0.001 #Learning rate for discriminator
latent_dim = 128 #Latent dimension
gener_batch_size = 60 #Batch size for generator
dis_batch_size = 30 #Batch size for discriminator
epoch = 200 #Number of epoch
weight_decay = 1e-4 #Weight decay
drop_rate = 0.5 #dropout

# architecture details by authors
image_size = 28 #H,W size of image for discriminator
initial_size = 7 #Initial size for generator
patch_size = 14 #Patch size for generated image
num_classes = 1 #Number of classes for discriminator 
output_dir = 'checkpoint' #saved model path
dim = 128 #Embedding dimension 
depth = 1 #depth for transformers encoder block for discriminator
depth1 = 1 #depth for first transformers encoder block-set which is after MLP block for generator
depth2 = 1 #depth for second transformers encoder block-set which is after first encoder block for generator
depth3 = 1 #depth for third transformers encoder block-set which is after second encoder block for generator
heads = 1 #head for attention mechanism

## Training & Saving Model for MNIST
Although the model size and dataset is too small compared to paper that has experimented on big dataset and models, training takes too long even for one epoch for this dataset and model. Hence, please use GPU!

In [5]:
if torch.cuda.is_available():
    dev = "cuda:0"
else:
    dev = "cpu"

device = torch.device(dev)

generator= Generator(depth1=1, depth2=1, depth3=1, initial_size=7, dim=128, heads=1, mlp_ratio=4, drop_rate=0.5)
generator.to(device)     
discriminator = Discriminator(image_size=28, patch_size=14, input_channel=1, num_classes=1, dim=128, depth=1, heads=1, mlp_ratio=4, drop_rate=0.5)
discriminator.to(device)
generator.apply(inits_weight)
discriminator.apply(inits_weight)

  nn.init.xavier_uniform(m.weight.data, 1.)


Discriminator(
  (patches): ImgPatches(
    (patch_embed): Conv2d(1, 128, kernel_size=(14, 14), stride=(14, 14))
  )
  (droprate): Dropout(p=0.5, inplace=False)
  (TransfomerEncoder): TransformerEncoder(
    (Encoder_Blocks): ModuleList(
      (0): Encoder_Block(
        (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=128, out_features=384, bias=False)
          (attention_dropout): Dropout(p=0.5, inplace=False)
          (out): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): Dropout(p=0.5, inplace=False)
          )
        )
        (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (fc1): Linear(in_features=128, out_features=512, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=512, out_features=128, bias=True)
          (droprateout): Dropout(p=0.5, inplace=False)
        )
      )
    )
  )
  (n

In [6]:
optim_gen = optim.Adam(filter(lambda p: p.requires_grad, generator.parameters()),
                lr=lr_gen, weight_decay= weight_decay)

optim_dis = optim.Adam(filter(lambda p: p.requires_grad, discriminator.parameters()),
                lr=lr_dis, weight_decay=weight_decay)

writer=SummaryWriter()
writer_dict = {'writer':writer}
writer_dict["train_global_steps"]=0
writer_dict["valid_global_steps"]=0

In [23]:
def train(generator, discriminator, optim_gen, optim_dis,
        epoch, writer,img_size=28, latent_dim = 384,
        gener_batch_size=60,device="cpu"):


    writer = writer_dict['writer']
    gen_step = 0

    generator = generator.train()
    discriminator = discriminator.train()

    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

    train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=30, shuffle=True)

    for index, (img, _) in enumerate(train_loader):

        global_steps = writer_dict['train_global_steps']

        real_imgs = img.type(torch.cuda.FloatTensor)
        noise = torch.cuda.FloatTensor(np.random.normal(0, 1, (img.shape[0], latent_dim)))

        optim_dis.zero_grad()
        real_valid=discriminator(real_imgs)
        fake_imgs = generator(noise).detach()

        assert fake_imgs.size() == real_imgs.size(), f"fake_imgs.size(): {fake_imgs.size()} real_imgs.size(): {real_imgs.size()}"

        fake_valid = discriminator(fake_imgs)

        loss_dis = torch.mean(nn.ReLU(inplace=True)(1.0 - real_valid)).to(device) + torch.mean(nn.ReLU(inplace=True)(1 + fake_valid)).to(device)

        loss_dis.backward()
        optim_dis.step()

        writer.add_scalar("loss_dis", loss_dis.item(), global_steps)

        if global_steps % 5 == 0:

            optim_gen.zero_grad()

            gener_noise = torch.cuda.FloatTensor(np.random.normal(0, 1, (gener_batch_size, latent_dim)))

            generated_imgs= generator(gener_noise)
            fake_valid = discriminator(generated_imgs)

            gener_loss = -torch.mean(fake_valid).to(device)
            gener_loss.backward()
            optim_gen.step()
            writer.add_scalar("gener_loss", gener_loss.item(), global_steps)

            gen_step += 1


        if gen_step and index % 50 == 0:
            sample_imgs = generated_imgs[:25]
            img_grid = make_grid(sample_imgs, nrow=5, normalize=True, scale_each=True)
            save_image(sample_imgs, f'generated_images/generated_img_{epoch}_{index % len(train_loader)}.jpg', nrow=5, normalize=True, scale_each=True)            
            tqdm.write("[Epoch %d] [Batch %d/%d] [D loss: %f] [G loss: %f]" %
                (epoch+1, index % len(train_loader)+50, len(train_loader), loss_dis.item(), gener_loss.item()))

In [27]:
def validate(generator, writer_dict):



        writer = writer_dict['writer']
        global_steps = writer_dict['valid_global_steps']

        generator = generator.eval()
        writer_dict['valid_global_steps'] = global_steps + 1

In [29]:
epoch = 25
for epoch in range(epoch):

    train(generator, discriminator, optim_gen, optim_dis,
    epoch, writer,img_size=28, latent_dim = 128, gener_batch_size=60)

    checkpoint = {'epoch':epoch}
    checkpoint['generator_state_dict'] = generator.state_dict()
    checkpoint['discriminator_state_dict'] = discriminator.state_dict()

    score = validate(generator, writer_dict)

checkpoint = {'epoch':epoch}
checkpoint['generator_state_dict'] = generator.state_dict()
checkpoint['discriminator_state_dict'] = discriminator.state_dict()
score = validate(generator, writer_dict) 
save_checkpoint(checkpoint, output_dir=output_dir)

[Epoch 1] [Batch 50/2000] [D loss: 1.862579] [G loss: -0.103841]
[Epoch 1] [Batch 100/2000] [D loss: 1.939610] [G loss: -0.155793]
[Epoch 1] [Batch 150/2000] [D loss: 1.880374] [G loss: -0.136092]
[Epoch 1] [Batch 200/2000] [D loss: 1.985547] [G loss: -0.080095]
[Epoch 1] [Batch 250/2000] [D loss: 1.929414] [G loss: -0.030376]
[Epoch 1] [Batch 300/2000] [D loss: 2.034805] [G loss: 0.017664]
[Epoch 1] [Batch 350/2000] [D loss: 1.971889] [G loss: 0.109574]
[Epoch 1] [Batch 400/2000] [D loss: 2.040691] [G loss: -0.254877]
[Epoch 1] [Batch 450/2000] [D loss: 2.233412] [G loss: 0.133916]
[Epoch 1] [Batch 500/2000] [D loss: 2.450507] [G loss: -0.632127]
[Epoch 1] [Batch 550/2000] [D loss: 1.878941] [G loss: 0.043961]
[Epoch 1] [Batch 600/2000] [D loss: 2.232890] [G loss: 0.201755]
[Epoch 1] [Batch 650/2000] [D loss: 1.976163] [G loss: 0.257815]
[Epoch 1] [Batch 700/2000] [D loss: 1.846442] [G loss: 0.075465]
[Epoch 1] [Batch 750/2000] [D loss: 2.082432] [G loss: 0.040188]
[Epoch 1] [Batch 80

## Results 

In this paper, we aim to train the smallest model which was TransGAN-S on CIFAR-10 but since it takes long and we don't have enough computational power our qualitative and quantitative results were not so good. Therefore, we simply reduce the model like shallow model, and train it on the MNIST dataset. Although, this implemented model was so small when comparing DCGAN, the training takes long and the qualitative results were not so good comraing the convolutional based models. In this re-implementation, as some of the challenges that has been encountered, we could not be reproduced and  will discuss the reasons in other section. This is the qualitative results our implemetation on MNIST dataset: 

![f1](https://s3.gifyu.com/images/transgan_mnist1.gif) ![f2](https://s3.gifyu.com/images/transgan_mnist2.gif) 

In this implementation we could not use metrics such as FID or IS score since in the original paper there is no any benchmarking on MNIST dataset. You can try yourself with pretrained on MNIST at ```./checkpoint``` Also, training the model is relatively takes long time comparing the convoltional based GANs and we provide an alternative which is in the ```./cifar``` path to readers to look original benchmark.  




## Challenges and Discussions

This GAN model was implemented completely free of Convolutions (just in unflatten step, we use Conv2d) and used Transformers architectures which became popular since Vision Transformers, [ViT paper](https://arxiv.org/abs/2010.11929).
 

During the implementation, there were some challenges that make it unable to solve and get the desired or aimed results. One of them is that the authors did not give detailed hyperparameters for each Transformers Blok and Multi-Head Attention Mechanism. Also, in the training part they did not give detailed of hyperparameters such as droprate, weight decay or batch normalization. 

Unlike the paper authors claims, Transformers based generator is not such as "friendly memory" generator which use the [Upsampling](https://arxiv.org/pdf/1609.05158.pdf) strategy to reduce the dimenson of channels and increse the dimension of H,W and patch embedding in ViT. Since when we try to train implementation for CIFAR-10 at Google Colab, we could not train it in more iteration. 

Therefore, when implementing and trainig this model, it shows that this model is relatively huge comparing the convolutional based models and to train it you more computational power.

## Citation
```
@article{jiang2021transgan,
  title={TransGAN: Two Transformers Can Make One Strong GAN},
  author={Jiang, Yifan and Chang, Shiyu and Wang, Zhangyang},
  journal={arXiv preprint arXiv:2102.07074},
  year={2021}
}
```
```
@article{dosovitskiy2020,
  title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
  author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and  Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and Uszkoreit, Jakob and Houlsby, Neil},
  journal={arXiv preprint arXiv:2010.11929},
  year={2020}
}
```