In [None]:
# 1. Bibliotheken importieren
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
from model import HockeyActionModel
from dataloader import HockeyDataset
from torch.utils.data import DataLoader

# 2. Device wählen
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Benutze Gerät: {device}")

# 3. Modell vorbereiten
num_classes = 4
model = HockeyActionModel(num_classes=num_classes)
model.load_state_dict(torch.load('models/hockey_action_model.pth', map_location=device))
model = model.to(device)
model.eval()
print("✅ Modell geladen und bereit.")

# 4. Transformationen definieren
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# 5. Testdaten vorbereiten
test_dataset = HockeyDataset('labels.csv', 'frames', transform=transform, frames_per_clip=10)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# 6. Visualisierung: Frames + Vorhersage anzeigen
label_map_reverse = {0: 'Check', 1: 'Neutral', 2: 'Schuss', 3: 'Tor'}

# Nur einen Clip anschauen
frames, labels = next(iter(test_loader))
frames = torch.stack(frames).to(device)

# Vorhersage
with torch.no_grad():
    outputs = model(frames)
    _, predicted = torch.max(outputs, 1)
    prediction = label_map_reverse[predicted.item()]
    print(f"✅ Vorhersage: {prediction}")

# Frames anzeigen
frames = frames.squeeze(0).cpu()  # (10, 3, 224, 224)

fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.flatten()

for idx, ax in enumerate(axes):
    if idx < len(frames):
        img = frames[idx].permute(1, 2, 0)  # (C, H, W) -> (H, W, C)
        ax.imshow(img)
        ax.axis('off')
        ax.set_title(f"Frame {idx+1}")

plt.suptitle(f"Vorhersage für Clip: {prediction}", fontsize=16)
plt.tight_layout()
plt.show()
