Libraries

In [1]:
import cv2
import torch
import torch.nn as nn
import numpy as np
from torchvision.transforms import ToTensor

SGP layer

In [2]:
# SGP Layer
class SGPLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SGPLayer, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)  # Fine granularity
        self.conv2 = nn.Conv1d(in_channels, out_channels, kernel_size=5, stride=1, padding=2)  # Coarser granularity
        self.conv3 = nn.Conv1d(in_channels, out_channels, kernel_size=7, stride=1, padding=3)  # Coarsest granularity
        self.fusion = nn.Conv1d(out_channels * 3, out_channels, kernel_size=1)  # Fuse all scales
        self.relu = nn.ReLU()

    def forward(self, x):
        fine_features = self.conv1(x)
        coarse_features = self.conv2(x)
        coarsest_features = self.conv3(x)
        combined = torch.cat([fine_features, coarse_features, coarsest_features], dim=1)
        fused_features = self.fusion(combined)
        return self.relu(fused_features)

Punch Detection Model

In [3]:
class PunchDetectionModel(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(PunchDetectionModel, self).__init__()
        self.sgp_layer = SGPLayer(in_channels, 128)
        self.temporal_pooling = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        """
        x: Tensor of shape (batch_size, in_channels, sequence_length)
        """
        sgp_out = self.sgp_layer(x)  # Apply SGP layer
        pooled_features = self.temporal_pooling(sgp_out).squeeze(-1)  # Global temporal pooling
        logits = self.fc(pooled_features)  # Final classification layer
        return logits

Video Annotation

In [None]:
def annotate_video(video_path, output_path, model, sequence_length=100, threshold=0.5):
    """
    Annotate a video with hit predictions.

    Args:
        video_path (str): Path to the input video file.
        output_path (str): Path to save the annotated video.
        model (nn.Module): Trained punch detection model.
        sequence_length (int): Number of frames in a sequence for the model.
        threshold (float): Probability threshold for detecting a hit.
    """
    cap = cv2.VideoCapture(video_path)
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))

    model.eval()
    frame_buffer = []
    transform = ToTensor()  # Transform frames into tensors

    for i in range(total_frames):
        ret, frame = cap.read()
        if not ret:
            break

        resized_frame = cv2.resize(frame, (224, 224))  # Resize to fit model input
        frame_tensor = transform(resized_frame).permute(2, 0, 1)  # Convert to tensor
        frame_buffer.append(frame_tensor)

        if len(frame_buffer) == sequence_length:
            input_tensor = torch.stack(frame_buffer).unsqueeze(0)  # Shape (1, C, L, H, W)

            with torch.no_grad():
                logits = model(input_tensor)
                probs = torch.softmax(logits, dim=1)
                prediction = probs[0, 1].item()

            for frame_idx in range(sequence_length):
                annotated_frame = frame_buffer[frame_idx].permute(1, 2, 0).numpy() * 255  # Convert back to image
                annotated_frame = annotated_frame.astype(np.uint8)
                label = "Hit" if prediction >= threshold else "No Hit"
                color = (0, 255, 0) if label == "Hit" else (0, 0, 255)

                cv2.putText(annotated_frame, f"{label} ({prediction:.2f})",
                            (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2, cv2.LINE_AA)
                out.write(annotated_frame)

            frame_buffer = []

    cap.release()
    out.release()
    print(f"Annotated video saved to {output_path}")

Sample usage

In [None]:
if __name__ == "__main__":
    # Define input parameters
    video_path = "input_video.mp4"
    output_path = "output_annotated.mp4"

    # Load trained model
    in_channels = 3
    num_classes = 2
    model = PunchDetectionModel(in_channels, num_classes)
    model.load_state_dict(torch.load("punch_detection_model.pth"))  # Load your trained model weights

    # Annotate the video
    annotate_video(video_path, output_path, model)