<a href="https://colab.research.google.com/github/Heo-JuYeong/Data_Analysis_Team3_Project/blob/main/Unet_FineTunning_30000.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **0. Basic Setting**

In [None]:
# 구글 드라이브 마운트

from google.colab import drive
drive.mount('/content/drive')

In [None]:
# pip install

!pip install segmentation-models-pytorch
!pip install albumentations

In [None]:
# 라이브러리 import

import os
import cv2
import torch
import numpy as np

from torch.utils.data import Dataset, DataLoader, random_split
from torch import nn, optim
from torchvision import transforms
# from torchvision.transforms import ToTensor

import segmentation_models_pytorch as smp

from PIL import Image
from tqdm import tqdm
from glob import glob
import matplotlib.pyplot as plt

In [None]:
# device setting
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# **2. Define Class** & **Methods**

In [None]:
# 정답 데이터(label)과 class mapping

GRAYSCALE_CLASS_MAP = {
    10: 0, 20: 1, 30: 2, 40: 3, 50: 4, 60: 5, 70: 6, 80: 7, 100: 8
}

def convert_mask(mask):
    converted = np.zeros_like(mask, dtype=np.uint8)
    for gray_val, class_idx in GRAYSCALE_CLASS_MAP.items():
        converted[mask == gray_val] = class_idx
    return converted

In [None]:
# 사용자 정의 Dataset 클래스

class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, mask_transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_files = sorted(glob(os.path.join(test_image_dir, '*.tif')))
        self.mask_files = sorted(glob(os.path.join(test_mask_dir, '*.tif')))
        self.transform = transform
        self.mask_transform = mask_transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_files[idx])

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path)

        if self.transform:
            image = self.transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)
        else:
            mask = np.array(mask, dtype=np.int64)
            mask = torch.from_numpy(mask)

        return image, mask

In [None]:
# EarlyStopping 클래스

class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = np.inf
        self.early_stop = False

    def __call__(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            print(f"EarlyStopping counter: {self.counter} / {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True

In [None]:
# 경로 설정

train_image_dir = '/content/your_path/train/images'
train_mask_dir = '/content/your_path/train/masks'
test_image_dir = '/content/your_path/test/images'
test_mask_dir = '/content/your_path/test/masks'

In [None]:
# Transform 정의

transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
])

In [None]:
# Dataset 정의

full_train_dataset = SegmentationDataset(train_image_dir, train_mask_dir, transform=transform)
test_dataset = SegmentationDataset(test_image_dir, test_mask_dir, transform=transform)

In [None]:
# train/val split

val_ratio = 0.2
train_size = int(len(full_train_dataset) * (1 - val_ratio))
val_size = len(full_train_dataset) - train_size
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

In [None]:
# DataLoader 생성

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=2)

# **2. 모델 정의 및 학습(Fine-Tunning)**

In [None]:
model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=3,
    classes=9
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
model.summary()

In [None]:
best_val_loss = float('inf')
patience = 7
# 3epoch 동안 개선 없을 시 조기 종료(총 10epoch 학습이라 1/3지점으로 설정)
trigger_times = 0
save_path = '/content/best_unet_model.pth'
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    loop = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]", leave=False)

    for images, masks in loop:
        images = images.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        loop.set_postfix(train_loss=loss.item())

    epoch_loss = running_loss / len(train_loader)

    # Validation step
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for val_images, val_masks in val_loader:
            val_images = val_images.to(device)
            val_masks = val_masks.to(device)

            val_outputs = model(val_images)
            val_loss += criterion(val_outputs, val_masks).item()

    val_loss /= len(val_loader)
    print(f"\nEpoch {epoch+1}, Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}")

    # EarlyStopping & Checkpoint
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), save_path)
        print(f"Validation loss improved. Saving model to {save_path}")
        trigger_times = 0
    else:
        trigger_times += 1
        print(f"No improvement in validation loss for {trigger_times} epochs.")

    if trigger_times >= patience:
        print("Early stopping triggered. Training stopped.")
        break

# **3. Evaluation**

In [None]:
# Intersection over Union (IoU)

