<h1>QGGAN</h1>
Implementation according to the paper <a href="https://arxiv.org/abs/2407.11913">Quantised Global Autoencoder: A Holistic Approach to Representing Visual Data</a><br/>
This Version is meant for "production", i.e. uses lightning and has the option to use additional sharpening; It can do anything the other notebook can do, but is more complicated and requires more libraries.

In [None]:
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import datasets, transforms
import time
import torch.nn as nn
import math
import numpy as np
import platform
import sys
import os
import matplotlib.pyplot as plt
import torchvision
import lpips
import lightning as L

torch.set_printoptions(precision=8, sci_mode=False)

In [None]:
def mkdir(path):
    try:
        os.makedirs(path)
    except:
        pass

These settings are menat to compress ImageNet into 256 tokens, with additional VQGAN-style sharpening (QGGAN)

In [None]:
class Settings:
    ### SETTINGS FOR QUANTISED AUTOENCODER ###
    VQVAE_D = 64 #dimensionality in which we cluster
    VQVAE_K = 512 #number of different codewords (i.e. number of clusters)
    VQVAE_C = 256 #number of codewords to describe one encoded item; i.e. how many TOKENS we end up with
    DS_SPEED_FACTOR = 3 #number of times we downscale our feature map
    
    ### SETTINGS FOR NETWORKS ###
    CHANNELS_MAXIMUM = 128 #maximum number of channels in our Unets
    CHANNELS_MINIMUM = 16 #minimum number of channels we output in Unet
    NUM_WORKERS = 10 ###CHANGE THIS TO 0 FOR WINDOWS...
    LR = None
    RES_BLOCKS = 2 #number of residual blocks in sequence
    DIM_L_EMBED = 16
    NUM_HEADS = 128
    USE_DROPOUT = False #set to True if you want to order the latent space hierarchically, i.e. with the first codeword containing the most, the second codeword [...]

    NUM_TRAIN_BATCHES_UNTIL_RESET = 1000 #number of epochs until we reset unused codebook entries
    #if we use more items per batch, you can reduce this

    ### GAN SETTINGS ###
    USE_GAN     = True #use VQGAN-like sharpener ("QGGAN")
    WARMUP_ITS_GAN = 5000

    ### RUN SETTINGS ###
    JID = -1
    DATASET = ["CIFAR128", "CIFAR", "MNIST", "SVHN", "IMAGENET", "IMAGENET64", "IMAGENET128", "CELEB", "CELEB64"][-3]

    CLUSTERRUN = False

    RUN_NAME = "" #is set later automatically

    #get number of GPUs available
    NUM_GPUS = torch.cuda.device_count()
    print("CURRENT RUN IS USING ", NUM_GPUS, " GPUs")

    USE_CHECKPOINTS = True

SETTINGS = Settings()

In [None]:
SETTINGS.RUN_NAME = str(SETTINGS.JID)+"_"+SETTINGS.DATASET+"_v5_best" + ("_gan" if SETTINGS.USE_GAN else "") + ("_dropout" if SETTINGS.USE_DROPOUT else "")
print("RUN NAME: ",SETTINGS.RUN_NAME)

In [None]:
mkdir("outputs/")
mkdir("outputs/"+SETTINGS.RUN_NAME+"/")

In [None]:
def load_svhn():
    print("Find this error & update all the paths below accordingly!")
    assert(False)

    import scipy.io
    mat = scipy.io.loadmat('/clusterarchive/ImageDatasets/SVHN/train_32x32.mat')
    data_train = mat['X']
    data_train = np.moveaxis(data_train, -1, 0)
    data_train = data_train/255.0

    mat = scipy.io.loadmat('/clusterarchive/ImageDatasets/SVHN/train_32x32.mat')
    data_test = mat['X']
    data_test = np.moveaxis(data_test, -1, 0)
    data_test = data_test/255.0
    return torch.tensor(data_train).permute(0, 3, 1, 2).float(), torch.tensor(data_test).permute(0, 3, 1, 2).float()

class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.labels = torch.zeros(len(data))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

In [None]:
if SETTINGS.DATASET == "CIFAR":
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=SETTINGS.NUM_WORKERS)
    testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    testloader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=SETTINGS.NUM_WORKERS)
elif SETTINGS.DATASET == "CIFAR128":
    #upscaled CIFAR, useful for testing
    transform = transforms.Compose([
        transforms.Resize(128),
        transforms.ToTensor(),
    ])
    trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=SETTINGS.NUM_WORKERS)
    testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    testloader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=SETTINGS.NUM_WORKERS)
elif SETTINGS.DATASET == "MNIST":
    transform = transforms.Compose([transforms.ToTensor()])
    trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=SETTINGS.NUM_WORKERS)
    testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    testloader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=SETTINGS.NUM_WORKERS)
elif SETTINGS.DATASET == "SVHN":
    svhn_train, svhn_test = load_svhn()
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    trainset = CustomDataset(svhn_train)
    trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=SETTINGS.NUM_WORKERS)
    testset = CustomDataset(svhn_test)
    testloader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=SETTINGS.NUM_WORKERS)
elif SETTINGS.DATASET == "IMAGENET":
    print("Loading ImageNet...")
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(256),
        transforms.ToTensor(),
    ])
    
    print("Find this error & update all the paths below accordingly!")
    assert(False)

    imagenet_data = ImageFolder('/clusterarchive/ImageDatasets/imagenet/images/train/',  transform=transform)#ImageFolder('/clusterarchive/ImageDatasets/imagenet/images/train/', transform=transform)
    imagenet_data_test = ImageFolder('/clusterarchive/ImageDatasets/imagenet/images/val/',  transform=transform)#ImageFolder('/clusterarchive/ImageDatasets/imagenet/images/train/', transform=transform)

    bsize = 4
    
    trainloader = DataLoader(imagenet_data, batch_size=bsize, shuffle=True, num_workers=SETTINGS.NUM_WORKERS, pin_memory=False)
    testloader = DataLoader(imagenet_data_test, batch_size=bsize, shuffle=True, num_workers=SETTINGS.NUM_WORKERS, pin_memory=False)
    print("Done loading ImageNet!")

elif SETTINGS.DATASET == "IMAGENET64":
    print("Loading ImageNet64...")
    transform = transforms.Compose([
        transforms.Resize(64),
        transforms.CenterCrop(64),
        transforms.ToTensor(),
    ])
    
    print("Find this error & update all the paths below accordingly!")
    assert(False)

    imagenet_data = ImageFolder('/clusterarchive/ImageDatasets/imagenet/images/train/',  transform=transform)#ImageFolder('/clusterarchive/ImageDatasets/imagenet/images/train/', transform=transform)
    imagenet_data_test = ImageFolder('/clusterarchive/ImageDatasets/imagenet/images/val/',  transform=transform)#ImageFolder('/clusterarchive/ImageDatasets/imagenet/images/train/', transform=transform)

    bsize = 4*3
    
    trainloader = DataLoader(imagenet_data, batch_size=bsize, shuffle=True, num_workers=SETTINGS.NUM_WORKERS, pin_memory=False)
    testloader = DataLoader(imagenet_data_test, batch_size=bsize, shuffle=True, num_workers=SETTINGS.NUM_WORKERS, pin_memory=False)
    print("Done loading ImageNet!")
