## __Fine-Tuning Versi 1__
---

In [1]:
# CELL 1: IMPORT STANDARD & ATUR DEVICE

import os
import time
import numpy as np
import scipy.io as sio
from tqdm import tqdm
import matplotlib.pyplot as plt

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

# Sklearn untuk metrik
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score, classification_report

# Atur device (periksa GPU)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Menggunakan device:", device)


Menggunakan device: cuda


In [2]:
# CELL 2: IMPLEMENTASI zeroPadding_3D

def zeroPadding_3D(old_matrix, pad_length, pad_depth=0):
    """
    old_matrix: numpy array (H, W, B)
    pad_length: jumlah pad di spatial (keempat arah)
    pad_depth: optional, (default 0)
    """
    new_matrix = np.pad(old_matrix, ((pad_length, pad_length), (pad_length, pad_length), (pad_depth, pad_depth)),
                        mode='constant', constant_values=0)
    return new_matrix


In [3]:
# CELL 3: DEFINISI MODEL (encoder, transformer, projection head)

class SpectralSpatialEncoder3D(nn.Module):
    def __init__(self, embedding_dim=256, init_channels=32):
        super().__init__()
        # Konvolusi pertama (non-overlapping subpatch)
        self.conv1 = nn.Conv3d(in_channels=1, out_channels=init_channels,
                               kernel_size=(20,3,3), stride=(20,3,3), padding=0)
        self.bn1 = nn.BatchNorm3d(init_channels)
        self.relu1 = nn.ReLU(inplace=True)
        # Konvolusi kedua: linear projection ke embedding_dim
        self.conv2 = nn.Conv3d(in_channels=init_channels, out_channels=embedding_dim,
                               kernel_size=(1,1,1), stride=(1,1,1), padding=0)
        self.bn2 = nn.BatchNorm3d(embedding_dim)
        self.relu2 = nn.ReLU(inplace=True)

    def forward(self, x):
        # x: (B,1,224,9,9)
        x = self.relu1(self.bn1(self.conv1(x)))   # -> (B, init_ch, 11, 3, 3)
        x = self.relu2(self.bn2(self.conv2(x)))   # -> (B, 256, 11, 3, 3)
        B, C, D, H, W = x.shape
        # Permute dan flatten token axis -> (B, 99, 256)
        x = x.permute(0,2,3,4,1).contiguous().view(x.size(0), -1, x.size(1))
        return x

class SimpleTransformerEncoder(nn.Module):
    def __init__(self, embed_dim=256, num_heads=8, num_layers=5, mlp_dim=512, dropout=0.1):
        super().__init__()
        layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads,
                                           dim_feedforward=mlp_dim, dropout=dropout, batch_first=True)
        self.transformer = nn.TransformerEncoder(layer, num_layers=num_layers)

    def forward(self, x):
        # x: (B,99,256) -> keluar (B,99,256)
        return self.transformer(x)

class ProjectionHead_A(nn.Module): # Projection Head VERSI A
    def __init__(self, in_dim=256, proj_dim=128):
        super().__init__()
        self.net = nn.Linear(in_dim, proj_dim)

    def forward(self, x):
        # x: (B,99,256)
        x = x.mean(dim=1)   # Global average pooling antar token -> (B,256)
        return self.net(x)  # -> (B,128)
    
class ProjectionHead_B(nn.Module): # Projection Head VERSI B
    def __init__(self, in_dim=256, proj_dim=128):
        super().__init__()
        self.net = nn.Linear(in_dim, proj_dim)

    def forward(self, x):      # x: (B, 99, 256)
        x = self.net(x)        #  (B, 99, 128)  # proyeksi per-token
        x = x.mean(dim=1)      #  (B, 128)      # pooling global antar token
        return x

class ProjectionHead_C(nn.Module): # Projection Head VERSI C
    def __init__(self, proj_dim=128):
        super().__init__()
        self.net = nn.Linear(99, proj_dim)  # 99 ke 128

    def forward(self, x):  # x: (B, 99, 256)
        x = x.mean(dim=2)        # (B, 99, 1)  # GAP Dalam Token
        x = x.squeeze(-1)        # (B, 99)
        return self.net(x)        # (B, 128)

