<a href="https://colab.research.google.com/github/ValentinaEmili/Sign_language/blob/main/ASL_recognition_100.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

The glossary is made of 100 different words but the instances for each word are not the same as the ones in WLASL_v0.3 file. Indeed, some links were broken and the correspective instances have been removed. Every word has at least one instance.

In [1]:
# mount google drive on colab
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import pandas as pd
import cv2
from google.colab.patches import cv2_imshow
from tqdm import tqdm
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence, pad_packed_sequence
from torch.nn import LSTM
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, Dataset
import shutil

In [3]:
js_100 = pd.read_json("/content/drive/MyDrive/NLP/WLASL100.json")
folder = "/content/drive/MyDrive/NLP/dataset/subset_100/"
original_folder = "/content/drive/MyDrive/NLP/dataset/"

training_folder = folder + "train/"
validation_folder = folder + "val/"
test_folder = folder + "test/"

training_video = training_folder + "video/"
validation_video = validation_folder + "video/"
test_video = test_folder + "video/"

training_images = training_folder + "images/"
validation_images = validation_folder + "images/"
test_images = test_folder + "images/"

os.makedirs(training_video, exist_ok=True)
os.makedirs(validation_video, exist_ok=True)
os.makedirs(test_video, exist_ok=True)

os.makedirs(training_images, exist_ok=True)
os.makedirs(validation_images, exist_ok=True)
os.makedirs(test_images, exist_ok=True)


Preprocess the data

In [4]:
train_gloss, val_gloss, test_gloss = set(), set(), set()
for image in os.listdir(training_images):
  word, _ = image.split("_")
  train_gloss.add(word)

for image in os.listdir(validation_images):
  word, _ = image.split("_")
  val_gloss.add(word)

for image in os.listdir(test_images):
  word, _ = image.split("_")
  test_gloss.add(word)
gloss = sorted(list(train_gloss | val_gloss | test_gloss))


label_map = {label: num for num, label in enumerate(gloss)}

In [8]:
print('train set:', len(os.listdir(training_images)))
print('val set:', len(os.listdir(validation_images)))
print('test set:', len(os.listdir(test_images)))
print('tot dataset:', len(os.listdir(training_images))+len(os.listdir(validation_images))+len(os.listdir(test_images)))

train set: 914
val set: 211
test set: 189
tot dataset: 1314


Build and train LSTM Neural Network

In [15]:
class SignLanguageDataset(Dataset):
  def __init__(self, image_dir, label_map):
     self.image_dir = image_dir
     self.label_map = label_map
     self.files = sorted(os.listdir(image_dir))

  def __len__(self):
    return len(self.files)

  def __getitem__(self, idx):
    file_name = self.files[idx]
    np_array = np.load(os.path.join(self.image_dir, file_name))
    if np_array.size == 0 or len(np_array.shape) != 2 or np_array.shape[1] != 258:
      print(f"Warning: Empty or invalid shape for file: {file_name}")
      np_array = np.zeros((1, 258), dtype=np.float32)

    label, _ = file_name.split("_")
    label = self.label_map[label]

    return torch.tensor(np_array, dtype=torch.float32), torch.tensor(label, dtype=torch.long)

# Add zero-padding to get sequences of the same length for each batch
def collate_fn(batch):
  sequences, labels = zip(*batch)
  lengths = [len(seq) for seq in sequences]
  padded_sequences = pad_sequence(sequences, batch_first=True)

  # pack the padded sequence
  packed_sequences = pack_padded_sequence(padded_sequences, lengths, batch_first=True, enforce_sorted=False)
  return packed_sequences, torch.tensor(labels)

train_dataset = SignLanguageDataset(training_images, label_map)
val_dataset = SignLanguageDataset(validation_images, label_map)
test_dataset = SignLanguageDataset(test_images, label_map)

batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

