In [None]:
from egoexo_dataloader_2 import EgoExoDataset
import os
import numpy as np
import torch
import random
from torch.utils.data import DataLoader

from mmpose.apis import MMPoseInferencer
from mmaction.apis import pose_inference, detection_inference

In [None]:
dataset_path = '/media/thibault/T5 EVO/Datasets/Ego4D/'
train_dataset = EgoExoDataset(dataset_path, os.path.join(dataset_path, 'takes.json'), split = "train", skill=True, get_frames=False, get_pose=True, get_hands_pose=False, frame_rate=3, transform=None)
val_dataset = EgoExoDataset(dataset_path, os.path.join(dataset_path, 'takes.json'), split = "val", skill=True, get_frames=False, get_pose=True, get_hands_pose=False, frame_rate=3, transform=None)


In [None]:
data = train_dataset.__getitem__(50)
data["pose"]["0"][0]

In [None]:
paths = [path for path in data["samples"]["exo"]]
paths.append(data["samples"]["ego"])
paths

In [None]:
print(len(train_dataset), len(val_dataset))

In [None]:
skeletons = list(data["pose"].values())


for skeleton in skeletons:
    for frame in skeleton:
        print(np.min(frame["keypoints"]))
        

In [None]:
from mmaction.apis import inference_skeleton, init_recognizer
from mmengine import Config
import torch
torch.cuda.empty_cache()
config_path = "/home/thibault/Documents/Code/pckg/mmaction2/configs/skeleton/posec3d/slowonly_r50_8xb16-u48-240e_ntu60-xsub-keypoint.py"
checkpoint_path = "/home/thibault/Documents/Code/pckg/mmaction2/checkpoints/slowonly_r50_8xb16-u48-240e_ntu60-xsub-keypoint_20220815-38db104b.pth" # can be a local path   # you can specify your own picture path

cfg = Config.fromfile(config_path)



# build the model from a config file and a checkpoint file
skeleton_model = init_recognizer(cfg, checkpoint_path, device="cuda:0")


In [None]:
import torch

def collate_pose_and_label_with_padding(batch):
    """
    batch: list of samples, each a dict with:
      - 'pose':  np.ndarray (dtype=object) length N_i of view‐dicts
      - 'label': int

    Returns:
      padded_keypoints: FloatTensor (B, V_max, K, 2)
      padded_scores:    FloatTensor (B, V_max, K)
      view_mask:        BoolTensor  (B, V_max)
      labels:           LongTensor  (B,)
    """
    B = len(batch)
    # 1) labels
    labels = torch.tensor([s['label'] for s in batch], dtype=torch.long)

    # 2) collect pose-arrays and find max views
    pose_arrays = [s['pose'] for s in batch]            # list of np.object arrays
    num_views   = [len(pa) for pa in pose_arrays]       # N_i
    V_max       = max(num_views)

    # 3) get joint-count K (and dim=2) from the first view of first sample
    sample0_view0 = pose_arrays[0]['0'][0]
    _, K, dim = sample0_view0['keypoints'].shape        # (1, K, 2)

    # 4) allocate padded tensors
    padded_keypoints = torch.zeros(B, V_max, K, dim, dtype=torch.float)
    padded_scores    = torch.zeros(B, V_max, K,    dtype=torch.float)
    view_mask        = torch.zeros(B, V_max,        dtype=torch.bool)

    # 5) fill in real views
    for b, pa in enumerate(pose_arrays):
        for v, view_dict in enumerate(pa.values()):
            # squeeze out that leading 1 if present
            kp = view_dict['keypoints']
            if kp.ndim == 3 and kp.shape[0] == 1:
                kp = kp[0]            # now (K, 2)
            scores = view_dict['keypoint_scores']  # (K,)

            padded_keypoints[b, v] = torch.from_numpy(kp)
            padded_scores[b, v]    = torch.from_numpy(scores)
            view_mask[b, v]        = True

    return padded_keypoints, padded_scores, view_mask, labels


In [None]:
import torch

