# Training LSTM Model

In [6]:
#Import modules (LSTM model from PyTorch)
import os
import json
import numpy as np
import torch
import torch.nn as nn
from Models import LSTMModel
from torch.utils.data import DataLoader, Dataset, random_split

In [7]:
# Parameters & filepaths
SEQUENCE_LENGTH = 5
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 0.001
LOGS_FOLDER = "./training_logs"
MODEL_SAVE_PATH = "./trained_models/lstm_model.pth"

In [8]:
# --- Prep Data ---
class CommunicationDataset(Dataset):
    def __init__(self, logs_folder, sequence_length):
        self.sequence_length = sequence_length
        self.samples = []

        all_entries = []
        for filename in os.listdir(logs_folder):
            if filename.endswith('.json'):
                with open(os.path.join(logs_folder, filename), 'r') as f:
                    data = json.load(f)
                    for entry in data:
                        all_entries.append(entry)

        # Group by robot_id to create sequences
        from collections import defaultdict
        robot_logs = defaultdict(list)
        for entry in all_entries:
            robot_logs[entry['robot_id']].append(entry)

        for robot_id, logs in robot_logs.items():
            logs = sorted(logs, key=lambda x: x['timestamp'])
            if len(logs) >= sequence_length:
                for i in range(len(logs) - sequence_length + 1):
                    seq = logs[i:i+sequence_length]
                    feature_seq = []
                    label_seq = []
                    for entry in seq:
                        pos = np.array(entry["position"])
                        num_tokens = entry["num_reported_tokens"]
                        avg_distance = np.mean([
                            np.linalg.norm(pos - np.array(token))
                            for token in entry["reported_tokens"]
                        ]) if entry["reported_tokens"] else 0
                        feature_seq.append(list(pos) + [num_tokens, avg_distance])
                        label_seq.append(entry["is_byzantine"])
                    self.samples.append((np.array(feature_seq), label_seq[-1]))  # Predict last label

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        x, y = self.samples[idx]
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.long)

In [None]:
# Train LSTM & Save at the end of the run
dataset = CommunicationDataset(LOGS_FOLDER, SEQUENCE_LENGTH)
print(f"Loaded {len(dataset)} sequences.")

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

model = LSTMModel(input_size=4, hidden_size=64, num_layers=1)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

for epoch in range(EPOCHS):
    model.train()
    for x_batch, y_batch in train_loader:
        optimizer.zero_grad()
        output = model(x_batch)
        loss = criterion(output, y_batch)
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {loss.item():.4f}")

# Evaluation on validation data
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for x_batch, y_batch in test_loader:
        outputs = model(x_batch)
        _, predicted = torch.max(outputs.data, 1)
        total += y_batch.size(0)
        correct += (predicted == y_batch).sum().item()

print(f"Test Accuracy: {100 * correct / total:.2f}%")

os.makedirs("trained_models", exist_ok=True)
torch.save(model.state_dict(), MODEL_SAVE_PATH)
print(f"Model saved to {MODEL_SAVE_PATH}")

Loaded 89100 sequences.
Epoch 1/10, Loss: 0.0000
Epoch 2/10, Loss: 0.0000
Epoch 3/10, Loss: 0.0000
Epoch 4/10, Loss: 0.0000
Epoch 5/10, Loss: 0.0000
Epoch 6/10, Loss: 0.0000
Epoch 7/10, Loss: 0.0000
Epoch 8/10, Loss: 0.0000
Epoch 9/10, Loss: 0.0000
Epoch 10/10, Loss: 0.0000
Test Accuracy: 100.00%
Model saved to lstm_model.pth