In [22]:
class SignLanguageLSTM(nn.Module):
  def __init__(self, input_size, hidden_size, num_classes, dropout_rate=0.5):
    super(SignLanguageLSTM, self).__init__()

    # input regularization
    self.input_bn = nn.BatchNorm1d(input_size)
    self.input_dropout = nn.Dropout(0.3)

    # single bidirectional LSTM layer
    self.lstm = nn.LSTM(
        input_size=input_size,
        hidden_size=hidden_size,
        batch_first=True,
        dropout=dropout_rate,
        bidirectional=True)

    # fully connected layers
    self.fc1 = nn.Linear(hidden_size * 2, hidden_size)
    self.fc2 = nn.Linear(hidden_size, num_classes)

    self.dropout = nn.Dropout(dropout_rate)

  def forward(self, packed_input):
    # unpack input for batch normalization
    padded_input, lengths = pad_packed_sequence(packed_input, batch_first=True)

    # apply input normalization and dropout
    padded_input = padded_input.transpose(1, 2)
    padded_input = self.input_bn(padded_input)
    padded_input = padded_input.transpose(1, 2)
    padded_input = self.input_dropout(padded_input)

    # re-pack input
    packed_input = pack_padded_sequence(padded_input, lengths, batch_first=True, enforce_sorted=False)

    # LSTM
    packed_output, (hn, cn) = self.lstm(packed_input)

    output_forward = hn[0, :, :] # last hidden state for forward direction
    output_backward = hn[1, :, :] # last hidden state for backward direction
    output = torch.cat((output_forward, output_backward), dim=1)

    output = F.relu(self.fc1(output))
    output = self.fc2(self.dropout(output))

    return output

In [None]:
all_labels = [image.split("_")[0] for image in os.listdir(training_images)]
label_counts = {label: all_labels.count(label) for label in label_map}
weight = sum(label_counts.values()) / len(label_counts)
weights = torch.tensor([weight / count for _, count in label_counts.items()], dtype=torch.float32)

In [25]:
# training configuration tailored for small datasets
def get_training_config():
  return {
    'hidden_size': 256, # {64, 128, 256}
    'learning_rate': 1e-4, # {1e-3, 1e-4, 1e-5}
    'num_epochs': 100,
    'weight_decay': 1e-4, # {1e-2, 1e-3}
    'dropout_rate': 0.1,
    'scheduler_params': {
      'factor': 0.5,
      'min_lr': 1e-6
    },
  }

best_accuracy = 0.0
training_history = []
config = get_training_config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = config['num_epochs']

model = SignLanguageLSTM(
    input_size=258,
    hidden_size = config['hidden_size'],
    num_classes = len(label_map),
    dropout_rate=config['dropout_rate']).to(device)

criterion = nn.CrossEntropyLoss() # for multi-class classification
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=config['learning_rate'],
    weight_decay=config['weight_decay'])

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='max',
    factor=config['scheduler_params']['factor'],
    min_lr=config['scheduler_params']['min_lr'],
    patience=5)

for epoch in range(num_epochs):
  model.train()
  running_loss = 0.0

  for inputs, labels in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs} [Train]'):
    inputs, labels = inputs.to(device), labels.to(device)

    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    optimizer.step()

    running_loss += loss.item()

  avg_train_loss = running_loss / len(train_loader)
  print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_train_loss:.4f}')

  # evaluation phase
  model.eval()
  val_loss, correct, total = 0, 0, 0

  with torch.no_grad():
    for inputs, labels in tqdm(val_loader, desc=f'Epoch {epoch + 1}/{num_epochs} [Valid]'):
      inputs, labels = inputs.to(device), labels.to(device)
      outputs = model(inputs)
      loss = criterion(outputs, labels)
      val_loss += loss.item()

      _, predicted = torch.max(outputs, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()
  accuracy = correct / total
  avg_val_loss = val_loss / len(val_loader)

  # update learning rate based on validation accuracy
  scheduler.step(accuracy)

  print(f'Validation Accuracy: {accuracy * 100:.2f}%')
  print(f'Epoch [{epoch+1}/{num_epochs}], Validation Loss: {avg_val_loss:.4f}')

  # store training history
  training_history.append({
    'epoch': epoch + 1,
    'train_loss': avg_train_loss,
    'val_loss': avg_val_loss,
    'acc': round(accuracy * 100, 2),  # store as percentage
    'lr': optimizer.param_groups[0]['lr']
    })

  # save best model
  if accuracy > best_accuracy:
    best_accuracy = accuracy
    torch.save({
      'epoch': epoch,
      'model_state_dict': model.state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      'accuracy': accuracy, # saved as decimal
      'val_loss': avg_val_loss,
    }, '/content/drive/MyDrive/NLP/saved_models/best_model_100.pth')
    print(f'Saved new best model with accuracy: {best_accuracy * 100:.2f}%\n-----')

Epoch 1/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 14.78it/s]


