In [3]:
import torch
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch import nn
from torch.utils.data import ConcatDataset, DataLoader, Dataset
import numpy as np
from nptyping import Float32, NDArray, Number, Shape, UInt
from transformers import ViTModel
import pytorch_lightning as pl
import math
from torch.autograd import Variable

import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.insert(0, module_path)

: 

: 

In [None]:
class SignedDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        return self.X[i], self.Y[i]

In [None]:
class ViT_FeatureExtractor(pl.LightningModule):
	def __init__(
		self,
		nb_classes: int = 10,
	):
		super().__init__()
		# print("---VIT INIT---")

		self.pretrained_vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
		self.pretrained_vit.eval()

		self.conv_1d_1 = torch.nn.Conv1d(
			in_channels=197,
			out_channels=64,
			kernel_size=3,
		)
		self.layer_1_relu = nn.ReLU()
		self.conv_1d_2 = torch.nn.Conv1d(
			in_channels=64,
			out_channels=nb_classes, # <-- i/o 1
			kernel_size=3,
		)
		self.layer_2_relu = nn.ReLU()

	def vit_extract_features(self, x):
		# print("---VIT EXTRACT FEATURES---")
		with torch.no_grad():
			outputs = self.pretrained_vit(pixel_values=x)
			vit_feat = outputs.last_hidden_state
			print(f"{vit_feat.shape= }")
			vit_feat = torch.flatten(vit_feat, start_dim=1)
		return vit_feat
	
	def forward(
		self,
		vit_feat, 
	) -> NDArray[Shape["* batch, * vocab size"], Float32]:
		# print("---VIT FORWARD---")
		x = self.conv_1d_1(vit_feat)
		x = self.layer_1_relu(x)
		x = self.conv_1d_2(x)
		x = self.layer_2_relu(x)
		return x

In [None]:
class GRU_Translator(pl.LightningModule):
	def __init__(
		self,
		nb_classes,
		H_input_size: int = 151296,
		H_output_size: int = 10,
		num_layers: int = 1,
		dropout: int = 0,
	):
		# print("---GRU INIT---")
		super().__init__()
		self.save_hyperparameters()
		self.vocabulary_size = nb_classes
		self.layer_gru = nn.GRU(
			input_size=self.hparams.H_input_size,
			hidden_size=self.hparams.nb_classes,
			num_layers=self.hparams.num_layers,
			batch_first=True,
			dropout=self.hparams.dropout,
		)
		
		self.layer_1_dense = nn.Linear(self.hparams.nb_classes, self.hparams.nb_classes)
		self.layer_1_relu = nn.ReLU()
		# self.layer_2_dense = nn.Linear(self.hparams.H_output_size, self.vocabulary_size)
		# self.layer_2_relu = nn.ReLU()
		self.softmax = nn.Softmax(dim=2) # <-- i/o dim=2

	def forward(self, X):
		# print("---GRU FORWARD---")	
		# print(f'{X.shape =}')
		X, _ = self.layer_gru(X)
		X = self.layer_1_dense(X)
		# X = self.layer_1_relu(X)
		# X = self.layer_2_dense(X)
		# X = self.layer_2_relu(X)
		# print(f'{X.shape =}')
		# print(f"avant softmax: {X=}")
		X = self.softmax(X)
		# print(f"apres softmax: {X=}")
		# print(f'{X.shape =}')
		return X

In [None]:

class BaseSquareNet(pl.LightningModule):
	def __init__(
		self,
		batch_size: int = 1,
		seq_size: int = 1,
		nb_classes: int = 10,
		h_in: int = 10,
	):
		super().__init__()
		self.save_hyperparameters()

		self.batch_size = batch_size
		self.nb_seq_sizebatch = seq_size
		self.image_feature_extractr = ViT_FeatureExtractor(nb_classes=nb_classes)
		self.recurrent_translator = GRU_Translator(
			nb_classes = nb_classes,
			H_input_size=h_in,
			# H_output_size=100,
			num_layers=1,
			dropout=0,
		)

	def forward(
		self, x: NDArray[Shape["* batch, 224, 224, 3"], Float32]
	) -> NDArray[Shape["* batch, * vocab size"], Float32]:
		x = self.recurrent_translator(x)
		return x

