In [None]:
# ===== Path config (aligned with Untitled2) =====
PROJECT_ROOT = "movie"
DATA_VERSION = "v13"
REPORTS_BASE_DIR = f"{PROJECT_ROOT}/reports"



In [None]:
import cv2
import numpy as np
import glob
import openpyxl

# Paths and parameters
video_dir = f"{PROJECT_ROOT}/{DATA_VERSION}/dataset/cut"              # Directory containing V001.mp4, V002.mp4, ...
excel_file = f"{PROJECT_ROOT}/{DATA_VERSION}/dataset/cut.xlsx"        # Excel file with annotations
frame_size = (224, 224)                                        # resize frames for CNN input

# -------- 读取 Excel（用 openpyxl，替代 pandas） --------
wb = openpyxl.load_workbook(excel_file, data_only=True)
ws = wb.active  # 默认第一个工作表

rows = list(ws.iter_rows(values_only=True))
if len(rows) == 0:
    raise RuntimeError("Excel 里没有任何行！")

# 第 1 行：全局帧率（例如 24）
fps_row = rows[0]

def get_fps_from_row(r, default=24.0):
    for cell in r:
        if cell is None:
            continue
        if isinstance(cell, (int, float)):
            return float(cell)
        if isinstance(cell, str):
            s = cell.strip()
            if s == "" or s.lower() == "none":
                continue
            try:
                return float(s)
            except Exception:
                continue
    return float(default)

fps = get_fps_from_row(fps_row, default=24.0)
print(f"Using FPS from Excel first row: {fps}")

# 后面的每一行对应一个视频：这一行的每一列都是一个 cut 的 timecode
data_rows = rows[1:]

# rows[i] 对应 V00{i+1}
video_files = sorted(glob.glob(f"{video_dir}/V*.mp4"))
assert len(video_files) == len(data_rows), \
    f"Mismatch between number of videos ({len(video_files)}) and Excel data rows ({len(data_rows)})"


def timecode_to_frame(tc, fps):
    """
    把 Excel 里的单个标注值统一本成帧号：
    - None / 'none' / '' → None（表示没有 cut）
    - 纯数字字符串 / 数字 → 直接当作帧号
    - 'ss:ff' → 秒 + 帧（例如 01:12 在 24fps 下就是 36）
    - 'hh:mm:ss' → 标准时码（按秒算：((h*60+m)*60+s)*fps）
    """
    if tc is None:
        return None

    if isinstance(tc, (int, float)):
        return int(tc)

    if isinstance(tc, str):
        s = tc.strip()
        if s == "" or s.lower() == "none":
            return None

        if s.isdigit():
            return int(s)

        if ":" in s:
            parts = s.split(":")
            try:
                if len(parts) == 2:
                    sec = int(parts[0])
                    frm = int(parts[1])
                    return int(sec * fps + frm)
                elif len(parts) == 3:
                    h = int(parts[0])
                    m = int(parts[1])
                    sec = int(parts[2])
                    total_sec = h * 3600 + m * 60 + sec
                    return int(total_sec * fps)
                else:
                    return None
            except Exception:
                return None

    return None


# -------- 准备数据容器（这些变量后面 Cell2 会用） --------
video_frames = []                 # list[list[np.ndarray]]: 每个视频的全部帧（已resize）
boundary_pairs = []               # list[tuple]: (vid_idx, i, label)
boundary_pairs_by_video = []      # list[list[tuple]]: 每个视频自己的pairs（便于后面更聪明采样）
cut_indices_list = []             # list[list[int]]: 每个视频的 cut_indices（i位置，表示 i 和 i+1 之间有切）
total_frames_list = []            # list[int]: 每个视频总帧数
valid_video_mask = []             # list[bool]: 视频是否有效（>=2帧）
video_paths = list(video_files)   # 备份一下路径，后面打印/定位很方便

# -------- 遍历视频 + 对应 Excel 行 --------
for vid_idx, (video_path, row) in enumerate(zip(video_files, data_rows)):
    cap = cv2.VideoCapture(video_path)
    frames = []
    success, frame = cap.read()
    while success:
        frame_resized = cv2.resize(frame, frame_size)
        frames.append(frame_resized)
        success, frame = cap.read()
    cap.release()

    video_frames.append(frames)
    total_frames = len(frames)
    total_frames_list.append(total_frames)

    if total_frames <= 1:
        valid_video_mask.append(False)
        boundary_pairs_by_video.append([])
        cut_indices_list.append([])
        print(f"Warning: video {video_path} has {total_frames} frame(s), skip boundary generation.")
        continue

    valid_video_mask.append(True)

    raw_values = list(row) if row is not None else []

    cut_indices = []
    for v in raw_values:
        frame_idx = timecode_to_frame(v, fps)
        if frame_idx is None:
            continue

        # Excel 标的是切后镜头起始帧 (B-start)
        # SBD 的 cut 应标在 (B-1, B) 之间 → 对应 pair i = B-1
        if frame_idx > 0:
            frame_idx = frame_idx - 1
        else:
            continue

        frame_idx = int(frame_idx)
        frame_idx = max(0, min(frame_idx, total_frames - 2))
        cut_indices.append(frame_idx)

    cut_indices = sorted(set(cut_indices))
    cut_indices_list.append(cut_indices)

    print(f"Video {vid_idx} ({video_path}): total_frames={total_frames}, cuts-1@frames={cut_indices}")

    cut_set = set(cut_indices)

    per_video_pairs = []
    for i in range(total_frames - 1):
        label = 1 if i in cut_set else 0
        tup = (vid_idx, i, label)
        boundary_pairs.append(tup)
        per_video_pairs.append(tup)

    boundary_pairs_by_video.append(per_video_pairs)

# -------- 关键全局变量（Cell2 会直接用到） --------
num_videos = len(video_files)

# -------- 数据统计 --------
num_pairs = len(boundary_pairs)
num_cuts = sum(1 for _, _, lbl in boundary_pairs if lbl == 1)
num_noncuts = num_pairs - num_cuts

total_extracted_frames = sum(len(frames) for frames in video_frames)
num_valid_videos = sum(1 for v in valid_video_mask if v)

print(f"\nProcessed {num_videos} videos (valid {num_valid_videos}/{num_videos}), extracted {total_extracted_frames} frames.")
print(f"Generated {num_pairs} frame pairs: {num_cuts} cuts (positive) and {num_noncuts} non-cuts (negative).")

# -------- 可选：把全局正负比例也存一下，后续算pos_weight更方便 --------
pos_count = num_cuts
neg_count = num_noncuts
pos_ratio = (pos_count / num_pairs) if num_pairs > 0 else 0.0
neg_ratio = (neg_count / num_pairs) if num_pairs > 0 else 0.0
print(f"Pos ratio: {pos_ratio:.6f} | Neg ratio: {neg_ratio:.6f}")


In [None]:
# ============= Cell2: Dataset + 按视频划分 + 正例全量/负例动态抽样（每epoch不一样） =============
import torch
from torch.utils.data import Dataset, DataLoader
import random
import cv2
import numpy as np