Epoch [1/100], Loss: 4.6099


Epoch 1/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.86it/s]


Validation Accuracy: 1.42%
Epoch [1/100], Validation Loss: 4.6051
Saved new best model with accuracy: 1.42%
-----


Epoch 2/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.81it/s]


Epoch [2/100], Loss: 4.5894


Epoch 2/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.57it/s]


Validation Accuracy: 0.95%
Epoch [2/100], Validation Loss: 4.6032


Epoch 3/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.52it/s]


Epoch [3/100], Loss: 4.5630


Epoch 3/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.09it/s]


Validation Accuracy: 1.90%
Epoch [3/100], Validation Loss: 4.5933
Saved new best model with accuracy: 1.90%
-----


Epoch 4/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.52it/s]


Epoch [4/100], Loss: 4.5330


Epoch 4/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.55it/s]


Validation Accuracy: 2.37%
Epoch [4/100], Validation Loss: 4.5851
Saved new best model with accuracy: 2.37%
-----


Epoch 5/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.78it/s]


Epoch [5/100], Loss: 4.4793


Epoch 5/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 20.07it/s]


Validation Accuracy: 3.32%
Epoch [5/100], Validation Loss: 4.5525
Saved new best model with accuracy: 3.32%
-----


Epoch 6/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.72it/s]


Epoch [6/100], Loss: 4.3986


Epoch 6/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 17.64it/s]


Validation Accuracy: 2.84%
Epoch [6/100], Validation Loss: 4.4630


Epoch 7/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 14.62it/s]


Epoch [7/100], Loss: 4.2990


Epoch 7/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.18it/s]


Validation Accuracy: 2.84%
Epoch [7/100], Validation Loss: 4.3946


Epoch 8/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.81it/s]


Epoch [8/100], Loss: 4.1843


Epoch 8/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.38it/s]


Validation Accuracy: 3.79%
Epoch [8/100], Validation Loss: 4.3272
Saved new best model with accuracy: 3.79%
-----


Epoch 9/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.24it/s]


Epoch [9/100], Loss: 4.1140


Epoch 9/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 17.91it/s]


Validation Accuracy: 3.32%
Epoch [9/100], Validation Loss: 4.2996


Epoch 10/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.15it/s]


Epoch [10/100], Loss: 4.0470


Epoch 10/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 18.01it/s]


Validation Accuracy: 2.84%
Epoch [10/100], Validation Loss: 4.2817


Epoch 11/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 14.55it/s]


Epoch [11/100], Loss: 4.0119


Epoch 11/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.10it/s]


Validation Accuracy: 3.79%
Epoch [11/100], Validation Loss: 4.2419


Epoch 12/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.03it/s]


Epoch [12/100], Loss: 3.9612


Epoch 12/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 18.95it/s]


Validation Accuracy: 3.79%
Epoch [12/100], Validation Loss: 4.2297


Epoch 13/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.53it/s]


Epoch [13/100], Loss: 3.8954


Epoch 13/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.10it/s]


Validation Accuracy: 4.27%
Epoch [13/100], Validation Loss: 4.2076
Saved new best model with accuracy: 4.27%
-----


Epoch 14/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.68it/s]


Epoch [14/100], Loss: 3.8464


Epoch 14/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.25it/s]


Validation Accuracy: 3.79%
Epoch [14/100], Validation Loss: 4.1461


Epoch 15/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 14.92it/s]


Epoch [15/100], Loss: 3.7971


Epoch 15/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.04it/s]


Validation Accuracy: 4.27%
Epoch [15/100], Validation Loss: 4.1417


Epoch 16/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 14.72it/s]


Epoch [16/100], Loss: 3.7272


