In [14]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os
import numpy as np
import re
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def natural_sort_key(filename):
    return [int(s) if s.isdigit() else s for s in re.split("(\d+)", filename)]

def addPadding(srcShapeTensor, tensor_whose_shape_isTobechanged):

    if(srcShapeTensor.shape != tensor_whose_shape_isTobechanged.shape):
        target = torch.zeros(srcShapeTensor.shape)
        target[:, :, :tensor_whose_shape_isTobechanged.shape[2],
               :tensor_whose_shape_isTobechanged.shape[3]] = tensor_whose_shape_isTobechanged
        return target.to(device)
    return tensor_whose_shape_isTobechanged.to(device)

class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, inputs, targets):
        # 使用 Sigmoid 將輸出壓縮到 [0, 1]
        # inputs = torch.sigmoid(inputs)

        # 展平
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        # 計算交集和
        intersection = (inputs * targets).sum()
        dice_coefficient = (2. * intersection + self.smooth) / (inputs.sum() + targets.sum() + self.smooth)

        # 計算 Dice Loss
        dice_loss = 1 - dice_coefficient
        return dice_loss

class CombinedLoss(nn.Module):
    def __init__(self, weight_dice=0.5, weight_ce=0.5, smooth=1e-6):
        super(CombinedLoss, self).__init__()
        self.dice_loss = DiceLoss(smooth=smooth)
        self.cross_entropy = nn.BCEWithLogitsLoss()
        self.weight_dice = weight_dice
        self.weight_ce = weight_ce

    def forward(self, inputs, targets):
        dice_loss = self.dice_loss(inputs, targets)
        ce_loss = self.cross_entropy(inputs, targets)
        combined_loss = self.weight_dice * dice_loss + self.weight_ce * ce_loss
        return combined_loss

class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_filenames = sorted(os.listdir(image_dir), key=natural_sort_key)
        self.mask_filenames = sorted(os.listdir(mask_dir), key=natural_sort_key)

        if self.image_filenames != self.mask_filenames:
            raise ValueError("Image and mask filenames do not match!")

        self.preprocess = transforms.Compose([
            # transforms.Resize((256, 256)),
            transforms.ToTensor(),
        ])
        self.transform = transform
        self.color_transform = transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.05, hue=0.05)

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_filenames[idx])
        image = Image.open(image_path).convert("RGB")
        image = np.array(image).astype(np.float32)
        mask = Image.open(mask_path).convert("L")
        mask = np.array(mask)
        mask = np.where(mask > 0, 0.999999, 0).astype(np.float32)

        if self.transform:
          t = self.transform(image=image, mask=mask)
          image = t['image']
          mask = t['mask']

        # image = self.preprocess(image)
        # mask = self.preprocess(mask)

        # if self.transform:
        #     image = self.color_transform(image)

        #     image_and_mask = torch.cat((image, mask))
        #     image_and_mask = self.transform(torch.cat((image, mask), dim=0))
        #     image = image_and_mask[0:3]
        #     mask = image_and_mask[3].unsqueeze(0)

        return image, mask


