In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os
from sklearn.metrics import accuracy_score
from torch.utils.tensorboard import SummaryWriter

# Define LSTM model
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size1, hidden_size2, hidden_size3, output_size):
        super(LSTMModel, self).__init__()
        self.lstm1 = nn.LSTM(input_size, hidden_size1, batch_first=True)
        self.lstm2 = nn.LSTM(hidden_size1, hidden_size2, batch_first=True)
        self.lstm3 = nn.LSTM(hidden_size2, hidden_size3, batch_first=True)
        self.fc1 = nn.Linear(hidden_size3, 64)
        self.fc2 = nn.Linear(64, output_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        x, _ = self.lstm1(x)
        x, _ = self.lstm2(x)
        x, _ = self.lstm3(x)
        if len(x.shape) == 3:
            x = x[:, -1, :]
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Paths and parameters
DATA_PATH = 'processed_data'
MODEL_PATH = 'models/sign_lstm.pth'
INPUT_SIZE = 225  # 21*3*2 + 33*3
HIDDEN_SIZE1 = 128
HIDDEN_SIZE2 = 64
HIDDEN_SIZE3 = 32
OUTPUT_SIZE = 6  
EPOCHS = 25000
BATCH_SIZE = 32
LEARNING_RATE = 0.001
ACCURACY_THRESHOLD = 0.95  

def load_data():
    X_train = np.load(os.path.join(DATA_PATH, 'X_train.npy'))
    y_train = np.load(os.path.join(DATA_PATH, 'y_train.npy'))
    X_test = np.load(os.path.join(DATA_PATH, 'X_test.npy'))
    y_test = np.load(os.path.join(DATA_PATH, 'y_test.npy'))
    labels = np.load(os.path.join(DATA_PATH, 'labels.npy'))
    
    if len(labels) != OUTPUT_SIZE:
        print(f"Error: Expected {OUTPUT_SIZE} signs, found {len(labels)}")
        raise ValueError("Incorrect number of signs")
    
    X_train = torch.tensor(X_train, dtype=torch.float32)
    y_train = torch.tensor(y_train, dtype=torch.long)
    X_test = torch.tensor(X_test, dtype=torch.float32)
    y_test = torch.tensor(y_test, dtype=torch.long)
    
    return X_train, y_train, X_test, y_test, labels

def train_model():
    X_train, y_train, X_test, y_test, labels = load_data()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model = LSTMModel(INPUT_SIZE, HIDDEN_SIZE1, HIDDEN_SIZE2, HIDDEN_SIZE3, OUTPUT_SIZE).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.9)  # Decay LR by 0.9 every 50 epochs
    writer = SummaryWriter()

    train_data = torch.utils.data.TensorDataset(X_train, y_train)
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

    try:
        for epoch in range(EPOCHS):
            model.train()
            running_loss = 0.0
            correct_predictions = 0
            total_samples = 0

            for X_batch, y_batch in train_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)

                outputs = model(X_batch)

                if y_batch.dim() > 1:
                    y_batch = torch.argmax(y_batch, axis=1)

                loss = criterion(outputs, y_batch)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                running_loss += loss.item()

                _, preds = torch.max(outputs, 1)
                correct_predictions += (preds == y_batch).sum().item()
                total_samples += y_batch.size(0)

            epoch_loss = running_loss / len(train_loader)
            epoch_accuracy = correct_predictions / total_samples

            writer.add_scalar("Loss/Train", epoch_loss, epoch)
            writer.add_scalar("Accuracy/Train", epoch_accuracy, epoch)
            writer.add_scalar("Learning Rate", optimizer.param_groups[0]['lr'], epoch)

            print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}")

            scheduler.step()

            # Early termination based on accuracy threshold
            if epoch_accuracy >= ACCURACY_THRESHOLD:
                print(f"Accuracy {epoch_accuracy:.4f} reached threshold {ACCURACY_THRESHOLD}. Terminating training.")
                os.makedirs('models', exist_ok=True)
                torch.save(model.state_dict(), MODEL_PATH)
                print(f"Model saved to {MODEL_PATH}")
                break

    except KeyboardInterrupt:
        print("\nKeyboard interrupt detected. Saving model before termination...")
        os.makedirs('models', exist_ok=True)
        torch.save(model.state_dict(), MODEL_PATH)
        print(f"Model saved to {MODEL_PATH} at interruption.")
    finally:
        # Ensure model is saved if training completes without reaching threshold or is interrupted
        if not os.path.exists(MODEL_PATH):
            os.makedirs('models', exist_ok=True)
            torch.save(model.state_dict(), MODEL_PATH)
            print(f"Training interrupted or completed without threshold. Model saved to {MODEL_PATH}")
        writer.close()

if __name__ == "__main__":
    train_model()

Epoch 1/25000, Loss: 1.7948, Accuracy: 0.2083
Epoch 2/25000, Loss: 1.7715, Accuracy: 0.2083
Epoch 3/25000, Loss: 1.7608, Accuracy: 0.2083
Epoch 4/25000, Loss: 1.7440, Accuracy: 0.3542
Epoch 5/25000, Loss: 1.7109, Accuracy: 0.3542
Epoch 6/25000, Loss: 1.6768, Accuracy: 0.4375
Epoch 7/25000, Loss: 1.6406, Accuracy: 0.5000
Epoch 8/25000, Loss: 1.6073, Accuracy: 0.3750
Epoch 9/25000, Loss: 1.5672, Accuracy: 0.3750
Epoch 10/25000, Loss: 1.5335, Accuracy: 0.3750
Epoch 11/25000, Loss: 1.4791, Accuracy: 0.4375
Epoch 12/25000, Loss: 1.4233, Accuracy: 0.5208
Epoch 13/25000, Loss: 1.3893, Accuracy: 0.5000
Epoch 14/25000, Loss: 1.3923, Accuracy: 0.4375
Epoch 15/25000, Loss: 1.2717, Accuracy: 0.4792
Epoch 16/25000, Loss: 1.2414, Accuracy: 0.6667
Epoch 17/25000, Loss: 1.1721, Accuracy: 0.7292
Epoch 18/25000, Loss: 1.0808, Accuracy: 0.7083
Epoch 19/25000, Loss: 1.0662, Accuracy: 0.8125
Epoch 20/25000, Loss: 1.0262, Accuracy: 0.7292
Epoch 21/25000, Loss: 0.9672, Accuracy: 0.6667
Epoch 22/25000, Loss: 