Epoch 16/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.42it/s]


Validation Accuracy: 4.74%
Epoch [16/100], Validation Loss: 4.0598
Saved new best model with accuracy: 4.74%
-----


Epoch 17/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.76it/s]


Epoch [17/100], Loss: 3.6657


Epoch 17/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 17.93it/s]


Validation Accuracy: 4.27%
Epoch [17/100], Validation Loss: 4.0130


Epoch 18/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.20it/s]


Epoch [18/100], Loss: 3.6195


Epoch 18/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.63it/s]


Validation Accuracy: 4.74%
Epoch [18/100], Validation Loss: 3.9768


Epoch 19/100 [Train]: 100%|██████████| 58/58 [00:04<00:00, 13.78it/s]


Epoch [19/100], Loss: 3.5441


Epoch 19/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.13it/s]


Validation Accuracy: 7.11%
Epoch [19/100], Validation Loss: 3.9218
Saved new best model with accuracy: 7.11%
-----


Epoch 20/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.33it/s]


Epoch [20/100], Loss: 3.4827


Epoch 20/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 18.94it/s]


Validation Accuracy: 6.64%
Epoch [20/100], Validation Loss: 3.9077


Epoch 21/100 [Train]: 100%|██████████| 58/58 [00:04<00:00, 14.33it/s]


Epoch [21/100], Loss: 3.4483


Epoch 21/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.49it/s]


Validation Accuracy: 8.06%
Epoch [21/100], Validation Loss: 3.8722
Saved new best model with accuracy: 8.06%
-----


Epoch 22/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.92it/s]


Epoch [22/100], Loss: 3.3483


Epoch 22/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.49it/s]


Validation Accuracy: 7.58%
Epoch [22/100], Validation Loss: 3.8390


Epoch 23/100 [Train]: 100%|██████████| 58/58 [00:04<00:00, 13.14it/s]


Epoch [23/100], Loss: 3.3063


Epoch 23/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 17.75it/s]


Validation Accuracy: 8.06%
Epoch [23/100], Validation Loss: 3.7866


Epoch 24/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 16.05it/s]


Epoch [24/100], Loss: 3.2617


Epoch 24/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.30it/s]


Validation Accuracy: 8.53%
Epoch [24/100], Validation Loss: 3.7661
Saved new best model with accuracy: 8.53%
-----


Epoch 25/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.96it/s]


Epoch [25/100], Loss: 3.1731


Epoch 25/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.42it/s]


Validation Accuracy: 8.06%
Epoch [25/100], Validation Loss: 3.7401


Epoch 26/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.44it/s]


Epoch [26/100], Loss: 3.1470


Epoch 26/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 17.17it/s]


Validation Accuracy: 8.06%
Epoch [26/100], Validation Loss: 3.7158


Epoch 27/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 14.84it/s]


Epoch [27/100], Loss: 3.0743


Epoch 27/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.75it/s]


Validation Accuracy: 9.00%
Epoch [27/100], Validation Loss: 3.6712
Saved new best model with accuracy: 9.00%
-----


Epoch 28/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.82it/s]


Epoch [28/100], Loss: 3.0379


Epoch 28/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.28it/s]


Validation Accuracy: 8.53%
Epoch [28/100], Validation Loss: 3.6886


Epoch 29/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.08it/s]


Epoch [29/100], Loss: 3.0067


Epoch 29/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 18.77it/s]


Validation Accuracy: 11.85%
Epoch [29/100], Validation Loss: 3.6408
Saved new best model with accuracy: 11.85%
-----


Epoch 30/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.89it/s]


Epoch [30/100], Loss: 2.9274


Epoch 30/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.58it/s]


Validation Accuracy: 10.43%
Epoch [30/100], Validation Loss: 3.6209


Epoch 31/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 14.78it/s]


Epoch [31/100], Loss: 2.9096


Epoch 31/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.77it/s]


Validation Accuracy: 8.53%
Epoch [31/100], Validation Loss: 3.5614


Epoch 32/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.05it/s]


Epoch [32/100], Loss: 2.8527


Epoch 32/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.59it/s]


Validation Accuracy: 11.85%
Epoch [32/100], Validation Loss: 3.5120


