In [None]:
# Install all required packages with version control
!pip install captum torch-geometric torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.2.0+cu121.html

In [None]:
# Detect environment
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
cuda_version = torch.version.cuda.replace('.', '') if torch.cuda.is_available() else "cpu"

# Install matching PyG packages
pyg_url = f"https://data.pyg.org/whl/torch-{torch.__version__}+cu{cuda_version}.html"
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f {pyg_url}

# Verify
try:
    from torch_geometric.nn import GCNConv
    print("\nSUCCESS: PyTorch Geometric installed correctly!")
except ImportError as e:
    print("\nERROR:", e)
    print("Try manually specifying versions above")

In [None]:
import os
import json
import torch
import cv2
import numpy as np
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import networkx as nx
from torch_geometric.utils import to_networkx
from matplotlib.colors import Normalize
from PIL import Image
from matplotlib.gridspec import GridSpec

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define paths
base_graph_dir = "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Processed_Graph_Data/GraphFolders"
base_frame_dir = "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Processed_Frames/ImageFolders"
train_graph_dir = os.path.join(base_graph_dir, "TrainGraphFolders")
train_frame_dir = os.path.join(base_frame_dir, "TrainImageFolders")

# 1. Enhanced Dataset Class with Image-Graph pairing
class PoseGraphDataset(torch.utils.data.Dataset):
    def __init__(self, graph_dir, frame_dir, mode='train'):
        self.graph_dir = graph_dir
        self.frame_dir = frame_dir
        self.mode = mode
        self.samples = []
        self.body_part_labels = self._create_body_part_labels()

        if mode == 'train':
            self._load_train_data()
        else:
            self._load_test_data()

        print(f"Loaded {len(self.samples)} {mode} samples")

    def _create_body_part_labels(self):
        """Create labels for key body parts"""
        labels = {
            0: "Nose", 1: "L Eye", 2: "R Eye", 3: "L Ear", 4: "R Ear",
            5: "L Shoulder", 6: "R Shoulder", 7: "L Elbow", 8: "R Elbow",
            9: "L Wrist", 10: "R Wrist", 11: "L Hip", 12: "R Hip",
            13: "L Knee", 14: "R Knee", 15: "L Ankle", 16: "R Ankle"
        }
        # Add hands (simplified)
        for i in range(33, 54):
            labels[i] = f"L Hand {i-33}"
        for i in range(54, 75):
            labels[i] = f"R Hand {i-54}"
        return labels

    def _find_corresponding_frame(self, json_path):
        """Find the corresponding image frame for a graph JSON file"""
        parts = json_path.split('/')
        video_folder = parts[-2]
        frame_num = parts[-1].split('_')[1].split('.')[0]

        if "ErrorGraphFolders" in json_path:
            frame_dir = os.path.join(self.frame_dir, "ErrorImageFolders", video_folder)
        elif "NoErrorGraphFolders" in json_path:
            frame_dir = os.path.join(self.frame_dir, "NoErrorImageFolders", video_folder)
        else:
            frame_dir = os.path.join(self.frame_dir, "TestImageFolders", video_folder)

        frame_path = os.path.join(frame_dir, f"frame_{frame_num}.jpg")
        return frame_path if os.path.exists(frame_path) else None

    def _load_train_data(self):
        """Load training data with image-graph pairs"""
        for class_folder in ["ErrorGraphFolders", "NoErrorGraphFolders"]:
            class_dir = os.path.join(self.graph_dir, class_folder)
            label = 0 if "Error" in class_folder else 1

            for video_folder in os.listdir(class_dir):
                video_path = os.path.join(class_dir, video_folder)
                if os.path.isdir(video_path):
                    for json_file in os.listdir(video_path):
                        if json_file.endswith('.json'):
                            json_path = os.path.join(video_path, json_file)
                            frame_path = self._find_corresponding_frame(json_path)
                            if frame_path:
                                self.samples.append((json_path, frame_path, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        json_path, frame_path, label = self.samples[idx]

        try:
            with open(json_path, 'r') as f:
                graph_data = json.load(f)

            x = torch.tensor(graph_data['nodes'], dtype=torch.float)
            edge_index = torch.tensor(graph_data['edges'], dtype=torch.long).t().contiguous()

            return Data(
                x=x,
                edge_index=edge_index,
                y=torch.tensor([label], dtype=torch.long),
                frame_path=frame_path,
                body_part_labels=self.body_part_labels
            )

        except Exception as e:
            print(f"Error loading {json_path}: {str(e)}")
            return Data(
                x=torch.zeros((75, 3), dtype=torch.float),
                edge_index=torch.zeros((2, 1), dtype=torch.long),
                y=torch.tensor([label], dtype=torch.long),
                frame_path="",
                body_part_labels=self.body_part_labels
            )

# 2. Define ExplainableGCN Model
class ExplainableGCN(nn.Module):
    def __init__(self, num_classes=2):
        super(ExplainableGCN, self).__init__()
        self.conv1 = GCNConv(3, 64)
        self.conv2 = GCNConv(64, 128)
        self.conv3 = GCNConv(128, 256)
        self.fc = nn.Linear(256, num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))

        x = global_mean_pool(x, batch)
        x = self.fc(x)

        return F.log_softmax(x, dim=1)

    def compute_node_importance(self, data):
        """Compute node importance using gradients"""
        data = data.clone().to(device)
        data.x.requires_grad_(True)

        self.eval()
        output = self(data)
        pred_class = output.argmax(dim=1).item()

        self.zero_grad()
        output[0, pred_class].backward()

        if data.x.grad is not None:
            gradients = data.x.grad.abs().mean(dim=1)
            return gradients, pred_class
        else:
            return torch.zeros(data.x.size(0)), pred_class

# 4. Initialize dataset
train_dataset = PoseGraphDataset(train_graph_dir, train_frame_dir, mode='train')

# 3. Visualization function
def visualize_graph_with_image(data, importance, pred_class):
    """Improved visualization that handles upper-body shots and correct labeling"""
    fig = plt.figure(figsize=(24, 10))
    gs = GridSpec(1, 2, width_ratios=[1, 1])

    try:
        # 1. GRAPH VISUALIZATION
        ax1 = plt.subplot(gs[0])
        G = nx.Graph()
        for i in range(data.x.size(0)):
            G.add_node(i)
        edge_list = data.edge_index.t().tolist()
        G.add_edges_from(edge_list)

        # Use normalized coordinates with proper aspect ratio
        x_coords = data.x[:,0].cpu().numpy()
        y_coords = 1 - data.x[:,1].cpu().numpy()  # Flip y-axis

        # Center and scale coordinates
        x_min, x_max = x_coords.min(), x_coords.max()
        y_min, y_max = y_coords.min(), y_coords.max()

        # Handle upper-body shots (no legs visible)
        if y_max - y_min < 0.3:  # If vertical range is small
            y_min = 0  # Extend to full height
            y_max = 1

        pos = {i: (x_coords[i], y_coords[i]) for i in range(len(x_coords))}

        # Normalize importance
        importance_norm = (importance - importance.min()) / (importance.max() - importance.min() + 1e-9)

        # Draw only major body parts (skip hand/finger joints for clarity)
        visible_nodes = [i for i in pos.keys() if i < 33]  # Only pose keypoints
        subgraph = G.subgraph(visible_nodes)
        sub_pos = {k: pos[k] for k in visible_nodes}
        sub_importance = importance_norm[:33]

        # Draw graph
        nx.draw_networkx_nodes(
            subgraph, sub_pos,
            node_color=sub_importance.cpu().numpy(),
            cmap=plt.cm.Reds,
            node_size=300,
            alpha=0.8,
            ax=ax1
        )

        # Draw only major connections
        major_edges = [
            (0,1),(0,2),(1,3),(2,4),        # Head
            (5,6),(5,7),(6,8),(7,9),(8,10),  # Arms
            (5,11),(6,12),(11,12),(11,13),(12,14),(13,15),(14,16)  # Torso and legs
        ]

        nx.draw_networkx_edges(
            subgraph, sub_pos,
            edgelist=major_edges,
            edge_color='gray',
            width=2,
            alpha=0.5,
            ax=ax1
        )

        # Correct body part labels (MediaPipe standard indices)
        body_part_labels = {
    0: "Nose", 1: "Left Eye Inner", 2: "Left Eye", 3: "Left Eye Outer",
    4: "Right Eye Inner", 5: "Right Eye", 6: "Right Eye Outer",
    7: "Left Ear", 8: "Right Ear", 9: "Mouth Left", 10: "Mouth Right",
    11: "Left Shoulder", 12: "Right Shoulder", 13: "Left Elbow",
    14: "Right Elbow", 15: "Left Wrist", 16: "Right Wrist",
    17: "Left Pinky", 18: "Right Pinky", 19: "Left Index",
    20: "Right Index", 21: "Left Thumb", 22: "Right Thumb",
    23: "Left Hip", 24: "Right Hip", 25: "Left Knee", 26: "Right Knee",
    27: "Left Ankle", 28: "Right Ankle", 29: "Left Heel", 30: "Right Heel",
    31: "Left Foot Index", 32: "Right Foot Index"
}

        # Label only key joints
        key_joints = [0,5,6,7,8,11,12,13,14]
        labels = {i: body_part_labels.get(i, "") for i in key_joints if i in sub_pos}
        nx.draw_networkx_labels(
            subgraph, sub_pos,
            labels=labels,
            font_size=12,
            font_weight='bold',
            ax=ax1
        )

        ax1.set_title(f"Body Graph Importance\nPredicted: {'Error' if pred_class == 0 else 'NoError'}", fontsize=14)
        ax1.set_xlim(0, 1)
        ax1.set_ylim(0, 1)
        ax1.axis('off')

        # 2. IMAGE VISUALIZATION
        ax2 = plt.subplot(gs[1])
        if hasattr(data, 'frame_path') and os.path.exists(data.frame_path):
            img = Image.open(data.frame_path)
            ax2.imshow(img)

            # Overlay keypoints on image
            for i, (x, y) in sub_pos.items():
                if i in body_part_labels:
                    ax2.scatter(
                        x * img.width,
                        y * img.height,
                        s=50,
                        c='red' if importance_norm[i] > 0.5 else 'blue',
                        alpha=0.7
                    )
                    if i in key_joints:
                        ax2.text(
                            x * img.width + 10,
                            y * img.height + 10,
                            body_part_labels[i],
                            color='white',
                            fontsize=12,
                            bbox=dict(facecolor='black', alpha=0.5)
                        )

            ax2.set_title("Original Frame with Keypoints", fontsize=14)
            ax2.axis('off')
        else:
            ax2.text(0.5, 0.5, "Image not found", ha='center', va='center')
            ax2.axis('off')

        plt.tight_layout()
        plt.show()

        # Print importance for key joints
        print("\nKey Joint Importance Scores:")
        for i in sorted(key_joints):
            if i < len(importance):
                print(f"{body_part_labels.get(i, f'Joint {i}'):>12}: {importance[i].item():.3f}")

    except Exception as e:
        print(f"Visualization error: {str(e)}")

# 5. Load or create model
model = ExplainableGCN().to(device)

# Example usage with your model
sample_indices = [0, 5, 10]  # Different samples to visualize
for idx in sample_indices:
    if idx < len(train_dataset):
        sample_data = train_dataset[idx].to(device)
        node_importances, pred_class = model.compute_node_importance(sample_data)
        visualize_graph_with_image(sample_data, node_importances, pred_class)


# Try to load pretrained weights
model_path = "best_gcn_model.pth"
if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path))
    print("Loaded pretrained model weights")
else:
    print("Initialized new model (no pretrained weights found)")

# 6. Generate visualizations
sample_indices = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100]  # Different samples to visualize
for idx in sample_indices:
    if idx < len(train_dataset):  # Check if index is valid
        sample_data = train_dataset[idx].to(device)
        node_importances, pred_class = model.compute_node_importance(sample_data)
        visualize_graph_with_image(sample_data, node_importances, pred_class)
    else:
        print(f"Skipping index {idx} - exceeds dataset size")