In [None]:
import torch
from torch import nn
import torchvision
import torchvision.transforms as T

import matplotlib.pyplot as plt
import seaborn as sns

from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
from pathlib import Path

import numpy as np

device = "cuda"

In [None]:
%matplotlib inline

In [None]:
!pip install gdown
!pip install scipy

# We are going to use the Calthech101 dataset

In [None]:
input_transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
])

dataset = torchvision.datasets.Caltech101(root=Path.cwd(), 
                                          target_type = 'category', 
                                          transform = input_transform, 
                                          target_transform = None, 
                                          download= True)

dataloader = DataLoader(dataset, batch_size=32)



In [None]:
indexes = np.random.randint(0, len(dataset), (3,3))

fig, axes = plt.subplots(3,3)

for i in range(3):
    for j in range(3):
        img, _ = dataset[indexes[i,j]]
        axes[i,j].imshow(img.permute(1,2,0))

# Dataset presentation : 
print(f"the dataset includes {len(dataset)} samples")

# Model creation

In [None]:
D_h = 768
N = 100

z = torch.randn((N, D_h))

In [None]:
z.shape

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, D_input, D_h):
        super().__init__()
        self.D_h = D_h
        self.D_input=D_input

        self.q_mat = nn.Linear(in_features=self.D_input, out_features=self.D_h, bias=None)
        self.k_mat = nn.Linear(in_features=self.D_input, out_features=self.D_h, bias=None)
        self.v_mat = nn.Linear(in_features=self.D_input, out_features=self.D_h, bias=None)
    
    def forward(self, z):
        q, k, v = self.q_mat(z), self.k_mat(z), self.v_mat(z)
        A = torch.softmax(torch.matmul(q, torch.transpose(k, 1, 0)) / torch.sqrt(torch.tensor(self.D_h)), axis=1)
        return torch.matmul(A, v)
        

class MSA(nn.Module):
    def __init__(self, D_input, k):
        super().__init__()
        self.k = k
        self.D_input = D_input
        self.D_h = D_input//k

        self.attentions = [SelfAttention(self.D_h, self.D_h) for i in range(k)]
        self.unification_matrix = nn.Linear(self.D_input, self.D_input, bias=None)
    
    def forward(self, z):
        vectors = torch.split(z, split_size_or_sections=self.D_h, dim=1)
        MSA = torch.cat([self.attentions[i](vectors[i]) for i in range(self.k)], dim=1)
        return self.unification_matrix(MSA)

class MLP(nn.Module):
    def __init__(self, embedding_dim, mlp_size, dropout=0.1):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.mlp_size = mlp_size
        
        self.block = nn.Sequential(nn.LayerNorm(),
                        nn.Linear(in_features=embedding_dim, out_features=mlp_size),
                        nn.GELU(),
                        nn.Dropout(0.1),
                        nn.Linear(in_features=mlp_size, embedding_dim),
                        nn.Dropout(0.1))

    def forward(self, z):
        return self.block(z)



In [None]:
SA = SelfAttention(D_h, D_h)
SA(z).shape

msa = MSA(D_input=D_h, k=16)
output_attention = msa.forward(z)
print(output_attention.shape)

In [None]:
for x in torch.split(z, split_size_or_sections=D_h, dim=1):
    print(x.shape)

In [None]:
class Encoder(nn.Module):
        def __init__(self, ):
        pass
    
    def forward(self, x):
        pass
