In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import EMNIST
import timm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms

In [None]:
# import kagglehub

# # Download latest version
# path = kagglehub.dataset_download("crawford/emnist")

# print("Path to dataset files:", path)

In [2]:
path = 'C:\\Users\\scanu\\.cache\\kagglehub\\datasets\\crawford\\emnist\\versions\\3'

dataset_file_train = "/emnist-letters-train.csv"
dataset_file_test = "/emnist-letters-test.csv"

train = pd.read_csv(path + dataset_file_train, delimiter=',')
test = pd.read_csv(path + dataset_file_test, delimiter=',')

In [3]:
y_train = np.array(train.iloc[:,0].values)
x_train = np.array(train.iloc[:,1:].values)

y_test = np.array(test.iloc[:,0].values)
x_test = np.array(test.iloc[:,1:].values)

In [4]:
n_data = len(x_train)
height = 28
resizer = transforms.Resize((224,224),interpolation=transforms.InterpolationMode.BICUBIC)

print(x_train.reshape(n_data,height,height).shape)
rgb_batch = np.repeat(x_train.reshape(n_data,height,height), 3, axis = 0).reshape(n_data,3,height,height)
print(rgb_batch.shape)

train_dataloader = DataLoader( [[rgb_batch[i], y_train[i]] for i in range(len(y_train))], batch_size=32, shuffle=False)
# test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

(88799, 28, 28)
(88799, 3, 28, 28)


In [5]:
# Modello Transformer
class TextRecognitionModel(nn.Module):
    def __init__(self, num_classes=27):  # 26 lettere + background
        super(TextRecognitionModel, self).__init__()
        self.backbone = timm.create_model("vit_base_patch16_224", pretrained=True)
        self.backbone.head = nn.Identity()  # Rimuoviamo la testa di classificazione
        
        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=768, nhead=8), num_layers=6
        )
        self.fc = nn.Linear(768, num_classes)
        
    def forward(self, x):
        features = self.backbone(x)
        features = features.unsqueeze(0)  # Adatta la dimensione per il decoder
        decoded = self.decoder(features, features)
        output = self.fc(decoded.squeeze(0))
        return output

In [None]:
# Istanziamento del modello
model = TextRecognitionModel(num_classes=27)

# Definizione della funzione di perdita e ottimizzatore
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Addestramento del modello
epochs = 5
for epoch in range(epochs):
    for images, labels in train_dataloader:
        images = resizer(images).float()
#         images = images.unsqueeze(0)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        print(f"This Batch, Loss: {loss.item():.4f}")
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

print("Addestramento completato!")

## Attempt with Cuda

In [6]:
availability = torch.cuda.is_available()
device = torch.device("cuda" if availability else "cpu")
print("Device:", device)

Device: cuda


In [7]:
# Istanziamento del modello
model = TextRecognitionModel(num_classes=27).to(device)

In [8]:
# Definizione della funzione di perdita e ottimizzatore
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Addestramento del modello
epochs = 1
for epoch in range(epochs):
    for images, labels in train_dataloader:
        images = resizer(images).float().to(device)
        labels = labels.to(device)
#         images = images.unsqueeze(0)
        optimizer.zero_grad()
        outputs = model(images).to(device)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        print(f"This Batch, Loss: {loss.item():.4f}")
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}", end = "\n")