elif SETTINGS.DATASET == "IMAGENET128":
    print("Loading ImageNet...")
    transform = transforms.Compose([
        transforms.Resize(128),
        transforms.CenterCrop(128),
        transforms.ToTensor(),
    ])
    
    print("Find this error & update all the paths below accordingly!")
    assert(False)

    imagenet_data = ImageFolder('/clusterarchive/ImageDatasets/imagenet/images/train/',  transform=transform)#ImageFolder('/clusterarchive/ImageDatasets/imagenet/images/train/', transform=transform)
    imagenet_data_test = ImageFolder('/clusterarchive/ImageDatasets/imagenet/images/val/',  transform=transform)#ImageFolder('/clusterarchive/ImageDatasets/imagenet/images/train/', transform=transform)

    bsize = 16
    
    trainloader = DataLoader(imagenet_data, batch_size=bsize, shuffle=True, num_workers=SETTINGS.NUM_WORKERS, pin_memory=False)
    testloader = DataLoader(imagenet_data_test, batch_size=bsize, shuffle=True, num_workers=SETTINGS.NUM_WORKERS, pin_memory=False)
    print("Done loading ImageNet!")

elif SETTINGS.DATASET == "CELEB":
    from PIL import Image
    class EmbedInBlackBox(object):
        def __init__(self, size=256):
            self.size = size

        def __call__(self, img):
            # Ensure img is a PIL Image
            if not isinstance(img, Image.Image):
                raise TypeError("Input should be a PIL Image")

            # Resize image while maintaining aspect ratio
            img.thumbnail((self.size, self.size))

            # Create a new black image
            new_img = Image.new("RGB", (self.size, self.size), (0, 0, 0))

            # Get dimensions
            width, height = img.size
            new_width = (self.size - width) // 2
            new_height = (self.size - height) // 2

            # Paste the original image onto the center of the black image
            new_img.paste(img, (new_width, new_height))

            return new_img
    
    print("Find this error & update all the paths below accordingly!")
    assert(False)
    
    #celeb is a bit tricky... in our case, we separated train/test split into folders,
    #the original format is a bit of a mess 
    pre_path = "/clusterarchive/ImageDatasets/CelebA/"
        
    transform = transforms.Compose([
        EmbedInBlackBox(size=256),
        transforms.ToTensor(),
    ])
    
    imagenet_data = ImageFolder(pre_path+'distributed_train/', transform=transform)
    imagenet_data_test = ImageFolder(pre_path+'distributed_test/', transform=transform)
    
    bsize = 4 * 2
    
    trainloader = DataLoader(imagenet_data, batch_size=bsize, shuffle=True, num_workers=SETTINGS.NUM_WORKERS, pin_memory=False)
    testloader = DataLoader(imagenet_data_test, batch_size=bsize, shuffle=True, num_workers=SETTINGS.NUM_WORKERS, pin_memory=False)
elif SETTINGS.DATASET == "CELEB64":
    
    print("Find this error & update all the paths below accordingly!")
    assert(False)
    
    #celeb is a bit tricky... in our case, we separated train/test split into folders,
    #the original format is a bit of a mess 
    from PIL import Image
    pre_path = "/clusterarchive/ImageDatasets/CelebA/"
        
    transform = transforms.Compose([
        transforms.Resize(64),
        transforms.CenterCrop(64),
        transforms.ToTensor(),
    ])
    
    imagenet_data = ImageFolder(pre_path+'distributed_train/', transform=transform)
    imagenet_data_test = ImageFolder(pre_path+'distributed_test/', transform=transform)
    
    bsize = 4*3
    
    trainloader = DataLoader(imagenet_data, batch_size=bsize, shuffle=True, num_workers=SETTINGS.NUM_WORKERS, pin_memory=False)
    testloader = DataLoader(imagenet_data_test, batch_size=bsize, shuffle=True, num_workers=SETTINGS.NUM_WORKERS, pin_memory=False)
else:
    print("INVALID DATASET CHOSEN!")
    assert(False)

def imshow(img):
    #check if is in interactive session:
    if 'ipykernel' in sys.modules:
        npimg = img.clamp(0.0, 1.0).numpy()
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
        plt.show()

def save_image(img, name):
    torchvision.utils.save_image(img, name)

for images, labels in trainloader:
    print("IMAGE VALUE RANGES: ",images.min(), " to ",images.max())
    print("IMAGE SHAPE: ",images.shape)
    imshow(torchvision.utils.make_grid(images))
    save_image(torchvision.utils.make_grid(images), "outputs/"+SETTINGS.RUN_NAME+"/test.png")
    break

In [None]:
def positional_encoding(x, DIM_L_EMBED):
    rets = []
    for i in range(DIM_L_EMBED):
        for fn in [torch.sin, torch.cos]:
            rets.append(fn((2. ** i) * x))
    return torch.cat(rets, -1)

In [None]:
assert(SETTINGS.VQVAE_C % SETTINGS.NUM_HEADS == 0)

In [None]:
SETTINGS.TRAIN_BATCHES = len(trainloader) / SETTINGS.NUM_GPUS
SETTINGS.TEST_BATCHES = len(testloader) / SETTINGS.NUM_GPUS

In [None]:
for images, labels in trainloader:
    SETTINGS.INPUT_C = images.size()[1]
    SETTINGS.INPUT_W = images.size()[2]
    break
SETTINGS.INPUT_H = SETTINGS.INPUT_W

In [None]:
class SmallResidualBlock(nn.Module):
    #this is basically the "Imagen" residual block
    def __init__(self, c_in, c_int, c_out, res_dims, relu_slope=0.01):
        super(SmallResidualBlock, self).__init__()

        if True:
            n1_no_grp = 16
            n2_no_grp = 16
            while (c_in) % n1_no_grp != 0:
                n1_no_grp -= 1
            while (c_int) % n2_no_grp != 0:
                n2_no_grp -= 1
            assert(n1_no_grp > 0 and n2_no_grp > 0)
            self.norm1 = nn.GroupNorm(num_groups=n1_no_grp, num_channels=c_in, eps=1e-05, affine=True)
            self.norm2 = nn.GroupNorm(num_groups=n2_no_grp, num_channels=c_int, eps=1e-05, affine=True)
        else:
            self.norm1 = nn.BatchNorm2d(num_features=c_in)
            self.norm2 = nn.BatchNorm2d(num_features=c_int)
        
        self.conv_res = nn.Conv2d(kernel_size=1, in_channels=c_in, out_channels=c_out, stride=1, padding=0, bias=False)
        self.conv_res_conditional = nn.Conv2d(kernel_size=1, in_channels=res_dims, out_channels=c_in, stride=1, padding=0, bias=False)
        self.conv1 = nn.Conv2d(kernel_size=3, in_channels=c_in, out_channels=c_int, stride=1, padding=1)
        self.conv2 = nn.Conv2d(kernel_size=3, in_channels=c_int, out_channels=c_out, stride=1, padding=1)
        
        self.relu = nn.LeakyReLU(negative_slope=relu_slope)
        self.activation = nn.SiLU(inplace=False)

    def forward(self, x, x_conditional_res):
        x_res = self.conv_res(x)

        x = x + self.conv_res_conditional(x_conditional_res)
        #according to Google's Imagen:
        x = self.activation(self.norm1(x))
        x = self.activation(self.norm2(self.conv1(x)))
        x = self.conv2(x)
        x = x + x_res
        return x
    
