<h1>This is the demo file for the paper:<br/>
"Multidimensional Byte Pair Encoding: Shortened Sequences for Improved Visual Data Generation"</h1>
We will later publish the full version, along with the C++ version of MDBPE algorithm

<h1><span style="color: red;">This is just a small demo with FEW examples and LITTLE compression - it's tuned so it runs FAST and shows the principle, not such that it works optimally!</span></h1>

In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time
import matplotlib.pyplot as plt
import math
import numpy as np
import torch.nn as nn
import random

<h1>0. Setup</h1>

In [None]:
WIDTH, HEIGHT = 12, 12 #use a reduced version of MNIST for this demo - 12-by-12 is easier to work with

MAXIMUM_NUMBER_OF_TOKENS = 256+0 #256 = only use the base tokens, 256 + 32 --> use 32 extra tokens to compact stuff into (i.e. expand vocabulary by 32 entries)
LIMITED_EXAMPLES = 1000 #--> only use a subset of the data (so that this runs quickly); set to "None" to use all data (takes a while, our algorithm in python is made for readability, i.e. has loops! c++ is lightning fast!)
BASE_TOKENS = 256
NUM_TOKENS = MAXIMUM_NUMBER_OF_TOKENS

PRE_LOAD = True #can be a bit of a memory hog, but speeds up training by loading all the data into the RAM

<h3>Load data:</h3>

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#resize to WIDTH-by-HEIGHT
transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((WIDTH, HEIGHT))])
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_testnset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(mnist_trainset, batch_size=8, shuffle=False)
test_loader = DataLoader(mnist_testnset, batch_size=8, shuffle=False)

def show(img):
    assert(len(img.size()) == 3)
    plt.imshow(img.clamp(0.0, 1.0).squeeze(), cmap='gray')
    plt.show()

for img, label in train_loader:
    show(img[0])
    break

<h1>1. Transcribe</h1>
Translate dataset into hard drive files:<br/>
"transcribed_data" is the tensors of [uniqueIDs, classIDs]<br/>
"transcribed_data_og" are the ORIGINAL tensors, used to look up what actual base tokens a larger token is made up of

In [None]:
#remove path "transcribed_data" and "transcribed_data_test" if they exist (start with clean slate)
import os
import shutil
if os.path.exists("transcribed_data"):
    shutil.rmtree("transcribed_data")
if os.path.exists("transcribed_data_og"):
    shutil.rmtree("transcribed_data_og")
if os.path.exists("tokensequences"):
    shutil.rmtree("tokensequences")
if os.path.exists("tokenshapes"):
    shutil.rmtree("tokenshapes")
#create folder "transcribed data"
import os
if not os.path.exists('transcribed_data'):
    os.makedirs('transcribed_data')
if not os.path.exists('transcribed_data_og'):
    os.makedirs('transcribed_data_og')
#creat subfolders
FILE_MODULO = 128 #number of subfolders, make sure this is something large for e.g. ImageNet
assert(len(train_loader.dataset) / FILE_MODULO < 500)
for i in range(FILE_MODULO):
    if not os.path.exists('transcribed_data/' + str(i)):
        os.makedirs('transcribed_data/' + str(i))
    if not os.path.exists('transcribed_data_og/' + str(i)):
        os.makedirs('transcribed_data_og/' + str(i))

In [None]:
#go over (first few) elements, then write to file as tokens:
#write both the token class (e.g. "what greyscale value")
#and the unique id (for now: just a unique ID per pixel, later one a unique ID per constellation of pixels made up from multiple tokens)
its = 0
unique = None
for img, label in train_loader:
    #make sure we can save some unique ID per token
    if unique is None:
        unique = torch.zeros_like(img[0,0].clone()).long()
        for x in range(0, img.size()[2]):
            for y in range(0, img.size()[3]):
                unique[x][y] = x * img.size()[2] + y
    #turn into tokens; later replace with VQ-VAE
    img *= 255.0
    img = img.int()

    for k in range(0, img.size()[0]):
        torch.save([unique, img[k,0].clone()], "transcribed_data/"+str(its%FILE_MODULO)+"/"+str(its)+".dat")
        torch.save(img[k,0].clone(), "transcribed_data_og/"+str(its%FILE_MODULO)+"/"+str(its)+".dat")
        
        its += 1
        if LIMITED_EXAMPLES != None and its >= LIMITED_EXAMPLES:
            break
        
    if LIMITED_EXAMPLES != None and its >= LIMITED_EXAMPLES:
        break

<h1>2. Start MDBPE</h2>

In [None]:
#create a unique dataloader that loads everything from the transcribed data folder:
class UniqueDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.files = []
        #recursively get all files in the folder
        for i in range(FILE_MODULO):
            for root, dirs, files in os.walk("transcribed_data/" + str(i)):
                for file in files:
                    self.files.append(str(i) + "/" + file)
    def __len__(self):
        return len(self.files)
    def __getitem__(self, idx):
        unique, classes = torch.load("transcribed_data/" + self.files[idx], weights_only=False)
        return unique, classes, "transcribed_data/" + self.files[idx]

