In [55]:
import datasets
import transformers
import torch
import torch.nn.functional as F
from datasets import load_dataset
from sklearn.model_selection import train_test_split
import torch.nn as nn

from collections import defaultdict
from collections.abc import (
    Callable,
    Iterable
)
import numpy as np

In [35]:
config = {
    "seed": 0, 
    "device": "cuda", 
    "features_dtype": torch.float32,
}

In [36]:
torch.manual_seed(config["seed"])

<torch._C.Generator at 0x168115570>

In [None]:
chess_features, chess_labels = torch.load('Sample Dataset.pt')

In [33]:
features_train, features_valid, labels_train, labels_valid = train_test_split(
    chess_features, chess_labels, test_size=0.2, random_state=42
)

In [34]:
torch.save(
    {
        "train_features": features_train,
        "train_labels": labels_train,
        "valid_features": features_valid,
        "valid_labels": labels_valid
    },
    "preprocessed_train_valid_data"
)

torch.Size([5984, 8, 8, 9])
torch.Size([1497, 8, 8, 9])
torch.Size([5984, 2])


In [40]:
n, h, w, c = features_train.shape 
print(n,h,w,c)

5984 8 8 9


In [60]:
#patching the images
def patchify(images, n_patches):
    '''
    n is the number of images, 
    c is the number of channels, in our case it will be 9, 
    h is the height of the image and w is the width of the image, both be 8 in our case
    '''
    n, h, w, c = images.shape 

    assert h == w, "Patchify method is implemented for square images only"
    
    patches = torch.zeros(n, n_patches ** 2, h * w * c// n_patches ** 2)
    patch_size = h // n_patches

    for idx, image in enumerate(images):
        for i in range(n_patches):
            for j in range(n_patches):
                patch = image[i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size, :]
                patches[idx, i * n_patches + j] = patch.flatten()
    return patches

In [None]:
for idx, image in enumerate(features_train):
    print(idx, image.shape)

In [56]:
#try 2 of patching, does not fully work
def patchify(images, n_patches):
    '''
    n is the number of images, 
    c is the number of channels, in our case it will be 9, 
    h is the height of the image and w is the width of the image, both be 8 in our case
    '''
    n, h, w, c = images.shape 

    assert h == w, "Patchify method is implemented for square images only"
    
    patch_size = h // n_patches

    stride = patch_size

    patches = images.unfold(2, patch_size, stride).unfold(3, patch_size, stride)
    patches = patches.contiguous().view(-1, c, patch_size, patch_size)

    return patches


In [63]:
class MyViT(nn.Module):
  "Here we have initialization of the model and patching"
  def __init__(self, chw=(9, 8, 8), n_patches=4):
    # Super constructor
    super(MyViT, self).__init__()

    # Attributes
    self.chw = chw # (C, H, W)
    self.n_patches = n_patches

    assert chw[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
    assert chw[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"

  def forward(self, images):
    patches = patchify(images, self.n_patches)
    return patches

In [65]:
model = MyViT(
    chw=(9, 8, 8),
    n_patches=4
  )
print(model(features_train).shape) # torch.Size([5984, 16, 36])

torch.Size([5984, 16, 36])
torch.Size([16, 36])
tensor([[0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 1.,
         1., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 1.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 1.],
        [0., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         1., 0., 0., 0., 0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 1.],
        [1., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 1.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,