Epoch 33/100 [Train]: 100%|██████████| 58/58 [00:04<00:00, 13.73it/s]


Epoch [33/100], Loss: 2.7603


Epoch 33/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.12it/s]


Validation Accuracy: 10.90%
Epoch [33/100], Validation Loss: 3.5084


Epoch 34/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 16.01it/s]


Epoch [34/100], Loss: 2.7578


Epoch 34/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 18.60it/s]


Validation Accuracy: 11.85%
Epoch [34/100], Validation Loss: 3.4704


Epoch 35/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.38it/s]


Epoch [35/100], Loss: 2.6962


Epoch 35/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.55it/s]


Validation Accuracy: 12.32%
Epoch [35/100], Validation Loss: 3.4642
Saved new best model with accuracy: 12.32%
-----


Epoch 36/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.93it/s]


Epoch [36/100], Loss: 2.6431


Epoch 36/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.64it/s]


Validation Accuracy: 13.27%
Epoch [36/100], Validation Loss: 3.4501
Saved new best model with accuracy: 13.27%
-----


Epoch 37/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.95it/s]


Epoch [37/100], Loss: 2.6001


Epoch 37/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 18.80it/s]


Validation Accuracy: 13.74%
Epoch [37/100], Validation Loss: 3.4856
Saved new best model with accuracy: 13.74%
-----


Epoch 38/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.35it/s]


Epoch [38/100], Loss: 2.5855


Epoch 38/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.82it/s]


Validation Accuracy: 11.85%
Epoch [38/100], Validation Loss: 3.4605


Epoch 39/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 14.87it/s]


Epoch [39/100], Loss: 2.5178


Epoch 39/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.65it/s]


Validation Accuracy: 13.74%
Epoch [39/100], Validation Loss: 3.4367


Epoch 40/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.75it/s]


Epoch [40/100], Loss: 2.4502


Epoch 40/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 18.87it/s]


Validation Accuracy: 14.22%
Epoch [40/100], Validation Loss: 3.3919
Saved new best model with accuracy: 14.22%
-----


Epoch 41/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.28it/s]


Epoch [41/100], Loss: 2.4324


Epoch 41/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 20.10it/s]


Validation Accuracy: 13.27%
Epoch [41/100], Validation Loss: 3.3977


Epoch 42/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.98it/s]


Epoch [42/100], Loss: 2.4110


Epoch 42/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.80it/s]


Validation Accuracy: 14.69%
Epoch [42/100], Validation Loss: 3.3776
Saved new best model with accuracy: 14.69%
-----


Epoch 43/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.49it/s]


Epoch [43/100], Loss: 2.3170


Epoch 43/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 18.85it/s]


Validation Accuracy: 18.01%
Epoch [43/100], Validation Loss: 3.3363
Saved new best model with accuracy: 18.01%
-----


Epoch 44/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.54it/s]


Epoch [44/100], Loss: 2.2740


Epoch 44/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.49it/s]


Validation Accuracy: 15.64%
Epoch [44/100], Validation Loss: 3.3269


Epoch 45/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 14.52it/s]


Epoch [45/100], Loss: 2.2421


Epoch 45/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.47it/s]


Validation Accuracy: 17.06%
Epoch [45/100], Validation Loss: 3.2634


Epoch 46/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.43it/s]


Epoch [46/100], Loss: 2.2119


Epoch 46/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 18.71it/s]


Validation Accuracy: 16.11%
Epoch [46/100], Validation Loss: 3.3796


Epoch 47/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 14.55it/s]


Epoch [47/100], Loss: 2.1607


Epoch 47/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.71it/s]


Validation Accuracy: 17.54%
Epoch [47/100], Validation Loss: 3.3219


Epoch 48/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.84it/s]


Epoch [48/100], Loss: 2.1818


Epoch 48/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.34it/s]


Validation Accuracy: 18.01%
Epoch [48/100], Validation Loss: 3.2850


Epoch 49/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.41it/s]


Epoch [49/100], Loss: 2.1508


Epoch 49/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 17.82it/s]