class ResidualBlocks(nn.Module):
    def __init__(self, c_in, c_int, c_out, res_dims, relu_slope=0.01, RES_BLOCKS=3):
        super(ResidualBlocks, self).__init__()

        #list of small residual blocks #
        self.blocks = nn.ModuleList()
        # use exactly THREE res blocks!
        for i in range(0, RES_BLOCKS):
            cur_c_in, cur_c_out = c_int, c_int
            if i == 0:
                cur_c_in = c_in
            if i == RES_BLOCKS - 1:
                cur_c_out = c_out
            self.blocks.append(SmallResidualBlock(cur_c_in, c_int, cur_c_out, res_dims=res_dims, relu_slope=relu_slope))

    def forward(self, x, x_conditional_res):
        for block in self.blocks:
            x = block(x, x_conditional_res)
        return x

class UNetBlock(nn.Module):
    def __init__(self, SETTINGS, c_in, c_int, width_height, NO_CHANNELS_MAX, NO_CHANNELS_MIN, res_dims, relu_slope=0.01, current_layer=0, skip_up_layers=0, skip_down_layers=0, target_out=None):
        super(UNetBlock, self).__init__()

        self.in_conv = nn.Conv2d(kernel_size=3, in_channels=c_in, out_channels=c_int, stride=1, padding=1)
        c_next = min(c_int * 2, NO_CHANNELS_MAX)
        self.in_res = ResidualBlocks(c_int, c_int, c_out=c_next, res_dims=res_dims, relu_slope=relu_slope, RES_BLOCKS=SETTINGS.RES_BLOCKS)

        self.width_height = width_height

        if self.width_height > 8 or (skip_up_layers > current_layer or current_layer < skip_down_layers): #don't go below 8x8, that makes no sense (except if we have to)
            self.unet = UNetBlock(SETTINGS, c_in=c_next, c_int=c_next, width_height=int(width_height/2), NO_CHANNELS_MAX=NO_CHANNELS_MAX, NO_CHANNELS_MIN=NO_CHANNELS_MIN, res_dims=res_dims, relu_slope=relu_slope, current_layer=current_layer+1, skip_up_layers=skip_up_layers, skip_down_layers=skip_down_layers)
            self.downsample = nn.Upsample(scale_factor=0.5, mode='bilinear')
            self.downsample_cond = nn.Upsample(scale_factor=0.5, mode='nearest')
            self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')

            next_input = self.unet.out_conv.out_channels
        else:
            self.unet = None
            next_input = c_next

        dim_out = c_int
        if skip_up_layers > current_layer:
            dim_out = self.unet.out_conv.out_channels
        
        if current_layer == 0 and target_out != None:
            dim_out = target_out
        
        self.out_res = ResidualBlocks(next_input, max(c_int, dim_out), c_out=dim_out, res_dims=res_dims, relu_slope=relu_slope, RES_BLOCKS=SETTINGS.RES_BLOCKS)
        self.out_conv = nn.Conv2d(kernel_size=3, in_channels=dim_out, out_channels=dim_out, stride=1, padding=1)
        self.full_residual = nn.Conv2d(kernel_size=1, in_channels=c_in, out_channels=dim_out, stride=1, padding=0)
        
        if current_layer < skip_down_layers:
            #make sure we do SHRINK the number of parameters!
            if self.unet == None:
                prev_ch = c_next
            else:
                prev_ch = int(self.unet.out_conv.out_channels)
            layers_to_go = min(NO_CHANNELS_MIN * (2 ** current_layer), prev_ch)
            if current_layer == 0 and target_out != None:
                layers_to_go = target_out
            
            self.out_res = ResidualBlocks(next_input, next_input, c_out=layers_to_go, res_dims=res_dims, relu_slope=relu_slope, RES_BLOCKS=SETTINGS.RES_BLOCKS)
            self.out_conv = nn.Conv2d(kernel_size=3, in_channels=layers_to_go, out_channels=layers_to_go, stride=1, padding=1)
            self.full_residual = nn.Conv2d(kernel_size=1, in_channels=c_in, out_channels=layers_to_go, stride=1, padding=0)
        
        self.activation = nn.SiLU(inplace=False)

        self.current_layer = current_layer
        self.skip_up_layers = skip_up_layers
        self.skip_down_layers = skip_down_layers

    def forward(self, x, residual):
        #input block
        x = self.in_conv(x)

        ###print("\t" * self.current_layer, "ENCODER - CURRENT SIZE: ", x.size()[2], "x", x.size()[3], " with ", x.size()[1] , " channels")
        #apply (multiple?) residual blocks
        x = self.in_res(x, residual)
        if x.size()[2] > 1 and self.unet != None:
            #downscale
            if self.skip_down_layers <= self.current_layer:
                x = self.downsample(x)
            ds_res = residual
            while x.size()[2] < ds_res.size()[2]:
                ds_res = self.downsample_cond(ds_res)
            #recursive
            x = self.unet(x, ds_res)
            #upscale
            if self.skip_up_layers <= self.current_layer:
                x = self.upsample(x)
                
        #output residual
        while x.size()[2] > residual.size()[2]:
            residual = self.upsample(residual)
        while x.size()[2] < residual.size()[2]:
            residual = self.downsample(residual)

        ###print("\t" * self.current_layer, "DECODER - CURRENT SIZE: ", x.size()[2], "x", x.size()[3], " with ", x.size()[1] , " channels")
        x = self.out_res(x, residual)
        #convolution
        x = self.out_conv(x)

        return x