class ShotBoundaryDataset(Dataset):
    def __init__(self, pairs, video_frames):
        self.pairs = pairs
        self.video_frames = video_frames

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        vid_idx, frame_idx, label = self.pairs[idx]
        frames = self.video_frames[vid_idx]
        T = len(frames)

        # 4-frame context (clamp to avoid out-of-range)
        i_prev = max(0, frame_idx - 1)
        i_A    = max(0, min(frame_idx,     T - 1))
        i_B    = max(0, min(frame_idx + 1, T - 1))
        i_next = max(0, min(frame_idx + 2, T - 1))

        framePrev = frames[i_prev]
        frameA    = frames[i_A]
        frameB    = frames[i_B]
        frameNext = frames[i_next]

        diff1 = cv2.absdiff(framePrev, frameA)
        diff2 = cv2.absdiff(frameA, frameB)
        diff3 = cv2.absdiff(frameB, frameNext)

        # 21ch = prev(3)+A(3)+B(3)+next(3)+diff(prev,A)(3)+diff(A,B)(3)+diff(B,next)(3)
        img_21ch = np.concatenate([framePrev, frameA, frameB, frameNext, diff1, diff2, diff3], axis=2).astype("float32") / 255.0
        img_chw = np.transpose(img_21ch, (2, 0, 1))

        img_tensor = torch.tensor(img_chw, dtype=torch.float32)
        label_tensor = torch.tensor(label, dtype=torch.long)
        return img_tensor, label_tensor


# ============= 你只需要改这里的采样参数 =============
BATCH_SIZE = 1024

USE_DYNAMIC_NEG_SAMPLING = False       # True=每个epoch动态抽负例；False=全量训练（等同你原来的）
NEG_SAMPLING_MODE = "ratio"           # "ratio" or "per_pos"
NEG_SAMPLE_RATIO = 0.10               # mode="ratio": 每个epoch抽取负例池的比例（0.1=十分之一）
NEG_PER_POS = 5                       # mode="per_pos": 每个epoch负例数 = 正例数 * NEG_PER_POS

SEED_SPLIT = 42                       # 训练/测试划分可复现
SEED_EPOCH_BASE = 20260118            # 每个epoch采样可复现（不同epoch会不一样）


# ============= 划分训练测试（按视频） =============
all_video_indices = list(range(num_videos))
random.seed(SEED_SPLIT)
random.shuffle(all_video_indices)

split_idx = int(len(all_video_indices) * 0.95)
train_vids = set(all_video_indices[:split_idx])
test_vids  = set(all_video_indices[split_idx:])

train_pairs_all = [p for p in boundary_pairs if p[0] in train_vids]
test_pairs      = [p for p in boundary_pairs if p[0] in test_vids]

# 训练池拆正/负
train_pos_pairs = [p for p in train_pairs_all if p[2] == 1]
train_neg_pairs = [p for p in train_pairs_all if p[2] == 0]

# 测试集：永远全量，不做抽样
test_dataset = ShotBoundaryDataset(test_pairs, video_frames)
test_loader  = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"[Split] train_vids={len(train_vids)} test_vids={len(test_vids)}")
print(f"[Pool] train_pos={len(train_pos_pairs)} train_neg={len(train_neg_pairs)} test_total={len(test_pairs)}")


def make_train_loader_for_epoch(epoch: int, batch_size: int = BATCH_SIZE, shuffle: bool = True):
    """
    每个epoch动态构造训练集：
    - 正例：全量保留（每轮必学）
    - 负例：每轮重新随机抽样（每轮都不一样）
    """
    if (not USE_DYNAMIC_NEG_SAMPLING) or (len(train_neg_pairs) == 0):
        epoch_pairs = list(train_pairs_all)
        rng = None
    else:
        rng = random.Random(SEED_EPOCH_BASE + int(epoch))

        if NEG_SAMPLING_MODE == "ratio":
            k = int(len(train_neg_pairs) * float(NEG_SAMPLE_RATIO))
        elif NEG_SAMPLING_MODE == "per_pos":
            k = int(len(train_pos_pairs) * int(NEG_PER_POS))
        else:
            raise ValueError("NEG_SAMPLING_MODE must be 'ratio' or 'per_pos'")

        k = max(1, min(k, len(train_neg_pairs)))
        neg_sample = rng.sample(train_neg_pairs, k)

        epoch_pairs = list(train_pos_pairs) + neg_sample
        rng.shuffle(epoch_pairs)

    train_dataset_epoch = ShotBoundaryDataset(epoch_pairs, video_frames)
    train_loader_epoch  = DataLoader(train_dataset_epoch, batch_size=batch_size, shuffle=shuffle)

    pos_n = len(train_pos_pairs)
    neg_n = len(train_dataset_epoch) - pos_n
    raw_pos_weight_epoch = (neg_n / max(1, pos_n))

    # 给你一个非常直观的确认：每轮neg数量、raw_pos_weight都会按你的抽样变化
    print(f"[Epoch {epoch}] train_pairs={len(train_dataset_epoch)} (pos {pos_n}, neg {neg_n}) raw_pos_weight_ep={raw_pos_weight_epoch:.4f}")
    return train_loader_epoch, train_dataset_epoch, raw_pos_weight_epoch


In [None]:
# ============= Cell3-L: Linear Baseline 训练（1层网络：Pool + Linear）+ 每Epoch输出多指标 + 写Excel =============

import os
import platform
import time
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import openpyxl
from datetime import datetime

try:
    from zoneinfo import ZoneInfo  # py>=3.9
    _TZ = ZoneInfo("Asia/BeiJing")
except Exception:
    _TZ = None

# ===== 简单版 classification_report（沿用你的）=====
def simple_classification_report(y_true, y_pred, target_names):
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    n_classes = len(target_names)
    lines = []
    acc = (y_true == y_pred).sum() / len(y_true) if len(y_true) > 0 else 0.0
    lines.append("precision    recall  f1-score   support")
    for i in range(n_classes):
        name = target_names[i]
        true_i = (y_true == i)
        pred_i = (y_pred == i)
        tp = np.logical_and(true_i, pred_i).sum()
        fp = np.logical_and(~true_i, pred_i).sum()
        fn = np.logical_and(true_i, ~pred_i).sum()
        support = true_i.sum()
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall    = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1        = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
        lines.append(f"{name:10s}  {precision:0.4f}   {recall:0.4f}   {f1:0.4f}   {support:5d}")
    lines.append(f"\naccuracy                        {acc:0.4f}   {len(y_true):5d}")
    return "\n".join(lines)

# ===== 二分类指标（沿用你的）=====
def _safe_div(a, b):
    return float(a) / float(b) if b else 0.0

def binary_metrics_from_probs(y_true, prob_pos, threshold=0.95):
    y_true = np.asarray(y_true, dtype=np.int64)
    prob_pos = np.asarray(prob_pos, dtype=np.float64)
    y_pred = (prob_pos >= threshold).astype(np.int64)

    tp = int(np.sum((y_true == 1) & (y_pred == 1)))
    fp = int(np.sum((y_true == 0) & (y_pred == 1)))
    tn = int(np.sum((y_true == 0) & (y_pred == 0)))
    fn = int(np.sum((y_true == 1) & (y_pred == 0)))

    precision = _safe_div(tp, tp + fp)
    recall    = _safe_div(tp, tp + fn)
    f1        = _safe_div(2 * precision * recall, precision + recall)
    acc       = _safe_div(tp + tn, tp + tn + fp + fn)

    pos_pred_rate = _safe_div(tp + fp, len(y_true))

    avg_prob_pos = float(np.mean(prob_pos[y_true == 1])) if np.any(y_true == 1) else 0.0
    avg_prob_neg = float(np.mean(prob_pos[y_true == 0])) if np.any(y_true == 0) else 0.0

    return {
        "threshold": float(threshold),
        "tp": tp, "fp": fp, "tn": tn, "fn": fn,
        "precision": precision, "recall": recall, "f1": f1, "acc": acc,
        "pos_pred_rate": pos_pred_rate,
        "avg_prob_pos": avg_prob_pos,
        "avg_prob_neg": avg_prob_neg,
    }