class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.enc1 = self.double_conv(3, 8)
        self.enc2 = self.double_conv(8, 8)
        self.enc3 = self.double_conv(8, 8)
        self.enc4 = self.double_conv(8, 8)
        self.enc5 = self.double_conv(8, 8)

        self.up5 = self.up_trans(8, 8)
        self.att5 = AttentionBlock(F_g=8, F_l=8, F_int=4)
        self.dec5 = self.double_conv(16, 8)

        self.up4 = self.up_trans(8, 8)
        self.att4 = AttentionBlock(F_g=8, F_l=8, F_int=4)
        self.dec4 = self.double_conv(16, 8)

        self.up3 = self.up_trans(8, 8)
        self.att3 = AttentionBlock(F_g=8, F_l=8, F_int=4)
        self.dec3 = self.double_conv(16, 8)

        self.up2 = self.up_trans(8, 8)
        self.att2 = AttentionBlock(F_g=8, F_l=8, F_int=4)
        self.dec2 = self.double_conv(16, 8)

        self.final_conv = nn.Conv2d(8, 1, kernel_size=1)

    def double_conv(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout2d(0.3),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout2d(0.3)
        )

    def up_trans(self, in_channels, out_channels):
        return nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size = 2,
            stride = 2
        )

    def crop_and_concat(self, upsampled, bypass):
        return torch.cat((upsampled, bypass), dim=1)

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(nn.MaxPool2d(2)(enc1))
        enc3 = self.enc3(nn.MaxPool2d(2)(enc2))
        enc4 = self.enc4(nn.MaxPool2d(2)(enc3))
        enc5 = self.enc5(nn.MaxPool2d(2)(enc4))

        x = self.up5(enc5)
        x = addPadding(enc4, x)
        x = self.att5(x, enc4)
        x = self.dec5(self.crop_and_concat(x, enc4))

        x = self.up4(x)
        x = addPadding(enc3, x)
        x = self.att4(x, enc3)
        x = self.dec4(self.crop_and_concat(x, enc3))

        x = self.up3(x)
        x = addPadding(enc2, x)
        x = self.att3(x, enc2)
        x = self.dec3(self.crop_and_concat(x, enc2))

        x = self.up2(enc2)
        x = addPadding(enc1, x)
        x = self.att2(x, enc1)
        x = self.dec2(self.crop_and_concat(x, enc1))

        output = torch.sigmoid(self.final_conv(x))
        return output


def calculate_iou(output, target, threshold=0.5):
    output = output > threshold
    target = target > 0

    output_flat = output.view(-1)
    target_flat = target.view(-1)

    intersection = (output_flat & target_flat).sum().float()
    union = (output_flat | target_flat).sum().float()
    if union == 0:
        print("UNION SHOULD NOT BE ZERO, THERE MUST BE SOME MISTAKE")
        return 1

    iou = intersection / union
    return iou.item()


def tensor_to_required_image(image, input_path, input_index):
    image = image > 0.5

    ref_image = Image.open(f"{input_path}/{input_index}.png")
    ref_shape = ref_image.size
    ref_shape = (ref_shape[0] * 2, ref_shape[1])

    image = Image.fromarray((image * 255).astype(np.uint8))
    image = image.resize(ref_shape)
    return image

In [16]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

transform = A.Compose([
    A.Resize(512, 512),
    A.RandomSizedCrop(min_max_height=(128, 512), height=512, width=512, p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    # A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
    A.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03, p=0.3),
    A.OpticalDistortion(distort_limit=0.05, shift_limit=0.05, p=0.3),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.2),
    A.GaussianBlur(blur_limit=(3, 5), p=0.3),
    A.GaussNoise(var_limit=(10.0, 50.0), p=0.2),
    A.ChannelShuffle(p=0.1),
    ToTensorV2()
])

# train_transform = transforms.Compose([
#     transforms.RandomHorizontalFlip(),
#     transforms.RandomRotation(30),
#     transforms.RandomResizedCrop(256, scale=(0.3, 1.0)),
#     transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
# ])

directory_path = "/content/drive/MyDrive/dip_final/training_dataset"
train_dataset = SegmentationDataset(f"{directory_path}/image", f"{directory_path}/mask", transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)

directory_path = "/content/drive/MyDrive/dip_final/testing_dataset"
test_dataset = SegmentationDataset(f"{directory_path}/image", f"{directory_path}/mask")
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)


model = UNet().to(device)
criterion = CombinedLoss(weight_dice=0.5, weight_ce=0.5)
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)

  A.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03, p=0.3),


In [17]:
phases = ["train", "valid"]
data_loader = {"train": train_dataloader, "valid": test_dataloader}
best_mean_iou = 0
best_epoch = 0

