In [None]:
# Install all required packages with version control
!pip install captum torch-geometric torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.2.0+cu121.html

In [None]:
# Detect environment
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
cuda_version = torch.version.cuda.replace('.', '') if torch.cuda.is_available() else "cpu"

# Install matching PyG packages
pyg_url = f"https://data.pyg.org/whl/torch-{torch.__version__}+cu{cuda_version}.html"
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f {pyg_url}

# Verify
try:
    from torch_geometric.nn import GCNConv
    print("\nSUCCESS: PyTorch Geometric installed correctly!")
except ImportError as e:
    print("\nERROR:", e)
    print("Try manually specifying versions above")

In [None]:
import os
import json
import torch
import cv2
import numpy as np
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import networkx as nx
from torch_geometric.utils import to_networkx
from matplotlib.colors import Normalize
from PIL import Image
from matplotlib.gridspec import GridSpec

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define paths
base_graph_dir = "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Processed_Graph_Data/GraphFolders"
base_frame_dir = "/content/drive/MyDrive/E-RAU(DB)/MA680/data/Shooting/Processed_Frames/ImageFolders"
train_graph_dir = os.path.join(base_graph_dir, "TrainGraphFolders")
train_frame_dir = os.path.join(base_frame_dir, "TrainImageFolders")

