# PyTorch Geometric Temporal Model (GConvGRU)

This notebook trains a temporal graph neural network using GConvGRU for engine fault classification.

## 1. Setup and Data Loading

In [None]:
# Imports
import time
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric_temporal.nn.recurrent import GConvGRU
from torch_geometric.nn import global_mean_pool
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

# Set random seeds
torch.manual_seed(42)

print("Libraries imported successfully.")

Libraries imported successfully.


In [3]:
# Load datasets
data_dir = Path("../results/datasets")

train_data = torch.load(data_dir / "train_pyg_temporal.pt", weights_only=False)
val_data = torch.load(data_dir / "val_pyg_temporal.pt", weights_only=False)
test_data = torch.load(data_dir / "test_pyg_temporal.pt", weights_only=False)

print(f"Train samples: {len(train_data['samples'])}")
print(f"Val samples: {len(val_data['samples'])}")
print(f"Test samples: {len(test_data['samples'])}")

# Inspect sample
sample = train_data['samples'][0]
print(f"\nSample structure:")
print(f"  x shape: {sample.x.shape}  # [num_nodes, 1, window_size]")
print(f"  edge_index shape: {sample.edge_index.shape}  # [2, num_edges]")
print(f"  y: {sample.y}  (dtype: {sample.y.dtype})")

Train samples: 171443
Val samples: 37658
Test samples: 37658

Sample structure:
  x shape: torch.Size([20, 1, 10])  # [num_nodes, 1, window_size]
  edge_index shape: torch.Size([2, 112])  # [2, num_edges]
  y: 1  (dtype: torch.int64)


In [10]:
# Configuration
BATCH_SIZE = 2048 * 4
NUM_NODES = sample.x.shape[0]  # 20 sensors
NUM_FEATURES = sample.x.shape[1]  # 1 feature per node
WINDOW_SIZE = sample.x.shape[2]  # 10 timesteps
NUM_CLASSES = 5  # 5 fault types
HIDDEN_DIM = 64
NUM_EPOCHS = 50
LEARNING_RATE = 0.0001

print(f"Configuration:")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Num nodes: {NUM_NODES}")
print(f"  Num features: {NUM_FEATURES}")
print(f"  Window size: {WINDOW_SIZE}")
print(f"  Hidden dim: {HIDDEN_DIM}")
print(f"  Num classes: {NUM_CLASSES}")

Configuration:
  Batch size: 8192
  Num nodes: 20
  Num features: 1
  Window size: 10
  Hidden dim: 64
  Num classes: 5


In [11]:
# Create DataLoaders
train_loader = DataLoader(train_data['samples'], batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_data['samples'], batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_data['samples'], batch_size=BATCH_SIZE, shuffle=False)

print(f"DataLoaders created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")

DataLoaders created:
  Train batches: 21
  Val batches: 5
  Test batches: 5


## 2. Model Definition

In [12]:
class TemporalGNN(nn.Module):
    def __init__(self, num_features, hidden_dim=64, num_classes=5):
        super(TemporalGNN, self).__init__()
        
        self.hidden_dim = hidden_dim
        
        # Temporal layer
        self.temporal = GConvGRU(
            in_channels=num_features,
            out_channels=hidden_dim,
            K=2
        )
        
        # Classification head
        self.fc = nn.Linear(hidden_dim, num_classes)
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        time_steps = x.size(2)
        
        # Process temporal dimension
        h = None
        for t in range(time_steps):
            x_t = x[:, :, t]  # [total_nodes, num_features]
            h = self.temporal(x_t, edge_index, H=h)
        
        # Pool each graph separately
        h_graph = global_mean_pool(h, batch)  # [batch_size, hidden_dim]
        
        # Classification
        h_graph = self.dropout(h_graph)
        logits = self.fc(h_graph)
        
        return logits

print("Model class defined.")

Model class defined.


## 3. Training and Evaluation Functions

In [13]:
def train_epoch(model, loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    pbar = tqdm(loader, desc='Training', leave=False)
    for batch in pbar:
        batch = batch.to(device)
        
        optimizer.zero_grad()
        logits = model(batch)
        loss = criterion(logits, batch.y)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        preds = logits.argmax(dim=1).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(batch.y.cpu().numpy())
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    avg_loss = total_loss / len(loader)
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='weighted')
    
    return avg_loss, accuracy, f1


def evaluate(model, loader, criterion, device):
    """Evaluate the model."""
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(loader, desc='Evaluating', leave=False):
            batch = batch.to(device)
            
            logits = model(batch)
            loss = criterion(logits, batch.y)
            
            total_loss += loss.item()
            preds = logits.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(batch.y.cpu().numpy())
    
    avg_loss = total_loss / len(loader)
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='weighted')
    
    return avg_loss, accuracy, f1, all_preds, all_labels

print("Training and evaluation functions defined.")

Training and evaluation functions defined.


## 4. Training

In [14]:
# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TemporalGNN(
    num_features=NUM_FEATURES,
    hidden_dim=HIDDEN_DIM,
    num_classes=NUM_CLASSES
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)

print(f"Device: {device}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

Device: cuda
Model parameters: 25,669


In [15]:
# Training loop
best_val_f1 = 0
patience = 5
patience_counter = 0

history = {
    'train_loss': [], 'train_acc': [], 'train_f1': [],
    'val_loss': [], 'val_acc': [], 'val_f1': []
}

start_time = time.time()

print("Starting training...")
print("=" * 70)

for epoch in range(NUM_EPOCHS):
    epoch_start = time.time()
    
    # Train
    train_loss, train_acc, train_f1 = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_acc, val_f1, _, _ = evaluate(model, val_loader, criterion, device)
    
    epoch_time = time.time() - epoch_start
    
    # Store history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['train_f1'].append(train_f1)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['val_f1'].append(val_f1)
    
    # Print progress
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    print(f"  Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | F1: {train_f1:.4f}")
    print(f"  Val   Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | F1: {val_f1:.4f}")
    print(f"  Epoch time: {epoch_time:.1f}s")
    
    # Early stopping
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        patience_counter = 0
        torch.save(model.state_dict(), "../results/best_pyg_temporal_model.pt")
        print(f"  ✓ Best model saved (F1: {best_val_f1:.4f})")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"\nEarly stopping at epoch {epoch+1}")
            break
    
    print("-" * 70)