unique_train = UniqueDataset()
for a, _, _ in unique_train:
    display_colours = torch.rand(3, a.size()[0] * a.size()[1]) #random colours for visualisation
    break

In [None]:
def visualise_unique(input):
    #get a list of unique colours - at most, n-by-n unique colours
    global display_colours
    #for a given input image with size [n x n], visualise the unique tokens with a different color
    assert(len(input.size()) == 2)
    
    img = display_colours.clone()[:,input.view(-1)].view(3, input.size()[0], input.size()[1])
    #show the [3 x n x n] image:

    plt.imshow(img.permute(1, 2, 0))
    #make sure plt shows during the execution
    plt.show()

#the largest token in data
NEW_TOKEN_ID = 256

print("NEW TOKEN STARTS AT ", NEW_TOKEN_ID)
START_OF_NEW_TOKENS = NEW_TOKEN_ID

CODEWORDS_TO_ADD = MAXIMUM_NUMBER_OF_TOKENS - START_OF_NEW_TOKENS

In [None]:
#get the anchor point of the token constellation: identify the most left token in the upper most row - conveniently, this is the first occurence of the unique ID
def get_anchor(unique, pos_x, pos_y):
    id = unique[pos_x][pos_y]
    #find most top-left corner of the id
    #get all occurences of id:
    top_left = torch.nonzero(unique == id)[0]
    return top_left

CODEWORDS_ADDED = 0
token_path_sizes = {}

