In [None]:
import cv2
import numpy as np
import glob
import openpyxl

# ---------- 数据版本 / 路径 ----------
DATA_VERSION = "v13"
PROJECT_ROOT = "movie"

TRAIN_VIDEO_DIR  = f"{PROJECT_ROOT}/{DATA_VERSION}/dataset/cut"
TRAIN_EXCEL_FILE = f"{PROJECT_ROOT}/{DATA_VERSION}/dataset/cut.xlsx"

TEST_VIDEO_DIR   = f"{PROJECT_ROOT}/{DATA_VERSION}/dataset/cut_test"
TEST_EXCEL_FILE  = f"{PROJECT_ROOT}/{DATA_VERSION}/dataset/cut_test.xlsx"

REPORTS_BASE_DIR = f"{PROJECT_ROOT}/reports"

# ---------- 输入尺寸 ----------
FRAME_SIZE = (224, 224)  # (H,W)

# ---------- 训练超参数 ----------
EPOCHS     = 100
BATCH_SIZE = 1024
LR_INIT    = 1e-5

# ---------- 类不平衡（pos_weight） ----------
POS_WEIGHT           = 40
POS_WEIGHT_MODE      = "fixed"   # "fixed" / "epoch"
POS_WEIGHT_EPOCH_MAX = 80.0

# ---------- 负例抽样（每epoch动态抽负例） ----------
USE_DYNAMIC_NEG_SAMPLING = False
NEG_SAMPLING_MODE        = "ratio"    # "ratio" or "per_pos"
NEG_SAMPLE_RATIO         = 0.20
NEG_PER_POS              = 5

SEED_SPLIT      = 42
SEED_EPOCH_BASE = 20260118

# ---------- CUDA/性能相关 ----------
USE_CUDA          = True
CUDA_DEVICE_INDEX = 0

CUDNN_BENCHMARK     = True
CUDNN_DETERMINISTIC = False

ALLOW_TF32 = True
MATMUL_PRECISION = "high"

DATALOADER_NUM_WORKERS        = 12
DATALOADER_PIN_MEMORY         = True
DATALOADER_PERSISTENT_WORKERS = True

# ---------- 向下兼容：旧变量名别名 ----------
video_dir  = TRAIN_VIDEO_DIR
excel_file = TRAIN_EXCEL_FILE
frame_size = FRAME_SIZE

# 数据读取

wb = openpyxl.load_workbook(excel_file, data_only=True)
ws = wb.active
rows = list(ws.iter_rows(values_only=True))
if len(rows) == 0:
    raise RuntimeError("Excel 里没有任何行！")

fps_row = rows[0]

def get_fps_from_row(r, default=24.0):
    for cell in r:
        if cell is None:
            continue
        if isinstance(cell, (int, float)):
            return float(cell)
        if isinstance(cell, str):
            s = cell.strip()
            if s == "" or s.lower() == "none":
                continue
            try:
                return float(s)
            except Exception:
                continue
    return float(default)

fps = get_fps_from_row(fps_row, default=24.0)
print(f"Using FPS from Excel first row: {fps}")

data_rows = rows[1:]

video_files = sorted(glob.glob(f"{video_dir}/V*.mp4"))
assert len(video_files) == len(data_rows), f"Mismatch between number of videos ({len(video_files)}) and Excel data rows ({len(data_rows)})"

def timecode_to_frame(tc, fps):
    if tc is None:
        return None
    if isinstance(tc, (int, float)):
        return int(tc)

    if isinstance(tc, str):
        s = tc.strip()
        if s == "" or s.lower() == "none":
            return None
        if s.isdigit():
            return int(s)

        if ":" in s:
            parts = s.split(":")
            try:
                if len(parts) == 2:
                    sec = int(parts[0]); frm = int(parts[1])
                    return int(sec * fps + frm)
                elif len(parts) == 3:
                    h = int(parts[0]); m = int(parts[1]); sec = int(parts[2])
                    total_sec = h * 3600 + m * 60 + sec
                    return int(total_sec * fps)
            except Exception:
                return None
    return None

video_frames = []
boundary_pairs = []
boundary_pairs_by_video = []
cut_indices_list = []
total_frames_list = []
valid_video_mask = []
video_paths = list(video_files)

for vid_idx, (video_path, row) in enumerate(zip(video_files, data_rows)):
    cap = cv2.VideoCapture(video_path)
    frames = []
    success, frame = cap.read()
    while success:
        frame_resized = cv2.resize(frame, frame_size)
        frames.append(frame_resized)
        success, frame = cap.read()
    cap.release()

    video_frames.append(frames)
    total_frames = len(frames)
    total_frames_list.append(total_frames)

    if total_frames <= 1:
        valid_video_mask.append(False)
        boundary_pairs_by_video.append([])
        cut_indices_list.append([])
        print(f"Warning: video {video_path} has {total_frames} frame(s), skip boundary generation.")
        continue

    valid_video_mask.append(True)

    raw_values = list(row) if row is not None else []
    cut_indices = []

    for v in raw_values:
        frame_idx = timecode_to_frame(v, fps)
        if frame_idx is None:
            continue

        # Excel 标的是切后镜头起始帧(B-start)，SBD cut 应落在 (B-1,B) => i=B-1
        if frame_idx > 0:
            frame_idx -= 1
        else:
            continue

        frame_idx = max(0, min(int(frame_idx), total_frames - 2))
        cut_indices.append(frame_idx)

    cut_indices = sorted(set(cut_indices))
    cut_indices_list.append(cut_indices)

    print(f"Video {vid_idx} ({video_path}): total_frames={total_frames}, cuts-1@frames={cut_indices}")

    cut_set = set(cut_indices)

    per_video_pairs = []
    for i in range(total_frames - 1):
        label = 1 if i in cut_set else 0
        tup = (vid_idx, i, label)
        boundary_pairs.append(tup)
        per_video_pairs.append(tup)

    boundary_pairs_by_video.append(per_video_pairs)

