## Helper function

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GroupNorm(nn.Module):
    def __init__(self,channels):
        super(GroupNorm,self).__init__()
        self.gn = nn.GroupNorm(num_groups=32,num_channels=channels,eps=1e-6,affine=True)


    def forward(self,x):
        return self.gn(x)

class ResidualBlock(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(ResidualBlock,self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.Residual_block = nn.Sequential(
            GroupNorm(in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels,out_channels,3,1,1),
            GroupNorm(out_channels),
            nn.SiLU(),
            nn.Conv2d(out_channels,out_channels,3,1,1),
        )
        if in_channels != out_channels:
            self.channel_up =  nn.Conv2d(in_channels,out_channels,1,1,0) # 1x1 kernel is applied
            # so no padding is needed since 1x1 kernels keep the spatial dimensions the same 

    def forward(self,x):
        if self.in_channels != self.out_channels:
            upchannel = self.channel_up(x)
            res_block = self.Residual_block(x)
            return self.channel_up(x) + self.Residual_block(x) 
        
        else:
            return x + self.Residual_block(x)

class UpSampleBlock(nn.Module):
    def __init__(self,channels):
        super(UpSampleBlock,self).__init__()
        self.conv = nn.Conv2d(channels,channels,3,1,1)
        
    def forward(self,x):
        x = F.interpolate(x, scale_factor=2.0) # double height and width
        # fill in spot by nearest neighbour
        return self.conv(x)
    
class DownSampleBlock(nn.Module):
    def __init__(self,channels):
        super(DownSampleBlock,self).__init__()
        self.conv = nn.Conv2d(channels,channels,3,2,0)
    
    def forward(self,x):
        pad = (0,1,0,1)
        x = F.pad(x,pad,mode="constant",value=0) # pad 0 on the height 
        # dimension and on the width dimension (done so that the dimension
        # will be half rounded to the nearest integer)
        return self.conv(x) 

class NonLocalBlock(nn.Module):
    def __init__(self,channels):
        super(NonLocalBlock, self).__init__()
        self.in_channels = channels

        self.gn = GroupNorm(channels) 

        self.q = nn.Conv2d(channels,channels,1,1,0)
        self.k = nn.Conv2d(channels,channels,1,1,0)
        self.v = nn.Conv2d(channels,channels,1,1,0)

        self.project_out = nn.Conv2d(channels,channels,1,1,0)

    def forward(self,x):
        embedding = self.gn(x)
        q = self.q(embedding)
        k = self.k(embedding)
        v = self.v(embedding)

        b,c,h,w = q.shape

        q = q.reshape(b,c,h*w) # put the image into a 
        # single vector instead of a 2d matrix
        q = q.permute(0,2,1) # effectively transposing Q
        k = k.reshape(b,c,h*w)
        v = v.reshape(b,c,h*w)

        attention = torch.bmm(q,k) # performs matrix-matrix multiplication
        # for all the batches so we obtain Q^T*K
        attention *= (int(c)**(-.5))
        attention = F.softmax(attention,dim=2)
        attention = attention.permute(0,2,1) # effectively transposing again
        # probably not really necessary

        A = torch.bmm(v,attention)  

        A = A.reshape(b,c,h,w)

        return x+A
        



# Encoder

In [2]:
import torch.nn as nn
#from torchsummary import summary
class Encoder(nn.Module):
    def __init__(self,args):
        super(Encoder,self).__init__()
        channels = [128,128,128,256,256,512] 
        attention_resolutions = [16]
        number_res_blocks = 2
        resolution = 256
        layers = [nn.Conv2d(args.image_channels,channels[0],3,1,1)]
        for i in range(len(channels)-1):
            in_channels = channels[i]
            out_channels = channels[i+1]
            for j in range(number_res_blocks):
                layers.append(ResidualBlock(in_channels,out_channels))
                in_channels = out_channels
                if resolution in attention_resolutions:
                    layers.append(NonLocalBlock(in_channels))
            if i!=len(channels)-2:
                layers.append(DownSampleBlock(channels[i+1]))
                resolution //=2
        layers.append(ResidualBlock(channels[-1],channels[-1]))
        layers.append(NonLocalBlock(channels[-1]))
        layers.append(ResidualBlock(channels[-1],channels[-1]))
        layers.append(GroupNorm(channels[-1]))
        layers.append(nn.SiLU())
        layers.append(nn.Conv2d(channels[-1],args.latent_dim,3,1,1))
        self.model = nn.Sequential(*layers)

    def forward(self,x):
        ##summary(self.model,)
        #print("Total length model", len(self.model))
        #for i,layer in enumerate(self.model):
            #print("Layer:",i)
            #print("Input dims:", x.shape)
            #classname = layer.__class__.__name__
            #if classname == "NonLocalBlock" and i==14:
                #print(layer.in_channels)
                ##print(layer.out_channels)
                
            #print("layername",classname)

            #x=layer(x)
            #print("Output dims: ", x.shape)
        #return x
        return self.model(x)


# Decoder

In [3]:
import torch.nn as nn

class Decoder(nn.Module):
    def __init__(self,args):
        super(Decoder,self).__init__()
        channels = [512,256,256,128,128]
        attention_resolution = [16]
        number_res_blocks = 3
        resolution = 16

        in_channels = channels[0]
        layers = [
            nn.Conv2d(args.latent_dim,in_channels,3,1,1),
            ResidualBlock(in_channels,in_channels),
            NonLocalBlock(in_channels),
            ResidualBlock(in_channels,in_channels)
            ]
        for i in range(len(channels)):
            out_channels=channels[i]
            for j in range(number_res_blocks):
                layers.append(ResidualBlock(in_channels,out_channels))
                in_channels = out_channels
                if resolution in attention_resolution:
                    layers.append(NonLocalBlock(in_channels))
            if i!=0:
                layers.append(UpSampleBlock(in_channels))
                resolution *= 2
        layers.append(GroupNorm(in_channels))
        layers.append(nn.SiLU())
        layers.append(nn.Conv2d(in_channels,args.image_channels,3,1,1))
        self.model = nn.Sequential(*layers)


    def forward(self,x):
        return self.model(x)

# Creating Codebook 

In [4]:
class Codebook(nn.Module):
    def __init__(self,args):
        super(Codebook,self).__init__()
        self.num_codebook_vectors = args.num_codebook_vectors
        self.latent_dim = args.latent_dim
        self.beta = args.beta

        self.embedding = nn.Embedding(self.num_codebook_vectors,self.latent_dim) # matrix with as rows the different embedding vectors

        # takes as input tensor with indices, output will be a tensor containing all the requested embedding vectors that corr with the indices
        self.embedding.weight.data.uniform_(-1.0/self.num_codebook_vectors,1.0/self.num_codebook_vectors) #the learnable weights of the module of shape (num_embeddings, embedding_dim) initialized uniformly now

    def forward(self,z):
        # z is normally of shape (batch_size,channels,height, width), after permutation its (batch_size, height,width,channels)
        z = z.permute(0,2,3,1).contiguous() # prepending latent vectors for finding the minimal distance to the codebook vectors
        z_flattened = z.view(-1,self.latent_dim)

        d = torch.sum(z_flattened**2,dim=1,keepdim=True)+\
            torch.sum(self.embedding.weight**2, dim=1)-\
            2*(torch.matmul(z_flattened,self.embedding.weight.t()))

        min_encoding_indices = torch.argmin(d,dim=1)
        z_q = self.embedding(min_encoding_indices).view(z.shape)
        loss = torch.mean((z_q.detach()-z)**2)+ self.beta * torch.mean((z_q-z.detach())**2)
        # above we first remove the gradient from the quantized latent vectors from the gradient flow and substract it from the original latent vector
        # in the second part we remove tha gradient from the original latent vector and keep the one of the quantized latent vector and substract them 

        z_q = z + (z_q-z).detach() # here we make sure that z_q has the gradient of z but keeps its quantized value
        z_q = z_q.permute(0,3,1,2) 

        return z_q, min_encoding_indices, loss 
        



# VQ GAN

In [5]:
import torch.nn as nn

class VQGAN(nn.Module):
    def __init__(self,args):
        super(VQGAN, self).__init__()
        self.encoder = Encoder(args).to(device=args.device)
        self.decoder= Decoder(args).to(device=args.device)
        self.codebook = Codebook(args).to(device=args.device)
        self.quant_conv = nn.Conv2d(args.latent_dim, args.latent_dim,1).to(device=args.device)
        self.post_quant_conv = nn.Conv2d(args.latent_dim,args.latent_dim,1).to(device=args.device)
    
    def forward(self,imgs):
        enc_imgs = self.encoder(imgs)
        pre_quant_conv = self.quant_conv(enc_imgs)
        quantized_imgs, quantized_indices,q_loss = self.codebook(pre_quant_conv)
        post_quant_conv_mapping = self.post_quant_conv(quantized_imgs)
        decoded_imgs = self.decoder(post_quant_conv_mapping)

        return decoded_imgs,quantized_indices, q_loss
    
    def encode(self,imgs):
        encoded_images = self.encoder(imgs)
        quant_conv_encoded_images = self.quant_conv(encoded_images)
        quantized_imgs, indices, q_loss = self.codebook(quant_conv_encoded_images)

        return quantized_imgs,indices,q_loss
    
    def decode(self,z):
        post_quant_conv_mapping = self.post_quant_conv(z)
        decoded_imgs = self.decoder(post_quant_conv_mapping)
        return decoded_imgs

    def calculate_lambda(self,perceptual_loss,gan_loss):
        last_layer = self.decoder.model[-1]
        last_layer_weight = last_layer.weight
        perceptual_loss_grads = torch.autograd.grad(perceptual_loss,last_layer_weight,retain_graph=True)[0]# retain graph makes sure the computational 
        # graph won't get discarded after calling .grad, this way the gan loss can also be calculated wrt to the weights `[0]` here means we get the 
        # gradient of the zero'th element in the sequence of inputs we've given
        gan_loss_grads = torch.autograd.grad(gan_loss,last_layer_weight,retain_graph=True)[0]
        
        lamb = torch.norm(perceptual_loss_grads)/(torch.norm(gan_loss_grads)+1e-4)
        lamb = torch.clamp(lamb,0,1e4).detach()
        return .8* lamb

    @staticmethod
    def adopt_weight(disc_factor,i,threshold,value=0.):
        if i<threshold:
            disc_factor = value
        return disc_factor

    def load_checkpoint(self,path):
        self.load_state_dict(torch.load(path))



# Discriminator (copy of cycleGAN)


In [6]:
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self,args,num_filters_last=64,n_layers=3):
        super(Discriminator,self).__init__()

        # layers start with a convolutional layer that takes an input image with args.image_channels number of channels
        layers = [nn.Conv2d(args.image_channels,num_filters_last,4,2,1), nn.LeakyReLU(.2)]
        num_filters_mult=1
        # we multiply the number of outputted filters by 2 every iteration
        # untill we would multiply the original number of filters by 8, after that the number of 
        # features stay the same
        for i in range(1,n_layers+1):
            num_filters_mult_last = num_filters_mult
            num_filters_mult = min(2**i,8)

            layers += [
                nn.Conv2d(num_filters_last*num_filters_mult_last, num_filters_last* num_filters_mult,4, 2 if i <n_layers else 1,bias=False),
                nn.BatchNorm2d(num_filters_last*num_filters_mult),
                nn.LeakyReLU(.2,True)
            ]
        layers.append(nn.Conv2d(num_filters_last*num_filters_mult,1,4,1,1))
        self.model = nn.Sequential(*layers)
    def forward(self,x):
        #print("Total length model", len(self.model))
        #for i,layer in enumerate(self.model):
            #print("Layer:",i)
            #print("Input dims:", x.shape)
            #classname = layer.__class__.__name__
            #if classname == "Conv2d" and i==6:
                #print(layer.in_channels)
                #print(layer.out_channels)
                
            #print("layername",classname)

            #x=layer(x)
            #print("Output dims: ", x.shape)
        #return x
        return self.model(x)

# LPIPS

In [7]:
from torchvision.models import vgg16

import os
import torch
import torch.nn as nn
import requests
from tqdm import tqdm
from collections import namedtuple

URL_DCT = {
    "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
}

CKPT_DCT = {
    "vgg_lpips": "vgg.pth"
}

def download(url, local_path, chunk_size=1024):
    # creates the necessary directory for the local path using os.makedirs
    os.makedirs(os.path.split(local_path)[0],exist_ok=True)
    # sends get request to specified url with stream == True so that the response 
    # is streamed instead of loaded into memory at once
    with requests.get(url,stream=True) as r:
        total_size = int(r.headers.get("content-length",0))
        # creates progress bar to display the download progress, the total= Total_size
        # which is the the size of the file being downloaded retrieved from content
        # length header 
        with tqdm(total=total_size,unit="B",unit_scale=True) as pbar:
            with open(local_path,"wb") as f:
                for data in r.iter_content(chunk_size=chunk_size):
                    if data:
                        # we iterate over the data stream and write to the
                        # file
                        f.write(data)
                        # progress bar is updated
                        pbar.update(chunk_size)

def get_ckpt_path(name,root):
    assert name in URL_DCT
    path = os.path.join(root,CKPT_DCT[name])
    if not os.path.exists(path):
        print(f"Downloading {name} model from {URL_DCT[name]} to {path}")
        download(URL_DCT[name], path)
    return path



In [8]:
class NetLinLayer(nn.Module):
    def __init__(self,in_channels,out_channels=1):
        super(NetLinLayer, self).__init__()
        self.model=nn.Sequential(
            nn.Dropout(),
            nn.Conv2d(in_channels,out_channels,1,1,0,bias=False)
        )

class VGG16(nn.Module):
    def __init__(self):
        super(VGG16,self).__init__()
        vgg_pretrained_features = vgg16(pretrained=True).features
        slices = [vgg_pretrained_features[i] for i in range(30)]
        self.slice1 = nn.Sequential(*slices[0:4])
        self.slice2 = nn.Sequential(*slices[4:9])
        self.slice3 = nn.Sequential(*slices[9:16])
        self.slice4 = nn.Sequential(*slices[16:23])
        self.slice5 = nn.Sequential(*slices[23:30])

        for param in self.parameters():
            param.requires_grad = False
        
    def forward(self,x):

        h=x
        #print("Total length model slice1", len(self.slice1))
        #for i,layer in enumerate(self.slice1):
            #print("Layer:",i)
            #print("Input dims:", h.shape)
            #classname = layer.__class__.__name__
            #if classname == "NonLocalBlock" and i==14:
                #print(layer.in_channels)
                ##print(layer.out_channels)
                
            #print("layername",classname)

            #h=layer(h)
            #print("Output dims: ", h.shape)
        h = self.slice1(x)
        h_relu1 = h
        h = self.slice2(h)
        h_relu2 = h
        h = self.slice3(h)
        h_relu3 = h
        h = self.slice4(h)
        h_relu4 = h
        h = self.slice5(h)
        h_relu5 = h
        vgg_outputs = namedtuple("VGGOutputs", ['relu1_2','relu2_2','relu3_3', "relu4_3","relu5_3"])
        return vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5)

