In [4]:
import os

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import classification_report

# File paths
data_dir = "/home/yugdes/snapDetection/dataV0"
model_output_path = "/home/yugdes/snapDetection/LSTM/models/attention_lstm.pt"
os.makedirs(os.path.dirname(model_output_path), exist_ok=True)

train_file = os.path.join(data_dir, "joint_velocity_train.csv")
val_file = os.path.join(data_dir, "joint_velocity_val.csv")
test_file = os.path.join(data_dir, "joint_velocity_test.csv")

# Hyperparameters
WINDOW_SIZE = 30
NUM_JOINTS = 7
TAIL_SIZE = 5  # Use last 5 time stamps for labeling
BATCH_SIZE = 64
HIDDEN_SIZE = 64
EPOCHS = 30
LEARNING_RATE = 0.001
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#######################################
# Dataset Class with Sliding Windows #
#######################################
class SnapWindowedDataset(Dataset):
    def __init__(self, csv_file, window_size=WINDOW_SIZE, num_joints=NUM_JOINTS, tail_size=TAIL_SIZE, scaler=None, fit_scaler=False):
        df = pd.read_csv(csv_file)
        # Assume columns with "vel" in their name are joint velocities and one column "label" exists.
        joint_cols = [col for col in df.columns if "vel" in col]
        label_col = "label"

        data = df[joint_cols].values  # shape: (total_time_steps, num_joints)
        labels = df[label_col].values  # shape: (total_time_steps,)

        # Initialize or use provided scalers for each joint (robot-specific normalization)
        self.scaler = scaler if scaler else [MinMaxScaler() for _ in range(num_joints)]
        if fit_scaler:
            for j in range(num_joints):
                self.scaler[j].fit(data[:, j].reshape(-1, 1))
        # Apply normalization per joint
        for j in range(num_joints):
            data[:, j] = self.scaler[j].transform(data[:, j].reshape(-1, 1)).flatten()

        # Create sliding windows:
        self.X, self.y = [], []
        for i in range(len(data) - window_size + 1):
            window = data[i:i + window_size]  # shape: (window_size, num_joints)
            # Transpose to get shape: (num_joints, window_size)
            window = window.T  
            # Labeling strategy: look at the last 'tail_size' time steps.
            # If any of these time steps is positive, label the entire window as positive.
            tail_labels = labels[i + window_size - tail_size: i + window_size]
            label = 1 if np.max(tail_labels) == 1 else 0
            self.X.append(window)
            self.y.append(label)

        self.X = torch.tensor(np.stack(self.X), dtype=torch.float32)
        self.y = torch.tensor(self.y, dtype=torch.float32)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

##########################################
# Attention-based LSTM Model Definition  #
##########################################
class AttentionLSTM(nn.Module):
    def __init__(self, num_joints, input_size, hidden_size):
        super().__init__()
        self.num_joints = num_joints
        # We'll process each joint's time series independently.
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True)
        # Attention layer: one score per joint feature vector
        self.attn_fc = nn.Linear(hidden_size, 1)
        # Classifier to decide snap event based on aggregated features.
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        # x: (batch, num_joints, window_size)
        joint_features = []
        for j in range(self.num_joints):
            # For each joint, extract its time series: (batch, window_size)
            joint_seq = x[:, j, :].unsqueeze(-1)  # shape: (batch, window_size, 1)
            out, _ = self.lstm(joint_seq)  
            # Use the last output of the LSTM as the joint's feature representation
            last_hidden = out[:, -1, :]  # shape: (batch, hidden_size)
            joint_features.append(last_hidden)

        # Stack features from all joints: (batch, num_joints, hidden_size)
        joint_features = torch.stack(joint_features, dim=1)
        # Compute attention scores and weights for each joint
        attn_scores = self.attn_fc(joint_features).squeeze(-1)  # (batch, num_joints)
        attn_weights = torch.softmax(attn_scores, dim=1)          # (batch, num_joints)
        # Aggregate features using the computed attention weights
        attn_output = torch.sum(attn_weights.unsqueeze(-1) * joint_features, dim=1)  # (batch, hidden_size)
        # Classify the aggregated feature
        out = self.classifier(attn_output).squeeze()
        return out

#########################
# Data Loading Section  #
#########################
train_set = SnapWindowedDataset(train_file, fit_scaler=True)
val_set = SnapWindowedDataset(val_file, scaler=train_set.scaler)
test_set = SnapWindowedDataset(test_file, scaler=train_set.scaler)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE)

#############################
# Training and Evaluation   #
#############################
model = AttentionLSTM(NUM_JOINTS, input_size=1, hidden_size=HIDDEN_SIZE).to(DEVICE)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print("Starting training...")
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for X, y in train_loader:
        X, y = X.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        preds = model(X)
        loss = criterion(preds, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    # Validation after each epoch
    model.eval()
    with torch.no_grad():
        val_preds, val_targets = [], []
        for X, y in val_loader:
            X = X.to(DEVICE)
            preds = model(X).cpu().numpy()
            val_preds.extend(preds)
            val_targets.extend(y.numpy())

        # Binary threshold at 0.5
        val_preds_bin = [1 if p > 0.5 else 0 for p in val_preds]
        print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {total_loss/len(train_loader):.4f}")
        print("Validation Classification Report:")
        print(classification_report(val_targets, val_preds_bin, zero_division=0))

# Save the trained model
torch.save(model.state_dict(), model_output_path)
print(f"Model saved to {model_output_path}")

# Final Evaluation on Test Data
model.eval()
test_preds, test_targets = [], []
with torch.no_grad():
    for X, y in test_loader:
        X = X.to(DEVICE)
        preds = model(X).cpu().numpy()
        test_preds.extend(preds)
        test_targets.extend(y.numpy())

test_preds_bin = [1 if p > 0.5 else 0 for p in test_preds]
print("Test Classification Report:")
print(classification_report(test_targets, test_preds_bin, zero_division=0))


Starting training...
Epoch 1/30 - Loss: 0.3921
Validation Classification Report:
              precision    recall  f1-score   support

         0.0       0.87      1.00      0.93      5885
         1.0       0.00      0.00      0.00       901

    accuracy                           0.87      6786
   macro avg       0.43      0.50      0.46      6786
weighted avg       0.75      0.87      0.81      6786

Epoch 2/30 - Loss: 0.3891
Validation Classification Report:
              precision    recall  f1-score   support

         0.0       0.87      1.00      0.93      5885
         1.0       0.00      0.00      0.00       901

    accuracy                           0.87      6786
   macro avg       0.43      0.50      0.46      6786
weighted avg       0.75      0.87      0.81      6786

Epoch 3/30 - Loss: 0.3889
Validation Classification Report:
              precision    recall  f1-score   support

         0.0       0.87      1.00      0.93      5885
         1.0       0.00      0.00   