<a href="https://colab.research.google.com/github/RuwaAbey/Computer_vision_based_group_activity_detection/blob/main/Setting_up_collective_activity_keypoint_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
import numpy as np
import cv2
import mmcv
from mmengine.registry import init_default_scope
from mmpose.apis import inference_topdown, init_model as init_pose_estimator
from mmpose.structures import merge_data_samples
from mmdet.apis import inference_detector, init_detector
from mmpose.evaluation.functional import nms
import pickle
from tqdm import tqdm

# Constants from Collective Activity Dataset
FRAMES_NUM = {
    1: 302, 2: 347, 3: 194, 4: 257, 5: 536, 6: 401, 7: 968, 8: 221, 9: 356, 10: 302,
    11: 1813, 12: 1084, 13: 851, 14: 723, 15: 464, 16: 1021, 17: 905, 18: 600, 19: 203, 20: 342,
    21: 650, 22: 361, 23: 311, 24: 321, 25: 617, 26: 734, 27: 1804, 28: 470, 29: 635, 30: 356,
    31: 690, 32: 194, 33: 193, 34: 395, 35: 707, 36: 914, 37: 1049, 38: 653, 39: 518, 40: 401,
    41: 707, 42: 420, 43: 410, 44: 356
}

# Configuration
DATA_ROOT = '/home/akila17/e19-group-activity/Group_Activity/Datasets/collective_dataset'
POSE_CONFIG = '/home/akila17/e19-group-activity/Group_Activity/Datasets/mmpose/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_hrnet-w32_8xb64-210e_coco-256x192.py'
POSE_CHECKPOINT = '/home/akila17/e19-group-activity/Group_Activity/Datasets/Skeleton_Data/hrnet_w32_coco_256x192.pth'
DET_CONFIG = '/home/akila17/e19-group-activity/Group_Activity/Datasets/mmpose/demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py'
DET_CHECKPOINT = '/home/akila17/e19-group-activity/Group_Activity/Datasets/Skeleton_Data/faster_rcnn_r50_fpn_1x_coco.pth'
OUTPUT_PKL = 'collective_keypoints.pkl'
MAX_PEOPLE = 13  # M
NUM_FRAMES = 10  # T
NUM_KEYPOINTS = 17  # V (COCO keypoints)

# Check GPU availability
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
cfg_options = dict(model=dict(test_cfg=dict(output_heatmaps=True)))

# Initialize detector and pose estimator
detector = init_detector(DET_CONFIG, DET_CHECKPOINT, device=device)
pose_estimator = init_pose_estimator(
    POSE_CONFIG,
    POSE_CHECKPOINT,
    device=device,
    cfg_options=cfg_options
)

