In [1]:
import torch
import torch.nn as nn

# 定义与训练时一致的模型结构
class Dance3DCNN(nn.Module):
    def __init__(self, num_classes):
        super(Dance3DCNN, self).__init__()
        self.conv1 = nn.Conv3d(3, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        self.pool1 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
        self.conv2 = nn.Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(1, 1, 1)) 
        self.fc1 = nn.Linear(32 * 1 * 124 * 124, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.pool1(x)
        x = torch.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 加载模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Dance3DCNN(num_classes=4).to(device)
model.load_state_dict(torch.load("xiaoewochaoshini.pth", map_location=device))
model.eval()


Dance3DCNN(
  (conv1): Conv3d(3, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (pool1): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (pool2): MaxPool3d(kernel_size=(2, 2, 2), stride=(1, 1, 1), padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=492032, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=4, bias=True)
)

In [2]:
import cv2
from torchvision import transforms
import numpy as np
from PIL import Image  # 确保导入 PIL 库

# 定义预处理
transform = transforms.Compose([
    transforms.Resize((250, 250)),  # 和训练时保持一致
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# 视频路径
video_path = "TestVideo_JA01.mp4"

# 打开视频
cap = cv2.VideoCapture(video_path)
frames = []
frame_count = 0

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    # OpenCV 的默认颜色空间是 BGR，需要转为 RGB
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame = Image.fromarray(frame)
    frame = transform(frame).unsqueeze(0)  # 增加 batch 维度
    frames.append(frame)
    frame_count += 1

    # 控制时间维度为 5
    if len(frames) == 5:
        break

cap.release()

# 转换为模型输入格式 (batch_size=1, channels=3, time=5, height=250, width=250)
input_tensor = torch.stack(frames).permute(1, 2, 0, 3, 4).to(device)
# input_tensor = input_tensor.permute(0, 2, 1, 3, 4)
print(f"Input Tensor Shape: {input_tensor.shape}")


Input Tensor Shape: torch.Size([1, 3, 5, 250, 250])


In [3]:
# 预测
with torch.no_grad():
    outputs = model(input_tensor)
    _, predicted = torch.max(outputs, 1)
    print(f"Predicted Class: {predicted.item()}")

# 类别映射（根据您的任务修改）
class_mapping = {0: "HipHop", 1: "Jazz", 2: "Kata", 3: "Taichi"}
predicted_class = class_mapping[predicted.item()]
print(f"Predicted Dance Type: {predicted_class}")


Predicted Class: 1
Predicted Dance Type: Jazz