def average_precision_score(y_true, prob_pos):
    y_true = np.asarray(y_true, dtype=np.int64)
    prob_pos = np.asarray(prob_pos, dtype=np.float64)
    pos_count = int(np.sum(y_true == 1))
    if pos_count == 0:
        return 0.0
    order = np.argsort(-prob_pos)
    y_sorted = y_true[order]
    tp = 0
    fp = 0
    precisions_at_hits = []
    for i in range(len(y_sorted)):
        if y_sorted[i] == 1:
            tp += 1
            precisions_at_hits.append(tp / (tp + fp))
        else:
            fp += 1
    return float(np.sum(precisions_at_hits) / pos_count)

def roc_auc_score_rank(y_true, prob_pos):
    y_true = np.asarray(y_true, dtype=np.int64)
    prob_pos = np.asarray(prob_pos, dtype=np.float64)
    n_pos = int(np.sum(y_true == 1))
    n_neg = int(np.sum(y_true == 0))
    if n_pos == 0 or n_neg == 0:
        return 0.0

    order = np.argsort(prob_pos)
    ranks = np.empty_like(order, dtype=np.float64)
    ranks[order] = np.arange(1, len(prob_pos) + 1, dtype=np.float64)

    sorted_scores = prob_pos[order]
    i = 0
    while i < len(sorted_scores):
        j = i
        while j + 1 < len(sorted_scores) and sorted_scores[j + 1] == sorted_scores[i]:
            j += 1
        if j > i:
            avg_rank = float(np.mean(ranks[order[i:j+1]]))
            ranks[order[i:j+1]] = avg_rank
        i = j + 1

    sum_ranks_pos = float(np.sum(ranks[y_true == 1]))
    auc = (sum_ranks_pos - n_pos * (n_pos + 1) / 2.0) / (n_pos * n_neg)
    return float(auc)

def find_best_threshold_f1(y_true, prob_pos, num_thresholds=101):
    best_t = 0.5
    best_f1 = -1.0
    for k in range(num_thresholds):
        t = k / (num_thresholds - 1)
        m = binary_metrics_from_probs(y_true, prob_pos, threshold=t)
        if m["f1"] > best_f1:
            best_f1 = m["f1"]
            best_t = t
    return float(best_t), float(best_f1)

# ===== 设备 =====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ===== Linear Baseline：无参数池化 + 1层Linear（唯一可学习层）=====
class BoundaryLinearBaseline(nn.Module):
    def __init__(self, in_ch=21, pool_hw=16):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d((pool_hw, pool_hw))  # 无参数
        self.fc = nn.Linear(in_ch * pool_hw * pool_hw, 2)     # 仅这一层有参数

    def forward(self, x):
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

baseline_model = BoundaryLinearBaseline(in_ch=21, pool_hw=16).to(device)

# ===== 类别不平衡权重：保持你原逻辑（用全量统计的 num_cuts/num_noncuts）=====
#if num_cuts > 0:
    #raw_pos_weight = num_noncuts / num_cuts
    # 你 CNN 里手动写 pos_weight=90，这里也沿用（对比更“同条件”）
    #pos_weight = 40.0
#else:
    #raw_pos_weight = None
    #pos_weight = 1.0


# ===== 和你一致：POS_WEIGHT_MODE 仍保留，但默认 fixed（也就是不随epoch变）=====
POS_WEIGHT_MODE = "fixed"      # "fixed" or "epoch"
POS_WEIGHT_EPOCH_MAX = 40.0    # epoch模式上限
threshold_default = 0.95
pos_weight = 40.0
raw_pos_weight = num_noncuts / num_cuts
print(f"raw_pos_weight = {num_noncuts / num_cuts if num_cuts > 0 else 'NA'}, used pos_weight = {pos_weight}")

# ===== 优化器：保持简单（你说不用考虑学习好不好，就是对照）=====
lr_init = 1e-5
optimizer_b = optim.Adam(baseline_model.parameters(), lr=lr_init)

def get_system_info():
    info = {}
    info["time_start"] = datetime.now(_TZ).strftime("%Y-%m-%d %H:%M:%S %Z") if _TZ else datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    info["python_version"] = platform.python_version()
    info["platform"] = platform.platform()
    info["processor"] = platform.processor()
    info["torch_version"] = torch.__version__
    info["cuda_available"] = str(torch.cuda.is_available())
    if torch.cuda.is_available():
        try:
            info["cuda_device_count"] = str(torch.cuda.device_count())
            info["cuda_device_name_0"] = torch.cuda.get_device_name(0)
        except Exception:
            info["cuda_device_count"] = "NA"
            info["cuda_device_name_0"] = "NA"
    else:
        info["cuda_device_count"] = "0"
        info["cuda_device_name_0"] = "NA"
    info["device_used"] = str(device)
    return info

run_info = get_system_info()

# ===== 训练 =====
epochs = 100  # 你要对比的话，可改成跟CNN同样轮数
epoch_rows = []
global_step = 0
t0 = time.time()

