In [8]:
#!pip install lpips

In [9]:
import torch
import torchvision
#from torchvision.transforms import v2 ####
from torch import nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.utils.data import DataLoader, Dataset
try:
    import lpips
except:
    pass
import numpy as np
import cv2
import os
import time
import shutil
import matplotlib.pyplot as plt
#os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [10]:
torch.manual_seed(42)
np.random.seed(12)

phase = 2
# 1: training vae/gan & get encoding dataset
# 2: training transformer
# 3: product phase

config = {'ver': f'afhq_maskgit_0412_v02_phase{phase}',
          'description': 'train_again',
          'activation': 'relu', 'num_res_blocks': 6, 'conv_in_channels': 64,
          'res_bottle_neck_factor': 2,
          'channels_mult': (1,2,2,3,4),
          'embed_dim': 128, 'n_embed': 512,
          'codebook_beta': 0.25,
          'base_learning_rate': 3.0e-4, 
          'min_learning_rate':1e-4,
          'gan_loss_weight': 0.03,
          'batch_size': 32,
          'max_epochs': 40, 'train_test_split': 0.9,
          'save_every_n_epoch':40,
          'dataset_size': 15228,
          'rec_loss_weight': 1 ,#/ 0.06327039811675479 / 4, # pixel level rec loss
          'lpips_type': 'vgg',
          'lpips_loss_weight': 1,
          'cuda': torch.cuda.is_available(),
          'image_size': (256,256),###
          'load': '/kaggle/input/afhq-vqvae-0412-v01/vq12.pth',
          'save': True,
          'use_disc': False,
          'phase': phase  
          } 
#if phase!=1:
#    assert (not config['use_disc']) and config['lpips_loss_weight']<=0

config['Ds_ratio'] = 2**(len(config['channels_mult'])-1)
config['latent_size']=(round(config['image_size'][0]/config['Ds_ratio']),
                       round(config['image_size'][1]/config['Ds_ratio']))
config['one_gan_loss_for_x_rec_loss'] = config['gan_loss_weight']/config['rec_loss_weight']


d_config = {'channels_mult': (1,2,4),
            'conv_in_channels': 64,
            'base_learning_rate': 8e-5
           }

d_config['Ds_ratio'] = 2**(len(d_config['channels_mult']))
d_config['patch_size']=(round(config['image_size'][0]/d_config['Ds_ratio']),
                        round(config['image_size'][1]/d_config['Ds_ratio']))

t_config = {'n_pos':256,
            'n_tokens':512,
            'embed_dim':192,
            'nhead':8,
            'hidden_dim':768,
            'n_layers':8,
            'n_steps':10,
            'base_learning_rate':2e-4,
            'min_learning_rate':5e-5,
            'batch_size':16,
            'batch_acc':1,
            'load':'/kaggle/input/maskgit-0412-v01/maskgit40.pth',
            #'temperature':(3,1),
           }

assert t_config['n_pos']==config['latent_size'][0]*config['latent_size'][1]

data_path_afhq = '/kaggle/input/afhq-512'
data_path_celeba = '/kaggle/input/celeba-dataset/img_align_celeba/img_align_celeba'
data_path_enc = '/kaggle/input/afhq-vqvae-0412-v01/12_enc.npy'
outcome_root = "/kaggle/working/vqvae"
if not os.path.exists(os.path.join(outcome_root, config['ver'])):
    os.makedirs(os.path.join(outcome_root, config['ver']))
config_path = os.path.join(outcome_root, f"{config['ver']}/config.txt")
with open(config_path, 'w') as f:
    f.write(str(config) + '\n')
    if config['use_disc']:
        f.write(str(d_config) + '\n')
    f.write(str(t_config) + '\n')
    #f.write(f"N_GD: {config['max_epochs']*config['dataset_size']*config['train_test_split']/config['batch_size']:.0f}\n")

print(f"Cuda Availability: {config['cuda']}")


Cuda Availability: True


In [11]:
if config['activation'] == 'swish':
    def activation(x):
        return x * F.sigmoid(x)
elif config['activation'] == 'hardswish':
    activation = F.hardswish_
else:
    activation = F.relu_


def cf2cl(tensor):
    return torch.permute(tensor, [0, 2, 3, 1])


def cl2cf(tensor):
    return torch.permute(tensor, [0, 3, 1, 2])


def norm(in_channels):
    #return nn.Identity()
    # return nn.BatchNorm2d(num_features=in_channels)
    #return nn.InstanceNorm2d(num_features=in_channels, eps=1e-05, momentum=0.1, affine=False)
    return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=False)

def d_norm(in_channels):
    return nn.Identity()
    #return nn.BatchNorm2d(num_features=in_channels)
    #return nn.InstanceNorm2d(num_features=in_channels, eps=1e-05, momentum=0.1, affine=False)
    #return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=False)


class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels=None, kernel_size=3):
        super().__init__()
        out_channels = out_channels if out_channels else in_channels

        self.conv = nn.Conv2d(in_channels=in_channels,
                                out_channels=out_channels,
                                kernel_size=kernel_size,
                                stride=2,
                                padding=1)

    def forward(self, x):
        h = self.conv(x)
        return h


class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels=None, kernel_size=3):
        super().__init__()
        out_channels = out_channels if out_channels else in_channels
        
        if kernel_size == 4:
            self.upsample = nn.Identity()
            self.conv = nn.ConvTranspose2d(in_channels=in_channels,
                                           out_channels=out_channels,
                                           kernel_size=4,
                                           stride=2,
                                           padding=1,
                                           output_padding=0)
            
        elif kernel_size == 3:
            self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
            self.conv = nn.Conv2d(in_channels=in_channels,
                                  out_channels=out_channels,
                                  kernel_size=3,
                                  stride=1,
                                  padding='same')
            
    def forward(self, x):
        h = self.upsample(x)
        h = self.conv(h)
        return h


class ResBlock(nn.Module):

    def __init__(self, in_channels, bottle_neck_channels=None, out_channels=None):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels if out_channels else in_channels
        if bottle_neck_channels is not None:
            self.bottle_neck_channels = bottle_neck_channels
        else:
            self.bottle_neck_channels = max(self.out_channels,
                                            self.in_channels)\
                                        // config['res_bottle_neck_factor']
            self.bottle_neck_channels = max(32, self.bottle_neck_channels)

        self.norm1 = norm(in_channels)
        self.conv1 = nn.Conv2d(in_channels=in_channels,
                               out_channels=self.bottle_neck_channels,
                               kernel_size=3,
                               stride=1,
                               padding='same')
        self.norm2 = norm(self.bottle_neck_channels)
        self.conv2 = nn.Conv2d(in_channels=self.bottle_neck_channels,
                               out_channels=self.out_channels,
                               kernel_size=3,
                               stride=1,
                               padding='same')
        if self.in_channels != self.out_channels:
            self.conv_shortcut = nn.Conv2d(in_channels=in_channels,
                                           out_channels=self.out_channels,
                                           kernel_size=1,
                                           stride=1,
                                           padding='same')
        else:
            self.conv_shortcut = nn.Identity()
        self.rescale = 1  # / config['num_res_blocks']

    def forward(self, x):
        h = x
        h = self.norm1(h)
        h = activation(h)
        h = self.conv1(h)
        h = self.norm2(h)
        h = activation(h)
        h = self.conv2(h)
        x = self.conv_shortcut(x)

        return x + h # self.rescale

