In [None]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

import cv2
import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms

########################################################################
# 0. 設定裝置
########################################################################
dev = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(dev)

########################################################################
# 1. UNet 分割模型定義與載入
########################################################################
class DoubleConv(nn.Module):
    """(Conv -> BN -> ReLU) * 2"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """下採樣：MaxPool + DoubleConv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """上採樣：Upsample/ConvTranspose + DoubleConv with skip connection"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)
    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    """1x1 卷積"""
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    """簡易 UNet：單通道輸入、單通道輸出"""
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        return self.outc(x)

# 載入事先訓練好的 U-Net 權重
seg_model = UNet(n_channels=1, n_classes=1).to(device)
seg_model_path = "unet_best_model_LV_HorizontalFlip_112.pth"
if os.path.exists(seg_model_path):
    seg_model.load_state_dict(torch.load(seg_model_path, map_location=device))
    seg_model.eval()
    print(f"成功載入分割模型權重：{seg_model_path}")
else:
    raise FileNotFoundError("找不到分割模型權重檔案！")

########################################################################
# 2. 後處理：取最大連通區 + 凸包
########################################################################
def apply_edge_prior(binary_mask):
    """
    1) 取最大連通區
    2) 取該區域的凸包
    3) 回傳精修後的 mask
    """
    if binary_mask.max() <= 1:
        mask_uint8 = (binary_mask * 255).astype(np.uint8)
    else:
        mask_uint8 = binary_mask.copy()

    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask_uint8, connectivity=8)
    if num_labels > 1:
        largest_label = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])  # 找最大區域
        largest_component = (labels == largest_label).astype(np.uint8) * 255
    else:
        largest_component = mask_uint8.copy()

    contours, _ = cv2.findContours(largest_component, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not contours:
        return largest_component
    largest_contour = max(contours, key=cv2.contourArea)
    hull = cv2.convexHull(largest_contour)

    refined_mask = np.zeros_like(largest_component)
    cv2.drawContours(refined_mask, [hull], -1, 255, thickness=-1)
    return refined_mask

def get_final_mask(binary_mask):
    return apply_edge_prior(binary_mask)

########################################################################
# 3. 產生遮罩
########################################################################
def precompute_video_masks(video_path, seg_model, transform, device,
                           threshold=0.5, mask_dir="precomputed_masks"):
    """
    讀取影片每一幀, 用 U-Net 產生 soft mask → 二值化 → refined mask
    最後存成 .npy，檔名: {video_basename}_masks.npy
    """
    if not os.path.exists(mask_dir):
        os.makedirs(mask_dir)

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print("無法開啟影片:", video_path)
        return None

    video_basename = os.path.splitext(os.path.basename(video_path))[0]
    save_path = os.path.join(mask_dir, f"{video_basename}_masks.npy")

    masks = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        # 1) 灰階化
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        pil_img = Image.fromarray(gray)

        # 2) transform => (1,1,112,112)
        inp = transform(pil_img).unsqueeze(0).to(device)

        # 3) U-Net 前向推論
        with torch.no_grad():
            seg_out = seg_model(inp)
            soft_mask = torch.sigmoid(seg_out).cpu().numpy().squeeze()

        # 4) 二值化 + 後處理
        binary_mask = (soft_mask > threshold).astype(np.uint8)
        refined_mask = get_final_mask(binary_mask)
        masks.append(refined_mask)

    cap.release()
    masks = np.stack(masks, axis=0)  # (num_frames, H, W)
    np.save(save_path, masks)
    print(f"產生遮罩檔: {save_path}")
    return save_path

########################################################################
# 4. 主程式: 對 CSV 清單中的影片產生遮罩
########################################################################
if __name__ == "__main__":
    # CSV檔與影片所在資料夾
    csv_file   = "a4c-video-dir/FileList.csv"
    videos_dir = "a4c-video-dir/Videos"
    mask_dir   = "precomputed_masks"

    # 分割模型預處理 (保持與訓練時一致)
    seg_transform = transforms.Compose([
        transforms.Resize((112, 112)),
        transforms.ToTensor()
    ])

    # 讀取 CSV, 對每支影片做遮罩推論
    df_all = pd.read_csv(csv_file)
    for idx, row in df_all.iterrows():
        fn = row["FileName"]
        if not fn.lower().endswith(".avi"):
            fn += ".avi"
        video_path = os.path.join(videos_dir, fn)

        base_name = os.path.splitext(os.path.basename(fn))[0]
        save_path = os.path.join(mask_dir, f"{base_name}_masks.npy")

        # 若檔案尚未產生, 則進行推論
        if not os.path.exists(save_path):
            print(f"[precompute] 處理 {fn} ...")
            precompute_video_masks(
                video_path, seg_model, seg_transform,
                device=device, threshold=0.5, mask_dir=mask_dir
            )