time_start = time.time()
last_output = time_start
rules = []
while NEW_TOKEN_ID < MAXIMUM_NUMBER_OF_TOKENS:
    #re-load dataset to make sure we always can merge stuff
    unique_train = UniqueDataset()
    rnd_subset = DataLoader(unique_train, batch_size=1, shuffle=True)
    
    #############################
    ########## STEP 1 ###########
    ## go over data, count BPs  #
    #############################

    #actual BPE:

    #pass 1: count occurences of every unique token constellation in both direction
    map_bpe = {}
    it = 0

    #go over training data in random order:
    for unique_, classes_, _ in rnd_subset:
        #remove the batch from the dataloader
        unique = unique_[0]
        classes = classes_[0]
        
        if time.time() - last_output > 60:
            print("\tDONE WITH ",it/ len(unique_train)*100,"% on counting pass")
            last_output = time.time()

        if it == 0:
            visualise_unique(unique)
            
        it += 1

        #go over every pixel coordinate, track which ocurs most often:
        used = {}
        for x in range(0, unique.size()[0]):
            for y in range(0, unique.size()[1]):
                #get the two possible alignments of our box, horizontal and vertical:
                for n in [(x+1,y), (x,y+1)]: 
                    n_x, n_y = n[0], n[1] #position of our neighbour token, i.e. for two neighbouring tokens: [x,y][n_x,n_y]
                    #skip if neighbour is out of bounds
                    if n_x >= unique.size()[0] or n_y >= unique.size()[1]:
                        continue

                    #skip if they're the same unique ID (=we can't merge a single large token with itself)
                    if unique[x][y] == unique[n_x][n_y]:
                        continue

                    #find the class ID (greyscale value/later larger token values) of the two elements under our byte pair mask
                    class_a, class_b = classes[x][y].item(), classes[n_x][n_y].item()
                    #find the anchor points to identify the constellation they're in
                    anchor_a, anchor_b = get_anchor(unique, x, y), get_anchor(unique, n_x, n_y)
                    
                    #only do every centre pair once - if we already looked at two unique tokens, we don't need to look at them again (for this one image):
                    #   two unique tokens at specific positions are only counted once
                    if not (anchor_a[0].item(), anchor_a[1].item(), anchor_b[0].item(), anchor_b[1].item()) in used:
                        used[(anchor_a[0].item(), anchor_a[1].item(), anchor_b[0].item(), anchor_b[1].item())] = True
                        used[(anchor_a[0].item(), anchor_a[1].item(), anchor_b[0].item(), anchor_b[1].item())] = True
                    else:
                        continue

                    #compute vector from anchor to anchor to uniquely identify the constellation
                    v = (anchor_a - anchor_b)
                    
                    #count occurence of this constellation of features:
                    if not (class_a, class_b, v[0].item(), v[1].item()) in map_bpe:
                        #add if not in map
                        map_bpe[(class_a, class_b, v[0].item(), v[1].item())] = 1
                    else: #just increment
                        map_bpe[(class_a, class_b, v[0].item(), v[1].item())] += 1
        if its >= len(rnd_subset) / 10:
            break
                    
    #############################
    # find monst frequent pair ##
    #############################

    max = 0
    max_key = None
    for key in map_bpe:
        #print(key, map_bpe[key])
        if map_bpe[key] > max:
            max = map_bpe[key]
            max_key = key
    max_occ = 0
    for key in map_bpe:
        if map_bpe[key] == max:
            max_occ += 1
            
    #save which tokens in what constellation we merge - these are basically the rules we need to alter "encode" new images: which tokens (0,1) in what constellation (2,3) are merged
    #we can use for new images: just go over this loop, but load from the rules instead of the original data and we're golden
    rules.append([max_key[0], max_key[1], max_key[2], max_key[3]])

    #############################
    ########## STEP 2 ###########
    ### replace most frequent ###
    ###### with new token #######
    #############################

    #go over the data a second time; this time: identify all occurences of the most frequent BPE
    #and replace them with a new token
    it = 0
    total_saved = 0
    for unique, classes, path in unique_train:
        if time.time() - last_output > 60:
            print("\tDONE WITH ",it/ len(unique_train)*100,"% on replacement pass")
            last_output = time.time()
        
        changed = False
        used_uniques = {}
        for x in range(0, unique.size()[0]):
            for y in range(0, unique.size()[1]):
                if unique[x][y].item() in used_uniques: #skip if we already replaced this unique ID
                    continue
                #get the neighbours
                for n in [(x+1,y), (x,y+1)]: 
                    #skip if neighbour is out of bounds
                    n_x, n_y = n[0], n[1]
                    if n_x >= unique.size()[0] or n_y >= unique.size()[1]:
                        continue
                    if unique[n_x][n_y].item() in used_uniques: #skip if we already replaced this neighbour unique ID (this token has already been merged this pass)
                        continue

                    #skip if they're the same unique ID
                    if unique[x][y] == unique[n_x][n_y]:
                        continue
                    
                    #find the token classes we merge, i.e. the greyscale values (intially) or the token IDs (later)
                    class_a, class_b = classes[x][y].item(), classes[n_x][n_y].item()
                    
                    #only if the classes fit, even further consider if we're at the right place to merge
                    if max_key[0] != class_a or max_key[1] != class_b:
                        continue
                    
                    #get the centres
                    anchor_a, anchor_b = get_anchor(unique, x, y), get_anchor(unique, n_x, n_y)
                    
                    v = (anchor_a - anchor_b)
                    vx = v[0].item()
                    vy = v[1].item()
                    
                    #check if V matches, i.e. check if we're EXACTLY at the right place to merge:
                    if max_key[2] == vx and max_key[3] == vy:
                        a_unique = unique[x][y]

                        #replace with new token:
                        #   replace unique ID of every part of the SECOND token with the first one;
                        #   the unique ID of the first one stays the same

                        #replace the A item
                        indices_to_relabel_a = torch.nonzero(unique == unique[x,y])
                        indices_to_relabel_b = torch.nonzero(unique == unique[n_x,n_y])
                        
                        #replace the B item, both classID and uniqueID
                        for i in range(0, indices_to_relabel_b.size()[0]):
                            #   replace unique ID of the B element with the A one
                            unique[indices_to_relabel_b[i][0]][indices_to_relabel_b[i][1]] = a_unique
                            #   replace class ID of the B element with the new one
                            classes[indices_to_relabel_b[i][0]][indices_to_relabel_b[i][1]] = NEW_TOKEN_ID
                        #replace the classID (=what kind of token is this) of the A item; leave the unique ID as is
                        for i in range(0, indices_to_relabel_a.size()[0]):
                            #   replace class ID of the A element with the new one
                            classes[indices_to_relabel_a[i][0]][indices_to_relabel_a[i][1]] = NEW_TOKEN_ID
                        changed = True
                        used_uniques[unique[x][y].item()] = True
        
        #save if we changed something - yes, we do always safe, because later, e.g. ImageNet is 1TB large, we can't keep it in memory
        if changed:
            total_saved += 1
            torch.save([unique, classes], path)

            #count the unique entries in unique:
            unique_entries = len(torch.unique(unique))
            #print("COMPRESSION AT THE END: ",unique_entries/(WIDTH*HEIGHT) * 100,"%")
        it += 1
    
    #increment token ID
    NEW_TOKEN_ID += 1
    assert(total_saved > 0)
    CODEWORDS_ADDED += 1
    
    time_so_far = time.time() - time_start

    time_so_far = time_so_far / CODEWORDS_ADDED * (CODEWORDS_TO_ADD - CODEWORDS_ADDED)

    print("\n\n---> DONE ADDING CODEWORD",CODEWORDS_ADDED,"/",CODEWORDS_TO_ADD,", (POSSIBLY) REPEATING... Time left (est.): ",time_so_far/60," minutes\n\n")
torch.save(rules, "rules.dat")

In [None]:
unique_train = UniqueDataset()

<h1>3. Export token shapes & sequences</h1>

In [None]:
#create folder tokenshapes
if not os.path.exists('tokenshapes'):
    os.makedirs('tokenshapes')
    