class AttnBlock(nn.Module):
    def __init__(self, in_channels,embed_channels=None):
        super().__init__()
        self.in_channels = in_channels
        if embed_channels is not None:
            self.embed_channels = embed_channels
        else:
            self.embed_channels = in_channels

        self.norm = norm(in_channels)
        self.q = nn.Conv2d(in_channels,
                           self.embed_channels,
                           kernel_size=1,
                           stride=1,
                           padding=0)
        self.k = nn.Conv2d(in_channels,
                           self.embed_channels,
                           kernel_size=1,
                           stride=1,
                           padding=0)
        self.v = nn.Conv2d(in_channels,
                           self.embed_channels,
                           kernel_size=1,
                           stride=1,
                           padding=0)
        self.proj_out = nn.Conv2d(self.embed_channels,
                                  in_channels,
                                  kernel_size=1,
                                  stride=1,
                                  padding=0)


    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention
        b,c,h,w = q.shape
        q = q.reshape(b,c,h*w)
        q = q.permute(0,2,1)   # b,hw,c
        k = k.reshape(b,c,h*w) # b,c,hw
        w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
        w_ = w_ * (int(c)**(-0.5))
        w_ = torch.nn.functional.softmax(w_, dim=2)

        # attend to values
        v = v.reshape(b,c,h*w)
        w_ = w_.permute(0,2,1)   # b,hw,hw (first hw of k, second of q)
        h_ = torch.bmm(v,w_)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
        h_ = h_.reshape(b,c,h,w)

        h_ = self.proj_out(h_)

        return x+h_



class Encoder(nn.Module):

    def __init__(self, in_channels=3,
                 conv_in_channels=config['conv_in_channels'],
                 out_channels=config['embed_dim'],
                 channels_mult=config['channels_mult']):
        super().__init__()
        
        self.conv_in = nn.Conv2d(in_channels=in_channels,
                                 out_channels=conv_in_channels * channels_mult[0],
                                 kernel_size=3,
                                 stride=1,
                                 padding='same')
        current_channels = conv_in_channels * channels_mult[0]
        
        layers = nn.ModuleList()
        for i, m in enumerate(channels_mult): 
            blk_in = current_channels
            blk_out = conv_in_channels * m
            if i != len(channels_mult)-1:
                layers.append(ResBlock(in_channels=blk_in,
                                       out_channels=blk_in))  ###
                layers.append(DownSample(in_channels=blk_in,
                                         out_channels=blk_out))
                
            else:
                layers.append(ResBlock(in_channels=blk_in,
                                       out_channels=blk_out))
            current_channels = blk_out
        self.layers=layers
        
        self.mid_res1 = ResBlock(in_channels=current_channels,
                                 out_channels=current_channels)
        
        self.mid_attn = nn.Identity()#AttnBlock(in_channels=current_channels)
        
        self.mid_res2 = ResBlock(in_channels=current_channels,
                                 out_channels=current_channels)
        
        self.norm_out = norm(current_channels)
        self.pre_vq_conv = nn.Conv2d(in_channels=current_channels,
                                     out_channels=out_channels,
                                     kernel_size=1,
                                     stride=1,
                                     padding='same')

    def forward(self, x):
        h = self.conv_in(x)
        for layer in self.layers:
            h = layer(h)
        h = self.mid_res1(h)
        h = self.mid_attn(h)
        h = self.mid_res2(h)
        h = self.norm_out(h)
        h = activation(h)
        h = self.pre_vq_conv(h)
        return h


class Decoder(nn.Module):

    def __init__(self, in_channels=config['embed_dim'],
                 conv_in_channels=config['conv_in_channels'],
                 channels_mult=config['channels_mult'],
                 out_channels=3):
        super().__init__()
        self.conv_in = nn.Conv2d(in_channels=in_channels,
                                 out_channels=conv_in_channels * channels_mult[-1] ,
                                 kernel_size=3,
                                 stride=1,
                                 padding='same')
        current_channels = conv_in_channels * channels_mult[-1]
        
        self.mid_res1 = ResBlock(in_channels=current_channels,                                 
                                 out_channels=current_channels)
        
        self.mid_attn = nn.Identity()#AttnBlock(in_channels=current_channels)
        
        self.mid_res2 = ResBlock(in_channels=current_channels,
                                 out_channels=current_channels)
        
        layers = nn.ModuleList()
        for i, m in enumerate(reversed((1,) + channels_mult[:-1])): 
            blk_in = current_channels
            blk_out = conv_in_channels * m
            if i != 0:
                layers.append(ResBlock(in_channels=blk_in,
                                       out_channels=blk_out)) ### 
                layers.append(UpSample(in_channels=blk_out,
                                       out_channels=blk_out))
                
            else:
                layers.append(ResBlock(in_channels=blk_in,
                                       out_channels=blk_out))
            current_channels = blk_out
        self.layers = layers
        
        self.norm_out = norm(current_channels)
        self.conv_out = nn.Conv2d(in_channels=current_channels,
                                  out_channels=out_channels,
                                  kernel_size=1,
                                  stride=1,
                                  padding='same')
        

    def forward(self, x):
        h = self.conv_in(x)
        h = self.mid_res1(h)
        h = self.mid_attn(h)
        h = self.mid_res2(h)
        for layer in self.layers:
            h = layer(h)
        h = self.norm_out(h)
        h = activation(h)
        h = self.conv_out(h)
        return h


