In [42]:
import os
import json
import numpy as np
from tifffile import imread
import albumentations as A
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision.models.detection import MaskRCNN, FasterRCNN_ResNet50_FPN_Weights
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.models import ResNet50_Weights

import torch
from torch.optim import SGD, lr_scheduler

In [43]:
class MedicalDataset(Dataset):
    def __init__(self, root_dir, transform=None, is_test=False):
        self.root = root_dir
        self.transform = transform
        self.is_test = is_test
        self.samples = self._load_samples()
        
    def _load_samples(self):
        samples = []
        for img_dir in os.listdir(self.root):
            img_path = os.path.join(self.root, img_dir, 'image.tif')
            if not self.is_test:
                mask_paths = {
                    cls: os.path.join(self.root, img_dir, f'{cls}.tif') 
                    for cls in ['class1', 'class2', 'class3', 'class4']
                }
            samples.append({'image': img_path, 'masks': mask_paths if not self.is_test else None})
        return samples
    
    def _merge_masks(self, mask_dict):
        """合成四通道mask為單通道實例標籤"""
        h, w = imread(mask_dict['class1']).shape
        combined = np.zeros((h, w), dtype=np.uint8)
        for idx, cls in enumerate(['class1', 'class2', 'class3', 'class4'], 1):
            mask = imread(mask_dict[cls])
            combined[mask > 0] = idx
        return combined

    def __getitem__(self, idx):
        sample = self.samples[idx]
        image = imread(sample['image']).astype(np.float32) / 255.0
        
        if not self.is_test:
            mask = self._merge_masks(sample['masks'])
            if self.transform:
                transformed = self.transform(image=image, mask=mask)
                image, mask = transformed['image'], transformed['mask']
            return image.transpose(2,0,1), mask  # CHW格式
        else:
            return image.transpose(2,0,1)

    def __len__(self):
        return len(self.samples)

In [44]:
project_root = '..'
train_dir = os.path.join(project_root, 'dataset/train')
test_dir = os.path.join(project_root, 'dataset/test_release')

In [45]:
train_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.3),
    A.Rotate(limit=15, p=0.4),
    A.CLAHE(p=0.5),
    A.GridDistortion(p=0.2),
    A.RandomBrightnessContrast(p=0.3)
], additional_targets={'mask': 'mask'})

train_set = MedicalDataset(root_dir=train_dir, transform=train_transform)
test_set = MedicalDataset(root_dir=test_dir, is_test=True)

BATCH_SIZE = 4
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

In [None]:
class EnhancedMaskRCNN(MaskRCNN):
    def __init__(self, backbone, num_classes=None, **kwargs):
        super().__init__(backbone, num_classes, **kwargs)
        # 添加邊界感知分支
        self.boundary_head = self._build_boundary_head()
        
    def _build_boundary_head(self):
        layers = [
            torch.nn.Conv2d(256, 256, 3, padding=1),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(256, 256, 3, padding=1),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(256, 1, 1)  # 邊界檢測輸出
        ]
        return torch.nn.Sequential(*layers)

    def forward(self, images, targets=None):
        # 原始輸出
        outputs = super().forward(images, targets)
        
        # 邊界檢測分支
        if self.training and targets is not None:
            boundary_maps = self.boundary_head(outputs['features'])
            outputs['boundary_loss'] = self.compute_boundary_loss(boundary_maps, targets)
            
        return outputs

def create_model(num_classes=5, pretrained=True):
    backbone = torchvision.models.resnet50(weights=ResNet50_Weights.DEFAULT)
    # backbone = torchvision.models._utils.IntermediateLayerGetter(
    #     backbone, return_layers={'layer4': 'out'}
    # )
    # backbone = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
    
    model = EnhancedMaskRCNN(
        backbone,
        num_classes=num_classes,
        min_size=512,
        max_size=512
    )
    
    # 修改分類頭
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    
    # 修改mask頭
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    model.roi_heads.mask_predictor = MaskRCNNPredictor(
        in_features_mask, 256, num_classes
    )
    
    return model


In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = create_model().to(device)
params = [p for p in model.parameters() if p.requires_grad]

optimizer = SGD(params, lr=0.003, momentum=0.9, weight_decay=0.0005)
lr_sched = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# 混合損失函數
def hybrid_loss(pred_masks, gt_masks, boundaries):
    mask_loss = torch.nn.functional.binary_cross_entropy_with_logits(pred_masks, gt_masks)
    boundary_loss = torch.nn.functional.mse_loss(pred_masks, boundaries)
    return mask_loss + 0.3 * boundary_loss

NUM_EPOCHS = 30

for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0.0
    
    for images, targets in train_loader:
        images = list(img.to(device) for img in images)
        targets = [{'masks': t.to(device)} for t in targets]
        
        optimizer.zero_grad()
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        
        losses.backward()
        optimizer.step()
        
        total_loss += losses.item()
    
    lr_sched.step()
    print(f"Epoch {epoch+1} | Avg Loss: {total_loss/len(train_loader):.4f}")



AttributeError: 'ResNet' object has no attribute 'out_channels'

In [None]:
def masks_to_coco(results, image_ids):
    coco_results = []
    for img_id, output in zip(image_ids, results):
        for score, mask, label in zip(output['scores'], output['masks'], output['labels']):
            rle = binary_mask_to_rle(mask)
            coco_results.append({
                "image_id": img_id,
                "category_id": label.item(),
                "segmentation": rle,
                "score": score.item()
            })
    return coco_results

def binary_mask_to_rle(mask):
    # RLE編碼實現
    pixels = mask.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return {'size': list(mask.shape[-2:]), 'counts': runs.tolist()}


In [None]:
model.eval()
test_loader = DataLoader(test_set, batch_size=2, shuffle=False)

results = []
with torch.no_grad():
    for batch in test_loader:
        outputs = model(batch.to(device))
        results.extend(outputs)

# 生成最終提交文件
with open('test-results.json', 'w') as f:
    json.dump(masks_to_coco(results, test_set.image_ids), f)

print("Submission file generated!")