#export token shapes:
#we store a grid in the form of:
#[ -1  0  2  4 -1]
#[  2  1  3  5  6]
#[  2 -1 -1 -1 -1]
#[ -2 -1 -1 -1 -1]
#--> -1 means "empty/not this token", everything else are the ORIGINAL, underlying base token IDs
exported = {}
token_sizes = {}
for unique, classes, path in unique_train:
    path_v2 = path.replace("transcribed_data", "transcribed_data_og")
    tokens_original = torch.load(path_v2)
    total_size = 0
    uniques_exported = {}
    
    for x in range(0, unique.size()[0]):
        for y in range(0, unique.size()[1]):
            class_to_export = classes[x][y].item()
            if class_to_export not in exported:
                uid_to_export = unique[x][y].item()
                #find most left elmenet and most top element of the token (bounding box; if we'd take the top-left corner, like in the example in this cell, we'd cut stuff off!)
                most_left = WIDTH+1
                most_top = HEIGHT+1
                indices = torch.nonzero(unique == uid_to_export)
                for i in range(0, indices.size()[0]):
                    most_left = min(most_left, indices[i][0])
                    most_top = min(most_top, indices[i][1])
                
                #find diff to most left and most top - example at the top would have an offset of -1, 0
                offset_x = most_left - x
                
                crop_unique = unique.clone()[most_left:, most_top:]
                crop_classes = classes.clone()[most_left:, most_top:]
                crop_original = tokens_original.clone()[most_left:, most_top:]

                #remove everything that isn't the unique ID
                wrong = torch.nonzero(crop_unique != uid_to_export)
                for i in range(0, wrong.size()[0]):
                    crop_classes[wrong[i][0]][wrong[i][1]] = -1
                    crop_original[wrong[i][0]][wrong[i][1]] = -1
                
                #crop properly
                while (crop_classes[-1] != -1).sum() == 0:
                    crop_classes = crop_classes[:-1]
                    crop_original = crop_original[:-1]
                while (crop_classes[:,-1] != -1).sum() == 0:
                    crop_classes = crop_classes[:,:-1]
                    crop_original = crop_original[:,:-1]
                #count elements != -1:
                token_size = (crop_original != -1).sum().item()
                token_sizes[classes[x,y].item()] = token_size
                total_size += token_size
                
                uniques_exported[uid_to_export] = True
                torch.save([crop_original, offset_x], "tokenshapes/"+str(class_to_export)+".dat")
                exported[class_to_export] = True
            elif not unique[x][y].item() in uniques_exported: #track size of unique tokens that we exported
                total_size += token_sizes[classes[x,y].item()]
                uniques_exported[unique[x][y].item()] = True
    
    assert(total_size == WIDTH*HEIGHT)

In [None]:
if not os.path.exists('tokensequences'):
    os.makedirs('tokensequences')

for k in range(0, FILE_MODULO):
    if not os.path.exists('tokensequences/' + str(k)):
        os.makedirs('tokensequences/' + str(k))

In [None]:
#export squences with positional information:
#go over the data a third item by item, export as sequence:
# [token_id, x, y]

index = 0
sequence_lengths = []
for unique, classes, _ in unique_train:
    #go over every pixel coordinate, write out sequence:
    sequence = []
    used_ids = {}
    for x in range(0, unique.size()[0]):
        for y in range(0, unique.size()[1]):
            #if a unique token does NOT exist already (=top left of it) --> add to sequence
            if unique[x][y].item() in used_ids:
                continue
            used_ids[unique[x][y].item()] = True
            sequence.append([classes[x][y].item(), x, y])
    if index == 0:
        visualise_unique(unique)
    sequence_lengths.append(len(sequence))
    if index % 100 == 0:
        print("DONE EXPORTING ",index,"/",len(unique_train)," SEQUENCES, AVG LENGTH: ",sum(sequence_lengths)/len(sequence_lengths),", compared to ",WIDTH*HEIGHT,", i.e. ",sum(sequence_lengths)/len(sequence_lengths)/(WIDTH*HEIGHT)*100,"% of the intial size")
    #check if folder tokensequences/index exists:
    torch.save(sequence, "tokensequences/"+str(index % FILE_MODULO)+"/"+str(index)+".dat")
    index += 1

<h1>4. Testwise load a shortened sequence & render it<h1>

In [None]:
def resolve_token(grid, token_to_put, x, y):
    #load full token shape, then put one of the original ones into grid:
    if token_to_put < START_OF_NEW_TOKENS:
        #traditional, 1-by-1 token - just write it and be done
        grid[x, y] = token_to_put
        return grid
    tokendata = torch.load("tokenshapes/"+str(token_to_put)+".dat")
    tokenshape = tokendata[0]
    offset_x = tokendata[1]
    for tx in range(0, tokenshape.size()[0]):
        for ty in range(0, tokenshape.size()[1]):
            if tokenshape[tx, ty] != -1:
                grid[x + tx + offset_x, y + ty] = tokenshape[tx, ty]
    return grid

def generate_sequence(width, height):
    #make sure we have a shape that tells us if a token has been (implicitly) generated
    #-1 = not generated yet
    output_tokens = torch.ones(width, height).long() * -1
    
    sequence_index, sequence_x, sequence_y = 0, 0, 0 #current position in generative process
    current_sequence = []
    #first at 0/0:
    
    #just load first token from sequence
    dummy_sequence = torch.load("tokensequences/0/0.dat")

    #add the first code ala [token, x, y] to the sequence
    current_sequence.append(dummy_sequence[0])
    #enter tokens into output token grid:
    output_tokens = resolve_token(output_tokens, dummy_sequence[sequence_index][0], dummy_sequence[sequence_index][1], dummy_sequence[sequence_index][2])
    sequence_index += 1
    
    #as long as there's token to generate
    while (output_tokens == -1).long().sum() > 0:
        #fish next token from our sequence list
        sequence = dummy_sequence[sequence_index]
        sequence_x, sequence_y = sequence[1], sequence[2]
        assert(output_tokens[sequence_x, sequence_y] == -1)
        next_token = sequence[0]
            
        #put token in
        current_sequence.append([next_token, sequence_x, sequence_y])
        #enter tokens into output token grid:
        output_tokens = resolve_token(output_tokens, next_token, sequence_x, sequence_y)
    
        sequence_index += 1

    #todo: decode with VQ-VAE
    return output_tokens