In [4]:
# CELL 4: LOAD HASIL PRETRAINING dan FREEZE semua kecuali classifier nanti

def build_frozen_parts_from_best_pretrained_model(variant, device):
    """
    Mengembalikan tiga komponen yang sudah dimuat:
    encoder, transformer, proj_head (semua parameter dibekukan)
    """
    # instantiate model parts
    encoder = SpectralSpatialEncoder3D(embedding_dim=256).to(device)
    transformer = SimpleTransformerEncoder(embed_dim=256).to(device)

    if variant == 'A':
        proj_head = ProjectionHead_A(in_dim=256, proj_dim=128).to(device)
        best_model_path = f"best_sst_ver3{variant}.pt" # ini file pretrained best model ver3A
    elif variant == 'B' :
        proj_head = ProjectionHead_B(in_dim=256, proj_dim=128).to(device)
        best_model_path = f"best_sst_ver3{variant}.pt" # ini file pretrained best model ver3B
    elif variant == 'C':
        proj_head = ProjectionHead_C(proj_dim=128).to(device)
        best_model_path = f"best_sst_ver3{variant}.pt" # ini file pretrained best model ver3C
    else:
        raise ValueError("variant must be 'A'/'B'/'C'")

    # load best model yang sudah dilatih sebelumnya
    assert os.path.exists(best_model_path), f"Best model tidak ditemukan: {best_model_path}"
    bm_point = torch.load(best_model_path, map_location=device) #bm_point untuk menampung best model yang di-load

    # muat state dict (jaga kompatibilitas)
    if "encoder_state" in bm_point:
        encoder.load_state_dict(bm_point["encoder_state"])
    if "transformer_state" in bm_point:
        transformer.load_state_dict(bm_point["transformer_state"])
    if "proj_head_state" in bm_point:
        try:
            proj_head.load_state_dict(bm_point["proj_head_state"])
        except Exception as e:
            # coba non-strict load bila ada mismatch minor
            print("[PERINGATAN] proj_head.load_state_dict error -> mencoba strict=False. Error:", e)
            proj_head.load_state_dict(bm_point["proj_head_state"], strict=False)

    # Freeze param agar tidak ikut update saat fine-tuning
    for p in encoder.parameters():
        p.requires_grad = False
    for p in transformer.parameters():
        p.requires_grad = False
    for p in proj_head.parameters():
        p.requires_grad = False

    return encoder, transformer, proj_head


In [5]:
# CELL 5: FTClassifier (hanya classifier yang trainable)

class FTClassifier(nn.Module):
    def __init__(self, encoder, transformer, proj_head, num_classes=2):
        super().__init__()
        # komponen beku (sudah di-freeze sebelumnya)
        self.encoder = encoder
        self.transformer = transformer
        self.proj_head = proj_head
        # classifier linear sederhana
        self.classifier = nn.Linear(128, num_classes)

    def forward(self, x):
        # Semua feature extraction dilakukan tanpa grad untuk menghemat memori
        with torch.no_grad():
            feat = self.encoder(x)        # (B,99,256)
            feat = self.transformer(feat) # (B,99,256)
            proj = self.proj_head(feat)   # (B,128)
        logits = self.classifier(proj)    # (B,2)
        return logits



In [6]:
# CELL 6: HELPERS - memuat dataset patch

