# LSTM Final Training and Evaluation

This notebook trains LSTM with the best hyperparameters found during tuning and evaluates on the test set.

**Task**: Multiclass fault classification (18 classes)

**Data Split**:
- Train: Model fitting
- Validation: Early stopping monitoring
- Test: Final evaluation (never seen during training)

**Modes**:
- QUICK_MODE: Uses 1% of training data, 1 epoch (for testing pipeline)
- FULL MODE: Uses all training data with tuned hyperparameters

**Outputs**:
- Trained model: `outputs/models/lstm_final[_quick].pt`
- Metrics: `outputs/metrics/lstm_metrics[_quick].json`
- Confusion matrix: `outputs/figures/lstm_confusion_matrix[_quick].png`

## Configuration

In [None]:
import os
import sys
import time
import json
import pickle
from pathlib import Path

start_time = time.time()
print("="*60)
print("LSTM Final Training and Evaluation")
print("="*60)
print(f"Started at: {time.strftime('%Y-%m-%d %H:%M:%S')}")

# Quick mode configuration
QUICK_MODE = os.getenv('QUICK_MODE', 'False').lower() in ('true', '1', 'yes')

if QUICK_MODE:
    TRAIN_FRACTION = 0.01
    MAX_EPOCHS = 1
    PATIENCE = 1
    print("ðŸš€ QUICK MODE (1% data, 1 epoch)")
else:
    TRAIN_FRACTION = 1.0
    MAX_EPOCHS = 100
    PATIENCE = 10
    print("ðŸ”¬ FULL MODE (100% data, up to 100 epochs)")

# Paths
DATA_DIR = Path('../data')
OUTPUT_DIR = Path('../outputs')
HYPERPARAM_DIR = OUTPUT_DIR / 'hyperparams'
MODEL_DIR = OUTPUT_DIR / 'models'
METRICS_DIR = OUTPUT_DIR / 'metrics'
FIGURES_DIR = OUTPUT_DIR / 'figures'

MODEL_DIR.mkdir(parents=True, exist_ok=True)
METRICS_DIR.mkdir(parents=True, exist_ok=True)
FIGURES_DIR.mkdir(parents=True, exist_ok=True)

RANDOM_SEED = 42
MODE_SUFFIX = '_quick' if QUICK_MODE else ''

# Load best hyperparameters
if (HYPERPARAM_DIR / 'lstm_best.json').exists():
    hp_file = HYPERPARAM_DIR / 'lstm_best.json'
    print("Using FULL mode hyperparameters")
else:
    hp_file = HYPERPARAM_DIR / 'lstm_best_quick.json'
    print("Using QUICK mode hyperparameters")

with open(hp_file) as f:
    hp_data = json.load(f)
    best_params = hp_data['best_params']

print(f"\nHyperparameters:")
for k, v in best_params.items():
    print(f"  {k}: {v}")
print("="*60)

## Imports

In [None]:
print("\n[Step 1/6] Loading libraries...")
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import (
    accuracy_score, f1_score, precision_score, recall_score,
    classification_report, confusion_matrix, balanced_accuracy_score
)
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"âœ“ Using device: {device}")

## Data Loading and Preprocessing

In [None]:
print("\n[Step 2/6] Loading datasets...")
data_load_start = time.time()

train = pd.read_csv(DATA_DIR / 'multiclass_train.csv')
val = pd.read_csv(DATA_DIR / 'multiclass_val.csv')
test = pd.read_csv(DATA_DIR / 'multiclass_test.csv')

print(f"âœ“ Train: {train.shape}")
print(f"âœ“ Val: {val.shape}")
print(f"âœ“ Test: {test.shape}")
print(f"âœ“ Data loading time: {time.time() - data_load_start:.2f}s")

# Get feature columns
features = [col for col in train.columns if 'xmeas' in col or 'xmv' in col]
num_features = len(features)
print(f"âœ“ Number of features: {num_features}")

