
# **Paper Information**


**TransGAN: Two Transformers Can Make One Strong GAN, and That Can Scale Up, CVPR 2021**, Yifan Jiang, Shiyu Chang, Zhangyang Wang

*   Paper Link: https://arxiv.org/pdf/2102.07074v2.pdf
*   Official Implementation: https://github.com/VITA-Group/TransGAN
*   Paper Presentation by Ahmet Sarıgün : https://www.youtube.com/watch?v=xwrUkHiDoiY


**Project Group Members:**


*   Ahmet Sarıgün, ahmet.sarigun@metu.edu.tr
*   Dursun Bekci, bekci.dursun@metu.edu.tr

## **Paper Summary**
### **Introduction**
TransGAN is a transformer-based GAN model which can be considered as a pilot study as being completely free of convolutions. The architecture of TransGAN mainly consists of a memory-friendly transformer-based generator that progressively increases feature resolution, and correspondingly a patch-level discriminator that is also transformer-based. In training of the model, a series of techniques are combined in the original paper such as data augmentation, modified normalization, and relative position encoding to overcome the general training instability issues of the GAN. We implemented data augmentation [(Dosovitskiy et al., 2020)](https://arxiv.org/pdf/2010.11929.pdf), and relative position encoding in our work. In the original paper, performance of the model tested on different datasets such as STL-10, CIFAR-10, CelebA datasets and achieved competitive results compared to current state-of-the-art GANs using convolutions. In our project, we only tested our implementation on CIFAR10 dataset as we stated in our experimental result goals.

### **TransGAN Architecture**
The architecture pipeline of TransGAN is shown below in the figure taken from the original paper.

<img src="https://raw.githubusercontent.com/asarigun/TransGAN/main/images/transgan.jpg"> 
    
Figure 1: The pipeline of the pure transform-based generator and discriminator of TransGAN.

### **Transformer Encoder as Basic Block**
We used the transformer encoder [(Vaswani et al., 2017)](https://arxiv.org/pdf/1706.03762.pdf) as our basic block as in the original paper. An encoder is a composition of two parts. The first part is constructed by a multi-head self-attention module and the second part is a feed-forward MLP with GELU non-linearity. We apply layer normalization [(Ba et al., 2016)](https://arxiv.org/pdf/1607.06450.pdf) before both of the two parts.Both parts employ residual connection.

$$Attention(Q, K, V ) = softmax(QK^T√d_k)V$$
<img src="https://raw.githubusercontent.com/asarigun/TransGAN/main/images/vit.gif">
Credits for illustration of ViT: [@lucidrains](https://github.com/lucidrains)


### **Memory-friendly Generator**
In building the memory-friendly generator, TransGAN utilizes a common design philosophy in CNN-based GANs which iteratively upscale the resolution at multiple stages. Figure 1 (left) illustrates the memory-friendly generator which consists of multiple stages with several transformer blocks. At each stage, feature map resolution is gradually increased until it meets the target resolution *H × W*. The generator takes the random noise input and passes it through a multiple-layer perceptron (MLP). The output vector reshaped into a $H_0 × W_0$ resolution feature map (by default $H_0$ = $W_0$ = 8), each point a C-dimensional embedding.  This “feature map" is next treated as a length-64 sequence of C-dimensional tokens, combined with the learnable positional encoding. 
Then, transformer encoders take embedding tokens as inputs and calculate the correspondence between each token recursively. To synthesize higher resolution images, we insert an upsampling module after each stage, consisting of a pixelshuffle [(Shi et al., 2016)](https://arxiv.org/pdf/1609.05158.pdf) module. 

### **Tokenized-input for Discriminator**
The authors design the discriminator as shown in Figure 1 (right) that it takes the patches of an image as inputs. Then, they split the input images $Y$ ∈ $R^{H×W×3}$ into 8x8 patches where each patch can be regarded as a "word". The patches are then converted to the 1D sequence of token embeddings through a linear flatten layer. After that, learnable position encoding is added, and tokens pass through the transformer encoder. Finally, tokens are taken by the classification head to output the real/fake prediction.    

### **Training the Model**
In this section, we show our training code and training score for CIFAR-10 Dataset with the best performance hyperparameters that we found. We trained the largest model TransGAN-XL with data augmentation using different hyperparameters, and record the results in cifar/experiments folder.    

## Importing Libraries

In [None]:
from __future__ import division
from __future__ import print_function

import time
import argparse
import numpy as np

import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision.utils import make_grid, save_image

from tensorboardX import SummaryWriter
from tqdm import tqdm
from copy import deepcopy


from utils import *
from models import *
from fid_score import *
from inception_score import *

!mkdir checkpoint
!mkdir generated_imgs
!pip install tensorboardX
!mkdir fid_stat
%cd fid_stat
!wget bioinf.jku.at/research/ttur/ttur_stats/fid_stats_cifar10_train.npz
%cd ..

Using downloaded and verified file: ./data\cifar-10-python.tar.gz
Extracting ./data\cifar-10-python.tar.gz to ./data
Files already downloaded and verified


## Hyperparameters for CIFAR-10 Dataset
Since Google Colab provides limited computational power, we decreased the generated_batch_size from 64 to 32, and we also run it for 10 epochs to show our pre-computed training scores. On our local GPU machine, we train the model with generated_batch_size is 64 and run for 200 epochs.  

In [None]:
# training hyperparameters given by code author

lr_gen = 0.0001 #Learning rate for generator
lr_dis = 0.0001 #Learning rate for discriminator
latent_dim = 1024 #Latent dimension
gener_batch_size = 32 #Batch size for generator
dis_batch_size = 32 #Batch size for discriminator
epoch = 10 #Number of epoch
weight_decay = 1e-3 #Weight decay
drop_rate = 0.5 #dropout
n_critic = 5 #
max_iter = 500000
img_name = "img_name"
lr_decay = True

# architecture details by authors
image_size = 32 #H,W size of image for discriminator
initial_size = 8 #Initial size for generator
patch_size = 4 #Patch size for generated image
num_classes = 1 #Number of classes for discriminator 
output_dir = 'checkpoint' #saved model path
dim = 384 #Embedding dimension 
optimizer = 'Adam' #Optimizer
loss = "wgangp_eps" #Loss function
phi = 1 #
beta1 = 0 #
beta2 = 0.99 #
diff_aug = "translation,cutout,color" #data augmentation


## Training & Saving Model for CIFAR-10
As we mentioned above we run the training for 10 epochs due to limitation of Google Colab and showed the decrease in FID score from 253 to 138 in 10 epochs.

In [None]:
if torch.cuda.is_available():
    dev = "cuda:0"
else:
    dev = "cpu"

device = torch.device(dev)

generator= Generator(depth1=5, depth2=4, depth3=2, initial_size=8, dim=384, heads=4, mlp_ratio=4, drop_rate=0.5)#,device = device)
generator.to(device)

discriminator = Discriminator(diff_aug = diff_aug, image_size=32, patch_size=4, input_channel=3, num_classes=1,
                 dim=384, depth=7, heads=4, mlp_ratio=4,
                 drop_rate=0.5)
discriminator.to(device)


generator.apply(inits_weight)
discriminator.apply(inits_weight)

  nn.init.xavier_uniform(m.weight.data, 1.)


Discriminator(
  (patches): ImgPatches(
    (patch_embed): Conv2d(3, 384, kernel_size=(4, 4), stride=(4, 4))
  )
  (droprate): Dropout(p=0.5, inplace=False)
  (TransfomerEncoder): TransformerEncoder(
    (Encoder_Blocks): ModuleList(
      (0): Encoder_Block(
        (ln1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=384, out_features=1152, bias=False)
          (attention_dropout): Dropout(p=0.5, inplace=False)
          (out): Sequential(
            (0): Linear(in_features=384, out_features=384, bias=True)
            (1): Dropout(p=0.5, inplace=False)
          )
        )
        (ln2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=1536, out_features=384, bias=True)
          (droprateout): Dropout(p=0.5, inplace=False)
        )
      )
      (1): Enco

In [None]:
if optimizer == 'Adam':
    optim_gen = optim.Adam(filter(lambda p: p.requires_grad, generator.parameters()), lr=lr_gen, betas=(beta1, beta2))

    optim_dis = optim.Adam(filter(lambda p: p.requires_grad, discriminator.parameters()),lr=lr_dis, betas=(beta1, beta2))
elif optimizer == 'SGD':
    optim_gen = optim.SGD(filter(lambda p: p.requires_grad, generator.parameters()),
                lr=lr_gen, momentum=0.9)

    optim_dis = optim.SGD(filter(lambda p: p.requires_grad, discriminator.parameters()),
                lr=lr_dis, momentum=0.9)

elif optimizer == 'RMSprop':
    optim_gen = optim.RMSprop(filter(lambda p: p.requires_grad, discriminator.parameters()),
                lr=lr_dis, eps=1e-08, weight_decay=weight_decay, momentum=0, centered=False)

    optim_dis = optim.RMSprop(filter(lambda p: p.requires_grad, discriminator.parameters()), lr=lr_dis, eps=1e-08, weight_decay=weight_decay, momentum=0, centered=False)

gen_scheduler = LinearLrDecay(optim_gen, lr_gen, 0.0, 0, max_iter * n_critic)
dis_scheduler = LinearLrDecay(optim_dis, lr_dis, 0.0, 0, max_iter * n_critic)

#RMSprop(params, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False)

print("optimizer:",optimizer)

fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'

writer=SummaryWriter()
writer_dict = {'writer':writer}
writer_dict["train_global_steps"]=0
writer_dict["valid_global_steps"]=0

optimizer: Adam


In [None]:
def compute_gradient_penalty(D, real_samples, fake_samples, phi):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = torch.Tensor(np.random.random((real_samples.size(0), 1, 1, 1))).to(real_samples.get_device())
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = torch.ones([real_samples.shape[0], 1], requires_grad=False).to(real_samples.get_device())
    # Get gradient w.r.t. interpolates
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.contiguous().view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - phi) ** 2).mean()
    return gradient_penalty


def train(noise,generator, discriminator, optim_gen, optim_dis,
        epoch, writer, schedulers, img_size=32, latent_dim = latent_dim,
        n_critic = n_critic,
        gener_batch_size=gener_batch_size, device="cuda:0"):


    writer = writer_dict['writer']
    gen_step = 0

    generator = generator.train()
    discriminator = discriminator.train()

    transform = transforms.Compose([transforms.Resize(size=(img_size, img_size)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=30, shuffle=True)

    for index, (img, _) in enumerate(train_loader):

        global_steps = writer_dict['train_global_steps']

        real_imgs = img.type(torch.cuda.FloatTensor)

        noise = torch.cuda.FloatTensor(np.random.normal(0, 1, (img.shape[0], latent_dim)))#noise(img, latent_dim)#= args.latent_dim)

        optim_dis.zero_grad()
        real_valid=discriminator(real_imgs)
        fake_imgs = generator(noise).detach()
        
        #assert fake_imgs.size() == real_imgs.size(), f"fake_imgs.size(): {fake_imgs.size()} real_imgs.size(): {real_imgs.size()}"

        fake_valid = discriminator(fake_imgs)

        if loss == 'hinge':
            loss_dis = torch.mean(nn.ReLU(inplace=True)(1.0 - real_valid)).to(device) + torch.mean(nn.ReLU(inplace=True)(1 + fake_valid)).to(device)
        elif loss == 'wgangp_eps':
            gradient_penalty = compute_gradient_penalty(discriminator, real_imgs, fake_imgs.detach(), phi)
            loss_dis = -torch.mean(real_valid) + torch.mean(fake_valid) + gradient_penalty * 10 / (phi ** 2)         

        loss_dis.backward()
        optim_dis.step()

        writer.add_scalar("loss_dis", loss_dis.item(), global_steps)

        if global_steps % n_critic == 0:

            optim_gen.zero_grad()
            if schedulers:
                gen_scheduler, dis_scheduler = schedulers
                g_lr = gen_scheduler.step(global_steps)
                d_lr = dis_scheduler.step(global_steps)
                writer.add_scalar('LR/g_lr', g_lr, global_steps)
                writer.add_scalar('LR/d_lr', d_lr, global_steps)

            gener_noise = torch.cuda.FloatTensor(np.random.normal(0, 1, (gener_batch_size, latent_dim)))

            generated_imgs= generator(gener_noise)
            fake_valid = discriminator(generated_imgs)

            gener_loss = -torch.mean(fake_valid).to(device)
            gener_loss.backward()
            optim_gen.step()
            writer.add_scalar("gener_loss", gener_loss.item(), global_steps)

            gen_step += 1

            #writer_dict['train_global_steps'] = global_steps + 1

        if gen_step and index % 100 == 0:
            sample_imgs = generated_imgs[:25]
            img_grid = make_grid(sample_imgs, nrow=5, normalize=True, scale_each=True)
            save_image(sample_imgs, f'generated_images/generated_img_{epoch}_{index % len(train_loader)}.jpg', nrow=5, normalize=True, scale_each=True)            
            tqdm.write("[Epoch %d] [Batch %d/%d] [D loss: %f] [G loss: %f]" %
                (epoch+1, index % len(train_loader), len(train_loader), loss_dis.item(), gener_loss.item()))

In [None]:
def validate(generator, writer_dict, fid_stat):


        writer = writer_dict['writer']
        global_steps = writer_dict['valid_global_steps']

        generator = generator.eval()
        fid_score = get_fid(fid_stat, epoch, generator, num_img=5000, val_batch_size=60*2, latent_dim=1024, writer_dict=None, cls_idx=None)


        print(f"FID score: {fid_score}")

        writer.add_scalar('FID_score', fid_score, global_steps)

        writer_dict['valid_global_steps'] = global_steps + 1
        return fid_score

In [None]:
best = 1e4

for epoch in range(epoch):

    lr_schedulers = (gen_scheduler, dis_scheduler) if lr_decay else None

    train(noise, generator, discriminator, optim_gen, optim_dis,
    epoch, writer, lr_schedulers,img_size=32, latent_dim = latent_dim,
    n_critic = n_critic,
    gener_batch_size=gener_batch_size)

    checkpoint = {'epoch':epoch, 'best_fid':best}
    checkpoint['generator_state_dict'] = generator.state_dict()
    checkpoint['discriminator_state_dict'] = discriminator.state_dict()

    score = validate(generator, writer_dict, fid_stat)

    print(f'FID score: {score} - best ID score: {best} || @ epoch {epoch+1}.')
    if epoch == 0 or epoch > 30:
        if score < best:
            save_checkpoint(checkpoint, is_best=(score<best), output_dir=output_dir)
            print("Saved Latest Model!")
            best = score


checkpoint = {'epoch':epoch, 'best_fid':best}
checkpoint['generator_state_dict'] = generator.state_dict()
checkpoint['discriminator_state_dict'] = discriminator.state_dict()
score = validate(generator, writer_dict, fid_stat) ####CHECK AGAIN
save_checkpoint(checkpoint,is_best=(score<best), output_dir=output_dir)

Files already downloaded and verified
[Epoch 1] [Batch 0/1667] [D loss: 2.170649] [G loss: 0.804360]
[Epoch 1] [Batch 100/1667] [D loss: -8.304835] [G loss: 10.228638]
[Epoch 1] [Batch 200/1667] [D loss: -3.785351] [G loss: 8.178776]
[Epoch 1] [Batch 300/1667] [D loss: -1.363279] [G loss: 8.804716]
[Epoch 1] [Batch 400/1667] [D loss: 3.654663] [G loss: 4.162188]
[Epoch 1] [Batch 500/1667] [D loss: -1.666954] [G loss: 10.869068]
[Epoch 1] [Batch 600/1667] [D loss: -0.912084] [G loss: 7.093612]
[Epoch 1] [Batch 700/1667] [D loss: 0.081573] [G loss: 6.099335]
[Epoch 1] [Batch 800/1667] [D loss: -0.579418] [G loss: 5.591341]
[Epoch 1] [Batch 900/1667] [D loss: -0.207472] [G loss: 5.181511]
[Epoch 1] [Batch 1000/1667] [D loss: -0.217765] [G loss: 3.403362]
[Epoch 1] [Batch 1100/1667] [D loss: -0.506849] [G loss: 4.797401]
[Epoch 1] [Batch 1200/1667] [D loss: 0.399069] [G loss: 3.140276]
[Epoch 1] [Batch 1300/1667] [D loss: 0.596609] [G loss: 4.660552]
[Epoch 1] [Batch 1400/1667] [D loss: -0

sample images: 100%|██████████| 41/41 [00:10<00:00,  4.01it/s]
  3%|▎         | 3/98 [00:00<00:05, 15.87it/s]



100%|██████████| 98/98 [00:17<00:00,  5.64it/s]


FID score: 253.59642028808594
FID score: 253.59642028808594 - best ID score: 10000.0 || @ epoch 1.
Saved Latest Model!
Files already downloaded and verified
[Epoch 2] [Batch 0/1667] [D loss: -1.014748] [G loss: 5.774697]
[Epoch 2] [Batch 100/1667] [D loss: -0.355909] [G loss: 6.370594]
[Epoch 2] [Batch 200/1667] [D loss: 1.714307] [G loss: 3.597477]
[Epoch 2] [Batch 300/1667] [D loss: -0.131275] [G loss: 5.499962]
[Epoch 2] [Batch 400/1667] [D loss: 0.279393] [G loss: 4.044170]
[Epoch 2] [Batch 500/1667] [D loss: -0.194252] [G loss: 4.526648]
[Epoch 2] [Batch 600/1667] [D loss: 0.094030] [G loss: 3.439218]
[Epoch 2] [Batch 700/1667] [D loss: 1.523710] [G loss: 4.529868]
[Epoch 2] [Batch 800/1667] [D loss: -0.376306] [G loss: 3.141362]
[Epoch 2] [Batch 900/1667] [D loss: -1.031662] [G loss: 2.987192]
[Epoch 2] [Batch 1000/1667] [D loss: -0.558286] [G loss: 2.625803]
[Epoch 2] [Batch 1100/1667] [D loss: 0.100764] [G loss: 3.178997]
[Epoch 2] [Batch 1200/1667] [D loss: -0.225897] [G loss:

sample images: 100%|██████████| 41/41 [00:10<00:00,  4.00it/s]
  3%|▎         | 3/98 [00:00<00:06, 15.43it/s]



100%|██████████| 98/98 [00:17<00:00,  5.63it/s]


FID score: 256.2900390625
FID score: 256.2900390625 - best ID score: 253.59642028808594 || @ epoch 2.
Files already downloaded and verified
[Epoch 3] [Batch 0/1667] [D loss: -0.372902] [G loss: 1.947552]
[Epoch 3] [Batch 100/1667] [D loss: 0.475183] [G loss: 1.971642]
[Epoch 3] [Batch 200/1667] [D loss: -0.481462] [G loss: 1.334552]
[Epoch 3] [Batch 300/1667] [D loss: 0.233310] [G loss: 1.891104]
[Epoch 3] [Batch 400/1667] [D loss: -0.011959] [G loss: 2.165990]
[Epoch 3] [Batch 500/1667] [D loss: 0.746799] [G loss: 1.237505]
[Epoch 3] [Batch 600/1667] [D loss: 0.324373] [G loss: 1.170042]
[Epoch 3] [Batch 700/1667] [D loss: 0.274361] [G loss: 2.603842]
[Epoch 3] [Batch 800/1667] [D loss: -0.035530] [G loss: 2.541132]
[Epoch 3] [Batch 900/1667] [D loss: 0.114036] [G loss: 2.654665]
[Epoch 3] [Batch 1000/1667] [D loss: -0.258933] [G loss: 1.853575]
[Epoch 3] [Batch 1100/1667] [D loss: 0.695169] [G loss: 2.649147]
[Epoch 3] [Batch 1200/1667] [D loss: -0.254932] [G loss: 2.856803]
[Epoch 3

sample images: 100%|██████████| 41/41 [00:10<00:00,  4.00it/s]
  3%|▎         | 3/98 [00:00<00:05, 17.11it/s]



100%|██████████| 98/98 [00:16<00:00,  5.98it/s]


FID score: 207.31369018554688
FID score: 207.31369018554688 - best ID score: 253.59642028808594 || @ epoch 3.
Files already downloaded and verified
[Epoch 4] [Batch 0/1667] [D loss: -0.365915] [G loss: 3.485483]
[Epoch 4] [Batch 100/1667] [D loss: 0.443299] [G loss: 3.688633]
[Epoch 4] [Batch 200/1667] [D loss: -0.308660] [G loss: 2.951671]
[Epoch 4] [Batch 300/1667] [D loss: 0.600909] [G loss: 1.788807]
[Epoch 4] [Batch 400/1667] [D loss: 0.937112] [G loss: 1.625080]
[Epoch 4] [Batch 500/1667] [D loss: 0.299240] [G loss: 2.000552]
[Epoch 4] [Batch 600/1667] [D loss: 0.024430] [G loss: 2.145702]
[Epoch 4] [Batch 700/1667] [D loss: -0.568093] [G loss: 1.746835]
[Epoch 4] [Batch 800/1667] [D loss: 0.112734] [G loss: 0.780115]
[Epoch 4] [Batch 900/1667] [D loss: 0.062658] [G loss: 1.361718]
[Epoch 4] [Batch 1000/1667] [D loss: -0.388711] [G loss: 2.048912]
[Epoch 4] [Batch 1100/1667] [D loss: 0.292507] [G loss: 2.032061]
[Epoch 4] [Batch 1200/1667] [D loss: 0.602682] [G loss: 1.755839]
[E

sample images: 100%|██████████| 41/41 [00:10<00:00,  3.98it/s]
  3%|▎         | 3/98 [00:00<00:06, 15.33it/s]



100%|██████████| 98/98 [00:16<00:00,  5.83it/s]


FID score: 172.43017578125
FID score: 172.43017578125 - best ID score: 253.59642028808594 || @ epoch 4.
Files already downloaded and verified
[Epoch 5] [Batch 0/1667] [D loss: 0.172841] [G loss: 2.592175]
[Epoch 5] [Batch 100/1667] [D loss: 0.496616] [G loss: 2.731457]
[Epoch 5] [Batch 200/1667] [D loss: 0.140289] [G loss: 2.167178]
[Epoch 5] [Batch 300/1667] [D loss: -0.197189] [G loss: 2.397559]
[Epoch 5] [Batch 400/1667] [D loss: 0.439036] [G loss: 2.326895]
[Epoch 5] [Batch 500/1667] [D loss: 0.547368] [G loss: 1.894225]
[Epoch 5] [Batch 600/1667] [D loss: 0.337562] [G loss: 2.094888]
[Epoch 5] [Batch 700/1667] [D loss: 0.304808] [G loss: 2.826921]
[Epoch 5] [Batch 800/1667] [D loss: 0.845009] [G loss: 1.749743]
[Epoch 5] [Batch 900/1667] [D loss: 0.175082] [G loss: 1.266261]
[Epoch 5] [Batch 1000/1667] [D loss: 0.731524] [G loss: 0.654840]
[Epoch 5] [Batch 1100/1667] [D loss: 0.080363] [G loss: 1.845362]
[Epoch 5] [Batch 1200/1667] [D loss: -0.123356] [G loss: 2.430002]
[Epoch 5] 

sample images: 100%|██████████| 41/41 [00:10<00:00,  4.00it/s]
  3%|▎         | 3/98 [00:00<00:06, 15.32it/s]



100%|██████████| 98/98 [00:17<00:00,  5.61it/s]


FID score: 157.23953247070312
FID score: 157.23953247070312 - best ID score: 253.59642028808594 || @ epoch 5.
Files already downloaded and verified
[Epoch 6] [Batch 0/1667] [D loss: 0.371292] [G loss: 0.855299]
[Epoch 6] [Batch 100/1667] [D loss: 0.441028] [G loss: 0.614191]
[Epoch 6] [Batch 200/1667] [D loss: -0.097045] [G loss: 0.962588]
[Epoch 6] [Batch 300/1667] [D loss: 0.367652] [G loss: 0.857501]
[Epoch 6] [Batch 400/1667] [D loss: 0.638341] [G loss: 1.082901]
[Epoch 6] [Batch 500/1667] [D loss: 0.210048] [G loss: 1.923066]
[Epoch 6] [Batch 600/1667] [D loss: -0.016896] [G loss: 1.761053]
[Epoch 6] [Batch 700/1667] [D loss: 0.193946] [G loss: 1.685530]
[Epoch 6] [Batch 800/1667] [D loss: 1.111356] [G loss: 1.994036]
[Epoch 6] [Batch 900/1667] [D loss: -0.004096] [G loss: 2.789292]
[Epoch 6] [Batch 1000/1667] [D loss: 0.067194] [G loss: 2.250651]
[Epoch 6] [Batch 1100/1667] [D loss: 1.058728] [G loss: 1.593802]
[Epoch 6] [Batch 1200/1667] [D loss: -0.007716] [G loss: 1.086996]
[E

sample images: 100%|██████████| 41/41 [00:10<00:00,  4.00it/s]
  3%|▎         | 3/98 [00:00<00:05, 16.98it/s]



100%|██████████| 98/98 [00:17<00:00,  5.61it/s]


FID score: 143.97250366210938
FID score: 143.97250366210938 - best ID score: 253.59642028808594 || @ epoch 6.
Files already downloaded and verified
[Epoch 7] [Batch 0/1667] [D loss: 0.178964] [G loss: 1.296113]
[Epoch 7] [Batch 100/1667] [D loss: 0.032568] [G loss: 2.019550]
[Epoch 7] [Batch 200/1667] [D loss: 0.251444] [G loss: 2.284503]
[Epoch 7] [Batch 300/1667] [D loss: -0.062520] [G loss: 1.776974]
[Epoch 7] [Batch 400/1667] [D loss: 0.211068] [G loss: 2.131124]
[Epoch 7] [Batch 500/1667] [D loss: 0.180025] [G loss: 1.905816]
[Epoch 7] [Batch 600/1667] [D loss: 0.390855] [G loss: 1.178607]
[Epoch 7] [Batch 700/1667] [D loss: 0.813857] [G loss: 1.837076]
[Epoch 7] [Batch 800/1667] [D loss: 0.585625] [G loss: 2.185049]
[Epoch 7] [Batch 900/1667] [D loss: 0.245140] [G loss: 1.398675]
[Epoch 7] [Batch 1000/1667] [D loss: 0.579851] [G loss: 1.662638]
[Epoch 7] [Batch 1100/1667] [D loss: -0.435777] [G loss: 2.190656]
[Epoch 7] [Batch 1200/1667] [D loss: 0.447797] [G loss: 1.233313]
[Epo

sample images: 100%|██████████| 41/41 [00:10<00:00,  4.00it/s]
  3%|▎         | 3/98 [00:00<00:06, 15.28it/s]



100%|██████████| 98/98 [00:17<00:00,  5.57it/s]


FID score: 142.53961181640625
FID score: 142.53961181640625 - best ID score: 253.59642028808594 || @ epoch 7.
Files already downloaded and verified
[Epoch 8] [Batch 0/1667] [D loss: 0.223762] [G loss: 2.058608]
[Epoch 8] [Batch 100/1667] [D loss: 0.664551] [G loss: 1.365796]
[Epoch 8] [Batch 200/1667] [D loss: -0.030917] [G loss: 0.785733]
[Epoch 8] [Batch 300/1667] [D loss: 0.205168] [G loss: 1.533132]
[Epoch 8] [Batch 400/1667] [D loss: 0.286476] [G loss: 1.991652]
[Epoch 8] [Batch 500/1667] [D loss: 0.340443] [G loss: 2.079747]
[Epoch 8] [Batch 600/1667] [D loss: 0.504586] [G loss: 1.812056]
[Epoch 8] [Batch 700/1667] [D loss: -0.231664] [G loss: 1.432201]
[Epoch 8] [Batch 800/1667] [D loss: 0.073055] [G loss: 2.221403]
[Epoch 8] [Batch 900/1667] [D loss: 0.441062] [G loss: 1.701555]
[Epoch 8] [Batch 1000/1667] [D loss: 0.664099] [G loss: 2.063335]
[Epoch 8] [Batch 1100/1667] [D loss: 0.518637] [G loss: 1.571779]
[Epoch 8] [Batch 1200/1667] [D loss: -0.353109] [G loss: 1.516699]
[Ep

sample images: 100%|██████████| 41/41 [00:10<00:00,  4.00it/s]
  3%|▎         | 3/98 [00:00<00:06, 15.42it/s]



100%|██████████| 98/98 [00:17<00:00,  5.57it/s]


FID score: 135.49749755859375
FID score: 135.49749755859375 - best ID score: 253.59642028808594 || @ epoch 8.
Files already downloaded and verified
[Epoch 9] [Batch 0/1667] [D loss: 0.644349] [G loss: 1.506196]
[Epoch 9] [Batch 100/1667] [D loss: 0.430402] [G loss: 1.700323]
[Epoch 9] [Batch 200/1667] [D loss: 0.071626] [G loss: 1.859765]
[Epoch 9] [Batch 300/1667] [D loss: 0.712856] [G loss: 1.934990]
[Epoch 9] [Batch 400/1667] [D loss: -0.230229] [G loss: 1.745375]
[Epoch 9] [Batch 500/1667] [D loss: 0.363142] [G loss: 1.438794]
[Epoch 9] [Batch 600/1667] [D loss: 0.296138] [G loss: 1.327496]
[Epoch 9] [Batch 700/1667] [D loss: 0.589227] [G loss: 0.971769]
[Epoch 9] [Batch 800/1667] [D loss: 0.357308] [G loss: 1.326427]
[Epoch 9] [Batch 900/1667] [D loss: 0.602394] [G loss: 2.051815]
[Epoch 9] [Batch 1000/1667] [D loss: 0.024476] [G loss: 1.609274]
[Epoch 9] [Batch 1100/1667] [D loss: 0.068283] [G loss: 1.800756]
[Epoch 9] [Batch 1200/1667] [D loss: -0.241045] [G loss: 1.714448]
[Epo

sample images: 100%|██████████| 41/41 [00:10<00:00,  4.00it/s]
  3%|▎         | 3/98 [00:00<00:06, 15.41it/s]



100%|██████████| 98/98 [00:17<00:00,  5.59it/s]


FID score: 151.22991943359375
FID score: 151.22991943359375 - best ID score: 253.59642028808594 || @ epoch 9.
Files already downloaded and verified
[Epoch 10] [Batch 0/1667] [D loss: 0.345630] [G loss: 0.840272]
[Epoch 10] [Batch 100/1667] [D loss: 0.222541] [G loss: 1.538569]
[Epoch 10] [Batch 200/1667] [D loss: 0.289131] [G loss: 1.055008]
[Epoch 10] [Batch 300/1667] [D loss: 0.186299] [G loss: 1.278053]
[Epoch 10] [Batch 400/1667] [D loss: -0.148766] [G loss: 0.599416]
[Epoch 10] [Batch 500/1667] [D loss: 0.306913] [G loss: 1.451300]
[Epoch 10] [Batch 600/1667] [D loss: -0.133036] [G loss: 1.377939]
[Epoch 10] [Batch 700/1667] [D loss: 0.124857] [G loss: 1.209821]
[Epoch 10] [Batch 800/1667] [D loss: 0.867689] [G loss: 1.488301]
[Epoch 10] [Batch 900/1667] [D loss: 0.344258] [G loss: 1.940454]
[Epoch 10] [Batch 1000/1667] [D loss: 0.264276] [G loss: 1.534945]
[Epoch 10] [Batch 1100/1667] [D loss: -0.213311] [G loss: 1.180523]
[Epoch 10] [Batch 1200/1667] [D loss: 0.196303] [G loss: 

sample images: 100%|██████████| 41/41 [00:10<00:00,  4.01it/s]
  3%|▎         | 3/98 [00:00<00:06, 15.78it/s]



100%|██████████| 98/98 [00:17<00:00,  5.56it/s]
sample images:   0%|          | 0/41 [00:00<?, ?it/s]

FID score: 138.50048828125
FID score: 138.50048828125 - best ID score: 253.59642028808594 || @ epoch 10.


sample images: 100%|██████████| 41/41 [00:10<00:00,  3.88it/s]
  3%|▎         | 3/98 [00:00<00:06, 15.32it/s]



100%|██████████| 98/98 [00:17<00:00,  5.52it/s]


FID score: 138.24746704101562


### **Experimental Result Goals vs. Achieved Results**
In this project, we aimed to reproduce qualitative results(generating image samples by CIFAR-10 Dataset) and quantitative results in Table 2 and Table 4 of the original paper that shown below.
<table>
<tr>
<td> <img src="https://raw.githubusercontent.com/asarigun/TransGAN/main/results/table2.png" style="width: 400px;"/> </td>
<td> <img src="https://raw.githubusercontent.com/asarigun/TransGAN/main/results/table4.png" style="width: 400px;"/> </td>
</tr></table>

Since we have limited computational resource and time for the training all size of TransGAN model on CIFAR-10 Dataset, we only trained the largest model with data augmentation, TransGAN-XL, for Table 4 results.

## Test Model and Results
In this section, we loaded pre-trained model and got the following qualitative and quantitative results.

### Qualitative Results
The following pictures show our generated images at different epoch numbers. 
<table>
<tr>
<td style="text-align: center">0 Epoch</td>
<td style="text-align: center">40 Epoch</td> 
<td style="text-align: center">100 Epoch</td>
<td style="text-align: center">200 Epoch</td> 
</tr>
<trt>
<p align="center"><img width="30%" src="https://raw.githubusercontent.com/asarigun/TransGAN/main/images/atransgan_cifar.gif"></p>
</tr>
<tr>
<td> <img src="https://raw.githubusercontent.com/asarigun/TransGAN/main/results/0.jpg" style="width: 400px;"/> </td>
<td> <img src="https://raw.githubusercontent.com/asarigun/TransGAN/main/results/40.jpg" style="width: 400px;"/> </td>
<td> <img src="https://raw.githubusercontent.com/asarigun/TransGAN/main/results/100.jpg" style="width: 400px;"/> </td>
<td> <img src="https://raw.githubusercontent.com/asarigun/TransGAN/main/results/200.jpg" style="width: 400px;"/> </td>
</tr>
</table>

### Quantitative Results
As we mentioned above, due to the lack of computational resource, we did our experiments only with the largest model TransGAN-XL and get the following results. We had decided not to implement 'Co-Training with Self-Supervised Auxiliary Task' and 'Locality-Aware Initialization for Self-Attention' since they made only small differences as shown in the paper. The difference between our result and original paper result can be originated in using some different hyperparameters and abovementioned implementation differences. You can see our quantitative result, FID score 26.82, [here](https://github.com/asarigun/TransGAN/blob/main/results/wgangp_eps_optim_Adam_lr_gen_0_0001_lr_dis_0_0001_epoch_200.txt).



## Challenges and Discussions

Since the authors did not give detailed hyperparameters for each Transformers Block and Multi-Head Attention Mechanism on version 1, we needed to find the best hyperparameters. Also, in the training part, they did not give detailed hyperparameters such as droprate, weight decay, or batch normalization in version 1. But in the last version of the original paper, authors gave more detailed hyperparameters for training, therefore we got more reasonable results.

During the implementation, first we used Hinge loss and faced convergence problem in training. When we tried another loss function, WGAN-GP, that is mentioned in the last version of original paper, we achieved to overcome convergence problem and got better results.

As authors didn't share detailed training process in their previous version, we struggled to converge FID score during training. But in the latest version of the original paper, authors provided more details for training so that we achieved to converge FID score in the training.   

Due to lack of computational resource, we only trained the largest model, TransGAN-XL in our project. We implemented data augmentation in our model as it is considered crucial for TransGAN in the original paper. We didn't implement 'Co-Training with Self-Supervised Auxiliary Task' and 'Locality-Aware Initialization for Self-Attention' since they make only small differences as shown in the paper.

## Citation
```
@article{jiang2021transgan,
  title={TransGAN: Two Transformers Can Make One Strong GAN},
  author={Jiang, Yifan and Chang, Shiyu and Wang, Zhangyang},
  journal={arXiv preprint arXiv:2102.07074},
  year={2021}
}
```
```
@article{dosovitskiy2020,
  title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
  author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and  Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and Uszkoreit, Jakob and Houlsby, Neil},
  journal={arXiv preprint arXiv:2010.11929},
  year={2020}
}
```
```
@inproceedings{zhao2020diffaugment,
  title={Differentiable Augmentation for Data-Efficient GAN Training},
  author={Zhao, Shengyu and Liu, Zhijian and Lin, Ji and Zhu, Jun-Yan and Han, Song},
  booktitle={Conference on Neural Information Processing Systems (NeurIPS)},
  year={2020}
}
```
