# ST-GNN Model Prototyping and Development

This notebook provides a complete walkthrough of the Spatio-Temporal Graph Neural Network (ST-GNN) model development for human motion prediction in the Predictive Human-Robot Collaboration system.

## Objectives
1. Explore the Human3.6M dataset structure
2. Implement and test the ST-GNN architecture
3. Train the model with sample data
4. Evaluate prediction performance
5. Visualize results

In [None]:
# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import h5py
import json
from pathlib import Path
import sys
import os
from tqdm import tqdm

# Add the src directory to the path
sys.path.append('../src')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("Libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Data Generation and Exploration

Since we don't have the actual Human3.6M dataset, let's create synthetic data that mimics the structure for prototyping.

In [None]:
# Generate synthetic human motion data
def generate_synthetic_motion_data(num_sequences=1000, seq_length=75, num_joints=34):
    """Generate synthetic human motion data that mimics realistic movement patterns."""
    data = np.zeros((num_sequences, seq_length, num_joints, 3))
    
    for seq in range(num_sequences):
        # Generate different motion patterns
        motion_type = seq % 4
        
        for frame in range(seq_length):
            t = frame / 30.0  # Time in seconds
            
            for joint in range(num_joints):
                if motion_type == 0:  # Walking pattern
                    data[seq, frame, joint, 0] = 0.2 * np.sin(2 * np.pi * 1.0 * t + joint * 0.1)
                    data[seq, frame, joint, 1] = 0.1 * np.cos(2 * np.pi * 0.5 * t + joint * 0.2)
                    data[seq, frame, joint, 2] = 1.5 + 0.1 * np.sin(2 * np.pi * 2.0 * t + joint * 0.15)
                elif motion_type == 1:  # Arm reaching
                    data[seq, frame, joint, 0] = 0.3 * np.sin(2 * np.pi * 0.3 * t) if joint in [11, 12, 13, 14, 15, 16] else 0
                    data[seq, frame, joint, 1] = 0.2 * np.cos(2 * np.pi * 0.4 * t) if joint in [11, 12, 13, 14, 15, 16] else 0
                    data[seq, frame, joint, 2] = 1.0 + 0.05 * np.sin(2 * np.pi * 1.5 * t)
                elif motion_type == 2:  # Sitting motion
                    data[seq, frame, joint, 0] = 0.05 * np.sin(2 * np.pi * 0.1 * t + joint * 0.05)
                    data[seq, frame, joint, 1] = 0.03 * np.cos(2 * np.pi * 0.15 * t + joint * 0.1)
                    data[seq, frame, joint, 2] = 1.2 + 0.02 * np.sin(2 * np.pi * 0.5 * t + joint * 0.2)
                else:  # General movement
                    data[seq, frame, joint, 0] = 0.1 * np.sin(2 * np.pi * 0.7 * t + joint * 0.1)
                    data[seq, frame, joint, 1] = 0.1 * np.cos(2 * np.pi * 0.8 * t + joint * 0.15)
                    data[seq, frame, joint, 2] = 1.0 + 0.05 * np.sin(2 * np.pi * 1.2 * t + joint * 0.3)
                
                # Add some noise
                data[seq, frame, joint] += np.random.normal(0, 0.005, 3)
    
    return data

# Generate training data
print("Generating synthetic motion data...")
motion_data = generate_synthetic_motion_data(num_sequences=1000, seq_length=75, num_joints=34)
print(f"Generated data shape: {motion_data.shape}")
print("Shape interpretation: (sequences, frames, joints, coordinates)")

# Split into input and target sequences
input_seq_len = 30
output_seq_len = 45

input_sequences = motion_data[:, :input_seq_len, :, :]  # First 30 frames
target_sequences = motion_data[:, input_seq_len:, :, :]  # Next 45 frames

print(f"Input sequences shape: {input_sequences.shape}")
print(f"Target sequences shape: {target_sequences.shape}")

# Visualize a sample sequence
plt.figure(figsize=(15, 10))

# Plot a few joints over time for one sequence
sample_seq = 0
joints_to_plot = [0, 8, 11, 14]  # Hip, Neck, Left shoulder, Right shoulder
joint_names = ['Hip', 'Neck', 'L_Shoulder', 'R_Shoulder']

for i, (joint_idx, joint_name) in enumerate(zip(joints_to_plot, joint_names)):
    plt.subplot(2, 2, i+1)
    
    # Plot x, y, z coordinates
    frames = range(75)
    plt.plot(frames, motion_data[sample_seq, :, joint_idx, 0], 'r-', label='X', alpha=0.7)
    plt.plot(frames, motion_data[sample_seq, :, joint_idx, 1], 'g-', label='Y', alpha=0.7)
    plt.plot(frames, motion_data[sample_seq, :, joint_idx, 2], 'b-', label='Z', alpha=0.7)
    
    # Mark the split between input and target
    plt.axvline(x=input_seq_len, color='black', linestyle='--', alpha=0.5, label='Split')
    
    plt.title(f'{joint_name} Joint Motion')
    plt.xlabel('Frame')
    plt.ylabel('Position (m)')
    plt.legend()
    plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Data generation complete!")