Validation Accuracy: 19.91%
Epoch [49/100], Validation Loss: 3.3167
Saved new best model with accuracy: 19.91%
-----


Epoch 50/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.74it/s]


Epoch [50/100], Loss: 2.0532


Epoch 50/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.75it/s]


Validation Accuracy: 20.85%
Epoch [50/100], Validation Loss: 3.2707
Saved new best model with accuracy: 20.85%
-----


Epoch 51/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.82it/s]


Epoch [51/100], Loss: 2.0239


Epoch 51/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.51it/s]


Validation Accuracy: 19.91%
Epoch [51/100], Validation Loss: 3.1837


Epoch 52/100 [Train]: 100%|██████████| 58/58 [00:04<00:00, 13.58it/s]


Epoch [52/100], Loss: 1.9935


Epoch 52/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 18.86it/s]


Validation Accuracy: 19.43%
Epoch [52/100], Validation Loss: 3.2516


Epoch 53/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.77it/s]


Epoch [53/100], Loss: 2.0148


Epoch 53/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.75it/s]


Validation Accuracy: 20.38%
Epoch [53/100], Validation Loss: 3.2469


Epoch 54/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.82it/s]


Epoch [54/100], Loss: 1.9030


Epoch 54/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.41it/s]


Validation Accuracy: 20.38%
Epoch [54/100], Validation Loss: 3.2574


Epoch 55/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.20it/s]


Epoch [55/100], Loss: 1.8198


Epoch 55/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.90it/s]


Validation Accuracy: 18.96%
Epoch [55/100], Validation Loss: 3.2332


Epoch 56/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.88it/s]


Epoch [56/100], Loss: 1.8759


Epoch 56/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.52it/s]


Validation Accuracy: 20.38%
Epoch [56/100], Validation Loss: 3.2244


Epoch 57/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.95it/s]


Epoch [57/100], Loss: 1.7949


Epoch 57/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.60it/s]


Validation Accuracy: 22.75%
Epoch [57/100], Validation Loss: 3.1271
Saved new best model with accuracy: 22.75%
-----


Epoch 58/100 [Train]: 100%|██████████| 58/58 [00:04<00:00, 13.20it/s]


Epoch [58/100], Loss: 1.7083


Epoch 58/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 17.74it/s]


Validation Accuracy: 21.80%
Epoch [58/100], Validation Loss: 3.1099


Epoch 59/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.56it/s]


Epoch [59/100], Loss: 1.6396


Epoch 59/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.27it/s]


Validation Accuracy: 22.27%
Epoch [59/100], Validation Loss: 3.1462


Epoch 60/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 14.92it/s]


Epoch [60/100], Loss: 1.6532


Epoch 60/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.52it/s]


Validation Accuracy: 23.22%
Epoch [60/100], Validation Loss: 3.1211
Saved new best model with accuracy: 23.22%
-----


Epoch 61/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.03it/s]


Epoch [61/100], Loss: 1.6080


Epoch 61/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.46it/s]


Validation Accuracy: 21.80%
Epoch [61/100], Validation Loss: 3.1573


Epoch 62/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.80it/s]


Epoch [62/100], Loss: 1.5791


Epoch 62/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.32it/s]


Validation Accuracy: 24.64%
Epoch [62/100], Validation Loss: 3.1084
Saved new best model with accuracy: 24.64%
-----


Epoch 63/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.80it/s]


Epoch [63/100], Loss: 1.5454


Epoch 63/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.12it/s]


Validation Accuracy: 24.17%
Epoch [63/100], Validation Loss: 3.0828


Epoch 64/100 [Train]: 100%|██████████| 58/58 [00:04<00:00, 13.53it/s]


Epoch [64/100], Loss: 1.5327


Epoch 64/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.61it/s]


Validation Accuracy: 24.17%
Epoch [64/100], Validation Loss: 3.0299


Epoch 65/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.91it/s]


Epoch [65/100], Loss: 1.5133


Epoch 65/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.66it/s]


Validation Accuracy: 24.64%
Epoch [65/100], Validation Loss: 3.1057


Epoch 66/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.73it/s]


Epoch [66/100], Loss: 1.5213


Epoch 66/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 18.53it/s]


