## Import Dependencies

In [None]:
!pip uninstall -y transformers accelerate tokenizers numpy

!pip install numpy==1.26.4
!pip install -U transformers accelerate tokenizers evaluate torchmetrics

In [None]:
import numpy
import scipy
print(f"Numpy version: {numpy.__version__}")
print(f"Scipy version: {scipy.__version__}")

In [None]:
import os
import warnings
warnings.filterwarnings('ignore')
from glob import glob
from tqdm import tqdm
import time

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt

import cv2
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

## Dataset Loader

In [None]:
ls /kaggle/input

NUM_CLASSES = 10
BATCH_SIZE = 16
EPOCHS = 25

MODEL_NAME = "best_model_mask2former.pth"
MODEL_NAME_FINETUNED = "best_model_mask2former_finetuned.pth"
DIR_MODEL = "/kaggle/input/mask2former/pytorch/default/1/"

In [None]:
import glob
import re
def sort_files_numerically(directory):
    files = os.listdir(directory)
    files_sorted = sorted(files, key=lambda x: int(re.search(r'\d+', x).group()))
    return [os.path.join(directory, f) for f in files_sorted]

ROOT_INP = "/kaggle/input/indo-flood-segmentation-dataset"

train_image_paths = sort_files_numerically(ROOT_INP+'/train/train-org-img')
train_mask_paths = sort_files_numerically(ROOT_INP+'/train/train-label-img')

val_image_paths = sort_files_numerically(ROOT_INP+'/val/val-org-img')
val_mask_paths = sort_files_numerically(ROOT_INP+'/val/val-label-img')

test_image_paths = sort_files_numerically(ROOT_INP+'/test/test-org-img')
test_mask_paths = sort_files_numerically(ROOT_INP+'/test/test-label-img')

In [None]:
class FloodDataset(Dataset):
    def __init__(self, image_path, mask_path, transform=None, image_size=(512, 512)):
        self.image_path = image_path
        self.mask_path = mask_path
        self.transform = transform
        self.image_size = image_size

    def __len__(self):
        return len(self.image_path)

    def __getitem__(self, idx):
        image = Image.open(self.image_path[idx]).convert('RGB')
        mask = Image.open(self.mask_path[idx]).convert('L')

        if self.transform:
            image = self.transform(image)

        mask = mask.resize(self.image_size, Image.NEAREST)
        mask = np.array(mask, dtype=np.int64)
        mask = np.clip(mask, 0, 9)
        mask = torch.from_numpy(mask).long()

        return image, mask

## Dataset Prep