num_videos = len(video_files)

num_pairs = len(boundary_pairs)
num_cuts = sum(1 for _, _, lbl in boundary_pairs if lbl == 1)
num_noncuts = num_pairs - num_cuts

total_extracted_frames = sum(len(frames) for frames in video_frames)
num_valid_videos = sum(1 for v in valid_video_mask if v)

print(f"\nProcessed {num_videos} videos (valid {num_valid_videos}/{num_videos}), extracted {total_extracted_frames} frames.")
print(f"Generated {num_pairs} frame pairs: {num_cuts} cuts (positive) and {num_noncuts} non-cuts (negative).")

pos_count = num_cuts
neg_count = num_noncuts
pos_ratio = (pos_count / num_pairs) if num_pairs > 0 else 0.0
neg_ratio = (neg_count / num_pairs) if num_pairs > 0 else 0.0
print(f"Pos ratio: {pos_ratio:.6f} | Neg ratio: {neg_ratio:.6f}")

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import random
import cv2
import numpy as np

class ShotBoundaryDataset(Dataset):
    def __init__(self, pairs, video_frames):
        self.pairs = pairs
        self.video_frames = video_frames

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        vid_idx, frame_idx, label = self.pairs[idx]
        frameA = self.video_frames[vid_idx][frame_idx]
        frameB = self.video_frames[vid_idx][frame_idx + 1]

        diff = cv2.absdiff(frameA, frameB)
        img_9ch = np.concatenate([frameA, frameB, diff], axis=2).astype("float32") / 255.0
        img_9ch_chw = np.transpose(img_9ch, (2, 0, 1))

        img_tensor = torch.tensor(img_9ch_chw, dtype=torch.float32)
        label_tensor = torch.tensor(label, dtype=torch.long)
        return img_tensor, label_tensor

# ============= 划分训练测试 =============
all_video_indices = list(range(num_videos))
random.seed(SEED_SPLIT)
random.shuffle(all_video_indices)

split_idx = int(len(all_video_indices) * 0.95)
train_vids = set(all_video_indices[:split_idx])
test_vids  = set(all_video_indices[split_idx:])

train_pairs_all = [p for p in boundary_pairs if p[0] in train_vids]
test_pairs      = [p for p in boundary_pairs if p[0] in test_vids]

train_pos_pairs = [p for p in train_pairs_all if p[2] == 1]
train_neg_pairs = [p for p in train_pairs_all if p[2] == 0]

# 测试集永远全量
test_dataset = ShotBoundaryDataset(test_pairs, video_frames)

_nw = int(globals().get("DATALOADER_NUM_WORKERS", 0))
_pw = bool(globals().get("DATALOADER_PERSISTENT_WORKERS", False)) and (_nw > 0)

test_loader  = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=_nw,
    pin_memory=bool(globals().get("DATALOADER_PIN_MEMORY", False)),
    persistent_workers=_pw,
)

print(f"[Split] train_vids={len(train_vids)} test_vids={len(test_vids)}")
print(f"[Pool] train_pos={len(train_pos_pairs)} train_neg={len(train_neg_pairs)} test_total={len(test_pairs)}")

def make_train_loader_for_epoch(epoch: int, batch_size: int = BATCH_SIZE, shuffle: bool = True):
    """
    每个epoch动态构造训练集：
    - 正例：全量保留
    - 负例：每轮重新随机抽样（可关）
    """
    if (not USE_DYNAMIC_NEG_SAMPLING) or (len(train_neg_pairs) == 0):
        epoch_pairs = list(train_pairs_all)
        rng = None
        raw_pos_weight_epoch = (len(train_neg_pairs) / max(1, len(train_pos_pairs))) if len(train_pos_pairs) > 0 else 1.0
    else:
        rng = random.Random(SEED_EPOCH_BASE + int(epoch))

        if NEG_SAMPLING_MODE == "ratio":
            k = int(len(train_neg_pairs) * float(NEG_SAMPLE_RATIO))
        elif NEG_SAMPLING_MODE == "per_pos":
            k = int(len(train_pos_pairs) * int(NEG_PER_POS))
        else:
            raise ValueError("NEG_SAMPLING_MODE must be 'ratio' or 'per_pos'")

        k = max(1, min(k, len(train_neg_pairs)))
        neg_sample = rng.sample(train_neg_pairs, k)

        epoch_pairs = list(train_pos_pairs) + neg_sample
        rng.shuffle(epoch_pairs)

        pos_n = len(train_pos_pairs)
        neg_n = len(epoch_pairs) - pos_n
        raw_pos_weight_epoch = (neg_n / max(1, pos_n))

    train_dataset_epoch = ShotBoundaryDataset(epoch_pairs, video_frames)

    _nw = int(globals().get("DATALOADER_NUM_WORKERS", 0))
    _pw = bool(globals().get("DATALOADER_PERSISTENT_WORKERS", False)) and (_nw > 0)

    train_loader_epoch  = DataLoader(
        train_dataset_epoch,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=_nw,
        pin_memory=bool(globals().get("DATALOADER_PIN_MEMORY", False)),
        persistent_workers=_pw,
    )

    pos_n = len(train_pos_pairs)
    neg_n = len(train_dataset_epoch) - pos_n
    print(f"[Epoch {epoch}] train_pairs={len(train_dataset_epoch)} (pos {pos_n}, neg {neg_n}) raw_pos_weight_ep={raw_pos_weight_epoch:.4f}")
    return train_loader_epoch, train_dataset_epoch, raw_pos_weight_epoch


In [None]:
# ============= Boundary CNN 训练部分 =============
import os
import platform
import time
import math
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import openpyxl
from datetime import datetime

