In [None]:
import os
import glob
import json
import random
import numpy as np
import pandas as pd
import cv2
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import segmentation_models_pytorch as smp

In [None]:
DATASET_ROOT = Path("dataset/final_dataset")
AUG_DATASET_ROOT = Path("dataset/temp_aug_training") 
PATH_BASELINE_EXP = Path("experiments/baseline_run/mit_FPN_20251227_xxxxx") 
PATH_AUG_EXP      = Path("experiments/augmentation_run/mit_FPN_AUG_PHYSICAL_20251227_xxxxx")

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
sns.set_style("whitegrid")

print(f"Device: {DEVICE}")

In [None]:
def count_files(root_dir, split):
    return len(glob.glob(str(root_dir / split / "images" / "*.png")))

n_train_base = count_files(DATASET_ROOT, "train") 

In [None]:
if AUG_DATASET_ROOT.exists():
    n_train_aug = count_files(AUG_DATASET_ROOT, "train")
    # Cari file original di dalam folder aug (yang tidak ada _aug_)
    all_aug_files = glob.glob(str(AUG_DATASET_ROOT / "train" / "images" / "*.png"))
    n_train_base_actual = len([f for f in all_aug_files if "_aug_" not in f])
    
    data_counts = pd.DataFrame({
        "Dataset": ["Baseline (Original)", "Augmented (Physical)"],
        "Total Images": [n_train_base_actual, n_train_aug]
    })
    
    plt.figure(figsize=(6, 4))
    sns.barplot(data=data_counts, x="Dataset", y="Total Images", palette="viridis")
    plt.title("Perbandingan Jumlah Data Training")
    plt.ylabel("Jumlah Gambar")
    for i, v in enumerate(data_counts["Total Images"]):
        plt.text(i, v + 5, str(v), ha='center', fontweight='bold')
    plt.show()
else:
    print("Folder Augmented Temp tidak ditemukan. Pastikan path benar.")

In [None]:
def visualize_overlay(img, mask, title=""):
    # img: HxW, mask: HxW
    overlay = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
    # Mask merah transparan
    overlay[(mask > 127)] = overlay[(mask > 127)] * 0.5 + np.array([255, 0, 0]) * 0.5
    return overlay

if AUG_DATASET_ROOT.exists():
    train_dir = AUG_DATASET_ROOT / "train" / "images"
    # Cari sampel yang punya versi aug
    all_files = os.listdir(train_dir)
    aug_files = [f for f in all_files if "_aug_" in f]
    
    if aug_files:
        # Ambil 1 sampel acak
        sample_aug_name = random.choice(aug_files)
        # Rekonstruksi nama file asli
        base_name = sample_aug_name.split("_aug_")[0] + ".png"
        
        # Load paths
        path_orig_img = train_dir / base_name
        path_orig_msk = AUG_DATASET_ROOT / "train" / "masks" / base_name
        path_aug_img  = train_dir / sample_aug_name
        path_aug_msk  = AUG_DATASET_ROOT / "train" / "masks" / sample_aug_name
        
        # Read Images
        img_o = cv2.imread(str(path_orig_img), cv2.IMREAD_GRAYSCALE)
        msk_o = cv2.imread(str(path_orig_msk), cv2.IMREAD_GRAYSCALE)
        img_a = cv2.imread(str(path_aug_img), cv2.IMREAD_GRAYSCALE)
        msk_a = cv2.imread(str(path_aug_msk), cv2.IMREAD_GRAYSCALE)
        
        # Plotting
        plt.figure(figsize=(15, 8))
        
        # Baris 1: Original
        plt.subplot(2, 3, 1); plt.imshow(img_o, cmap='gray'); plt.title("Original Image (Raw)")
        plt.subplot(2, 3, 2); plt.imshow(msk_o, cmap='gray'); plt.title("Original Mask")
        plt.subplot(2, 3, 3); plt.imshow(visualize_overlay(img_o, msk_o)); plt.title("Original Overlay")
        
        # Baris 2: Augmented
        plt.subplot(2, 3, 4); plt.imshow(img_a, cmap='gray'); plt.title("Augmented Image (Rotated/Shifted)")
        plt.subplot(2, 3, 5); plt.imshow(msk_a, cmap='gray'); plt.title("Augmented Mask")
        plt.subplot(2, 3, 6); plt.imshow(visualize_overlay(img_a, msk_a)); plt.title("Augmented Overlay")
        
        plt.tight_layout()
        plt.show()
    else:
        print("Tidak ada file augmented ditemukan.")

