<a href="https://colab.research.google.com/github/aju22/VQ-GANs/blob/main/VQ_GANs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import os
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision import utils as vutils
import numpy as np
import matplotlib.pyplot as plt

#VQ-GAN: Architecture
![Architecture](https://miro.medium.com/v2/resize:fit:828/format:webp/1*JOrCybe84dKUvgiVNe0TaA.png)

In [2]:
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

In [3]:
class ResidualBlock(nn.Module):
  
  def __init__(self, in_channels, out_channels):
    
    super().__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    
    self.block = nn.Sequential(
        nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True),
        Swish(),
        nn.Conv2d(in_channels, out_channels, 3, 1, 1),
        nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True),
        Swish(),
        nn.Conv2d(out_channels, out_channels, 3, 1, 1)
    )

    if self.in_channels != self.out_channels:
      self.channel_up = nn.Conv2d(in_channels, out_channels, 1, 1, 0)

  def forward(self, x):

      if self.in_channels != self.out_channels:
        return self.channel_up(x)+self.block(x)

      else:
        return x + self.block(x)  

In [4]:
class UpSampleBlock(nn.Module):
    
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, 3, 1, 1)

    def forward(self, x):
        x = F.interpolate(x, scale_factor=2.0)
        return self.conv(x)

In [5]:
class DownSampleBlock(nn.Module):
    
    def __init__(self, channels):
        
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, 3, 2, 0)

    def forward(self, x):
        
        pad = (0, 1, 0, 1)
        x = F.pad(x, pad, mode="constant", value=0)
        return self.conv(x)

In [6]:
class AttentionBlock(nn.Module):
    def __init__(self, channels):
        
        super().__init__()
        self.in_channels = channels

        self.gn = nn.GroupNorm(32, channels)
        self.q = nn.Conv2d(channels, channels, 1, 1, 0)
        self.k = nn.Conv2d(channels, channels, 1, 1, 0)
        self.v = nn.Conv2d(channels, channels, 1, 1, 0)
        self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0)

    def forward(self, x):
        
        h_ = self.gn(x)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        b, c, h, w = q.shape

        q = q.reshape(b, c, h*w)
        q = q.permute(0, 2, 1)
        k = k.reshape(b, c, h*w)
        v = v.reshape(b, c, h*w)

        attn = torch.bmm(q, k)
        attn = attn * (int(c)**(-0.5))
        attn = F.softmax(attn, dim=2)
        attn = attn.permute(0, 2, 1)

        A = torch.bmm(v, attn)
        A = A.reshape(b, c, h, w)

        return x + A

# Encoder