class EmbeddingEMA(nn.Module):
    def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
        super().__init__()
        self.decay = decay
        self.eps = eps
        weight = torch.randn(num_tokens, codebook_dim)
        self.weight = nn.Parameter(weight, requires_grad=False)
        self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
        self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
        self.update = True

    def forward(self, embed_id):
        return F.embedding(embed_id, self.weight)

    def cluster_size_ema_update(self, new_cluster_size):
        self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)

    def embed_avg_ema_update(self, new_embed_avg):
        self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)

    def weight_update(self, num_tokens):
        n = self.cluster_size.sum()
        smoothed_cluster_size = (
                (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
        )
        # normalize embedding average with smoothed cluster size
        embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
        self.weight.data.copy_(embed_normalized)
    
    def reinit(self,probs):
        n,l=self.weight.shape
        for i in range(n):
            if probs[i]==0:
                init.uniform_(self.weight[i], -1.0, 1.0)
        
        


class CodeBook(nn.Module):

    def __init__(self, embed_dim=config['embed_dim'], n_embed=config['n_embed'],
                 beta=config['codebook_beta'], decay=0.99, eps=1e-5):
        super().__init__()
        self.embed_dim = embed_dim
        self.n_embed = n_embed
        self.embed = EmbeddingEMA(num_tokens=n_embed, codebook_dim=embed_dim,
                                  decay=decay, eps=eps)
        # weight[num_embeddings, embedding_dim]
        init.uniform_(self.embed.weight, -1.0, 1.0)  # ?
        self.beta = beta
        self.decay = decay
        self.eps = eps

    def forward(self, z_e: torch.Tensor):
        # b, c, h, w -> b, h, w, c
        z_e = torch.permute(z_e, [0, 2, 3, 1]).contiguous()
        z_flat = z_e.view(-1, self.embed_dim)

        d = torch.cdist(z_flat, self.embed.weight, p=2)

        closest_indices = torch.argmin(d, dim=1)
        z_q = self.embed(closest_indices).view(z_e.shape)
        loss = torch.mean(self.beta * (z_e - z_q.detach()) ** 2)

        encodings = F.one_hot(closest_indices, self.n_embed).type(z_e.dtype)
        # EMA cluster size
        encodings_sum = encodings.sum(axis=0).detach()

        if self.training and self.embed.update:
            with torch.no_grad():
                self.embed.cluster_size_ema_update(encodings_sum)
                # EMA embedding average
                embed_sum = encodings.transpose(0, 1) @ z_flat
                self.embed.embed_avg_ema_update(embed_sum)
                # normalize embed_avg and update weight
                self.embed.weight_update(self.n_embed)

        z_q = z_e + (z_q - z_e).detach()
        # b, h, w, c -> b, c, h, w
        z_q = (torch.permute(z_q, [0, 3, 1, 2])).contiguous()

        return {'z_q': z_q, 'loss': loss, 'encodings_sum': encodings_sum,
               'encodings': closest_indices.detach()}
    
class Discriminator(nn.Module):

    def __init__(self, in_channels=3,
                 conv_in_channels=d_config['conv_in_channels'],
                 out_channels=1,
                 channels_mult=d_config['channels_mult']): # (1,2,4)
        super().__init__()
        
        self.conv_in = nn.Conv2d(in_channels=in_channels,
                                 out_channels=conv_in_channels * channels_mult[0],
                                 kernel_size=4,
                                 stride=2,
                                 padding=1)
        current_channels = conv_in_channels * channels_mult[0]
        self.conv_in_activation = nn.LeakyReLU(0.2,True)
        
        layers = nn.ModuleList()
        for i, m in enumerate(channels_mult[1:]): 
            blk_in = current_channels
            blk_out = conv_in_channels * m
            layers.append(nn.Conv2d(in_channels=blk_in,
                                    out_channels=blk_out,
                                    kernel_size=4,
                                    stride=2,
                                    padding=1,
                                    bias=True)) 
            layers.append(d_norm(blk_out))
            layers.append(nn.LeakyReLU(0.2,True))
            current_channels = blk_out   
        
        layers.append(nn.Conv2d(in_channels=current_channels,
                                out_channels=current_channels,
                                kernel_size=4,
                                stride=1,
                                padding=1,
                                bias=True)) 
        layers.append(d_norm(current_channels))
        layers.append(nn.LeakyReLU(0.2,True))
        
        self.layers=layers

        self.conv_out = nn.Conv2d(in_channels=current_channels,
                                  out_channels=out_channels,
                                  kernel_size=4,
                                  stride=1,
                                  padding=1)

    def forward(self, x):
        h = self.conv_in(x)
        h = self.conv_in_activation(h)
        for layer in self.layers:
            h = layer(h)
        h = self.conv_out(h)
        return F.sigmoid(h)
    

class VQGAN(nn.Module):

    def __init__(self, channels=3, embed_dim=config['embed_dim'], use_disc=config['use_disc']):
        super().__init__()
        self.embed_dim=embed_dim
        self.encoder = Encoder(in_channels=channels, out_channels=embed_dim)
        self.code_book = CodeBook(embed_dim=embed_dim)
        self.decoder = Decoder(in_channels=embed_dim, out_channels=channels)
        if use_disc and phase==1:
            self.discriminator = Discriminator(in_channels=channels, out_channels=1)
        self.rec_loss_weight = config['rec_loss_weight']
        self.gan_loss_weight = config['gan_loss_weight']
        self.disc_loss_weight = 1
        self.use_disc= use_disc and phase==1
        if config['lpips_loss_weight']>0 and phase==1:
            self.lpips_loss_weight = config['lpips_loss_weight']
            self.lpips_fn = lpips.LPIPS(net=config['lpips_type'])
            if config['cuda']:
                self.lpips_fn.cuda()
        else:
            self.lpips_loss_weight = None
        
    
    @torch.no_grad()
    def update_gan_loss_weight(self,epoch,d_loss_val):
        return
        #if epoch<=3:
         #   self.gan_loss_weight = 0.15
        #elif 5<epoch<9:
        #    self.gan_loss_weight = 0.06*(epoch-4)
        #else:
         #   self.gan_loss_weight = 0.3
        #self.gan_loss_weight = self.gan_loss_weight * max(0.2,-d_loss_val)
        
    @torch.no_grad()
    def update_disc_loss_weight(self,epoch,d_loss_val):
        return
        #self.disc_loss_weight = 1 #max(0.5,(d_loss_val+1)/2)
    
    @torch.no_grad()
    def encode(self, x):
        z_e = self.encoder(x)
        codebook_output = self.code_book(z_e)
        return codebook_output['encodings']
    
    @torch.no_grad()
    def decode(self, ind):
        z_q = self.code_book.embed(ind).view([-1, config['latent_size'][0],
                                              config['latent_size'][1],
                                              self.embed_dim]).contiguous()
        z_q = torch.permute(z_q, [0, 3, 1, 2]).contiguous()
        recx = self.decoder(z_q)
        return recx
    
    def calculate_adaptive_weight(self, nll_loss, g_loss):
        last_layer = self.decoder.conv_out.weight
        nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=False)[0]
        g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=False)[0]
        
        d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
        # d_weight = d_weight * self.discriminator_weight
        return d_weight
    
    def reconstruct(self,x):
        z_e = self.encoder(x)
        codebook_output = self.code_book(z_e)
        z_q, codebook_loss = codebook_output['z_q'], codebook_output['loss']
        recx = self.decoder(z_q)
        recx = torch.clip(recx,-1,1)

        return {'recx':recx,
                'codebook_loss':codebook_loss,
                'encodings_sum':codebook_output['encodings_sum'],
                'encodings':codebook_output['encodings'],
                }
    
    def discriminate(self,fake_x,real_x=None,on_train=True):
        if real_x is None:
            self.discriminator.eval()
            disc_out_fake = self.discriminator(fake_x)
            if on_train:
                self.discriminator.train()
            disc_out_real=None
        else:
            disc_out_fake = self.discriminator(fake_x)
            disc_out_real = self.discriminator(real_x)
        return {'disc_out_fake':disc_out_fake,
                'disc_out_real':disc_out_real}
    
    @staticmethod
    def calculate_bce(proba,target=1):
        if target==1:
            return -torch.mean(torch.log(proba+1e-2))
        elif target==0:
            return -torch.mean(torch.log((1-proba)+1e-2))
    
    def calculate_g_loss(self,x,recx,codebook_loss,disc_out_fake=None):
        rec_loss = torch.mean(torch.abs(recx - x)) #+ torch.abs(recx - x) * 0.05
        lpips_loss = self.lpips_fn.forward(recx,x).mean()
        if self.use_disc:
            gan_loss = self.calculate_bce(disc_out_fake,target=1)
            gan_acc = 1-torch.mean(disc_out_fake)
        else:
            gan_loss = 0
            gan_acc = 0
        tot_loss = rec_loss*self.rec_loss_weight+\
                   codebook_loss+\
                   lpips_loss*self.lpips_loss_weight+\
                   gan_loss*self.gan_loss_weight
        return {'rec_loss':rec_loss,
                'codebook_loss':codebook_loss,
                'lpips_loss':lpips_loss,
                'gan_loss':gan_loss,
                'gan_accuracy':gan_acc,
                'tot_loss':tot_loss
               }
    
    def calculate_d_loss(self,disc_out_fake,disc_out_real):
        disc_loss = self.calculate_bce(disc_out_fake,target=0)+\
                    self.calculate_bce(disc_out_real,target=1)
        tot_loss = disc_loss * self.disc_loss_weight
        disc_acc_r = torch.mean(disc_out_real)
        disc_acc_f = 1-torch.mean(disc_out_fake)
        return {'disc_loss':disc_loss,
                'tot_loss':tot_loss,
                'disc_accuracy_real':disc_acc_r,
                'disc_accuracy_fake':disc_acc_f,
               }