Validation Accuracy: 24.64%
Epoch [66/100], Validation Loss: 3.0788


Epoch 67/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.20it/s]


Epoch [67/100], Loss: 1.4337


Epoch 67/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.09it/s]


Validation Accuracy: 24.17%
Epoch [67/100], Validation Loss: 3.1046


Epoch 68/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.92it/s]


Epoch [68/100], Loss: 1.5140


Epoch 68/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.40it/s]


Validation Accuracy: 27.49%
Epoch [68/100], Validation Loss: 3.0396
Saved new best model with accuracy: 27.49%
-----


Epoch 69/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.75it/s]


Epoch [69/100], Loss: 1.4485


Epoch 69/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 18.93it/s]


Validation Accuracy: 24.64%
Epoch [69/100], Validation Loss: 3.1149


Epoch 70/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.28it/s]


Epoch [70/100], Loss: 1.4047


Epoch 70/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.67it/s]


Validation Accuracy: 24.64%
Epoch [70/100], Validation Loss: 3.0886


Epoch 71/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 14.85it/s]


Epoch [71/100], Loss: 1.4132


Epoch 71/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.41it/s]


Validation Accuracy: 23.22%
Epoch [71/100], Validation Loss: 3.0626


Epoch 72/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.52it/s]


Epoch [72/100], Loss: 1.3845


Epoch 72/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 18.76it/s]


Validation Accuracy: 26.54%
Epoch [72/100], Validation Loss: 3.0198


Epoch 73/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.38it/s]


Epoch [73/100], Loss: 1.4416


Epoch 73/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.73it/s]


Validation Accuracy: 24.17%
Epoch [73/100], Validation Loss: 3.0533


Epoch 74/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.87it/s]


Epoch [74/100], Loss: 1.3767


Epoch 74/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.31it/s]


Validation Accuracy: 30.33%
Epoch [74/100], Validation Loss: 3.0307
Saved new best model with accuracy: 30.33%
-----


Epoch 75/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.52it/s]


Epoch [75/100], Loss: 1.3448


Epoch 75/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 18.61it/s]


Validation Accuracy: 24.17%
Epoch [75/100], Validation Loss: 3.0547


Epoch 76/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.51it/s]


Epoch [76/100], Loss: 1.3226


Epoch 76/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.70it/s]


Validation Accuracy: 27.01%
Epoch [76/100], Validation Loss: 3.0839


Epoch 77/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 14.93it/s]


Epoch [77/100], Loss: 1.3292


Epoch 77/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.33it/s]


Validation Accuracy: 22.75%
Epoch [77/100], Validation Loss: 3.0139


Epoch 78/100 [Train]: 100%|██████████| 58/58 [00:04<00:00, 14.46it/s]


Epoch [78/100], Loss: 1.2648


Epoch 78/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 18.16it/s]


Validation Accuracy: 24.64%
Epoch [78/100], Validation Loss: 3.0763


Epoch 79/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.44it/s]


Epoch [79/100], Loss: 1.2655


Epoch 79/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.05it/s]


Validation Accuracy: 27.96%
Epoch [79/100], Validation Loss: 3.0893


Epoch 80/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.56it/s]


Epoch [80/100], Loss: 1.2742


Epoch 80/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.31it/s]


Validation Accuracy: 27.96%
Epoch [80/100], Validation Loss: 2.9864


Epoch 81/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.32it/s]


Epoch [81/100], Loss: 1.1794


Epoch 81/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 17.21it/s]


Validation Accuracy: 27.96%
Epoch [81/100], Validation Loss: 2.9954


Epoch 82/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.90it/s]


Epoch [82/100], Loss: 1.1388


Epoch 82/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.50it/s]


Validation Accuracy: 24.64%
Epoch [82/100], Validation Loss: 3.0354


Epoch 83/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.93it/s]


Epoch [83/100], Loss: 1.1380


Epoch 83/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.67it/s]


Validation Accuracy: 27.01%
Epoch [83/100], Validation Loss: 3.0621


Epoch 84/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.48it/s]


Epoch [84/100], Loss: 1.1203


Epoch 84/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 17.73it/s]