def collate_pose_and_label_with_padding_2(batch):
    """
    batch: list of samples, each a dict with:kp
      - 'pose':  np.ndarray (dtype=object) length N_i of view‐dicts
      - 'label': int, V_max)
      labels:           LongTensor  (B,)
    """
    B = len(batch)
    # 1) labels
    labels = torch.tensor([s['skill'] for s in batch], dtype=torch.long)

    # 2) collect pose-arrays and find max views
    pose_arrays = [s['pose'] for s in batch]            # list of np.object arrays
    num_views   = [len(pa) for pa in pose_arrays]       # N_i
    V_max       = max(num_views)

    for sample in batch:
      if len(sample["pose"].keys()) < V_max:
        for i in range(len(sample["pose"].keys()), V_max):
          T = len(sample["pose"]["0"])
          sample["pose"][str(i)] = np.array(np.repeat({"keypoints": np.zeros((V_max, 17, 2)), "keypoint_scores": np.zeros((V_max, 17))}, T))
       
        
      
    
    return pose_arrays, labels


In [None]:
import torch

def collate_objects(batch):
    """
    batch: list of samples, each a dict with:kp
      - 'pose':  np.ndarray (dtype=object) length N_i of view‐dicts
      - 'label': int, V_max)
      labels:           LongTensor  (B,)
    """
    B = len(batch)
    # 1) labels
    labels = torch.tensor([s['skill'] for s in batch], dtype=torch.long)

    # 2) collect pose-arrays and find max views
    pose_arrays = [s for s in batch]            # list of np.object arrays
        
      
    
    return pose_arrays, labels


In [None]:
from pathlib import Path
DETECTION_CONFIG = Path(
    "./generate_skeletons/"
    "faster-rcnn_r50_fpn_2x_coco_infer.py"
)
DETECTION_CHECKPOINT = Path(
    "./generate_skeletons/"
    "faster_rcnn_r50_fpn_2x_coco_bbox_mAP-0.384_20200504_210434-a5d8aa15.pth"
)

SKELETON_CONFIG = Path(
    "./generate_skeletons/"
    "td-hm_hrnet-w32_8xb64-210e_coco-256x192_infer.py"
)
SKELETON_CHECKPOINT = Path(
    "./generate_skeletons/"
    "hrnet_w32_coco_256x192-c78dce93_20200708.pth"
)

WINDOW_SIZE = 128
SAMPLING_RATE = 3

In [None]:

import cv2
import numpy as np
from typing import Callable, Generator, List
from IPython.display import clear_output
import os
import contextlib


def frame_windows(video_path: str,
                  window_size: int,
                  sampling_rate: int = 1
                 ) -> Generator[List[np.ndarray], None, None]:
    """
    Lazily read frames from the video, sampling 1 in every `sampling_rate`,
    and yield lists of up to `window_size` sampled frames.
    """
    cap = cv2.VideoCapture(video_path)
    print(f'Length of video: {cap.get(cv2.CAP_PROP_FRAME_COUNT)}')
    if not cap.isOpened():
        print(f"Error opening video file {video_path}")
        return None
    frames: List[np.ndarray] = []
    raw_idx = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break

        # only keep every `sampling_rate`th frame
        if raw_idx % sampling_rate == 0:
            frames.append(frame)
            print()
            if len(frames) == window_size:
                yield frames
                frames = []

        raw_idx += 1

    # leftover sampled frames
    if frames:
        yield frames

    cap.release()

def extract_skeleton(video_path: str,
                     window_size: int = 32,
                     sampling_rate: int = 1,
                     predict: Callable[[List[np.ndarray]], np.ndarray] = None
                    ) -> np.ndarray:
    """
    Process a video in non-overlapping windows to extract per-frame skeletons
    without ever loading the whole video into RAM at once.

    Args:
        video_path: Path to the input video file.
        window_size: Number of frames per chunk.
        predict: Function(frames_list) → np.ndarray of shape
                 (num_frames, num_keypoints, dims).

    Returns:
        skeleton_sequence: np.ndarray of shape
            (total_frames, num_keypoints, dims).
    """
    if predict is None:
        raise ValueError("You must pass in a `predict` function.")
    
    skeleton_chunks = []
    total = 0
    for frames in frame_windows(video_path, window_size, sampling_rate):
        # frames is at most window_size in length