In [None]:
def read_history(path):
    csv = path / "history.csv"
    if csv.exists(): return pd.read_csv(csv)
    return None

df_base = read_history(PATH_BASELINE_EXP)
df_aug  = read_history(PATH_AUG_EXP)

if df_base is not None and df_aug is not None:
    epochs = range(1, len(df_base) + 1)
    
    plt.figure(figsize=(18, 5))
    
    # 1. Validation Dice
    plt.subplot(1, 3, 1)
    plt.plot(df_base['val_dice'], label='Baseline', linestyle='--', color='blue')
    plt.plot(df_aug['val_dice'],  label='Augmented', color='red', linewidth=2)
    plt.title("Validation Dice Score (Higher is Better)")
    plt.xlabel("Epoch"); plt.legend(); plt.grid(True, alpha=0.3)
    
    # 2. Validation Loss
    plt.subplot(1, 3, 2)
    plt.plot(df_base['val_loss'], label='Baseline', linestyle='--', color='blue')
    plt.plot(df_aug['val_loss'],  label='Augmented', color='red', linewidth=2)
    plt.title("Validation Loss (Lower is Better)")
    plt.xlabel("Epoch"); plt.legend(); plt.grid(True, alpha=0.3)
    
    # 3. Gap Analysis (Train Dice - Val Dice) -> Semakin kecil gap, semakin tidak overfitting
    plt.subplot(1, 3, 3)
    gap_base = df_base['train_dice'] - df_base['val_dice']
    gap_aug  = df_aug['train_dice'] - df_aug['val_dice']
    plt.plot(gap_base, label='Baseline Gap', linestyle='--', color='blue')
    plt.plot(gap_aug,  label='Augmented Gap', color='red')
    plt.axhline(0, color='black', linewidth=0.5)
    plt.title("Overfitting Gap (Train - Val Dice)")
    plt.xlabel("Epoch"); plt.legend(); plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

In [None]:
# --- Load Models ---
class GrayToRGB(torch.nn.Module):
    def forward(self, x): return x.repeat(1, 3, 1, 1)

def load_inference_model(path):
    # Definisi ulang struktur model agar bisa load weights
    model = smp.FPN(encoder_name="mit_b5", encoder_weights=None, in_channels=3, classes=1)
    full_model = torch.nn.Sequential(GrayToRGB(), model)
    
    if path.exists():
        state = torch.load(path, map_location=DEVICE)
        full_model.load_state_dict(state)
        full_model.to(DEVICE)
        full_model.eval()
        return full_model
    else:
        print(f"Model not found at: {path}")
        return None

model_base = load_inference_model(PATH_BASELINE_EXP / "weights/best_model.pt")
model_aug  = load_inference_model(PATH_AUG_EXP / "weights/best_model.pt")

In [None]:
test_img_dir = DATASET_ROOT / "images" # Sesuaikan path jika pakai folder test terpisah
test_images = sorted(glob.glob(str(test_img_dir / "*.png")))

# Ambil 3 sampel acak
random.seed(42)
sample_indices = random.sample(range(len(test_images)), 3)

def predict_mask(model, img_tensor):
    with torch.no_grad():
        logits = model(img_tensor)
        pred = (torch.sigmoid(logits) > 0.5).float()
    return pred.cpu().numpy()[0, 0]