try:
    from zoneinfo import ZoneInfo
    _TZ = ZoneInfo("Asia/BeiJing")
except Exception:
    _TZ = None

def _g(name, default=None):
    return globals().get(name, default)

def _none_str(x):
    return "none" if x is None else str(x)

def simple_classification_report(y_true, y_pred, target_names):
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    n_classes = len(target_names)
    lines = []
    acc = (y_true == y_pred).sum() / len(y_true) if len(y_true) > 0 else 0.0

    lines.append("precision    recall  f1-score   support")

    for i in range(n_classes):
        name = target_names[i]
        true_i = (y_true == i)
        pred_i = (y_pred == i)

        tp = np.logical_and(true_i, pred_i).sum()
        fp = np.logical_and(~true_i, pred_i).sum()
        fn = np.logical_and(true_i, ~pred_i).sum()
        support = true_i.sum()

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall    = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1        = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0

        lines.append(f"{name:>9} {precision:9.4f} {recall:9.4f} {f1:9.4f} {support:9d}")

    lines.append(f"\naccuracy {acc:9.4f} {len(y_true):9d}")
    return "\n".join(lines)

# ===== 二分类指标 =====
def binary_metrics_from_probs(y_true, prob_pos, threshold=0.5):
    y_true = np.array(y_true).astype(int)
    prob_pos = np.array(prob_pos).astype(float)
    y_pred = (prob_pos >= threshold).astype(int)

    tp = int(np.logical_and(y_true == 1, y_pred == 1).sum())
    fp = int(np.logical_and(y_true == 0, y_pred == 1).sum())
    tn = int(np.logical_and(y_true == 0, y_pred == 0).sum())
    fn = int(np.logical_and(y_true == 1, y_pred == 0).sum())

    precision = tp / (tp + fp) if (tp + fp) else 0.0
    recall    = tp / (tp + fn) if (tp + fn) else 0.0
    f1        = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
    acc       = (tp + tn) / max(1, (tp + tn + fp + fn))

    pos_pred_rate = (tp + fp) / max(1, len(y_true))
    avg_prob_pos = float(prob_pos[y_true == 1].mean()) if (y_true == 1).any() else 0.0
    avg_prob_neg = float(prob_pos[y_true == 0].mean()) if (y_true == 0).any() else 0.0

    return {
        "threshold": float(threshold),
        "precision": float(precision),
        "recall": float(recall),
        "f1": float(f1),
        "acc": float(acc),
        "tp": tp, "fp": fp, "tn": tn, "fn": fn,
        "pos_pred_rate": float(pos_pred_rate),
        "avg_prob_pos": float(avg_prob_pos),
        "avg_prob_neg": float(avg_prob_neg),
    }

# ===== 简单 AP / AUC =====
def average_precision_score(y_true, prob_pos):
    y_true = np.array(y_true).astype(int)
    prob_pos = np.array(prob_pos).astype(float)
    order = np.argsort(-prob_pos)
    y = y_true[order]
    tp = 0
    fp = 0
    precisions = []
    recalls = []
    P = max(1, int(y_true.sum()))
    for i in range(len(y)):
        if y[i] == 1:
            tp += 1
        else:
            fp += 1
        precisions.append(tp / max(1, (tp + fp)))
        recalls.append(tp / P)
    ap = 0.0
    prev_r = 0.0
    for p, r in zip(precisions, recalls):
        ap += p * max(0.0, r - prev_r)
        prev_r = r
    return float(ap)

def roc_auc_score_rank(y_true, prob_pos):
    y_true = np.array(y_true).astype(int)
    prob_pos = np.array(prob_pos).astype(float)
    pos = prob_pos[y_true == 1]
    neg = prob_pos[y_true == 0]
    if len(pos) == 0 or len(neg) == 0:
        return 0.0
    # Mann–Whitney U
    all_scores = np.concatenate([pos, neg])
    ranks = all_scores.argsort().argsort() + 1
    r_pos = ranks[:len(pos)]
    U = r_pos.sum() - len(pos) * (len(pos) + 1) / 2
    auc = U / (len(pos) * len(neg))
    return float(auc)

# ===== 设备 =====
if bool(_g("USE_CUDA", True)) and torch.cuda.is_available():
    _cuda_idx = int(_g("CUDA_DEVICE_INDEX", 0))
    device = torch.device(f"cuda:{_cuda_idx}")
else:
    device = torch.device("cpu")

try:
    torch.backends.cudnn.benchmark = bool(_g("CUDNN_BENCHMARK", True))
    torch.backends.cudnn.deterministic = bool(_g("CUDNN_DETERMINISTIC", False))
except Exception:
    pass

try:
    torch.backends.cuda.matmul.allow_tf32 = bool(_g("ALLOW_TF32", True))
    torch.backends.cudnn.allow_tf32 = bool(_g("ALLOW_TF32", True))
except Exception:
    pass

try:
    torch.set_float32_matmul_precision(str(_g("MATMUL_PRECISION", "high")))
except Exception:
    pass

print("Using device:", device)