print("Addestramento completato!")

  x = F.scaled_dot_product_attention(


This Batch, Loss: 3.4337
This Batch, Loss: 5.3205
This Batch, Loss: 4.1498
This Batch, Loss: 4.8940
This Batch, Loss: 5.5230
This Batch, Loss: 4.2530
This Batch, Loss: 3.6919
This Batch, Loss: 3.9770
This Batch, Loss: 4.7903
This Batch, Loss: 3.9632
This Batch, Loss: 3.9731
This Batch, Loss: 3.9964
This Batch, Loss: 3.6504
This Batch, Loss: 3.8707
This Batch, Loss: 3.5302
This Batch, Loss: 3.4320
This Batch, Loss: 3.5896
This Batch, Loss: 3.9214
This Batch, Loss: 3.8904
This Batch, Loss: 3.5118
This Batch, Loss: 3.5252
This Batch, Loss: 3.5261
This Batch, Loss: 3.4565
This Batch, Loss: 3.2842
This Batch, Loss: 3.5011
This Batch, Loss: 3.3788
This Batch, Loss: 3.3413
This Batch, Loss: 3.4378
This Batch, Loss: 3.6197
This Batch, Loss: 3.4944
This Batch, Loss: 3.5242
This Batch, Loss: 3.4634
This Batch, Loss: 3.4879
This Batch, Loss: 3.2614
This Batch, Loss: 3.7557
This Batch, Loss: 3.5204
This Batch, Loss: 3.4082
This Batch, Loss: 3.4270
This Batch, Loss: 3.3051
This Batch, Loss: 3.3423


This Batch, Loss: 3.3737
This Batch, Loss: 3.2501
This Batch, Loss: 3.2895
This Batch, Loss: 3.4029
This Batch, Loss: 3.3403
This Batch, Loss: 3.4136
This Batch, Loss: 3.2123
This Batch, Loss: 3.4943
This Batch, Loss: 3.4103
This Batch, Loss: 3.3397
This Batch, Loss: 3.3969
This Batch, Loss: 3.3487
This Batch, Loss: 3.3963
This Batch, Loss: 3.3098
This Batch, Loss: 3.2532
This Batch, Loss: 3.4470
This Batch, Loss: 3.2708
This Batch, Loss: 3.3711
This Batch, Loss: 3.3923
This Batch, Loss: 3.2854
This Batch, Loss: 3.2167
This Batch, Loss: 3.2515
This Batch, Loss: 3.3193
This Batch, Loss: 3.3117
This Batch, Loss: 3.4182
This Batch, Loss: 3.3506
This Batch, Loss: 3.2352
This Batch, Loss: 3.4043
This Batch, Loss: 3.6024
This Batch, Loss: 3.4204
This Batch, Loss: 3.2356
This Batch, Loss: 3.3345
This Batch, Loss: 3.3581
This Batch, Loss: 3.4412
This Batch, Loss: 3.4299
This Batch, Loss: 3.2556
This Batch, Loss: 3.3851
This Batch, Loss: 3.4171
This Batch, Loss: 3.2702
This Batch, Loss: 3.4628


This Batch, Loss: 3.2464
This Batch, Loss: 3.3732
This Batch, Loss: 3.3823
This Batch, Loss: 3.3568
This Batch, Loss: 3.3023
This Batch, Loss: 3.3553
This Batch, Loss: 3.2680
This Batch, Loss: 3.3137
This Batch, Loss: 3.2051
This Batch, Loss: 3.3131
This Batch, Loss: 3.3613
This Batch, Loss: 3.2707
This Batch, Loss: 3.3223
This Batch, Loss: 3.3074
This Batch, Loss: 3.3199
This Batch, Loss: 3.2924
This Batch, Loss: 3.3191
This Batch, Loss: 3.3028
This Batch, Loss: 3.2811
This Batch, Loss: 3.2798
This Batch, Loss: 3.3430
This Batch, Loss: 3.2899
This Batch, Loss: 3.2547
This Batch, Loss: 3.3581
This Batch, Loss: 3.2253
This Batch, Loss: 3.2878
This Batch, Loss: 3.3400
This Batch, Loss: 3.2581
This Batch, Loss: 3.2312
This Batch, Loss: 3.3400
This Batch, Loss: 3.3199
This Batch, Loss: 3.3407
This Batch, Loss: 3.2425
This Batch, Loss: 3.2835
This Batch, Loss: 3.2377
This Batch, Loss: 3.3376
This Batch, Loss: 3.3217
This Batch, Loss: 3.1862
This Batch, Loss: 3.2990
This Batch, Loss: 3.3276


This Batch, Loss: 3.3501
This Batch, Loss: 3.2713
This Batch, Loss: 3.2945
This Batch, Loss: 3.2253
This Batch, Loss: 3.1939
This Batch, Loss: 3.3477
This Batch, Loss: 3.2668
This Batch, Loss: 3.2326
This Batch, Loss: 3.2978
This Batch, Loss: 3.3975
This Batch, Loss: 3.2555
This Batch, Loss: 3.1914
This Batch, Loss: 3.2473
This Batch, Loss: 3.2482
This Batch, Loss: 3.3877
This Batch, Loss: 3.2437
This Batch, Loss: 3.2878
This Batch, Loss: 3.3254
This Batch, Loss: 3.2779
This Batch, Loss: 3.3419
This Batch, Loss: 3.3763
This Batch, Loss: 3.3451
This Batch, Loss: 3.3800
This Batch, Loss: 3.3539
This Batch, Loss: 3.3263
This Batch, Loss: 3.3794
This Batch, Loss: 3.1928
This Batch, Loss: 3.2818
This Batch, Loss: 3.3034
This Batch, Loss: 3.3328
This Batch, Loss: 3.3065
This Batch, Loss: 3.2756
This Batch, Loss: 3.3323
This Batch, Loss: 3.2761
This Batch, Loss: 3.2974
This Batch, Loss: 3.2510
This Batch, Loss: 3.2391
This Batch, Loss: 3.2895
This Batch, Loss: 3.3259
This Batch, Loss: 3.2045


This Batch, Loss: 3.2542
This Batch, Loss: 3.2770
This Batch, Loss: 3.2808
This Batch, Loss: 3.2603
This Batch, Loss: 3.3145
This Batch, Loss: 3.3039
This Batch, Loss: 3.3172
This Batch, Loss: 3.2935
This Batch, Loss: 3.3510
This Batch, Loss: 3.2449
This Batch, Loss: 3.3312
This Batch, Loss: 3.3352
This Batch, Loss: 3.2223
This Batch, Loss: 3.3248
This Batch, Loss: 3.2622
This Batch, Loss: 3.2481
This Batch, Loss: 3.2153
This Batch, Loss: 3.2657
This Batch, Loss: 3.2417
This Batch, Loss: 3.2607
This Batch, Loss: 3.2455
This Batch, Loss: 3.3223
This Batch, Loss: 3.3207
This Batch, Loss: 3.3052
This Batch, Loss: 3.2861
This Batch, Loss: 3.3147
This Batch, Loss: 3.3269
This Batch, Loss: 3.2394
This Batch, Loss: 3.2101
This Batch, Loss: 3.2586
This Batch, Loss: 3.2863
This Batch, Loss: 3.2585
This Batch, Loss: 3.2584
This Batch, Loss: 3.2482
This Batch, Loss: 3.3822
This Batch, Loss: 3.2673
This Batch, Loss: 3.2514
This Batch, Loss: 3.4033
This Batch, Loss: 3.2735
This Batch, Loss: 3.3498


KeyboardInterrupt: 