class MaskGIT(nn.Module):
    
    def __init__(self, n_tokens=512, n_pos=143, 
                 embed_dim=128, nhead=8, hidden_dim=1024,
                 n_layers=6,
                 n_steps=8):
        super().__init__()
        
        self.n_steps=n_steps
        self.n_pos=n_pos
        self.embed_dim=embed_dim
        self.n_tokens=n_tokens
        
        weight = torch.randn(n_tokens,embed_dim)
        self.token_embed = nn.Parameter(weight, requires_grad=True)
        init.trunc_normal_(self.token_embed, 0, 0.02)
        weight = torch.randn(n_pos,embed_dim)
        self.pos_embed = nn.Parameter(weight, requires_grad=True)
        init.trunc_normal_(self.pos_embed, 0, 0.02)
        
        self.mask_embed=nn.Parameter(torch.zeros([embed_dim,]), requires_grad=True)
        
        self.embed_out=nn.ModuleList([nn.GELU(),
                                      nn.LayerNorm(embed_dim)])
        
        self.encoder = nn.ModuleList([nn.TransformerEncoderLayer(d_model=embed_dim, 
                                                                 nhead=nhead, 
                                                                 dim_feedforward=hidden_dim,
                                                                 dropout=0.1, 
                                                                 activation='relu',
                                                                 batch_first=True)
                                      for _ in range(n_layers)])
        
        
        self.proj_out=nn.ModuleList([nn.Linear(embed_dim,embed_dim),
                                     nn.GELU(),
                                     nn.LayerNorm(embed_dim)])
        self.bias = nn.Parameter(torch.zeros([n_pos,n_tokens]), requires_grad=True)
        self.debugger=None
    
    @torch.no_grad()
    def calculate_n_mask(self,x=None):
        n = torch.cos(x*3.1415926535/2)*self.n_pos
        n = torch.round(n).int()
        return n.item()
    
    @torch.no_grad()
    def sample_mask(self):
        n=self.calculate_n_mask(x=torch.rand((1,)))
        mask=torch.full((self.n_pos,),False,dtype=torch.bool)
        r = torch.rand((self.n_pos,))
        _, selected_positions = torch.topk(r, k=n, dim=-1) # (n_masked,)
        mask[selected_positions]=True
        return mask
    
    def embed(self,ind,mask=None): # ind/mask [batch_size,n_pos] or [n_pos,]
        embedding = self.token_embed[ind] # [batch_size,n_pos,embed_dim] or [n_pos,embed_dim]
        if mask is not None:
            embedding[mask]=self.mask_embed # [n_masked,embed_dim] <- [embed_dim,]
        embedding = embedding + self.pos_embed # [(batch_size,)n_pos,embed_dim]+[n_pos,embed_dim]
        return embedding
    
    def train_val_step(self,x): # x[batch_size, n_pos]
        masks=[]
        for b in range(x.shape[0]):
            masks.append(self.sample_mask())
        mask=torch.vstack(masks)
        embedding=self.embed(x,mask=mask)
        logits=self.forward(masked_embedding=embedding)
        return {'logits':logits,'mask':mask}
    
    def calculate_loss(self,x,logits,mask):
        logits_=logits[mask].view(-1,self.n_tokens).contiguous()
        x_=x[mask].view(-1).contiguous().long()
        ce_loss=F.cross_entropy(logits_,
                                target=x_,
                                label_smoothing=0.1).mean()
        log_proba = F.log_softmax(logits_.detach(),dim=-1)
        raw_perplexity = torch.gather(log_proba,dim=1,index=x_.unsqueeze(-1)).mean()
        return {"tot_loss":ce_loss,"raw_perplexity":raw_perplexity}
        
    def forward(self,masked_embedding):
        h=masked_embedding
        for layer in self.embed_out:
            h=layer(h)
        for layer in self.encoder:
            h=layer(h)
        for layer in self.proj_out:
            h=layer(h)
        logits = torch.matmul(h,self.token_embed.T) + self.bias
        return logits
    
    @torch.no_grad()
    def unconditional_generate(self,temperature=(1,1),n_steps=None):
        # assert batch_size == 1
        if n_steps is None:
            n_steps=self.n_steps
        self.eval()
        ind_ls=[]
        current_ind = (torch.rand((self.n_pos,))*(self.n_tokens-1)).long() # [n_pos,]
        n_masked=self.n_pos
        mask = torch.full((self.n_pos,),True,dtype=torch.bool) # [n_pos,]
        for t in range(1,n_steps):
            embedding=self.embed(current_ind,mask=mask) # [n_pos,embed_dim]
            logits=self.forward(masked_embedding=embedding.view(1,self.n_pos,self.embed_dim)).squeeze(0) # [n_pos,n_tokens]
            masked_logits=logits.clone()[mask]/temperature[0] # [n_masked,n_tokens]
            token_dis=torch.distributions.categorical.Categorical(logits=masked_logits)
            token_sample=token_dis.sample() # [n_masked,]
            token_confidence=torch.gather(token_dis.probs, # [n_masked,]
                                          dim=-1,
                                          index=token_sample.unsqueeze(-1)).squeeze(-1)
            
            sorted_confidence,_=torch.sort(token_confidence,
                                           dim=-1,descending=True)
            n=self.calculate_n_mask(x=torch.tensor(t/n_steps, 
                                                   dtype=torch.float32).view(1,))
            dn=n_masked-n
            n_masked=n
            threshold_confidence=sorted_confidence[dn] # [1,]
            confident_token_flag=token_confidence>threshold_confidence # [n_masked,]
            current_ind[mask]=torch.where(confident_token_flag.cpu(),
                                          token_sample.cpu(),
                                          current_ind[mask])
            mask[mask.clone()]=~confident_token_flag.cpu()

            assert torch.abs(torch.sum(mask)-n_masked).cpu().item()<=1
            ind_ls.append(current_ind.clone().squeeze(0))
        embedding=self.embed(current_ind,mask=mask) # [n_pos,embed_dim]
        logits=self.forward(masked_embedding=embedding.view(1,self.n_pos,self.embed_dim).contiguous()).squeeze(0) # [n_pos,n_tokens]
        masked_logits=logits.clone()[mask] # [n_masked,n_token]
        token_dis=torch.distributions.categorical.Categorical(logits=masked_logits)
        token_sample=token_dis.sample() # [n_masked,]
        current_ind[mask]=token_sample.cpu()
        ind_ls.append(current_ind.clone().squeeze(0))
        return ind_ls