In [9]:
def norm_tensor(x):
    '''
    Computes the L2 norm of each channel vector of of input tensor x,
    and returns the normalized version of x
    '''
    norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
    return x/(norm_factor+1e-10)

def spatial_average(x):
    '''
    Computes the average of each channel over the spatial dimensions of the tensor
    When x has shape (batch_size,num_channels,height,width) then the output tensor
    will have shape (batch_size,num_channels, 1,1), the scalar represents how similar 
    the real and fake images are.
    '''
    return x.mean([2,3],keepdim=True)

In [10]:
class LPIPS(nn.Module):
    def __init__(self):
        super(LPIPS,self).__init__()
        self.scaling_layer = ScalingLayer()
        self.channels = [64,128,256,512,512]
        self.vgg = VGG16()

        self.lins = nn.ModuleList([
            NetLinLayer(self.channels[0]),
            NetLinLayer(self.channels[1]),
            NetLinLayer(self.channels[2]),
            NetLinLayer(self.channels[3]),
            NetLinLayer(self.channels[4]),
        ])

        self.load_from_pretrained()

        for param in self.parameters():
            param.requires_grad = False
        
    def load_from_pretrained(self,name="vgg_lpips"):
        ckpt = get_ckpt_path(name,"vgg_lpips")
        self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")),strict=False)

    def forward(self,real_img,fake_img):
        features_real = self.vgg(self.scaling_layer(real_img))
        features_fake = self.vgg(self.scaling_layer(fake_img))

        diffs = {}

        for i in range(len(self.channels)):
            diffs[i] = (norm_tensor(features_real[i])-norm_tensor(features_fake[i]))**2
        
        return sum([spatial_average(self.lins[i].model(diffs[i])) for i in range(len(self.channels))])