def load_patch_dataset(data_dir="../data/processed", batch_size=32, val_split=0.2, seed=42):
    """
    Memuat patch_class0.npy dan patch_class1.npy,
    menggabungkan, kemudian membagi ke train/val
    Output: train_loader, val_loader
    """
    path0 = os.path.join(data_dir, "patch_class0.npy")
    path1 = os.path.join(data_dir, "patch_class1.npy")
    assert os.path.exists(path0) and os.path.exists(path1), "File patch_class*.npy tidak ditemukan"

    p0 = np.load(path0)  # shape (N0, 9,9,224)
    p1 = np.load(path1)  # shape (N1, 9,9,224)
    X_all = np.concatenate([p0, p1], axis=0)
    y_all = np.concatenate([np.zeros(len(p0)), np.ones(len(p1))], axis=0)

    # ubah ke tensor PyTorch format conv3d: (N,1,224,9,9)
    X_tensor = torch.tensor(X_all, dtype=torch.float32).unsqueeze(1).permute(0,1,4,2,3)
    y_tensor = torch.tensor(y_all, dtype=torch.long)

    # split train/val konsisten
    N = len(X_tensor)
    rng = torch.Generator().manual_seed(seed)
    indices = torch.randperm(N, generator=rng)
    val_size = int(val_split * N)
    val_idx = indices[:val_size]
    train_idx = indices[val_size:]

    train_ds = TensorDataset(X_tensor[train_idx], y_tensor[train_idx])
    val_ds = TensorDataset(X_tensor[val_idx], y_tensor[val_idx])
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader



In [10]:
# CELL 7: TRAINING LOOP untuk fine-tuning classifier (hanya classifier param yang dioptimasi)

def train_finetune(model, variant, train_loader, val_loader, device,
                   num_epochs=200, lr=1e-3, weight_decay=1e-4, patience=50):
    optimizer = optim.AdamW(model.classifier.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()
    best_val_loss = float('inf')
    no_improve = 0
    start_epoch = 1
    
    
    checkpoint_finetuned_path = f"checkpoint_sst_finetuned_ver3{variant}.pt" # ini file finetuned checkpoint untuk versi {variant}
    best_finetuned_path = f"best_finetuned_ver3{variant}.pt" # ini file finetuned best model untuk versi {variant}
    
    # ==== Jika checkpoint ada, lanjutkan dari sana ====
    if os.path.exists(checkpoint_finetuned_path):
        checkpoint = torch.load(checkpoint_finetuned_path, map_location=device)
        model.classifier.load_state_dict(checkpoint["classifier_state"])
        optimizer.load_state_dict(checkpoint["optimizer_state"])
        start_epoch = checkpoint["epoch"] + 1
        best_val_loss = checkpoint["best_val_loss"]
        print(f"[OK] Checkpoint ditemukan. Melanjutkan dari epoch {start_epoch}.")
    else:
        print("[MAAF] Tidak ditemukan checkpoint. Memulai training dari awal.")

    for epoch in range(start_epoch, num_epochs+1):
        start_time = time.time()

        model.train()
        total_loss = 0.0
        n = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs} - Train")
        for x_batch, y_batch in pbar:
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)

            logits = model(x_batch)

            loss = criterion(logits, y_batch)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * x_batch.size(0)
            n += x_batch.size(0)
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        avg_train_loss = total_loss / n

        # validasi
        model.eval()
        val_loss = 0.0
        nval = 0
        correct = 0
        with torch.no_grad():
            for xv, yv in val_loader:
                xv = xv.to(device); yv = yv.to(device)
                logits = model(xv)
                lossv = criterion(logits, yv)
                val_loss += lossv.item() * xv.size(0)
                nval += xv.size(0)
                preds = logits.argmax(dim=1)
                correct += (preds == yv).sum().item()
        avg_val_loss = val_loss / nval
        val_acc = correct / nval

        # Waktu per epoch
        epoch_time = time.time() - start_time

        print(f"Epoch [{epoch}/{num_epochs}]" 
              f"TrainLoss: {avg_train_loss:.4f} |"
              f"ValLoss: {avg_val_loss:.4f} |" 
              f"ValAcc: {val_acc:.4f}|"
              f"Time: {epoch_time:.2f}s")

        checkpoint = {
            "epoch" : epoch,
            "classifier_state" : model.classifier.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "best_val_loss" : best_val_loss
        }
        torch.save(checkpoint, checkpoint_finetuned_path)

        # checkpoint best classifier
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            no_improve = 0
            torch.save(checkpoint, best_finetuned_path)
            print(">> Model fine-tuned terbaik disimpan:", best_finetuned_path)
        else:
            no_improve += 1
            if no_improve >= patience:
                print("Early stopping triggered.")
                break

    print("Selesai training fine-tuning.")