baseline_model.train()
for epoch in range(epochs):
    running_loss = 0.0

    # 每个epoch动态构造 train_loader（来自Cell2）
    train_loader, train_dataset, raw_pos_weight_epoch = make_train_loader_for_epoch(
        epoch, batch_size=BATCH_SIZE, shuffle=True
    )

    seed_epoch = (int(SEED_EPOCH_BASE) + int(epoch)) if USE_DYNAMIC_NEG_SAMPLING else None
    print(f"[Epoch {epoch}] seed_for_neg_sampling = {seed_epoch}")

    # 每epoch选择 used_pos_weight（与你CNN一致写法）
    if POS_WEIGHT_MODE == "fixed":
        used_pos_weight_epoch = float(pos_weight)
    elif POS_WEIGHT_MODE == "epoch":
        used_pos_weight_epoch = float(min(raw_pos_weight_epoch, POS_WEIGHT_EPOCH_MAX))
    else:
        used_pos_weight_epoch = float(pos_weight)  # 不可用时兜底

    # CrossEntropyLoss(weight=[noncut, cut])（与你CNN一致）
    class_weights = torch.tensor([1.0, used_pos_weight_epoch], dtype=torch.float32).to(device)
    criterion_b = nn.CrossEntropyLoss(weight=class_weights)

    y_true_ep = []
    prob_pos_ep = []

    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer_b.zero_grad()
        outputs = baseline_model(imgs)   # logits [B,2]
        loss = criterion_b(outputs, labels)
        loss.backward()
        optimizer_b.step()

        running_loss += loss.item() * imgs.size(0)
        global_step += 1

        with torch.no_grad():
            probs = torch.softmax(outputs, dim=1)[:, 1]
            prob_pos_ep.extend(probs.detach().cpu().numpy().tolist())
            y_true_ep.extend(labels.detach().cpu().numpy().tolist())

    epoch_loss = running_loss / len(train_loader.dataset)

    m = binary_metrics_from_probs(y_true_ep, prob_pos_ep, threshold=threshold_default)
    ap = average_precision_score(y_true_ep, prob_pos_ep)
    auc = roc_auc_score_rank(y_true_ep, prob_pos_ep)
    best_t, best_f1 = find_best_threshold_f1(y_true_ep, prob_pos_ep, num_thresholds=101)

    lr_now = optimizer_b.param_groups[0].get("lr", None)

    print(
        f"[Linear] Epoch {epoch+1}/{epochs} | "
        f"loss {epoch_loss:.4f} | "
        f"F1 {m['f1']:.4f} (P {m['precision']:.4f}, R {m['recall']:.4f}) | "
        f"AP {ap:.4f} | AUC {auc:.4f} | "
        f"TP {m['tp']} FP {m['fp']} TN {m['tn']} FN {m['fn']} | "
        f"pos_pred_rate {m['pos_pred_rate']:.4f} | "
        f"best_t(train) {best_t:.2f} bestF1(train) {best_f1:.4f} | "
        f"pos_weight {used_pos_weight_epoch:.2f} (raw_ep {raw_pos_weight_epoch:.2f}) | "
        f"train_pairs {len(train_loader.dataset)} "
        f"(pos {len(train_pos_pairs)}, neg {len(train_loader.dataset)-len(train_pos_pairs)}) | "
        f"neg_sampling {USE_DYNAMIC_NEG_SAMPLING} {NEG_SAMPLING_MODE} "
        f"ratio {NEG_SAMPLE_RATIO} per_pos {NEG_PER_POS}"
    )

    # ===== epoch_rows：列名/顺序对齐你CNN Cell3 的写法 =====
    epoch_rows.append({
        "epoch": epoch + 1,
        "global_step": global_step,
        "loss": float(epoch_loss),
        "threshold": float(threshold_default),
        "precision": float(m["precision"]),
        "recall": float(m["recall"]),
        "f1": float(m["f1"]),
        "acc": float(m["acc"]),
        "tp": int(m["tp"]),
        "fp": int(m["fp"]),
        "tn": int(m["tn"]),
        "fn": int(m["fn"]),
        "pos_pred_rate": float(m["pos_pred_rate"]),
        "avg_prob_pos": float(m["avg_prob_pos"]),
        "avg_prob_neg": float(m["avg_prob_neg"]),
        "pr_auc_ap": float(ap),
        "roc_auc": float(auc),
        "best_threshold_train": float(best_t),
        "best_f1_train": float(best_f1),
        "lr": float(lr_now) if lr_now is not None else 0.0,

        # ===== 保留原字段含义 =====
        "pos_weight": float(pos_weight),
        "raw_pos_weight": float(raw_pos_weight) if raw_pos_weight is not None else 0.0,
        "weight_noncut": float(class_weights[0].item()) if class_weights is not None else 0.0,
        "weight_cut": float(class_weights[1].item()) if class_weights is not None else 0.0,

        # ===== 追加字段（对齐你CNN的 extra_cols）=====
        "train_pairs_epoch": int(len(train_loader.dataset)) if train_loader is not None else 0,
        "raw_pos_weight_epoch": float(raw_pos_weight_epoch) if raw_pos_weight_epoch is not None else 0.0,
        "pos_weight_mode": str(POS_WEIGHT_MODE) if POS_WEIGHT_MODE is not None else "unavailable",
        "pos_weight_used_epoch": float(used_pos_weight_epoch) if used_pos_weight_epoch is not None else 0.0,
        "neg_sampling_enabled": bool(USE_DYNAMIC_NEG_SAMPLING),
        "neg_sampling_mode": str(NEG_SAMPLING_MODE) if NEG_SAMPLING_MODE is not None else "unavailable",
        "neg_sample_ratio": float(NEG_SAMPLE_RATIO) if NEG_SAMPLE_RATIO is not None else 0.0,
        "neg_per_pos": int(NEG_PER_POS) if NEG_PER_POS is not None else 0,
    })

print("Finished training Linear Baseline model.")
t_train = time.time() - t0

# ===== 最终测试（按你CNN Cell3：test_loader 是Cell2里做的视频级切分后的全量）=====
baseline_model.eval()
y_true_b, prob_pos_b, y_pred_b = [], [], []
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs = imgs.to(device)
        outputs = baseline_model(imgs)                 # logits [B,2]
        probs = torch.softmax(outputs, dim=1)[:, 1]    # P(cut)
        pred = (probs >= threshold_default).long()

        y_true_b.extend(labels.detach().cpu().numpy().tolist())
        prob_pos_b.extend(probs.detach().cpu().numpy().tolist())
        y_pred_b.extend(pred.detach().cpu().numpy().tolist())

m_test = binary_metrics_from_probs(y_true_b, prob_pos_b, threshold=threshold_default)
ap_test = average_precision_score(y_true_b, prob_pos_b)
auc_test = roc_auc_score_rank(y_true_b, prob_pos_b)

print("\n[Linear] FINAL TEST (threshold=0.95) => "
      f"F1 {m_test['f1']:.4f} (P {m_test['precision']:.4f}, R {m_test['recall']:.4f}) | "
      f"AP {ap_test:.4f} | AUC {auc_test:.4f} | "
      f"TP {m_test['tp']} FP {m_test['fp']} TN {m_test['tn']} FN {m_test['fn']}")

print("\nShot Boundary Detection - Classification Report (TEST):")
print(simple_classification_report(y_true_b, y_pred_b, target_names=["Non-cut", "Cut"]))

# ===== 写入 Excel 到 movie/reports，结构对齐你CNN Cell3 =====
base_dir = "movie/reports"
os.makedirs(base_dir, exist_ok=True)

ts_folder = datetime.now(_TZ).strftime("%m%d%H%M") if _TZ else datetime.now().strftime("%m%d%H%M")

# 保持你原来的命名风格：{time}_e{epochs}_w{pos_weight}
# 但为了区分CNN vs Linear，追加一个 suffix，不改变你程序抓 e/w 的方式
REPORT_FOLDER_NAME = f"{ts_folder}_e{epochs}_w{int(pos_weight)}_linear"
REPORT_DIR = os.path.join(base_dir, REPORT_FOLDER_NAME)
os.makedirs(REPORT_DIR, exist_ok=True)

ts = datetime.now(_TZ).strftime("%Y%m%d_%H%M%S") if _TZ else datetime.now().strftime("%Y%m%d_%H%M%S")
out_path = os.path.join(REPORT_DIR, f"boundary_train_metrics_{ts}.xlsx")

wb = openpyxl.Workbook()

# Sheet 1: run_info（对齐你CNN的键）
ws0 = wb.active
ws0.title = "run_info"
ws0.append(["key", "value"])
for k, v in run_info.items():
    ws0.append([k, str(v)])

ws0.append([""])
ws0.append(["model_name", "BoundaryLinearBaseline"])
ws0.append(["epochs", str(epochs)])
ws0.append(["train_seconds", f"{t_train:.3f}"])
ws0.append(["threshold_default", str(threshold_default)])
ws0.append(["optimizer", "Adam"])
ws0.append(["lr_init", str(lr_init)])
ws0.append(["loss", "CrossEntropyLoss(weighted)"])
ws0.append(["class_weight_noncut", str(float(1.0))])
ws0.append(["class_weight_cut", str(float(pos_weight))])
ws0.append(["raw_pos_weight", str(raw_pos_weight) if raw_pos_weight is not None else "0"])
ws0.append(["pos_weight_used", str(pos_weight)])

