In [1]:
import torch
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch import nn
from torch.utils.data import ConcatDataset, DataLoader, Dataset
import numpy as np
from nptyping import Float32, NDArray, Number, Shape, UInt
from transformers import ViTModel
import pytorch_lightning as pl


import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.insert(0, module_path)

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
class SignedDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        # [n_video, nb_frames, 3, 320, 240]
        self.Y = Y
        # [n_video, nb_signes, 1]

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        return self.X[i], self.Y[i]

In [7]:
nb_classes=1999

class ViT_FeatureExtractor(pl.LightningModule):
	def __init__(
		self,
		corpus: str = "/usr/share/dict/words",
	):
		super().__init__()

		# self.vocabulary_size = len(np.array(open(corpus).read().splitlines()))
		self.vocabulary_size = nb_classes

		self.pretrained_vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
		self.pretrained_vit.eval()

		self.conv_1d_1 = torch.nn.Conv1d(
			in_channels=197,
			out_channels=64,
			kernel_size=3,
		)
		self.layer_1_relu = nn.ReLU()
		self.conv_1d_2 = torch.nn.Conv1d(
			in_channels=64,
			out_channels=1,
			kernel_size=3,
		)
		self.layer_2_relu = nn.ReLU()

	def vit_extract_features(self, x):
		with torch.no_grad():
			outputs = self.pretrained_vit(pixel_values=x)
			vit_feat = outputs.last_hidden_state
			vit_feat = torch.flatten(vit_feat, start_dim=1)
			vit_feat = torch.unsqueeze(vit_feat, dim=1)
		return vit_feat
	
	def forward(
		self,
		vit_feat, 
		# x: NDArray[Shape["* batch, 224, 224, 3"], Float32]
	) -> NDArray[Shape["* batch, * vocab size"], Float32]:

		# print(f"vit {vit_feat.shape= }")
		x = self.conv_1d_1(vit_feat)
		x = self.layer_1_relu(x)
		x = self.conv_1d_2(x)
		x = self.layer_2_relu(x)
		# print(f"vit {x.shape= }")
		x = torch.squeeze(x, dim=2)
		# print(f"vit {x.shape= }")
		return x

In [None]:
class BasicModel(pl.LightningModule):
	def __init__(
		self,
	):
		super().__init__()
		self.save_hyperparameters()

		self.vocabulary_size = nb_classes
		self.layer = nn.Linear(151296, self.vocabulary_size)
		self.softmax = torch.nn.Softmax(dim=2)

	def forward(
		self, x: NDArray[Shape["* batch, 224, 224, 3"], Float32]
	) -> NDArray[Shape["* batch, * vocab size"], Float32]:
		x = self.layer(x)
		x = self.softmax(x)
		return x

In [3]:
class GRU_Translator(pl.LightningModule):
	def __init__(
		self,
		H_input_size: int = 151296,
		H_output_size: int = 100,
		num_layers: int = 1,
		dropout: int = 0,
		corpus: str = "/usr/share/dict/words",
	):
		super().__init__()
		self.save_hyperparameters()
		# self.vocabulary_size = len(np.array(open(corpus).read().splitlines()))
		self.vocabulary_size = nb_classes
		self.layer_gru = nn.GRU(
			input_size=self.hparams.H_input_size,
			hidden_size=self.hparams.H_output_size,
			num_layers=self.hparams.num_layers,
			batch_first=True,
			dropout=self.hparams.dropout,
		)

		self.layer_1_dense = nn.Linear(self.hparams.H_output_size, self.hparams.H_output_size)
		self.layer_1_relu = nn.ReLU()
		self.layer_2_dense = nn.Linear(self.hparams.H_output_size, self.vocabulary_size)
		self.layer_2_relu = nn.ReLU()
		self.softmax = nn.Softmax(dim=2)
			# x = torch.squeeze(x)

		X = self.layer_1_relu(X)
		# print(f"gru: {X.shape = }")
		X = self.layer_2_dense(X)
		X = self.layer_2_relu(X)
		# print(f"gru: {X.shape = }")
		X = self.softmax(X)
		# print(f"gru end: {X.shape= }")
		# print(f"gru: {X.shape = }")
		return X

class BaseSquareNet(pl.LightningModule):
	def __init__(
		self,
		corpus: str = "/usr/share/dict/words",
		sequence_size: int = 16,
	):
		super().__init__()
		self.save_hyperparameters()

		# self.vocabulary_size = len(np.array(open(corpus).read().splitlines()))
		self.vocabulary_size = nb_classes
		# self.image_feature_extractr = ViT_FeatureExtractor(corpus)
		self.image_feature_extractr = ViT_FeatureExtractor(corpus)
		self.recurrent_translator = GRU_Translator(
			H_input_size=151296,
			H_output_size=100,
			num_layers=1,
			dropout=0,
			corpus=corpus,
		)

	def forward(
		self, x: NDArray[Shape["* batch, 224, 224, 3"], Float32]
	) -> NDArray[Shape["* batch, * vocab size"], Float32]:
		# print(f"Bsqr: {x.shape = }")
		x = self.recurrent_translator(x)
		# print(f"Bsqr: {x.shape = }")
		return x



nb_batch = 1
batch_size = 16
x = torch.rand((batch_size * nb_batch, 3, 224, 224))
y = torch.randint(0, nb_classes, (batch_size * nb_batch,))

corpus="/home/dolmalin/Documents/work/42ai/Hand2Text/data/H2T/wlasl_words"

model = BaseSquareNet(corpus=corpus)
vit_feat = model.image_feature_extractr.vit_extract_features(x)
dataset = SignedDataset(vit_feat, y)
# model = BasicModel()
print(f"{vit_feat.shape= }")

dataloader = DataLoader(dataset=dataset, batch_size=batch_size)

learning_rate = 1e-2
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

def train(train_loader, model, loss_fn, optmizer):
	loss = 10
	while loss > 6.7:
		for batch_idx, (X, y) in enumerate(train_loader):

			pred = model(X)

			pred = torch.squeeze(pred, dim=1)

			loss = loss_fn(pred, y)

			loss.backward()
			optimizer.step()
			optimizer.zero_grad()
			if batch_idx % 100 == 0:
				print(f'loss: {loss}\r', end='')

vit_feat.shape= torch.Size([16, 1, 151296])


In [4]:
train(dataloader, model, loss_fn, optimizer)

loss: 6.601263046264648