### Ver3 Optimized

In [1]:
# ===========================================
# Cell 1. Import Library dan Setup Environment
# ===========================================

import os
import json
import random
from pathlib import Path
from collections import Counter

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import confusion_matrix, f1_score, accuracy_score
from sklearn.preprocessing import MinMaxScaler

# Gunakan GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device yang digunakan:", device)

# Monitor GPU memory
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Total Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

Device yang digunakan: cuda
GPU: NVIDIA GeForce RTX 4080 SUPER
Total Memory: 17.17 GB


In [2]:
# ===========================================
# Cell 2. Fungsi Bantuan Umum
# ===========================================

from matplotlib.colors import ListedColormap

def seed_everything(seed=42):
    """Menetapkan seed random agar hasil eksperimen bisa direplikasi"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

seed_everything(42)

def visualize_tile(x_tile, y_true=None, y_pred=None, json_path=None, class_names=None, idx=0):
    """Menampilkan citra tile beserta mask ground-truth dan prediksi"""
    if isinstance(x_tile, torch.Tensor):
        x = x_tile.cpu().numpy()
        x = np.transpose(x, (1,2,0))  # ubah dari [B,H,W] -> [H,W,B]
    else:
        x = x_tile

    # menampilkan pseudo-RGB (karena data hyperspectral)
    B = x.shape[2]
    b1, b2, b3 = int(B*0.05), int(B*0.5), int(B*0.9)
    rgb = x[..., [b1, b2, b3]]
    rgb_norm = (rgb - rgb.min()) / (rgb.max() - rgb.min() + 1e-9)

    # Coba baca colormap dari file JSON
    if json_path and os.path.exists(json_path):
        with open(json_path, "r") as f:
            label_info = json.load(f)
        custom_colors = [c["color"][:7] for c in label_info]
        cmap = ListedColormap(custom_colors)
    else:
        print("File json tidak terbaca, menggunakan cmap tab20")
        cmap = "tab20"  # fallback

    # Visualisasi
    plt.figure(figsize=(12,4))
    plt.subplot(1,3,1); plt.imshow(rgb_norm); plt.title("Citra (Pseudo-RGB)")
    if y_true is not None:
        plt.subplot(1,3,2); plt.imshow(y_true, cmap=cmap); plt.title("Ground Truth")
    if y_pred is not None:
        plt.subplot(1,3,3); plt.imshow(y_pred, cmap=cmap); plt.title("Prediksi")
    plt.show()

In [3]:
# ===========================================
# Cell 3. Dataset Loader
# ===========================================

def load_label_mapping(json_path):
    """Membaca file label_classes.json untuk mapping id ke nama kelas"""
    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    idx_to_name = {i: item["name"] for i, item in enumerate(data)}
    return idx_to_name

def normalize_reflectance(cube):
    """Menormalkan reflektansi 0–1 per tile, hemat RAM, aman untuk mmap read-only."""
    if not cube.flags.writeable:
        cube = cube.astype(np.float32, copy=True)

    if cube.dtype != np.float32:
        cube = cube.astype(np.float32, copy=False)

    np.nan_to_num(cube, copy=False)

    min_val = np.nanmin(cube)
    max_val = np.nanmax(cube)
    if max_val > min_val:
        cube -= min_val
        cube /= (max_val - min_val + 1e-8)

    return cube


class SeaweedDataset(Dataset):
    """Dataset hemat memori berbasis file .npy hasil konversi."""
    def __init__(self, data_files, label_map, tile_size=64, normalize=True, label_remap=None):
        self.data_files = data_files
        self.label_map = label_map
        self.tile_size = tile_size
        self.normalize = normalize
        self.label_remap = label_remap

        # Daftar pasangan (file_x, file_y)
        self.pairs = []
        for f in data_files:
            if f.endswith("_x.npy"):
                fy = f.replace("_x.npy", "_y.npy")
                if os.path.exists(fy):
                    self.pairs.append((f, fy))
        
        # SOLUSI 3: Pre-filter tile kosong
        self.index = []
        print("[INFO] Pre-filtering empty tiles...")
        
        empty_count = 0
        valid_count = 0
        
        for file_idx, (fx, fy) in enumerate(self.pairs):
            x = np.load(fx, mmap_mode="r")
            y = np.load(fy, mmap_mode="r")
            H, W, _ = x.shape
            
            for i in range(0, H - tile_size + 1, tile_size):
                for j in range(0, W - tile_size + 1, tile_size):
                    y_tile = y[i:i+tile_size, j:j+tile_size]
                    
                    if np.any(y_tile > 0):
                        self.index.append((file_idx, i, j))
                        valid_count += 1
                    else:
                        empty_count += 1
            del x, y
            
        print(f"[INFO] Total tile valid: {valid_count}")
        print(f"[INFO] Total tile empty (filtered): {empty_count}")
        print(f"[INFO] Ratio valid/total: {valid_count/(valid_count+empty_count)*100:.2f}%")

        print(f"[INFO] Total tile terdaftar: {len(self.index)} dari {len(self.pairs)} file")

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        file_idx, i, j = self.index[idx]
        fx, fy = self.pairs[file_idx]
        
        # Memuat tile menggunakan mmap (sudah melalui pre-filtered guna mendapatkan kondisi non-empty)
        x = np.load(fx, mmap_mode="r")[i:i+self.tile_size, j:j+self.tile_size, :]
        y = np.load(fy, mmap_mode="r")[i:i+self.tile_size, j:j+self.tile_size]

        if self.normalize:
            x = normalize_reflectance(x)

        # REMAP label bila mapping diberikan
        if self.label_remap is not None:
            y_remap = np.zeros_like(y, dtype=np.int64)
            for orig_label, new_idx in self.label_remap.items():
                y_remap[y == orig_label] = new_idx
            y = y_remap
        else:
            y = y.astype(np.int64)

        # Konversi ke tensor
        x_tensor = torch.tensor(x.transpose(2, 0, 1), dtype=torch.float32)
        y_tensor = torch.tensor(y, dtype=torch.long)
        return x_tensor, y_tensor


def detect_actual_classes(pairs):
    """Scan semua file y.npy untuk mendeteksi kelas yang benar-benar ada"""
    found = set()
    for _, fy in pairs:
        y = np.load(fy, mmap_mode="r")
        found |= set(np.unique(y))
    found = sorted(list(found))
    print(f"[INFO] Kelas AKTUAL yang ditemukan di dataset: {found}")
    return found

In [4]:
# ===========================================
# Cell 4. Load Dataset dan Splitting
# ===========================================

data_dir = "../data/npy_converted"
label_json_path = "../data/annotation/segmentation_masks/label_classes.json"

label_map = load_label_mapping(label_json_path)
print(f"Jumlah total kelas di JSON: {len(label_map)}")

# Ambil semua file _x.npy
all_x_files = sorted([os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith("_x.npy")])
pairs = [(fx, fx.replace("_x.npy", "_y.npy")) for fx in all_x_files if os.path.exists(fx.replace("_x.npy", "_y.npy"))]

print(f"Total pasangan file X-Y ditemukan: {len(pairs)}")

# Split deterministik berbasis urutan nama (11 train, 5 val, 2 test)
train_pairs = pairs[:11]
val_pairs   = pairs[11:16]
test_pairs  = pairs[16:]

print("\n=== FINAL SPLIT PER FILE ===")
print(f"Train : {len(train_pairs)}")
print(f"Val   : {len(val_pairs)}")
print(f"Test  : {len(test_pairs)}")

# DETEKSI kelas aktual
actual_classes = detect_actual_classes(train_pairs + val_pairs + test_pairs)
orig_classes = [int(x) for x in actual_classes]
label_remap = {orig: idx for idx, orig in enumerate(orig_classes)}
print(f"[INFO] Label remap (orig -> new): {label_remap}")

# PENTING: Ubah tile_size ke 64 untuk menghemat memory
TILE_SIZE = 32  # Turun dari 64

train_dataset = SeaweedDataset([p[0] for p in train_pairs], label_map, tile_size=TILE_SIZE, label_remap=label_remap)
val_dataset   = SeaweedDataset([p[0] for p in val_pairs], label_map, tile_size=TILE_SIZE, label_remap=label_remap)
test_dataset  = SeaweedDataset([p[0] for p in test_pairs], label_map, tile_size=TILE_SIZE, label_remap=label_remap, normalize=False)

# Hitung class weights
counter = Counter()
for _, fy in train_pairs:
    y = np.load(fy, mmap_mode="r")
    for orig, new in label_remap.items():
        cnt = int((y == orig).sum())
        counter[new] += cnt

print(f"[INFO] Pixel counts per class: {dict(counter)}")

counts = np.array([counter.get(i, 0) for i in range(len(label_remap))], dtype=np.float64)
eps = 1e-6
inv_freq = 1.0 / (counts + eps)
inv_freq = inv_freq / np.mean(inv_freq)
inv_freq[0] = 0.0  # ignore background

class_weights_np = inv_freq.astype(np.float32)
print(f"[INFO] Class weights: {class_weights_np}")

num_classes_actual = len(label_remap)
print(f"\nTotal TILE train: {len(train_dataset)}, val: {len(val_dataset)}, test: {len(test_dataset)}")

Jumlah total kelas di JSON: 41
Total pasangan file X-Y ditemukan: 18

=== FINAL SPLIT PER FILE ===
Train : 11
Val   : 5
Test  : 2
[INFO] Kelas AKTUAL yang ditemukan di dataset: [np.int32(0), np.int32(8), np.int32(12), np.int32(13), np.int32(14), np.int32(18), np.int32(38)]
[INFO] Label remap (orig -> new): {0: 0, 8: 1, 12: 2, 13: 3, 14: 4, 18: 5, 38: 6}
[INFO] Pre-filtering empty tiles...
[INFO] Total tile valid: 5395
[INFO] Total tile empty (filtered): 12805
[INFO] Ratio valid/total: 29.64%
[INFO] Total tile terdaftar: 5395 dari 11 file
[INFO] Pre-filtering empty tiles...
[INFO] Total tile valid: 5698
[INFO] Total tile empty (filtered): 2898
[INFO] Ratio valid/total: 66.29%
[INFO] Total tile terdaftar: 5698 dari 5 file
[INFO] Pre-filtering empty tiles...
[INFO] Total tile valid: 1710
[INFO] Total tile empty (filtered): 950
[INFO] Ratio valid/total: 64.29%
[INFO] Total tile terdaftar: 1710 dari 2 file
[INFO] Pixel counts per class: {0: 14185638, 1: 840140, 2: 1566138, 3: 808104, 4: 369

In [5]:
# ===========================================
# Cell 5. Model OPTIMIZED HybridSN
# ===========================================

class OptimizedFCHybridSN(nn.Module):
    """Versi optimized dari HybridSN dengan spectral pooling"""
    def __init__(self, in_bands=300, num_classes=7):
        super().__init__()
        # 3D Convolution layers
        self.conv3d_1 = nn.Conv3d(1, 16, (7,3,3), padding=(0,1,1))
        self.bn3d_1 = nn.BatchNorm3d(16)
        
        self.conv3d_2 = nn.Conv3d(16, 32, (5,3,3), padding=(0,1,1))
        self.bn3d_2 = nn.BatchNorm3d(32)
        
        self.conv3d_3 = nn.Conv3d(32, 64, (3,3,3), padding=(0,1,1))
        self.bn3d_3 = nn.BatchNorm3d(64)

        # PERBAIKAN: Tambahkan spectral pooling
        self.spectral_pool = nn.AdaptiveAvgPool3d((8, None, None))
        
        # 2D Convolution layers (input channel jauh lebih kecil sekarang)
        self.conv2d_1 = nn.Conv2d(64 * 8, 256, 3, padding=1)
        self.bn2d_1 = nn.BatchNorm2d(256)
        self.dropout1 = nn.Dropout2d(0.3)
        
        self.conv2d_2 = nn.Conv2d(256, 128, 3, padding=1)
        self.bn2d_2 = nn.BatchNorm2d(128)
        self.dropout2 = nn.Dropout2d(0.3)
        
        self.conv2d_3 = nn.Conv2d(128, 64, 3, padding=1)
        self.bn2d_3 = nn.BatchNorm2d(64)
        
        self.classifier = nn.Conv2d(64, num_classes, 1)

    def forward(self, x):
        B, Bands, H, W = x.shape
        
        # 3D CNN processing
        x3 = x.unsqueeze(1)
        x3 = F.relu(self.bn3d_1(self.conv3d_1(x3)))
        x3 = F.relu(self.bn3d_2(self.conv3d_2(x3)))
        x3 = F.relu(self.bn3d_3(self.conv3d_3(x3)))
        
        # PERBAIKAN: Spectral pooling
        x3 = self.spectral_pool(x3)
        
        # Reshape ke 2D
        B, C3, reduced_spec, H, W = x3.shape
        x2 = x3.view(B, C3 * reduced_spec, H, W)
        
        # 2D CNN processing
        x2 = self.dropout1(F.relu(self.bn2d_1(self.conv2d_1(x2))))
        x2 = self.dropout2(F.relu(self.bn2d_2(self.conv2d_2(x2))))
        x2 = F.relu(self.bn2d_3(self.conv2d_3(x2)))
        
        return self.classifier(x2)


# Ambil jumlah band dari data
sample_x = np.load(train_pairs[0][0], mmap_mode="r")
in_bands_actual = sample_x.shape[2]
print(f"Band input aktual: {in_bands_actual}")

model = OptimizedFCHybridSN(in_bands=in_bands_actual, num_classes=num_classes_actual).to(device)

# Hitung parameter
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")
print(model)

Band input aktual: 300
Total parameters: 1,629,767
OptimizedFCHybridSN(
  (conv3d_1): Conv3d(1, 16, kernel_size=(7, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
  (bn3d_1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3d_2): Conv3d(16, 32, kernel_size=(5, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
  (bn3d_2): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3d_3): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
  (bn3d_3): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (spectral_pool): AdaptiveAvgPool3d(output_size=(8, None, None))
  (conv2d_1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2d_1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout1): Dropout2d(p=0.3, inplace=False)
  (conv2d_2): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2d_2): Batch

In [None]:
print(f"[INFO] Original class weights: {class_weights_np}")

In [6]:
# ===========================================
# Cell 6. Loss, Optimizer, dan Metrics (STABLE + SQRT SMOOTHING)
# ===========================================

import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Enable cudnn benchmark
torch.backends.cudnn.benchmark = True

# Hyperparameters
LR = 5e-5 # turun dari 1e-4
WEIGHT_DECAY = 1e-5
BATCH_SIZE = 1
ACCUMULATION_STEPS = 8  # Efektif batch size = 8
CLIP_NORM = 0.5 # Turun dari 1.0

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

# PERBAIKAN: Smoothed class weights untuk stabilitas
print(f"[INFO] Original class weights: {class_weights_np}")

# Smoothing varian sqrt untuk smoothing
class_weights_smoothed = np.sqrt(class_weights_np)
class_weights_smoothed[0] = 0.0

print(f"[INFO] Smoothed class weights: {class_weights_smoothed}")

# Versi loss function dengan class weights
weight_tensor = torch.from_numpy(class_weights_smoothed).to(device)
criterion = nn.CrossEntropyLoss(weight=weight_tensor, ignore_index=0, label_smoothing=0.1)

# Versi loss function criterion non-class-weight
# criterion = nn.CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)

# Optimizer & scheduler
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, eps=1e-8)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5, verbose=True, min_lr=1e-7)


# PERBAIKAN: Metrik evaluasi yang benar
class SegmentationMetrics:
    def __init__(self, num_classes, ignore_index=0):
        self.num_classes = num_classes
        self.ignore_index = ignore_index
        self.reset()
    
    def reset(self):
        self.total_intersection = torch.zeros(self.num_classes)
        self.total_union = torch.zeros(self.num_classes)
        self.total_correct = 0
        self.total_pixels = 0
    
    def update(self, pred, target):
        """Update metrics dengan batch baru"""
        valid = (target != self.ignore_index)
        pred = pred[valid]
        target = target[valid]
        
        # Pixel accuracy
        self.total_correct += (pred == target).sum().item()
        self.total_pixels += valid.sum().item()
        
        # IoU per class
        for cls in range(self.num_classes):
            pred_i = (pred == cls)
            target_i = (target == cls)
            intersection = (pred_i & target_i).sum().item()
            union = (pred_i | target_i).sum().item()
            
            self.total_intersection[cls] += intersection
            self.total_union[cls] += union
    
    def get_metrics(self):
        """Hitung metrik final"""
        pixel_acc = self.total_correct / (self.total_pixels + 1e-9)
        
        iou_per_class = self.total_intersection / (self.total_union + 1e-9)
        # Exclude background (index 0) dan kelas yang tidak muncul
        valid_ious = []
        for i in range(1, self.num_classes):
            if self.total_union[i] > 0:
                valid_ious.append(iou_per_class[i].item())
        
        mean_iou = np.mean(valid_ious) if valid_ious else 0.0
        
        return pixel_acc, mean_iou, iou_per_class.numpy()

print("Setup selesai!")

[INFO] Original class weights: [0.         0.2185292  0.11722793 0.22719245 4.9649825  1.3176336
 0.14149204]
[INFO] Smoothed class weights: [0.         0.46747106 0.34238565 0.47664708 2.228224   1.1478822
 0.37615427]
Setup selesai!




In [None]:
# TEST MEMORY - Jalankan sebelum training
model.eval()
with torch.no_grad():
    dummy_x = torch.randn(1, in_bands_actual, TILE_SIZE, TILE_SIZE).to(device)
    dummy_y = model(dummy_x)
    print(f"Test passed! Output shape: {dummy_y.shape}")
    print(f"Memory used: {torch.cuda.memory_allocated()/1e9:.2f} GB")
    del dummy_x, dummy_y
    torch.cuda.empty_cache()

In [7]:
# ===========================================
# Cell 7. Training Loop (OPTIMIZED)
# ===========================================

from tqdm import tqdm
import time

START_EPOCH = 1
NUM_EPOCHS = 2
best_val_miou = 0.0

checkpoint_path = "hybridsn_sgmt_ver3_checkpoint.pth"
best_model_path = "hybridsn_sgmt_ver3_best_model.pth"

# Load checkpoint jika ada
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint["model_state"])
    optimizer.load_state_dict(checkpoint["optimizer_state"])
    START_EPOCH = checkpoint["epoch"] + 1
    best_val_miou = checkpoint.get("best_val_miou", 0.0)
    print(f"[INFO] Resume dari epoch {START_EPOCH}")
else:
    print("[INFO] Training dari awal")

history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": [], "val_miou": []}

def train_one_epoch(model, loader, criterion, optimizer, metrics, device):
    model.train()
    running_loss = 0.0
    metrics.reset()
    
    pbar = tqdm(loader, desc="Training", leave=False)
    optimizer.zero_grad()

    nan_count = 0  # Counter NaN
    
    for i, (xb, yb) in enumerate(pbar):
        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)

        # SOLUSI 1: Skip jika semua background (safety net)
        if torch.all(yb == 0):
            continue
        
        logits = model(xb)
        loss = criterion(logits, yb) / ACCUMULATION_STEPS

        # SOLUSI 1: CHECK NaN
        if torch.isnan(loss) or torch.isinf(loss):
            nan_count += 1
            continue
        
        loss.backward()
        
        # Gradient accumulation
        if (i + 1) % ACCUMULATION_STEPS == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
            optimizer.step()
            optimizer.zero_grad()
        
        running_loss += loss.item() * ACCUMULATION_STEPS * xb.size(0)
        
        preds = logits.argmax(dim=1)
        metrics.update(preds, yb)
        
        pbar.set_postfix({"loss": f"{loss.item()*ACCUMULATION_STEPS:.4f}"})
        
    if nan_count > 0:
        print(f"  [INFO] Skipped {nan_count} NaN batches in training")
    
    avg_loss = running_loss / len(loader.dataset)
    pixel_acc, _, _ = metrics.get_metrics()
    
    return avg_loss, pixel_acc

def validate(model, loader, criterion, metrics, device):
    model.eval()
    running_loss = 0.0
    valid_batches = 0
    metrics.reset()

    nan_count = 0  # Counter NaN
    
    with torch.no_grad():
        for xb, yb in tqdm(loader, desc="Validation", leave=False):
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)

            # SOLUSI 1: Skip jika semua background (safety net)
            if torch.all(yb == 0):
                continue
            
            logits = model(xb)
            loss = criterion(logits, yb)

            # SOLUSI 1: CHECK NaN
            if torch.isnan(loss) or torch.isinf(loss):
                nan_count += 1
                continue
            
            running_loss += loss.item() * xb.size(0)
            valid_batches += 1
            
            preds = logits.argmax(dim=1)
            metrics.update(preds, yb)

    if nan_count > 0:
        print(f"  [INFO] Skipped {nan_count} NaN batches in validation")
    
    # Hitung average loss
    if valid_batches > 0:
        avg_loss = running_loss / (valid_batches * loader.batch_size)
    else:
        avg_loss = float('nan')
        print("  [WARNING] Semua validation batches di-skip!")
    
    pixel_acc, mean_iou, iou_per_class = metrics.get_metrics()
    
    return avg_loss, pixel_acc, mean_iou, iou_per_class


# Training loop
for epoch in range(START_EPOCH, NUM_EPOCHS + 1):
    print(f"\n{'='*60}")
    print(f"Epoch {epoch}/{NUM_EPOCHS}")
    print(f"{'='*60}")
    
    start_time = time.time()
    
    # Monitor GPU memory
    if torch.cuda.is_available():
        print(f"GPU Memory: {torch.cuda.memory_allocated()/1e9:.2f} GB allocated")
    
    # Training
    train_metrics = SegmentationMetrics(num_classes_actual, ignore_index=0)
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, train_metrics, device)
    
    # Validation
    val_metrics = SegmentationMetrics(num_classes_actual, ignore_index=0)
    val_loss, val_acc, val_miou, val_iou_per_class = validate(model, val_loader, criterion, val_metrics, device)
    
    # Scheduler step
    scheduler.step(val_miou)
    
    # Logging
    elapsed = time.time() - start_time
    print(f"\nResults:")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"  Val Loss  : {val_loss:.4f} | Val Acc  : {val_acc:.4f}")
    print(f"  Val mIoU  : {val_miou:.4f}")
    print(f"  Time      : {elapsed/60:.2f} min")
    print(f"  IoU per class: {val_iou_per_class[1:]}")
    
    # Save history
    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["train_acc"].append(train_acc)
    history["val_acc"].append(val_acc)
    history["val_miou"].append(val_miou)
    
    # Checkpoint
    checkpoint = {
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "scheduler_state": scheduler.state_dict(),
        "best_val_miou": best_val_miou,
        "history": history
    }
    torch.save(checkpoint, checkpoint_path)
    
    # Save best model
    if val_miou > best_val_miou:
        best_val_miou = val_miou
        torch.save(checkpoint, best_model_path)
        print(f"[OK] Best model saved! (mIoU: {best_val_miou:.4f})")
    
    # Clear cache
    torch.cuda.empty_cache()

print("\n" + "="*60)
print("Training selesai!")
print(f"Best validation mIoU: {best_val_miou:.4f}")

[INFO] Training dari awal

Epoch 1/2
GPU Memory: 0.01 GB allocated


                                                                     


Results:
  Train Loss: 1.8991 | Train Acc: 0.3833
  Val Loss  : 1.8985 | Val Acc  : 0.4634
  Val mIoU  : 0.1077
  Time      : 4.15 min
  IoU per class: [0.         0.03565709 0.         0.         0.         0.50302017]
[OK] Best model saved! (mIoU: 0.1077)

Epoch 2/2
GPU Memory: 0.03 GB allocated


                                                                     


Results:
  Train Loss: 1.6945 | Train Acc: 0.5737
  Val Loss  : 1.8632 | Val Acc  : 0.0623
  Val mIoU  : 0.0263
  Time      : 3.43 min
  IoU per class: [0.         0.03460516 0.00362831 0.         0.         0.04056756]

Training selesai!
Best validation mIoU: 0.1077




In [None]:
# ===========================================
# Cell 8. Plot Training History
# ===========================================

fig, axes = plt.subplots(1, 3, figsize=(18, 4))

# Loss
axes[0].plot(history["train_loss"], label="Train Loss", marker='o')
axes[0].plot(history["val_loss"], label="Val Loss", marker='s')
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].set_title("Training vs Validation Loss")
axes[0].legend()
axes[0].grid(True)

# Accuracy
axes[1].plot(history["train_acc"], label="Train Acc", marker='o')
axes[1].plot(history["val_acc"], label="Val Acc", marker='s')
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Pixel Accuracy")
axes[1].set_title("Training vs Validation Accuracy")
axes[1].legend()
axes[1].grid(True)

# mIoU
axes[2].plot(history["val_miou"], label="Val mIoU", marker='d', color='green')
axes[2].set_xlabel("Epoch")
axes[2].set_ylabel("Mean IoU")
axes[2].set_title("Validation mIoU")
axes[2].legend()
axes[2].grid(True)

plt.tight_layout()
plt.savefig("training_history_ver3.png", dpi=150)
plt.show()

In [None]:
# ===========================================
# Cell 9. Testing dan Visualisasi
# ===========================================

# Load best model
best_checkpoint = torch.load(best_model_path, map_location=device)
model.load_state_dict(best_checkpoint["model_state"])
print(f"Loaded best model from epoch {best_checkpoint['epoch']}")

# Test evaluation
test_metrics = SegmentationMetrics(num_classes_actual, ignore_index=0)
test_loss, test_acc, test_miou, test_iou_per_class = validate(model, test_loader, criterion, test_metrics, device)

print("\n" + "="*60)
print("TEST RESULTS")
print("="*60)
print(f"Test Loss     : {test_loss:.4f}")
print(f"Test Accuracy : {test_acc:.4f}")
print(f"Test mIoU     : {test_miou:.4f}")
print(f"IoU per class : {test_iou_per_class[1:]}")

# Visualisasi beberapa prediksi
model.eval()
num_vis = 3
vis_samples = []

with torch.no_grad():
    for i, (xb, yb) in enumerate(test_loader):
        if i >= num_vis:
            break
        xb = xb.to(device)
        logits = model(xb)
        preds = logits.argmax(dim=1)
        
        vis_samples.append((xb[0], yb[0], preds[0]))

# Plot visualisasi
for i, (x, y_true, y_pred) in enumerate(vis_samples):
    visualize_tile(x, y_true.cpu().numpy(), y_pred.cpu().numpy(), 
                   json_path=label_json_path, idx=i)

print("\nSelesai!")

In [None]:
# ===========================================
# Cell 10. Confusion Matrix
# ===========================================

from sklearn.metrics import confusion_matrix
import seaborn as sns

# Collect predictions untuk confusion matrix
all_preds = []
all_targets = []

model.eval()
with torch.no_grad():
    for xb, yb in tqdm(test_loader, desc="Computing CM"):
        xb = xb.to(device)
        logits = model(xb)
        preds = logits.argmax(dim=1)
        
        # Flatten dan filter valid pixels
        preds_flat = preds.cpu().numpy().flatten()
        targets_flat = yb.numpy().flatten()
        
        valid = targets_flat != 0  # Exclude background
        all_preds.extend(preds_flat[valid])
        all_targets.extend(targets_flat[valid])

# Compute confusion matrix
cm = confusion_matrix(all_targets, all_preds, labels=list(range(1, num_classes_actual)))

# Plot
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=range(1, num_classes_actual),
            yticklabels=range(1, num_classes_actual))
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix (Test Set - Excluding Background)')
plt.tight_layout()
plt.savefig('confusion_matrix_ver3.png', dpi=150)
plt.show()

print("Confusion matrix saved!")