elapsed_time = time.time() - start_time
print(f"\nTraining completed in {elapsed_time/60:.2f} minutes")

Starting training...


Training:   0%|          | 0/21 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/5 [00:00<?, ?it/s]


Epoch 1/50
  Train Loss: 1.6617 | Acc: 0.1599 | F1: 0.1074
  Val   Loss: 1.6145 | Acc: 0.2235 | F1: 0.0816
  Epoch time: 34.0s
  ✓ Best model saved (F1: 0.0816)
----------------------------------------------------------------------


Training:   0%|          | 0/21 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/5 [00:00<?, ?it/s]


Epoch 2/50
  Train Loss: 1.6443 | Acc: 0.1717 | F1: 0.1335
  Val   Loss: 1.6011 | Acc: 0.2235 | F1: 0.0816
  Epoch time: 36.7s
----------------------------------------------------------------------


Training:   0%|          | 0/21 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/5 [00:00<?, ?it/s]


Epoch 3/50
  Train Loss: 1.6268 | Acc: 0.1888 | F1: 0.1637
  Val   Loss: 1.5848 | Acc: 0.2235 | F1: 0.0816
  Epoch time: 34.4s
----------------------------------------------------------------------


Training:   0%|          | 0/21 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/5 [00:00<?, ?it/s]


Epoch 4/50
  Train Loss: 1.6107 | Acc: 0.2143 | F1: 0.1944
  Val   Loss: 1.5746 | Acc: 0.2235 | F1: 0.0816
  Epoch time: 37.0s
----------------------------------------------------------------------


Training:   0%|          | 0/21 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/5 [00:00<?, ?it/s]


Epoch 5/50
  Train Loss: 1.5981 | Acc: 0.2452 | F1: 0.2181
  Val   Loss: 1.5651 | Acc: 0.3535 | F1: 0.1846
  Epoch time: 34.9s
  ✓ Best model saved (F1: 0.1846)
----------------------------------------------------------------------


Training:   0%|          | 0/21 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/5 [00:00<?, ?it/s]


Epoch 6/50
  Train Loss: 1.5870 | Acc: 0.2750 | F1: 0.2317
  Val   Loss: 1.5566 | Acc: 0.3535 | F1: 0.1846
  Epoch time: 37.5s
----------------------------------------------------------------------


Training:   0%|          | 0/21 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/5 [00:00<?, ?it/s]


Epoch 7/50
  Train Loss: 1.5769 | Acc: 0.3034 | F1: 0.2376
  Val   Loss: 1.5492 | Acc: 0.3535 | F1: 0.1846
  Epoch time: 37.5s
----------------------------------------------------------------------


Training:   0%|          | 0/21 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/5 [00:00<?, ?it/s]


Epoch 8/50
  Train Loss: 1.5674 | Acc: 0.3280 | F1: 0.2390
  Val   Loss: 1.5431 | Acc: 0.3535 | F1: 0.1846
  Epoch time: 35.0s
----------------------------------------------------------------------


Training:   0%|          | 0/21 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/5 [00:00<?, ?it/s]


Epoch 9/50
  Train Loss: 1.5606 | Acc: 0.3443 | F1: 0.2354
  Val   Loss: 1.5383 | Acc: 0.3535 | F1: 0.1846
  Epoch time: 37.1s
----------------------------------------------------------------------


Training:   0%|          | 0/21 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/5 [00:00<?, ?it/s]


Epoch 10/50
  Train Loss: 1.5534 | Acc: 0.3581 | F1: 0.2317
  Val   Loss: 1.5348 | Acc: 0.3535 | F1: 0.1846
  Epoch time: 34.6s

Early stopping at epoch 10

Training completed in 5.98 minutes


## 5. Evaluation on Test Set

In [None]:
# Load best model
model.load_state_dict(torch.load("../results/best_pyg_temporal_model.pt"))

# Evaluate on test set
test_loss, test_acc, test_f1, test_preds, test_labels = evaluate(model, test_loader, criterion, device)

print("\n" + "=" * 70)
print("TEST SET RESULTS")
print("=" * 70)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")
print(f"Test F1 Score: {test_f1:.4f}")
print("=" * 70)

In [None]:
# Classification report
label_names = ['corrosion', 'erosion', 'fouling', 'tip_clearance', 'no_fault']
print("\nClassification Report:")
print(classification_report(test_labels, test_preds, target_names=label_names))

In [None]:
# Confusion matrix
cm = confusion_matrix(test_labels, test_preds)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=label_names, yticklabels=label_names)
plt.title('Confusion Matrix - PyG Temporal Model')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.savefig('../results/confusion_matrix_pyg_temporal.png', dpi=300, bbox_inches='tight')
plt.show()

## 6. Training History Visualization

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Loss
axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Validation')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Loss History')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(history['train_acc'], label='Train')
axes[1].plot(history['val_acc'], label='Validation')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Accuracy History')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# F1 Score
axes[2].plot(history['train_f1'], label='Train')
axes[2].plot(history['val_f1'], label='Validation')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('F1 Score')
axes[2].set_title('F1 Score History')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../results/training_history_pyg_temporal.png', dpi=300, bbox_inches='tight')
plt.show()