# The reason for the scaling layer class is to preprocess input images to match the 
# expected input format of the VGG-16 network used to compute image feature representations
class ScalingLayer(nn.Module):
    def __init__(self):
        super(ScalingLayer,self).__init__()
        # creates tensor with torch.tensor([-.030,-.088,-1.88]) but adds new dimension
        # of size 1 at the start, : selects all elements along the original tensor resulting
        # in shape (1,3), then two additional dimensions are added using None indexing
        # resulting in shape (1,3,1), the third None adds a new dimension of size 1 at the
        # fourth position (1,3,1,1)
        # equivalent: self.register_buffer("shift", torch.tensor([-0.030, -0.088, -1.88]).unsqueeze(0).unsqueeze(2).unsqueeze(3))

        self.register_buffer("shift",torch.tensor([-.030,-.088,-1.88])[None,:,None,None])
        self.register_buffer("scale",torch.tensor([.458,.448,.45])[None,:,None,None])
        # dimensions are added to make it broadcastable since the images will be of size
        # (batch_size, 3, height, width)

    def forward(self,x):
        return (x-self.shift)/self.scale



# Utils

In [11]:
import os
import albumentations
import numpy as np
import torch.nn as nn
from PIL import Image
from torch.utils.data import Dataset,DataLoader
import matplotlib.pyplot as plt