def process_frame(img_path, detector, pose_estimator):
    """Process a single frame to extract keypoints, scores, and boxes."""
    assert os.path.exists(img_path), f"Image not found at {img_path}"

    # Run detection
    scope = detector.cfg.get('default_scope', 'mmdet')
    if scope is not None:
        init_default_scope(scope)
    detect_result = inference_detector(detector, img_path)
    pred_instance = detect_result.pred_instances.cpu().numpy()
    bboxes = np.concatenate(
        (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
    bboxes = bboxes[np.logical_and(pred_instance.labels == 0, pred_instance.scores > 0.3)]
    bboxes = bboxes[nms(bboxes, 0.3)][:, :4]

    # Run pose estimation
    pose_results = inference_topdown(pose_estimator, img_path, bboxes)

    # Extract keypoints and scores
    all_keypoints = []
    all_scores = []
    for pose_result in pose_results:
        keypoints = pose_result.pred_instances.keypoints
        scores = pose_result.pred_instances.keypoint_scores
        all_keypoints.append(keypoints)
        all_scores.append(scores)

    all_keypoints = np.array(all_keypoints)  # Shape: (M, V, 2)
    all_scores = np.array(all_scores)       # Shape: (M, V)
    boxes = bboxes                          # Shape: (M, 4)

    # Pad if fewer than MAX_PEOPLE
    M = all_keypoints.shape[0] if all_keypoints.size > 0 else 0
    if M < MAX_PEOPLE:
        pad_shape = (MAX_PEOPLE - M, NUM_KEYPOINTS, 2)
        all_keypoints = np.pad(all_keypoints, ((0, MAX_PEOPLE - M), (0, 0), (0, 0)),
                               mode='constant', constant_values=0) if M > 0 else np.zeros((MAX_PEOPLE, NUM_KEYPOINTS, 2))
        all_scores = np.pad(all_scores, ((0, MAX_PEOPLE - M), (0, 0)),
                            mode='constant', constant_values=0) if M > 0 else np.zeros((MAX_PEOPLE, NUM_KEYPOINTS))
        boxes = np.pad(boxes, ((0, MAX_PEOPLE - M), (0, 0)),
                       mode='constant', constant_values=0) if M > 0 else np.zeros((MAX_PEOPLE, 4))

    return {
        'keypoints': all_keypoints[:MAX_PEOPLE],        # Shape: (13, 17, 2)
        'keypoint_scores': all_scores[:MAX_PEOPLE],     # Shape: (13, 17)
        'boxes': boxes[:MAX_PEOPLE]                     # Shape: (13, 4)
    }

def create_keypoint_dataset():
    """Create keypoint dataset for Collective Activity Dataset."""
    dataset_dict = {}

    # Iterate over sequences (1 to 44)
    for sid in tqdm(range(1, 45), desc="Processing sequences"):
        seq_key = f'seq{sid:02d}'
        seq_dict = {}
        total_frames = FRAMES_NUM[sid]
        seq_path = os.path.join(DATA_ROOT, seq_key)

        # Iterate over every 10th frame
        for start_frame in range(1, total_frames + 1, 10):
            if start_frame + 9 > total_frames:
                continue  # Skip if clip exceeds total frames

            clip_dict = {
                'keypoints': np.zeros((MAX_PEOPLE, NUM_FRAMES, NUM_KEYPOINTS, 2)),
                'keypoint_scores': np.zeros((MAX_PEOPLE, NUM_FRAMES, NUM_KEYPOINTS)),
                'boxes': np.zeros((MAX_PEOPLE, NUM_FRAMES, 4))  # Per-frame boxes
            }

            # Process 10 frames in the clip
            for t, fid in enumerate(range(start_frame, start_frame + NUM_FRAMES)):
                img_path = os.path.join(seq_path, f'frame{fid:04d}.jpg')
                if not os.path.exists(img_path):
                    print(f"Warning: Image not found at {img_path}")
                    continue

                result = process_frame(img_path, detector, pose_estimator)
                clip_dict['keypoints'][:, t] = result['keypoints']
                clip_dict['keypoint_scores'][:, t] = result['keypoint_scores']
                clip_dict['boxes'][:, t] = result['boxes']

            seq_dict[start_frame] = clip_dict

        dataset_dict[seq_key] = seq_dict

    # Save to pickle file
    with open(OUTPUT_PKL, 'wb') as f:
        pickle.dump(dataset_dict, f)
    print(f"Dataset saved to {OUTPUT_PKL}")

if __name__ == "__main__":
    create_keypoint_dataset()

In [None]:

%pip install torch==2.1.0+cu118 torchvision==0.16.0+cu118 torchaudio==2.1.0+cu118 --index-url https://download.pytorch.org/whl/cu118



Looking in indexes: https://download.pytorch.org/whl/cu118


In [None]:
# check NVCC version
!nvcc -V

# check GCC version
!gcc --version

# check python in conda environment
!which python

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Jun__6_02:18:23_PDT_2024
Cuda compilation tools, release 12.5, V12.5.82
Build cuda_12.5.r12.5/compiler.34385749_0
gcc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Copyright (C) 2021 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

/usr/local/bin/python


In [None]:
import torch
import torchvision
import torchaudio
print("Torch version:", torch.__version__)

print("Torch version:", torchvision.__version__)
print("Torch version:", torchaudio.__version__)
print("CUDA available:", torch.cuda.is_available())

Torch version: 2.1.0+cu118
Torch version: 0.16.0+cu118
Torch version: 2.1.0+cu118
CUDA available: True


In [None]:
!pip install "numpy<2"



In [None]:

# install MMEngine, MMCV and MMDetection using MIM
%pip install -U openmim
!mim install mmengine
!mim install "mmcv==2.1.0"
!mim install "mmdet<3.3.0,>=3.0.0"

Looking in links: https://download.openmmlab.com/mmcv/dist/cu118/torch2.1.0/index.html
Looking in links: https://download.openmmlab.com/mmcv/dist/cu118/torch2.1.0/index.html
Looking in links: https://download.openmmlab.com/mmcv/dist/cu118/torch2.1.0/index.html
Ignoring mmcv: markers 'extra == "mim"' don't match your environment
Ignoring mmengine: markers 'extra == "mim"' don't match your environment


In [None]:
import mmcv
print("mmcv version:", mmcv.__version__)

mmcv version: 2.1.0


In [None]:
# for better Colab compatibility, install xtcocotools from source
%pip install git+https://github.com/jin-s13/xtcocoapi

Collecting git+https://github.com/jin-s13/xtcocoapi
  Cloning https://github.com/jin-s13/xtcocoapi to /tmp/pip-req-build-mkiz2ajy
  Running command git clone --filter=blob:none --quiet https://github.com/jin-s13/xtcocoapi /tmp/pip-req-build-mkiz2ajy
  Resolved https://github.com/jin-s13/xtcocoapi to commit d74033ff1635e9002133b2380862bc2b728584d2
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [None]:

!git clone https://github.com/open-mmlab/mmpose.git
# The master branch is version 1.x
%cd mmpose
%pip install -r requirements.txt
%pip install -v -e .
# "-v" means verbose, or more output
# "-e" means installing a project in editable mode,
# thus any local modifications made to the code will take effect without reinstallation.

fatal: destination path 'mmpose' already exists and is not an empty directory.
/content/mmpose
Using pip 24.1.2 from /usr/local/lib/python3.11/dist-packages/pip (python 3.11)
Obtaining file:///content/mmpose
  Running command python setup.py egg_info
  running egg_info
  creating /tmp/pip-pip-egg-info-68fjtitx/mmpose.egg-info
  writing manifest file '/tmp/pip-pip-egg-info-68fjtitx/mmpose.egg-info/SOURCES.txt'
  writing manifest file '/tmp/pip-pip-egg-info-68fjtitx/mmpose.egg-info/SOURCES.txt'
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting xtcocotools@ git+https://github.com/jin-s13/xtcocoapi (from mmpose==1.3.2)
  Cloning https://github.com/jin-s13/xtcocoapi to /tmp/pip-install-hec44s1q/xtcocotools_9383190f3ca042678db04582baf66ca5
  Running command git version
  git version 2.34.1
  Running command git clone --filter=blob:none https://github.com/jin-s13/xtcocoapi /tmp/pip-install-hec44s1q/xtcocotools_9383190f3ca042678db04582baf66ca5
  Cloning into '/tmp/pip-install-hec

In [None]:

# Check Pytorch installation
import torch, torchvision

print('torch version:', torch.__version__, torch.cuda.is_available())
print('torchvision version:', torchvision.__version__)

# Check MMPose installation
import mmpose

print('mmpose version:', mmpose.__version__)

# Check mmcv installation
from mmcv.ops import get_compiling_cuda_version, get_compiler_version

print('cuda version:', get_compiling_cuda_version())
print('compiler information:', get_compiler_version())

torch version: 2.1.0+cu118 True
torchvision version: 0.16.0+cu118
mmpose version: 1.3.2
cuda version: 11.8
compiler information: GCC 9.3


In [None]:

from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
import torch
import numpy as np
import cv2
import mmcv
from mmengine.registry import init_default_scope
from mmpose.apis import inference_topdown, init_model as init_pose_estimator
from mmpose.structures import merge_data_samples
from mmdet.apis import inference_detector, init_detector
from mmpose.evaluation.functional import nms
import pickle
from tqdm import tqdm

In [None]:
# Constants from Collective Activity Dataset
FRAMES_NUM = {
    1: 302
}


In [None]:
# Configuration
DATA_ROOT = '/content/drive/My Drive/Collective'
POSE_CONFIG = 'configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_hrnet-w32_8xb64-210e_coco-256x192.py'
POSE_CHECKPOINT = 'https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_coco_256x192-c78dce93_20200708.pth'
DET_CONFIG = 'demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py'
DET_CHECKPOINT = 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
OUTPUT_PKL = '/content/drive/My Drive/Collective/collective_keypoints.pkl'
MAX_PEOPLE = 13  # M
NUM_FRAMES = 10  # T
NUM_KEYPOINTS = 17  # V (COCO keypoints)

In [None]:
# Check GPU availability
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
cfg_options = dict(model=dict(test_cfg=dict(output_heatmaps=True)))

# Initialize detector and pose estimator
detector = init_detector(DET_CONFIG, DET_CHECKPOINT, device=device)
pose_estimator = init_pose_estimator(
    POSE_CONFIG,
    POSE_CHECKPOINT,
    device=device,
    cfg_options=cfg_options
)

Loads checkpoint by http backend from path: https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth
Loads checkpoint by http backend from path: https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_coco_256x192-c78dce93_20200708.pth


In [None]:
def process_frame(img_path, detector, pose_estimator):
    """Process a single frame to extract keypoints, scores, and boxes."""
    assert os.path.exists(img_path), f"Image not found at {img_path}"

    # Run detection
    scope = detector.cfg.get('default_scope', 'mmdet')
    if scope is not None:
        init_default_scope(scope)
    detect_result = inference_detector(detector, img_path)
    pred_instance = detect_result.pred_instances.cpu().numpy()
    bboxes = np.concatenate(
        (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
    bboxes = bboxes[np.logical_and(pred_instance.labels == 0, pred_instance.scores > 0.3)]
    bboxes = bboxes[nms(bboxes, 0.3)][:, :4]

    # Debug: Print number of boxes
    print(f"Before NMS: {len(bboxes)} boxes")
    print(f"After NMS: {len(bboxes)} boxes")
    print(bboxes.shape)

    # Run pose estimation
    pose_results = inference_topdown(pose_estimator, img_path, bboxes)

    # Extract keypoints and scores, removing extra batch dimension
    all_keypoints = []
    all_scores = []
    for pose_result in pose_results:
        keypoints = pose_result.pred_instances.keypoints  # Shape: (1, 17, 2) or (17, 2)
        scores = pose_result.pred_instances.keypoint_scores  # Shape: (1, 17) or (17,)
        # Remove batch dimension if present
        keypoints = keypoints.squeeze(0) if keypoints.shape[0] == 1 else keypoints  # Shape: (17, 2)
        scores = scores.squeeze(0) if scores.shape[0] == 1 else scores  # Shape: (17,)
        print(f"Pose result: keypoints.shape = {keypoints.shape}")
        all_keypoints.append(keypoints)
        all_scores.append(scores)

    # Convert to arrays, handle empty cases
    if len(all_keypoints) > 0:
        all_keypoints = np.stack(all_keypoints, axis=0)  # Shape: (M, 17, 2)
        all_scores = np.stack(all_scores, axis=0)        # Shape: (M, 17)
    else:
        all_keypoints = np.zeros((0, NUM_KEYPOINTS, 2))  # Empty array
        all_scores = np.zeros((0, NUM_KEYPOINTS))

    boxes = bboxes  # Shape: (M, 4)

    # Debug: Print shapes before padding
    print(f"Before padding: all_keypoints.shape = {all_keypoints.shape}")
    print(f"Before padding: all_scores.shape = {all_scores.shape}")
    print(f"Before padding: boxes.shape = {boxes.shape}")
    print(f"M = {all_keypoints.shape[0]}, MAX_PEOPLE = {MAX_PEOPLE}")

    # Pad if fewer than MAX_PEOPLE
    M = all_keypoints.shape[0]
    if M < MAX_PEOPLE:
        # Pad keypoints: (M, 17, 2) -> (MAX_PEOPLE, 17, 2)
        keypoints_pad = np.zeros((MAX_PEOPLE - M, NUM_KEYPOINTS, 2))
        all_keypoints = np.concatenate([all_keypoints, keypoints_pad], axis=0) if M > 0 else keypoints_pad

        # Pad scores: (M, 17) -> (MAX_PEOPLE, 17)
        scores_pad = np.zeros((MAX_PEOPLE - M, NUM_KEYPOINTS))
        all_scores = np.concatenate([all_scores, scores_pad], axis=0) if M > 0 else scores_pad

        # Pad boxes: (M, 4) -> (MAX_PEOPLE, 4)
        boxes_pad = np.zeros((MAX_PEOPLE - M, 4))
        boxes = np.concatenate([boxes, boxes_pad], axis=0) if M > 0 else boxes_pad

    # Debug: Print shapes after padding
    print(f"After padding: all_keypoints.shape = {all_keypoints.shape}")
    print(f"After padding: all_scores.shape = {all_scores.shape}")
    print(f"After padding: boxes.shape = {boxes.shape}")

    return {
        'keypoints': all_keypoints[:MAX_PEOPLE],        # Shape: (13, 17, 2)
        'keypoint_scores': all_scores[:MAX_PEOPLE],     # Shape: (13, 17)
        'boxes': boxes[:MAX_PEOPLE]                     # Shape: (13, 4)
    }

In [None]:
def create_keypoint_dataset():
    """Create keypoint dataset for Collective Activity Dataset."""
    dataset_dict = {}

    # Iterate over sequences (1 to 44)
    for sid in tqdm(range(1,2), desc="Processing sequences"):
        seq_key = f'seq{sid:02d}'
        seq_dict = {}
        total_frames = FRAMES_NUM[sid]
        seq_path = os.path.join(DATA_ROOT, seq_key)

        # Iterate over every 10th frame
        for start_frame in range(1, total_frames + 1, 10):
            if start_frame + 9 > total_frames:
                continue  # Skip if clip exceeds total frames

            clip_dict = {
                'keypoints': np.zeros((MAX_PEOPLE, NUM_FRAMES, NUM_KEYPOINTS, 2)),
                'keypoint_scores': np.zeros((MAX_PEOPLE, NUM_FRAMES, NUM_KEYPOINTS)),
                'boxes': np.zeros((MAX_PEOPLE, NUM_FRAMES, 4))  # Per-frame boxes
            }

            # Process 10 frames in the clip
            for t, fid in enumerate(range(start_frame, start_frame + NUM_FRAMES)):
                img_path = os.path.join(seq_path, f'frame{fid:04d}.jpg')
                if not os.path.exists(img_path):
                    print(f"Warning: Image not found at {img_path}")
                    continue

                result = process_frame(img_path, detector, pose_estimator)
                clip_dict['keypoints'][:, t] = result['keypoints']
                clip_dict['keypoint_scores'][:, t] = result['keypoint_scores']
                clip_dict['boxes'][:, t] = result['boxes']

            seq_dict[start_frame] = clip_dict

        dataset_dict[seq_key] = seq_dict

    # Save to pickle file
    with open(OUTPUT_PKL, 'wb') as f:
        pickle.dump(dataset_dict, f)
    print(f"Dataset saved to {OUTPUT_PKL}")

In [None]:
if __name__ == "__main__":
    create_keypoint_dataset()

Processing sequences:   0%|          | 0/1 [00:00<?, ?it/s]

Before NMS: 5 boxes
After NMS: 5 boxes
(5, 4)
Pose result: keypoints.shape = (17, 2)
Pose result: keypoints.shape = (17, 2)
Pose result: keypoints.shape = (17, 2)
Pose result: keypoints.shape = (17, 2)
Pose result: keypoints.shape = (17, 2)
Before padding: all_keypoints.shape = (5, 17, 2)
Before padding: all_scores.shape = (5, 17)
Before padding: boxes.shape = (5, 4)
M = 5, MAX_PEOPLE = 13
After padding: all_keypoints.shape = (13, 17, 2)
After padding: all_scores.shape = (13, 17)
After padding: boxes.shape = (13, 4)
Before NMS: 4 boxes
After NMS: 4 boxes
(4, 4)
Pose result: keypoints.shape = (17, 2)
Pose result: keypoints.shape = (17, 2)
Pose result: keypoints.shape = (17, 2)
Pose result: keypoints.shape = (17, 2)
Before padding: all_keypoints.shape = (4, 17, 2)
Before padding: all_scores.shape = (4, 17)
Before padding: boxes.shape = (4, 4)
M = 4, MAX_PEOPLE = 13
After padding: all_keypoints.shape = (13, 17, 2)
After padding: all_scores.shape = (13, 17)
After padding: boxes.shape = (13

Processing sequences: 100%|██████████| 1/1 [01:21<00:00, 81.54s/it]

Pose result: keypoints.shape = (17, 2)
Pose result: keypoints.shape = (17, 2)
Pose result: keypoints.shape = (17, 2)
Pose result: keypoints.shape = (17, 2)
Before padding: all_keypoints.shape = (4, 17, 2)
Before padding: all_scores.shape = (4, 17)
Before padding: boxes.shape = (4, 4)
M = 4, MAX_PEOPLE = 13
After padding: all_keypoints.shape = (13, 17, 2)
After padding: all_scores.shape = (13, 17)
After padding: boxes.shape = (13, 4)
Dataset saved to /content/drive/My Drive/Collective/collective_keypoints.pkl





In [None]:
path = '/content/drive/My Drive/Collective/collective_keypoints.pkl'
with open(path, 'rb') as f:
    data = pickle.load(f)


# Check the top-level keys
print(data.keys())  # Should print: ['split', 'annotations']

dict_keys(['seq01'])


In [None]:
print(data['seq01'].keys())

dict_keys([1, 11, 21, 31, 41, 51, 61, 71, 81, 91, 101, 111, 121, 131, 141, 151, 161, 171, 181, 191, 201, 211, 221, 231, 241, 251, 261, 271, 281, 291])


In [None]:
print(data['seq01'][1]['keypoints'].shape)

(13, 10, 17, 2)


In [None]:
import os
import torch
import numpy as np
import cv2
import mmcv
from mmengine.registry import init_default_scope
from mmpose.apis import inference_topdown, init_model as init_pose_estimator
from mmpose.structures import merge_data_samples
from mmdet.apis import inference_detector, init_detector
from mmpose.evaluation.functional import nms
import pickle
from tqdm import tqdm

# Constants from Collective Activity Dataset
FRAMES_NUM = {
    1: 302, 2: 347, 3: 194, 4: 257, 5: 536, 6: 401, 7: 968, 8: 221, 9: 356, 10: 302,
    11: 1813, 12: 1084, 13: 851, 14: 723, 15: 464, 16: 1021, 17: 905, 18: 600, 19: 203, 20: 342,
    21: 650, 22: 361, 23: 311, 24: 321, 25: 617, 26: 734, 27: 1804, 28: 470, 29: 635, 30: 356,
    31: 690, 32: 194, 33: 193, 34: 395, 35: 707, 36: 914, 37: 1049, 38: 653, 39: 518, 40: 401,
    41: 707, 42: 420, 43: 410, 44: 356
}

# Configuration
DATA_ROOT = '/home/akila17/e19-group-activity/Group_Activity/Datasets/collective_dataset'
POSE_CONFIG = '/home/akila17/e19-group-activity/Group_Activity/Datasets/mmpose/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_hrnet-w32_8xb64-210e_coco-256x192.py'
POSE_CHECKPOINT = '/home/akila17/e19-group-activity/Group_Activity/Datasets/Skeleton_Data/hrnet_w32_coco_256x192.pth'
DET_CONFIG = '/home/akila17/e19-group-activity/Group_Activity/Datasets/mmpose/demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py'
DET_CHECKPOINT = '/home/akila17/e19-group-activity/Group_Activity/Datasets/Skeleton_Data/faster_rcnn_r50_fpn_1x_coco.pth'
OUTPUT_PKL = 'collective_keypoints.pkl'
MAX_PEOPLE = 13  # M
NUM_FRAMES = 10  # T
NUM_KEYPOINTS = 17  # V (COCO keypoints)

# Check GPU availability
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
cfg_options = dict(model=dict(test_cfg=dict(output_heatmaps=True)))

# Initialize detector and pose estimator
detector = init_detector(DET_CONFIG, DET_CHECKPOINT, device=device)
pose_estimator = init_pose_estimator(
    POSE_CONFIG,
    POSE_CHECKPOINT,
    device=device,
    cfg_options=cfg_options
)

def process_frame(img_path, detector, pose_estimator):
    """Process a single frame to extract keypoints, scores, and boxes."""
    assert os.path.exists(img_path), f"Image not found at {img_path}"

    # Run detection
    scope = detector.cfg.get('default_scope', 'mmdet')
    if scope is not None:
        init_default_scope(scope)
    detect_result = inference_detector(detector, img_path)
    pred_instance = detect_result.pred_instances.cpu().numpy()
    bboxes = np.concatenate(
        (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
    bboxes = bboxes[np.logical_and(pred_instance.labels == 0, pred_instance.scores > 0.3)]
    bboxes = bboxes[nms(bboxes, 0.3)][:, :4]

    # Run pose estimation
    pose_results = inference_topdown(pose_estimator, img_path, bboxes)

    # Extract keypoints and scores
    all_keypoints = []
    all_scores = []
    for pose_result in pose_results:
        keypoints = pose_result.pred_instances.keypoints
        scores = pose_result.pred_instances.keypoint_scores
        all_keypoints.append(keypoints)
        all_scores.append(scores)

    all_keypoints = np.array(all_keypoints)  # Shape: (M, V, 2)
    all_scores = np.array(all_scores)       # Shape: (M, V)
    boxes = bboxes                          # Shape: (M, 4)

    # Pad if fewer than MAX_PEOPLE
    M = all_keypoints.shape[0] if all_keypoints.size > 0 else 0
    if M < MAX_PEOPLE:
        pad_shape = (MAX_PEOPLE - M, NUM_KEYPOINTS, 2)
        all_keypoints = np.pad(all_keypoints, ((0, MAX_PEOPLE - M), (0, 0), (0, 0)),
                               mode='constant', constant_values=0) if M > 0 else np.zeros((MAX_PEOPLE, NUM_KEYPOINTS, 2))
        all_scores = np.pad(all_scores, ((0, MAX_PEOPLE - M), (0, 0)),
                            mode='constant', constant_values=0) if M > 0 else np.zeros((MAX_PEOPLE, NUM_KEYPOINTS))
        boxes = np.pad(boxes, ((0, MAX_PEOPLE - M), (0, 0)),
                       mode='constant', constant_values=0) if M > 0 else np.zeros((MAX_PEOPLE, 4))

    return {
        'keypoints': all_keypoints[:MAX_PEOPLE],        # Shape: (13, 17, 2)
        'keypoint_scores': all_scores[:MAX_PEOPLE],     # Shape: (13, 17)
        'boxes': boxes[:MAX_PEOPLE]                     # Shape: (13, 4)
    }

def create_keypoint_dataset():
    """Create keypoint dataset for Collective Activity Dataset."""
    dataset_dict = {}

    # Iterate over sequences (1 to 44)
    for sid in tqdm(range(1, 45), desc="Processing sequences"):
        seq_key = f'seq{sid:02d}'
        seq_dict = {}
        total_frames = FRAMES_NUM[sid]
        seq_path = os.path.join(DATA_ROOT, seq_key)

        # Iterate over every 10th frame
        for start_frame in range(1, total_frames + 1, 10):
            if start_frame + 9 > total_frames:
                continue  # Skip if clip exceeds total frames

            clip_dict = {
                'keypoints': np.zeros((MAX_PEOPLE, NUM_FRAMES, NUM_KEYPOINTS, 2)),
                'keypoint_scores': np.zeros((MAX_PEOPLE, NUM_FRAMES, NUM_KEYPOINTS)),
                'boxes': np.zeros((MAX_PEOPLE, NUM_FRAMES, 4))  # Per-frame boxes
            }

            # Process 10 frames in the clip
            for t, fid in enumerate(range(start_frame, start_frame + NUM_FRAMES)):
                img_path = os.path.join(seq_path, f'frame{fid:04d}.jpg')
                if not os.path.exists(img_path):
                    print(f"Warning: Image not found at {img_path}")
                    continue

                result = process_frame(img_path, detector, pose_estimator)
                clip_dict['keypoints'][:, t] = result['keypoints']
                clip_dict['keypoint_scores'][:, t] = result['keypoint_scores']
                clip_dict['boxes'][:, t] = result['boxes']

            seq_dict[start_frame] = clip_dict

        dataset_dict[seq_key] = seq_dict

    # Save to pickle file
    with open(OUTPUT_PKL, 'wb') as f:
        pickle.dump(dataset_dict, f)
    print(f"Dataset saved to {OUTPUT_PKL}")

if __name__ == "__main__":
    create_keypoint_dataset()