In [None]:
train_test_transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = FloodDataset(train_image_paths, train_mask_paths, transform=train_test_transform)
val_dataset = FloodDataset(val_image_paths, val_mask_paths, transform=train_test_transform)
test_dataset = FloodDataset(test_image_paths, test_mask_paths, transform=train_test_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

## Modelling

In [None]:
from transformers import (Mask2FormerForUniversalSegmentation , Mask2FormerImageProcessor)

model_id = "facebook/mask2former-swin-large-ade-semantic"

processor = Mask2FormerImageProcessor.from_pretrained(
    model_id, 
    ignore_index=255, 
    do_resize=False, 
    do_rescale=False)

config = Mask2FormerConfig.from_pretrained(model_id)
config.num_labels = NUM_CLASSES
config.id2label = {i: f"LABEL_{i}" for i in range(NUM_CLASSES)}
config.label2id = {f"LABEL_{i}": i for i in range(NUM_CLASSES)}

model = Mask2FormerForUniversalSegmentation(config)

checkpoint = torch.load(DIR_MODEL+MODEL_NAME, map_location="cpu")

# model = Mask2FormerForUniversalSegmentation.from_pretrained(
#     model_id,
#     num_labels=NUM_CLASSES, 
#     ignore_mismatched_sizes=True
# )

if 'state_dict' in checkpoint:
    state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
    state_dict = checkpoint['model']
else:
    state_dict = checkpoint

new_state_dict = {}
for key, value in state_dict.items():
    if "class_predictor" in key:
        continue
    new_state_dict[key] = value

msg = model.load_state_dict(new_state_dict, strict=False)

print(f"Model Mask2Former (Swin Large) loaded. \nLog: {msg}")

model.to("cuda" if torch.cuda.is_available() else "cpu")


## Train n Eval

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
model.to(device)
print(f"Model {model_id} siap untuk training 11 kelas RescueNet.")

In [None]:
print(f"Train Images: {len(train_image_paths)}, Train Masks: {len(train_mask_paths)}")
print(f"Val Images: {len(val_image_paths)}, Val Masks: {len(val_mask_paths)}")

In [None]:
import torch
from torch.optim import AdamW
from tqdm.auto import tqdm
from torchmetrics import JaccardIndex
from torch.optim.lr_scheduler import LambdaLR
import numpy as np

val_iou_metric = JaccardIndex(
    task="multiclass",
    num_classes=NUM_CLASSES,
    ignore_index=255
).to(device)

betas = (0.9, 0.999)
weight_decay = 0.05
lr=1e-5

# EPOCHS = 6

optimizer = AdamW(model.parameters(), 
                  # weight_decay=weight_decay, 
                  # betas=betas, 
                  lr=lr)

best_val_miou = 0.0  
best_epoch = -1
global_iter = 0

history = {
    "train_loss": [],
    "val_loss": [],
    "val_miou": []
}

print("ðŸš€ Mulai Training Mask2Former...")

for epoch in range(EPOCHS):
    
    model.train()
    
    epoch_train_loss = 0.0
    train_bar = tqdm(train_loader, desc=f"[Train] Epoch {epoch+1}/{EPOCHS}")

    for images, masks in train_bar:

        list_images = [img for img in images]
        list_masks  = [m for m in masks]

        inputs = processor(
            images=list_images,
            segmentation_maps=list_masks,
            task_inputs=["semantic"] * len(images),
            return_tensors="pt"
        )

        pixel_values = inputs["pixel_values"].to(device)
        mask_labels  = [m.to(device) for m in inputs["mask_labels"]]
        class_labels = [c.to(device) for c in inputs["class_labels"]]

        outputs = model(
            pixel_values=pixel_values,
            mask_labels=mask_labels,
            class_labels=class_labels
        )

        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()

        # torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        
        optimizer.step()
        global_iter += 1
        
        epoch_train_loss += loss.item()
        train_bar.set_postfix({"loss": f"{loss.item():.4f}"})
        del inputs, pixel_values, mask_labels, class_labels, outputs, loss
        torch.cuda.empty_cache()
        

    avg_train_loss = epoch_train_loss / len(train_loader)
    print(f"ðŸŽ¯ Train Loss: {avg_train_loss:.4f}")
    history["train_loss"].append(avg_train_loss)

    # ----------------------------
    # VALIDATION
    # ----------------------------

    model.eval()
    val_iou_metric.reset()
    epoch_val_loss = 0.0

    val_bar = tqdm(val_loader, desc=f"[Val] Epoch {epoch+1}/{EPOCHS}")

    with torch.no_grad():
        for images, masks in val_bar:

            list_images = [img for img in images]
            list_masks  = [m for m in masks]

            inputs = processor(
                images=list_images,
                segmentation_maps=list_masks,
                task_inputs=["semantic"] * len(images),
                return_tensors="pt"
            )

            pixel_values = inputs["pixel_values"].to(device)
            mask_labels  = [m.to(device) for m in inputs["mask_labels"]]
            class_labels = [c.to(device) for c in inputs["class_labels"]]

            outputs = model(
                pixel_values=pixel_values,
                mask_labels=mask_labels,
                class_labels=class_labels
            )

            loss = outputs.loss
            epoch_val_loss += loss.item()

            target_sizes = [(m.shape[0], m.shape[1]) for m in masks]
            preds = processor.post_process_semantic_segmentation(
                outputs, target_sizes=target_sizes
            )

            preds_tensor  = torch.stack(preds).to(device)
            target_tensor = masks.to(device)

            val_iou_metric.update(preds_tensor, target_tensor)

            del inputs, outputs, loss
            torch.cuda.empty_cache()

    avg_val_loss = epoch_val_loss / len(val_loader)
    val_miou = val_iou_metric.compute().mean().item()
    val_iou_metric.reset()
    
    print(f"ðŸ“Œ Val Loss: {avg_val_loss:.4f} | Val mIoU: {val_miou:.4f}")
    print("-" * 50)

    history["val_loss"].append(avg_val_loss)
    history["val_miou"].append(val_miou)

    if val_miou > best_val_miou:
        best_val_miou = val_miou
        best_epoch = epoch + 1
        torch.save(model.state_dict(), "best_model_mask2former_finetuned.pth")
        print(f"ðŸ’¾ New best model ! Epoch {epoch+1}, mIoU={val_miou:.4f}")

print(f"\nTraining Done Om! Best model in epoch {best_epoch} with mIoU={best_val_miou:.4f}")


In [None]:
import json 

WORKDIR = "/kaggle/working"
output_path = WORKDIR + "/history_mask2former_finetuned.json"

with open(output_path, "w") as f:
    json.dump(history, f, indent=4)

print("File saved to:", output_path)

In [None]:
import torch
from tqdm.auto import tqdm
from torchmetrics import JaccardIndex  

def test_model(model, test_loader, device, processor):
    metric = JaccardIndex(
        task="multiclass", 
        num_classes=NUM_CLASSES, 
        ignore_index=255,
        average="none" 
    ).to(device)

    model.to(device)
    
    model.eval()
    print("Mulai Testing (menggunakan JaccardIndex)...")
    
    with torch.no_grad():
        for images, masks in tqdm(test_loader, desc="Testing"):
            list_images = [img for img in images]
            
            inputs = processor(
                images=list_images,
                return_tensors="pt",
                do_resize=False,   
                do_rescale=False   
            )
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            outputs = model(**inputs)
            
            target_sizes = [(m.shape[0], m.shape[1]) for m in masks]
            pred_maps = processor.post_process_semantic_segmentation(
                outputs, target_sizes=target_sizes
            )
            preds_batch = torch.stack(pred_maps).to(device)
            target_batch = masks.to(device)
            
            metric.update(preds_batch, target_batch)
    
    iou_per_class = metric.compute()
    
    mIoU = iou_per_class.mean().item()
    
    print("\n=== HASIL TESTING ===")
    print(f"Mean IoU (mIoU): {mIoU:.4f}")
    print("-" * 30)
    
    class_names = ["Background", "Building Flooded", "Building Non-Flooded",  
                   "Road Flooded", "Road Non-Flooded", "Water", "Tree", "Vehicle", "Pool", "Grass"]
    
    for i, iou in enumerate(iou_per_class):
        name = class_names[i] if i < len(class_names) else f"Class {i}"
        print(f"{name:25s}: {iou.item():.4f}")
        
    metric.reset()
    return mIoU, iou_per_class

In [None]:
import torch
from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerImageProcessor

model_id = "facebook/mask2former-swin-large-ade-semantic"
WORKDIR = "/kaggle/working"

class_names = ["Background", "Building Flooded", "Building Non-Flooded",  
               "Road Flooded", "Road Non-Flooded", "Water", "Tree", "Vehicle", "Pool", "Grass"]
    
# num_labels = len(class_names)            
# num_classes_internal = num_labels + 1    

processor = Mask2FormerImageProcessor.from_pretrained(
    model_id,
    ignore_index=255,
    do_resize=False,
    do_rescale=False
)

model = Mask2FormerForUniversalSegmentation.from_pretrained(
    model_id,
    num_labels=NUM_CLASSES,
    ignore_mismatched_sizes=True
)

ckpt = torch.load(WORKDIR + "/best_model_mask2former_finetuned.pth", map_location="cpu")

if isinstance(ckpt, dict) and "model" in ckpt:
    ckpt_state = ckpt["model"]
else:
    ckpt_state = ckpt

model_state = model.state_dict()
compatible = {k: v for k, v in ckpt_state.items() if (k in model_state and v.shape == model_state[k].shape)}

print(f"Total model keys: {len(model_state)}; Compatible keys from ckpt: {len(compatible)}; Skipped keys: {len(ckpt_state)-len(compatible)}")

model_state.update(compatible)
model.load_state_dict(model_state)

test_mIoU, test_iou_per_class = test_model(model, test_loader, device, processor)

## Testing

In [None]:
# import matplotlib.pyplot as plt
# import numpy as np
# import torch
# import os
# from matplotlib.patches import Patch

# CLASS_NAMES = [
#     "Background",               
#     "Water",                    
#     "Building No Damage",       
#     "Building Minor Damage",    
#     "Building Major Damage",    
#     "Building Total Destruction",
#     "Road-Clear",               
#     "Road-Blocked",             
#     "Vehicle",                  
#     "Tree",                     
#     "Pool"                      
# ]

# LABEL_COLORS = np.array([
#     [0, 0, 0],         # Background 
#     [30, 230, 255],    # Water 
#     [184, 115, 117],   # Building No Damage
#     [216, 255, 0],     # Building Minor Damage
#     [252, 199, 0],     # Building Major Damage
#     [255, 0, 0],       # Building Total Destruction
#     [140, 140, 140],   # Road-Clear
#     [151, 0, 255],     # Road-Blocked
#     [255, 0, 246],     # Vehicle 
#     [0, 255, 0],       # Tree
#     [244, 255, 0]      # Pool
# ])
# def decode_segmap(mask):
#     r = np.zeros_like(mask).astype(np.uint8)
#     g = np.zeros_like(mask).astype(np.uint8)
#     b = np.zeros_like(mask).astype(np.uint8)
    
#     for l in range(0, len(LABEL_COLORS)):
#         idx = mask == l
#         r[idx] = LABEL_COLORS[l, 0]
#         g[idx] = LABEL_COLORS[l, 1]
#         b[idx] = LABEL_COLORS[l, 2]
        
#     rgb = np.stack([r, g, b], axis=2)
#     return rgb

# def find_indices_by_filename(dataset, target_ids):
#     found_indices = []
#     for target in target_ids:
#         found = False
#         for idx, path in enumerate(dataset.image_path):
#             if str(target) in os.path.basename(path):
#                 found_indices.append(idx)
#                 found = True
#                 break
#         if not found:
#             return 
#     return found_indices

# def visualize_specific_images(model, dataset, target_ids, device, processor):
#     model.eval()
    
#     indices = find_indices_by_filename(dataset, target_ids)

#     num_samples = len(indices)
#     fig, axes = plt.subplots(num_samples, 3, figsize=(18, 6 * num_samples))
    
#     if num_samples == 1:
#         axes = axes.reshape(1, -1)

#     for row_idx, idx in enumerate(indices):
#         image, mask = dataset[idx] 
        
#         filename = os.path.basename(dataset.image_path[idx])
        
#         inputs = processor(
#             images=[image], 
#             return_tensors="pt",
#             do_resize=False, 
#             do_rescale=False
#         )
#         inputs = {k: v.to(device) for k, v in inputs.items()}
        
#         with torch.no_grad():
#             outputs = model(**inputs)
        
#         target_sizes = [(mask.shape[0], mask.shape[1])]
#         pred_map = processor.post_process_semantic_segmentation(
#             outputs, target_sizes=target_sizes
#         )[0] 
        
#         img_np = image.permute(1, 2, 0).numpy()
        
#         mask_rgb = decode_segmap(mask.numpy())
#         pred_rgb = decode_segmap(pred_map.cpu().numpy())
        
#         axes[row_idx, 0].imshow(img_np)
#         axes[row_idx, 0].set_title(f"ID: {filename}\nOriginal Image")
#         axes[row_idx, 0].axis("off")
        
#         axes[row_idx, 1].imshow(mask_rgb)
#         axes[row_idx, 1].set_title("Ground Truth")
#         axes[row_idx, 1].axis("off")
        
#         axes[row_idx, 2].imshow(pred_rgb)
#         axes[row_idx, 2].set_title("Mask2Former Prediction")
#         axes[row_idx, 2].axis("off")

#     handles = [Patch(color=LABEL_COLORS[i]/255.0, label=CLASS_NAMES[i]) for i in range(len(CLASS_NAMES))]
#     fig.legend(handles=handles, loc='lower center', ncol=6, bbox_to_anchor=(0.5, 0.0), fontsize=12)

#     plt.savefig('visualisasi_prediksi_rescuenet.png', bbox_inches='tight', dpi=300)
    
#     plt.tight_layout()
#     plt.subplots_adjust(bottom=0.08) 
#     plt.show()

# target_ids = ["10794", "10801", "10807"]

# visualize_specific_images(model, test_dataset, target_ids, device, processor)

In [None]:
# model.eval()
# import matplotlib.pyplot as plt

# test_imgs, test_masks = next(iter(test_loader))

# with torch.no_grad():
#     inputs = [{"image": test_imgs[0].to(cfg.MODEL.DEVICE), "height": 512, "width": 512}]
    
#     outputs = model(inputs)
    
#     pred_mask = outputs[0]["sem_seg"].argmax(dim=0).cpu().numpy()

# plt.figure(figsize=(10, 5))
# plt.subplot(1, 2, 1); plt.title("Prediction"); plt.imshow(pred_mask)
# plt.subplot(1, 2, 2); plt.title("Ground Truth"); plt.imshow(test_masks[0])
# plt.show()