show(generate_sequence(WIDTH, HEIGHT)[None].float() / 255.0)

<h1>5. Train transformer</h1>

<h3>5.1 Hyper parameters & Setup</h3>

In [None]:
BATCH_SIZE = 16

#transformer hyper parameters:
DIMENSIONS = 512 #basically what channels are to a regular NN
NUM_HEADS = 8 #number of heads, i.e. we split the input into 8 parts and process them in parallel, with the same architecture. this is the "multi-head" part of the transformer
NUM_LAYERS = 8 #number of layers in the transformer, i.e. how many times we apply the same architecture to the input

assert(DIMENSIONS % 8 == 0) #for the positional encoding to work, the number of dimensions must be divisible by 8 because we describe the position as sine and cosine for x and y position, doing this twice (once for token itself, once for follow-up token)
assert(DIMENSIONS % NUM_HEADS == 0) #the number of dimensions must be divisible by the number of heads: the dimensions are split between the heads

In [None]:
#(i.e. COUNT of unique tokens, not the maximum value)
SOS_token = torch.tensor([NUM_TOKENS]) #we use this token to signal the start of a sequence
EOS_token = torch.tensor([NUM_TOKENS+1]) #we use this token to signal the end of a sequence
pre_loaded_tokens = []
for i in range(BASE_TOKENS, NUM_TOKENS):
    pre_loaded_tokens.append(torch.load("tokenshapes/"+str(i)+".dat", weights_only=False))

def grid_positional_encoding(embed_dims):
    grid = torch.ones(WIDTH, HEIGHT, 2)
    for x in range(0, WIDTH):
        grid[x,:,0] = x / WIDTH
    for y in range(0, HEIGHT):
        grid[:,y,1] = y / HEIGHT
    rets = []
    for i in range(embed_dims):
        for fn in [torch.sin, torch.cos]:
            rets.append(fn((2. ** i) * grid))
    return torch.cat(rets, -1)

positional_encodings_half = grid_positional_encoding(int(DIMENSIONS/8))
positional_encodings_quarter = grid_positional_encoding(int(DIMENSIONS/16))

#(i.e. COUNT of unique tokens, not the maximum value)
SOS_token = torch.tensor([NUM_TOKENS]) #we use this token to signal the start of a sequence
EOS_token = torch.tensor([NUM_TOKENS+1]) #we use this token to signal the end of a sequence
pre_loaded_tokens = []
for i in range(BASE_TOKENS, NUM_TOKENS):
    pre_loaded_tokens.append(torch.load("tokenshapes/"+str(i)+".dat", weights_only=False))

def grid_positional_encoding(embed_dims):
    grid = torch.ones(WIDTH, HEIGHT, 2)
    for x in range(0, WIDTH):
        grid[x,:,0] = x / WIDTH
    for y in range(0, HEIGHT):
        grid[:,y,1] = y / HEIGHT
    rets = []
    for i in range(embed_dims):
        for fn in [torch.sin, torch.cos]:
            rets.append(fn((2. ** i) * grid))
    return torch.cat(rets, -1)

positional_encodings_half = grid_positional_encoding(int(DIMENSIONS/8))
positional_encodings_quarter = grid_positional_encoding(int(DIMENSIONS/16))

In [None]:
#use additional encoding information (token shape & next token)
USE_NPE = True
USE_IPE = True 

def compute_positional_encoding(x, y, tokenClass, n_x, n_y):
    #encode three things: a) position of the token b) position of the next token c) the INTEGRATED AREA of all pieces belonging to the token
    #this way, the model can learn to best "fill in" the token
    #a) position of the token itself
    positional_encoding = torch.zeros(DIMENSIONS)
    positional_encoding[0:int(DIMENSIONS/4)] = positional_encodings_quarter[x][y]
    #b) position of the next token - only relevant for every token that isn't the last one; only use for ours
    if n_x != None and USE_NPE:
        positional_encoding[int(DIMENSIONS/4):int(DIMENSIONS/4*2)] = positional_encodings_quarter[n_x][n_y]
        
    #c) the INTEGRATED AREA of all pieces belonging to the token
    #   -look up token shape mask (i.e. all X-Y coordinates that belong to the token), clone it, and offset all values by token position + offset values
    #   -embed individual values
    #   -sum them up
    if tokenClass < BASE_TOKENS and USE_IPE:
        #just use the same PE
        positional_encoding[int(DIMENSIONS/2):] = positional_encodings_half[x][y]
    elif USE_IPE:
        tokenshape = pre_loaded_tokens[tokenClass-BASE_TOKENS][0]
        offset_x, offset_y = pre_loaded_tokens[tokenClass-BASE_TOKENS][1], 0 ### TODO: checkme
        #get all indices that are not -1, i.e. all positions that are meaningful for this token:
        indices = torch.nonzero(tokenshape != -1)
        indices[:,0] += x + offset_x
        indices[:,1] += y + offset_y
        #put these indices into positional_encodings_half:
        x_indices = indices[:, 0].long()
        y_indices = indices[:, 1].long()
        full_area = positional_encodings_half[x_indices, y_indices, :]
        #now a tensor of size [no_elements X DIMENSIONS/8]: sum up all elements along the first dimension
        full_area = full_area.sum(0)
        positional_encoding[int(DIMENSIONS/2):] = full_area
    return positional_encoding

