In [2]:
import torch
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch import nn
from torch.utils.data import ConcatDataset, DataLoader, Dataset
import numpy as np
from nptyping import Float32, NDArray, Number, Shape, UInt
from transformers import ViTModel
import pytorch_lightning as pl


import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.insert(0, module_path)

# from src.models.components.baseline.BaseSquareNet import BaseSquareNet
# from src.models.components.vit_baseline import ViTBaselineModel

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class SignedDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        # [n_video, nb_frames, 3, 320, 240]
        self.Y = Y
        # [n_video, nb_signes, 1]

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        return self.X[i], self.Y[i]


In [4]:
class ViT_FeatureExtractor(pl.LightningModule):
    def __init__(
        self,
        corpus: str = "/usr/share/dict/words",
    ):
        super().__init__()

        # self.vocabulary_size = len(np.array(open(corpus).read().splitlines()))
        self.vocabulary_size = 1999

        self.pretrained_vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        self.pretrained_vit.eval()

        self.conv_1d_1 = torch.nn.Conv1d(
            in_channels=197,
            out_channels=64,
            kernel_size=3,
        )
        self.layer_1_relu = nn.ReLU()
        self.conv_1d_2 = torch.nn.Conv1d(
            in_channels=64,
            out_channels=1,
            kernel_size=3,
        )
        self.layer_2_relu = nn.ReLU()

    def forward(
        self, x: NDArray[Shape["* batch, 224, 224, 3"], Float32]
    ) -> NDArray[Shape["* batch, * vocab size"], Float32]:

        outputs = self.pretrained_vit(pixel_values=x)
        vit_feat = outputs.last_hidden_state

        x = self.conv_1d_1(vit_feat)
        x = self.layer_1_relu(x)
        x = self.conv_1d_2(x)
        x = self.layer_2_relu(x)
        x = torch.squeeze(x, dim=0)
        print(f"{x.shape= }")
        return x

class ViT(pl.LightningModule):
    def __init__(
        self,
        corpus: str = "/usr/share/dict/words",
    ):
        super().__init__()

        # self.vocabulary_size = len(np.array(open(corpus).read().splitlines()))
        self.vocabulary_size = 1999

        self.pretrained_vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        self.pretrained_vit.eval()

        self.conv_1d_1 = torch.nn.Conv1d(
            in_channels=197,
            out_channels=64,
            kernel_size=3,
        )
        self.layer_1_relu = nn.ReLU()
        self.conv_1d_2 = torch.nn.Conv1d(
            in_channels=64,
            out_channels=1,
            kernel_size=3,
        )
        self.layer_2_relu = nn.ReLU()
        self.dense = nn.Linear(764, self.vocabulary_size)
        self.softmax = nn.Softmax(dim=1)
        
    def forward(
        self, x: NDArray[Shape["* batch, 224, 224, 3"], Float32]
    ) -> NDArray[Shape["* batch, * vocab size"], Float32]:

        outputs = self.pretrained_vit(pixel_values=x)
        vit_feat = outputs.last_hidden_state

        x = self.conv_1d_1(vit_feat)
        x = self.layer_1_relu(x)
        x = self.conv_1d_2(x)
        x = self.layer_2_relu(x)
        # x = torch.squeeze(x, dim=0)
        # print(f"{x.shape= }")
        x = self.dense(x)
        x = self.softmax(x)
        return x

In [5]:
class GRU_Translator(pl.LightningModule):
    def __init__(
        self,
        H_input_size: int = 764,
        H_output_size: int = 100,
        num_layers: int = 1,
        dropout: int = 0,
        corpus: str = "/usr/share/dict/words",
    ):
        super().__init__()
        self.save_hyperparameters()
        # self.vocabulary_size = len(np.array(open(corpus).read().splitlines()))
        self.vocabulary_size = 1999
        self.layer_gru = nn.GRU(
            input_size=self.hparams.H_input_size,
            hidden_size=self.hparams.H_output_size,
            num_layers=self.hparams.num_layers,
            batch_first=True,
            dropout=self.hparams.dropout,
        )

        self.layer_1_dense = nn.Linear(self.hparams.H_output_size, self.hparams.H_output_size)
        self.layer_1_relu = nn.ReLU()
        self.layer_2_dense = nn.Linear(self.hparams.H_output_size, self.vocabulary_size)
        self.layer_2_relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=2)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        # print(f"gru: {X.shape = }")
        X, hidden = self.layer_gru(X)
        # print(f"gru: {X.shape = }")
        X = self.layer_1_dense(X)
        X = self.layer_1_relu(X)
        # print(f"gru: {X.shape = }")
        X = self.layer_2_dense(X)
        X = self.layer_2_relu(X)
        # print(f"gru: {X.shape = }")
        X = self.softmax(X)
        print(f"{X.shape= }")
        # print(f"gru: {X.shape = }")
        return X


In [6]:
class BaseSquareNet(pl.LightningModule):
    def __init__(
        self,
        corpus: str = "/usr/share/dict/words",
        sequence_size: int = 16,
    ):
        super().__init__()
        self.save_hyperparameters()

        # self.vocabulary_size = len(np.array(open(corpus).read().splitlines()))
        self.vocabulary_size = 1999
        # self.image_feature_extractr = ViT_FeatureExtractor(corpus)
        self.image_feature_extractr = ViT(corpus)
        self.recurrent_translator = GRU_Translator(
            H_input_size=764,
            H_output_size=100,
            num_layers=1,
            dropout=0,
            corpus=corpus,
        )

    def forward(
        self, x: NDArray[Shape["* batch, 224, 224, 3"], Float32]
    ) -> NDArray[Shape["* batch, * vocab size"], Float32]:
        # x_seq = []
        # for i in range(self.hparams.sequence_size):
        #     print(f"{x.shape = }")
        #     b, f = x.shape
        #     x = x.view((b, 1, f))
        #     print(f"{x.shape = }")
        #     x_seq.append(x)
        # x_seq = torch.cat(x_seq, dim=1)
        # print(f"In: {x.shape = }")
        x = self.image_feature_extractr(x)
        # print(f"Vit: {x.shape = }")

        # b, f = x.shape
        # x_seq = x.view(1, b, f)
        # x = self.recurrent_translator(x_seq)
        return x


In [10]:
# x = np.random.rand(1, 3, 224, 224)
# x = torch.tensor(x, dtype=torch.float)
x = torch.rand((1, 3, 224, 244))
y = torch.tensor([0])

print(f"{x.shape= }")
print(f"{y.shape= }")

dataset = SignedDataset(x, y)
dataloader = DataLoader(dataset=dataset, batch_size=1)
model = BaseSquareNet()

learning_rate = 0.002
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

x.shape= torch.Size([1, 3, 224, 244])
y.shape= torch.Size([1])


In [8]:
def train(train_loader, model, loss_fn, optmizer):
    size = len(train_loader.dataset)
    batches_l = len(train_loader)
    loss = 0
    correct = 0

    while True:
        for batch_idx, (X, y) in enumerate(train_loader):
            pred = torch.squeeze(model(X), dim=0)
            loss = loss_fn(pred, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % 100 == 0:
                print(f'loss: {loss}\r', end='')

In [9]:
train(dataloader, model, loss_fn, optimizer)

ValueError: Input image size (224*244) doesn't match model (224*224).