![]( https://miro.medium.com/max/828/0*PitQ20hZW7-HjPWr)

In [7]:
class Encoder(nn.Module):
    
    def __init__(self, image_channels, latent_dim):
        
        super().__init__()
        channels = [128, 128, 128, 256, 256, 512]
        attn_resolutions = [16]
        num_res_blocks = 2
        resolution = 256
        
        layers = [nn.Conv2d(image_channels, channels[0], 3, 1, 1)]
        
        for i in range(len(channels)-1):
            in_channels = channels[i]
            out_channels = channels[i + 1]
            for j in range(num_res_blocks):
                layers.append(ResidualBlock(in_channels, out_channels))
                in_channels = out_channels
                if resolution in attn_resolutions:
                    layers.append(AttentionBlock(in_channels))
            if i != len(channels)-2:
                layers.append(DownSampleBlock(channels[i+1]))
                resolution //= 2
        
        layers.append(ResidualBlock(channels[-1], channels[-1]))
        layers.append(AttentionBlock(channels[-1]))
        layers.append(ResidualBlock(channels[-1], channels[-1]))
        layers.append(nn.GroupNorm(32, channels[-1]))
        layers.append(Swish())
        layers.append(nn.Conv2d(channels[-1], latent_dim, 3, 1, 1))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        
        return self.model(x)

# Decoder

![](https://miro.medium.com/max/828/0*Uv9K77hnuyYplsxw)

In [8]:
class Decoder(nn.Module):
    
    def __init__(self, image_channels, latent_dim):
        
        super().__init__()
        channels = [512, 256, 256, 128, 128]
        attn_resolutions = [16]
        num_res_blocks = 3
        resolution = 16

        in_channels = channels[0]
        layers = [nn.Conv2d(latent_dim, in_channels, 3, 1, 1),
                  ResidualBlock(in_channels, in_channels),
                  AttentionBlock(in_channels),
                  ResidualBlock(in_channels, in_channels)]

        for i in range(len(channels)):
            out_channels = channels[i]
            for j in range(num_res_blocks):
                layers.append(ResidualBlock(in_channels, out_channels))
                in_channels = out_channels
                if resolution in attn_resolutions:
                    layers.append(AttentionBlock(in_channels))
            if i != 0:
                layers.append(UpSampleBlock(in_channels))
                resolution *= 2

        layers.append(nn.GroupNorm(32, in_channels))
        layers.append(Swish())
        layers.append(nn.Conv2d(in_channels, image_channels, 3, 1, 1))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        
        return self.model(x)

#CodeBook

####Note --> VQ-VAE Loss


![](https://miro.medium.com/max/786/1*9kcpNwDdfVInvNSm_kM94A.webp)

In [9]:
class Codebook(nn.Module):
    
    def __init__(self, num_codebook_vectors, latent_dim, beta):
        
        super().__init__()
        self.num_codebook_vectors = num_codebook_vectors
        self.latent_dim = latent_dim
        self.beta = beta

        self.embedding = nn.Embedding(self.num_codebook_vectors, self.latent_dim)
        self.embedding.weight.data.uniform_(-1.0 / self.num_codebook_vectors, 1.0 / self.num_codebook_vectors)

    def forward(self, z):
        
        z = z.permute(0, 2, 3, 1).contiguous()
        z_flattened = z.view(-1, self.latent_dim)

        d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \
            torch.sum(self.embedding.weight**2, dim=1) - \
            2*(torch.matmul(z_flattened, self.embedding.weight.t()))

        min_encoding_indices = torch.argmin(d, dim=1)
        z_q = self.embedding(min_encoding_indices).view(z.shape)

        loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2) #Stop gradient loss implementation

        z_q = z + (z_q - z).detach() #This is preserving gradients in backprop

        z_q = z_q.permute(0, 3, 1, 2)

        return z_q, min_encoding_indices, loss

#VQ-GAN

In [10]:
class VQGAN(nn.Module):
    
    def __init__(self, image_channels, latent_dim, num_codebook_vectors, beta, device = torch.device("cuda" if torch.cuda.is_available() else "cpu")):
        
        super().__init__()
        self.encoder = Encoder(image_channels, latent_dim).to(device=device)
        self.decoder = Decoder(image_channels, latent_dim).to(device=device)
        self.codebook = Codebook(num_codebook_vectors, latent_dim, beta).to(device=device)
        self.quant_conv = nn.Conv2d(latent_dim, latent_dim, 1).to(device=device)
        self.post_quant_conv = nn.Conv2d(latent_dim, latent_dim, 1).to(device=device)

    
    #Encode from image space to latent(z) space
    def encode(self, imgs):
        
        encoded_images = self.encoder(imgs)
        quant_conv_encoded_images = self.quant_conv(encoded_images)
        codebook_mapping, codebook_indices, q_loss = self.codebook(quant_conv_encoded_images)
        return codebook_mapping, codebook_indices, q_loss

    #Decode from latent(z) space to image space
    def decode(self, z):
        
        post_quant_conv_mapping = self.post_quant_conv(z)
        decoded_images = self.decoder(post_quant_conv_mapping)
        return decoded_images
    
    def forward(self, imgs):
        
        codebook_mapping, codebook_indices, q_loss = self.encode(imgs)
        decoded_images = self.decode(codebook_mapping)

        return decoded_images, codebook_indices, q_loss

   #Lambda is used by authors as a weighting factor betwen VQ-VAE Loss and GAN-Loss based on Perceptual Losses.
    def calculate_lambda(self, perceptual_loss, gan_loss):
        
        last_layer = self.decoder.model[-1]
        last_layer_weight = last_layer.weight
        perceptual_loss_grads = torch.autograd.grad(perceptual_loss, last_layer_weight, retain_graph=True)[0]
        gan_loss_grads = torch.autograd.grad(gan_loss, last_layer_weight, retain_graph=True)[0]

        λ = torch.norm(perceptual_loss_grads) / (torch.norm(gan_loss_grads) + 1e-4)
        λ = torch.clamp(λ, 0, 1e4).detach()
        return 0.8 * λ

    #Start the discriminator later, for generator to learn some basic reconstruction
    @staticmethod
    def adopt_weight(disc_factor, i, threshold, value=0.):
        if i < threshold:
            disc_factor = value
        return disc_factor

    def load_checkpoint(self, path):
        self.load_state_dict(torch.load(path))

# Discriminator

### Taken from PatchGAN Paper


In [11]:
class PatchDiscriminator(nn.Module):
    
    def __init__(self, image_channels, num_filters_last=64, n_layers=3):
        
        super().__init__()
        
        layers = [nn.Conv2d(image_channels, num_filters_last, 4, 2, 1), nn.LeakyReLU(0.2)]
        num_filters_mult = 1

        for i in range(1, n_layers + 1):
            num_filters_mult_last = num_filters_mult
            num_filters_mult = min(2 ** i, 8)
            layers += [
                nn.Conv2d(num_filters_last * num_filters_mult_last, num_filters_last * num_filters_mult, 4,
                          2 if i < n_layers else 1, 1, bias=False),
                nn.BatchNorm2d(num_filters_last * num_filters_mult),
                nn.LeakyReLU(0.2, True)
            ]

        layers.append(nn.Conv2d(num_filters_last * num_filters_mult, 1, 4, 1, 1))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        
        return self.model(x)

# VGG Perceptual Loss

In [12]:
class VGGPerceptualLoss(nn.Module):
    
    def __init__(self, resize=False):
        
        super().__init__()
        blocks = []
        blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
        
        for bl in blocks:
            for p in bl.parameters():
                p.requires_grad = False
        
        self.blocks = nn.ModuleList(blocks)
        self.transform = F.interpolate
        self.resize = resize
        self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def forward(self, input, target, feature_layers=[0, 1, 2, 3]):
        
        input = (input-self.mean) / self.std
        target = (target-self.mean) / self.std
        
        x = input
        y = target
        loss = 0.0
        
        for i, block in enumerate(self.blocks):
            x = block(x)
            y = block(y)
   
            if i in feature_layers:
            
                loss += torch.abs(x - y).mean(dim=[1,2,3], keepdim=True)
          
        return loss

# Dataset

In [None]:
!wget https://www.robots.ox.ac.uk/~vgg/data/flowers/17/17flowers.tgz
!tar -xvf "17flowers.tgz"

In [14]:
class CustomDataSet(Dataset):
    
    def __init__(self, main_dir, size=None):
        
        self.main_dir = main_dir
        self.transform = transforms.Compose([
            transforms.Resize(size),
            transforms.ToTensor(),
            transforms.CenterCrop(size),
            transforms.Lambda(lambda image: ((image/127.5) - 1.0))
            
        ])
        self.all_imgs = [fn for fn in os.listdir(main_dir) if fn.endswith("jpg")]

    def __len__(self):
        
        return len(self.all_imgs)

    def __getitem__(self, idx):
        
        img_loc = os.path.join(self.main_dir, self.all_imgs[idx])
        image = Image.open(img_loc).convert("RGB")
        tensor_image = self.transform(image)
        
        return tensor_image

In [15]:
def trainer(path=None, batch_size=64):

  train_data = CustomDataSet(path, size=32)
  train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=False)

  return train_loader

# Stage 1: Training VQ-GAN

## Overall Objective Function:
![](https://miro.medium.com/max/720/1*W4z0g6n8S0uvBv9bfcTsVA.webp)

![](https://miro.medium.com/max/640/1*58MFhwgN3FaojP8KNahy7w.webp)

![](https://miro.medium.com/max/366/1*s8SWQzPbVn-mwBQ3do93gg.webp)

In [16]:
class TrainVQGAN:
    def __init__(self,
                 path = None, 
                 image_channels = 3, 
                 latent_dim = 256, 
                 num_codebook_vectors = 1024, 
                 beta = 0.25,
                 lr = 3e-4,
                 beta1 = 0.5,
                 beta2 = 0.9,
                 disc_factor = 1.,
                 disc_start = 100,
                 perceptual_loss_factor = 1.,
                 rec_loss_factor = 1.,
                 epochs = 10,
                 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")):
        
        self.path = path
        self.image_channels = image_channels 
        self.latent_dim = latent_dim 
        self.num_codebook_vectors = num_codebook_vectors 
        self.beta = beta
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.disc_factor = disc_factor
        self.disc_start = disc_start
        self.perceptual_loss_factor = perceptual_loss_factor
        self.rec_loss_factor = rec_loss_factor
        self.epochs = epochs
        self.device = device
        
        self.vqgan = VQGAN(image_channels, latent_dim, num_codebook_vectors, beta, device).to(device=device)
        self.discriminator = PatchDiscriminator(image_channels).to(device=device)
        self.perceptual_loss = VGGPerceptualLoss().eval().to(device=device)
        
        self.opt_vq, self.opt_disc = self.configure_optimizers()
        self.prepare_training()
        #self.train(epochs, path, disc_factor, disc_start, perceptual_loss_factor, rec_loss_factor, device)

    def configure_optimizers(self):
  
        opt_vq = torch.optim.Adam(
            list(self.vqgan.encoder.parameters()) +
            list(self.vqgan.decoder.parameters()) +
            list(self.vqgan.codebook.parameters()) +
            list(self.vqgan.quant_conv.parameters()) +
            list(self.vqgan.post_quant_conv.parameters()),
            lr=self.lr, eps=1e-08, betas=(self.beta1, self.beta2)
        )
        opt_disc = torch.optim.Adam(self.discriminator.parameters(),
                                    lr=self.lr, eps=1e-08, betas=(self.beta1, self.beta2))

        return opt_vq, opt_disc

    @staticmethod
    def prepare_training():
        
        os.makedirs("results", exist_ok=True)
        os.makedirs("checkpoints", exist_ok=True)

    def train(self):
        
        train_dataset = trainer(path=self.path)
        steps_per_epoch = len(train_dataset)
        
        for epoch in range(self.epochs):
            with tqdm(range(len(train_dataset))) as pbar:
                
                for i, imgs in zip(pbar, train_dataset):
                    
                    imgs = imgs.to(device=self.device)
                    decoded_images, _, q_loss = self.vqgan(imgs)

                    disc_real = self.discriminator(imgs)
                    disc_fake = self.discriminator(decoded_images)

                    disc_factor = self.vqgan.adopt_weight(self.disc_factor, epoch*steps_per_epoch+i, threshold=self.disc_start)

                    perceptual_loss = self.perceptual_loss(imgs, decoded_images)
                    rec_loss = torch.abs(imgs - decoded_images)
                    perceptual_rec_loss = self.perceptual_loss_factor * perceptual_loss + self.rec_loss_factor * rec_loss
                    perceptual_rec_loss = perceptual_rec_loss.mean()
                    g_loss = -torch.mean(disc_fake)

                    λ = self.vqgan.calculate_lambda(perceptual_rec_loss, g_loss)
                    vq_loss = perceptual_rec_loss + q_loss + disc_factor * λ * g_loss

                    d_loss_real = torch.mean(F.relu(1. - disc_real))
                    d_loss_fake = torch.mean(F.relu(1. + disc_fake))
                    gan_loss = disc_factor * 0.5*(d_loss_real + d_loss_fake)

                    self.opt_vq.zero_grad()
                    vq_loss.backward(retain_graph=True)

                    self.opt_disc.zero_grad()
                    gan_loss.backward()

                    self.opt_vq.step()
                    self.opt_disc.step()

                    if i % 10 == 0:
                        with torch.no_grad():
                            real_fake_images = torch.cat((imgs[:4], decoded_images.add(1).mul(0.5)[:4]))
                            vutils.save_image(real_fake_images, os.path.join("/content/results", f"{epoch}_{i}.jpg"), nrow=4)

                    pbar.set_postfix(
                        VQ_Loss=np.round(vq_loss.cpu().detach().numpy().item(), 5),
                        GAN_Loss=np.round(gan_loss.cpu().detach().numpy().item(), 3)
                    )
                    pbar.update(0)
                
                torch.save(self.vqgan.state_dict(), os.path.join("/content/checkpoints", f"vqgan_epoch_{epoch}.pt"))

In [17]:
vqgan_trainer = TrainVQGAN(path = '/content/jpg')

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


  0%|          | 0.00/528M [00:00<?, ?B/s]

In [18]:
vqgan_trainer.train()

100%|██████████| 22/22 [00:24<00:00,  1.12s/it, GAN_Loss=0, VQ_Loss=4.53]
100%|██████████| 22/22 [00:17<00:00,  1.28it/s, GAN_Loss=0, VQ_Loss=1.8]
100%|██████████| 22/22 [00:17<00:00,  1.29it/s, GAN_Loss=0, VQ_Loss=1.84]
100%|██████████| 22/22 [00:17<00:00,  1.29it/s, GAN_Loss=0, VQ_Loss=1.24]
100%|██████████| 22/22 [00:17<00:00,  1.28it/s, GAN_Loss=1.28, VQ_Loss=5.44]
100%|██████████| 22/22 [00:17<00:00,  1.26it/s, GAN_Loss=0.991, VQ_Loss=-44.2]
100%|██████████| 22/22 [00:17<00:00,  1.26it/s, GAN_Loss=0.984, VQ_Loss=0.173]
100%|██████████| 22/22 [00:17<00:00,  1.26it/s, GAN_Loss=1.33, VQ_Loss=-1.24]
100%|██████████| 22/22 [00:18<00:00,  1.22it/s, GAN_Loss=0.904, VQ_Loss=1.64]
100%|██████████| 22/22 [00:17<00:00,  1.26it/s, GAN_Loss=1, VQ_Loss=-.084]


# Transformer

In [19]:
!git clone https://github.com/karpathy/minGPT.git
%cd /content/minGPT
!pip install -e .

Cloning into 'minGPT'...
remote: Enumerating objects: 489, done.[K
remote: Counting objects: 100% (489/489), done.[K
remote: Compressing objects: 100% (230/230), done.[K
remote: Total 489 (delta 266), reused 425 (delta 248), pack-reused 0[K
Receiving objects: 100% (489/489), 1.43 MiB | 2.73 MiB/s, done.
Resolving deltas: 100% (266/266), done.
/content/minGPT
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Obtaining file:///content/minGPT
  Preparing metadata (setup.py) ... [?25l[?25hdone
Installing collected packages: minGPT
  Running setup.py develop for minGPT
Successfully installed minGPT-0.0.1


In [20]:
from mingpt.model import GPT

![](https://miro.medium.com/max/828/0*YhKKu0sDIEfe_vI3)

In [21]:
class VQGANTransformer(nn.Module):
    
    def __init__(self, 
                 num_codebook_vectors = 1024, 
                 pkeep = 0.5,
                 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")):
        
        super().__init__()

        self.sos_token = 0
        self.num_codebook_vectors = num_codebook_vectors
        self.vqgan = self.load_vqgan(checkpoint_path = '/content/checkpoints',
                                     device = device)

        model_config = GPT.get_default_config()
        model_config.vocab_size = num_codebook_vectors
        model_config.block_size = 512
        model_config.model_type = 'gpt2-medium'
        
        self.transformer = GPT(model_config)

        self.pkeep = pkeep

    @staticmethod
    def load_vqgan(
                 image_channels = 3, 
                 latent_dim = 256, 
                 num_codebook_vectors = 1024, 
                 beta = 0.25,
                 checkpoint_path=None,
                 device=None
                 ):
        
        model = VQGAN(image_channels, latent_dim, num_codebook_vectors, beta, device)
        
        if os.path.exists(checkpoint_path):
          dirFiles = os.listdir(checkpoint_path)
          dirFiles.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
          last_ckpt = os.path.join(checkpoint_path, dirFiles[-1])
          
          model.load_checkpoint(last_ckpt)
        
        model = model.eval()
        return model

    @torch.no_grad()
    def encode_to_z(self, x):
        quant_z, indices, _ = self.vqgan.encode(x)
        indices = indices.view(quant_z.shape[0], -1)
        return quant_z, indices

    @torch.no_grad()
    def z_to_image(self, indices, p1=2, p2=2):
      
        ix_to_vectors = self.vqgan.codebook.embedding(indices).reshape(indices.shape[0], p1, p2, 256)
        ix_to_vectors = ix_to_vectors.permute(0, 3, 1, 2)
        image = self.vqgan.decode(ix_to_vectors)
        return image

    def forward(self, x):
        _, indices = self.encode_to_z(x)

        sos_tokens = torch.ones(x.shape[0], 1) * self.sos_token
        sos_tokens = sos_tokens.long().to("cuda")

        mask = torch.bernoulli(self.pkeep * torch.ones(indices.shape, device=indices.device))
        mask = mask.round().to(dtype=torch.int64)
        random_indices = torch.randint_like(indices, self.num_codebook_vectors)
        new_indices = mask * indices + (1 - mask) * random_indices

        new_indices = torch.cat((sos_tokens, new_indices), dim=1)

        target = indices

        logits, _ = self.transformer(new_indices[:, :-1])

        return logits, target

    def top_k_logits(self, logits, k):
        v, ix = torch.topk(logits, k)
        out = logits.clone()
        out[out < v[..., [-1]]] = -float("inf")
        return out

    @torch.no_grad()
    def sample(self, x, c, steps, temperature=1.0, top_k=100):
        self.transformer.eval()
        x = torch.cat((c, x), dim=1)
        
        for k in range(steps):
            logits, _ = self.transformer(x)
            logits = logits[:, -1, :] / temperature

            if top_k is not None:
                logits = self.top_k_logits(logits, top_k)

            probs = F.softmax(logits, dim=-1)

            ix = torch.multinomial(probs, num_samples=1)

            x = torch.cat((x, ix), dim=1)

        x = x[:, c.shape[1]:]
        self.transformer.train()
        return x

    @torch.no_grad()
    def log_images(self, x):
        log = dict()

        _, indices = self.encode_to_z(x)
        sos_tokens = torch.ones(x.shape[0], 1) * self.sos_token
        sos_tokens = sos_tokens.long().to("cuda")

        start_indices = indices[:, :indices.shape[1] // 2]
        sample_indices = self.sample(start_indices, sos_tokens, steps=indices.shape[1] - start_indices.shape[1])
        half_sample = self.z_to_image(sample_indices)

        start_indices = indices[:, :0]
        sample_indices = self.sample(start_indices, sos_tokens, steps=indices.shape[1])
        full_sample = self.z_to_image(sample_indices)

        x_rec = self.z_to_image(indices)

        log["input"] = x
        log["rec"] = x_rec
        log["half_sample"] = half_sample
        log["full_sample"] = full_sample

        return log, torch.concat((x, x_rec, half_sample, full_sample))

In [22]:
def plot_images(images):
    x = images["input"]
    reconstruction = images["rec"]
    half_sample = images["half_sample"]
    full_sample = images["full_sample"]

    fig, axarr = plt.subplots(1, 4)
    axarr[0].imshow(x.cpu().detach().numpy()[0].transpose(1, 2, 0))
    axarr[1].imshow(reconstruction.cpu().detach().numpy()[0].transpose(1, 2, 0))
    axarr[2].imshow(half_sample.cpu().detach().numpy()[0].transpose(1, 2, 0))
    axarr[3].imshow(full_sample.cpu().detach().numpy()[0].transpose(1, 2, 0))
    plt.show()

#Stge 2: Training Transformer

In [23]:
class TrainTransformer():
    def __init__(self, path = None, epochs = 10, device = torch.device("cuda" if torch.cuda.is_available() else "cpu")):
        self.model = VQGANTransformer().to(device=device)
        self.optim = self.configure_optimizers()
        self.device = device
        self.path = path
        self.epochs = epochs
        self.prepare_training()

    @staticmethod
    def prepare_training():
        
        os.makedirs("/content/transformer_results", exist_ok=True)
        os.makedirs("/content/transformer_checkpoints", exist_ok=True)

    
    def configure_optimizers(self):
        decay, no_decay = set(), set()
        whitelist_weight_modules = (nn.Linear, )
        blacklist_weight_modules = (nn.LayerNorm, nn.Embedding)

        for mn, m in self.model.transformer.named_modules():
            for pn, p in m.named_parameters():
                fpn = f"{mn}.{pn}" if mn else pn

                if pn.endswith("bias"):
                    no_decay.add(fpn)

                elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
                    decay.add(fpn)

                elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
                    no_decay.add(fpn)

        param_dict = {pn: p for pn, p in self.model.transformer.named_parameters()}


        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]

        optimizer = torch.optim.AdamW(optim_groups, lr=4.5e-06, betas=(0.9, 0.95))
        return optimizer

    def train(self):
        
        train_dataset = trainer(path = self.path)
        
        for epoch in range(self.epochs):
            
            with tqdm(range(len(train_dataset))) as pbar:
                
                for i, imgs in zip(pbar, train_dataset):
                    self.optim.zero_grad()
                    imgs = imgs.to(device=self.device)
                    logits, targets = self.model(imgs)
                    loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
                    loss.backward()
                    self.optim.step()
                    pbar.set_postfix(Transformer_Loss=np.round(loss.cpu().detach().numpy().item(), 4))
                    pbar.update(0)
            
            log, sampled_imgs = self.model.log_images(imgs[0][None])
            
            vutils.save_image(sampled_imgs, os.path.join("/content/transformer_results", f"transformer_{epoch}.jpg"), nrow=4)
            
            torch.save(self.model.state_dict(), os.path.join("/content/transformer_checkpoints", f"transformer_{epoch}.pt"))

In [24]:
transformer_trainer = TrainTransformer(path = '/content/jpg')

number of parameters: 303.88M


In [25]:
transformer_trainer.train()

100%|██████████| 22/22 [00:17<00:00,  1.29it/s, Transformer_Loss=0.0093]
100%|██████████| 22/22 [00:14<00:00,  1.50it/s, Transformer_Loss=0.0002]
100%|██████████| 22/22 [00:14<00:00,  1.50it/s, Transformer_Loss=0.0002]
100%|██████████| 22/22 [00:14<00:00,  1.50it/s, Transformer_Loss=0.0001]
100%|██████████| 22/22 [00:14<00:00,  1.50it/s, Transformer_Loss=0.0001]
100%|██████████| 22/22 [00:14<00:00,  1.49it/s, Transformer_Loss=0.0001]
100%|██████████| 22/22 [00:14<00:00,  1.50it/s, Transformer_Loss=0.0001]
100%|██████████| 22/22 [00:15<00:00,  1.43it/s, Transformer_Loss=0.0001]
100%|██████████| 22/22 [00:14<00:00,  1.52it/s, Transformer_Loss=0]
100%|██████████| 22/22 [00:14<00:00,  1.52it/s, Transformer_Loss=0]