class UNet(nn.Module):
    def __init__(self, SETTINGS, c_in, c_int, c_out, NO_CHANNELS_MAX, NO_CHANNELS_MIN, width_height, skip_up_layers=0, skip_down_layers=0):
        super(UNet, self).__init__()

        assert(skip_up_layers == 0 or skip_down_layers == 0)

        self.SETTINGS = SETTINGS
        res_dims = 2 * 2 * self.SETTINGS.DIM_L_EMBED + c_in
        self.position = nn.ParameterList()
        wh = width_height
        self.wh = width_height
        while wh >= 1:
            self.position.append(nn.Parameter(self.grid_positional_encoding(wh)[None], requires_grad=False))
            wh = int(wh / 2)
        
        self.unet_blocks = UNetBlock(SETTINGS, c_in, c_int, width_height=width_height, NO_CHANNELS_MAX=NO_CHANNELS_MAX, NO_CHANNELS_MIN=NO_CHANNELS_MIN, res_dims=res_dims, skip_up_layers=skip_up_layers, skip_down_layers=skip_down_layers, target_out=c_out)

        self.downsample = nn.Upsample(scale_factor=0.5, mode='bilinear')
        self.upsample = nn.Upsample(scale_factor=2.0, mode='bilinear')
        
        self.mega_residual = torch.nn.Conv2d(kernel_size=3, in_channels=c_in, out_channels=self.unet_blocks.out_conv.out_channels, stride=1, padding=1, bias=False)

        self.out_projection = nn.Conv2d(kernel_size=3, in_channels=self.unet_blocks.out_conv.out_channels, out_channels=c_out, stride=1, padding=1, bias=False)
        
    def grid_positional_encoding(self, width_height):
        grid = torch.ones(width_height, width_height, 2)
        for x in range(0, width_height):
            grid[x,:,0] = x / width_height
        for y in range(0, width_height):
            grid[:,y,1] = y / width_height
        #first dimension is what we should encode
        rets = []
        for i in range(self.SETTINGS.DIM_L_EMBED):
            for fn in [torch.sin, torch.cos]:
                rets.append(fn((2. ** i) * grid))
        return torch.cat(rets, -1).transpose(2,1).transpose(1,0)

    #decode in the sense of: turn into embeddings
    def decode_indices(self, indices, codebook):
        w, h = indices.size()[1], indices.size()[2]
        output = codebook.transpose(0,1)[indices.view(-1)]
        output = output.transpose(0,1)
        output = output.view(output.size()[0], indices.size()[0], w, h).transpose(0,1)
        return output
    
    def get_positions(self, width_height):
        positions = []
        cur_wh = width_height
        while cur_wh >= 1:
            ### use only relative position
            diff = self.wh - cur_wh
            if diff == 0:
                index = 0
            else:
                index = int(math.log2(self.wh)) - int(math.log2(cur_wh)) #0 at first
            position_absolute = self.position[index].clone()
            cur_wh = int(cur_wh / 2)
            positions.append(position_absolute)
        return positions
        
    def forward(self, x):
        residual = torch.cat((x, self.position[0].repeat(x.size()[0], 1, 1, 1)), 1)
        mega_res = self.mega_residual(x)
        #downscale:
        out = self.unet_blocks(x, residual)

        while out.size()[2] > mega_res.size()[2]:
            mega_res = self.upsample(mega_res)
        while out.size()[2] < mega_res.size()[2]:
            mega_res = self.downsample(mega_res)
        result = self.out_projection(out + mega_res)
        return result

In [None]:
class ConvEncoder(nn.Module):
    def __init__(self, c_in):
        super(ConvEncoder, self).__init__()
        self.relu = torch.nn.LeakyReLU()
        self.downsample = nn.Upsample(scale_factor=0.5, mode='bilinear')

        self.encoder = nn.Sequential()
        self.encoder.append(torch.nn.Conv2d(in_channels=c_in, out_channels=SETTINGS.CHANNELS_MAXIMUM, kernel_size=3, stride=1, padding=1))
        self.encoder.append(self.relu)
        self.encoder.append(self.downsample)
        self.encoder.append(torch.nn.Conv2d(in_channels=SETTINGS.CHANNELS_MAXIMUM, out_channels=SETTINGS.CHANNELS_MAXIMUM, kernel_size=3, stride=1, padding=1))
        self.encoder.append(self.relu)
        self.encoder.append(torch.nn.Conv2d(in_channels=SETTINGS.CHANNELS_MAXIMUM, out_channels=SETTINGS.CHANNELS_MAXIMUM, kernel_size=3, stride=1, padding=1))
        self.encoder.append(self.relu)
        self.encoder.append(self.downsample)
        self.encoder.append(torch.nn.Conv2d(in_channels=SETTINGS.CHANNELS_MAXIMUM, out_channels=SETTINGS.VQVAE_C, kernel_size=3, stride=1, padding=1))
        self.encoder.append(self.relu)
        
    def forward(self, x):
        return self.encoder(x)

class ConvDecoder(nn.Module):
    def __init__(self, c_out):
        super(ConvDecoder, self).__init__()
        self.relu = torch.nn.LeakyReLU()
        self.upsample = nn.Upsample(scale_factor=2.0, mode='bilinear')

        self.decoder = nn.Sequential()
        self.decoder.append(torch.nn.Conv2d(in_channels=SETTINGS.VQVAE_C, out_channels=SETTINGS.CHANNELS_MAXIMUM, kernel_size=3, stride=1, padding=1))
        self.decoder.append(self.relu)
        self.decoder.append(torch.nn.Conv2d(in_channels=SETTINGS.CHANNELS_MAXIMUM, out_channels=SETTINGS.CHANNELS_MAXIMUM, kernel_size=3, stride=1, padding=1))
        self.decoder.append(self.relu)
        self.decoder.append(self.upsample)
        self.decoder.append(torch.nn.Conv2d(in_channels=SETTINGS.CHANNELS_MAXIMUM, out_channels=SETTINGS.CHANNELS_MAXIMUM, kernel_size=3, stride=1, padding=1))
        self.decoder.append(self.relu)
        self.decoder.append(self.upsample)
        self.decoder.append(torch.nn.Conv2d(in_channels=SETTINGS.CHANNELS_MAXIMUM, out_channels=c_out, kernel_size=3, stride=1, padding=1))
        self.decoder.append(self.relu) #only because we have the output values in [0, 1]
    
    def forward(self, x):
        return self.decoder(x)

class Discriminator(nn.Module):
    def __init__(self, SETTINGS, c_in, c_int, wh):
        super(Discriminator, self).__init__()

        self.block_in = SmallResidualBlock(c_in, c_int, c_int, res_dims=c_in, relu_slope=0.2)

        self.unet_blocks = UNetBlock(SETTINGS, c_int, c_int, width_height=wh, NO_CHANNELS_MAX=SETTINGS.CHANNELS_MAXIMUM, NO_CHANNELS_MIN=16, res_dims=c_in, relu_slope=0.2)
        
        self.block_out = SmallResidualBlock(c_int, c_int, 1, res_dims=c_in, relu_slope=0.2)
        
    def forward(self, x):
        x_input = x
        #in
        x = self.block_in(x, x_input)
        #unet
        x = self.unet_blocks(x, x_input)
        #residual after input layer
        x = self.block_out(x, x_input)
        #patchwise discriminator
        return x
discriminator = Discriminator(SETTINGS, SETTINGS.INPUT_C, SETTINGS.CHANNELS_MINIMUM, SETTINGS.INPUT_W)
discriminator(torch.rand(4, 3, 128, 128))
print("NO PARAMS: ", sum(p.numel() for p in discriminator.parameters() if p.requires_grad))

In [None]:
def get_fullsize(x):
    if len(x.size()) == 1:
        return "ERROR: get_fullsize() called with 1D tensor"
    elif len(x.size()) == 2:
        return x.size()[1]
    elif len(x.size()) == 3:
        return x.size()[1]*x.size()[2]
    elif len(x.size()) == 4:
        return x.size()[1]*x.size()[2]*x.size()[3]
    elif len(x.size()) == 5:
        return x.size()[1]*x.size()[2]*x.size()[3]*x.size()[4]
    else:
        print("ERROR: get_fullsize() called with tensor of size ", x.size())