ws0.append([""])
ws0.append(["pos_weight_mode", str(POS_WEIGHT_MODE) if POS_WEIGHT_MODE is not None else "unavailable"])
ws0.append(["pos_weight_epoch_max", str(POS_WEIGHT_EPOCH_MAX) if POS_WEIGHT_EPOCH_MAX is not None else "0"])
ws0.append(["neg_sampling_enabled", str(USE_DYNAMIC_NEG_SAMPLING)])
ws0.append(["neg_sampling_mode", str(NEG_SAMPLING_MODE) if NEG_SAMPLING_MODE is not None else "unavailable"])
ws0.append(["neg_sample_ratio", str(NEG_SAMPLE_RATIO) if NEG_SAMPLE_RATIO is not None else "0"])
ws0.append(["neg_per_pos", str(NEG_PER_POS) if NEG_PER_POS is not None else "0"])

ws0.append([""])
ws0.append(["report_folder_name", REPORT_FOLDER_NAME])
ws0.append(["report_dir", REPORT_DIR])

# Sheet 2: epoch_metrics（列名+顺序对齐你CNN）
ws1 = wb.create_sheet("epoch_metrics")
cols = [
    "epoch","global_step","loss","lr",
    "threshold","precision","recall","f1","acc",
    "tp","fp","tn","fn",
    "pos_pred_rate","avg_prob_pos","avg_prob_neg",
    "pr_auc_ap","roc_auc",
    "best_threshold_train","best_f1_train",
    "raw_pos_weight","pos_weight","weight_noncut","weight_cut"
]
extra_cols = [
    "train_pairs_epoch",
    "raw_pos_weight_epoch",
    "pos_weight_mode",
    "pos_weight_used_epoch",
    "neg_sampling_enabled",
    "neg_sampling_mode",
    "neg_sample_ratio",
    "neg_per_pos"
]
all_cols = cols + extra_cols

ws1.append(all_cols)
for r in epoch_rows:
    row_out = []
    for c in all_cols:
        v = r.get(c, None)
        # 不可用字段处理：字符串 -> unavailable；数值 -> 0
        if v is None:
            if c in ["pos_weight_mode", "neg_sampling_mode"]:
                v = "unavailable"
            else:
                v = 0
        row_out.append(v)
    ws1.append(row_out)

# Sheet 3: final_test（对齐你CNN）
ws2 = wb.create_sheet("final_test")
ws2.append(["metric", "value"])
ws2.append(["threshold", m_test["threshold"]])
ws2.append(["precision", m_test["precision"]])
ws2.append(["recall", m_test["recall"]])
ws2.append(["f1", m_test["f1"]])
ws2.append(["acc", m_test["acc"]])
ws2.append(["tp", m_test["tp"]])
ws2.append(["fp", m_test["fp"]])
ws2.append(["tn", m_test["tn"]])
ws2.append(["fn", m_test["fn"]])
ws2.append(["pos_pred_rate", m_test["pos_pred_rate"]])
ws2.append(["avg_prob_pos", m_test["avg_prob_pos"]])
ws2.append(["avg_prob_neg", m_test["avg_prob_neg"]])
ws2.append(["pr_auc_ap", ap_test])
ws2.append(["roc_auc", auc_test])

ws2.append([""])
ws2.append(["classification_report", ""])
report_str = simple_classification_report(y_true_b, y_pred_b, target_names=["Non-cut", "Cut"])
for line in report_str.splitlines():
    ws2.append([line, ""])

wb.save(out_path)
print(f"\nSaved metrics Excel to: {out_path}")
print(f"[Report] Folder: {REPORT_DIR}")

# ===== 保存模型到同一个报告文件夹（对齐你之前的保存习惯）=====
try:
    ts_model = datetime.now(_TZ).strftime("%Y%m%d_%H%M%S") if _TZ else datetime.now().strftime("%Y%m%d_%H%M%S")
    model_path = os.path.join(REPORT_DIR, f"linear_model_{ts_model}.pt")
    ckpt_path  = os.path.join(REPORT_DIR, f"linear_ckpt_{ts_model}.pth")

    torch.save(baseline_model.state_dict(), model_path)

    ckpt = {
        "model_state_dict": baseline_model.state_dict(),
        "model_name": "BoundaryLinearBaseline",
        "device_saved": str(device),
        "threshold_default": float(threshold_default),
        "batch_size": int(BATCH_SIZE),
        "report_dir": REPORT_DIR,
        "report_folder_name": REPORT_FOLDER_NAME,
        "epochs": int(epochs),
        "pos_weight": float(pos_weight),
        "raw_pos_weight": float(raw_pos_weight) if raw_pos_weight is not None else 0.0,
        "saved_time": datetime.now(_TZ).strftime("%Y-%m-%d %H:%M:%S %Z") if _TZ else datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
    }
    torch.save(ckpt, ckpt_path)

    print(f"[Model] Saved state_dict to: {model_path}")
    print(f"[Model] Saved checkpoint to: {ckpt_path}")
except Exception as e:
    print(f"[Model] Save failed: {e}")


In [None]:
# ============= Cell4: cut_test 测试（详细报告 + 写Excel，严格对齐“代码4.docx 里的原Cell4格式”） =============

import os
import glob
import cv2
import numpy as np
import openpyxl
import torch
from torch.utils.data import Dataset, DataLoader
from datetime import datetime
import platform
import time

try:
    from zoneinfo import ZoneInfo  # py>=3.9
    _TZ = ZoneInfo("Asia/BeiJing")
except Exception:
    _TZ = None

# ===== 路径改成你的测试集 =====
test_video_dir  = f"{PROJECT_ROOT}/{DATA_VERSION}/dataset/cut_test"
test_excel_file = f"{PROJECT_ROOT}/{DATA_VERSION}/dataset/cut_test.xlsx"

# ===== 一些参数 =====
threshold_default = 0.95
batch_size = 1024
topk_suspects = 10  # 每个视频输出 topK “最可疑的 FP / 最可疑的 FN”

# ===== 输入尺寸：复用你Cell1/Cell3的 frame_size；没有就兜底 224 =====
if "frame_size" in globals() and isinstance(frame_size, (tuple, list)) and len(frame_size) == 2:
    _FRAME_SIZE = (int(frame_size[0]), int(frame_size[1]))
else:
    _FRAME_SIZE = (224, 224)

# ===== 设备与模型 =====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# 兼容：baseline_model（单层）或 boundary_model（CNN）
_model = None
_model_name = "unavailable"
if "baseline_model" in globals():
    _model = baseline_model
    _model_name = "baseline_model"
elif "boundary_model" in globals():
    _model = boundary_model
    _model_name = "boundary_model"

if _model is None:
    raise RuntimeError("No model found: baseline_model / boundary_model 都不存在。请先运行你的训练Cell。")

_model = _model.to(device)
_model.eval()

