In [1]:
import torch
import torch.nn as nn
import cv2
from torchvision import transforms
from PIL import Image
from collections import Counter

class Dance3DCNN(nn.Module):
    def __init__(self, num_classes):
        super(Dance3DCNN, self).__init__()
        self.conv1 = nn.Conv3d(3, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        self.pool1 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
        self.conv2 = nn.Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(1, 1, 1))
        self.fc1 = None
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.pool1(x)
        x = torch.relu(self.conv2(x))
        x = self.pool2(x)

        if self.fc1 is None:
            fc1_input_dim = x.numel() // x.size(0) 
            self.fc1 = nn.Linear(fc1_input_dim, 128).to(x.device)
            # print(f"Dynamic fc1 initialized with input_dim: {fc1_input_dim}")

            
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x



num_classes = 4
class_mapping = {0: "HipHop", 1: "Jazz", 2: "Kata", 3: "Taichi"}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Dance3DCNN(num_classes=num_classes).to(device)
state_dict = torch.load("ModelTry.pth", map_location=device)


state_dict.pop("fc1.weight", None)
state_dict.pop("fc1.bias", None)


model.load_state_dict(state_dict, strict=False)
model.eval()


Dance3DCNN(
  (conv1): Conv3d(3, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (pool1): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (pool2): MaxPool3d(kernel_size=(2, 2, 2), stride=(1, 1, 1), padding=0, dilation=1, ceil_mode=False)
  (fc2): Linear(in_features=128, out_features=4, bias=True)
)

In [2]:

video_path = "TestVideo_KT01.mp4"


cap = cv2.VideoCapture(video_path)


transform = transforms.Compose([
    transforms.Resize((250, 250)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


frame_count = 16 
frame_interval = 2 
frames = []
frame_predictions = []
frame_idx = 0


while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break


    if frame_idx % frame_interval == 0:

        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = Image.fromarray(frame)
        frame = transform(frame)
        frames.append(frame)

    frame_idx += 1


    if len(frames) == frame_count:

        input_tensor = torch.stack(frames).unsqueeze(0).permute(0, 2, 1, 3, 4).to(device)  
        # [batch_size, channels, time, height, width]


        # print(f"Input shape to model: {input_tensor.shape}")


        with torch.no_grad():
            outputs = model(input_tensor)
            _, predicted = torch.max(outputs, 1)
            frame_predictions.append(predicted.item()) 
        frames = []

cap.release()


prediction_counts = Counter(frame_predictions)
final_class = prediction_counts.most_common(1)[0][0]
predicted_class_name = class_mapping[final_class]


print(f"Frame Predictions: {frame_predictions}")
print(f"Final Predicted Class: {predicted_class_name}")


Frame Predictions: [1, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 0, 1, 1, 3, 1, 2, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 3, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 2, 2, 0, 1, 1, 0, 1, 2, 2, 2, 3, 0, 1, 1, 0, 0, 1, 2, 0, 3, 2, 2, 2, 0, 0, 2, 2, 2, 3, 3, 0, 0, 0, 0, 0, 0, 1, 2, 2, 2, 2, 2, 1, 1, 1, 2, 1, 1, 1, 1, 2, 0, 2, 2, 2, 1, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 3, 0]
Final Predicted Class: HipHop
