In [2]:
import segmentation_models_pytorch as smp
import torch.nn as nn
import torch.optim as optim
import torch
from torchinfo import summary
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import torchvision.transforms as T
import os
from PIL import Image

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [4]:
model = smp.Unet('resnet34', classes=3, activation='softmax').to(device)

In [5]:
summary(model)

Layer (type:depth-idx)                        Param #
Unet                                          --
├─ResNetEncoder: 1-1                          --
│    └─Conv2d: 2-1                            9,408
│    └─BatchNorm2d: 2-2                       128
│    └─ReLU: 2-3                              --
│    └─MaxPool2d: 2-4                         --
│    └─Sequential: 2-5                        --
│    │    └─BasicBlock: 3-1                   73,984
│    │    └─BasicBlock: 3-2                   73,984
│    │    └─BasicBlock: 3-3                   73,984
│    └─Sequential: 2-6                        --
│    │    └─BasicBlock: 3-4                   230,144
│    │    └─BasicBlock: 3-5                   295,424
│    │    └─BasicBlock: 3-6                   295,424
│    │    └─BasicBlock: 3-7                   295,424
│    └─Sequential: 2-7                        --
│    │    └─BasicBlock: 3-8                   919,040
│    │    └─BasicBlock: 3-9                   1,180,672
│    │    └─Basi

In [6]:
for param in model.encoder.parameters():
    param.requires_grad = False

In [7]:
loss_fn = nn.CrossEntropyLoss()  
optimizer = torch.optim.Adam([
    {'params': model.decoder.parameters()},
    {'params': model.segmentation_head.parameters()}
], lr=1e-4)

In [8]:
mask_dir = os.path.join(os.getcwd(),"labels/final_labels")
image_dir= os.path.join(os.getcwd(),"images_224_rename")

In [9]:
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, mask_transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.mask_transform = mask_transform
        self.image_filenames = sorted(os.listdir(mask_dir))  # 假設 image 和 mask 檔名相同

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_path = os.path.join(self.mask_dir, self.image_filenames[idx])
        
        image = Image.open(image_path.replace(".png", ".jpg")).convert("RGB")
        mask = Image.open(mask_path).convert("L")  # 灰階，單一通道

        if self.transform:
            image = self.transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)

        return image, mask

In [10]:
image_transform = T.Compose([
    T.ToTensor(),  # [0,1] 範圍
    T.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet 標準
                std=[0.229, 0.224, 0.225])
])

In [11]:
# 轉換 mask 成 tensor
mask_transform = T.Compose([
    T.ToTensor()
])



In [12]:
# 建立 Dataset
dataset = SegmentationDataset(
    image_dir=image_dir,
    mask_dir=mask_dir,
    transform=image_transform,
    mask_transform=mask_transform
)

In [13]:
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

In [14]:
for imgs, masks in dataloader:
    print(imgs.shape, masks.shape)

torch.Size([4, 3, 224, 224]) torch.Size([4, 1, 224, 224])
torch.Size([4, 3, 224, 224]) torch.Size([4, 1, 224, 224])
torch.Size([4, 3, 224, 224]) torch.Size([4, 1, 224, 224])
torch.Size([2, 3, 224, 224]) torch.Size([2, 1, 224, 224])


In [15]:
def pixel_accuracy(preds, masks):
    """
    preds: prediction, shape = (B, H, W)
    masks: ground truth mask，shape = (B, H, W)
    """
    correct = (preds == masks).float()
    acc = correct.sum() / correct.numel()
    return acc

In [17]:
model.train()
for epoch in range(10):
    model.train()
    total_loss = 0
    for imgs, masks in dataloader:
        imgs = imgs.to(device)
        masks = masks.long().squeeze(1).to(device)  # 轉回 int 類型 + squeeze channel

        outputs = model(imgs)
        loss = loss_fn(outputs, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    with torch.inference_mode():
        val_loss = 0
        val_acc = 0
        val_batches = 0
    
        for imgs, masks in dataloader:
            imgs = imgs.to(device)
            masks = masks.long().squeeze(1).to(device)
            
            outputs = model(imgs)
            preds = torch.argmax(outputs, dim=1)
            loss = loss_fn(outputs, masks)
            acc = pixel_accuracy(preds, masks)
    
            val_loss += loss.item()
            val_acc += acc.item()
            val_batches += 1
    
        avg_loss = val_loss / val_batches
        avg_acc = val_acc / val_batches
        print(f"Epoch [{epoch+1}], Accuracy: {avg_acc:.4f}, Loss: {avg_loss:.4f}")

Epoch [1], Accuracy: 0.8550, Loss: 0.9290
Epoch [2], Accuracy: 0.8851, Loss: 0.9144
Epoch [3], Accuracy: 0.9127, Loss: 0.8998
Epoch [4], Accuracy: 0.9309, Loss: 0.8865
Epoch [5], Accuracy: 0.9495, Loss: 0.8723
Epoch [6], Accuracy: 0.9547, Loss: 0.8618
Epoch [7], Accuracy: 0.9569, Loss: 0.8514
Epoch [8], Accuracy: 0.9612, Loss: 0.8410
Epoch [9], Accuracy: 0.9641, Loss: 0.8320
Epoch [10], Accuracy: 0.9640, Loss: 0.8247