# =====（如果存在你的 simple_classification_report，就复用；否则按原docx补一个最简版）=====
if "simple_classification_report" not in globals():
    def simple_classification_report(y_true, y_pred, target_names):
        y_true = np.array(y_true)
        y_pred = np.array(y_pred)
        n_classes = len(target_names)
        lines = []
        acc = (y_true == y_pred).sum() / len(y_true) if len(y_true) > 0 else 0.0
        lines.append("precision    recall  f1-score   support")
        for i in range(n_classes):
            name = target_names[i]
            true_i = (y_true == i)
            pred_i = (y_pred == i)
            tp = np.logical_and(true_i, pred_i).sum()
            fp = np.logical_and(~true_i, pred_i).sum()
            fn = np.logical_and(true_i, ~pred_i).sum()
            support = true_i.sum()
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
            recall    = tp / (tp + fn) if (tp + fn) > 0 else 0.0
            f1        = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
            lines.append(f"{name:10s}  {precision:0.4f}   {recall:0.4f}   {f1:0.4f}   {support:5d}")
        lines.append(f"\naccuracy                        {acc:0.4f}   {len(y_true):5d}")
        return "\n".join(lines)

# ===== 指标函数（与原docx一致风格）=====
def _safe_div(a, b):
    return float(a) / float(b) if b else 0.0

def binary_metrics_from_probs(y_true, prob_pos, threshold=0.95):
    y_true = np.asarray(y_true, dtype=np.int64)
    prob_pos = np.asarray(prob_pos, dtype=np.float64)
    y_pred = (prob_pos >= threshold).astype(np.int64)

    tp = int(np.sum((y_true == 1) & (y_pred == 1)))
    fp = int(np.sum((y_true == 0) & (y_pred == 1)))
    tn = int(np.sum((y_true == 0) & (y_pred == 0)))
    fn = int(np.sum((y_true == 1) & (y_pred == 0)))

    precision = _safe_div(tp, tp + fp)
    recall    = _safe_div(tp, tp + fn)
    f1        = _safe_div(2 * precision * recall, precision + recall)
    acc       = _safe_div(tp + tn, tp + tn + fp + fn)
    pos_pred_rate = _safe_div(tp + fp, len(y_true))

    avg_prob_pos = float(np.mean(prob_pos[y_true == 1])) if np.any(y_true == 1) else 0.0
    avg_prob_neg = float(np.mean(prob_pos[y_true == 0])) if np.any(y_true == 0) else 0.0

    return {
        "threshold": float(threshold),
        "tp": tp, "fp": fp, "tn": tn, "fn": fn,
        "precision": precision, "recall": recall, "f1": f1, "acc": acc,
        "pos_pred_rate": pos_pred_rate,
        "avg_prob_pos": avg_prob_pos,
        "avg_prob_neg": avg_prob_neg,
        "y_pred": y_pred,
    }

def average_precision_score(y_true, prob_pos):
    y_true = np.asarray(y_true, dtype=np.int64)
    prob_pos = np.asarray(prob_pos, dtype=np.float64)
    pos_count = int(np.sum(y_true == 1))
    if pos_count == 0:
        return 0.0
    order = np.argsort(-prob_pos)
    y_sorted = y_true[order]
    tp = 0
    fp = 0
    precisions_at_hits = []
    for i in range(len(y_sorted)):
        if y_sorted[i] == 1:
            tp += 1
            precisions_at_hits.append(tp / (tp + fp))
        else:
            fp += 1
    return float(np.sum(precisions_at_hits) / pos_count)

def roc_auc_score_rank(y_true, prob_pos):
    y_true = np.asarray(y_true, dtype=np.int64)
    prob_pos = np.asarray(prob_pos, dtype=np.float64)
    n_pos = int(np.sum(y_true == 1))
    n_neg = int(np.sum(y_true == 0))
    if n_pos == 0 or n_neg == 0:
        return 0.0

    order = np.argsort(prob_pos)
    ranks = np.empty_like(order, dtype=np.float64)
    ranks[order] = np.arange(1, len(prob_pos) + 1, dtype=np.float64)

    sorted_scores = prob_pos[order]
    i = 0
    while i < len(sorted_scores):
        j = i
        while j + 1 < len(sorted_scores) and sorted_scores[j + 1] == sorted_scores[i]:
            j += 1
        if j > i:
            avg_rank = float(np.mean(ranks[order[i:j+1]]))
            ranks[order[i:j+1]] = avg_rank
        i = j + 1

    sum_ranks_pos = float(np.sum(ranks[y_true == 1]))
    auc = (sum_ranks_pos - n_pos * (n_pos + 1) / 2.0) / (n_pos * n_neg)
    return float(auc)

# ===== 读取 Excel（第一行 FPS，后面每行=一个视频的多列cut标注）=====
wb = openpyxl.load_workbook(test_excel_file, data_only=True)
ws = wb.active
rows = list(ws.iter_rows(values_only=True))
if len(rows) == 0:
    raise RuntimeError("cut_test.xlsx 里没有任何行！")

fps_row = rows[0]

def get_fps_from_row(r, default=24.0):
    for cell in r:
        if cell is None:
            continue
        if isinstance(cell, (int, float)):
            return float(cell)
        if isinstance(cell, str):
            s = cell.strip()
            if s == "" or s.lower() == "none":
                continue
            try:
                return float(s)
            except Exception:
                continue
    return float(default)

fps = get_fps_from_row(fps_row, default=24.0)
print(f"[cut_test] Using FPS from Excel first row: {fps}")

data_rows = rows[1:]

video_files = sorted(glob.glob(f"{test_video_dir}/V*.mp4"))
assert len(video_files) == len(data_rows), f"Mismatch: videos({len(video_files)}) vs excel rows({len(data_rows)})"

def timecode_to_frame(tc, fps):
    if tc is None:
        return None
    if isinstance(tc, (int, float)):
        return int(tc)
    if isinstance(tc, str):
        s = tc.strip()
        if s == "" or s.lower() == "none":
            return None
        if s.isdigit():
            return int(s)
        if ":" in s:
            parts = s.split(":")
            try:
                if len(parts) == 2:
                    sec = int(parts[0])
                    frm = int(parts[1])
                    return int(sec * fps + frm)
                elif len(parts) == 3:
                    h = int(parts[0]); m = int(parts[1]); sec = int(parts[2])
                    total_sec = h * 3600 + m * 60 + sec
                    return int(total_sec * fps)
                else:
                    return None
            except Exception:
                return None
    return None

# ===== 读取视频帧 + 生成 boundary_pairs_test（完全对齐原Cell4逻辑：Excel标B-start -> pair=i=B-1）=====
video_frames = []
boundary_pairs_test = []
video_meta = []  # 每个视频的 meta：total_frames, gt_cut_indices, name, path

for vid_idx, (video_path, row) in enumerate(zip(video_files, data_rows)):
    cap = cv2.VideoCapture(video_path)
    frames = []
    success, frame = cap.read()
    while success:
        frame_resized = cv2.resize(frame, _FRAME_SIZE)
        frames.append(frame_resized)
        success, frame = cap.read()
    cap.release()

    video_frames.append(frames)
    total_frames = len(frames)

    if total_frames <= 1:
        print(f"[cut_test] Warning: {video_path} has {total_frames} frame(s), skip.")
        video_meta.append({
            "vid_idx": vid_idx,
            "name": os.path.basename(video_path),
            "path": video_path,
            "total_frames": total_frames,
            "gt_cut_indices": [],
        })
        continue

    raw_values = list(row) if row is not None else []
    cut_indices = []

    for v in raw_values:
        frame_idx = timecode_to_frame(v, fps)
        if frame_idx is None:
            continue

        # Excel 标切后起始帧(B-start)，SBD 标在 (B-1,B) -> i=B-1
        if frame_idx > 0:
            frame_idx = frame_idx - 1
        else:
            continue

        frame_idx = int(frame_idx)
        frame_idx = max(0, min(frame_idx, total_frames - 2))
        cut_indices.append(frame_idx)

    cut_indices = sorted(set(cut_indices))
    print(f"[cut_test] Video {vid_idx} ({os.path.basename(video_path)}): total_frames={total_frames}, cuts-1@frames={cut_indices}")

    video_meta.append({
        "vid_idx": vid_idx,
        "name": os.path.basename(video_path),
        "path": video_path,
        "total_frames": total_frames,
        "gt_cut_indices": cut_indices,
    })

    cut_set = set(cut_indices)
    for i in range(total_frames - 1):
        label = 1 if i in cut_set else 0
        boundary_pairs_test.append((vid_idx, i, label))

