In [1]:
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import os


class Dance3DCNN(nn.Module):
    def __init__(self, num_classes):
        super(Dance3DCNN, self).__init__()
        self.conv1 = nn.Conv3d(3, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        self.pool1 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
        self.conv2 = nn.Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(1, 1, 1)) 
        self.fc1 = nn.Linear(32 * 1 * 124 * 124, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.pool1(x)
        x = torch.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Dance3DCNN(num_classes=4).to(device)
model.load_state_dict(torch.load("ModelTry.pth", map_location=device))
model.eval()



Dance3DCNN(
  (conv1): Conv3d(3, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (pool1): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (pool2): MaxPool3d(kernel_size=(2, 2, 2), stride=(1, 1, 1), padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=492032, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=4, bias=True)
)

In [2]:

transform = transforms.Compose([
    transforms.Resize((250, 250)), 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


class_mapping = {0: "HipHop", 1: "Jazz", 2: "Kata", 3: "Taichi"}

In [3]:

test_image_path = "Taichi_S6_C0_00765.png" 


if not os.path.exists(test_image_path):
    print(f"文件未找到：{test_image_path}")
else:

    image = Image.open(test_image_path).convert("RGB")


    image = transform(image).unsqueeze(0).unsqueeze(2).repeat(1, 1, 5, 1, 1).to(device)

    print(f"Input Tensor Shape: {image.shape}")


    with torch.no_grad():
        outputs = model(image)
        _, predicted = torch.max(outputs, 1)
        predicted_class = class_mapping[predicted.item()]

    print(f"Predicted Class: {predicted.item()}")
    print(f"Predicted Dance Type: {predicted_class}")


Input Tensor Shape: torch.Size([1, 3, 5, 250, 250])
Predicted Class: 3
Predicted Dance Type: Taichi