class EMALogger:
    def __init__(self,decay=0.9):
        self.decay=decay
        self.val=0
    def update(self,val):
        self.val = self.decay*self.val + val*(1-self.decay)


In [12]:
@torch.no_grad()
def save_phase1(epoch=0):
    assert phase==1
    torch.save(model.state_dict(),
               os.path.join(outcome_root,f"{config['ver']}/vq{epoch}.pth"))
    n_pos=config['latent_size'][0]*config['latent_size'][1]
    model.eval()
    acc_encodings = []
    for _, batch_data in enumerate(train_dataloader):
        if config['cuda']:
            batch_data = batch_data.cuda()
        encodings = model.encode(batch_data)
        acc_encodings.append(np.reshape(encodings.cpu().numpy().astype('uint16'),
                                        [-1,n_pos]))

    for _, batch_data in enumerate(test_dataloader):
        if config['cuda']:
            batch_data = batch_data.cuda()
        encodings = model.encode(batch_data)
        acc_encodings.append(np.reshape(encodings.cpu().numpy().astype('uint16'),
                                        [-1,n_pos]))
    acc_encodings=np.vstack(acc_encodings)
    np.save(os.path.join(outcome_root, f"{config['ver']}/{epoch}_enc.npy"),acc_encodings)

@torch.no_grad()
def save_phase2(epoch='maskgit'):
    assert phase==2
    maskgit.eval()
    torch.save(maskgit.state_dict(),
               os.path.join(outcome_root,f"{config['ver']}/maskgit{epoch}.pth"))

    
def print_num_params(model, name=""):
    with torch.no_grad():
        num_params = 0
        for param in model.parameters():
            num_params += param.numel()
        with open(config_path, 'a') as f:
            f.write(f"{name} parameters: {num_params}\n")
        print(f"{name} parameters: {num_params}")
        
def train_step(model, epoch):
    model.train()
    n_batch = 0
    acc_rec_loss = 0
    acc_cb_loss = 0
    acc_lpips_loss = 0
    acc_gan_loss = 0
    acc_gan_acc = 0
    acc_encodings_sum = 0
    acc_d_loss = 0
    acc_d_acc_r = 0
    acc_d_acc_f = 0
    disc_skip = 0
    for batch_ind, batch_data in enumerate(train_dataloader):
        n_batch += 1
        if config['cuda']:
            batch_data = batch_data.cuda()
        
        d_optim.zero_grad()
        rec_out = model.reconstruct(batch_data)
        disc_out = model.discriminate(fake_x=rec_out['recx'].detach(),
                                      real_x=batch_data,on_train=True)
        d_loss = model.calculate_d_loss(disc_out['disc_out_fake'],
                                        disc_out['disc_out_real'])
        if d_loss['disc_accuracy_real']+d_loss['disc_accuracy_fake']<1.8:
            d_loss['tot_loss'].backward()
            d_optim.step()
        else:
            disc_skip+=1
        
        
        g_optim.zero_grad()
        disc_out = model.discriminate(fake_x=rec_out['recx'],on_train=True)
        g_loss = model.calculate_g_loss(x=batch_data,
                                        recx=rec_out['recx'],
                                        codebook_loss=rec_out['codebook_loss'],
                                        disc_out_fake=disc_out['disc_out_fake'])

        g_loss['tot_loss'].backward()
        g_optim.step()
        
       
        d_loss_logger.update(d_loss['disc_loss'].detach().cpu().item())
        
        #model.update_gan_loss_weight(epoch=epoch,d_loss_val=d_loss_logger.val)
        #model.update_disc_loss_weight(epoch=epoch,d_loss_val=d_loss_logger.val)
        
        acc_rec_loss += g_loss['rec_loss'].detach()
        acc_cb_loss += g_loss['codebook_loss'].detach()
        acc_lpips_loss += g_loss['lpips_loss'].detach()
        acc_gan_loss += g_loss['gan_loss'].detach()
        acc_gan_acc += g_loss['gan_accuracy'].detach()
        acc_encodings_sum += rec_out['encodings_sum'].detach().cpu()
        acc_d_loss += d_loss['disc_loss'].detach()
        acc_d_acc_r += d_loss['disc_accuracy_real'].detach()
        acc_d_acc_f += d_loss['disc_accuracy_fake'].detach()
        
    avg_probs = acc_encodings_sum / torch.sum(acc_encodings_sum)
    perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
    if epoch % 1 == 0:
        with torch.no_grad():
            vis_img(batch_data, rec_out['recx'], f"train {epoch}")
    if epoch%config['save_every_n_epoch']==0:
        save_phase1(epoch)
    if epoch % 3 == 1 and epoch >= 3:
        with torch.no_grad():
            model.code_book.embed.reinit(avg_probs.cuda())
    info = f'Train Epoch: {epoch}.\n' +\
           f'rec_loss: {acc_rec_loss / n_batch:.4f}; ' +\
           f'codebook_loss: {acc_cb_loss / n_batch:.4f}; ' +\
           f'lpips_loss: {acc_lpips_loss / n_batch:.4f}; ' +\
           f'gan_loss: {acc_gan_loss / n_batch:.4f}; ' +\
           f'gan_accuray: {acc_gan_acc / n_batch:.4f}; ' +\
           f'perplexity: {perplexity:.4f}; ' +\
           f'disc_loss: {acc_d_loss / n_batch:.4f}; ' +\
           f'disc_accuracy: {acc_d_acc_r/n_batch:.4f}/{acc_d_acc_f/n_batch:.4f}; ' + \
           f'disc_skip: {disc_skip}\n'
    with open(config_path, 'a') as f:
        f.write(info)

    print(info)