num_pairs = len(boundary_pairs_test)
num_cuts = sum(1 for _, _, lbl in boundary_pairs_test if lbl == 1)
num_noncuts = num_pairs - num_cuts
print(f"\n[cut_test] Generated {num_pairs} pairs: {num_cuts} cuts, {num_noncuts} non-cuts")

# ===== Dataset：9通道(frameA, frameB, diff)（对齐原Cell4）=====
class ShotBoundaryDataset(Dataset):
    def __init__(self, pairs, video_frames):
        self.pairs = pairs
        self.video_frames = video_frames

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        vid_idx, frame_idx, label = self.pairs[idx]
        frames = self.video_frames[vid_idx]
        T = len(frames)

        # 4-frame context (clamp to avoid out-of-range)
        i_prev = max(0, frame_idx - 1)
        i_A    = max(0, min(frame_idx,     T - 1))
        i_B    = max(0, min(frame_idx + 1, T - 1))
        i_next = max(0, min(frame_idx + 2, T - 1))

        framePrev = frames[i_prev]
        frameA    = frames[i_A]
        frameB    = frames[i_B]
        frameNext = frames[i_next]

        diff1 = cv2.absdiff(framePrev, frameA)
        diff2 = cv2.absdiff(frameA, frameB)
        diff3 = cv2.absdiff(frameB, frameNext)

        # 21ch = prev(3)+A(3)+B(3)+next(3)+diff(prev,A)(3)+diff(A,B)(3)+diff(B,next)(3)
        img_21ch = np.concatenate([framePrev, frameA, frameB, frameNext, diff1, diff2, diff3], axis=2).astype("float32") / 255.0
        img_chw = np.transpose(img_21ch, (2, 0, 1))

        img_tensor = torch.tensor(img_chw, dtype=torch.float32)
        label_tensor = torch.tensor(label, dtype=torch.long)
        return img_tensor, label_tensor

test_dataset = ShotBoundaryDataset(boundary_pairs_test, video_frames)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

# ===== 推理并收集：整体 + 每个pair的详细信息（对齐原Cell4）=====
y_true, prob_pos, y_pred = [], [], []
pair_details = []  # (vid_idx, frame_idx, gt, prob, pred)

t0 = time.time()
pair_ptr = 0
with torch.no_grad():
    for imgs, labels in test_loader:
        bsz = labels.size(0)
        imgs = imgs.to(device)
        out = _model(imgs)

        # 兼容 CNN logits[B,2] / baseline 单logit[B] 或 [B,1]
        if out.dim() == 2 and out.size(1) == 2:
            probs = torch.softmax(out, dim=1)[:, 1]
        else:
            probs = torch.sigmoid(out.view(-1))

        pred = (probs >= threshold_default).long()

        probs_np = probs.detach().cpu().numpy()
        pred_np  = pred.detach().cpu().numpy()
        labels_np = labels.detach().cpu().numpy()

        y_true.extend(labels_np.tolist())
        prob_pos.extend(probs_np.tolist())
        y_pred.extend(pred_np.tolist())

        for j in range(bsz):
            vid_idx, frame_idx, gt = boundary_pairs_test[pair_ptr + j]
            pair_details.append((vid_idx, frame_idx, int(gt), float(probs_np[j]), int(pred_np[j])))
        pair_ptr += bsz

t_infer = time.time() - t0

# ===== 汇总整体指标 =====
m = binary_metrics_from_probs(y_true, prob_pos, threshold=threshold_default)
ap = average_precision_score(y_true, prob_pos)
auc = roc_auc_score_rank(y_true, prob_pos)

print("\n[cut_test] FINAL TEST (threshold=0.5) => "
      f"F1 {m['f1']:.4f} (P {m['precision']:.4f}, R {m['recall']:.4f}) | "
      f"AP {ap:.4f} | AUC {auc:.4f} | "
      f"TP {m['tp']} FP {m['fp']} TN {m['tn']} FN {m['fn']}")

print("\nShot Boundary Detection - Classification Report (cut_test):")
report_str = simple_classification_report(y_true, y_pred, target_names=["Non-cut", "Cut"])
print(report_str)

# ===== 每个视频的详细统计（对齐原Cell4：per_video + suspects_topk）=====
per_video_pairs = {vm["vid_idx"]: [] for vm in video_meta}
for (vid_idx, frame_idx, gt, p, pred) in pair_details:
    per_video_pairs[vid_idx].append((frame_idx, gt, p, pred))

per_video_rows = []
per_video_suspects_rows = []  # [vid, vid_idx, type, rank, frame_idx, prob_cut]

for vm in video_meta:
    vid_idx = vm["vid_idx"]
    name = vm["name"]
    total_frames = vm["total_frames"]
    gt_cuts = set(vm["gt_cut_indices"])

    pairs = per_video_pairs.get(vid_idx, [])
    if not pairs:
        per_video_rows.append({
            "vid": name, "vid_idx": vid_idx, "total_frames": total_frames,
            "gt_cut_count": len(gt_cuts), "pred_cut_count": 0,
            "tp": 0, "fp": 0, "fn": len(gt_cuts),
            "gt_cuts": ",".join(map(str, sorted(gt_cuts))) if gt_cuts else "none",
            "pred_cuts": "none",
        })
        continue

    pred_cuts = sorted({fr for (fr, gt, p, pred) in pairs if pred == 1})
    pred_cut_set = set(pred_cuts)

    tp = len(gt_cuts & pred_cut_set)
    fp = len(pred_cut_set - gt_cuts)
    fn = len(gt_cuts - pred_cut_set)

    per_video_rows.append({
        "vid": name, "vid_idx": vid_idx, "total_frames": total_frames,
        "gt_cut_count": len(gt_cuts), "pred_cut_count": len(pred_cut_set),
        "tp": tp, "fp": fp, "fn": fn,
        "gt_cuts": ",".join(map(str, sorted(gt_cuts))) if gt_cuts else "none",
        "pred_cuts": ",".join(map(str, pred_cuts)) if pred_cuts else "none",
    })

    fps_list = [(fr, p) for (fr, gt, p, pred) in pairs if gt == 0 and pred == 1]
    fns_list = [(fr, p) for (fr, gt, p, pred) in pairs if gt == 1 and pred == 0]
    fps_list = sorted(fps_list, key=lambda x: -x[1])[:topk_suspects]
    fns_list = sorted(fns_list, key=lambda x: x[1])[:topk_suspects]

    for rank, (fr, pp) in enumerate(fps_list, start=1):
        per_video_suspects_rows.append([name, vid_idx, "FP", rank, fr, float(pp)])
    for rank, (fr, pp) in enumerate(fns_list, start=1):
        per_video_suspects_rows.append([name, vid_idx, "FN", rank, fr, float(pp)])