Validation Accuracy: 28.44%
Epoch [84/100], Validation Loss: 3.0326


Epoch 85/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.83it/s]


Epoch [85/100], Loss: 1.1790


Epoch 85/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.69it/s]


Validation Accuracy: 27.01%
Epoch [85/100], Validation Loss: 3.0932


Epoch 86/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.83it/s]


Epoch [86/100], Loss: 1.2040


Epoch 86/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.10it/s]


Validation Accuracy: 26.07%
Epoch [86/100], Validation Loss: 3.0506


Epoch 87/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.06it/s]


Epoch [87/100], Loss: 1.0802


Epoch 87/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 18.53it/s]


Validation Accuracy: 26.54%
Epoch [87/100], Validation Loss: 3.0362


Epoch 88/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.94it/s]


Epoch [88/100], Loss: 1.1025


Epoch 88/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.84it/s]


Validation Accuracy: 26.54%
Epoch [88/100], Validation Loss: 3.0206


Epoch 89/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.90it/s]


Epoch [89/100], Loss: 1.1253


Epoch 89/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.44it/s]


Validation Accuracy: 24.64%
Epoch [89/100], Validation Loss: 3.0417


Epoch 90/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 14.94it/s]


Epoch [90/100], Loss: 1.1144


Epoch 90/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.91it/s]


Validation Accuracy: 25.12%
Epoch [90/100], Validation Loss: 3.0182


Epoch 91/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.99it/s]


Epoch [91/100], Loss: 1.0880


Epoch 91/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 20.05it/s]


Validation Accuracy: 27.01%
Epoch [91/100], Validation Loss: 2.9886


Epoch 92/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.88it/s]


Epoch [92/100], Loss: 1.1013


Epoch 92/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.79it/s]


Validation Accuracy: 24.17%
Epoch [92/100], Validation Loss: 3.0681


Epoch 93/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 14.80it/s]


Epoch [93/100], Loss: 1.0633


Epoch 93/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.52it/s]


Validation Accuracy: 27.96%
Epoch [93/100], Validation Loss: 3.0057


Epoch 94/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.97it/s]


Epoch [94/100], Loss: 1.0545


Epoch 94/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.95it/s]


Validation Accuracy: 26.07%
Epoch [94/100], Validation Loss: 3.0328


Epoch 95/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 16.07it/s]


Epoch [95/100], Loss: 1.0441


Epoch 95/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 20.20it/s]


Validation Accuracy: 25.12%
Epoch [95/100], Validation Loss: 3.0785


Epoch 96/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 14.99it/s]


Epoch [96/100], Loss: 1.0559


Epoch 96/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.11it/s]


Validation Accuracy: 27.49%
Epoch [96/100], Validation Loss: 2.9991


Epoch 97/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 16.04it/s]


Epoch [97/100], Loss: 1.0521


Epoch 97/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.39it/s]


Validation Accuracy: 26.07%
Epoch [97/100], Validation Loss: 2.9932


Epoch 98/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.80it/s]


Epoch [98/100], Loss: 1.0228


Epoch 98/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.46it/s]


Validation Accuracy: 27.96%
Epoch [98/100], Validation Loss: 3.0018


Epoch 99/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.03it/s]


Epoch [99/100], Loss: 1.0546


Epoch 99/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.49it/s]


Validation Accuracy: 29.86%
Epoch [99/100], Validation Loss: 2.9755


Epoch 100/100 [Train]: 100%|██████████| 58/58 [00:03<00:00, 15.82it/s]


Epoch [100/100], Loss: 1.0696


Epoch 100/100 [Valid]: 100%|██████████| 14/14 [00:00<00:00, 19.24it/s]

Validation Accuracy: 28.91%
Epoch [100/100], Validation Loss: 3.0105





In [26]:
model.eval()
correct, total = 0, 0
with torch.no_grad():
  for inputs, labels in test_loader:
    inputs, labels = inputs.to(device), labels.to(device)
    outputs = model(inputs)
    _, predicted = torch.max(outputs, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
test_accuracy = correct / total  # Test accuracy
print(f'Test Accuracy: {test_accuracy * 100:.2f}%')

Test Accuracy: 22.22%