def val_step(model, epoch):
    model.eval()
    with torch.no_grad():
        n_batch = 0
        acc_rec_loss = 0
        acc_cb_loss = 0
        acc_lpips_loss = 0
        acc_gan_loss = 0
        acc_gan_acc = 0
        acc_encodings_sum = 0
        acc_d_loss = 0
        acc_d_acc_r = 0
        acc_d_acc_f = 0
       
        for batch_ind, batch_data in enumerate(test_dataloader):
            n_batch += 1
            if config['cuda']:
                batch_data = batch_data.cuda()
            rec_out = model.reconstruct(batch_data)
            disc_out = model.discriminate(fake_x=rec_out['recx'],
                                          real_x=batch_data,on_train=False)
            d_loss = model.calculate_d_loss(disc_out['disc_out_fake'],
                                            disc_out['disc_out_real'])
            g_loss = model.calculate_g_loss(x=batch_data,
                                            recx=rec_out['recx'],
                                            codebook_loss=rec_out['codebook_loss'],
                                            disc_out_fake=disc_out['disc_out_fake'])
            acc_rec_loss += g_loss['rec_loss'].detach()
            acc_cb_loss += g_loss['codebook_loss'].detach()
            acc_lpips_loss += g_loss['lpips_loss'].detach()
            acc_gan_loss += g_loss['gan_loss'].detach()
            acc_gan_acc += g_loss['gan_accuracy'].detach()
            acc_encodings_sum += rec_out['encodings_sum'].detach().cpu()
            acc_d_loss += d_loss['disc_loss'].detach()
            acc_d_acc_r += d_loss['disc_accuracy_real'].detach()
            acc_d_acc_f += d_loss['disc_accuracy_fake'].detach()
        avg_probs = acc_encodings_sum / torch.sum(acc_encodings_sum)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        if epoch % 1 == 0:
            vis_img(batch_data, rec_out['recx'], f"test {epoch}")
        info = f'Test Epoch: {epoch}.\n' +\
               f'rec_loss: {acc_rec_loss / n_batch:.4f}; ' +\
               f'codebook_loss: {acc_cb_loss / n_batch:.4f}; ' +\
               f'lpips_loss: {acc_lpips_loss / n_batch:.4f}; ' +\
               f'gan_loss: {acc_gan_loss / n_batch:.4f}; ' +\
               f'gan_accuray: {acc_gan_acc / n_batch:.4f}; ' +\
               f'perplexity: {perplexity:.4f}; ' +\
               f'disc_loss: {acc_d_loss / n_batch:.4f}; ' +\
               f'disc_accuracy: {acc_d_acc_r/n_batch:.4f}/{acc_d_acc_f/n_batch:.4f}\n'
        with open(config_path, 'a') as f:
            f.write(info)
        if epoch==config['max_epochs']:
            avg_probs=avg_probs.numpy()
            avg_probs=np.sort(avg_probs)
            plt.bar(np.arange(avg_probs.shape[0]),avg_probs)
            plt.savefig(os.path.join(outcome_root, f"{config['ver']}/Dis.png"),dpi=80)
            plt.clf()
    print(info)


def train_step_vae(model, epoch):
    model.train()
    n_batch = 0
    acc_rec_loss = 0
    acc_cb_loss = 0
    acc_lpips_loss = 0
    acc_encodings_sum = 0
    for batch_ind, batch_data in enumerate(train_dataloader):
        n_batch += 1
        if config['cuda']:
            batch_data = batch_data.cuda()
        
        rec_out = model.reconstruct(batch_data)
        g_optim.zero_grad()
        g_loss = model.calculate_g_loss(x=batch_data,
                                        recx=rec_out['recx'],
                                        codebook_loss=rec_out['codebook_loss'],
                                        disc_out_fake=None)

        g_loss['tot_loss'].backward()
        g_optim.step()
        
        acc_rec_loss += g_loss['rec_loss'].detach()
        acc_cb_loss += g_loss['codebook_loss'].detach()
        acc_lpips_loss += g_loss['lpips_loss'].detach()
        acc_encodings_sum += rec_out['encodings_sum'].detach().cpu()
        
    avg_probs = acc_encodings_sum / torch.sum(acc_encodings_sum)
    perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
    if epoch % 1 == 0:
        with torch.no_grad():
            vis_img(batch_data, rec_out['recx'], f"train {epoch}")
    if epoch%config['save_every_n_epoch']==0:
        save_phase1(epoch)
    if epoch % 3 == 1 and epoch >= 3:
        with torch.no_grad():
            model.code_book.embed.reinit(avg_probs.cuda())
    info = f'Train Epoch: {epoch}.\n' +\
           f'rec_loss: {acc_rec_loss / n_batch:.4f}; ' +\
           f'codebook_loss: {acc_cb_loss / n_batch:.4f}; ' +\
           f'lpips_loss: {acc_lpips_loss / n_batch:.4f}; ' +\
           f'perplexity: {perplexity:.4f}\n'
           
    with open(config_path, 'a') as f:
        f.write(info)

    print(info)


def val_step_vae(model, epoch):
    model.eval()
    with torch.no_grad():
        n_batch = 0
        acc_rec_loss = 0
        acc_cb_loss = 0
        acc_lpips_loss = 0 
        acc_encodings_sum = 0
       
        for batch_ind, batch_data in enumerate(test_dataloader):
            n_batch += 1
            if config['cuda']:
                batch_data = batch_data.cuda()
            rec_out = model.reconstruct(batch_data)
            g_loss = model.calculate_g_loss(x=batch_data,
                                            recx=rec_out['recx'],
                                            codebook_loss=rec_out['codebook_loss'],
                                            disc_out_fake=None)
            acc_rec_loss += g_loss['rec_loss'].detach()
            acc_cb_loss += g_loss['codebook_loss'].detach()
            acc_lpips_loss += g_loss['lpips_loss'].detach()
            acc_encodings_sum += rec_out['encodings_sum'].detach().cpu()
        avg_probs = acc_encodings_sum / torch.sum(acc_encodings_sum)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        if epoch % 1 == 0:
            vis_img(batch_data, rec_out['recx'], f"test {epoch}")
        info = f'Test Epoch: {epoch}.\n' +\
               f'rec_loss: {acc_rec_loss / n_batch:.4f}; ' +\
               f'codebook_loss: {acc_cb_loss / n_batch:.4f}; ' +\
               f'lpips_loss: {acc_lpips_loss / n_batch:.4f}; ' +\
               f'perplexity: {perplexity:.4f}\n'
        with open(config_path, 'a') as f:
            f.write(info)
        if epoch==config['max_epochs']:
            avg_probs=avg_probs.numpy()
            avg_probs=np.sort(avg_probs)
            plt.bar(np.arange(avg_probs.shape[0]),avg_probs)
            plt.savefig(os.path.join(outcome_root, f"{config['ver']}/Dis.png"),dpi=80)
            plt.clf()
    print(info)

