In [None]:
from typing import List, Dict, Tuple
import os
import cv2
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import matplotlib.pyplot as plt
import json
from datetime import datetime


# --- Directories ---
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
RESULTS_DIR = os.path.join(PROJECT_ROOT, "results")
EXPERIMENTS_DIR = os.path.join(PROJECT_ROOT, "experiments")
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(EXPERIMENTS_DIR, exist_ok=True)


def extract_frames(video_path: str, max_frames: int = 16) -> List[np.ndarray]:
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise FileNotFoundError(f"Could not open {video_path}")
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    if total <= 0:
        cap.release()
        return []
    indices = np.linspace(0, total - 1, num=min(max_frames, total)).astype(int)
    frames = []
    for idx in indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx))
        ret, frame = cap.read()
        if not ret:
            continue
        frames.append(frame[:, :, ::-1])  # BGR->RGB
    cap.release()
    return frames


class VideoFrameDataset(Dataset):
    def __init__(self, records: List[Dict[str, any]], max_frames: int = 16):
        self.records = records
        self.max_frames = max_frames
        resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-1])
        self.feature_extractor.eval()
        for p in self.feature_extractor.parameters():
            p.requires_grad = False
        self.preprocess = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

    def __len__(self):
        return len(self.records)

    def _frames_to_features(self, frames: List[np.ndarray]) -> np.ndarray:
        feats = []
        for f in frames[:self.max_frames]:
            t = self.preprocess(f)
            with torch.no_grad():
                feat = self.feature_extractor(t.unsqueeze(0)).squeeze().cpu().numpy()
            feats.append(feat)
        if len(feats) == 0:
            feats = [np.zeros((512,))] * self.max_frames
        if len(feats) < self.max_frames:
            pad = [np.zeros_like(feats[0])] * (self.max_frames - len(feats))
            feats.extend(pad)
        return np.stack(feats, axis=0)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        rec = self.records[idx]
        frames = extract_frames(rec["video_path"], max_frames=self.max_frames)
        feats = self._frames_to_features(frames)
        return torch.tensor(feats, dtype=torch.float32), int(rec["label"])


class LSTMClassifier(nn.Module):
    def __init__(self, feat_dim: int = 512, hidden_size: int = 256,
                 num_classes: int = 10, num_layers: int = 1):
        super().__init__()
        self.lstm = nn.LSTM(feat_dim, hidden_size, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        _, (hn, _) = self.lstm(x)
        return self.fc(hn[-1])


def train_video_baseline(train_records: List[Dict[str, any]], val_records: List[Dict[str, any]],
                         num_classes: int = 10, device: str = "cpu",
                         experiment_name: str = "exp_video") -> Dict[str, List[float]]:
    train_ds = VideoFrameDataset(train_records)
    val_ds = VideoFrameDataset(val_records)
    train_loader = DataLoader(train_ds, batch_size=8, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=8)

    model = LSTMClassifier(num_classes=num_classes).to(device)
    optim = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()

    history = {"train_loss": [], "val_acc": []}

    for epoch in range(1, 6):
        model.train()
        running = 0.0
        for feats, labels in train_loader:
            feats, labels = feats.to(device), labels.to(device)
            out = model(feats)
            loss = loss_fn(out, labels)
            optim.zero_grad()
            loss.backward()
            optim.step()
            running += loss.item()
        history["train_loss"].append(running / len(train_loader))

        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for feats, labels in val_loader:
                feats, labels = feats.to(device), labels.to(device)
                preds = model(feats).argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        acc = correct / total if total > 0 else 0.0
        history["val_acc"].append(acc)
        print(f"Epoch {epoch} train loss {history['train_loss'][-1]:.4f}, val acc {acc:.4f}")

    # --- Save plots ---
    plt.figure()
    plt.plot(history["train_loss"], marker="o", label="Train Loss")
    plt.plot(history["val_acc"], marker="s", label="Val Acc")
    plt.xlabel("Epoch")
    plt.title("Video Classification Training")
    plt.legend()
    plot_path = os.path.join(RESULTS_DIR, f"{experiment_name}_plot.png")
    plt.savefig(plot_path)
    plt.close()

    # --- Save experiment log ---
    exp_record = {
        "experiment": experiment_name,
        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "params": {"epochs": 5, "lr": 1e-3, "device": device},
        "history": history
    }
    log_path = os.path.join(EXPERIMENTS_DIR, f"{experiment_name}_log.json")
    with open(log_path, "w") as f:
        json.dump(exp_record, f, indent=2)

    print(f"✅ Results plot saved to {plot_path}")
    print(f"✅ Experiment log saved to {log_path}")
    return exp_record


if __name__ == "__main__":
    print("Run train_video_baseline with prepared train/val records.")