# Suppose this is the noisy function:
# from some_lib import noisy_function

        with open(os.devnull, 'w') as devnull, \
            contextlib.redirect_stdout(devnull), \
            contextlib.redirect_stderr(devnull):
            # Anything printed to stdout or stderr inside here is discarded
            

            det_results, _ = detection_inference(str(DETECTION_CONFIG), str(DETECTION_CHECKPOINT), frames)
            pose_results, pose_data_samples = pose_inference(str(SKELETON_CONFIG), str(SKELETON_CHECKPOINT), frames, det_results)
            # e.g. (len(frames), K, D)
            skeleton_chunks.append(np.array(pose_results))
            total += len(pose_results)
        print(total)
    
    # Concatenate along temporal axis:
    if len(skeleton_chunks) == 0:
        skeleton_sequence = np.array([])
    else:
        skeleton_sequence = np.concatenate(skeleton_chunks, axis=0)
    assert skeleton_sequence.shape[0] == total

    return skeleton_sequence


In [None]:
test_skeleton = np.load("test_skeleton.npy", allow_pickle=True)
test_skeleton

In [None]:
def process_skeletons(poses, skeleton_model, pose_inference):
    """
    Process a batch of skeletons.
    Args:
        poses: list of dicts with keys "ego" and "exo"
        pose_inference: function to extract skeletons
    Returns:
        batch: list of tensors with shape (B, T, D)
    """
    skeletons_batch = []
    for data in poses:
        paths = [path for path in data["samples"]["exo"]]
        paths.append(data["samples"]["ego"])
        
        skeletons = {}
        
        for i, path in enumerate(paths):
            # Extract skeletons from the video
            skeletons[str(i)] = extract_skeleton(path, window_size = 128, sampling_rate=64, predict=pose_inference)
        skeletons_batch.append(skeletons)
    batch = []
    
    for sample in skeletons_batch:
        views = []
        for view in sample.values():
            representation = inference_skeleton(skeleton_model, view,(1920,1080), test_pipeline=None)  # (B, T, D)
            views.append(representation)  # (B, T, D)
        batch.append(torch.stack(views))  # (B, T, D)
    return batch



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
logging.getLogger('mmdet').setLevel(logging.WARNING)
# Now INFO and DEBUG from some_lib won’t show up.



class MultiViewSkeletonClassifier(nn.Module):
    """
    PyTorch model for multi-view skeleton fusion and classification.

    Inputs:
        - Four skeleton feature tensors, each of shape (batch_size, seq_len=20, feat_dim=512).
    Output:
        - Logits for classification into num_classes.
    """
    def __init__(self, skeleton_model, feat_dim=512, seq_len=20, num_views=4, hidden_dim=512, num_classes=10):
        super(MultiViewSkeletonClassifier, self).__init__()
        self.num_views = num_views
        self.seq_len = seq_len
        self.feat_dim = feat_dim
        self.skeleton_model = skeleton_model
        
        # Define the attention head for per-view summary
        self.attn_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim // 2, 1)
        )
        
        

        # Learnable fusion weights for each view (softmax-normalized)
        self.view_weights = nn.Parameter(torch.ones(num_views))

        # Project concatenated features to hidden dimension
        self.fusion_proj = nn.Linear(feat_dim, hidden_dim)

        # Temporal convolutional layers
        self.conv1 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm1d(hidden_dim)

        # Classification head
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim * seq_len, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, batch):
        """
        Args:
            views: list or tuple of length num_views, each tensor of shape (B, T, D)
        Returns:
            logits: tensor of shape (B, num_classes)
        """
        # Stack views: (V, B, T, D) -> (B, V, T, D)
        
            
        x = torch.stack(batch).to(device='cuda:0')  # (B, V, T, D)
        B, V, T, D = x.shape
        
        # per‐view summary → score → masked softmax → fusion
        summary = x.mean(dim=2)              # (B, V, hidden)
        raw_w   = self.attn_head(summary).squeeze(-1) # (B, V)
        w       = F.softmax(raw_w)    # (B, V)
        fused   = (w.view(B,V,1,1) * x).sum(dim=1)


        # Project features
        fused = self.fusion_proj(fused)  # (B, T, hidden_dim)
        fused = F.relu(fused)

        # Prepare for temporal conv: (B, hidden_dim, T)
        fused = fused.permute(0, 2, 1)

        # Temporal convolutional block
        out = F.relu(self.bn1(self.conv1(fused)))
        out = F.relu(self.bn2(self.conv2(out)))

        # Flatten temporal features
        out = out.view(B, -1)

        # Classification head
        logits = self.fc(out)
        return logits