num_epochs = 100
for epoch in range(num_epochs):
    start_time = time.time()
    for phase in phases:
        if phase == "train":
            model.train()
            train_loss = 0
            train_iou_scores = []
        elif phase == "valid":
            model.eval()
            val_loss = 0
            val_iou_scores = []

        for images, masks in data_loader[phase]:
            images, masks = images.to(device), masks.to(device)

            if phase == "train":
                outputs = model(images)

                masks = masks.unsqueeze(1)
                loss = criterion(outputs, masks)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                train_loss += loss.item()

                iou = calculate_iou(outputs, masks)
                train_iou_scores.append(iou)
            elif phase == "valid":
                with torch.no_grad():
                    images = images.permute(0,3,1,2)
                    masks = masks.unsqueeze(1)
                    outputs = model(images)

                    loss = criterion(outputs, masks)

                    val_loss += loss.item()

                    iou = calculate_iou(outputs, masks)
                    val_iou_scores.append(iou)

    end_time = time.time()
    train_mean_iou = sum(train_iou_scores) / len(train_iou_scores)
    val_mean_iou = sum(val_iou_scores) / len(val_iou_scores)
    print(f"Epoch {epoch+1}/{num_epochs}, Time: {end_time - start_time:.2f} Training Loss: {train_loss:.4f}, Training Mean IoU: {train_mean_iou:.4f}, Validation Loss: {val_loss:.4f}, Validation Mean IoU: {val_mean_iou:.4f}")

    if val_mean_iou > best_mean_iou:
        best_mean_iou = val_mean_iou
        best_epoch = epoch
        torch.save(model.state_dict(), '/content/drive/MyDrive/dip_final/best_epoch.pth')

print("Training Complete!")
print(f"Best Mean IoU: {best_mean_iou:.4f}, at epoch {best_epoch}")

Epoch 1/100, Time: 5.89 Training Loss: 16.5976, Training Mean IoU: 0.2843, Validation Loss: 13.7168, Validation Mean IoU: 0.0676
Epoch 2/100, Time: 6.23 Training Loss: 16.4250, Training Mean IoU: 0.2860, Validation Loss: 13.6904, Validation Mean IoU: 0.1489
Epoch 3/100, Time: 5.60 Training Loss: 16.3409, Training Mean IoU: 0.3915, Validation Loss: 13.6497, Validation Mean IoU: 0.2246
Epoch 4/100, Time: 6.20 Training Loss: 16.2108, Training Mean IoU: 0.4565, Validation Loss: 13.6404, Validation Mean IoU: 0.2173
Epoch 5/100, Time: 6.63 Training Loss: 16.1509, Training Mean IoU: 0.4465, Validation Loss: 13.6697, Validation Mean IoU: 0.1287
Epoch 6/100, Time: 6.23 Training Loss: 16.0816, Training Mean IoU: 0.5124, Validation Loss: 13.6635, Validation Mean IoU: 0.1403
Epoch 7/100, Time: 5.58 Training Loss: 15.8407, Training Mean IoU: 0.5309, Validation Loss: 13.6344, Validation Mean IoU: 0.2055
Epoch 8/100, Time: 5.45 Training Loss: 15.8094, Training Mean IoU: 0.5770, Validation Loss: 13.63

In [18]:
model.load_state_dict(torch.load('/content/drive/MyDrive/dip_final/best_epoch.pth', weights_only=True))
model.eval()
iou_scores = []
output_dir = "/content/drive/MyDrive/dip_final/output"
os.makedirs(output_dir, exist_ok=True)

with torch.no_grad():
    for idx, (images, masks) in enumerate(test_dataloader):
        images, masks = images.to(device), masks.to(device)

        outputs = model(images)
        iou = calculate_iou(outputs, masks)

        iou_scores.append(iou)

        output_image_path = os.path.join(output_dir, f"output_{idx + 1}.png")

        output_and_mask = torch.cat((masks, outputs), dim=3).squeeze().cpu().numpy()
        output_and_mask_image = tensor_to_required_image(output_and_mask, f"{directory_path}/mask", idx + 1)
        output_and_mask_image.save(output_image_path)

        print(f"Saved output and mask for image {idx + 1}. IoU: {iou:.4f}")

mean_iou = sum(iou_scores) / len(iou_scores)
print(f"Mean IoU: {mean_iou:.4f}")

RuntimeError: Given groups=1, weight of size [8, 3, 3, 3], expected input[1, 1152, 1536, 3] to have 3 channels, but got 1152 channels instead