class BaseAE(nn.Module):
    def __init__(self, SETTINGS, c_in):
        super(BaseAE, self).__init__()

        self.SETTINGS = SETTINGS
        self.relu = torch.nn.LeakyReLU() #we use leaky relu as activation function, as it also has a gradient for an input < 0
        self.reduce = torch.nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
        self.upscale = torch.nn.Upsample(scale_factor=2.0, mode='bilinear')

        self.residual_encoder = torch.nn.Conv2d(kernel_size=1, in_channels=SETTINGS.INPUT_C, out_channels=SETTINGS.VQVAE_C, stride=1, padding=0)
        self.residual_decoder = torch.nn.Conv2d(kernel_size=1, in_channels=SETTINGS.VQVAE_C, out_channels=SETTINGS.INPUT_C, stride=1, padding=0)

        if True:
            self.encoder = UNet(SETTINGS, c_in=c_in, c_int=SETTINGS.CHANNELS_MINIMUM, c_out=SETTINGS.VQVAE_C, NO_CHANNELS_MAX=SETTINGS.CHANNELS_MAXIMUM, NO_CHANNELS_MIN=SETTINGS.CHANNELS_MINIMUM, width_height=SETTINGS.INPUT_W, skip_up_layers=SETTINGS.DS_SPEED_FACTOR)
            
            #create a list of projections:
            self.down_proj = nn.ModuleList()
            self.up_proj   = nn.ModuleList()
            for i in range(0, SETTINGS.NUM_HEADS):
                self.down_proj.append(torch.nn.Conv1d(in_channels=int(SETTINGS.INPUT_W*SETTINGS.INPUT_H/(2 ** SETTINGS.DS_SPEED_FACTOR)/(2 ** SETTINGS.DS_SPEED_FACTOR)), out_channels=SETTINGS.VQVAE_D, kernel_size=1, stride=1, padding=0))
                self.up_proj.append(torch.nn.Conv1d(in_channels=SETTINGS.VQVAE_D, out_channels=int(SETTINGS.INPUT_W*SETTINGS.INPUT_H/(2 ** SETTINGS.DS_SPEED_FACTOR)/(2 ** SETTINGS.DS_SPEED_FACTOR)), kernel_size=1, stride=1, padding=0))
            self.decoder = UNet(SETTINGS, c_in=SETTINGS.VQVAE_C, c_int=SETTINGS.CHANNELS_MAXIMUM, c_out=c_in, NO_CHANNELS_MAX=SETTINGS.CHANNELS_MAXIMUM, NO_CHANNELS_MIN=SETTINGS.CHANNELS_MINIMUM, width_height=int(SETTINGS.INPUT_W/(2 ** SETTINGS.DS_SPEED_FACTOR)), skip_down_layers=SETTINGS.DS_SPEED_FACTOR)
        
        self.codebook = torch.nn.Parameter((torch.rand(SETTINGS.VQVAE_C, SETTINGS.VQVAE_K, SETTINGS.VQVAE_D) * 2.0 - 1.0) * 0.5, requires_grad=True)
        self.offset = torch.nn.Parameter((torch.rand(SETTINGS.VQVAE_D, SETTINGS.VQVAE_C) * 2.0 - 1.0) * 0.5, requires_grad=True)

        self.pos = positional_encoding(torch.arange(0, SETTINGS.VQVAE_C)[:,None] / (SETTINGS.VQVAE_C+1), 10)[None,:,:].transpose(1,2)

    def quantise_individually(self, x):
        x_before_rounding = x.clone()
        #size  x in: [b x D x C x 1]
        #size cb in: [C x K x D]
        x_before_rounding = x.clone()

        ### torch.rand(C, D, K)
        B = x.size()[0]
        x = x.transpose(1,2)[:,:,:,0]
        codebook = self.codebook.transpose(1,2)

        dists = (x[..., None] - codebook[None]).square().sum(-2)#
        idx = dists.argmin(-1)
        
        codebook_expanded = codebook[None].expand(B, -1, -1, -1)
        idx_expanded = idx[..., None, None].expand(-1, -1, self.SETTINGS.VQVAE_D, -1)
        
        x_rounded = torch.gather(codebook_expanded, -1, idx_expanded).transpose(1,2)

        loss_commitment = (x_before_rounding - x_rounded.detach()).square().mean() #only have gradient for encoder 
        loss_codebook   = (x_before_rounding.detach() - x_rounded).square().mean() #only have gradient for codebook

        x = x_before_rounding - (x_before_rounding - x_rounded).detach()
        
        return x, loss_codebook, loss_commitment, idx.view(x.size()[0], x.size()[2], x.size()[3])
    
    def quantise(self, x):
        #this is where the magic happens:
        #    1. take input of size [b x W x H x VQVAE_D] and reshape to [b x VQVAE_D x W x H ];
        #       this means we have ALL the codewords in the last dimension now, so we can just merge together all other dimensions to easily compare them to the codewords
        x = x.permute(0, 2, 3, 1).contiguous()
        size_in = x.size()
        x = x.view(-1, self.SETTINGS.VQVAE_D)
        #   2. compare [b*W*H x VQVAE_D] to [CODEWORDS x VQVAE_D]: by re-shaping to have [b*W*H x 1 x VQVAE_D] and [1 x CODEWORDS x VQVAE_D], the substraction will
        #      give us a [b*W*H x CODEWORDS X VQVAE_D] tensor of all distance pairs between input elements (b*W*H) and the CODEWORDS-many codewords; just square, 
        #      then sum up the last dimension to get sum of the distance values for each dimension (=euclidean squared distance)
        #      and then take the argmin to get the index of the closest codeword for each input element - these are our indices we want!
        if self.SETTINGS.DATASET != "IMAGENET":
            with torch.no_grad():
                indices = (x[:,None] - self.codebook.view(1, self.SETTINGS.VQVAE_K, self.SETTINGS.VQVAE_D)).square().sum(dim=2).argmin(dim=1)
        else:
            #baseline:
            chunksize = int(x.size()[0]/8)#1024 # baseline value divided by number of elements to tokenise
            with torch.no_grad():
                index_begin = 0
                index_end   = chunksize
                acc_indices = []
                while index_begin < x.size()[0]:
                    indices = (x[index_begin:index_end][:,None] - self.codebook.view(1, self.SETTINGS.VQVAE_K, self.SETTINGS.VQVAE_D)).square().sum(dim=2).argmin(dim=1)
                    acc_indices.append(indices)
                    index_begin += chunksize
                    index_end += chunksize
                indices = torch.cat(acc_indices)
            
        #   3. we take those indices we just looked up: these are the indices of the closest codeword for each input element, so we can now look up the actual 
        #      codeword to find the closest codeword to each input element aka the "centroids" we round to
        x_rounded = self.codebook[indices]
        
        #   4. we calculate the loss for the commitment loss (i.e. does the encoder produce stuff from the codebook?) and the codebook loss (how close is the codebook to the encoder outputs?)
        #      the value is exactly the same for both terms, but this way we can a) train encoder and codebook with different magnitudes and b) the encoder outputs and codebook outputs are
        #      not just "shrinked" to one super tiny value: the loss we be almost zero if codebook and encoder outputs would be just downscaled by some tiny factor (latent space just shrinks together)
        #      to avoid that, we use the detach() function, making sure that codebook and encoder outputs move individually and not just shrink together to something tiny
        loss_commitment = (x - x_rounded.detach()).square().mean() #only have gradient for encoder 
        loss_codebook   = (x.detach() - x_rounded).square().mean() #only have gradient for codebook

        #   5. the actual rounding we do: we take the difference between the rounded and the unrounded value and add it to the unrounded value
        #      the paper calls this "straight through estimator", as we just pass the gradient through the rounding operation (which is not differentiable)
        #      what essentially happens here is that we just take our x and substract some float without any gradient, similar to writing "x = x - 0.1":
        #      "(x - x_rounded).detach()" becomes just some number that we substract from x; this converts x to the same values as x_rounded, but the gradient is not passed through this operation
        #      i.e. we round, but keep the (then a bit inexact) gradient 
        x = x - (x - x_rounded).detach() #change x to x_rounded, but keep the gradient from x
        #   6. we reshape back to the original shape and return the values; we also return both losses and the indices of the closest codewords
        x = x.view(size_in)
        x = x.permute(0, 3, 1, 2).contiguous()
        return x, loss_codebook, loss_commitment, indices.view(x.size()[0], x.size()[2], x.size()[3])
    
    def encode(self, x):
        #1. encode
        res = self.residual_encoder(x)
        x = self.encoder(x)
        while x.size()[2] < res.size()[2]:
            res = self.reduce(res)
        x = x + res #self.upscale
        #2. re-shape and swap dimensions
        x = x.view(x.size()[0], self.SETTINGS.VQVAE_C, -1)
        x = x.transpose(1,2)
        #3. project down
        w = x.size()[2]
        x_out = None
        for i in range(0, SETTINGS.NUM_HEADS):
            out = self.down_proj[i](x[:,:,int(w/SETTINGS.NUM_HEADS*(i)):int(w/SETTINGS.NUM_HEADS*(i+1))])
            if x_out == None:
                x_out = out
            else:
                x_out = torch.cat((x_out, out), 2)
        x = x_out

        #4. quantise
        #X B4 Q:  torch.Size([4, 64, 256])
        x = x.contiguous().view(x.size()[0], self.SETTINGS.VQVAE_D, -1, 1) #unfold, here just for quantisation

        x, loss_codebook, loss_commitment, indices = self.quantise_individually(x)

        #don't add individual offsets, but use the global one
        x = x + self.offset[None, :, :, None]
        return x, loss_codebook, loss_commitment, indices

    def decode(self, x, w, h):
        x = x.view(x.size()[0], self.SETTINGS.VQVAE_D, -1) #fold back after quantisation; REDUNDANT!!! only for using only decoder
        #5. project up
        x_out = None
        blocksize = int(x.size()[2]/SETTINGS.NUM_HEADS)
        for i in range(0, SETTINGS.NUM_HEADS):
            out = self.up_proj[i](x[:,:,(blocksize*i):(blocksize * (i+1))])
            if x_out == None:
                x_out = out
            else:
                x_out = torch.cat((x_out, out), 2)
        x = x_out
        
        #6. reshape back
        x = x.transpose(1,2) #swap global information and local information back again (channels vs spatial dimensions)
        x = x.contiguous().view(x.size()[0], self.SETTINGS.VQVAE_C, int(w/(2 ** self.SETTINGS.DS_SPEED_FACTOR)), int(h/(2 ** self.SETTINGS.DS_SPEED_FACTOR)))
        #7. decode
        res = self.residual_decoder(x)
        x = self.decoder(x)
        while x.size()[2] > res.size()[2]:
            res = self.upscale(res)
        x = x + res
        return x

    def forward(self, x, dropout_mask=None):
        #encode, then quantise, then decode; pass the codebook loss and commitment loss through so we can use them for training
        w, h = x.size()[2], x.size()[3]
        x, loss_codebook, loss_commitment, indices = self.encode(x)
        
        x = x.view(x.size()[0], self.SETTINGS.VQVAE_D, -1) #fold back after quantisation
        if dropout_mask != None:
            x = x * dropout_mask[:,None,:]
            
        x = self.decode(x, w, h)

        return x, loss_codebook, loss_commitment, indices