@torch.no_grad()
def vis_img(x, y, name):
    h,w=config['image_size']
    fp = os.path.join(outcome_root, f"{config['ver']}/{name}.png")
    x = (cf2cl(x.detach().cpu()).numpy()[:8] + 1)/2
    y = (cf2cl(y.detach().cpu()).numpy()[:8] + 1)/2
    arr = np.zeros((4 * h, 4 * w, 3))
    for i in [0, 1]:
        for j in range(4):
            arr[2 * i * h:(2 * i + 1) * h, \
            j * w:(j + 1) * w] = x[4 * i + j, :, :, :]
    for i in [0, 1]:
        for j in range(4):
            arr[(2 * i + 1) * h:(2 * i + 2) * h, \
            j * w:(j + 1) * w] = y[4 * i + j, :, :, :]
    arr = np.clip(arr * 255, 0, 255).astype(np.uint8)
    cv2.imwrite(fp, arr)

In [13]:
def train_step_maskgit(maskgit, epoch):
    maskgit.train()
    n_batch = 0
    acc_ce_loss = 0
    acc_raw_perplexity = 0
    
    for batch_ind, [batch_data] in enumerate(train_dataloader):
        n_batch += 1
        batch_data = batch_data.int()
        if config['cuda']:
            batch_data = batch_data.cuda()
        
        t_out = maskgit.train_val_step(batch_data)
        t_optim.zero_grad()
        t_loss = maskgit.calculate_loss(x=batch_data,
                                        logits=t_out['logits'],
                                        mask=t_out['mask'])

        t_loss['tot_loss'].backward()
        t_optim.step()
        
        acc_ce_loss += t_loss['tot_loss'].detach()
        acc_raw_perplexity += t_loss['raw_perplexity'].detach()
        
    perplexity = torch.exp(-acc_raw_perplexity/n_batch)
    if epoch%config['save_every_n_epoch']==0:
        save_phase2(epoch)
    info = f'Train Epoch: {epoch}.\n' +\
           f'ce_loss: {acc_ce_loss / n_batch:.4f}; ' +\
           f'perplexity: {perplexity:.2f}\n'  
    with open(config_path, 'a') as f:
        f.write(info)
    print(info)
    

@torch.no_grad()
def val_step_maskgit(maskgit, epoch):
    maskgit.eval()
    n_batch = 0
    acc_ce_loss = 0
    acc_raw_perplexity = 0
    
    for batch_ind, [batch_data] in enumerate(train_dataloader):
        n_batch += 1
        batch_data = batch_data.int()
        if config['cuda']:
            batch_data = batch_data.cuda()
        
        t_out = maskgit.train_val_step(batch_data)
        t_loss = maskgit.calculate_loss(x=batch_data,
                                        logits=t_out['logits'],
                                        mask=t_out['mask'])      
        acc_ce_loss += t_loss['tot_loss']
        acc_raw_perplexity += t_loss['raw_perplexity']
        
    perplexity = torch.exp(-acc_raw_perplexity/n_batch)
    info = f'Test Epoch: {epoch}.\n' +\
           f'ce_loss: {acc_ce_loss / n_batch:.4f}; ' +\
           f'perplexity: {perplexity:.2f}\n'  
    vis_maskgit_unconditional_generate(batch_size=8,name=f"m{epoch}")
    with open(config_path, 'a') as f:
        f.write(info)
    print(info)

@torch.no_grad()
def vis_maskgit_unconditional_generate(batch_size=8,name='1',t=None):
    if config['phase']!=3:
        img_rows=[]
        for _ in range(batch_size):
            ind_ls=maskgit.unconditional_generate(temperature=(1,1)) # [n_step, 1, n_pos]
            if config['cuda']:
                ind=torch.vstack(ind_ls).cuda()
            else:
                ind=torch.vstack(ind_ls)
            imgs = cf2cl(model.decode(ind).cpu()).numpy()
            img_rows.append(np.hstack(imgs))
        output=np.vstack(img_rows)
        output=np.clip(output*127.5+127.5,0,255).astype(np.uint8)
        fp = os.path.join(outcome_root, f"{config['ver']}/{name}.png")
        cv2.imwrite(fp, output)
    
    steps_scheduler=lambda x: maskgit.n_steps-4+round(x/16*8)
    temperature=t if t is not None else (1,1)
    ind_ls=[]
    for k in range(16):
        ind_ls.append(maskgit.unconditional_generate(temperature=temperature,
                                                     n_steps=steps_scheduler(k))[-1])
    if config['cuda']:
        ind=torch.vstack(ind_ls).cuda()
    else:
        ind=torch.vstack(ind_ls)
    imgs = model.decode(ind)
    vis_img(imgs[:8],imgs[8:],f"{name}_")
    
    temperature=t if t is not None else (2.5,1)
    ind_ls=[]
    for k in range(16):
        ind_ls.append(maskgit.unconditional_generate(temperature=temperature,
                                                     n_steps=steps_scheduler(k))[-1])
    if config['cuda']:
        ind=torch.vstack(ind_ls).cuda()
    else:
        ind=torch.vstack(ind_ls)
    imgs = model.decode(ind)
    vis_img(imgs[:8],imgs[8:],f"{name}__")

In [17]:
class CELEBAImageDataset(Dataset):
    def __init__(self, image_paths):
        self.target_size = config['image_size']
        self.image_paths = image_paths
        self.to_tensor=torchvision.transforms.ToTensor()
        

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = cv2.imread(image_path)[10:,2:]
        image = self.to_tensor((image/127.5-1).astype(np.float32))
        return image
    
    
class AFHQImageDataset(Dataset):
    def __init__(self, image_paths):
        self.target_size = config['image_size']
        self.intrp_method = cv2.INTER_LANCZOS4
        self.image_paths = image_paths
        self.to_tensor=torchvision.transforms.ToTensor()
        self.flip=torchvision.transforms.RandomHorizontalFlip(p=0.5)
        

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = cv2.imread(image_path)
        image = cv2.resize(image, self.target_size,interpolation=self.intrp_method)
        image = self.to_tensor((image/127.5-1).astype(np.float32))
        image = self.flip(image)

        return image