# Subsample if in quick mode (subsample runs, not random rows, to preserve sequences)
if TRAIN_FRACTION < 1.0:
    # Get unique simulation runs and sample them
    train_runs = train[['faultNumber', 'simulationRun']].drop_duplicates()
    val_runs = val[['faultNumber', 'simulationRun']].drop_duplicates()
    
    # Sample runs proportionally
    n_train_runs = max(1, int(len(train_runs) * TRAIN_FRACTION))
    n_val_runs = max(1, int(len(val_runs) * TRAIN_FRACTION))
    
    sampled_train_runs = train_runs.sample(n=n_train_runs, random_state=RANDOM_SEED)
    sampled_val_runs = val_runs.sample(n=n_val_runs, random_state=RANDOM_SEED)
    
    train = train.merge(sampled_train_runs, on=['faultNumber', 'simulationRun'])
    val = val.merge(sampled_val_runs, on=['faultNumber', 'simulationRun'])
    
    print(f"âœ“ Subsampled train to {TRAIN_FRACTION*100:.1f}%: {train.shape}")
    print(f"âœ“ Subsampled val to {TRAIN_FRACTION*100:.1f}%: {val.shape}")

# Fit scaler on training data only
scaler = StandardScaler()
scaler.fit(train[features])

# Fit label encoder on training data only
label_encoder = LabelEncoder()
label_encoder.fit(train['faultNumber'])
num_classes = len(label_encoder.classes_)
class_names = [str(int(c)) for c in label_encoder.classes_]
print(f"âœ“ Number of classes: {num_classes}")

## Model and Dataset Definition

In [None]:
print("\n[Step 3/6] Defining model and dataset...")

class SimulationRunDataset(Dataset):
    """
    Dataset that creates windows WITHIN simulation runs only.
    No windows cross simulation run boundaries.
    """
    def __init__(self, df, features, scaler, label_encoder, sequence_length=10):
        self.seq_len = sequence_length
        self.windows = []
        self.labels = []
        
        for (fault, run), group in df.groupby(['faultNumber', 'simulationRun']):
            group = group.sort_values('sample')
            X = scaler.transform(group[features].values)
            y = label_encoder.transform(group['faultNumber'].values)
            
            for i in range(len(X) - sequence_length + 1):
                self.windows.append(X[i:i+sequence_length])
                self.labels.append(y[i+sequence_length-1])
        
        self.windows = np.array(self.windows, dtype=np.float32)
        self.labels = np.array(self.labels, dtype=np.int64)
    
    def __len__(self):
        return len(self.windows)
    
    def __getitem__(self, idx):
        return torch.from_numpy(self.windows[idx]), torch.tensor(self.labels[idx])

class LSTMClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes, dropout=0.0):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, 
                           batch_first=True, dropout=dropout if num_layers > 1 else 0)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        out = self.dropout(lstm_out[:, -1, :])
        return self.fc(out)

print("âœ“ Model and dataset classes defined")

In [None]:
print("\n[Step 4/6] Creating datasets...")
dataset_start = time.time()

sequence_length = best_params['sequence_length']
batch_size = best_params['batch_size']

# Create separate datasets for train, val, and test
train_dataset = SimulationRunDataset(train, features, scaler, label_encoder, sequence_length)
val_dataset = SimulationRunDataset(val, features, scaler, label_encoder, sequence_length)
test_dataset = SimulationRunDataset(test, features, scaler, label_encoder, sequence_length)

print(f"âœ“ Train dataset: {len(train_dataset)} windows")
print(f"âœ“ Val dataset: {len(val_dataset)} windows")
print(f"âœ“ Test dataset: {len(test_dataset)} windows")

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

print(f"âœ“ Dataset creation time: {time.time() - dataset_start:.2f}s")

## Model Training

In [None]:
print("\n[Step 5/6] Training final model with early stopping on validation set...")
train_start = time.time()

# Build model
model = LSTMClassifier(
    input_size=num_features,
    hidden_size=best_params['hidden_size'],
    num_layers=best_params['num_layers'],
    num_classes=num_classes,
    dropout=best_params['dropout']
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=best_params['learning_rate'])

# Training with early stopping on validation loss
best_val_loss = float('inf')
patience_counter = 0
best_model_state = None
history = {'train_loss': [], 'val_loss': [], 'epoch': []}

print(f"Training for up to {MAX_EPOCHS} epochs with patience {PATIENCE}...")