# 1. Enhanced Dataset Class with Image-Graph pairing
class PoseGraphDataset(torch.utils.data.Dataset):
    def __init__(self, graph_dir, frame_dir, mode='train'):
        self.graph_dir = graph_dir
        self.frame_dir = frame_dir
        self.mode = mode
        self.samples = []
        self.body_part_labels = self._create_body_part_labels()

        if mode == 'train':
            self._load_train_data()
        else:
            self._load_test_data()

        print(f"Loaded {len(self.samples)} {mode} samples")

    def _create_body_part_labels(self):
        """Create labels for key body parts"""
        labels = {
            0: "Nose", 1: "L Eye", 2: "R Eye", 3: "L Ear", 4: "R Ear",
            5: "L Shoulder", 6: "R Shoulder", 7: "L Elbow", 8: "R Elbow",
            9: "L Wrist", 10: "R Wrist", 11: "L Hip", 12: "R Hip",
            13: "L Knee", 14: "R Knee", 15: "L Ankle", 16: "R Ankle"
        }
        # Add hands (simplified)
        for i in range(33, 54):
            labels[i] = f"L Hand {i-33}"
        for i in range(54, 75):
            labels[i] = f"R Hand {i-54}"
        return labels

    def _find_corresponding_frame(self, json_path):
        """Find the corresponding image frame for a graph JSON file"""
        parts = json_path.split('/')
        video_folder = parts[-2]
        frame_num = parts[-1].split('_')[1].split('.')[0]

        if "ErrorGraphFolders" in json_path:
            frame_dir = os.path.join(self.frame_dir, "ErrorImageFolders", video_folder)
        elif "NoErrorGraphFolders" in json_path:
            frame_dir = os.path.join(self.frame_dir, "NoErrorImageFolders", video_folder)
        else:
            frame_dir = os.path.join(self.frame_dir, "TestImageFolders", video_folder)

        frame_path = os.path.join(frame_dir, f"frame_{frame_num}.jpg")
        return frame_path if os.path.exists(frame_path) else None

    def _load_train_data(self):
        """Load training data with image-graph pairs"""
        for class_folder in ["ErrorGraphFolders", "NoErrorGraphFolders"]:
            class_dir = os.path.join(self.graph_dir, class_folder)
            label = 0 if "Error" in class_folder else 1

            for video_folder in os.listdir(class_dir):
                video_path = os.path.join(class_dir, video_folder)
                if os.path.isdir(video_path):
                    for json_file in os.listdir(video_path):
                        if json_file.endswith('.json'):
                            json_path = os.path.join(video_path, json_file)
                            frame_path = self._find_corresponding_frame(json_path)
                            if frame_path:
                                self.samples.append((json_path, frame_path, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        json_path, frame_path, label = self.samples[idx]

        try:
            with open(json_path, 'r') as f:
                graph_data = json.load(f)

            x = torch.tensor(graph_data['nodes'], dtype=torch.float)
            edge_index = torch.tensor(graph_data['edges'], dtype=torch.long).t().contiguous()

            return Data(
                x=x,
                edge_index=edge_index,
                y=torch.tensor([label], dtype=torch.long),
                frame_path=frame_path,
                body_part_labels=self.body_part_labels
            )

        except Exception as e:
            print(f"Error loading {json_path}: {str(e)}")
            return Data(
                x=torch.zeros((75, 3), dtype=torch.float),
                edge_index=torch.zeros((2, 1), dtype=torch.long),
                y=torch.tensor([label], dtype=torch.long),
                frame_path="",
                body_part_labels=self.body_part_labels
            )

# 2. Define ExplainableGCN Model
class ExplainableGCN(nn.Module):
    def __init__(self, num_classes=2):
        super(ExplainableGCN, self).__init__()
        self.conv1 = GCNConv(3, 64)
        self.conv2 = GCNConv(64, 128)
        self.conv3 = GCNConv(128, 256)
        self.fc = nn.Linear(256, num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))

        x = global_mean_pool(x, batch)
        x = self.fc(x)

        return F.log_softmax(x, dim=1)

    def compute_node_importance(self, data):
        """Compute node importance using gradients"""
        data = data.clone().to(device)
        data.x.requires_grad_(True)

        self.eval()
        output = self(data)
        pred_class = output.argmax(dim=1).item()

        self.zero_grad()
        output[0, pred_class].backward()

        if data.x.grad is not None:
            gradients = data.x.grad.abs().mean(dim=1)
            return gradients, pred_class
        else:
            return torch.zeros(data.x.size(0)), pred_class

# 3. Visualization function
def visualize_graph_with_image(data, importance, pred_class):
    """Visualize graph and corresponding image side by side"""
    fig = plt.figure(figsize=(20, 10))
    gs = GridSpec(1, 2, width_ratios=[1, 1])

    try:
        # Graph Visualization
        ax1 = plt.subplot(gs[0])
        G = to_networkx(data, to_undirected=True)
        pos = {i: (data.x[i,0].item(), -data.x[i,1].item()) for i in range(data.x.size(0))}

        importance_norm = (importance - importance.min()) / (importance.max() - importance.min() + 1e-9)

        nx.draw_networkx_nodes(
            G, pos,
            node_color=importance_norm.cpu().numpy(),
            cmap=plt.cm.Reds,
            node_size=200,
            alpha=0.8,
            ax=ax1
        )
        nx.draw_networkx_edges(G, pos, alpha=0.2, ax=ax1)

        top_nodes = torch.topk(importance, k=5).indices.tolist()
        labels = {i: data.body_part_labels.get(i, f"Node {i}") for i in top_nodes}
        nx.draw_networkx_labels(G, pos, labels, font_size=10, font_weight='bold', ax=ax1)

        ax1.set_title(f"Graph Importance\nPredicted: {'Error' if pred_class == 0 else 'NoError'}")
        ax1.axis('off')

        # Image Visualization
        ax2 = plt.subplot(gs[1])
        if hasattr(data, 'frame_path') and os.path.exists(data.frame_path):
            img = Image.open(data.frame_path)
            ax2.imshow(img)
            ax2.set_title("Original Frame")
            ax2.axis('off')
        else:
            ax2.text(0.5, 0.5, "Image not found", ha='center', va='center')
            ax2.axis('off')

        plt.tight_layout()
        plt.show()

        print("\nTop 5 Important Body Parts:")
        top_importances = torch.topk(importance, k=5)
        for idx, score in zip(top_importances.indices, top_importances.values):
            part_name = data.body_part_labels.get(idx.item(), f"Node {idx.item()}")
            print(f"{part_name}: {score.item():.4f}")

    except Exception as e:
        print(f"Visualization error: {str(e)}")

# 4. Initialize dataset
train_dataset = PoseGraphDataset(train_graph_dir, train_frame_dir, mode='train')

# 5. Load or create model
model = ExplainableGCN().to(device)

# Try to load pretrained weights
model_path = "best_gcn_model.pth"
if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path))
    print("Loaded pretrained model weights")
else:
    print("Initialized new model (no pretrained weights found)")

# 6. Generate visualizations
sample_indices = [0, 5, 10]  # Different samples to visualize
for idx in sample_indices:
    if idx < len(train_dataset):  # Check if index is valid
        sample_data = train_dataset[idx].to(device)
        node_importances, pred_class = model.compute_node_importance(sample_data)
        visualize_graph_with_image(sample_data, node_importances, pred_class)
    else:
        print(f"Skipping index {idx} - exceeds dataset size")