In [11]:
# CELL 8: EKSEKUSI DATA LOADER

# (dari cell 6 : Data Loader)
# Contoh muat data
train_loader, val_loader = load_patch_dataset(data_dir="../data/processed", batch_size=32, val_split=0.2)
print("Jumlah sampel train:", sum(len(batch[0]) for batch in train_loader), "| jumlah batch train:", len(train_loader))
print("Jumlah sampel val:", sum(len(batch[0]) for batch in val_loader), "| jumlah batch val:", len(val_loader))

Jumlah sampel train: 8 | jumlah batch train: 1
Jumlah sampel val: 2 | jumlah batch val: 1


In [12]:
# CELL 9: INISIALISASI dan EKSEKUSI MODEL dan TRAINING

# (dari cell 4 : Build Frozen Parts)
variant = 'A' # BAGIAN INI BISA DIGANTI A, B, atau C
encoder_frozen, transformer_frozen, proj_head_frozen = build_frozen_parts_from_best_pretrained_model(variant, device)
print("Komponen pra-trained telah dimuat dan dibekukan.")

# (dari cell 5 : Classifier)
# Buat instance model FT
model = FTClassifier(encoder_frozen, transformer_frozen, proj_head_frozen, num_classes=2).to(device)
# Pastikan hanya parameter classifier yang requires_grad=True
trainable_params = [p for p in model.parameters() if p.requires_grad]
print("Jumlah parameter yang dilatih (harus hanya classifier):", sum(p.numel() for p in trainable_params))

# (dari cell 7 : Training Loop)
# Jalankan training
train_finetune(model, variant, train_loader, val_loader, device, num_epochs=200, lr=1e-3, weight_decay=1e-4, patience=50)

  bm_point = torch.load(best_model_path, map_location=device) #bm_point untuk menampung best model yang di-load


Komponen pra-trained telah dimuat dan dibekukan.
Jumlah parameter yang dilatih (harus hanya classifier): 258
[MAAF] Tidak ditemukan checkpoint. Memulai training dari awal.


Epoch 1/200 - Train:   0%|                   | 0/1 [00:00<?, ?it/s]


RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling cublasLtMatmul with transpose_mat1 1 transpose_mat2 0 m 256 n 792 k 256 mat1_ld 256 mat2_ld 256 result_ld 256 abcType 0 computeType 68 scaleType 0

In [None]:
# CELL 10: MUAT bobot classifier terbaik (jika ada)

saved_best_finetuned_path = f"best_finetuned_ver3{variant}.pt"

if os.path.exists(saved_best_finetuned_path):
    svb = torch.load(saved_best_finetuned_path, map_location=device)
    model.classifier.load_state_dict(svb["classifier_state"])
    print("Loaded best fine-tuned classifier from", saved_best_finetuned_path)
else:
    print("Tidak ditemukan fine-tuned checkpoint. Pastikan training selesai dan file tersimpan.")


============================== 
just separator
==============================

In [None]:
# CELL 11: INFERENCE PETA KLASIFIKASI (SLIDING WINDOW PATCH-CENTER) - versi batch untuk efisiensi