# Data utils

class ImagePaths(Dataset):
    def __init__(self,path,size=None):
        self.size = size

        self.images = [os.path.join(path,file) for file in os.listdir(path)]
        self._length = len(self.images)

        self.rescaler = albumentations.SmallestMaxSize(max_size=self.size)
        self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
        self.preprocessor = albumentations.Compose([self.rescaler,self.cropper])

    def __len__(self):
        return self._length
    
    def preprocess_image(self,image_path):
        image = Image.open(image_path)
        if not image.mode == "RGB":
            image.convert("RGB")
    
        image = np.array(image).astype(np.uint8)
        image = self.preprocessor(image=image)["image"]
        # the original image has colors 0-255 pixel values (8bit)
        # so dividing by 127.5 will make the values range from [0,2].
        # then dividing by -1 will give you a range of [-1,1]
        image = (image/127.5-1.0).astype(np.float32)
        # normal PIL library will make channels the last dimension
        # so we transpose to make the channel go first 
        image = image.transpose(2,0,1)

        return image
    
    def __getitem__(self, index):
        img = self.preprocess_image(self.images[index])
        return img 

def load_data(args):
    train_data = ImagePaths(args.dataset_path,size=256)
    train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=False)
    return train_loader

In [12]:
#Module utils (encoder, decoder)