In [None]:
class Parallel_GQAE(L.LightningModule):
    def backward(self, loss):
        loss.backward()
    
    def __init__(self, SETTINGS, ae, discriminator=None):
        super().__init__()

        self.ae = ae
        self.SETTINGS = SETTINGS
        self.APPLY_GAN = False
        self.STEPS = 0
        self.unique_indices = set()
        
        self.index_count = torch.zeros(self.SETTINGS.VQVAE_C, self.SETTINGS.VQVAE_K)
        self.magnitudes = torch.zeros(self.SETTINGS.VQVAE_C, self.SETTINGS.VQVAE_K)
        self.EPSILON = 0.000001
        self.rec_loss = []
        self.TOTAL_ITS = 0
        self.last_output = None
        self.best_test = None

        self.start_epoch = time.time()
        
        if self.SETTINGS.USE_GAN:
            self.automatic_optimization = False
            self.hinge = nn.ReLU()
            self.loss_fn_perceptual = lpips.LPIPS(net='vgg')
            self.discriminator = discriminator
    
    def print(self, *args, **kwargs):
        if self.global_rank == 0:
            print(*args, **kwargs)

    def on_after_backward(self):
        #track gradients of codewords to find out which ones to reset
        self.magnitudes += self.ae.codebook.grad.abs().mean(dim=2).detach().cpu()
        
    def generic_step(self, batch, batchidx, test=False):
        data, label = batch

        if not test:
            self.TOTAL_ITS += 1
            
        prefix = "train"
        if test:
            prefix = "test"

        if not self.APPLY_GAN and self.SETTINGS.USE_GAN and self.TOTAL_ITS > self.SETTINGS.WARMUP_ITS_GAN:
            self.APPLY_GAN = True
            self.print("--> STARTING TO APPLY GAN")
        
        if self.SETTINGS.USE_GAN:
            optimizer_ae, optimizer_dis = self.optimizers()
            optimizer_ae.zero_grad()

        ### discriminator ###
        if self.APPLY_GAN and batchidx % 2 == 1 and not test:
            optimizer_dis.zero_grad()

            self.ae.train(False)
            self.discriminator.train(True)

            outputs, _, _, _ = self.ae(data)

            out_real = self.discriminator(data)
            out_generated = self.discriminator(outputs)
            
            #real should be 1.0, generated should be -1.0
            loss_discriminator = 0.5 * (self.hinge(1.0 - out_real).mean() + self.hinge(1.0 + out_generated).mean())
            self.log(prefix+"_loss_disc", loss_discriminator.item(), on_epoch=True, sync_dist=False)

            loss_discriminator.backward()
            optimizer_dis.step()
            
            return
    
        ### generator ###
        if True:
            self.ae.train(True)
            if self.SETTINGS.USE_GAN:
                self.discriminator.train(False)

            dropout_mask_binary = torch.randint(0, 3, (data.size()[0],))
            dropout_mask = torch.ones(data.size()[0], self.SETTINGS.VQVAE_C, device=data.device)
            if self.SETTINGS.USE_DROPOUT and not test:
                for i in range(dropout_mask.size()[0]):
                    if dropout_mask_binary[i] == 0:
                        dropout_mask[i,torch.randint(1, self.SETTINGS.VQVAE_C, (1,)).item():] = 0.0
            else:
                dropout_mask = None
            
            outputs, loss_codebook, loss_commitment, indices = self.ae(data, dropout_mask)
            
            indices = indices[:,:,0]

            if not test:
                onehot = torch.nn.functional.one_hot(indices.detach().cpu(), num_classes=SETTINGS.VQVAE_K)
                onehot = onehot.sum(dim=0)
                #now shape [256 x 512]
                self.index_count += onehot

            #crop:
            if SETTINGS.DATASET == "CELEB":
                data = data[:,:,19:(19+218),39:(39+178)]
                outputs = outputs[:,:,19:(19+218),39:(39+178)]

            #output images after every epoch
            if self.global_rank == 0 and batchidx == 0 and not test:
                save_image(torchvision.utils.make_grid(data.cpu().detach()).cpu().detach(), "outputs/"+self.SETTINGS.RUN_NAME+"/epoch_"+str(self.current_epoch)+"_in.png")
                save_image(torchvision.utils.make_grid(outputs.cpu().detach()).cpu().detach(), "outputs/"+self.SETTINGS.RUN_NAME+"/epoch_"+str(self.current_epoch)+"_out.png")
                imshow(torchvision.utils.make_grid(data.cpu().detach()).cpu().detach())
                imshow(torchvision.utils.make_grid(outputs.cpu().detach()).cpu().detach())

            #count unique indices:
            self.unique_indices.update(torch.unique(indices).tolist())
            #make sure unique indices stay unique
            self.unique_indices = set(self.unique_indices)
            
            if self.APPLY_GAN:
                #use perceptual loss + l1 loss
                loss_reconstruction = self.loss_fn_perceptual(outputs, data).mean()
                loss_l1 = (outputs - data).abs().mean()
            else:
                #MSE
                loss_reconstruction = (outputs - data).square()
                if test:
                    loss_reconstruction = (outputs - data).square()
                    loss_reconstruction = (-10.0 * torch.log10(loss_reconstruction.mean(dim=3).mean(dim=2).mean(dim=1))).mean()
                loss_reconstruction = loss_reconstruction.mean()
            
            if self.APPLY_GAN and not test:
                loss_gan = -(self.discriminator(outputs)).mean()

                loss_rec_grads = torch.autograd.grad(loss_reconstruction, self.ae.decoder.unet_blocks.out_conv.weight, retain_graph=True)[0]
                loss_gan_grads = torch.autograd.grad(loss_gan, self.ae.decoder.unet_blocks.out_conv.weight, retain_graph=True)[0]
                gan_lambda = 0.75 * torch.norm(loss_rec_grads) / torch.norm(loss_gan_grads + self.EPSILON)
                gan_lambda = gan_lambda.detach()
                
                loss_reconstruction += loss_l1
            elif self.APPLY_GAN:
                gan_lambda = 0.0
                loss_gan = torch.zeros(1, device=data.device)

            loss_non_rec = (loss_codebook * 0.25 + loss_commitment)
            loss = loss_reconstruction + loss_non_rec
            if self.APPLY_GAN:
                loss += gan_lambda * loss_gan
                self.log(prefix+"_loss_gan", gan_lambda * loss_gan.item(), on_epoch=True, sync_dist=False)

            self.log(prefix+"_loss_rec", loss_reconstruction.item(), on_epoch=True, sync_dist=False)
            self.log(prefix+"_loss_code", loss_codebook.item(), on_epoch=True, sync_dist=False)
            self.log(prefix+"_loss_comm", loss_commitment.item(), on_epoch=True, sync_dist=False)
            self.rec_loss.append(loss_reconstruction.item())
            self.rec_loss = self.rec_loss[-100:]
            
            if not test:
                if self.TOTAL_ITS == 100 and self.global_rank == 0:
                    time_taken = time.time() - self.epoch_start_time
                    time_per_it = time_taken / self.TOTAL_ITS
                    self.print("Projected total time for this epoch: ",((self.SETTINGS.TRAIN_BATCHES+self.SETTINGS.TEST_BATCHES) * time_per_it)," seconds")
            
                if self.global_rank == 0 and (self.last_output == None or (time.time() > self.last_output + 20.0)) and batchidx > 100:
                    self.print("\t"+str(batchidx/self.SETTINGS.TRAIN_BATCHES*100)+"% done, current running mean: "+str(sum(self.rec_loss)/len(self.rec_loss)))
                    self.print("\t--> "+str((time.time() - self.epoch_start_time) / (batchidx/self.SETTINGS.TRAIN_BATCHES) - (time.time() - self.epoch_start_time))+" seconds left until epoch is concluded")
                    self.last_output = time.time()
            
            if self.SETTINGS.USE_GAN and not test:
                loss.backward()
                self.magnitudes += self.ae.codebook.grad.abs().mean(dim=2).detach().cpu()
                optimizer_ae.step()
                return

            return loss

    def validation_step(self, loaded_data, batchidx):
        return self.generic_step(loaded_data, batchidx, test=True)
        
    def training_step(self, loaded_data, batchidx):
        #re-initialise the codebooks:
        if self.SETTINGS.NUM_TRAIN_BATCHES_UNTIL_RESET != None and self.TOTAL_ITS % self.SETTINGS.NUM_TRAIN_BATCHES_UNTIL_RESET == 0 and self.TOTAL_ITS > 0:
            print("---------------------------------")
            print("---> RE-ALIGNING FREQUENCIES <---")
            print("---------------------------------")
            
            re_aligned_frequencies_total = []
            for frequency in range(0, self.SETTINGS.VQVAE_C):
                #1. while there exist codewords with magnitude = 0, re-initialise them
                re_aligned_frequencies = 0
                while True:
                    #find a codeword to re-distribute
                    unused_codeword = self.magnitudes[frequency].argmin()
                    if self.magnitudes[frequency, unused_codeword] > 0:
                        break

                    #find a codeword to split up
                    replacement_codeword = self.magnitudes[frequency].argmax()
                    if self.magnitudes[frequency, replacement_codeword] <= self.EPSILON:
                        break

                    ### torch.zeros(self.SETTINGS.VQVAE_C, self.SETTINGS.VQVAE_K, self.SETTINGS.VQVAE_D)
                    #re-distribute the codeword: find the direction with the highest magnitude and move the unused codeword in that direction
                    direction = (torch.rand(self.SETTINGS.VQVAE_D) * 2.0 - 1.0) * 0.0001
                    with torch.no_grad():
                        self.ae.codebook[frequency, unused_codeword] = self.ae.codebook[frequency, replacement_codeword].to(self.ae.codebook[frequency, unused_codeword].get_device()) + direction.to(self.ae.codebook[frequency, unused_codeword].get_device())

                    #set the replacement codeword/unused codeword to EPSILON to make sure that we don't use it again
                    self.magnitudes[frequency, replacement_codeword] = self.EPSILON
                    self.magnitudes[frequency, unused_codeword]      = self.EPSILON
                    re_aligned_frequencies += 1
                re_aligned_frequencies_total.append(re_aligned_frequencies)
                
                #set back to zero for next epoch (accumulate all over again)
                self.magnitudes[frequency] *= 0.0
            print("Re-aligned ",sum(re_aligned_frequencies_total)/len(re_aligned_frequencies_total)," frequencies on average")
        #now do the step itself
        return self.generic_step(loaded_data, batchidx, test=False)
    
    def on_train_epoch_start(self):
        self.print("\n\n*** STARTING EPOCH "+str(self.current_epoch)+" ***")
        self.index_count = self.index_count * 0
        self.unique_indices = set()
        self.epoch_start_time = time.time()

    def on_train_epoch_end(self):
        #eval everything
        if self.global_rank == 0:
            self.print("*** DONE WITH EPOCH ",self.current_epoch," AFTER ",self.TOTAL_ITS," ITS ***")
            
            its_goal = 0
            if self.SETTINGS.DATASET == "CELEB":
                its_goal = 70
            elif "IMAGENET" in self.SETTINGS.DATASET:
                its_goal = 10000
            else:
                its_goal = 100

            time_run = time.time() - self.start_epoch
            print("PROGNOSTICATED TIME LEFT FOR ",its_goal," ITS: ",(time_run / (self.current_epoch * its_goal + 0.0001)  - time_run)/60.0/60.0," HOURS")
            
            self.print("\tUNIQUE INDICES IN TRAINING: ",len(self.unique_indices))
            if self.SETTINGS.USE_INDIVIDUAL_CODEBOOKS:
                no_unique = (self.index_count > 0).sum(dim=1)
                print("\tNo. of elements in each frequency band, min/avg/max: ",no_unique.min().item(),"/",no_unique.float().mean().item(),"/",no_unique.max().item())
            self.print("\tTIME TAKEN: ",time.time() - self.epoch_start_time," SECONDS")
            self.print("\tLosses TRAIN:")
            self.print("\t\tRec: ",self.trainer.callback_metrics["train_loss_rec"].item())
            self.print("\t\tComm: ",self.trainer.callback_metrics["train_loss_code"].item())
            self.print("\t\tCode: ",self.trainer.callback_metrics["train_loss_comm"].item())
            if self.APPLY_GAN:
                self.print("\t\t\tGAN Losses:")
                self.print("\t\t\t\tDisc: ",self.trainer.callback_metrics["train_loss_disc"].item())
                self.print("\t\t\t\tGen: ",self.trainer.callback_metrics["train_loss_gan"].item())
                
            if not SETTINGS.USE_GAN:
                self.print("\tLosses TEST:")
                self.print("\t\tRec / PSNR: ",self.trainer.callback_metrics["test_loss_rec"].item())
                self.print("\t\tComm: ",self.trainer.callback_metrics["test_loss_code"].item())
                self.print("\t\tCode: ",self.trainer.callback_metrics["test_loss_comm"].item())
                if self.best_test == None or self.best_test < self.trainer.callback_metrics["test_loss_rec"].item():
                    print("--> FOUND NEW BEST TEST LOSS: ",self.trainer.callback_metrics["test_loss_rec"].item())
                    print("\t...saving...")
                    self.best_test = self.trainer.callback_metrics["test_loss_rec"].item()
                    torch.save(self.ae.state_dict(), "outputs/"+self.SETTINGS.RUN_NAME+"/best.net")

            if self.current_epoch == its_goal:
                print("--> SHUTTING DOWN")
                torch.save(self.ae.state_dict(), "outputs/"+self.SETTINGS.RUN_NAME+"/last.net")
                sys.exit(0)
                asdf
            if self.APPLY_GAN and self.current_epoch % 5 == 0:
                print("--> SAVING DOWN")
                torch.save(self.ae.state_dict(), "outputs/"+self.SETTINGS.RUN_NAME+"/"+str(self.current_epoch)+".net")
        
    #for debugging purposes:
    #def on_after_backward(self) -> None:
    #    for name, p in self.ae.named_parameters():
    #        if p.grad is None and p.requires_grad:
    #            print(name)
    
    def configure_optimizers(self):
        LR = 0.0002
        if SETTINGS.USE_GAN:
            LR = 0.0002
        #IF IMAGENET: scale down
        if SETTINGS.DATASET in ["IMAGENET", "CELEB", "CELEB64"]:
            LR = 0.0002
        
        if SETTINGS.LR != None:
            LR = SETTINGS.LR
        optim_ae = torch.optim.AdamW(self.ae.parameters(), lr=LR, weight_decay=0.01)

        if SETTINGS.USE_GAN:
            optim_dis = torch.optim.AdamW(self.discriminator.parameters(), lr=LR, weight_decay=0.01)
            self.discriminator.train(False)

            return optim_ae, optim_dis
        
        return optim_ae