def inference_map_patch_center(full_image, model, device, patch_size=9, batch_size=256, pad_mode='zero'):
    """
    full_image: numpy (H, W, B)
    model: model FTClassifier yang sudah dimuat bobot classifier terbaik
    patch_size: 9 (patch spatial)
    pad_mode: hanya info; yang akan digunakan adalah zero padding (zeroPadding_3D)
    return: pred_map (H, W) int {0,1}
    """
    model.eval()
    H, W, B = full_image.shape
    assert B == 224, "Diharapkan 224 band; sesuaikan bila berbeda."

    half = patch_size // 2
    # Gunakan zero padding (seperti pada preprocessing)
    padded = zeroPadding_3D(full_image, half)  # hasil shape (H+2*half, W+2*half, B)

    # Pre-buat semua patch dalam bentuk tumpukan (agar memudahkan batching)
    coords = []
    patches = []
    # iterasi per piksel pusat
    for i in range(half, half + H):
        for j in range(half, half + W):
            patch = padded[i-half:i+half+1, j-half:j+half+1, :]  # (9,9,224)
            patches.append(patch)
            coords.append((i-half, j-half))

    patches = np.stack(patches, axis=0)  # shape (H*W, 9,9,224)
    N = patches.shape[0]

    # Konversi ke tensor conv3d format (N,1,224,9,9)
    X = torch.tensor(patches, dtype=torch.float32).unsqueeze(1).permute(0,1,4,2,3)

    loader = DataLoader(TensorDataset(X, torch.zeros(len(X), dtype=torch.long)),
                        batch_size=batch_size, shuffle=False)

    preds = []
    with torch.no_grad():
        for xb, _ in tqdm(loader, desc="Inferensi peta (batch)"):
            xb = xb.to(device)
            logits = model(xb)            # (B,2)
            p = logits.argmax(dim=1).cpu().numpy()
            preds.append(p)
    preds = np.concatenate(preds, axis=0)  # (H*W, )

    pred_map = preds.reshape(H, W)
    return pred_map


In [None]:
# CELL 12: MUAT GM01.mat -> JALANKAN INFERENCE -> SIMPAN PETA

# mat_path = "D:/CurrentlyActiveResearch/oilspill_project/data/raw/GM01.mat" # absolute path
mat_path = "../data/raw/GM01.mat" # relative path
assert os.path.exists(mat_path), f"File GM01.mat tidak ditemukan di {mat_path}"

mat = sio.loadmat(mat_path)
img = mat["img"]    # (H, W, B)
gt_map = mat["map"] # (H, W)

print("GM01 shapes -> img:", img.shape, "| gt:", gt_map.shape)

# Jalankan inference (Peringatan : Proses ini mungkin memakan memori & waktu)
pred_map_gm01 = inference_map_patch_center(img, model, device, patch_size=9, batch_size=512)

# Simpan peta prediksi
save_dir = "../data/result/"
os.makedirs(save_dir, exist_ok=True)
np.save(os.path.join(save_dir, "pred_map_GM01.npy"), pred_map_gm01)

print("Peta prediksi GM01 disimpan ke pred_map_GM01.npy")


In [None]:
# CELL 13: EVALUASI PETA

saved_pred_map_path = os.path.join(save_dir, "pred_map_GM01.npy")
assert os.path.exists(saved_pred_map_path), f"File pred_map_GM01.npy tidak ditemukan"

pred_map = np.load(saved_pred_map_path)
gt = gt_map.astype(int)

# Flatten untuk metrik
y_pred = pred_map.flatten()
y_true = gt.flatten()

oa = accuracy_score(y_true, y_pred)
f1_per_class = f1_score(y_true, y_pred, average=None)  # per class
cm = confusion_matrix(y_true, y_pred)

print("Overall Accuracy (OA):", oa)
print("F1 per kelas:", f1_per_class)
print("Confusion matrix:\n", cm)
print("\nReport klasifikasi (per-class precision/recall/f1):")
print(classification_report(y_true, y_pred, digits=4))


In [None]:
# CELL 14: VISUALISASI PETA PREDIKSI dan GT

plt.figure(figsize=(14,6))
plt.subplot(1,3,1)
plt.title("Citra (band visualisasi contoh)")
# untuk visual: ambil 3 pita (mis. 30, 20, 10)
b1, b2, b3 = 30, 20, 10
rgb = img[:,:, [b1,b2,b3]]
# normalisasi untuk tampil
rgb_norm = (rgb - rgb.min()) / (rgb.max() - rgb.min())
plt.imshow(rgb_norm)
plt.axis('off')

plt.subplot(1,3,2)
plt.title("Ground Truth GM01")
plt.imshow(gt, cmap='gray')
plt.axis('off')

plt.subplot(1,3,3)
plt.title("Prediksi GM01")
plt.imshow(pred_map_gm01, cmap='gray')
plt.axis('off')

plt.tight_layout()
plt.show()