if phase==1:
    if 'afhq' in config['ver']:
        split=round(config['dataset_size']/3*config['train_test_split'])
        train_paths=[os.path.join(data_path_afhq,s,f"{i:0>4d}.png")\
                     for i in range(split)\
                     for s in ['cat','dog','wild']]
        test_paths=[os.path.join(data_path_afhq,s,f"{i:0>4d}.png")\
                    for i in range(split,round(config['dataset_size']/3))\
                    for s in ['cat','dog','wild']]

        train_data = AFHQImageDataset(train_paths)
        test_data = AFHQImageDataset(test_paths)

    elif 'celeba' in config['ver']:
        split=round(config['dataset_size']*config['train_test_split'])
        train_paths=[os.path.join(data_path_celeba,f"{i:0>6d}.jpg") for i in range(1,split+1)]
        test_paths=[os.path.join(data_path_celeba,f"{i:0>6d}.jpg") for i in range(split+1,config['dataset_size']+1)]

        train_data = CELEBAImageDataset(train_paths)
        test_data = CELEBAImageDataset(test_paths)

    print(f"train data shape: {len(train_data)}\ntest data shape: {len(test_data)}")


    train_dataloader = DataLoader(train_data, batch_size=config['batch_size'],
                                  shuffle=True,num_workers=4,drop_last=True)

    test_dataloader = DataLoader(test_data, batch_size=config['batch_size'],
                                 shuffle=True,num_workers=4,drop_last=True)
    
elif phase==2:
    ori_data = np.load(data_path_enc).astype('float32')

    config['dataset_size']=min(config['dataset_size'],ori_data.shape[0])
    split_ind = int(config['dataset_size'] * config['train_test_split'])
    train_data, test_data = ori_data[:split_ind], ori_data[split_ind:config['dataset_size']]
    print(f"train data shape: {train_data.shape}\ntest data shape: {test_data.shape}")

    train_data_tensor = torch.from_numpy(train_data)
    train_dataset = torch.utils.data.TensorDataset(train_data_tensor)
    train_dataloader = DataLoader(train_dataset, batch_size=t_config['batch_size'],
                                  shuffle=True,drop_last=True)

    test_data_tensor = torch.from_numpy(test_data)
    test_dataset = torch.utils.data.TensorDataset(test_data_tensor)
    test_dataloader = DataLoader(test_dataset, batch_size=t_config['batch_size'],
                                 shuffle=True,drop_last=True)

model = VQGAN(channels=3)
device='cuda' if config['cuda'] else 'cpu'
if config['load']:
    model.load_state_dict(torch.load(config['load'],
                                     map_location=torch.device(device)),
                          strict=False)
    print("vq_loaded")
model.eval()

print_num_params(model.encoder,"Encoder")
print_num_params(model.code_book,"CodeBook")
print_num_params(model.decoder,"Decoder")
if phase==1:
    print_num_params(model.lpips_fn,"LPIPS")

if config['use_disc']:
    print_num_params(model.discriminator,"Discriminator")
if config['cuda']:
    model.cuda()

if phase==1:
    g_optim = torch.optim.Adam(list(model.encoder.parameters())+\
                               list(model.decoder.parameters()),
                               lr=config['base_learning_rate'],
                               betas=(0.5,0.9))
    g_scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(g_optim,
                                                           T_max=config['max_epochs'],
                                                           eta_min=config['min_learning_rate'])
if config['use_disc'] and phase==1:
    d_optim = torch.optim.Adam(model.discriminator.parameters(),
                               lr=d_config['base_learning_rate'],
                               betas=(0.5,0.9))

if phase==2 or phase==3:
    maskgit=MaskGIT(n_tokens=t_config['n_tokens'], n_pos=t_config['n_pos'], 
                    embed_dim=t_config['embed_dim'], nhead=t_config['nhead'],
                    hidden_dim=t_config['hidden_dim'],n_layers=t_config['n_layers'],
                    n_steps=t_config['n_steps'])
    device='cuda' if config['cuda'] else 'cpu'
    if t_config['load']:
        maskgit.load_state_dict(torch.load(t_config['load'],
                                           map_location=torch.device(device)),
                                strict=False)
        print("maskgit_loaded")
    maskgit.eval()
    print_num_params(maskgit,"MaskGIT")
    if config['cuda']:
        maskgit.cuda()
    t_optim=torch.optim.Adam(maskgit.parameters(),
                             lr=t_config['base_learning_rate'],
                             betas=(0.9, 0.96))
    t_scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(t_optim,
                                                           T_max=config['max_epochs'],
                                                           eta_min=t_config['min_learning_rate'])
if phase==1 or phase==2:
    t0 = time.time()
    t1 = time.time()
    d_loss_logger=EMALogger(decay=0.9)
    for epoch in range(1, config['max_epochs'] + 1):
        t0 = t1
        if phase==1:
            if config['use_disc']:
                train_step(model, epoch)
                val_step(model, epoch)
            else:
                train_step_vae(model, epoch)
                val_step_vae(model, epoch)
                g_scheduler.step()
        elif phase==2:
            train_step_maskgit(maskgit,epoch)
            val_step_maskgit(maskgit,epoch)
            t_scheduler.step()
        t1 = time.time()
        print(f"time: {t1 - t0:.2f}")

    with open(config_path, 'a') as f:
        f.write(f'time: {t1 - t0:.2f}\n')

    if config['save'] and config['max_epochs']%config['save_every_n_epoch']!=0:
        if phase==1:
            save_phase1(config['max_epochs'])
        elif phase==2:
            save_phase2(config['max_epochs'])
elif phase==3:
    for i in np.linspace(0.5,4.5,9):
        vis_maskgit_unconditional_generate(batch_size=8,name=f"{i:.1f}_01",t=(i,1))
        vis_maskgit_unconditional_generate(batch_size=8,name=f"{i:.1f}_02",t=(i,1))

Encoder parameters: 2629952
CodeBook parameters: 131584
Decoder parameters: 3015555
MaskGIT parameters: 3875456


In [18]:
from sklearn.decomposition import PCA
if phase==1:
    X = model.code_book.embed.weight.detach().cpu().numpy()
    pca = PCA(n_components=2)
    X_2d = pca.fit_transform(X)
    plt.scatter(X[:,0],X[:,1])
    plt.savefig(os.path.join(outcome_root, f"{config['ver']}/PCA.png"),dpi=80)
    plt.clf()
elif phase==2 or phase==3:
    def f(k):
        if 0<=k<=3: # 0 3
            return k
        elif 4<=k<=7: # 5 11
            return 2*k-3
        elif 8<=k<=11:  # 14 23
            return 3*k-10
        else:
            return 0
    X = maskgit.pos_embed.detach().cpu().numpy()
    pca = PCA(n_components=24)
    X = pca.fit_transform(X)
    fig, axs = plt.subplots(3, 4)
    #plt.subplots_adjust(hspace=0.1, wspace=0.1)
    for i,ax in enumerate(axs.flat):
        k=f(i)
        ax.imshow(np.reshape(X[:,k],config['latent_size']))
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlabel(str(k+1))
    plt.tight_layout(pad=0.5, h_pad=0.5, w_pad=0.5)
    plt.savefig(os.path.join(outcome_root, f"{config['ver']}/PosEmbed.png"),dpi=80)
    plt.clf()

<Figure size 640x480 with 0 Axes>

In [19]:
shutil.make_archive(f"/kaggle/working/{config['ver']}", "zip", f"/kaggle/working/vqvae/{config['ver']}")

'/kaggle/working/afhq_maskgit_0412_v01_phase3.zip'