In [None]:
# Research Paper
# [Spatial Temporal Graph Convolutional Networks for Skeleton-Based Action Recognition](https://arxiv.org/abs/1801.07455)
# [On loss functions and regret bounds for multi-category classification](https://arxiv.org/abs/2005.08155)

In [None]:
# Input: Keypoints [T, num_keypoints, 2] (e.g., [30, 8, 2])
#     ↓
# Graph Feature Extractor (GCN):
#     - Models relationships between body parts
#     - Outputs spatial embeddings [T, num_keypoints, d]
#     ↓
# Temporal Module (GRU or Transformer):
#     - Captures temporal dynamics in keypoint movement
#     - Outputs temporal embeddings [T, d]
#     ↓
# Global Average Pooling:
#     - Aggregates information across time
#     ↓
# Fully Connected Layers:
#     - Dense layers for classification
#     - Dropout for regularization
#     ↓
# Output: Behavior Class Probabilities


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from PIL import Image
from rich import print
import os

In [None]:
import torch
import torch.nn as nn
import torch_geometric.nn as gnn

from torch.utils.data import Dataset, DataLoader
from torchvision import models
from torchvision import transforms

In [None]:

IMG_WIDTH = 320
IMG_HEIGHT = 240
NUM_KEYPOINTS = 7
NUM_BATCH = 16

DATASET_ROOT = "../datasets"
MODEL_PATH = "../models"
DATASET_FILE = DATASET_ROOT + "/preprocessed_dataset.csv"

dataset = pd.read_csv(DATASET_FILE)

# SKELETON = []

In [None]:
# Device-Agnostic
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"""
Device: {device}
Device CUDNN enabled: {torch.backends.cudnn.enabled}
""")

In [None]:

class BirdBehaviorClassifier(nn.Module):
    def __init__(self, num_keypoints: int, num_classes: int, hidden_dim: int = 128):
        super(BirdBehaviorClassifier, self).__init__()
        # Spatial GCN
        self.gcn = gnn.GCNConv(num_keypoints * 2, hidden_dim)
        
        # Temporal Module (GRU)
        self.temporal_gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
        
        # Classification Head
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, num_classes)
        )

    def forward(self, keypoints: torch.Tensor, edges: torch.Tensor):
        # Spatial GCN
        batch_size, time_steps, num_keypoints, _ = keypoints.shape
        keypoints = keypoints.view(batch_size * time_steps, num_keypoints, 2).permute(0, 2, 1)  # [B*T, 2, num_keypoints]
        gcn_out = self.gcn(keypoints, edges)  # Apply GCN
        
        # Temporal Module
        gcn_out = gcn_out.view(batch_size, time_steps, -1)  # Reshape for GRU
        temporal_out, _ = self.temporal_gru(gcn_out)
        
        # Classification Head
        temporal_out = temporal_out.mean(dim=1)  # Global Average Pooling
        out = self.fc(temporal_out)  # Behavior Prediction
        return out
