In [2]:
import sys, os
sys.path.append(os.path.abspath(".."))

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

In [3]:
from models.mobilenetv3 import MobileNetV3Extractor
from models.lstm_attention import BiLSTMWithAttention
from preprocessing.dataset import SignLanguageDataset

In [13]:

model_path = "../checkpoints/mobilenetv3_lstm_aug_smooth.pth"

num_frames = 20
batch_size = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [14]:
val_root = "../data/frames/validation" 
class_names = sorted(os.listdir(val_root))
label_map = {name: idx for idx, name in enumerate(class_names)}
inv_label_map = {v: k for k, v in label_map.items()}

# 注意：验证时不做随机增强
import albumentations as A
from albumentations.pytorch import ToTensorV2

val_transform = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ToTensorV2()
])

val_dataset = SignLanguageDataset(
    root_dir=val_root,
    label_map=label_map,
    num_frames=num_frames,
    split="validation",
    transform=val_transform
)

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)


[DEBUG] First 20 samples loaded:
Sample 0: class=write, label=284
    Frame count: 20 | First frame: ../data/frames/validation/write/64061/frame_000.jpg
Sample 1: class=your, label=291
    Frame count: 20 | First frame: ../data/frames/validation/your/64423/frame_000.jpg
Sample 2: class=your, label=291
    Frame count: 20 | First frame: ../data/frames/validation/your/64434/frame_000.jpg
Sample 3: class=apple, label=7
    Frame count: 20 | First frame: ../data/frames/validation/apple/69213/frame_000.jpg
Sample 4: class=apple, label=7
    Frame count: 20 | First frame: ../data/frames/validation/apple/02999/frame_000.jpg
Sample 5: class=apple, label=7
    Frame count: 20 | First frame: ../data/frames/validation/apple/65086/frame_000.jpg
Sample 6: class=accident, label=1
    Frame count: 20 | First frame: ../data/frames/validation/accident/00626/frame_000.jpg
Sample 7: class=accident, label=1
    Frame count: 20 | First frame: ../data/frames/validation/accident/00627/frame_000.jpg
Sample 8

In [15]:
class FullSLRModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.feature_extractor = MobileNetV3Extractor()
        self.temporal_model = BiLSTMWithAttention(input_dim=960, hidden_dim=256, num_classes=num_classes)

    def forward(self, x):  # x: [B, T, C, H, W]
        features = self.feature_extractor(x)
        logits, _ = self.temporal_model(features)
        return logits

model = FullSLRModel(num_classes=len(label_map)).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()



FileNotFoundError: [Errno 2] No such file or directory: '../checkpoints/mobilenetv3_lstm_aug_smooth.pth'

In [None]:
print(classification_report(all_labels, all_preds, target_names=class_names))

In [None]:
cm = confusion_matrix(all_labels, all_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
fig, ax = plt.subplots(figsize=(12, 12))
disp.plot(xticks_rotation=90, ax=ax)
plt.title("Confusion Matrix")
plt.show()