if __name__ == "__main__":
    # Example usage
    B, T, D = 8, 20, 512
    num_views = 4
    num_classes = 4
    multiview_model = MultiViewSkeletonClassifier(skeleton_model=skeleton_model, feat_dim=512, seq_len=T, num_views=num_views, hidden_dim=D, num_classes=num_classes).to('cuda:0')
    # Create dummy inputs
    views = [torch.randn(B, T, D) for _ in range(num_views)]
    
    batch = process_skeletons(([data, data]), skeleton_model, pose_inference)
    results = multiview_model(batch)  # Forward pass
    
    print("Input shapes:", [view.device for view in views])  # expected (8, 20, 512)
    print(results)
     # expected (8, num_classes)


In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.optim as optim

# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Hyperparameters
num_epochs    = 30
learning_rate = 0.001
batch_size    = 8
num_classes   = 5

# Model
multiview_model = MultiViewSkeletonClassifier(
    skeleton_model=skeleton_model,
    feat_dim=512,
    seq_len=20,
    num_views=4,
    hidden_dim=512,
    num_classes=num_classes
).to(device)
model = multiview_model

# Loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)


# Dataloaders
train_loader = DataLoader(
    train_dataset,
    collate_fn=collate_objects,
    batch_size=batch_size,
    shuffle=True
)
val_loader = DataLoader(
    val_dataset,  # assuming you have a separate validation dataset
    collate_fn=collate_objects,
    batch_size=batch_size,
    shuffle=False
)

# Training + Validation loop
for epoch in range(num_epochs):
    # ---- TRAIN ----
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for views, labels in tqdm(train_loader, desc=f"Train Epoch {epoch+1}/{num_epochs}"):
        # if views is a list of tensors; otherwise .to(device)
        labels = labels.to(device)

        # forward + backward
        outputs = model(views)
        loss    = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # stats
        running_loss += loss.item()
        _, preds = outputs.max(1)
        total   += labels.size(0)
        correct += preds.eq(labels).sum().item()
        
    
    train_loss = running_loss / len(train_loader)
    train_acc  = 100. * correct / total

    # ---- VALIDATION ----
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total   = 0

    with torch.no_grad():
        for views, labels in tqdm(val_loader, desc=f"Val   Epoch {epoch+1}/{num_epochs}"):
            
            labels = labels.to(device)

            outputs = model(views)
            loss    = criterion(outputs, labels)

            val_loss += loss.item()
            _, preds = outputs.max(1)
            val_total   += labels.size(0)
            val_correct += preds.eq(labels).sum().item()

    val_loss_epoch = val_loss / len(val_loader)
    val_acc        = 100. * val_correct / val_total
    scheduler.step(val_loss_epoch)
    # ---- EPOCH SUMMARY ----
    print(
        f"Epoch {epoch+1}/{num_epochs} | "
        f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
        f"Val   Loss: {val_loss_epoch:.4f}, Val   Acc: {val_acc:.2f}%"
    )
    torch.save(model.state_dict(), f"multiview_model_epoch_{epoch+1}.pth")
print("Training complete.")



In [None]:
# Testing loop
test_loader = DataLoader(
    val_dataset,
    collate_fn=collate_objects,
    batch_size=batch_size,
    shuffle=False
)

# Load the trained model
model.load_state_dict(torch.load("./ckpt/multiview_model_epoch_30.pth"))
model.eval()

test_loss = 0.0
test_correct = 0
test_total = 0

with torch.no_grad():
    videos = []
    ego_model_predictions = []
    exo_model_predictions = []
    for views, labels in tqdm(test_loader, desc="Testing"):
        it_exo_model_predictions = []
        labels = labels.to(device)

        videos += [views[i]['samples']['ego'].split("/")[-5] for i in range(len(views))]
        batch = process_skeletons(views, skeleton_model, pose_inference)
        
        outputs = model(batch)
        
        exo_model_predictions += [[outputs[i][0]] * 4 for i in range(len(outputs))]
        loss = criterion(outputs, labels)

        test_loss += loss.item()
        _, preds = outputs.max(1)
        test_total += labels.size(0)
        test_correct += preds.eq(labels).sum().item()

test_loss_epoch = test_loss / len(test_loader)
test_acc = 100. * test_correct / test_total

# ---- TEST SUMMARY ----
print(
    f"Test Loss: {test_loss_epoch:.4f}, Test Acc: {test_acc:.2f}%"
)