# for initializing the weights of certain classes properly
def weights_init(m):
    classname = m.__class__.__name__

    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, mean=0.0, std=0.02)
    elif classname.find("BatchNorm")!=-1:
        nn.init.normal_(m.weight.data,mean=1.0,std=.02)
        nn.init.constant_(m.bias.data,0)

#plots images of the transformer stage
def plot_images(images):
    x = images["inputs"]
    reconstruction = images["reconstructions"]
    sample_half = images["samples_half"]
    sample_nopix = images["samples_nopix"]
    sample_det = images["samples_det"]
    
    fig, axarr = plt.subplots(1,5)
    axarr[0].imshow(x)
    axarr[1].imshow(reconstruction)
    axarr[2].imshow(sample_half)
    axarr[3].imshow(sample_nopix)
    axarr[4].imshow(sample_det)
    plt.show()


        

# Training the VQ-GAN

In [13]:
import os
import argparse
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import utils as vutils

class TrainVQGAN:
    def __init__(self,args):
        self.vqgan = VQGAN(args).to(device=args.device)
        self.discriminator = Discriminator(args).to(device=args.device)
        self.discriminator.apply(weights_init)
        self.perceptual_loss = LPIPS().eval().to(device=args.device)
        self.opt_vq, self.opt_disc = self.configure_optimizers(args)

        self.prepare_training()

        self.train(args)

    def configure_optimizers(self,args):
        lr = args.learning_rate
        opt_vq = torch.optim.Adam(
            list(self.vqgan.encoder.parameters()) +
            list(self.vqgan.decoder.parameters()) +
            list(self.vqgan.codebook.parameters()) +
            list(self.vqgan.quant_conv.parameters()) +
            list(self.vqgan.post_quant_conv.parameters()) 
        )
        opt_disc = torch.optim.Adam(self.discriminator.parameters(),
                                    lr=lr,eps=1e-8,betas = (args.beta1,args.beta2))
        return opt_vq,opt_disc
    
    @staticmethod
    def prepare_training():
        os.makedirs("results", exist_ok=True)
        os.makedirs("checkpoints", exist_ok=True)

    def train(self,args):
        train_dataset = load_data(args)
        steps_per_epoch = len(train_dataset)
        for epoch in range(args.epochs):
            with tqdm(range(len(train_dataset))) as pbar:
                for i,imgs in zip(pbar,train_dataset):
                    imgs = imgs.to(device=args.device)
                    decoded_imgs, _ ,q_loss = self.vqgan(imgs)

                    disc_real = self.discriminator(imgs)
                    disc_fake = self.discriminator(decoded_imgs)

                    disc_factor = self.vqgan.adopt_weight(args.disc_factor, epoch*steps_per_epoch+i,threshold=args.disc_start)

                    perceptual_loss = self.perceptual_loss(imgs,decoded_imgs)
                    rec_loss = torch.abs(imgs-decoded_imgs)
                    perceptual_rec_loss = args.perceptual_loss_factor * perceptual_loss+ args.rec_loss_factor*rec_loss
                    perceptual_rec_loss = perceptual_rec_loss.mean()
                    g_loss = -torch.mean(disc_fake)


                    #$\lambda = \frac{\nabla_{G_L}[\mathcal{L}_{rec}]}{\nabla_{G_L}[\mathcal{L}_{GAN}]+\delta}$
                    lam = self.vqgan.calculate_lambda(perceptual_rec_loss,g_loss)
                    vq_loss = perceptual_rec_loss * q_loss+ disc_factor * lam * g_loss

                    d_loss_real = torch.mean(F.relu(1.0-disc_real)) #hinge loss
                    d_loss_fake = torch.mean(F.relu(1.0+disc_fake)) #hinge loss, inspired by svm's

                    gan_loss = disc_factor * 0.5*(d_loss_real + d_loss_fake)

                    self.opt_vq.zero_grad()
                    vq_loss.backward(retain_graph=True)

                    self.opt_disc.zero_grad()
                    gan_loss.backward()

                    self.opt_vq.step()
                    self.opt_disc.step()
                    
                    if i %10 ==0:
                        with torch.no_grad():
                            real_fake_images = torch.cat((imgs[:4],decoded_imgs.add(1).mul(.5)[:4]))
                            vutils.save_image(real_fake_images,os.path.join("results", f"{epoch}_{i}.jpg"),nrow=4)
                    pbar.set_postfix(
                        VQ_Loss=np.round(vq_loss.cpu().detach().numpy().item(),5),
                        GAN_Loss=np.round(gan_loss.cpu().detach().numpy().item(),3)
                    )
                    pbar.update(0)
                    torch.save(self.vqgan.state_dict(),os.path.join("checkpoints", f"vqgan_epoch_{epoch}.pt"))