if model_base and model_aug:
    plt.figure(figsize=(15, 12))
    
    for i, idx in enumerate(sample_indices):
        path = test_images[idx]
        name = Path(path).stem
        
        # Read & Preprocess
        img_raw = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        h, w = img_raw.shape
        
        # Resize ke 512 untuk prediksi
        img_in = cv2.resize(img_raw, (512, 512))
        img_t = torch.from_numpy(img_in).float()/255.0
        img_t = img_t.unsqueeze(0).unsqueeze(0).to(DEVICE)
        
        # Ground Truth
        mask_path = DATASET_ROOT / "masks" / f"{name}_mask.png"
        gt_raw = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
        gt = cv2.resize(gt_raw, (512, 512))
        
        # Predict
        p_base = predict_mask(model_base, img_t)
        p_aug  = predict_mask(model_aug, img_t)
        
        # Difference Map (Green=Aug Better/Extra, Red=Base Better/Extra)
        # Logic: Hijau jika Aug=1 Base=0. Merah jika Base=1 Aug=0.
        diff = np.zeros((512, 512, 3))
        diff[(p_aug==1) & (p_base==0)] = [0, 1, 0] # Hijau
        diff[(p_aug==0) & (p_base==1)] = [1, 0, 0] # Merah
        diff[(p_aug==1) & (p_base==1)] = [1, 1, 1] # Putih (Sepakat)
        
        # Plotting Row
        row = i
        plt.subplot(3, 5, row*5 + 1); plt.imshow(img_in, cmap='gray'); plt.title("Input"); plt.axis('off')
        plt.subplot(3, 5, row*5 + 2); plt.imshow(gt, cmap='gray'); plt.title("Ground Truth"); plt.axis('off')
        plt.subplot(3, 5, row*5 + 3); plt.imshow(p_base, cmap='gray'); plt.title("Baseline Pred"); plt.axis('off')
        plt.subplot(3, 5, row*5 + 4); plt.imshow(p_aug, cmap='gray'); plt.title("Augmented Pred"); plt.axis('off')
        plt.subplot(3, 5, row*5 + 5); plt.imshow(diff); plt.title("Diff (G:Aug, R:Base)"); plt.axis('off')
        
    plt.tight_layout()
    plt.show()

In [None]:
def evaluate_dataset(model, img_paths, mask_root):
    dices = []
    worst_cases = [] # Simpan (score, img_path)
    
    for path in tqdm(img_paths, desc="Evaluating"):
        name = Path(path).stem
        mask_path = mask_root / f"{name}_mask.png"
        
        if not mask_path.exists(): continue
            
        # Load & Prep
        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        gt = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
        
        img = cv2.resize(img, (512, 512))
        gt  = cv2.resize(gt, (512, 512))
        gt_bin = (gt > 127).astype(float)
        
        img_t = torch.from_numpy(img).float()/255.0
        img_t = img_t.unsqueeze(0).unsqueeze(0).to(DEVICE)
        
        # Predict
        pred = predict_mask(model, img_t)
        
        # Dice Score
        intersection = (pred * gt_bin).sum()
        dice = (2. * intersection) / (pred.sum() + gt_bin.sum() + 1e-7)
        
        dices.append(dice)
        worst_cases.append((dice, path))
        
    return dices, sorted(worst_cases, key=lambda x: x[0])

print("Calculating Full Statistics...")
scores_base, worst_base = evaluate_dataset(model_base, test_images, DATASET_ROOT / "masks")
scores_aug,  worst_aug  = evaluate_dataset(model_aug,  test_images, DATASET_ROOT / "masks")

In [None]:
# Plot Boxplot
plt.figure(figsize=(8, 6))
plt.boxplot([scores_base, scores_aug], labels=['Baseline', 'Augmented'], patch_artist=True)
plt.title("Distribusi Dice Score pada Test Set")
plt.ylabel("Dice Score")
plt.grid(True, axis='y')
plt.show()

# --- B. Tampilkan Worst Cases Model Augmented ---
print("\n--- WORST CASES (AUGMENTED MODEL) ---")
plt.figure(figsize=(12, 4))
for i in range(3):
    score, path = worst_aug[i]
    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    name = Path(path).stem
    
    plt.subplot(1, 3, i+1)
    plt.imshow(img, cmap='gray')
    plt.title(f"{name}\nDice: {score:.4f}")
    plt.axis('off')
plt.show()

In [None]:
from scipy.stats import wilcoxon
stat, p = wilcoxon(scores_base, scores_aug)
print(f"Wilcoxon p-value: {p}")