class UniqueDataset(torch.utils.data.Dataset):
    def __init__(self, path):
        self.files = []
        self.pre_path = path
        #recursively get all files in the folder
        total = 0
        LIMIT = 128
        for i in range(FILE_MODULO):
            for root, dirs, files in os.walk(self.pre_path + "tokensequences/" + str(i)):
                for file in files:
                    self.files.append(self.pre_path + "tokensequences/" + str(i) + "/" + file)
                    total += 1
                    if LIMIT != None and total > LIMIT:
                        break
                if LIMIT != None and total > LIMIT:
                    break
            if LIMIT != None and total > LIMIT:
                break
            
        if PRE_LOAD: #pre-load all data & tokenise it, so we don't have to do it on the fly
            self.data = []
            start = time.time()
            
            for file in self.files:
                self.data.append(self.tokenise(torch.load(file, weights_only=False)))
                if time.time() - start > 30:
                    print("DONE WITH ", len(self.data)/len(self.files)*100, "%...")
                    start = time.time()

    def __len__(self):
        return len(self.files)
    
    def tokenise(self, data):
        #go over data:
        #   build positional encoding
        #   add SOS token
        #   add EOS tokens until we're at max length
        sequence = torch.ones(WIDTH * HEIGHT + 1) * EOS_token
        sequence[0] = SOS_token
        
        positional_encoding = torch.zeros(WIDTH * HEIGHT + 1, DIMENSIONS)
        #SOS/EOS token
        #don't get any extra information:
        #   -no positional encoding (0 everywhere); we don't know where the token is
        #   -no next token information (0 everywhere); next token is always at [0,0]
        #   -no token shape information (IPE); there's no shape to the token
        for index in range(0, len(data)):
            tokenClass = data[index][0]
            x, y = data[index][1], data[index][2]

            n_x, n_y = None, None
            if index+1 < len(data):
                n_x, n_y = data[index+1][1], data[index+1][2]
            positional_encoding[index+1] = compute_positional_encoding(x, y, tokenClass, n_x, n_y)
            sequence[index+1] = tokenClass
        
        #make sure to write out: what tokens are impossible to reach?
        training_mask = torch.ones(WIDTH * HEIGHT, NUM_TOKENS)

        return sequence.long(), positional_encoding, training_mask

    def __getitem__(self, idx):
        if PRE_LOAD:
            data, position_grid, training_mask = self.data[idx]
        else:
            data = torch.load(self.files[idx], weights_only=False)
            data, position_grid, training_mask = self.tokenise(data)
        
        return data, position_grid, training_mask

dataset = UniqueDataset("")
for data, pos, training_mask in dataset:
    break

In [None]:
#if this is linux: use 8 workers, otherwise 0 (windows things...)
if os.name != 'nt':
    train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
else:
    train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

def show(img):
    assert(len(img.size()) == 3)
    plt.imshow(img.clamp(0.0, 1.0).squeeze(), cmap='gray')
    plt.show()

<h3>5.2 The transformer itself</h3>

