In [None]:
# 1. Bibliotheken importieren
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from model import HockeyActionModel   # Modell aus model.py
from dataloader import HockeyDataset  # DataLoader aus dataloader.py
import os

# 2. Device auswählen
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Benutze Gerät: {device}")

# 3. Hyperparameter
batch_size = 2
num_epochs = 10
learning_rate = 0.001
num_classes = 4
frames_per_clip = 10

# 4. Daten vorbereiten
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

dataset = HockeyDataset('labels.csv', 'frames', transform=transform, frames_per_clip=frames_per_clip)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 5. Modell laden
model = HockeyActionModel(num_classes=num_classes).to(device)

# 6. Loss und Optimizer definieren
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 7. Trainingsschleife
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for frames, labels in dataloader:
        # frames: Liste von Bildern  zu Tensor stapeln
        frames = torch.stack(frames).to(device)  # (Batch, 10, 3, 224, 224)
        
        # Labels als Zahlen encodieren
        label_map = {'Check':0, 'Neutral':1, 'Schuss':2, 'Tor':3}
        labels = torch.tensor([label_map[label] for label in labels]).to(device)

        # Vorwärts
        outputs = model(frames)
        loss = criterion(outputs, labels)

        # Rückwärts
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Statistik
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    # Epochenergebnisse
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100 * correct / total
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')

# 8. Modell speichern
os.makedirs('models', exist_ok=True)
torch.save(model.state_dict(), 'models/hockey_action_model.pth')
print("✅ Modell erfolgreich gespeichert!")
