In [None]:
# ===========================================
# 1. Import Library dan Setup Environment
# ===========================================

import os
import json
import random
from pathlib import Path
from collections import Counter

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import confusion_matrix, f1_score, accuracy_score
from sklearn.preprocessing import MinMaxScaler

# Gunakan GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device yang digunakan:", device)

In [None]:
# ===========================================
# 2. Fungsi Bantuan Umum
# ===========================================

def seed_everything(seed=42):
    """Menetapkan seed random agar hasil eksperimen bisa direplikasi"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

seed_everything(42)

def visualize_tile(x_tile, y_true=None, y_pred=None, class_names=None, idx=0):
    """
    Menampilkan citra tile beserta mask ground-truth dan prediksi
    """
    if isinstance(x_tile, torch.Tensor):
        x = x_tile.cpu().numpy()
        x = np.transpose(x, (1,2,0))  # ubah dari [B,H,W] -> [H,W,B]
    else:
        x = x_tile

    # menampilkan pseudo-RGB (karena data hyperspectral)
    B = x.shape[2]
    b1, b2, b3 = int(B*0.05), int(B*0.5), int(B*0.9)
    rgb = x[..., [b1, b2, b3]]
    rgb_norm = (rgb - rgb.min()) / (rgb.max() - rgb.min() + 1e-9)

    plt.figure(figsize=(12,4))
    plt.subplot(1,3,1); plt.imshow(rgb_norm); plt.title("Citra (Pseudo-RGB)")
    if y_true is not None:
        plt.subplot(1,3,2); plt.imshow(y_true, cmap='tab20'); plt.title("Ground Truth")
    if y_pred is not None:
        plt.subplot(1,3,3); plt.imshow(y_pred, cmap='tab20'); plt.title("Prediksi")
    plt.show()



In [None]:
# ====================================================
# 3. Dataset Loader (SeaweedDataset) dan Label Mapping
# ====================================================

def load_label_mapping(json_path):
    """Membaca file label_classes.json untuk mapping id ke nama kelas"""
    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    idx_to_name = {i: item["name"] for i, item in enumerate(data)}
    return idx_to_name

def normalize_reflectance(cube):
    """Menormalkan nilai reflektansi ke rentang 0-1"""
    cube = np.nan_to_num(cube)
    scaler = MinMaxScaler()
    flat = cube.reshape(-1, cube.shape[-1])
    flat_scaled = scaler.fit_transform(flat)
    return flat_scaled.reshape(cube.shape)

class SeaweedDataset(Dataset):
    def __init__(self, data_files, label_map, tile_size=128):
        self.data_files = data_files
        self.label_map = label_map
        self.tile_size = tile_size
        self.tiles = []

        for f in data_files:
            data = np.load(f)
            x = data["x"]
            y = data["y"]
            x = normalize_reflectance(x)
            H, W, C = x.shape

            # Membagi citra menjadi tile berukuran tile_size x tile_size
            for i in range(0, H, tile_size):
                for j in range(0, W, tile_size):
                    x_tile = x[i:i+tile_size, j:j+tile_size, :]
                    y_tile = y[i:i+tile_size, j:j+tile_size]
                    if np.any(y_tile > 0):  # abaikan tile kosong
                        self.tiles.append((x_tile, y_tile))

    def __len__(self):
        return len(self.tiles)

    def __getitem__(self, idx):
        x_tile, y_tile = self.tiles[idx]
        x_tile = torch.tensor(x_tile.transpose(2, 0, 1), dtype=torch.float32)
        y_tile = torch.tensor(y_tile, dtype=torch.long)
        return x_tile, y_tile


In [None]:
# ===========================================
# 4. Load Dataset dan Splitting
# ===========================================

data_dir = "../data/processed"  # lokasi data hasil preprocessing
label_json_path = "../data/annotation/segmentation_masks/label_classes.json" # lokasi file label_classes.json

label_map = load_label_mapping(label_json_path)
print(f"Jumlah total kelas: {len(label_map)}")
print("Contoh nama kelas:", list(label_map.values())[:5])

# Ambil semua file .npz hasil preprocessing
all_files = sorted([os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".npz")])
print(f"\nTotal file ditemukan: {len(all_files)}")

# Tampilkan label unik per file
for f in all_files[:3]:
    data = np.load(f)
    mask = data["y"]
    unique_labels = np.unique(mask)
    print(f"{os.path.basename(f)} -> Label unik: {unique_labels}")

# Buat dataset dan dataloader
dataset = SeaweedDataset(all_files, label_map, tile_size=128)

# Split train-val 80%-20%
n = len(dataset)
idxs = list(range(n))
random.shuffle(idxs)
split = int(0.8 * n)
train_idx, val_idx = idxs[:split], idxs[split:]

train_loader = DataLoader(dataset, batch_size=4, sampler=torch.utils.data.SubsetRandomSampler(train_idx))
val_loader   = DataLoader(dataset, batch_size=4, sampler=torch.utils.data.SubsetRandomSampler(val_idx))

num_classes = len(label_map)
print(f"\nJumlah tile train: {len(train_idx)}, val: {len(val_idx)}")



In [None]:
# =================================================
# 5. Model Fully Convolutional HybridSN (3D+2D CNN)
# =================================================

class FCHybridSN(nn.Module):
    def __init__(self, in_bands=300, num_classes=41):
        super().__init__()
        self.conv3d_1 = nn.Conv3d(1, 16, (7,3,3), padding=(0,1,1))
        self.bn3d_1 = nn.BatchNorm3d(16)
        self.conv3d_2 = nn.Conv3d(16, 32, (5,3,3), padding=(0,1,1))
        self.bn3d_2 = nn.BatchNorm3d(32)
        self.conv3d_3 = nn.Conv3d(32, 64, (3,3,3), padding=(0,1,1))
        self.bn3d_3 = nn.BatchNorm3d(64)

        self._out_spec = in_bands - 12
        mid_ch = 256
        self.conv2d_1 = nn.Conv2d(64 * max(1, self._out_spec), mid_ch, 3, padding=1)
        self.bn2d_1 = nn.BatchNorm2d(mid_ch)
        self.conv2d_2 = nn.Conv2d(mid_ch, 128, 3, padding=1)
        self.bn2d_2 = nn.BatchNorm2d(128)
        self.conv2d_3 = nn.Conv2d(128, 64, 3, padding=1)
        self.bn2d_3 = nn.BatchNorm2d(64)
        self.classifier = nn.Conv2d(64, num_classes, 1)

    def forward(self, x):
        B, Bands, H, W = x.shape
        x3 = x.unsqueeze(1)
        x3 = F.relu(self.bn3d_1(self.conv3d_1(x3)))
        x3 = F.relu(self.bn3d_2(self.conv3d_2(x3)))
        x3 = F.relu(self.bn3d_3(self.conv3d_3(x3)))
        B, C3, out_spec, H, W = x3.shape
        x2 = x3.view(B, C3 * out_spec, H, W)
        x2 = F.relu(self.bn2d_1(self.conv2d_1(x2)))
        x2 = F.relu(self.bn2d_2(self.conv2d_2(x2)))
        x2 = F.relu(self.bn2d_3(self.conv2d_3(x2)))
        return self.classifier(x2)

model = FCHybridSN(in_bands=300, num_classes=num_classes).to(device)
print(model)


In [None]:
# ===========================================
# 6. Fungsi Training dan Evaluasi
# ===========================================

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

def pixel_accuracy(pred, target):
    valid = (target >= 0)
    correct = (pred[valid] == target[valid]).sum()
    total = valid.sum()
    return (correct.float() / (total.float() + 1e-9)).item()

def iou_per_class(pred, target, num_classes):
    ious = []
    for cls in range(num_classes):
        pred_i = (pred == cls)
        target_i = (target == cls)
        inter = (pred_i & target_i).sum()
        union = (pred_i | target_i).sum()
        if union == 0:
            ious.append(float('nan'))
        else:
            ious.append((inter.float() / union.float()).item())
    return ious



In [None]:
# ===========================================
# 7. Loop Training Utama
# ===========================================

num_epochs = 20
best_val_acc = 0.0

for epoch in range(1, num_epochs+1):
    model.train()
    running_loss = 0.0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    avg_loss = running_loss / len(train_loader)

    # Validasi
    model.eval()
    val_accs, val_ious = [], []
    with torch.no_grad():
        for xb, yb in val_loader:
            xb, yb = xb.to(device), yb.to(device)
            preds = model(xb).argmax(dim=1)
            val_accs.append(pixel_accuracy(preds, yb))
            val_ious.extend(iou_per_class(preds, yb, num_classes))
    mean_val_acc = np.nanmean(val_accs)
    mean_iou = np.nanmean([v for v in val_ious if not np.isnan(v)])
    print(f"Epoch {epoch}/{num_epochs} | Loss={avg_loss:.4f} | ValAcc={mean_val_acc:.4f} | mIoU={mean_iou:.4f}")

    if mean_val_acc > best_val_acc:
        best_val_acc = mean_val_acc
        torch.save(model.state_dict(), "hybridsn_seg_best.pth")
        print("OK, Model terbaik disimpan.")



In [None]:
# ===========================================
# 8. Visualisasi Hasil Prediksi
# ===========================================

model.load_state_dict(torch.load("hybridsn_seg_best.pth"))
model.eval()
xb, yb = next(iter(val_loader))
xb, yb = xb.to(device), yb.to(device)
with torch.no_grad():
    preds = model(xb).argmax(dim=1).cpu().numpy()
yb_np = yb.cpu().numpy()
visualize_tile(xb[0], y_true=yb_np[0], y_pred=preds[0])


In [None]:
# ===========================================
# 9. Inferensi Citra Penuh
# ===========================================

def infer_full_image(model, cube, tile_size=128, stride=96):
    model.eval()
    H, W, B = cube.shape
    out_logits = np.zeros((num_classes, H, W), dtype=np.float32)
    count = np.zeros((H, W), dtype=np.float32)
    for i in range(0, max(1, H - tile_size + 1), stride):
        for j in range(0, max(1, W - tile_size + 1), stride):
            tile = cube[i:i+tile_size, j:j+tile_size, :]
            if tile.shape[0] < tile_size or tile.shape[1] < tile_size:
                continue
            x = torch.from_numpy(np.transpose(tile, (2,0,1))).unsqueeze(0).to(device).float()
            with torch.no_grad():
                probs = F.softmax(model(x), dim=1).cpu().numpy()[0]
            out_logits[:, i:i+tile_size, j:j+tile_size] += probs
            count[i:i+tile_size, j:j+tile_size] += 1
    count[count==0] = 1.0
    pred_map = (out_logits / count[np.newaxis,...]).argmax(axis=0)
    return pred_map

p = all_files[0]
cube = np.load(p, allow_pickle=True)["x"]
predmap = infer_full_image(model, cube)
plt.figure(figsize=(8,6))
plt.imshow(predmap, cmap='tab20')
plt.title(f"Hasil Inferensi: {os.path.basename(p)}")
plt.show()

