In [None]:
!pip install segmentation_models_pytorch

In [None]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import segmentation_models_pytorch as smp
from torchvision.models.segmentation import deeplabv3_resnet50
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image


# 第二部分

In [None]:
class ETTDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform

        # 檢查交集
        image_files = {os.path.splitext(f)[0]: f for f in os.listdir(image_dir)}
        mask_files = {os.path.splitext(f)[0]: f for f in os.listdir(mask_dir)}
        self.filenames = sorted(list(set(image_files.keys()) & set(mask_files.keys())))

        if len(self.filenames) == 0:
            print("找不到圖像與遮罩對應的交集檔案。")

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        fname = self.filenames[idx]
        for ext in [".jpg", ".png", ".jpeg"]:
            img_path = os.path.join(self.image_dir, fname + ext)
            if os.path.exists(img_path):
                break
        else:
            raise FileNotFoundError(f"找不到圖片：{fname} (支援 jpg/png/jpeg)")

        mask_path = os.path.join(self.mask_dir, fname + ".png")
        if not os.path.exists(mask_path):
            raise FileNotFoundError(f"找不到遮罩：{fname}.png")

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask

In [None]:
def get_model(name):
    if name == "unet":
        return smp.Unet(encoder_name="resnet34", encoder_weights=None, in_channels=3, classes=1)
    elif name == "unetpp":
        return smp.UnetPlusPlus(encoder_name="resnet34", encoder_weights=None, in_channels=3, classes=1)
    elif name == "deeplab":
        model = deeplabv3_resnet50(pretrained=False, num_classes=1)
        return model
    else:
        raise ValueError("模型名稱必須為 'unet'、'unetpp' 或 'deeplab'")


In [None]:
def train_and_eval(model_name, base_dir, fold="Fold1", epochs=3, batch_size=4):
    print(f"開始訓練模型：{model_name.upper()}，資料集：{fold}")
    model = get_model(model_name)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.BCEWithLogitsLoss()

    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])

    # 載入圖片與遮罩資料夾
    train_images = os.path.join(base_dir, fold, "train")
    train_masks  = os.path.join(base_dir, fold, "trainannot")
    val_images   = os.path.join(base_dir, fold, "val")
    val_masks    = os.path.join(base_dir, fold, "valannot")

    train_dataset = ETTDataset(train_images, train_masks, transform)
    val_dataset = ETTDataset(val_images, val_masks, transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for images, masks in train_loader:
            #images, masks = images.cuda(), masks.cuda()
            optimizer.zero_grad()
            outputs = model(images)
            if isinstance(outputs, dict):  # 對 DeepLabV3 做特別處理
              outputs = outputs["out"]
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # 驗證階段
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for images, masks in val_loader:
                #images, masks = images.cuda(), masks.cuda()
                outputs = model(images)
                if isinstance(outputs, dict):  # 對 DeepLabV3 做特別處理
                  outputs = outputs["out"]
                loss = criterion(outputs, masks)
                val_loss += loss.item()

        print(f"[{model_name}] Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")


In [None]:
base_dir = "/content"

train_and_eval("unet", base_dir, fold="Fold1")
train_and_eval("deeplab", base_dir, fold="Fold1")
train_and_eval("unetpp", base_dir, fold="Fold1")