for epoch in range(MAX_EPOCHS):
    # Training phase
    model.train()
    train_loss = 0.0
    
    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        
        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
    
    avg_train_loss = train_loss / len(train_loader)
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            val_loss += loss.item()
    
    avg_val_loss = val_loss / len(val_loader)
    
    history['train_loss'].append(avg_train_loss)
    history['val_loss'].append(avg_val_loss)
    history['epoch'].append(epoch + 1)
    
    if (epoch + 1) % 5 == 0 or epoch == 0:
        print(f"Epoch {epoch+1}/{MAX_EPOCHS}: Train Loss = {avg_train_loss:.6f}, Val Loss = {avg_val_loss:.6f}")
    
    # Early stopping on validation loss
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print(f"Early stopping at epoch {epoch+1}")
            break

# Restore best model
if best_model_state is not None:
    model.load_state_dict({k: v.to(device) for k, v in best_model_state.items()})

train_time = time.time() - train_start
best_epoch = history['epoch'][history['val_loss'].index(min(history['val_loss']))]
print(f"\nâœ“ Training complete in {train_time:.2f}s ({epoch+1} epochs)")
print(f"âœ“ Best epoch: {best_epoch} (val_loss = {best_val_loss:.6f})")

## Evaluation on Test Set

In [None]:
print("\n[Step 6/6] Evaluating on test set...")

model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for X_batch, y_batch in test_loader:
        X_batch = X_batch.to(device)
        outputs = model(X_batch)
        preds = outputs.argmax(dim=1).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(y_batch.numpy())

y_test = np.array(all_labels)
y_pred = np.array(all_preds)

# Calculate metrics
accuracy = accuracy_score(y_test, y_pred)
balanced_acc = balanced_accuracy_score(y_test, y_pred)
f1_weighted = f1_score(y_test, y_pred, average='weighted')
f1_macro = f1_score(y_test, y_pred, average='macro')
precision_weighted = precision_score(y_test, y_pred, average='weighted')
recall_weighted = recall_score(y_test, y_pred, average='weighted')

print(f"\n{'='*60}")
print(f"TEST SET RESULTS {'(QUICK MODE)' if QUICK_MODE else ''}")
print(f"{'='*60}")
print(f"Accuracy:          {accuracy:.4f} ({accuracy*100:.2f}%)")
print(f"Balanced Accuracy: {balanced_acc:.4f} ({balanced_acc*100:.2f}%)")
print(f"F1 (weighted):     {f1_weighted:.4f}")
print(f"F1 (macro):        {f1_macro:.4f}")
print(f"Precision (weighted): {precision_weighted:.4f}")
print(f"Recall (weighted):    {recall_weighted:.4f}")
print(f"{'='*60}")

In [None]:
print("\nPer-Class Classification Report:")
print(classification_report(y_test, y_pred, target_names=class_names, digits=4))

## Visualizations

In [None]:
# Training history
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(history['epoch'], history['train_loss'], 'b-', linewidth=2, label='Train Loss')
ax.plot(history['epoch'], history['val_loss'], 'r-', linewidth=2, label='Val Loss')
ax.axvline(x=best_epoch, color='g', linestyle='--', alpha=0.7, label=f'Best Epoch ({best_epoch})')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title(f'LSTM Training History{" - QUICK" if QUICK_MODE else ""}')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(FIGURES_DIR / f'lstm_training_history{MODE_SUFFIX}.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Confusion Matrix
cm = confusion_matrix(y_test, y_pred)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

fig, axes = plt.subplots(1, 2, figsize=(20, 8))

sns.heatmap(cm, annot=False, fmt='d', cmap='Blues', 
            xticklabels=class_names, yticklabels=class_names, ax=axes[0])
axes[0].set_xlabel('Predicted')
axes[0].set_ylabel('Actual')
axes[0].set_title(f'LSTM Confusion Matrix (Counts){" - QUICK" if QUICK_MODE else ""}')

sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names, ax=axes[1])
axes[1].set_xlabel('Predicted')
axes[1].set_ylabel('Actual')
axes[1].set_title(f'LSTM Confusion Matrix (Normalized){" - QUICK" if QUICK_MODE else ""}')

