# Colab：RStab 训练/推理与多视频稳定性指标

说明：本 notebook 在 Colab GPU 上跑 RStab（CVPR 2024）并对 5 段视频做推理，输出稳定视频与与原视频对比的稳定性指标（与 deep-stabilization/colab_inference 2.ipynb 保持一致的指标集）。

流程概览：
1. 克隆仓库并安装依赖（CUDA 11.6 + torch 1.12.0 环境）
2. 下载官方预训练权重到 `RStab_core/pretrained`
3. 准备至少 5 段原始视频到 `./input`（可用提供的样例或自行上传）
4. Deep3D 预处理（几何优化）
5. RStab rectify 推理，生成稳定视频到 `./output/RStab`
6. 计算稳定性指标（光流分布、抖动突变、相邻帧 MSE/SSIM 等），生成 per-video CSV 与 summary CSV
7. 打包下载结果



## 环境要求
- Colab 运行时选择 GPU（NVIDIA，CUDA 11.x）
- Python 3.10
- torch 1.12.0 + cu116（按官方 README）



In [None]:
# 克隆仓库（包含子模块）并进入目录
!git clone --depth 1 --recursive https://github.com/pzzz-cv/RStab.git
%cd RStab

!pwd
!ls


In [None]:
# 安装依赖（CUDA 11.6 + torch 1.12.0）
!pip -q install --upgrade pip
!pip -q install torch==1.12.0+cu116 torchvision==0.13.0+cu116 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu116
!pip -q install -r requirements.txt
!pip -q install opencv-python-headless ffmpeg-python pandas scikit-image gdown tqdm