In [None]:
#this is the transformer model, which has the TransformerEncoder at its core (a little odd, but with masking, there is no difference between an encoder and a decoder)
#it also does a little other stuff, like embedding the tokens and adding the positional encoding, and applying a linear layer at the end to get the classification output
class TransformerDecoderModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.layer = nn.TransformerEncoderLayer(d_model=DIMENSIONS, nhead=NUM_HEADS, dim_feedforward=DIMENSIONS, batch_first=True)
        self.transformer = torch.nn.TransformerEncoder(self.layer, num_layers=NUM_LAYERS)

        #train an embedding of the tokens along with the model: encoding greyscale scalar values directly is always harder to learn for a NN than some higher dimensional embedding
        self.embedding = nn.Embedding(NUM_TOKENS + 2, DIMENSIONS) #we have our greyscale values PLUS the start/end token

        #this turns our transformer output into (logits of) a probability distribution over the tokens
        #i.e. says "which token is most likely to come next"
        self.linear = nn.Linear(DIMENSIONS, NUM_TOKENS + 2)
    
    #this produces a mask for the transformer, which is used to mask out the future tokens in the sequence
    #i.e. this to make sure we not only train for a sequence of length k to predict the k+1th token, but to predict
    #all tokens in the sequence at once
    def get_target_mask(self, squence_length):
        #2d mask:
        #  [0., -inf, -inf],
        #  [0.,   0., -inf],
        #  [0.,   0.,   0.]
        #  etc.
        
        #produce a lower triangular matrix
        mask = torch.tril(torch.ones(squence_length, squence_length) == 1).float()
        mask = mask.masked_fill(mask == 0, float('-inf')) #set the values to -inf, because that (in the softmax inside the transformer) will make the values 0
        mask = mask.masked_fill(mask == 1, float(0.0))
        return mask
    
    def forward(self, tokens, pos_embedding):
        #INPUT: a FULL sequence of tokens, starting witH SOS token, followed by the content tokens, followed by an EOS token
        #1. embed tokens
        tokens = self.embedding(tokens)
        #2. add positional embeddings & next-token-embedding: one says where a token is, one says where the NEXT token from this one is going to be
        tokens = tokens + pos_embedding
        #3. apply transformer to [b X seq X d], then apply the linear layer that acts as a classifier ("which token is next?")
        #note that we only apply the transformer to all but the last token, as we never need to predict what comes after the EOS token
        return self.linear(self.transformer(tokens[:,:-1], self.get_target_mask(squence_length=tokens.size()[1]-1).to(tokens.device), is_causal=True))
    
    #TODO: add next to pos, then cut that off
    def predict(self, tokens, pos_embedding):
        #same as forward, but we always want to predict the next token (we will throw in incomplete sequences in here, e.g. predicting the 4th token from the first 3)
        tokens = self.embedding(tokens)
        tokens = tokens + pos_embedding
        return self.linear(self.transformer(tokens, self.get_target_mask(squence_length=tokens.size()[1]).to(tokens.device), is_causal=True))

#create the model & test to throw some stuff in there
transformer_decoder = TransformerDecoderModel()
print("Transformer has ", sum(p.numel() for p in transformer_decoder.parameters()), " parameters.")

for data, pos, training_mask in train_loader:
    #test if everything goes through
    print(transformer_decoder(data, pos).size())
    break

<h3>5.3 Train the model</h3>

In [None]:
def resolve_token(grid, token_to_put, x, y):
    #load full token shape, then put one of the original ones into grid:
    if token_to_put < BASE_TOKENS:
        #traditional, 1-by-1 token - just write it and be done
        grid[x, y] = token_to_put
        return grid
    tokendata = torch.load("tokenshapes/"+str(token_to_put)+".dat", weights_only=False)
    tokenshape = tokendata[0]
    offset_x = tokendata[1]
    for tx in range(0, tokenshape.size()[0]):
        for ty in range(0, tokenshape.size()[1]):
            if tokenshape[tx, ty] != -1:
                grid[x + tx + offset_x, y + ty] = tokenshape[tx, ty]
    return grid

def nucleus(data, threshold=0.9):
    #sort all tokens by probability, then calculate the cumulative probability (i.e. the probability of the most likely token, the two most likely tokens, the three most likely tokens, etc.)
    sorted_probs, sorted_indices = torch.sort(data, descending=True, dim=-1)
    cum_probs = torch.cumsum(sorted_probs, dim=-1)
    mask = cum_probs <= threshold #find the point where we have enough tokens to reach the threshold, i.e. we have the most likely tokens we want to sample from
    mask[:,0] = True #make sure we always include at least the most likely token
    filtered_probs = torch.where(mask, sorted_probs, torch.zeros_like(sorted_probs)) #set probabilities to zero for all that are not in the top-k that form the most likely tokens
    sampled_indices = torch.multinomial(filtered_probs, num_samples=1) #sample from the filtered probabilities
    selected_indices = sorted_indices.gather(dim=-1, index=sampled_indices) #get the original indices of the sampled tokens (the currently sampled ones are indices from the sorted list)
    return selected_indices.squeeze(-1)

def predict_next(cur_seq, pos_enc, x, y, strategy="nucleus"):
    probs = transformer_decoder.predict(cur_seq, pos_enc)
    probs = probs[:,-1:].view(-1, NUM_TOKENS + 2)
    #TODO: make sure we filter out bad tokens, i.e.:
    #   -clamp SOS/EOS to zero
    #   -only allow tokens that can fit the space
    probs = torch.nn.functional.softmax(probs, -1)

    #clamp SOS/EOS to zero:
    probs[0,SOS_token.item()] = 0.0
    probs[0,EOS_token.item()] = 0.0

    for i in range(BASE_TOKENS, NUM_TOKENS):
        #check space
        space_enough = True
        #check tokensize:
        tokendata = pre_loaded_tokens[i-BASE_TOKENS]
        shape = tokendata[0]
        offset_x = tokendata[1]
        for tx in range(0, shape.size()[0]):
            for ty in range(0, shape.size()[1]):
                if shape[tx, ty] != -1:
                    if x + tx + offset_x >= WIDTH or y + ty >= HEIGHT or x + tx + offset_x < 0 or y + ty < 0:
                        space_enough = False
                        break
            if not space_enough:
                break
        #set probability to 0:
        if not space_enough:
            probs[0,i] = 0.0

    #re-normalise probs to add up to 1:
    probs = probs / probs.sum(dim=-1, keepdim=True)

    if strategy == "max":
        probs = probs.argmax(dim=-1)
    elif strategy == "nucleus":
        #just take nucleus:
        probs = nucleus(probs, threshold=0.9)[:,None]
    else: #default = multinomial
        probs = torch.multinomial(probs, num_samples=1)

    return probs.item()