In [None]:
class BaseSquareNet(pl.LightningModule):
	def __init__(
		self,
		batch_size: int = 1,
		seq_size: int = 1,
		nb_classes: int = 10,
		h_in: int = 10,
	):
		super().__init__()
		self.save_hyperparameters()

		self.batch_size = batch_size
		self.nb_seq_sizebatch = seq_size
		self.image_feature_extractr = ViT_FeatureExtractor(nb_classes=nb_classes)
		self.recurrent_translator = GRU_Translator(
			nb_classes = nb_classes,
			H_input_size=h_in,
			# H_output_size=100,
			num_layers=1,
			dropout=0,
		)

	def forward(
		self, x: NDArray[Shape["* batch, 224, 224, 3"], Float32]
	) -> NDArray[Shape["* batch, * vocab size"], Float32]:
		x = self.recurrent_translator(x)
		return x

In [None]:
# Hyperparameters
nb_classes=10
seq_size = 2
batch_size = 2
learning_rate = 1e-4
h_in = 10

In [None]:
x = torch.rand((batch_size * seq_size, 3, 224, 224))
y = torch.randint(0, nb_classes, (batch_size, seq_size))

# print(f"{y.size()=}")
# print(f"{y=}")

x = tensor([[[0.1326, 0.8331, 0.4766, 0.6109, 0.9087, 0.2746, 0.8763, 0.0457,
          0.7431, 0.3337],
         [0.7922, 0.1349, 0.3717, 0.6364, 0.8544, 0.5529, 0.1504, 0.8544,
          0.5674, 0.2392],
         [0.5797, 0.9653, 0.5151, 0.7264, 0.9557, 0.0348, 0.6179, 0.0335,
          0.8419, 0.4402],
         [0.6365, 0.3544, 0.1200, 0.0209, 0.3555, 0.0998, 0.2111, 0.8125,
          0.6757, 0.2069]],

        [[0.8706, 0.1550, 0.1041, 0.6045, 0.5984, 0.1453, 0.5757, 0.9145,
          0.1843, 0.6810],
         [0.8202, 0.2809, 0.8687, 0.7622, 0.7350, 0.3684, 0.2455, 0.5685,
          0.3301, 0.8901],
         [0.1646, 0.9842, 0.0604, 0.3706, 0.1392, 0.7003, 0.4133, 0.4687,
          0.8430, 0.7095],
         [0.2424, 0.4553, 0.8838, 0.0189, 0.9084, 0.4076, 0.3309, 0.2160,
          0.4893, 0.4794]],

        [[0.2542, 0.6651, 0.3881, 0.3942, 0.7424, 0.0986, 0.8370, 0.2860,
          0.8439, 0.7679],
         [0.9598, 0.7701, 0.2160, 0.3785, 0.6271, 0.7881, 0.5095, 0.0943,
         

In [None]:
from tqdm import tqdm 

# Models
# vit_feat = model.image_feature_extractr.vit_extract_features(x)

# dataset = SignedDataset(vit_feat, y)
model = BaseSquareNet(batch_size=batch_size, seq_size=seq_size, nb_classes=nb_classes, h_in=h_in)
vit_feat = model.image_feature_extractr.vit_extract_features(x)
dataset = SignedDataset(vit_feat, y)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size)
	# lr = 10 ** (- e_lr / 10)

loss_fn = torch.nn.CrossEntropyLoss()

def train(train_loader, model, loss_fn, learning_rate, epochs=100):
	optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
	losses = []
	idx = 0
	for epoch in tqdm(range(epochs)):
		for batch_idx, (X, y) in enumerate(train_loader):
			pred = model(X)
			pred = pred.permute(0, 2, 1)
			loss = loss_fn(pred, y)
			loss.backward()
			optimizer.step()
			optimizer.zero_grad()

			# if epoch % 10 == 0:
			# 	print(f'[{epoch = }] loss: {loss}')
			idx += 1
			losses.append(float(loss.detach().numpy()))
	return losses

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(16, 9))
torch.manual_seed(6)

for lr in [1e-3]:
	# lr = 10 ** (- e_lr / 10)
	# model = GRU(input_size=h_in, hidden_size=5, output_size=nb_classes, num_layers=1, bias=0)
	model = BaseSquareNet(nb_classes=nb_classes, seq_size=seq_size, batch_size=batch_size, h_in=h_in)
	losses = train(dataloader, model, loss_fn, learning_rate=lr, epochs=int(1e3))
	print(f"For {lr = }, {min(losses) = }")
	plt.plot(losses, label=f"{lr:e}")
leg = plt.legend(loc='best')
plt.show()

 25%|██▍       | 247/1000 [01:06<02:40,  4.70it/s]