In [1]:
!unzip hw3-data-release.zip

Archive:  hw3-data-release.zip
   creating: hw3-data-release/
  inflating: hw3-data-release/test_image_name_to_ids.json  
   creating: hw3-data-release/test_release/
  inflating: hw3-data-release/test_release/009510f3-2d1a-435e-b733-90f5450baaca.tif  
  inflating: hw3-data-release/test_release/01ce9840-ea96-495e-8fd1-696a734956af.tif  
  inflating: hw3-data-release/test_release/02e1b69a-2441-4e23-a61e-4b36617efd06.tif  
  inflating: hw3-data-release/test_release/06efce1e-bec6-4314-a308-a76815507c6d.tif  
  inflating: hw3-data-release/test_release/07517165-7bd5-4a30-8433-3c8830358bc0.tif  
  inflating: hw3-data-release/test_release/0bd26f8e-81f6-4267-82ad-740e2786393a.tif  
  inflating: hw3-data-release/test_release/0fb9d9c0-f786-49c5-b485-b8dfdcce929c.tif  
  inflating: hw3-data-release/test_release/1059cdc7-e5cf-4c32-9e26-2c7770997301.tif  
  inflating: hw3-data-release/test_release/14792cd4-ce7e-44fa-a63b-9e5663e2f479.tif  
  inflating: hw3-data-release/test_release/15bd2c22-962f-450

In [4]:
!pip install torch torchvision numpy opencv-python tifffile tqdm Pillow pycocotools tifffile matplotlib

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [None]:
import os
import json
import torch
import torchvision
import cv2
import numpy as np
import tifffile
import time
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.models.detection import maskrcnn_resnet50_fpn_v2
from torchvision.transforms import functional as TF

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ROOT = './hw3-data-release'
TRAIN_DIR = os.path.join(ROOT, 'train')
BATCH_SIZE = 2
NUM_CLASSES = 5
EPOCHS = 10
VAL_RATIO = 0.2  # 验证集比例
DEBUG_SAMPLES = None  # 调试时使用的样本数量

class HW3Dataset(Dataset):
    def __init__(self, root, max_samples=None):
        self.root = root
        self.samples = sorted(os.listdir(root))
        if max_samples is not None:
            self.samples = self.samples[:max_samples]
        print(f"使用 {len(self.samples)} 個樣本進行訓練")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample_id = self.samples[idx]
        sample_path = os.path.join(self.root, sample_id)

        # ✅ 用 cv2 取代 PIL/tifffile
        image_path = os.path.join(sample_path, 'image.tif')
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        if image is None:
            print(f"⚠️ 無法讀取圖像文件 {image_path}，跳過")
            return self.__getitem__((idx + 1) % len(self))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        masks, labels = [], []

        for class_id in range(1, 5):
            mask_path = os.path.join(sample_path, f'class{class_id}.tif')
            if not os.path.exists(mask_path):
                continue
            try:
                raw_mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
                instance_ids = np.unique(raw_mask)
                instance_ids = instance_ids[instance_ids != 0]
                for inst_id in instance_ids:
                    bin_mask = (raw_mask == inst_id).astype(np.uint8)
                    if bin_mask.sum() == 0:
                        continue
                    masks.append(torch.as_tensor(bin_mask, dtype=torch.uint8))
                    labels.append(class_id)
            except Exception as e:
                print(f"無法讀取掩碼文件 {mask_path}: {str(e)}")
                continue

        boxes = []
        for mask in masks:
            y_indices, x_indices = torch.where(mask > 0)
            if len(y_indices) == 0 or len(x_indices) == 0:
                continue
            x_min = float(x_indices.min())
            y_min = float(y_indices.min())
            x_max = float(x_indices.max())
            y_max = float(y_indices.max())
            if x_max > x_min and y_max > y_min:
                boxes.append([x_min, y_min, x_max, y_max])

        target = {
            'boxes': torch.tensor(boxes, dtype=torch.float32) if boxes else torch.zeros((0, 4), dtype=torch.float32),
            'labels': torch.tensor(labels, dtype=torch.int64),
            'masks': torch.stack(masks) if masks else torch.zeros((0, image.shape[0], image.shape[1]), dtype=torch.uint8),
            'image_id': torch.tensor([idx])
        }

        image = TF.to_tensor(image)
        return image, target



def get_model():
    model = maskrcnn_resnet50_fpn_v2(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, NUM_CLASSES)
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden = 256
    model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor(in_features_mask, hidden, NUM_CLASSES)
    return model

def validate(model, val_loader):
    model.eval()
    total_val_loss = 0

    with torch.no_grad():
        for images, targets in tqdm(val_loader, desc='Validating'):
            images = [img.to(DEVICE) for img in images]
            targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

            outputs = model(images, targets)
            if isinstance(outputs, dict):
                losses = sum(loss for loss in outputs.values())
            else:
                # 如果模型返回的是预测结果而不是损失字典
                continue
            total_val_loss += losses.item()

    avg_val_loss = total_val_loss / len(val_loader)
    return avg_val_loss