import torch, platform, subprocess, os, sys
print('PyTorch:', torch.__version__, 'CUDA available:', torch.cuda.is_available())
print('GPU:', torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None')


In [None]:
# 下载 RStab 预训练权重（自动判定是否 zip）
import os
from pathlib import Path
import gdown, zipfile, shutil

ckpt_file = Path('rstab_ckpt.zip')  # 官方链接文件名可能不是 zip，这里统一保存为此名
os.makedirs('RStab_core/pretrained', exist_ok=True)

if not ckpt_file.exists():
    gdown.download(id='1q3QM1damtvHLukhIOIAdv9IKm646Oj11', output=str(ckpt_file), quiet=False)
else:
    print('已存在 rstab_ckpt.zip，跳过下载')

if ckpt_file.exists():
    if zipfile.is_zipfile(ckpt_file):
        with zipfile.ZipFile(ckpt_file, 'r') as zf:
            zf.extractall('RStab_core/pretrained')
        print('已解压到 RStab_core/pretrained，文件列表：')
        for p in Path('RStab_core/pretrained').rglob('*'):
            print(' -', p)
    else:
        # 若非 zip，可能是直接的 .pth/.ckpt 文件，直接放入目录
        dst = Path('RStab_core/pretrained') / ckpt_file.name
        shutil.copy(ckpt_file, dst)
        print('下载文件非 zip，已直接放入', dst)
else:
    print('未找到下载文件，请检查链接或网络。')


In [None]:
# 准备输入视频（至少 5 段，支持 avi/mp4/m4v/mov）
# 方式 A：左侧 Files 面板上传 videos.zip 到 /content
# 方式 B：运行下方 upload 对话框选择本地文件
# 方式 C：gdown/wget 下载到 /content/videos.zip
# videos.zip 结构示例：
# videos/
#   vid1.m4v / vid1.mp4 / vid1.avi / vid1.mov
#   ...

import os, shutil, zipfile
from pathlib import Path

# 方式 B：运行后会弹窗选择文件
from google.colab import files
uploaded = files.upload()  # 选中本地压缩包，名称可自定义
ZIP_NAME = next(iter(uploaded.keys())) if uploaded else 'videos.zip'
UPLOAD_ZIP = f'/content/{ZIP_NAME}'  # 上传后文件会落在 /content
INPUT_DIR = Path('input')
INPUT_DIR.mkdir(exist_ok=True)

if Path(UPLOAD_ZIP).exists():
    with zipfile.ZipFile(UPLOAD_ZIP, 'r') as zf:
        zf.extractall('input')
        print(f'已解压 {ZIP_NAME} 到 ./input')
else:
    print(f'未检测到 {UPLOAD_ZIP}，请先上传。')

# 若只有单个示例，可开启复制以凑够 5 段（仅验证流程，指标无代表性）
DESIRED_VIDEO_COUNT = 5
ALLOW_DUPLICATE_SAMPLE = False
EXTS = ['.avi', '.mp4', '.m4v', '.mov']

vid_list = sorted([p for p in INPUT_DIR.rglob('*') if p.suffix.lower() in EXTS])
print(f'当前检测到 {len(vid_list)} 个视频:', [p.name for p in vid_list])

if len(vid_list) < DESIRED_VIDEO_COUNT and ALLOW_DUPLICATE_SAMPLE and vid_list:
    src = vid_list[0]
    while len(vid_list) < DESIRED_VIDEO_COUNT:
        dup = INPUT_DIR / f'{src.stem}_dup{len(vid_list)}{src.suffix}'
        shutil.copy(src, dup)
        vid_list.append(dup)
        print('已复制示例到', dup)

print('最终视频列表:', [p.name for p in vid_list])



In [None]:
# Deep3D 预处理（生成几何与中间结果）
# 对 input 下的 avi 逐个运行 geometry_optimizer，输出到 ./output/Deep3D/<video_name>

import subprocess, shlex
from pathlib import Path

INPUT_DIR = Path('input')
avi = sorted([p for p in INPUT_DIR.glob('*.avi')])
print('待处理视频:', [p.name for p in avi])

for p in avi:
    out_dir = Path('output/Deep3D')
    out_dir.mkdir(parents=True, exist_ok=True)
    cmd = f"python geometry_optimizer.py --video_path {p} --output_dir {out_dir} --name {p.name}"
    print('Running:', cmd)
    ret = subprocess.run(shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
    print(ret.stdout)
    if ret.returncode != 0:
        print('❌ 失败:', p)
        break



In [None]:
# RStab rectify 推理
# 依赖预处理生成的 ./output/Deep3D/<video_name>

import subprocess, shlex
from pathlib import Path

OUTPUT_DEEP3D = Path('output/Deep3D')
avi = sorted([p for p in Path('input').glob('*.avi')])
print('待推理视频:', [p.name for p in avi])

for p in avi:
    name = p.name
    cmd = f"python RStab_core/rectify.py --expname {name}"
    print('Running:', cmd)
    ret = subprocess.run(shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
    print(ret.stdout)
    if ret.returncode != 0:
        print('❌ 失败:', p)
        break



## 计算稳定性指标（与 deep-stabilization notebook 一致）
指标列表：
- 光流分布：`mean_flow`/`std_flow`/`flow_p95`/`flow_iqr`
- 抖动突变：`flow_jitter_std`
- 相邻帧平滑度：`temporal_mse`（越小越平滑）、`temporal_ssim`（越大越平滑）
- 稳定性收益：`stability_gain = (orig_mean_flow - stab_mean_flow) / orig_mean_flow`
输出：
- per-video `stabilization_metrics.csv`
- 汇总 `stabilization_metrics_summary.csv`（mean/std/min/max）


In [None]:
import os, cv2, numpy as np, pandas as pd
from typing import List, Tuple
from pathlib import Path
from skimage.metrics import structural_similarity as ssim


def load_video(path: str, max_frames: int = None, resize: float = 1.0):
    cap = cv2.VideoCapture(path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        if resize != 1.0:
            frame = cv2.resize(frame, (0, 0), fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
        frames.append(frame)
        if max_frames and len(frames) >= max_frames:
            break
    cap.release()
    return frames, fps, (frames[0].shape[1], frames[0].shape[0]) if frames else (0, 0)


def flow_stats(frames: List[np.ndarray]):
    med_per_frame = []
    for i in range(len(frames) - 1):
        g1 = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY)
        g2 = cv2.cvtColor(frames[i + 1], cv2.COLOR_BGR2GRAY)
        flow = cv2.calcOpticalFlowFarneback(g1, g2, None, 0.5, 3, 15, 3, 5, 1.2, 0)
        mag, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1])
        med_per_frame.append(np.median(mag))
    if len(med_per_frame) == 0:
        return {"mean": 0.0, "std": 0.0, "p95": 0.0, "iqr": 0.0, "jitter_std": 0.0}
    arr = np.array(med_per_frame)
    diffs = np.diff(arr)
    return {
        "mean": float(np.mean(arr)),
        "std": float(np.std(arr)),
        "p95": float(np.percentile(arr, 95)),
        "iqr": float(np.percentile(arr, 75) - np.percentile(arr, 25)),
        "jitter_std": float(np.std(diffs)) if len(diffs) > 0 else 0.0,
    }


def temporal_mse(frames: List[np.ndarray]) -> float:
    vals = []
    for i in range(len(frames) - 1):
        diff = frames[i].astype(np.float32) - frames[i + 1].astype(np.float32)
        vals.append(np.mean(diff ** 2))
    return float(np.mean(vals)) if vals else 0.0


def temporal_ssim(frames: List[np.ndarray]) -> float:
    vals = []
    for i in range(len(frames) - 1):
        g1 = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY)
        g2 = cv2.cvtColor(frames[i + 1], cv2.COLOR_BGR2GRAY)
        score = ssim(g1, g2, data_range=255)
        vals.append(score)
    return float(np.mean(vals)) if vals else 0.0


def find_stab_video(seq_name: str) -> Path:
    # 常见路径：output/RStab/<name>/<name>.avi 或子目录任意 avi；兼容 expname 带扩展名
    seq_stem = Path(seq_name).stem
    base = Path('output/RStab')
    candidates = []
    # 优先匹配同名子目录
    cand_dir = base / seq_name
    candidates += list(cand_dir.glob('*.avi')) + list(cand_dir.rglob('*.avi'))
    # 兜底：全局搜包含同 stem 的 avi
    candidates += [p for p in base.rglob('*.avi') if seq_stem in p.stem]
    return candidates[0] if candidates else None


def compare_one(orig_path: str, stab_path: str, max_frames: int = 600, resize: float = 0.5):
    orig_frames, fps_o, _ = load_video(orig_path, max_frames=max_frames, resize=resize)
    stab_frames, fps_s, _ = load_video(stab_path, max_frames=max_frames, resize=resize)
    n = min(len(orig_frames), len(stab_frames))
    if n < 2:
        raise ValueError(f"帧数不足，orig={len(orig_frames)}, stab={len(stab_frames)}")
    orig_frames = orig_frames[:n]
    stab_frames = stab_frames[:n]

    o_flow = flow_stats(orig_frames)
    s_flow = flow_stats(stab_frames)
    o_mse = temporal_mse(orig_frames)
    s_mse = temporal_mse(stab_frames)
    o_ssim = temporal_ssim(orig_frames)
    s_ssim = temporal_ssim(stab_frames)

    return {
        "orig_video": orig_path,
        "stab_video": stab_path,
        "num_frames_used": n,
        "fps_orig": fps_o,
        "fps_stab": fps_s,
        "mean_flow_orig": o_flow["mean"],
        "mean_flow_stab": s_flow["mean"],
        "std_flow_orig": o_flow["std"],
        "std_flow_stab": s_flow["std"],
        "flow_p95_orig": o_flow["p95"],
        "flow_p95_stab": s_flow["p95"],
        "flow_iqr_orig": o_flow["iqr"],
        "flow_iqr_stab": s_flow["iqr"],
        "flow_jitter_std_orig": o_flow["jitter_std"],
        "flow_jitter_std_stab": s_flow["jitter_std"],
        "temporal_mse_orig": o_mse,
        "temporal_mse_stab": s_mse,
        "temporal_ssim_orig": o_ssim,
        "temporal_ssim_stab": s_ssim,
        "stability_gain": 0.0 if o_flow["mean"] == 0 else (o_flow["mean"] - s_flow["mean"]) / o_flow["mean"],
    }


def run_all(video_root: str = "./input", desired_count: int = None, resize: float = 0.5, max_frames: int = 600):
    records = []
    seq_list = [p for p in sorted(Path(video_root).rglob('*')) if p.suffix.lower() in ['.avi', '.mp4', '.m4v', '.mov']]
    if desired_count is not None:
        if len(seq_list) < desired_count:
            print(f"⚠️ 当前只有 {len(seq_list)} 个序列，少于期望的 {desired_count} 个，请补齐后重跑以避免偶然性。")
        seq_list = seq_list[:desired_count]
    for p in seq_list:
        stab = find_stab_video(p.name)
        if not stab:
            print(f"跳过 {p.name}，未找到稳定视频输出")
            continue
        try:
            rec = compare_one(str(p), str(stab), max_frames=max_frames, resize=resize)
            rec["seq"] = p.name
            records.append(rec)
            print(f"done: {p.name}")
        except Exception as e:
            print(f"失败 {p.name}: {e}")
    if not records:
        print("未生成任何指标")
        return None, None
    df = pd.DataFrame(records)
    out_dir = Path('output/RStab')
    out_dir.mkdir(parents=True, exist_ok=True)
    out_path = out_dir / "stabilization_metrics.csv"
    df.to_csv(out_path, index=False)
    print("保存:", out_path)

    metric_cols = [c for c in df.columns if c not in {"seq", "orig_video", "stab_video"}]
    summary_df = df[metric_cols].agg(["mean", "std", "min", "max"])
    summary_path = out_dir / "stabilization_metrics_summary.csv"
    summary_df.to_csv(summary_path)
    print("保存统计:", summary_path)
    return df, summary_df


metrics_df, metrics_summary = run_all()
metrics_df


In [None]:
# 打包下载结果（稳定视频 + 指标）
!zip -qr rstab_results.zip output
from google.colab import files
if os.path.exists('rstab_results.zip'):
    files.download('rstab_results.zip')
else:
    print('未找到打包文件，请检查推理是否成功。')
