In [49]:
%pip install nptyping

Note: you may need to restart the kernel to use updated packages.


In [50]:
%pip install transformers

Note: you may need to restart the kernel to use updated packages.


In [51]:
import torch
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch import nn
from torch.utils.data import ConcatDataset, DataLoader, Dataset
import numpy as np
from nptyping import Float32, NDArray, Number, Shape, UInt
from transformers import ViTModel
import pytorch_lightning as pl


import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.insert(0, module_path)

In [52]:
class SignedDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        return self.X[i], self.Y[i]

In [53]:
class ViT_FeatureExtractor(pl.LightningModule):
	def __init__(
		self,
		nb_classes: int = 10,
	):
		super().__init__()
		# print("---VIT INIT---")

		self.pretrained_vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
		self.pretrained_vit.eval()

		self.conv_1d_1 = torch.nn.Conv1d(
			in_channels=197,
			out_channels=64,
			kernel_size=3,
		)
		self.layer_1_relu = nn.ReLU()
		self.conv_1d_2 = torch.nn.Conv1d(
			in_channels=64,
			out_channels=nb_classes, # <-- i/o 1
			kernel_size=3,
		)
		self.layer_2_relu = nn.ReLU()

	def vit_extract_features(self, x):
		# print("---VIT EXTRACT FEATURES---")
		with torch.no_grad():
			outputs = self.pretrained_vit(pixel_values=x)
			vit_feat = outputs.last_hidden_state
			print(f"{vit_feat.shape= }")
			vit_feat = torch.flatten(vit_feat, start_dim=1)
		return vit_feat
	
	def forward(
		self,
		vit_feat, 
	) -> NDArray[Shape["* batch, * vocab size"], Float32]:
		# print("---VIT FORWARD---")
		x = self.conv_1d_1(vit_feat)
		x = self.layer_1_relu(x)
		x = self.conv_1d_2(x)
		x = self.layer_2_relu(x)
		return x

In [54]:
class BasicModel(pl.LightningModule):
	def __init__(
		self,
	):
		# print("---BASIC MODEL INIT---")
		super().__init__()
		self.save_hyperparameters()

		self.vocabulary_size = nb_classes
		self.layer = nn.Linear(151296, self.vocabulary_size)
		self.softmax = torch.nn.Softmax(dim=2)

	def forward(
		self, x: NDArray[Shape["* batch, 224, 224, 3"], Float32]
	) -> NDArray[Shape["* batch, * vocab size"], Float32]:
		# print("---BASIC MODEL FORWARD---")
		x = self.layer(x)
		x = self.softmax(x)
		return x

In [63]:
class GRU_Translator(pl.LightningModule):
	def __init__(
		self,
		nb_classes,
		H_input_size: int = 151296,
		H_output_size: int = 100,
		num_layers: int = 1,
		dropout: int = 0,
	):
		# print("---GRU INIT---")
		super().__init__()
		self.save_hyperparameters()
		self.vocabulary_size = nb_classes
		self.layer_gru = nn.GRU(
			input_size=self.hparams.H_input_size,
			hidden_size=self.hparams.nb_classes,
			num_layers=self.hparams.num_layers,
			batch_first=True,
			dropout=self.hparams.dropout,
		)
		
		# self.layer_1_dense = nn.Linear(self.hparams.H_output_size, self.hparams.H_output_size)
		# self.layer_1_relu = nn.ReLU()
		# self.layer_2_dense = nn.Linear(self.hparams.H_output_size, self.vocabulary_size)
		# self.layer_2_relu = nn.ReLU()
		self.softmax = nn.Softmax(dim=2) # <-- i/o dim=2

	def forward(self, X):
		# print("---GRU FORWARD---")	
		# print(f'{X.shape =}')
		X, _ = self.layer_gru(X)
		# X = self.layer_1_dense(X)
		# X = self.layer_1_relu(X)
		# X = self.layer_2_dense(X)
		# X = self.layer_2_relu(X)
		# print(f'{X.shape =}')
		X = self.softmax(X)
		# print(f'{X.shape =}')
		return X


In [56]:

class BaseSquareNet(pl.LightningModule):
	def __init__(
		self,
		batch_size: int = 1,
		seq_size: int = 1,
		nb_classes: int = 10,
		h_in: int = 10,
	):
		super().__init__()
		self.save_hyperparameters()

		self.batch_size = batch_size
		self.nb_seq_sizebatch = seq_size
		self.image_feature_extractr = ViT_FeatureExtractor(nb_classes=nb_classes)
		self.recurrent_translator = GRU_Translator(
			nb_classes = nb_classes,
			H_input_size=h_in,
			# H_output_size=100,
			num_layers=1,
			dropout=0,
		)

	def forward(
		self, x: NDArray[Shape["* batch, 224, 224, 3"], Float32]
	) -> NDArray[Shape["* batch, * vocab size"], Float32]:
		x = self.recurrent_translator(x)
		return x

In [72]:
# Hyperparameters
nb_classes=1999
seq_size = 16
batch_size = 2
learning_rate = 1e-3
h_in = 10

In [73]:
# Data
# x = torch.rand((batch_size, seq_size, 3, 224, 224))
# y = torch.randint(0, nb_classes, (batch_size, seq_size, 1))

# x = torch.rand((batch_size, 3, 224, 224))
x = torch.rand((batch_size, seq_size, h_in))
y = torch.randint(0, nb_classes, (batch_size, seq_size))

print(f"{y.size()=}")

y.size()=torch.Size([2, 16])


In [74]:
# Models
model = BaseSquareNet(nb_classes=nb_classes, seq_size=seq_size, batch_size=batch_size)
# vit_feat = model.image_feature_extractr.vit_extract_features(x)

# dataset = SignedDataset(vit_feat, y)
dataset = SignedDataset(x, y)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size)

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

def train(train_loader, model, loss_fn, optimizer):
	loss = 10
	idx = 0
	while loss > 6.7:
		for batch_idx, (X, y) in enumerate(train_loader):
			pred = model(X)
			pred = pred.permute(0, 2, 1)
			loss = loss_fn(pred, y)
			loss.backward()
			optimizer.step()
			optimizer.zero_grad()

			# if idx % 10 == 0:
			print(f'[{idx}] loss: {loss}')
			idx += 1

	print(f'[{idx}] final loss: {loss}')

In [75]:
train(dataloader, model, loss_fn, optimizer)

[0] loss: 7.600401878356934
[1] loss: 7.600377082824707
[2] loss: 7.600338935852051
[3] loss: 7.60021448135376


KeyboardInterrupt: 