In [14]:
import os
print(os.getcwd())
parent_dir = os.path.abspath(os.path.join(os.getcwd(), ".."))
dataset_path = os.path.join(parent_dir, "atari/atari_v1/screens/revenge/1")
print(dataset_path)

/mnt/c/Users/avilu/OneDrive/Documenten/math/deep_learning/project/VQ-GAN_from_scratch
/mnt/c/Users/avilu/OneDrive/Documenten/math/deep_learning/project/atari/atari_v1/screens/revenge/1


In [15]:
import argparse
# Parser (from github)
parser = argparse.ArgumentParser(description="VQGAN")
parser.add_argument('--latent-dim', type=int, default=256, help='Latent dimension n_z (default: 256)')
parser.add_argument('--image-size', type=int, default=256, help='Image height and width (default: 256)')
parser.add_argument('--num-codebook-vectors', type=int, default=1024, help='Number of codebook vectors (default: 256)')
parser.add_argument('--beta', type=float, default=0.25, help='Commitment loss scalar (default: 0.25)')
parser.add_argument('--image-channels', type=int, default=3, help='Number of channels of images (default: 3)')
parser.add_argument('--dataset-path', type=str, default='/data', help='Path to data (default: /data)')
parser.add_argument('--device', type=str, default="cuda", help='Which device the training is on')
parser.add_argument('--batch-size', type=int, default=2, help='Input batch size for training (default: 6)')
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs to train (default: 100)')
parser.add_argument('--learning-rate', type=float, default=2.25e-05, help='Learning rate (default: 0.0002)')
parser.add_argument('--beta1', type=float, default=0.5, help='Adam beta param (default: 0.5)')
parser.add_argument('--beta2', type=float, default=0.9, help='Adam beta param (default: 0.9)')
parser.add_argument('--disc-start', type=int, default=10000, help='When to start the discriminator (default: 10000)')
parser.add_argument('--disc-factor', type=float, default=1., help='')
parser.add_argument('--rec-loss-factor', type=float, default=1., help='Weighting factor for reconstruction loss.')
parser.add_argument('--perceptual-loss-factor', type=float, default=1., help='Weighting factor for perceptual loss.')

args, _ = parser.parse_known_args()
args.dataset_path = dataset_path

In [16]:
train_vq_gan = TrainVQGAN(args)

 14%|█▍        | 290/2018 [09:24<57:24,  1.99s/it, GAN_Loss=0, VQ_Loss=1e-5]    