def get_GQAE_parallel(SETTINGS):
    no_limited_batches_train = None
    no_limited_batches_test  = None
    if "IMAGENET" in SETTINGS.DATASET: #don't train forever
        no_limited_batches_train = 0.1
        no_limited_batches_test  = 0.1
        if SETTINGS.USE_GAN:
            no_limited_batches_test  = 0.0

    
    its_goal = 0
    if SETTINGS.DATASET == "CELEB64":
        its_goal = 70
    elif "IMAGENET" in SETTINGS.DATASET or SETTINGS.DATASET == "CELEB":
        its_goal = 1000000
    else:
        its_goal = 100
    its_goal += 1
    
    ae = BaseAE(SETTINGS, c_in=SETTINGS.INPUT_C)
    
    if SETTINGS.USE_GAN:
        discriminator = Discriminator(SETTINGS, SETTINGS.INPUT_C, 16, SETTINGS.INPUT_W)
    else:
        discriminator = None
    
    parallel_net = Parallel_GQAE(SETTINGS, ae, discriminator)

    if SETTINGS.CLUSTERRUN:
        trainer = L.Trainer(limit_train_batches=no_limited_batches_train, limit_val_batches=no_limited_batches_test, max_epochs=its_goal, devices=torch.cuda.device_count(), accelerator="gpu", enable_progress_bar=False, enable_checkpointing=SETTINGS.USE_CHECKPOINTS, logger=False, strategy="ddp_find_unused_parameters_true")
    else:
        trainer = L.Trainer(limit_train_batches=no_limited_batches_train, limit_val_batches=no_limited_batches_test, max_epochs=its_goal, devices=torch.cuda.device_count(), accelerator="gpu", enable_progress_bar=False, enable_checkpointing=SETTINGS.USE_CHECKPOINTS, logger=False)

    trainer.fit(model=parallel_net, train_dataloaders=trainloader, val_dataloaders=testloader)

    print("--> DONE TRAINING! shutting down...")
    sys.exit(0)
get_GQAE_parallel(SETTINGS)