In [None]:
'''
Steps:
    1) Divide imput image into patches
    2) Flatten the patches to be handles as tokens
    3) Add classification token at the beginning
    4) Feed tokens into the encoder
    5) Number of output of the encoder is equal to the size of the input
    6) 
'''

In [54]:
import torch
import torch.nn as nn 
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from sklearn.model_selection import train_test_split
import numpy as np
from torchsummary import summary
from PIL import Image
import torch.nn.functional as F
from torchvision.datasets import FakeData # Test
from torchvision.transforms import v2

In [55]:
# TODO: LOAD OUR DATA 256X256

In [99]:
# https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html
# Paper: https://arxiv.org/abs/1706.03762 
# Tutorial: https://www.akshaymakes.com/blogs/vision-transformer

# Hyperparameters
d_model = 256 # Number of features in each patch
nhead = 8
num_encoder_layers = 6
num_decoder_layers = 6 
num_classes = 2 
norm_first = False 
patch_size = 16 
image_size = 256 # Square it for true size 
activation = 'relu'
dim_feedforward = 2048
num_channels = 3

class VisionTransformerCoinImageClassifier(nn.Module):
    def __init__(
        self, 
        d_model, 
        nhead, 
        num_encoder_layers, 
        num_decoder_layers, 
        dim_feedforward, 
        activation, 
        norm_first, 
        patch_size, 
        image_size, 
        num_classes,
    ):
        super(VisionTransformerCoinImageClassifier, self).__init__()

        # Calculate the number of patches
        num_patches = (image_size // patch_size) ** 2 # **2 = ^2 due to X(W) and Y(H) position considered

        # Path embedding layer - Sliding window - Non-overlapping image patches
        self.patch_embedding = nn.Conv2d(
            in_channels=num_channels, 
            out_channels=num_patches, 
            kernel_size=patch_size, 
            stride=patch_size
        )
        
        '''
        # Positional Encoding - Fixed positional encoding
        self.positional_encoding = self._generate_positional_encoding(256) # https://machinelearningmastery.com/a-gentle-introduction-to-positional-encoding-in-transformer-models-part-1/
        '''
        
        # Encoder - Multihead Attention, FFN and Layer Normalization 
        # Attention mechanism: https://machinelearningmastery.com/the-transformer-attention-mechanism/
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=nhead, 
            dim_feedforward=dim_feedforward,
            activation=activation,
            layer_norm_eps=1e-5, # Epsilon value for the layer normalization
            norm_first=False, # Layer normalization is after each sublayer i.e. normfirst = False
            bias=True, # Additive bias i.e. True
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer=encoder_layer, 
            num_layers=num_encoder_layers,
        )

        # Decoder - Multihead Attention, FFN and Layer Normalization
        # https://pytorch.org/docs/stable/generated/torch.nn.TransformerDecoderLayer.html
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            activation=activation,
            #tgt_mask=None, # Apply to input and not encoder output
            layer_norm_eps=1e-5,
            norm_first=False,
        )
        self.decoder = nn.TransformerDecoder(
            decoder_layer=decoder_layer, 
            num_layers=num_decoder_layers)
    
        # Classification - Linear and Softmax
        self.classification_head = nn.Linear(
            d_model, 
            num_classes
        )
        self.softmax = nn.Softmax(dim=1)


    def forward(self, x):
        x = x.float()
        
        print(f"Input shape: {x.shape}")
    
        x = self.patch_embedding(x)
        print(f"Weights of the patch embedding layer: {self.patch_embedding.weight.shape}")

        x = x.view(x.size(0), -1, d_model) # Flatten
        
        encoder_output = self.encoder(x)
        print(f"After encoder shape: {x.shape}")
        
        x = self.decoder(encoder_output)
        print(f"After decoder shape: {x.shape}")
        
        x = x.mean(dim=1) #??
        x = self.classification_head(x)
        x = self.softmax(x)

        return x
    
    '''
    def _generate_positional_encoding(self, seq_length):
        n = 10000 # Default
        d = d_model
        P = np.zeros((seq_length, d)) # Init zero matrix
        for k in range(seq_length):
            for i in np.arange(int(d/2)):
                denominator = np.power(n, 2*i/d)
                P[k, 2*i] = np.sin(k/denominator)
                P[k, 2*i+1] = np.cos(k/denominator)
        return P 
    '''

#############################################################################################################

model = VisionTransformerCoinImageClassifier(
    d_model=d_model, 
    nhead=nhead, 
    num_decoder_layers=num_encoder_layers, 
    num_encoder_layers=num_decoder_layers, 
    dim_feedforward=dim_feedforward, 
    activation=activation, 
    norm_first=norm_first,
    patch_size=patch_size,
    image_size=image_size,
    num_classes=num_classes,
)

# Get model info of the model
print(model)



# TEST
preproc = v2.Compose(
    [
        v2.PILToTensor(),
    ]
)

dataset = FakeData(size=10, image_size=(3,256,256), num_classes=2, transform=preproc)
img, label = dataset[0]
print(f"{type(img) = }, {img.dtype = }, {img.shape = }, {label = }")
output = model(img)

# Get detailed info of the model
summary(
    model=model, 
    input_data=(3, 256, 256), 
    col_names=('output_size','kernel_size','num_params'), 
    verbose=1, # More info = 2
)
'''
for name, param in model.named_parameters():
    print(f'{name}: {param.data}')
'''

VisionTransformerCoinImageClassifier(
  (patch_embedding): Conv2d(3, 256, kernel_size=(16, 16), stride=(16, 16))
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (decoder): TransformerDecoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQu

TypeError: TransformerDecoder.forward() missing 1 required positional argument: 'memory'