SOS_token_gpu = SOS_token.to(DEVICE)
def generate_sequence(width, height):
    #make sure we have a shape that tells us if a token has been (implicitly) generated
    #-1 = not generated yet
    output_tokens = torch.ones(width, height).long() * -1
    
    sequence_x, sequence_y = 0, 0 #current position in generative process
    
    current_sequence = torch.ones(1, 1).long().to(DEVICE) * SOS_token_gpu #start with SOS token
    positional_encoding = torch.zeros(WIDTH * HEIGHT + 1, DIMENSIONS).to(DEVICE)
    
    #as long as there's tokens to generate:
    generated = 0
    while (output_tokens == -1).long().sum() > 0:
        #add positional encoding for the token to predict: where do we add stuff right now? we need this to give the net an easier time to estimate where we are, so we add this now
        #this is added to every token so the net knows where the next token is going to be
        token = predict_next(current_sequence, positional_encoding[0:current_sequence.size()[1]], sequence_x, sequence_y)
        current_sequence = torch.cat([current_sequence, torch.tensor([[token]]).to(DEVICE)], 1)
        
        #enter tokens into output token grid:
        output_tokens = resolve_token(output_tokens, token, sequence_x, sequence_y)
        
        #find new X/Y position to put the token by finding the first occurence of -1 in the output tokens
        sequence_index = (output_tokens == -1).long().argmax().item()
        pos_x = sequence_index // WIDTH
        pos_y = sequence_index % HEIGHT

        #prepare positional encoding for the predicted token: where is the token, what is the token (for shape), and where is the next token
        generated += 1
        positional_encoding[generated] = compute_positional_encoding(sequence_x, sequence_y, token, pos_x, pos_y)

        sequence_x, sequence_y = pos_x, pos_y

    #afterwards: decode with VQ-VAE (if not in MNIST case)
    return output_tokens

In [None]:
from tqdm import tqdm

transformer_decoder = TransformerDecoderModel().to(DEVICE)
print("Transformer has ", sum(p.numel() for p in transformer_decoder.parameters()), " parameters.")
optimiser = torch.optim.AdamW(transformer_decoder.parameters(), lr=0.0001) #use AdamW here so you don't have to run schedulefree for this demo

In [None]:
loss_function = torch.nn.BCEWithLogitsLoss(reduction='none')
losses = []
epochs = []
for epoch in range(0, 1000):
    transformer_decoder.train()

    total_loss = 0
    print("\t\ttraining...")
    for i in range(0, 100): #compensate for so few examples - else, we only go for like 10 batches^^
        for data, pos, training_mask in train_loader:
            #turn images into tokens & append SOS and EOS tokens
            data, pos, training_mask = data.to(DEVICE), pos.to(DEVICE), training_mask.to(DEVICE)
            
            #predict the next token for the whole sequence
            #i.e. for an input of [b x s] many tokens, we get [b x s x num_tokens] many logits (for each sub-sequence, the probability distribution over the tokens),
            #i.e. we predict the next token for [SOS], for [SOS, 0], for [SOS, 0, 255], etc.
            output = transformer_decoder(data, pos)
            #turn our target (=our input) that is currently a [b x s] many tokens into a one-hot encoded tensor of shape [b x s x num_tokens]
            #we always take the tokens shiftet to the right ([1:]), because we want to predict the next token for each token in the sequence
            #i.e. predict the 3rd token given the first two, predict the 4th token given the first three, etc.
            target = torch.nn.functional.one_hot(data[:,1:], num_classes=NUM_TOKENS+2).float()
            loss = loss_function(output, target) #apply binary crossentropy
            #mask out everything we won't generate later on anyway; don't waste resources on learning a token doesn't fit!
            loss[:,:,:-2] = loss[:,:,:-2] * training_mask #mask out impossible tokens
            loss[:,:,-2:] = 0.0 #mask out loss for SOS and EOS tokens
            loss = loss.mean() #reduce to scalar

            #do the optimisation step
            optimiser.zero_grad()
            loss.backward()
            optimiser.step()
        
            total_loss += loss.detach().item()
    print("\tTRAIN LOSS: ",total_loss / len(train_loader))
    losses.append(total_loss / len(train_loader))
    epochs.append(epoch)
    
    print("***DONE WITH EPOCH ",epoch," ***")
    
    if WIDTH == 12 and HEIGHT == 12: #MNIST, just show something
        print("\tOutputs after epoch",epoch,":")
        with torch.no_grad():
            transformer_decoder.train(False)
            for i in range(0, 3):
                output = generate_sequence(WIDTH, HEIGHT)
                show(output.view(1, WIDTH, HEIGHT) / 255.0)
            transformer_decoder.train(True)