# ----- Model -----
class BoundaryCNN(nn.Module):
    def __init__(self):
        super(BoundaryCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(9, 32, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(64 * (frame_size[0] // (2**3)) * (frame_size[1] // (2**3)), 128),
            nn.ReLU(),
            nn.Linear(128, 2)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

boundary_model = BoundaryCNN().to(device)

pos_weight = float(_g("POS_WEIGHT", 1.0))
raw_pos_weight = None
if num_cuts > 0:
    raw_pos_weight = num_noncuts / max(1, num_cuts)

print(f"raw_pos_weight = {raw_pos_weight if raw_pos_weight is not None else 'NA'}, nominal pos_weight = {pos_weight}")

class_weights = torch.tensor([1.0, float(pos_weight)], dtype=torch.float32).to(device)
criterion_b = nn.CrossEntropyLoss(weight=class_weights)
optimizer_b = optim.Adam(boundary_model.parameters(), lr=float(_g("LR_INIT", 1e-3)))

# ===== 系统信息 =====
def get_system_info():
    info = {}
    info["time_start"] = datetime.now(_TZ).strftime("%Y-%m-%d %H:%M:%S %Z") if _TZ else datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    info["python_version"] = platform.python_version()
    info["platform"] = platform.platform()
    info["processor"] = platform.processor()
    info["torch_version"] = torch.__version__
    info["cuda_available"] = str(torch.cuda.is_available())
    info["device_used"] = str(device)
    try:
        info["cuda_device_count"] = str(torch.cuda.device_count()) if torch.cuda.is_available() else "0"
        info["cuda_device_name_0"] = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "NA"
    except Exception:
        info["cuda_device_count"] = "NA"
        info["cuda_device_name_0"] = "NA"
    return info

run_info = get_system_info()

# ===== 训练 =====
epochs = int(_g("EPOCHS", 1))
threshold_default = 0.95

t0 = time.time()
global_step = 0
epoch_rows = []

for epoch in range(epochs):
    boundary_model.train()

    train_loader, train_dataset, raw_pos_weight_epoch = make_train_loader_for_epoch(epoch, batch_size=int(_g("BATCH_SIZE", 32)), shuffle=True)

    mode = str(_g("POS_WEIGHT_MODE", "fixed")).lower()
    if mode == "fixed":
        used_pos_weight_epoch = float(pos_weight)
    elif mode == "epoch":
        used_pos_weight_epoch = float(min(float(raw_pos_weight_epoch), float(_g("POS_WEIGHT_EPOCH_MAX", pos_weight))))
    else:
        used_pos_weight_epoch = float(pos_weight)  # 兜底

    class_weights = torch.tensor([1.0, used_pos_weight_epoch], dtype=torch.float32).to(device)
    criterion_b = nn.CrossEntropyLoss(weight=class_weights)

    losses = []
    y_true_ep, prob_pos_ep = [], []
    for imgs, labels in train_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)

        optimizer_b.zero_grad()
        outputs = boundary_model(imgs)
        loss = criterion_b(outputs, labels)
        loss.backward()
        optimizer_b.step()

        losses.append(float(loss.item()))
        probs = torch.softmax(outputs.detach(), dim=1)[:, 1].detach().cpu().numpy().tolist()
        y_true_ep.extend(labels.detach().cpu().numpy().tolist())
        prob_pos_ep.extend(probs)

        global_step += 1

    m = binary_metrics_from_probs(y_true_ep, prob_pos_ep, threshold=threshold_default)
    ap = average_precision_score(y_true_ep, prob_pos_ep)
    auc = roc_auc_score_rank(y_true_ep, prob_pos_ep)

    lr_now = optimizer_b.param_groups[0].get("lr", None)

    print(
        f"[Boundary] Epoch {epoch+1}/{epochs} | "
        f"loss {np.mean(losses) if len(losses) else 0.0:.4f} | "
        f"F1 {m['f1']:.4f} (P {m['precision']:.4f}, R {m['recall']:.4f}) | "
        f"AP {ap:.4f} | AUC {auc:.4f} | "
        f"TP {m['tp']} FP {m['fp']} TN {m['tn']} FN {m['fn']} | "
        f"pos_pred_rate {m['pos_pred_rate']:.4f} | "
        f"pos_weight {used_pos_weight_epoch:.2f} (raw_ep {raw_pos_weight_epoch:.2f})"
    )

    epoch_rows.append({
        "epoch": int(epoch+1),
        "global_step": int(global_step),
        "loss": float(np.mean(losses) if len(losses) else 0.0),
        "lr": float(lr_now) if lr_now is not None else None,
        "threshold": float(threshold_default),
        "precision": float(m["precision"]),
        "recall": float(m["recall"]),
        "f1": float(m["f1"]),
        "acc": float(m["acc"]),
        "tp": int(m["tp"]), "fp": int(m["fp"]), "tn": int(m["tn"]), "fn": int(m["fn"]),
        "pos_pred_rate": float(m["pos_pred_rate"]),
        "avg_prob_pos": float(m["avg_prob_pos"]),
        "avg_prob_neg": float(m["avg_prob_neg"]),
        "pr_auc_ap": float(ap),
        "roc_auc": float(auc),
        "best_threshold_train": None,
        "best_f1_train": None,
        "raw_pos_weight": float(raw_pos_weight) if raw_pos_weight is not None else None,
        "pos_weight": float(pos_weight),  # 旧字段：固定“名义pos_weight”
        "weight_noncut": float(class_weights[0].item()) if class_weights is not None else None,
        "weight_cut": float(class_weights[1].item()) if class_weights is not None else None,
        "train_pairs_epoch": int(len(train_loader.dataset)) if hasattr(train_loader, "dataset") else None,
        "raw_pos_weight_epoch": float(raw_pos_weight_epoch) if raw_pos_weight_epoch is not None else None,
        "pos_weight_mode": str(_g("POS_WEIGHT_MODE", "none")),
        "pos_weight_used_epoch": float(used_pos_weight_epoch) if used_pos_weight_epoch is not None else None,
        "neg_sampling_enabled": bool(_g("USE_DYNAMIC_NEG_SAMPLING", False)),
        "neg_sampling_mode": str(_g("NEG_SAMPLING_MODE", "none")),
        "neg_sample_ratio": float(_g("NEG_SAMPLE_RATIO", 0.0)) if _g("NEG_SAMPLE_RATIO", None) is not None else None,
        "neg_per_pos": int(_g("NEG_PER_POS", 0)) if _g("NEG_PER_POS", None) is not None else None,
    })

print("Finished training boundary detection model.")
t_train = time.time() - t0

boundary_model.eval()
y_true_b, prob_pos_b, y_pred_b = [], [], []
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs = imgs.to(device)
        outputs = boundary_model(imgs)
        probs = torch.softmax(outputs, dim=1)[:, 1]
        pred = (probs >= threshold_default).long()

        y_true_b.extend(labels.detach().cpu().numpy().tolist())
        prob_pos_b.extend(probs.detach().cpu().numpy().tolist())
        y_pred_b.extend(pred.detach().cpu().numpy().tolist())

m_test = binary_metrics_from_probs(y_true_b, prob_pos_b, threshold=threshold_default)
ap_test = average_precision_score(y_true_b, prob_pos_b)
auc_test = roc_auc_score_rank(y_true_b, prob_pos_b)

print("\n[Boundary] FINAL TEST (threshold=0.95) => "
      f"F1 {m_test['f1']:.4f} (P {m_test['precision']:.4f}, R {m_test['recall']:.4f}) | "
      f"AP {ap_test:.4f} | AUC {auc_test:.4f} | "
      f"TP {m_test['tp']} FP {m_test['fp']} TN {m_test['tn']} FN {m_test['fn']}")

print("\nShot Boundary Detection - Classification Report (TEST):")
print(simple_classification_report(y_true_b, y_pred_b, target_names=["Non-cut", "Cut"]))

# ===== 写 Excel（保持原 sheet/列结构）=====
base_dir = str(_g("REPORTS_BASE_DIR", "movie/reports"))
os.makedirs(base_dir, exist_ok=True)

ts_folder = datetime.now(_TZ).strftime("%m%d%H%M") if _TZ else datetime.now().strftime("%m%d%H%M")
REPORT_FOLDER_NAME = f"{ts_folder}_e{epochs}_w{int(pos_weight)}_{_none_str(_g('DATA_VERSION', 'NA'))}"
REPORT_DIR = os.path.join(base_dir, REPORT_FOLDER_NAME)
os.makedirs(REPORT_DIR, exist_ok=True)

ts = datetime.now(_TZ).strftime("%Y%m%d_%H%M%S") if _TZ else datetime.now().strftime("%Y%m%d_%H%M%S")
out_path = os.path.join(REPORT_DIR, f"boundary_train_metrics_{ts}.xlsx")

wb = openpyxl.Workbook()

# Sheet 1: run_info
ws0 = wb.active
ws0.title = "run_info"
ws0.append(["key", "value"])
for k, v in run_info.items():
    ws0.append([k, str(v)])

ws0.append([""])
ws0.append(["epochs", _none_str(epochs)])
ws0.append(["train_seconds", f"{t_train:.3f}"])
ws0.append(["threshold_default", _none_str(threshold_default)])
ws0.append(["optimizer", "Adam"])
ws0.append(["lr_init", _none_str(_g("LR_INIT", None))])
ws0.append(["loss", "CrossEntropyLoss(weighted)"])
ws0.append(["class_weight_noncut", _none_str(1.0)])
ws0.append(["class_weight_cut", _none_str(pos_weight)])
ws0.append(["raw_pos_weight", _none_str(raw_pos_weight)])
ws0.append(["pos_weight_used", _none_str(pos_weight)])
ws0.append([""])
ws0.append(["pos_weight_mode", _none_str(_g("POS_WEIGHT_MODE", None))])
ws0.append(["pos_weight_epoch_max", _none_str(_g("POS_WEIGHT_EPOCH_MAX", None))])
ws0.append(["neg_sampling_enabled", _none_str(_g("USE_DYNAMIC_NEG_SAMPLING", None))])
ws0.append(["neg_sampling_mode", _none_str(_g("NEG_SAMPLING_MODE", None))])
ws0.append(["neg_sample_ratio", _none_str(_g("NEG_SAMPLE_RATIO", None))])
ws0.append(["neg_per_pos", _none_str(_g("NEG_PER_POS", None))])
ws0.append([""])
ws0.append(["report_folder_name", REPORT_FOLDER_NAME])
ws0.append(["report_dir", REPORT_DIR])


ws1 = wb.create_sheet("epoch_metrics")
cols = [
    "epoch","global_step","loss","lr",
    "threshold","precision","recall","f1","acc",
    "tp","fp","tn","fn",
    "pos_pred_rate","avg_prob_pos","avg_prob_neg",
    "pr_auc_ap","roc_auc",
    "best_threshold_train","best_f1_train",
    "raw_pos_weight","pos_weight","weight_noncut","weight_cut"
]
extra_cols = [
    "train_pairs_epoch",
    "raw_pos_weight_epoch",
    "pos_weight_mode",
    "pos_weight_used_epoch",
    "neg_sampling_enabled",
    "neg_sampling_mode",
    "neg_sample_ratio",
    "neg_per_pos"
]
all_cols = cols + extra_cols

ws1.append(all_cols)
for r in epoch_rows:
    ws1.append([r.get(c, None) for c in all_cols])

ws2 = wb.create_sheet("final_test")
ws2.append(["metric", "value"])
ws2.append(["threshold", m_test["threshold"]])
ws2.append(["precision", m_test["precision"]])
ws2.append(["recall", m_test["recall"]])
ws2.append(["f1", m_test["f1"]])
ws2.append(["acc", m_test["acc"]])
ws2.append(["tp", m_test["tp"]])
ws2.append(["fp", m_test["fp"]])
ws2.append(["tn", m_test["tn"]])
ws2.append(["fn", m_test["fn"]])
ws2.append(["pos_pred_rate", m_test["pos_pred_rate"]])
ws2.append(["avg_prob_pos", m_test["avg_prob_pos"]])
ws2.append(["avg_prob_neg", m_test["avg_prob_neg"]])
ws2.append(["pr_auc_ap", ap_test])
ws2.append(["roc_auc", auc_test])

ws2.append([""])
ws2.append(["classification_report", ""])
report_str = simple_classification_report(y_true_b, y_pred_b, target_names=["Non-cut", "Cut"])
for line in report_str.splitlines():
    ws2.append([line, ""])

wb.save(out_path)
print(f"\nSaved metrics Excel to: {out_path}")
print(f"[Report] Folder: {REPORT_DIR}")


In [None]:
import os
import glob
import cv2
import numpy as np
import openpyxl
import torch
from torch.utils.data import Dataset, DataLoader
from datetime import datetime
import platform

def _g(name, default=None):
    return globals().get(name, default)

if "boundary_model" not in globals() or boundary_model is None:
    raise RuntimeError("boundary_model 不存在：请先运行 Cell3 完成训练/加载模型，再运行 Cell4。")

pos_weight = float(_g("pos_weight", _g("POS_WEIGHT", 1.0)))
epochs = _g("epochs", _g("EPOCHS", None))


try:
    from zoneinfo import ZoneInfo  # py>=3.9
    _TZ = ZoneInfo("Asia/BeiJing")
except Exception:
    _TZ = None

test_video_dir  = globals().get("TEST_VIDEO_DIR", f"movie/{DATA_VERSION}/dataset/cut_test")

test_excel_file = globals().get("TEST_EXCEL_FILE", f"movie/{DATA_VERSION}/dataset/cut_test.xlsx")


threshold_default = 0.95
batch_size = 1024
topk_suspects = 10  

if bool(globals().get("USE_CUDA", True)) and torch.cuda.is_available():
    _cuda_idx = int(globals().get("CUDA_DEVICE_INDEX", 0))
    device = torch.device(f"cuda:{_cuda_idx}")
else:
    device = torch.device("cpu")
boundary_model.eval()

wb = openpyxl.load_workbook(test_excel_file, data_only=True)
ws = wb.active
rows = list(ws.iter_rows(values_only=True))
if len(rows) == 0:
    raise RuntimeError("cut_test.xlsx 里没有任何行！")

fps_row = rows[0]

def get_fps_from_row(r, default=24.0):
    for cell in r:
        if cell is None:
            continue
        if isinstance(cell, (int, float)):
            return float(cell)
        if isinstance(cell, str):
            s = cell.strip()
            if s == "" or s.lower() == "none":
                continue
            try:
                return float(s)
            except Exception:
                continue
    return float(default)

fps = get_fps_from_row(fps_row, default=24.0)
print(f"[cut_test] Using FPS from Excel first row: {fps}")

data_rows = rows[1:]
video_files = sorted(glob.glob(f"{test_video_dir}/V*.mp4"))
assert len(video_files) == len(data_rows), \
    f"Mismatch: videos({len(video_files)}) vs excel rows({len(data_rows)})"

def timecode_to_frame(tc, fps):
    if tc is None:
        return None
    if isinstance(tc, (int, float)):
        return int(tc)
    if isinstance(tc, str):
        s = tc.strip()
        if s == "" or s.lower() == "none":
            return None
        if s.isdigit():
            return int(s)
        if ":" in s:
            parts = s.split(":")
            try:
                if len(parts) == 2:
                    # ss:ff 例如 "01:12" -> 1秒12帧 -> 1*fps+12
                    sec = int(parts[0])
                    frm = int(parts[1])
                    return int(sec * fps + frm)
                elif len(parts) == 3:
                    # hh:mm:ss
                    h = int(parts[0]); m = int(parts[1]); sec = int(parts[2])
                    total_sec = h * 3600 + m * 60 + sec
                    return int(total_sec * fps)
                else:
                    return None
            except Exception:
                return None
    return None

video_frames = []
boundary_pairs_test = []

video_meta = []  # 每个视频：dict(total_frames, gt_cut_indices, name, path)

for vid_idx, (video_path, row) in enumerate(zip(video_files, data_rows)):
    cap = cv2.VideoCapture(video_path)
    frames = []
    success, frame = cap.read()
    while success:
        frame_resized = cv2.resize(frame, frame_size)
        frames.append(frame_resized)
        success, frame = cap.read()
    cap.release()

    video_frames.append(frames)
    total_frames = len(frames)
    if total_frames <= 1:
        print(f"[cut_test] Warning: {video_path} has {total_frames} frame(s), skip.")
        video_meta.append({
            "vid_idx": vid_idx,
            "name": os.path.basename(video_path),
            "path": video_path,
            "total_frames": total_frames,
            "gt_cut_indices": [],
        })
        continue

    raw_values = list(row) if row is not None else []

    cut_indices = []
    for v in raw_values:
        frame_idx = timecode_to_frame(v, fps)
        if frame_idx is None:
            continue

        if frame_idx > 0:
            frame_idx = frame_idx - 1
        else:
            continue

        frame_idx = int(frame_idx)
        frame_idx = max(0, min(frame_idx, total_frames - 2))
        cut_indices.append(frame_idx)

    cut_indices = sorted(set(cut_indices))
    print(f"[cut_test] Video {vid_idx} ({os.path.basename(video_path)}): total_frames={total_frames}, cuts-1@frames={cut_indices}")

    video_meta.append({
        "vid_idx": vid_idx,
        "name": os.path.basename(video_path),
        "path": video_path,
        "total_frames": total_frames,
        "gt_cut_indices": cut_indices,
    })

    cut_set = set(cut_indices)
    for i in range(total_frames - 1):
        label = 1 if i in cut_set else 0
        boundary_pairs_test.append((vid_idx, i, label))

num_pairs = len(boundary_pairs_test)
num_cuts = sum(1 for _, _, lbl in boundary_pairs_test if lbl == 1)
num_noncuts = num_pairs - num_cuts
print(f"\n[cut_test] Generated {num_pairs} pairs: {num_cuts} cuts, {num_noncuts} non-cuts")

class ShotBoundaryDataset(Dataset):
    def __init__(self, pairs, video_frames):
        self.pairs = pairs
        self.video_frames = video_frames

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        vid_idx, frame_idx, label = self.pairs[idx]
        frameA = self.video_frames[vid_idx][frame_idx]
        frameB = self.video_frames[vid_idx][frame_idx + 1]
        diff = cv2.absdiff(frameA, frameB)

        img_9ch = np.concatenate([frameA, frameB, diff], axis=2).astype("float32") / 255.0
        img_9ch_chw = np.transpose(img_9ch, (2, 0, 1))

        img_tensor = torch.tensor(img_9ch_chw, dtype=torch.float32)
        label_tensor = torch.tensor(label, dtype=torch.long)
        return img_tensor, label_tensor

test_dataset = ShotBoundaryDataset(boundary_pairs_test, video_frames)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

y_true, prob_pos, y_pred = [], [], []

pair_details = []  # list of tuples

pair_ptr = 0
with torch.no_grad():
    for imgs, labels in test_loader:
        bsz = labels.size(0)
        imgs = imgs.to(device)
        outputs = boundary_model(imgs)                 # logits [B,2]
        probs = torch.softmax(outputs, dim=1)[:, 1]    # P(cut)
        pred = (probs >= threshold_default).long()

        probs_np = probs.detach().cpu().numpy()
        pred_np  = pred.detach().cpu().numpy()
        labels_np = labels.detach().cpu().numpy()

        y_true.extend(labels_np.tolist())
        prob_pos.extend(probs_np.tolist())
        y_pred.extend(pred_np.tolist())

        for j in range(bsz):
            vid_idx, frame_idx, gt = boundary_pairs_test[pair_ptr + j]
            pair_details.append((vid_idx, frame_idx, int(gt), float(probs_np[j]), int(pred_np[j])))
        pair_ptr += bsz

m = binary_metrics_from_probs(y_true, prob_pos, threshold=threshold_default)
ap = average_precision_score(y_true, prob_pos)
auc = roc_auc_score_rank(y_true, prob_pos)

print("\n[cut_test] FINAL TEST (threshold=0.95) => "
      f"F1 {m['f1']:.4f} (P {m['precision']:.4f}, R {m['recall']:.4f}) | "
      f"AP {ap:.4f} | AUC {auc:.4f} | "
      f"TP {m['tp']} FP {m['fp']} TN {m['tn']} FN {m['fn']}")

print("\nShot Boundary Detection - Classification Report (cut_test):")
report_str = simple_classification_report(y_true, y_pred, target_names=["Non-cut", "Cut"])
print(report_str)

per_video_pairs = {vm["vid_idx"]: [] for vm in video_meta}
for (vid_idx, frame_idx, gt, p, pred) in pair_details:
    per_video_pairs[vid_idx].append((frame_idx, gt, p, pred))

per_video_rows = []
per_video_suspects_rows = []  

for vm in video_meta:
    vid_idx = vm["vid_idx"]
    name = vm["name"]
    total_frames = vm["total_frames"]
    gt_cuts = set(vm["gt_cut_indices"])

    pairs = per_video_pairs.get(vid_idx, [])
    if not pairs:
        per_video_rows.append({
            "vid": name, "vid_idx": vid_idx, "total_frames": total_frames,
            "gt_cut_count": len(gt_cuts), "pred_cut_count": 0,
            "tp": 0, "fp": 0, "fn": len(gt_cuts),
            "gt_cuts": ",".join(map(str, sorted(gt_cuts))) if gt_cuts else "none",
            "pred_cuts": "none",
        })
        continue

    pred_cuts = sorted({fr for (fr, gt, p, pred) in pairs if pred == 1})
    pred_cut_set = set(pred_cuts)

    tp = len(gt_cuts & pred_cut_set)
    fp = len(pred_cut_set - gt_cuts)
    fn = len(gt_cuts - pred_cut_set)

    per_video_rows.append({
        "vid": name, "vid_idx": vid_idx, "total_frames": total_frames,
        "gt_cut_count": len(gt_cuts), "pred_cut_count": len(pred_cut_set),
        "tp": tp, "fp": fp, "fn": fn,
        "gt_cuts": ",".join(map(str, sorted(gt_cuts))) if gt_cuts else "none",
        "pred_cuts": ",".join(map(str, pred_cuts)) if pred_cuts else "none",
    })

    fps_list = [(fr, p) for (fr, gt, p, pred) in pairs if gt == 0 and pred == 1]
    fns_list = [(fr, p) for (fr, gt, p, pred) in pairs if gt == 1 and pred == 0]

    fps_list = sorted(fps_list, key=lambda x: -x[1])[:topk_suspects]
    fns_list = sorted(fns_list, key=lambda x: x[1])[:topk_suspects]

    for rank, (fr, p) in enumerate(fps_list, start=1):
        per_video_suspects_rows.append([name, vid_idx, "FP", rank, fr, float(p)])
    for rank, (fr, p) in enumerate(fns_list, start=1):
        per_video_suspects_rows.append([name, vid_idx, "FN", rank, fr, float(p)])

base_dir = globals().get('REPORTS_BASE_DIR', 'movie/reports')
os.makedirs(base_dir, exist_ok=True)

if "REPORT_DIR" not in globals() or "REPORT_FOLDER_NAME" not in globals():
    ts_folder = datetime.now(_TZ).strftime("%m%d%H%M") if _TZ else datetime.now().strftime("%m%d%H%M")
    # epochs / pos_weight 在 Cell4 里不一定存在，所以优先复用已有；没有就用占位
    _e = globals().get("epochs", "NA")
    _w = globals().get("pos_weight", "NA")
    REPORT_FOLDER_NAME = f"{ts_folder}_e{_e}_w{_w}"
    REPORT_DIR = os.path.join(base_dir, REPORT_FOLDER_NAME)
    os.makedirs(REPORT_DIR, exist_ok=True)

ts = datetime.now(_TZ).strftime("%Y%m%d_%H%M%S") if _TZ else datetime.now().strftime("%Y%m%d_%H%M%S")
out_path = os.path.join(REPORT_DIR, f"cut_test_report_{ts}.xlsx")


wb_out = openpyxl.Workbook()

ws0 = wb_out.active
ws0.title = "run_info"
ws0.append(["key", "value"])
ws0.append(["time", datetime.now(_TZ).strftime("%Y-%m-%d %H:%M:%S %Z") if _TZ else datetime.now().strftime("%Y-%m-%d %H:%M:%S")])
ws0.append(["python_version", platform.python_version()])
ws0.append(["platform", platform.platform()])
ws0.append(["processor", platform.processor()])
ws0.append(["torch_version", torch.__version__])
ws0.append(["cuda_available", str(torch.cuda.is_available())])
ws0.append(["device_used", str(device)])
ws0.append(["test_video_dir", test_video_dir])
ws0.append(["test_excel_file", test_excel_file])
ws0.append(["fps_from_excel", str(fps)])
ws0.append(["threshold_default", str(threshold_default)])
ws0.append(["batch_size", str(batch_size)])
ws0.append(["topk_suspects", str(topk_suspects)])
ws0.append(["report_folder_name", REPORT_FOLDER_NAME])
ws0.append(["report_dir", REPORT_DIR])

ws_sum = wb_out.create_sheet("dataset_summary")
ws_sum.append(["item", "value"])
ws_sum.append(["num_videos", len(video_files)])
ws_sum.append(["num_pairs", num_pairs])
ws_sum.append(["num_cuts", num_cuts])
ws_sum.append(["num_non_cuts", num_noncuts])
ws_sum.append(["pos_ratio", (num_cuts / num_pairs) if num_pairs else 0.0])

ws_sum.append([""])
ws_sum.append(["per_video_frame_stats", ""])
frames_list = [vm["total_frames"] for vm in video_meta if vm["total_frames"] is not None]
if frames_list:
    ws_sum.append(["min_frames", int(np.min(frames_list))])
    ws_sum.append(["max_frames", int(np.max(frames_list))])
    ws_sum.append(["mean_frames", float(np.mean(frames_list))])
    ws_sum.append(["median_frames", float(np.median(frames_list))])

ws_v = wb_out.create_sheet("per_video")
ws_v.append([
    "vid", "vid_idx", "total_frames",
    "gt_cut_count", "pred_cut_count",
    "tp", "fp", "fn",
    "gt_cuts", "pred_cuts"
])
for r in per_video_rows:
    ws_v.append([
        r["vid"], r["vid_idx"], r["total_frames"],
        r["gt_cut_count"], r["pred_cut_count"],
        r["tp"], r["fp"], r["fn"],
        r["gt_cuts"], r["pred_cuts"]
    ])

ws_sus = wb_out.create_sheet("suspects_topk")
ws_sus.append(["vid", "vid_idx", "type", "rank", "frame_idx", "prob_cut"])
for row in per_video_suspects_rows:
    ws_sus.append(row)

ws2 = wb_out.create_sheet("final_test")
ws2.append(["metric", "value"])
ws2.append(["threshold", m["threshold"]])
ws2.append(["precision", m["precision"]])
ws2.append(["recall", m["recall"]])
ws2.append(["f1", m["f1"]])
ws2.append(["acc", m["acc"]])
ws2.append(["tp", m["tp"]])
ws2.append(["fp", m["fp"]])
ws2.append(["tn", m["tn"]])
ws2.append(["fn", m["fn"]])
ws2.append(["pos_pred_rate", m["pos_pred_rate"]])
ws2.append(["avg_prob_pos", m["avg_prob_pos"]])
ws2.append(["avg_prob_neg", m["avg_prob_neg"]])
ws2.append(["pr_auc_ap", ap])
ws2.append(["roc_auc", auc])

ws_rep = wb_out.create_sheet("classification_report")
ws_rep.append(["text"])
for line in report_str.splitlines():
    ws_rep.append([line])

wb_out.save(out_path)
print(f"\nSaved cut_test report Excel to: {out_path}")
print(f"[Report] Folder: {REPORT_DIR}")


In [None]:
try:
    os.makedirs(REPORT_DIR, exist_ok=True)
    ts_model = datetime.now(_TZ).strftime("%Y%m%d_%H%M%S") if _TZ else datetime.now().strftime("%Y%m%d_%H%M%S")

    model_path = os.path.join(REPORT_DIR, f"boundary_model_{ts_model}.pt")
    ckpt_path  = os.path.join(REPORT_DIR, f"boundary_ckpt_{ts_model}.pth")

    torch.save(boundary_model.state_dict(), model_path)

    ckpt = {
        "model_state_dict": boundary_model.state_dict(),
        "device_saved": str(device),
        "fps_from_excel": float(fps),
        "threshold_default": float(threshold_default),
        "batch_size": int(batch_size),
        "frame_size": tuple(frame_size) if "frame_size" in globals() else None,
        "report_dir": REPORT_DIR,
        "report_folder_name": REPORT_FOLDER_NAME,
        "epochs": globals().get("epochs", None),
        "pos_weight": globals().get("pos_weight", None),
        "build_id": globals().get("BUILD_ID", None),
        "saved_time": datetime.now(_TZ).strftime("%Y-%m-%d %H:%M:%S %Z") if _TZ else datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
    }
    torch.save(ckpt, ckpt_path)

    print(f"[Model] Saved state_dict to: {model_path}")
    print(f"[Model] Saved checkpoint to: {ckpt_path}")

except Exception as e:
    print(f"[Model] Save failed: {e}")