# ===== 写入 Excel：sheet 名与结构严格对齐原Cell4 =====
base_dir = "movie/reports"
os.makedirs(base_dir, exist_ok=True)

# 复用训练Cell3产生的报告目录；没有就按同规则新建（原Cell4就是这么做的）
if "REPORT_DIR" not in globals() or "REPORT_FOLDER_NAME" not in globals():
    ts_folder = datetime.now(_TZ).strftime("%m%d%H%M") if _TZ else datetime.now().strftime("%m%d%H%M")
    _e = globals().get("epochs", "NA")
    _w = globals().get("pos_weight", "NA")
    REPORT_FOLDER_NAME = f"{ts_folder}_e{_e}_w{_w}"
    REPORT_DIR = os.path.join(base_dir, REPORT_FOLDER_NAME)
    os.makedirs(REPORT_DIR, exist_ok=True)

ts = datetime.now(_TZ).strftime("%Y%m%d_%H%M%S") if _TZ else datetime.now().strftime("%Y%m%d_%H%M%S")
out_path = os.path.join(REPORT_DIR, f"cut_test_report_{ts}.xlsx")

wb_out = openpyxl.Workbook()

# Sheet: run_info（原Cell4格式）
ws0 = wb_out.active
ws0.title = "run_info"
ws0.append(["key", "value"])
ws0.append(["time", datetime.now(_TZ).strftime("%Y-%m-%d %H:%M:%S %Z") if _TZ else datetime.now().strftime("%Y-%m-%d %H:%M:%S")])
ws0.append(["python_version", platform.python_version()])
ws0.append(["platform", platform.platform()])
ws0.append(["processor", platform.processor()])
ws0.append(["torch_version", torch.__version__])
ws0.append(["cuda_available", str(torch.cuda.is_available())])
ws0.append(["device_used", str(device)])
ws0.append(["test_video_dir", test_video_dir])
ws0.append(["test_excel_file", test_excel_file])
ws0.append(["fps_from_excel", str(fps)])
ws0.append(["threshold_default", str(threshold_default)])
ws0.append(["batch_size", str(batch_size)])
ws0.append(["topk_suspects", str(topk_suspects)])
ws0.append(["model_used", _model_name])
ws0.append(["infer_seconds", f"{t_infer:.3f}"])
ws0.append(["report_folder_name", REPORT_FOLDER_NAME])
ws0.append(["report_dir", REPORT_DIR])

# Sheet: dataset_summary（原Cell4格式）
ws_sum = wb_out.create_sheet("dataset_summary")
ws_sum.append(["item", "value"])
ws_sum.append(["num_videos", len(video_files)])
ws_sum.append(["num_pairs", num_pairs])
ws_sum.append(["num_cuts", num_cuts])
ws_sum.append(["num_non_cuts", num_noncuts])
ws_sum.append(["pos_ratio", (num_cuts / num_pairs) if num_pairs else 0.0])

ws_sum.append([""])
ws_sum.append(["per_video_frame_stats", ""])
frames_list = [vm["total_frames"] for vm in video_meta if vm["total_frames"] is not None]
if frames_list:
    ws_sum.append(["min_frames", int(np.min(frames_list))])
    ws_sum.append(["max_frames", int(np.max(frames_list))])
    ws_sum.append(["mean_frames", float(np.mean(frames_list))])
    ws_sum.append(["median_frames", float(np.median(frames_list))])

# Sheet: per_video（原Cell4格式）
ws_v = wb_out.create_sheet("per_video")
ws_v.append([
    "vid", "vid_idx", "total_frames",
    "gt_cut_count", "pred_cut_count",
    "tp", "fp", "fn",
    "gt_cuts", "pred_cuts"
])
for r in per_video_rows:
    ws_v.append([
        r["vid"], r["vid_idx"], r["total_frames"],
        r["gt_cut_count"], r["pred_cut_count"],
        r["tp"], r["fp"], r["fn"],
        r["gt_cuts"], r["pred_cuts"]
    ])

# Sheet: suspects_topk（原Cell4格式）
ws_sus = wb_out.create_sheet("suspects_topk")
ws_sus.append(["vid", "vid_idx", "type", "rank", "frame_idx", "prob_cut"])
for row in per_video_suspects_rows:
    ws_sus.append(row)

# Sheet: final_test（原Cell4格式）
ws2 = wb_out.create_sheet("final_test")
ws2.append(["metric", "value"])
ws2.append(["threshold", m["threshold"]])
ws2.append(["precision", m["precision"]])
ws2.append(["recall", m["recall"]])
ws2.append(["f1", m["f1"]])
ws2.append(["acc", m["acc"]])
ws2.append(["tp", m["tp"]])
ws2.append(["fp", m["fp"]])
ws2.append(["tn", m["tn"]])
ws2.append(["fn", m["fn"]])
ws2.append(["pos_pred_rate", m["pos_pred_rate"]])
ws2.append(["avg_prob_pos", m["avg_prob_pos"]])
ws2.append(["avg_prob_neg", m["avg_prob_neg"]])
ws2.append(["pr_auc_ap", ap])
ws2.append(["roc_auc", auc])

# Sheet: classification_report（原Cell4格式：逐行写文本）
ws_rep = wb_out.create_sheet("classification_report")
ws_rep.append(["text"])
for line in report_str.splitlines():
    ws_rep.append([line])

wb_out.save(out_path)
print(f"\nSaved cut_test report Excel to: {out_path}")
print(f"[Report] Folder: {REPORT_DIR}")

# ===== 保存模型到同一个报告文件夹（对齐你原Cell4风格：state_dict + ckpt）=====
try:
    os.makedirs(REPORT_DIR, exist_ok=True)
    ts_model = datetime.now(_TZ).strftime("%Y%m%d_%H%M%S") if _TZ else datetime.now().strftime("%Y%m%d_%H%M%S")

    # 文件名保持原风格，但根据模型类型区分一下（不影响你的系统读取Excel）
    if _model_name == "baseline_model":
        model_path = os.path.join(REPORT_DIR, f"baseline_model_{ts_model}.pt")
        ckpt_path  = os.path.join(REPORT_DIR, f"baseline_ckpt_{ts_model}.pth")
    else:
        model_path = os.path.join(REPORT_DIR, f"boundary_model_{ts_model}.pt")
        ckpt_path  = os.path.join(REPORT_DIR, f"boundary_ckpt_{ts_model}.pth")

    torch.save(_model.state_dict(), model_path)

    ckpt = {
        "model_state_dict": _model.state_dict(),
        "model_used": _model_name,
        "device_saved": str(device),
        "fps_from_excel": float(fps),
        "threshold_default": float(threshold_default),
        "batch_size": int(batch_size),
        "frame_size": tuple(_FRAME_SIZE),
        "report_dir": REPORT_DIR,
        "report_folder_name": REPORT_FOLDER_NAME,
        "epochs": globals().get("epochs", None),
        "pos_weight": globals().get("pos_weight", None),
        "build_id": globals().get("BUILD_ID", None),
        "saved_time": datetime.now(_TZ).strftime("%Y-%m-%d %H:%M:%S %Z") if _TZ else datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
    }
    torch.save(ckpt, ckpt_path)

    print(f"[Model] Saved state_dict to: {model_path}")
    print(f"[Model] Saved checkpoint to: {ckpt_path}")

except Exception as e:
    print(f"[Model] Save failed: {e}")