plt.tight_layout()
plt.savefig(FIGURES_DIR / f'lstm_confusion_matrix{MODE_SUFFIX}.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"âœ“ Saved confusion matrix to {FIGURES_DIR / f'lstm_confusion_matrix{MODE_SUFFIX}.png'}")

In [None]:
# Per-class F1 scores
f1_per_class = f1_score(y_test, y_pred, average=None)

fig, ax = plt.subplots(figsize=(12, 6))
bars = ax.bar(class_names, f1_per_class, color='steelblue', edgecolor='black')
ax.axhline(y=f1_weighted, color='red', linestyle='--', label=f'Weighted Avg: {f1_weighted:.4f}')
ax.set_xlabel('Fault Class')
ax.set_ylabel('F1 Score')
ax.set_title(f'LSTM Per-Class F1 Scores{" - QUICK" if QUICK_MODE else ""}')
ax.set_ylim(0, 1.05)
ax.legend()
ax.grid(axis='y', alpha=0.3)

for bar, f1 in zip(bars, f1_per_class):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
            f'{f1:.3f}', ha='center', va='bottom', fontsize=8, rotation=90)

plt.tight_layout()
plt.savefig(FIGURES_DIR / f'lstm_per_class_f1{MODE_SUFFIX}.png', dpi=150, bbox_inches='tight')
plt.show()

## Save Results

In [None]:
end_time = time.time()
total_runtime = end_time - start_time

# Compile all metrics
metrics = {
    'model': 'LSTM',
    'task': 'multiclass',
    'quick_mode': QUICK_MODE,
    'train_fraction': TRAIN_FRACTION,
    'train_samples': len(train_dataset),
    'val_samples': len(val_dataset),
    'test_samples': len(test_dataset),
    'best_epoch': best_epoch,
    'best_val_loss': float(best_val_loss),
    'accuracy': float(accuracy),
    'balanced_accuracy': float(balanced_acc),
    'f1_weighted': float(f1_weighted),
    'f1_macro': float(f1_macro),
    'precision_weighted': float(precision_weighted),
    'recall_weighted': float(recall_weighted),
    'per_class_f1': {class_names[i]: float(f1_per_class[i]) for i in range(num_classes)},
    'hyperparameters': best_params,
    'epochs_trained': len(history['epoch']),
    'training_time_seconds': float(train_time),
    'total_runtime_seconds': float(total_runtime),
    'random_seed': RANDOM_SEED
}

# Save metrics
with open(METRICS_DIR / f'lstm_metrics{MODE_SUFFIX}.json', 'w') as f:
    json.dump(metrics, f, indent=2)
print(f"âœ“ Saved metrics to {METRICS_DIR / f'lstm_metrics{MODE_SUFFIX}.json'}")

# Save model
torch.save({
    'model_state_dict': model.state_dict(),
    'model_config': {
        'input_size': num_features,
        'hidden_size': best_params['hidden_size'],
        'num_layers': best_params['num_layers'],
        'num_classes': num_classes,
        'dropout': best_params['dropout'],
        'sequence_length': sequence_length
    },
    'scaler_mean': scaler.mean_.tolist(),
    'scaler_scale': scaler.scale_.tolist(),
    'label_encoder_classes': label_encoder.classes_.tolist(),
    'features': features
}, MODEL_DIR / f'lstm_final{MODE_SUFFIX}.pt')
print(f"âœ“ Saved model to {MODEL_DIR / f'lstm_final{MODE_SUFFIX}.pt'}")

# Save confusion matrix
cm_df = pd.DataFrame(cm, index=class_names, columns=class_names)
cm_df.to_csv(METRICS_DIR / f'lstm_confusion_matrix{MODE_SUFFIX}.csv')
print(f"âœ“ Saved confusion matrix to {METRICS_DIR / f'lstm_confusion_matrix{MODE_SUFFIX}.csv'}")

print(f"\n{'='*60}")
print(f"âœ“ LSTM Final Training Complete! {'(QUICK MODE)' if QUICK_MODE else ''}")
print(f"{'='*60}")
print(f"Total runtime: {int(total_runtime // 60)}m {int(total_runtime % 60)}s")
print(f"Best epoch: {best_epoch}")
print(f"Test Accuracy: {accuracy:.4f}")
print(f"Test F1 (weighted): {f1_weighted:.4f}")
print(f"{'='*60}")