def train():
    # 创建数据集并划分训练集和验证集
    full_dataset = HW3Dataset(TRAIN_DIR, max_samples=DEBUG_SAMPLES)
    val_size = int(len(full_dataset) * VAL_RATIO)
    train_size = len(full_dataset) - val_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

    print(f"训练集大小: {len(train_dataset)}, 验证集大小: {len(val_dataset)}")

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))

    model = get_model().to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    best_val_loss = float('inf')

    for epoch in range(EPOCHS):
        print(f"\n開始第 {epoch+1}/{EPOCHS} 個 epoch")
        model.train()
        epoch_loss = 0

        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{EPOCHS}')

        for images, targets in pbar:
            images = [img.to(DEVICE) for img in images]
            targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            epoch_loss += losses.item()

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()

            pbar.set_postfix({'loss': f'{losses.item():.4f}'})

        avg_train_loss = epoch_loss / len(train_loader)
        print(f'Epoch {epoch+1} 訓練完成，平均訓練損失: {avg_train_loss:.4f}')

        # 验证
        val_loss = validate(model, val_loader)
        print(f'Epoch {epoch+1} 驗證完成，平均驗證損失: {val_loss:.4f}')

        # 保存最佳模型
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_maskrcnn.pth')
            print(f'保存最佳模型，驗證損失: {val_loss:.4f}')

        # 保存最新模型
        torch.save(model.state_dict(), 'latest_maskrcnn.pth')

if __name__ == '__main__':
    train()


使用 209 個樣本進行訓練
训练集大小: 168, 验证集大小: 41





開始第 1/10 個 epoch


Epoch 1/10:  79%|███████▊  | 66/84 [00:50<00:18,  1.01s/it, loss=1.2311]

In [None]:
import os
import json
import torch
import numpy as np
from PIL import Image
from torchvision.transforms import functional as TF
from torchvision.models.detection import maskrcnn_resnet50_fpn_v2
from pycocotools import mask as mask_utils
from tqdm import tqdm

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
NUM_CLASSES = 5
TEST_DIR = './hw3-data-release/test_release'
ID_MAP_PATH = './hw3-data-release/test_image_name_to_ids.json'
WEIGHT_PATH = './best_maskrcnn.pth'
OUTPUT_PATH = './test-results.json'

def encode_binary_mask(mask):
    mask = np.asfortranarray(mask.astype(np.uint8))
    encoded = mask_utils.encode(mask)
    encoded['counts'] = encoded['counts'].decode('utf-8')
    return encoded

def get_model():
    model = maskrcnn_resnet50_fpn_v2(pretrained=False)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor.cls_score = torch.nn.Linear(in_features, NUM_CLASSES)
    model.roi_heads.box_predictor.bbox_pred = torch.nn.Linear(in_features, NUM_CLASSES * 4)
    model.roi_heads.mask_predictor.conv5_mask = torch.nn.Conv2d(256, 256, kernel_size=2, padding=0)
    model.roi_heads.mask_predictor.mask_fcn_logits = torch.nn.Conv2d(256, NUM_CLASSES, kernel_size=1)
    model.load_state_dict(torch.load(WEIGHT_PATH, map_location=DEVICE))
    model.to(DEVICE).eval()
    return model

def predict():
    with open(ID_MAP_PATH, 'r') as f:
        image_info_list = json.load(f)
        image_name_to_id = {info['file_name']: info['id'] for info in image_info_list}

    model = get_model()
    results = []

    for file_name in tqdm(sorted(os.listdir(TEST_DIR))):
        if not file_name.endswith('.tif'):
            continue
        image_path = os.path.join(TEST_DIR, file_name)
        image_id = image_name_to_id[file_name]

        image = Image.open(image_path).convert('RGB')
        tensor = TF.to_tensor(image).to(DEVICE)
        with torch.no_grad():
            outputs = model([tensor])[0]

        for i in range(len(outputs['scores'])):
            bin_mask = outputs['masks'][i, 0].cpu().numpy() > 0.5

            rle = mask_utils.encode(np.asfortranarray(bin_mask.astype(np.uint8)))
            rle['counts'] = rle['counts'].decode('utf-8')

            bbox = outputs['boxes'][i].detach().cpu().numpy()
            x1, y1, x2, y2 = bbox
            width = x2 - x1
            height = y2 - y1
            bbox = [float(x1), float(y1), float(width), float(height)]

            result = {
                'image_id': int(image_id),
                'category_id': int(outputs['labels'][i]),
                'bbox': bbox,
                'score': float(outputs['scores'][i]),
                'segmentation': {
                    'size': bin_mask.shape,
                    'counts': rle['counts']
                }
            }

            results.append(result)

    with open(OUTPUT_PATH, 'w') as f:
        json.dump(results, f)
    print(f"\n✅ Saved to: {OUTPUT_PATH}")

if __name__ == '__main__':
    predict()
