In [None]:
'''
Steps:
    1) Divide imput image into patches
    2) Flatten the patches to be handles as tokens
    3) Add classification token at the beginning
    4) Feed tokens into the encoder
    5) Number of output of the encoder is equal to the size of the input
    6) 
'''

In [1]:
import torch
import torch.nn as nn 
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from sklearn.model_selection import train_test_split
import numpy as np
from torchsummary import summary

In [2]:
# TODO: LOAD OUR DATA 256X256

In [15]:
# https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html
# Paper: https://arxiv.org/abs/1706.03762 

class VisionTransformerCoinImageClassifier(nn.Module):
    def __init__(
        self, 
        d_model, 
        nhead, 
        num_encoder_layers, 
        num_decoder_layers, 
        dim_feedforward, 
        activation, 
        norm_first, 
        patch_size, 
        image_size, 
        num_classes
    ):
        super(VisionTransformerCoinImageClassifier, self).__init__()

        # Calculate the number of patches
        num_patches = (image_size // patch_size) ** 2 # **2 = ^2 due to X(W) and Y(H) position considered

        # Path embedding layer - Sliding window - Non-overlapping image patches
        self.patch_embedding = nn.Conv2d(3, d_model, kernel_size=patch_size, stride=patch_size)

        # Positional Encoding - Fixed positional encoding
        self.positional_encoding = self._generate_positional_encoding() # https://machinelearningmastery.com/a-gentle-introduction-to-positional-encoding-in-transformer-models-part-1/
        
        # Encoder - Multihead Attention, FFN and Layer Normalization 
        # Attention mechanism: https://machinelearningmastery.com/the-transformer-attention-mechanism/
        self.encoder = self._build_encoder()
        self.layer_norm = nn.LayerNorm(d_model)

        # Decoder first layer - Masked Multihead and Layer Normalization
        self.decoder_first_layer = self._build_decoder_first_layer()
        self.layer_norm2 = nn.LayerNorm(d_model)

        # Decoder - Multihead Attention, FFN and Layer Normalization
        self.decoder = self._build_encoder()
        self.layer_norm3 = nn.LayerNorm(d_model)
    
        # Classification - Linear and Softmax
        self.classification_head = nn.Linear(d_model, num_classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.patch_embedding(x)
        pe = self.positional_encoding
        x += pe  #??
        x = self.encoder(x)
        x = self.layer_norm(x)
        x = self.decoder_first_layer(x)
        x = self.layer_norm2(x)
        x = self.decoder(x)
        x = self.layer_norm3(x)
        x = x.mean(dim=1) #??
        x = self.classification_head(x)
        x = self.softmax(x)

        return x
    
    
    def _generate_positional_encoding(self):
        n = 10000 # Default
        sequence_lenth = 0
        

    def _build_encoder(self):
        # https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=nhead, 
            dim_feedforward=dim_feedforward,
            activation=activation
        )
        return nn.TransformerEncoder(encoder_layer=encoder_layer, 
                                     num_layers=num_encoder_layers)
        
    
    def _build_decoder_first_layer(self):
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            activation=activation,
            # ADD THE MASKING
        )
        return nn.TransformerDecoder(decoder_layer=decoder_layer, num_layers=num_decoder_layers)
    
    def _build_decoder(self):
        # https://pytorch.org/docs/stable/generated/torch.nn.TransformerDecoderLayer.html
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            activation=activation
        )
        return nn.TransformerDecoder(decoder_layer=decoder_layer, num_layers=num_decoder_layers)
    

# Hyperparameters
d_model = 256
nhead = 8
num_encoder_layers = 6
num_decoder_layers = 6 
num_classes = 2 
norm_first = False 
patch_size = 16 
image_size = 256 # Square it for true size 
activation = 'relu'
dim_feedforward = 2048

model = VisionTransformerCoinImageClassifier(
    d_model=d_model, 
    nhead=nhead, 
    num_decoder_layers=num_encoder_layers, 
    num_encoder_layers=num_decoder_layers, 
    dim_feedforward=dim_feedforward, 
    activation=activation, 
    norm_first=norm_first,
    patch_size=patch_size,
    image_size=image_size,
    num_classes=num_classes
)

print(model)



VisionTransformerCoinImageClassifier(
  (patch_embedding): Conv2d(3, 256, kernel_size=(16, 16), stride=(16, 16))
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (decoder_first_layer): TransformerDecoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerDecoderLaye