In [81]:
import datasets
import transformers
import torch
import torch.nn.functional as F
from datasets import load_dataset
from sklearn.model_selection import train_test_split
import torch.nn as nn

from collections import defaultdict
from collections.abc import (
    Callable,
    Iterable
)
import numpy as np

In [82]:
config = {
    "seed": 0, 
    "device": "mps", 
    "features_dtype": torch.float32,
}

In [83]:
torch.manual_seed(config["seed"])

<torch._C.Generator at 0x168045730>

In [84]:
chess_features, chess_labels = torch.load('data/sample_dataset.pt')

  chess_features, chess_labels = torch.load('data/sample_dataset.pt')


In [85]:
features_train, features_valid, labels_train, labels_valid = train_test_split(
    chess_features, chess_labels, test_size=0.2, random_state=42
)

In [86]:
n, h, w, c = features_train.shape 
print(n,h,w,c)

5984 8 8 9


In [87]:
#patching the images
def patchify(images, n_patches):
    '''
    n is the number of images, 
    c is the number of channels, in our case it will be 9, 
    h is the height of the image and w is the width of the image, both be 8 in our case
    '''
    n, h, w, c = images.shape 

    assert h == w, "Patchify method is implemented for square images only"
    
    patches = torch.zeros(n, n_patches ** 2, h * w * c// n_patches ** 2)
    patch_size = h // n_patches

    for idx, image in enumerate(images):
        for i in range(n_patches):
            for j in range(n_patches):
                patch = image[i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size, :]
                patches[idx, i * n_patches + j] = patch.flatten()
    return patches

In [88]:
# define function to get positional embeddings for a given sequence length and d, the current length of each element in the sequence
def get_positional_embeddings(sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result

Idea for two move: encode can_move element of d vector to 0 (restricting movement to only piece selected)

In [89]:
class MyViT(nn.Module):
  "Here we have initialization of the model and patching"
  def __init__(self, chw=(9, 8, 8), n_patches=4, hidden_layer_dim=18):
    # Super constructor
    super(MyViT, self).__init__()

    # Attributes
    self.chw = chw # (C, H, W)
    self.n_patches = n_patches
    self.hidden_layer_dim = hidden_layer_dim

    assert chw[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
    assert chw[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"
    
    self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)
    
    # mapping to a linear vector
    self.input_vector_dim = int(chw[0] * self.patch_size[0] * self.patch_size[1])
    self.linear_mapper = nn.Linear(self.input_vector_dim, self.hidden_layer_dim)
    
    # create a classification token
    self.class_token = nn.Parameter(torch.rand(1, self.hidden_layer_dim))
    
    # create positional embeddings (reference from Vaswani et. al., 2017)
    self.pos_embed = nn.Parameter(torch.tensor(get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden_layer_dim)))
    self.pos_embed.requires_grad = False # note: we don't want to change the positions during autograd
    
    
  def forward(self, images):
    # make patches and linearly map them to hidden layer size
    patches = patchify(images, self.n_patches)
    tokens = self.linear_mapper(patches)
    
    # add classification token to patches
    tokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])
    
    # add positional embedding to each of the dataset elements
    pos_embed = self.pos_embed.repeat(n, 1, 1)
    out = tokens + pos_embed
    
    return out

In [90]:
model = MyViT(
    chw=(9, 8, 8),
    n_patches=4
  )
print(model(features_train).shape) # torch.Size([5984, 16, 36])

  self.pos_embed = nn.Parameter(torch.tensor(get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden_layer_dim)))


torch.Size([5984, 17, 18])


In [91]:
torch.save(
    {
        "train_features": features_train,
        "train_labels": labels_train,
        "valid_features": features_valid,
        "valid_labels": labels_valid
    },
    "preprocessed_train_valid_data"
)