def compute_iou(pred, target, num_classes):
    ious = []
    for cls in range(num_classes):
        pred_inds = (pred == cls)
        target_inds = (target == cls)
        intersection = (pred_inds & target_inds).sum()
        union = (pred_inds | target_inds).sum()
        if union == 0:
            ious.append(float('nan'))  # 해당 클래스가 없을 때
        else:
            ious.append(float(intersection) / float(union))
    return np.nanmean(ious)  # 평균 IoU

In [None]:
# Dice Coeifficient

def dice_coefficient(pred, target, num_classes):
    dices = []
    for cls in range(num_classes):
        pred_inds = (pred == cls)
        target_inds = (target == cls)
        intersection = (pred_inds & target_inds).sum()
        dice = (2. * intersection) / (pred_inds.sum() + target_inds.sum() + 1e-7)
        dices.append(dice)
    return np.mean(dices)

In [None]:
# Pixel Accuracy

def pixel_accuracy(pred, target):
    correct = (pred == target).sum()
    total = pred.size
    return correct / total

In [None]:
# Per-Class Accuracy

def per_class_accuracy(pred, target, num_classes):
    accuracies = []
    for cls in range(num_classes):
        cls_mask = (target == cls)
        if cls_mask.sum() == 0:
            accuracies.append(float('nan'))
            continue
        correct = ((pred == cls) & cls_mask).sum()
        total = cls_mask.sum()
        accuracies.append(correct / total)
    return accuracies  # type : list

In [None]:
# define input_tensor
test_image = Image.open("path/to/image.jpg").convert("RGB")
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])
input_tensor = transform(test_image).unsqueeze(0)  # shape: [1, 3, 256, 256]

# define ground_trutt
gt_mask = Image.open("path/to/mask.png")
gt_mask = np.array(gt_mask)  # shape: [H, W]
ground_truth = torch.tensor(gt_mask, dtype=torch.long)  # for loss / metric

In [None]:
# 모델 평가

model.eval()

with torch.no_grad():
    output = model(input_tensor.to(device))  # shape: [1, num_classes, H, W]
    pred_mask = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()  # shape: [H, W]

true_mask = ground_truth.squeeze().numpy()  # shape: [H, W]

# metric 계산
iou = compute_iou(pred_mask, true_mask, num_classes=9)
dice = dice_coefficient(pred_mask, true_mask, num_classes=9)
pixel_acc = pixel_accuracy(pred_mask, true_mask)

print(f"IoU: {iou:.4f}, Dice: {dice:.4f}, Pixel Accuracy: {pixel_acc:.4f}")

# **4.Visualization**

In [None]:
# 시각화 전처리

# 9개 클래스에 대해 고정된 colormap
CLASS_COLORS = np.array([
    [0, 0, 255],     # 건물 (blue)
    [128, 0, 128],   # 주차장 (purple)
    [0, 255, 255],   # 도로 (cyan)
    [0, 255, 0],     # 가로수 (green)
    [255, 255, 0],   # 논 (yellow)
    [255, 200, 0],   # 밭 (orange)
    [0, 128, 0],     # 산림 (dark green)
    [139, 69, 19],   # 나지 (brown)
    [128, 128, 128], # 비대상지 (gray)
], dtype=np.uint8)

def decode_segmap(mask):
    """Class index mask -> RGB image"""
    return CLASS_COLORS[mask]

In [None]:
# 시각화

with torch.no_grad():
    sample_img, sample_mask = test_dataset[10]
    sample_img_cuda = sample_img.unsqueeze(0).to(device)

    pred = model(sample_img_cuda)
    pred_mask = torch.argmax(pred.squeeze(), dim=0).cpu().numpy()  # [C,H,W] -> [H,W]
    gt_mask = sample_mask.cpu().numpy()

    plt.figure(figsize=(12,4))

    plt.subplot(1,3,1)
    plt.imshow(sample_img.permute(1,2,0).cpu().numpy())
    plt.title("Image")

    plt.subplot(1,3,2)
    plt.imshow(decode_segmap(gt_mask))
    plt.title("Ground Truth")

    plt.subplot(1,3,3)
    plt.imshow(decode_segmap(pred_mask))
    plt.title("Prediction")

    plt.show()