# **Filtered Back Projection (3D)**

In [None]:
import h5py
import numpy as np
import odl
import matplotlib.pyplot as plt

# ==========================================
# 1. KONFIGURASI FILE
# ==========================================
path_observation = "/content/dataset_sementara/observation_validation_000.hdf5"
path_ground_truth = "/content/dataset_sementara/ground_truth_validation_000.hdf5"

# ==========================================
# 2. SETUP GEOMETRI ODL
# ==========================================
print("‚öôÔ∏è Menyiapkan Geometri & GPU...")

reco_space = odl.uniform_discr(
    min_pt=[-0.13, -0.13], max_pt=[0.13, 0.13], shape=[362, 362], dtype='float32'
)
angle_partition = odl.uniform_partition(0, np.pi, 1000)
detector_partition = odl.uniform_partition(-0.19, 0.19, 513)
geometry = odl.tomo.Parallel2dGeometry(angle_partition, detector_partition)

try:
    operator = odl.tomo.RayTransform(reco_space, geometry, impl='astra_cuda')
    print("‚úÖ GPU T4 Aktif (astra_cuda).")
except:
    print("‚ö†Ô∏è GPU Gagal, pakai CPU. Akan lambat.")
    operator = odl.tomo.RayTransform(reco_space, geometry, impl='astra_cpu')

fbp = odl.tomo.fbp_op(operator)

# ==========================================
# 3. PROSES REKONSTRUKSI VOLUME
# ==========================================
def reconstruct_volume_3d():
    print(f"üöÄ Memproses file: {path_observation}")

    with h5py.File(path_observation, 'r') as f_obs:
        data_sinogram = f_obs['data'][:]
        num_slices = data_sinogram.shape[0]

        print(f"   Jumlah Slice dalam file: {num_slices}")
        print("   ‚è≥ Sedang merekonstruksi seluruh volume (mohon tunggu)...")

        # Wadah untuk menumpuk hasil (Volume 3D)
        volume_3d = np.zeros((num_slices, 362, 362), dtype=np.float32)

        for i in range(num_slices):
            sino = data_sinogram[i]

            rec = fbp(sino).asarray()

            rec = (rec - np.min(rec)) / (np.max(rec) - np.min(rec) + 1e-8)

            volume_3d[i, :, :] = rec

            if (i+1) % 20 == 0:
                print(f"   Processed {i+1}/{num_slices}...")

        print(f"‚úÖ Selesai! Volume 3D terbentuk.")
        print(f"üì¶ Dimensi Akhir: {volume_3d.shape} -> (Z, Y, X)")

        return volume_3d

volume_hasil = reconstruct_volume_3d()

# ==========================================
# 4. VISUALISASI PERBEDAAN (ORTHOGONAL VIEW)
# ==========================================

def show_3d_slices(vol):
    z, y, x = vol.shape

    # Ambil irisan tengah dari masing-masing sumbu
    slice_axial = vol[z // 2, :, :]      # Pandangan dari Atas (Standar CT)
    slice_coronal = vol[:, y // 2, :]    # Pandangan dari Depan
    slice_sagittal = vol[:, :, x // 2]   # Pandangan dari Samping

    plt.figure(figsize=(15, 5))

    # 1. Axial View (X-Y)
    plt.subplot(1, 3, 1)
    plt.imshow(slice_axial, cmap='gray')
    plt.title(f"1. Axial View (Atas)\nSlice Z={z//2}")
    plt.xlabel("X axis"); plt.ylabel("Y axis")

    # 2. Coronal View (X-Z)
    plt.subplot(1, 3, 2)
    plt.imshow(np.rot90(slice_coronal), cmap='gray')
    plt.title(f"2. Coronal View (Depan)\nSlice Y={y//2}")
    plt.xlabel("X axis"); plt.ylabel("Z axis (Tumpukan)")

    # 3. Sagittal View (Y-Z)
    plt.subplot(1, 3, 3)
    plt.imshow(np.rot90(slice_sagittal), cmap='gray')
    plt.title(f"3. Sagittal View (Samping)\nSlice X={x//2}")
    plt.xlabel("Y axis"); plt.ylabel("Z axis (Tumpukan)")

    plt.tight_layout()
    plt.show()

print("\nüìä Menampilkan Visualisasi 3D Orthogonal...")
show_3d_slices(volume_hasil)

# **Micro-patching**

Results were only used for Res-Att U-net

In [None]:
import numpy as np
import os
import glob
import h5py
import gc

# ==========================================
# 1. KONFIGURASI PATH
# ==========================================

# A. Folder Input FBP (Dari Google Drive - .npy)
input_drive_folder = "/content/drive/MyDrive/Finpro Pencit/HDF5/Hasil FBP/"

# B. Folder Ground Truth (Dari Lokal Colab - .hdf5)
gt_local_folder = "/content/dataset_sementara_GT/"

# C. Folder Output (Ke Google Drive - .npy Patches)
output_folder = "/content/drive/MyDrive/Finpro Pencit/HDF5/Training_Patches/"
os.makedirs(output_folder, exist_ok=True)

# Parameter Patching
PATCH_SIZE = (32, 32, 32)
STRIDE = (16, 16, 16)
THRESHOLD_AIR = 0.1

print(f"üìÇ Input FBP (Drive): {input_drive_folder}")
print(f"üìÇ Input GT (Lokal): {gt_local_folder}")
print(f"üíæ Output Patches: {output_folder}")

# ==========================================
# 2. EKSEKUSI BATCH HYBRID
# ==========================================
def process_patching_hybrid():
    search_pattern = os.path.join(input_drive_folder, "processed_input_*.npy")
    list_files = sorted(glob.glob(search_pattern))

    if len(list_files) == 0:
        print("‚ùå Tidak ada file Input FBP (.npy) di Drive.")
        return

    print(f"\nüéØ Memulai Patching untuk {len(list_files)} file pasangan.\n")

    for i, input_path in enumerate(list_files):
        # 1. Identifikasi File
        filename_npy = os.path.basename(input_path)
        file_id = filename_npy.split('_')[-1].replace('.npy', '')

        print(f"üî™ [{i+1}/{len(list_files)}] Memproses ID: {file_id}")

        # 2. Cari Pasangan GT HDF5 di Lokal
        gt_filename = f"ground_truth_validation_{file_id}.hdf5"
        gt_path = os.path.join(gt_local_folder, gt_filename)

        if not os.path.exists(gt_path):
            print(f"   ‚ö†Ô∏è SKIP: File GT {gt_filename} tidak ada di {gt_local_folder}")
            continue

        try:
            # A. Load Input FBP (NPY dari Drive)
            vol_input = np.load(input_path)

            # B. Load Target GT (HDF5 dari Lokal)
            with h5py.File(gt_path, 'r') as f_gt:
                valid_slices = vol_input.shape[0]

                vol_target_raw = f_gt['data'][:valid_slices] # Ambil sejumlah input

                vol_target = (vol_target_raw - np.min(vol_target_raw)) / \
                             (np.max(vol_target_raw) - np.min(vol_target_raw) + 1e-8)

            if vol_input.shape != vol_target.shape:
                print(f"   ‚ùå Dimensi beda! Input:{vol_input.shape} vs Target:{vol_target.shape}")
                continue

            # --- PROSES PATCHING ---
            batch_patches_input = []
            batch_patches_target = []

            z_len, y_len, x_len = vol_input.shape
            pz, py, px = PATCH_SIZE
            sz, sy, sx = STRIDE

            for z in range(0, z_len - pz + 1, sz):
                for y in range(0, y_len - py + 1, sy):
                    for x in range(0, x_len - px + 1, sx):

                        # Cek Threshold
                        patch_check = vol_target[z:z+pz, y:y+py, x:x+px]

                        if np.mean(patch_check) > THRESHOLD_AIR:
                            p_in = vol_input[z:z+pz, y:y+py, x:x+px]
                            p_gt = vol_target[z:z+pz, y:y+py, x:x+px]

                            batch_patches_input.append(p_in)
                            batch_patches_target.append(p_gt)

            # --- SIMPAN KE DRIVE ---
            if len(batch_patches_input) > 0:
                np_in = np.array(batch_patches_input, dtype=np.float32)
                np_gt = np.array(batch_patches_target, dtype=np.float32)

                save_name_in = os.path.join(output_folder, f"patch_in_{file_id}.npy")
                save_name_gt = os.path.join(output_folder, f"patch_gt_{file_id}.npy")

                np.save(save_name_in, np_in)
                np.save(save_name_gt, np_gt)

                print(f"   ‚úÖ Disimpan: {len(batch_patches_input)} patches.")
            else:
                print("   ‚ö†Ô∏è File ini kosong/hanya udara.")

            del vol_input, vol_target, vol_target_raw
            gc.collect()

        except Exception as e:
            print(f"   ‚ùå Error: {e}")

    print("\nüéâ SELESAI! Semua patch siap training di Drive.")

process_patching_hybrid()

# **3D Residual-Attention U-shaped Convolutional Neural Network (U-net)**

## **Training & Validation**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from pytorch_msssim import ssim
import numpy as np
import os
import glob
import time
import bisect
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime

# ==========================================
# 1. KONFIGURASI
# ==========================================
data_folder = "/content/drive/MyDrive/Finpro Pencit/HDF5/Training_Patches/"

experiment_name = "Run_Final_V6_LimitBatch"

output_dir = f"/content/drive/MyDrive/Finpro Pencit/Training_Logs/{experiment_name}/"
os.makedirs(output_dir, exist_ok=True)

# Hyperparameters
BATCH_SIZE = 2
NUM_WORKERS = 0
LEARNING_RATE = 1e-4
NUM_EPOCHS = 50

# PEMBAGIAN DATA
TEST_SPLIT = 0.10
VAL_SPLIT = 0.20

# LIMIT BATCH (Agar Epoch Cepat)
TRAIN_BATCH_LIMIT = 2000
VAL_BATCH_LIMIT = 200

print(f"üìÇ Folder Output: {output_dir}")
print(f"‚è±Ô∏è Limit Batch: {TRAIN_BATCH_LIMIT} per Epoch")

# ==========================================
# 2. KOMPONEN DATASET & MODEL
# ==========================================
class LoDoPaBDataset(Dataset):
    def __init__(self, folder_path):
        self.folder_path = folder_path
        self.input_files = sorted(glob.glob(os.path.join(folder_path, "patch_in_*.npy")))
        if len(self.input_files) == 0:
            raise RuntimeError(f"‚ùå Error: Tidak ada file .npy di {folder_path}")

        print(f"üîÑ Mengindeks {len(self.input_files)} file... (Lazy Loading)")
        self.file_indices = []
        self.cumulative_indices = [0]
        for f_path in self.input_files:
            data = np.load(f_path, mmap_mode='r')
            num = data.shape[0]
            self.file_indices.append(num)
            self.cumulative_indices.append(self.cumulative_indices[-1] + num)
        self.total_patches = self.cumulative_indices[-1]
        print(f"‚úÖ Total Data: {self.total_patches} patches")

    def __len__(self): return self.total_patches

    def __getitem__(self, idx):
        file_idx = bisect.bisect_right(self.cumulative_indices, idx) - 1
        local_idx = idx - self.cumulative_indices[file_idx]

        input_path = self.input_files[file_idx]
        filename = os.path.basename(input_path)
        target_path = os.path.join(self.folder_path, filename.replace("patch_in", "patch_gt"))

        d_in = np.load(input_path, mmap_mode='r')
        d_gt = np.load(target_path, mmap_mode='r')

        p_in = np.array(d_in[local_idx]).astype(np.float32)
        p_gt = np.array(d_gt[local_idx]).astype(np.float32)

        return torch.from_numpy(np.expand_dims(p_in, axis=0)), torch.from_numpy(np.expand_dims(p_gt, axis=0))

# --- MODEL ARSITEKTUR (Fixed) ---
class ResBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_c, out_c, 3, padding=1), nn.BatchNorm3d(out_c), nn.ReLU(True),
            nn.Conv3d(out_c, out_c, 3, padding=1), nn.BatchNorm3d(out_c))
        self.relu = nn.ReLU(True)
        self.sc = nn.Conv3d(in_c, out_c, 1) if in_c != out_c else nn.Sequential()
    def forward(self, x): return self.relu(self.conv(x) + self.sc(x))

class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        self.Wg = nn.Sequential(nn.Conv3d(F_g, F_int, 1), nn.BatchNorm3d(F_int))
        self.Wx = nn.Sequential(nn.Conv3d(F_l, F_int, 1), nn.BatchNorm3d(F_int))
        self.psi = nn.Sequential(nn.Conv3d(F_int, 1, 1), nn.BatchNorm3d(1), nn.Sigmoid())
        self.relu = nn.ReLU(True)
    def forward(self, g, x):
        psi = self.relu(self.Wg(g) + self.Wx(x))
        return x * self.psi(psi)

class Lightweight3DUNet(nn.Module):
    def __init__(self, in_c=1, out_c=1):
        super().__init__()
        f = [16, 32, 64, 128]
        self.enc1 = ResBlock(in_c, f[0]); self.p1 = nn.MaxPool3d(2)
        self.enc2 = ResBlock(f[0], f[1]); self.p2 = nn.MaxPool3d(2)
        self.enc3 = ResBlock(f[1], f[2]); self.p3 = nn.MaxPool3d(2)
        self.bot = ResBlock(f[2], f[3])

        self.up3 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        self.att3 = AttentionBlock(f[3], f[2], f[2])
        self.dec3 = ResBlock(f[3]+f[2], f[2])

        self.up2 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        self.att2 = AttentionBlock(f[2], f[1], f[1])
        self.dec2 = ResBlock(f[2]+f[1], f[1])

        self.up1 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        self.att1 = AttentionBlock(f[1], f[0], f[0])
        self.dec1 = ResBlock(f[1]+f[0], f[0])

        self.out = nn.Conv3d(f[0], out_c, 1)
        self.dropout = nn.Dropout3d(0.2)

    def forward(self, x):
        e1 = self.enc1(x); p1 = self.p1(e1)
        e2 = self.enc2(p1); p2 = self.p2(e2)
        e3 = self.enc3(p2); p3 = self.p3(e3)
        b = self.bot(p3)

        d3 = self.up3(b); x3 = self.att3(d3, e3); d3 = self.dec3(torch.cat([x3, d3], 1))
        d2 = self.up2(d3); x2 = self.att2(d2, e2); d2 = self.dec2(torch.cat([x2, d2], 1))
        d1 = self.up1(d2); x1 = self.att1(d1, e1); d1 = self.dec1(torch.cat([x1, d1], 1))
        return self.out(self.dropout(d1))

class GradientLoss3D(nn.Module):
    def __init__(self):
        super().__init__()
        k = torch.FloatTensor([[[[-1,0,1],[-2,0,2],[-1,0,1]],[[-2,0,2],[-4,0,4],[-2,0,2]],[[-1,0,1],[-2,0,2],[-1,0,1]]]])
        self.k = nn.Parameter(k.unsqueeze(1), requires_grad=False)
    def forward(self, p, t):
        if p.device != self.k.device: self.k = self.k.to(p.device)
        gp = torch.abs(torch.nn.functional.conv3d(p, self.k, padding=1))
        gt = torch.abs(torch.nn.functional.conv3d(t, self.k, padding=1))
        return torch.mean(torch.abs(gp - gt))

def composite_loss(p, t, g_fn):
    return nn.MSELoss()(p, t) + 0.1*(1-ssim(p, t, data_range=1.0, size_average=True)) + 0.01*g_fn(p, t)

# ==========================================
# 3. PERSIAPAN DATA & AUTO-RESUME
# ==========================================
full_ds = LoDoPaBDataset(data_folder)
total_len = len(full_ds)
test_len = int(total_len * TEST_SPLIT)
val_len = int(total_len * VAL_SPLIT)
train_len = total_len - val_len - test_len

train_ds, val_ds, test_ds = random_split(full_ds, [train_len, val_len, test_len])
print(f"üìä SPLIT DATA: Train={len(train_ds)} | Val={len(val_ds)} | Test={len(test_ds)} (Disisihkan)")

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Lightweight3DUNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
grad_fn = GradientLoss3D().to(device)

# --- LOGIKA RESUME ---
start_epoch = 0
best_val_loss = float('inf')
log_history = []
csv_path = os.path.join(output_dir, "training_log.csv")

checkpoints = glob.glob(os.path.join(output_dir, "checkpoint_ep*.pth"))
if len(checkpoints) > 0:
    latest = max(checkpoints, key=os.path.getctime)
    print(f"\nüîÑ Melanjutkan dari: {os.path.basename(latest)}")
    model.load_state_dict(torch.load(latest, map_location=device))
    try: start_epoch = int(latest.split('_ep')[-1].replace('.pth',''))
    except: start_epoch = 0

    if os.path.exists(csv_path):
        df = pd.read_csv(csv_path)
        log_history = df.to_dict('records')
        if len(log_history) > 0: best_val_loss = min([x['val_loss'] for x in log_history])
        print(f"üìà History: {len(log_history)} epoch terpulihkan.")
else:
    print("\nüÜï Memulai Training Baru.")

# ==========================================
# 4. TRAINING LOOP (VERBOSE & LIMITED)
# ==========================================
print(f"üî• START TRAINING (Epoch {start_epoch+1}/{NUM_EPOCHS})")

for epoch in range(start_epoch, NUM_EPOCHS):
    t0 = time.time()
    model.train()
    train_loss = 0.0
    count = 0

    # Train Loop (Limit Batch)
    for i, (x, y) in enumerate(train_loader):
        if i >= TRAIN_BATCH_LIMIT: break # Limit check

        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = composite_loss(out, y, grad_fn)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        count += 1

        if (i+1) % 100 == 0:
            print(f"   [Epoch {epoch+1}] Batch {i+1}/{TRAIN_BATCH_LIMIT} | Loss: {loss.item():.5f}")

    avg_train = train_loss / count if count > 0 else 0

    # Validation Loop
    print(f"   ‚è≥ Validasi...")
    model.eval()
    val_loss = 0.0
    v_count = 0
    with torch.no_grad():
        for i, (x, y) in enumerate(val_loader):
            if i >= VAL_BATCH_LIMIT: break
            x, y = x.to(device), y.to(device)
            val_loss += composite_loss(model(x), y, grad_fn).item()
            v_count += 1

    avg_val = val_loss / v_count if v_count > 0 else 0
    dt = (time.time() - t0)/60

    print(f"‚úÖ Epoch {epoch+1} Selesai ({dt:.1f}m) | Train: {avg_train:.5f} | Val: {avg_val:.5f}")

    # Save & Log
    log_history.append({'epoch': epoch+1, 'train_loss': avg_train, 'val_loss': avg_val, 'time_m': dt})
    pd.DataFrame(log_history).to_csv(csv_path, index=False)

    torch.save(model.state_dict(), os.path.join(output_dir, f"checkpoint_ep{epoch+1}.pth"))
    if avg_val < best_val_loss:
        best_val_loss = avg_val
        torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pth"))
        print(f"   üèÜ New Best Model Saved!")

    if epoch > 3:
        old = os.path.join(output_dir, f"checkpoint_ep{epoch-2}.pth")
        if os.path.exists(old): os.remove(old)

# ==========================================
# 5. POST-TRAINING REPORT (SPECS & PLOT)
# ==========================================
print("\nüéâ TRAINING SELESAI! Menampilkan Laporan Lengkap...")

# A. Spesifikasi Model
total_params = sum(p.numel() for p in model.parameters())
print("\n" + "="*30)
print(f"ü§ñ SPESIFIKASI MODEL")
print("="*30)
print(f"Model Name      : Lightweight 3D Res-Att U-Net")
print(f"Total Parameter : {total_params:,} parameters")
print(f"Input Shape     : (Batch, 1, 32, 32, 32)")
print(f"Filters Config  : [16, 32, 64, 128]")
print(f"Dropout Rate    : 0.2 (Monte Carlo Ready)")
print(f"Loss Function   : Composite (MSE + SSIM + Gradient)")
print("="*30 + "\n")

# B. Plot Grafik Training
if len(log_history) > 0:
    df = pd.DataFrame(log_history)
    plt.figure(figsize=(10, 6))
    plt.plot(df['epoch'], df['train_loss'], label='Training Loss', marker='o', linestyle='-')
    plt.plot(df['epoch'], df['val_loss'], label='Validation Loss', marker='s', linestyle='--')

    plt.title(f"Training History: {experiment_name}")
    plt.xlabel("Epoch")
    plt.ylabel("Composite Loss")
    plt.legend()
    plt.grid(True, alpha=0.3)

    plot_path = os.path.join(output_dir, "final_training_plot.png")
    plt.savefig(plot_path, dpi=300)
    print(f"üìä Grafik disimpan di: {plot_path}")
    plt.show()
else:
    print("‚ö†Ô∏è Tidak ada data log untuk di-plot.")

## **Testing with Patch Data**

In [None]:
# ==========================================
# 0. SETUP & LIBRARY
# ==========================================
import os
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import skew, kurtosis, entropy
from tqdm import tqdm
import bisect

# Install Library SSIM CPU
try:
    from skimage.metrics import structural_similarity as ssim
except ImportError:
    !pip install scikit-image -q
    from skimage.metrics import structural_similarity as ssim

# ==========================================
# 1. KONFIGURASI PATH (GOOGLE DRIVE)
# ==========================================
experiment_name = "Run_Final_V6_LimitBatch"
base_output_dir = f"/content/drive/MyDrive/Finpro Pencit/Output_Testing_Final/"
img_save_dir = os.path.join(base_output_dir, "Visual_Evidence")

# Buat Folder
os.makedirs(base_output_dir, exist_ok=True)
os.makedirs(img_save_dir, exist_ok=True)

print(f"üìÇ Hasil Analisis (Excel/Grafik) akan disimpan di: {base_output_dir}")
print(f"üìÇ Bukti Gambar (PNG) akan disimpan di: {img_save_dir}")

# Path Data & Model
data_folder = "/content/drive/MyDrive/Finpro Pencit/Training_Patches"
checkpoint_path = f"/content/drive/MyDrive/Finpro Pencit/Training_Logs/{experiment_name}/best_model.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==========================================
# 2. DEFINISI MODEL
# ==========================================
class ResBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_c, out_c, 3, padding=1), nn.BatchNorm3d(out_c), nn.ReLU(True),
            nn.Conv3d(out_c, out_c, 3, padding=1), nn.BatchNorm3d(out_c))
        self.relu = nn.ReLU(True)
        self.sc = nn.Conv3d(in_c, out_c, 1) if in_c != out_c else nn.Sequential()
    def forward(self, x): return self.relu(self.conv(x) + self.sc(x))

class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        self.Wg = nn.Sequential(nn.Conv3d(F_g, F_int, 1), nn.BatchNorm3d(F_int))
        self.Wx = nn.Sequential(nn.Conv3d(F_l, F_int, 1), nn.BatchNorm3d(F_int))
        self.psi = nn.Sequential(nn.Conv3d(F_int, 1, 1), nn.BatchNorm3d(1), nn.Sigmoid())
        self.relu = nn.ReLU(True)
    def forward(self, g, x):
        psi = self.relu(self.Wg(g) + self.Wx(x))
        return x * self.psi(psi)

class Lightweight3DUNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()
        in_c, out_c = in_channels, out_channels
        f = [16, 32, 64, 128]
        self.enc1 = ResBlock(in_c, f[0]); self.p1 = nn.MaxPool3d(2)
        self.enc2 = ResBlock(f[0], f[1]); self.p2 = nn.MaxPool3d(2)
        self.enc3 = ResBlock(f[1], f[2]); self.p3 = nn.MaxPool3d(2)
        self.bot = ResBlock(f[2], f[3])

        self.up3 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        self.att3 = AttentionBlock(f[3], f[2], f[2])
        self.dec3 = ResBlock(f[3]+f[2], f[2])
        self.up2 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        self.att2 = AttentionBlock(f[2], f[1], f[1])
        self.dec2 = ResBlock(f[2]+f[1], f[1])
        self.up1 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        self.att1 = AttentionBlock(f[1], f[0], f[0])
        self.dec1 = ResBlock(f[1]+f[0], f[0])
        self.out = nn.Conv3d(f[0], out_c, 1)
        self.dropout = nn.Dropout3d(0.2)

    def forward(self, x):
        e1 = self.enc1(x); p1 = self.p1(e1)
        e2 = self.enc2(p1); p2 = self.p2(e2)
        e3 = self.enc3(p2); p3 = self.p3(e3)
        b = self.bot(p3)

        d3 = self.up3(b)
        if d3.size() != e3.size(): d3 = F.interpolate(d3, size=e3.shape[2:], mode='trilinear', align_corners=True)
        x3 = self.att3(d3, e3); d3 = self.dec3(torch.cat([x3, d3], 1))
        d2 = self.up2(d3)
        if d2.size() != e2.size(): d2 = F.interpolate(d2, size=e2.shape[2:], mode='trilinear', align_corners=True)
        x2 = self.att2(d2, e2); d2 = self.dec2(torch.cat([x2, d2], 1))
        d1 = self.up1(d2)
        if d1.size() != e1.size(): d1 = F.interpolate(d1, size=e1.shape[2:], mode='trilinear', align_corners=True)
        x1 = self.att1(d1, e1); d1 = self.dec1(torch.cat([x1, d1], 1))
        return self.out(self.dropout(d1))

UNet3D = Lightweight3DUNet

# ==========================================
# 3. DATASET & SPLIT
# ==========================================
class LoDoPaBDataset(Dataset):
    def __init__(self, folder_path):
        self.folder_path = folder_path
        self.input_files = sorted(glob.glob(os.path.join(folder_path, "patch_in_*.npy")))
        if len(self.input_files) == 0:
            local_path = "/content/data_lokal_patches/"
            if os.path.exists(local_path):
                self.input_files = sorted(glob.glob(os.path.join(local_path, "patch_in_*.npy")))
                self.folder_path = local_path

        self.cumulative_indices = [0]
        for f_path in self.input_files:
            try:
                data = np.load(f_path, mmap_mode='r')
                self.cumulative_indices.append(self.cumulative_indices[-1] + data.shape[0])
            except: pass
        self.total_patches = self.cumulative_indices[-1]

    def __len__(self): return self.total_patches

    def __getitem__(self, idx):
        file_idx = bisect.bisect_right(self.cumulative_indices, idx) - 1
        local_idx = idx - self.cumulative_indices[file_idx]

        input_path = self.input_files[file_idx]
        filename = os.path.basename(input_path)
        target_path = os.path.join(self.folder_path, filename.replace("patch_in", "patch_gt"))

        d_in = np.load(input_path, mmap_mode='r')
        d_gt = np.load(target_path, mmap_mode='r')

        p_in = np.clip(np.array(d_in[local_idx]).astype(np.float32), 0, 1)
        p_gt = np.clip(np.array(d_gt[local_idx]).astype(np.float32), 0, 1)

        return torch.from_numpy(np.expand_dims(p_in, axis=0)), torch.from_numpy(np.expand_dims(p_gt, axis=0))

# SETUP DATA SPLIT
full_ds = LoDoPaBDataset(data_folder)
total_len = len(full_ds)

# Safety Logic
if total_len < 20:
    test_len = 2 # Minimal 2 data
    val_len = 0
    train_len = total_len - test_len
else:
    test_len = int(total_len * 0.10)
    val_len = int(total_len * 0.20)
    train_len = total_len - val_len - test_len

generator = torch.Generator().manual_seed(42)
_, _, test_ds = random_split(full_ds, [train_len, val_len, test_len], generator=generator)

print(f"üìä Total Data Test (10%): {len(test_ds)} Patches")
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False)

# ==========================================
# 4. METRICS & RADIOMICS FUNCTIONS
# ==========================================
def calculate_metrics_numpy(pred, gt):
    """PSNR, RMSE, SSIM (CPU Version - Anti Error)"""
    pred = pred.astype(np.float32)
    gt = gt.astype(np.float32)

    # 1. MSE & RMSE
    mse = np.mean((pred - gt) ** 2)
    rmse = np.sqrt(mse)

    # 2. PSNR
    if mse == 0: psnr = 100
    else: psnr = 20 * np.log10(1.0 / rmse)

    # 3. SSIM 3D
    ssim_val = 0
    D = gt.shape[0]
    for z in range(D):
        try:
            s = ssim(gt[z], pred[z], data_range=1.0, win_size=3)
        except ValueError:
            s = ssim(gt[z], pred[z], data_range=1.0)
        ssim_val += s

    return psnr, ssim_val / D, rmse

def get_radiomics(img_vol):
    """Radiomics Sederhana"""
    flat = img_vol.flatten()
    return {
        'Mean': np.mean(flat),
        'Std': np.std(flat),
        'Skewness': skew(flat),
        'Kurtosis': kurtosis(flat),
        'Entropy': entropy(np.histogram(flat, bins=50)[0] + 1e-8)
    }

# ==========================================
# 5. EXECUTION LOOP
# ==========================================
# Load Model
model = UNet3D(in_channels=1, out_channels=1).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device)
if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict'])
else: model.load_state_dict(checkpoint)

metrics_log = []
radiomics_log = []

SAVE_IMG_LIMIT = 50
img_saved_count = 0

print("üöÄ Memulai Pengujian...")

with torch.no_grad():
    for i, (inputs, targets) in enumerate(tqdm(test_loader)):
        patient_id = f"Test_Sample_{i+1:04d}"
        inputs, targets = inputs.to(device), targets.to(device)

        # A. Prediksi Normal
        model.eval()
        output_std = model(inputs)

        # B. Monte Carlo Uncertainty
        uncertainty_vol = None
        if img_saved_count < SAVE_IMG_LIMIT:
            model.train()
            mc_stack = []
            for _ in range(5):
                mc_stack.append(model(inputs).cpu().numpy())
            uncertainty_vol = np.std(np.array(mc_stack), axis=0)[0, 0]

        # C. Konversi Numpy
        vol_in = inputs[0, 0].cpu().numpy()
        vol_gt = targets[0, 0].cpu().numpy()
        vol_out = output_std[0, 0].cpu().numpy()

        # D. Hitung Metrik
        psnr, ssim_val, rmse = calculate_metrics_numpy(vol_out, vol_gt)

        # E. Hitung Radiomics
        rad_gt = get_radiomics(vol_gt)
        rad_out = get_radiomics(vol_out)

        metrics_log.append({'ID': patient_id, 'PSNR': psnr, 'SSIM': ssim_val, 'RMSE': rmse})

        rad_entry = {'ID': patient_id}
        for k in rad_gt:
            rad_entry[f'{k}_Error'] = abs(rad_gt[k] - rad_out[k])
        radiomics_log.append(rad_entry)

        # F. Simpan Gambar Bukti
        if img_saved_count < SAVE_IMG_LIMIT and uncertainty_vol is not None:
            mid = vol_gt.shape[0] // 2

            plt.figure(figsize=(20, 5))
            # 1. Input
            plt.subplot(1, 5, 1); plt.imshow(vol_in[mid], cmap='gray'); plt.title("Input Low Dose")
            plt.axis('off')
            # 2. Output
            plt.subplot(1, 5, 2); plt.imshow(vol_out[mid], cmap='gray'); plt.title(f"Output AI\nPSNR: {psnr:.2f}")
            plt.axis('off')
            # 3. GT
            plt.subplot(1, 5, 3); plt.imshow(vol_gt[mid], cmap='gray'); plt.title("Ground Truth")
            plt.axis('off')
            # 4. Uncertainty
            plt.subplot(1, 5, 4); plt.imshow(uncertainty_vol[mid], cmap='jet'); plt.title("Uncertainty Map")
            plt.colorbar(fraction=0.046); plt.axis('off')
            # 5. Difference
            plt.subplot(1, 5, 5); plt.imshow(np.abs(vol_gt[mid] - vol_out[mid]), cmap='inferno'); plt.title("Difference Error")
            plt.axis('off')

            plt.tight_layout()
            plt.savefig(os.path.join(img_save_dir, f"{patient_id}_evidence.png"))
            plt.close()
            img_saved_count += 1

# ==========================================
# 6. SIMPAN HASIL AKHIR
# ==========================================
df_metrics = pd.DataFrame(metrics_log)
df_radiomics = pd.DataFrame(radiomics_log)

df_metrics.to_csv(os.path.join(base_output_dir, "FINAL_Metrics.csv"), index=False)
df_radiomics.to_csv(os.path.join(base_output_dir, "FINAL_Radiomics.csv"), index=False)

print("\n" + "="*40)
print("‚úÖ TESTING SELESAI!")
print("="*40)
print(f"üìä Rata-rata PSNR: {df_metrics['PSNR'].mean():.2f} dB")
print(f"üìä Rata-rata SSIM: {df_metrics['SSIM'].mean():.4f}")
print(f"üì∏ {img_saved_count} Gambar bukti tersimpan di Drive.")
print(f"üìÇ Lokasi: {base_output_dir}")

## **Testing with Full Volume Data**

In [None]:
import torch
import torch.nn as nn
import numpy as np
import os
import glob
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats
import math
from google.colab import drive
from tqdm.auto import tqdm
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

# ==========================================
# 1. KONFIGURASI PATH
# ==========================================
drive.mount('/content/drive')

MODEL_PATH = '/content/drive/MyDrive/Finpro Pencit/Training_Logs/Run_Final_V6_FineTuned_Texture/best_model_finetuned.pth'

DATA_FOLDER = '/content/drive/MyDrive/Finpro Pencit/Hasil FBP/'

OUTPUT_DIR = '/content/drive/MyDrive/Finpro Pencit/Hasil_Evaluasi_LinearBlend_Fixed/'
os.makedirs(OUTPUT_DIR, exist_ok=True)

PATCH_SIZE = 32
OVERLAP_PERCENT = 0.5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==========================================
# 2. DEFINISI MODEL
# ==========================================
class ResBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_c, out_c, 3, padding=1), nn.BatchNorm3d(out_c), nn.ReLU(True),
            nn.Conv3d(out_c, out_c, 3, padding=1), nn.BatchNorm3d(out_c))
        self.relu = nn.ReLU(True)
        self.sc = nn.Conv3d(in_c, out_c, 1) if in_c != out_c else nn.Sequential()
    def forward(self, x): return self.relu(self.conv(x) + self.sc(x))

class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        self.Wg = nn.Sequential(nn.Conv3d(F_g, F_int, 1), nn.BatchNorm3d(F_int))
        self.Wx = nn.Sequential(nn.Conv3d(F_l, F_int, 1), nn.BatchNorm3d(F_int))
        self.psi = nn.Sequential(nn.Conv3d(F_int, 1, 1), nn.BatchNorm3d(1), nn.Sigmoid())
        self.relu = nn.ReLU(True)
    def forward(self, g, x):
        psi = self.relu(self.Wg(g) + self.Wx(x))
        return x * self.psi(psi)

class Lightweight3DUNet(nn.Module):
    def __init__(self, in_c=1, out_c=1):
        super().__init__()
        f = [16, 32, 64, 128]
        self.enc1 = ResBlock(in_c, f[0]); self.p1 = nn.MaxPool3d(2)
        self.enc2 = ResBlock(f[0], f[1]); self.p2 = nn.MaxPool3d(2)
        self.enc3 = ResBlock(f[1], f[2]); self.p3 = nn.MaxPool3d(2)
        self.bot = ResBlock(f[2], f[3])

        self.up3 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        self.att3 = AttentionBlock(f[3], f[2], f[2])
        self.dec3 = ResBlock(f[3]+f[2], f[2])

        self.up2 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        self.att2 = AttentionBlock(f[2], f[1], f[1])
        self.dec2 = ResBlock(f[2]+f[1], f[1])

        self.up1 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        self.att1 = AttentionBlock(f[1], f[0], f[0])
        self.dec1 = ResBlock(f[1]+f[0], f[0])

        self.out = nn.Conv3d(f[0], out_c, 1)
        self.dropout = nn.Dropout3d(0.2)

    def forward(self, x):
        e1 = self.enc1(x); p1 = self.p1(e1)
        e2 = self.enc2(p1); p2 = self.p2(e2)
        e3 = self.enc3(p2); p3 = self.p3(e3)
        b = self.bot(p3)

        d3 = self.up3(b); x3 = self.att3(d3, e3); d3 = self.dec3(torch.cat([x3, d3], 1))
        d2 = self.up2(d3); x2 = self.att2(d2, e2); d2 = self.dec2(torch.cat([x2, d2], 1))
        d1 = self.up1(d2); x1 = self.att1(d1, e1); d1 = self.dec1(torch.cat([x1, d1], 1))
        return self.out(self.dropout(d1))

# ==========================================
# 3. HELPER FUNCTIONS
# ==========================================
def normalize_data(data):
    """Normalisasi standar 0-1"""
    v_min, v_max = data.min(), data.max()
    if v_max - v_min > 0: return (data - v_min) / (v_max - v_min)
    return data

def calculate_cnr(vol):
    """Contrast-to-Noise Ratio"""
    flat = vol.flatten()
    thresh = np.mean(flat)
    sig = flat[flat > thresh]; bg = flat[flat <= thresh]
    if len(sig) < 5 or len(bg) < 5: return 0.0
    return abs(np.mean(sig) - np.mean(bg)) / (np.std(bg) + 1e-6)

# --- FUNGSI BOBOT LINEAR (PYRAMID) ---
def get_linear_weight_map(patch_size):
    """
    Membuat bobot berbentuk piramida.
    Tengah = 1.0 (Kuat), Pinggir = 0.0 (Lemah).
    Saat dijahit, bagian lemah akan digantikan oleh bagian kuat dari patch sebelahnya.
    Hasil: Mulus tanpa blur.
    """
    # 1D Linear ramp
    vals = np.linspace(0, 1, patch_size)
    vals = np.minimum(vals, 1 - vals) * 2

    # Buat 3D Weight Map
    w_x, w_y, w_z = np.meshgrid(vals, vals, vals, indexing='ij')
    weight_map = w_x * w_y * w_z

    return np.maximum(weight_map, 1e-4)

# ==========================================
# 4. INFERENCE ENGINE (LINEAR BLENDING)
# ==========================================
def predict_linear_blending(model, vol_in, patch_size, overlap=0.5):
    model.eval()
    d, h, w = vol_in.shape
    stride = int(patch_size * (1 - overlap))

    # Padding
    pad_d = (stride - d % stride) % stride
    pad_h = (stride - h % stride) % stride
    pad_w = (stride - w % stride) % stride

    # Extra padding
    pad_extra = patch_size
    vol_padded = np.pad(vol_in, ((0, pad_d + pad_extra), (0, pad_h + pad_extra), (0, pad_w + pad_extra)), mode='reflect')

    d_p, h_p, w_p = vol_padded.shape

    # Penampung Hasil
    output_sum = np.zeros_like(vol_padded)
    weight_sum = np.zeros_like(vol_padded)

    # Buat Weight Map
    patch_weight = torch.from_numpy(get_linear_weight_map(patch_size)).float().to(DEVICE)

    print("   ‚è≥ Stitching with Linear Blending (Seamless)...")

    with torch.no_grad():
        for z in range(0, d_p - patch_size, stride):
            for y in range(0, h_p - patch_size, stride):
                for x in range(0, w_p - patch_size, stride):

                    patch_in = vol_padded[z:z+patch_size, y:y+patch_size, x:x+patch_size]

                    if patch_in.shape != (patch_size, patch_size, patch_size): continue

                    t_in = torch.from_numpy(patch_in).float().unsqueeze(0).unsqueeze(0).to(DEVICE)
                    t_out = model(t_in).squeeze() # (32, 32, 32)

                    output_sum[z:z+patch_size, y:y+patch_size, x:x+patch_size] += (t_out * patch_weight).cpu().numpy()
                    weight_sum[z:z+patch_size, y:y+patch_size, x:x+patch_size] += patch_weight.cpu().numpy()

    # Normalisasi (Weighted Average)
    reconstructed = output_sum / (weight_sum + 1e-8)

    return reconstructed[:d, :h, :w]

# ==========================================
# 5. MAIN EVALUATION LOOP
# ==========================================
def run_evaluation():
    print(f"üß† Loading Model: {os.path.basename(MODEL_PATH)}")

    # Load Arsitektur Lightweight
    model = Lightweight3DUNet().to(DEVICE)
    if os.path.exists(MODEL_PATH):
        model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
        print("‚úÖ Model loaded successfully!")
    else:
        print("‚ùå Model path not found!"); return

    # Cari file
    all_files = sorted(glob.glob(os.path.join(DATA_FOLDER, "processed_input_*.npy")))
    input_files = all_files[-3:]

    results_log = []

    for i, in_path in enumerate(input_files):
        filename = os.path.basename(in_path)
        print(f"\nProcessing [{i+1}/3]: {filename}")

        # Load Data
        raw_in = np.load(in_path)
        vol_in = normalize_data(raw_in)

        target_path = os.path.join(DATA_FOLDER, filename.replace("processed_input_", "processed_target_"))
        vol_gt = normalize_data(np.load(target_path)) if os.path.exists(target_path) else None

        if vol_gt is not None and vol_gt.shape != vol_in.shape:
            vol_gt = vol_gt[:vol_in.shape[0], :vol_in.shape[1], :vol_in.shape[2]]

        # --- PREDICT ---
        vol_out = predict_linear_blending(model, vol_in, PATCH_SIZE, OVERLAP_PERCENT)

        # --- METRICS ---
        metrics = {'PSNR': 0, 'SSIM': 0, 'RMSE': 0}
        if vol_gt is not None:
            metrics['PSNR'] = psnr(vol_gt, vol_out, data_range=1.0)
            metrics['SSIM'] = ssim(vol_gt, vol_out, data_range=1.0)
            metrics['RMSE'] = math.sqrt(np.mean((vol_gt - vol_out)**2))

        cnr_in = calculate_cnr(vol_in)
        cnr_out = calculate_cnr(vol_out)
        cnr_gt = calculate_cnr(vol_gt) if vol_gt is not None else 0

        results_log.append({
            'Filename': filename,
            'PSNR': metrics['PSNR'], 'SSIM': metrics['SSIM'], 'RMSE': metrics['RMSE'],
            'CNR_In': cnr_in, 'CNR_Out': cnr_out, 'CNR_GT': cnr_gt,
            'CNR_Improv': cnr_out - cnr_in
        })

        print(f"   -> PSNR: {metrics['PSNR']:.2f} | SSIM: {metrics['SSIM']:.4f} | RMSE: {metrics['RMSE']:.4f}")
        print(f"   -> CNR Improv: {cnr_out - cnr_in:.4f}")

        # --- VISUALIZATION ---
        mid = vol_out.shape[0] // 2
        fig = plt.figure(figsize=(20, 12))
        gs = fig.add_gridspec(3, 4)

        # Row 1: Images
        ax1 = fig.add_subplot(gs[0, 0]); ax1.imshow(vol_in[mid], cmap='gray'); ax1.set_title("Input (FBP)")
        ax2 = fig.add_subplot(gs[0, 1]); ax2.imshow(vol_out[mid], cmap='gray'); ax2.set_title("AI Output (Natural)")
        ax3 = fig.add_subplot(gs[0, 2]);
        if vol_gt is not None: ax3.imshow(vol_gt[mid], cmap='gray'); ax3.set_title("Ground Truth")
        else: ax3.axis('off')

        # CNR Chart
        ax_chart = fig.add_subplot(gs[0, 3])
        ax_chart.bar(['In', 'Out', 'GT'], [cnr_in, cnr_out, cnr_gt], color=['gray', 'blue', 'green'])
        ax_chart.set_title("CNR Comparison"); ax_chart.grid(axis='y', linestyle='--', alpha=0.3)

        # Row 2: Density Maps (Heatmap)
        ax4 = fig.add_subplot(gs[1, 0]); im4 = ax4.imshow(vol_in[mid], cmap='jet'); ax4.set_title("Density (In)"); plt.colorbar(im4, ax=ax4)
        ax5 = fig.add_subplot(gs[1, 1]); im5 = ax5.imshow(vol_out[mid], cmap='jet'); ax5.set_title("Density (Out)"); plt.colorbar(im5, ax=ax5)
        ax6 = fig.add_subplot(gs[1, 2]);
        if vol_gt is not None: im6 = ax6.imshow(vol_gt[mid], cmap='jet'); ax6.set_title("Density (GT)"); plt.colorbar(im6, ax=ax6)
        else: ax6.axis('off')

        # Row 3: Error Map
        ax7 = fig.add_subplot(gs[2, 1]);
        err = np.abs(vol_gt - vol_out) if vol_gt is not None else np.zeros_like(vol_out)
        im7 = ax7.imshow(err[mid], cmap='inferno'); ax7.set_title(f"Error Map (RMSE: {metrics['RMSE']:.4f})"); plt.colorbar(im7, ax=ax7)

        for ax in [ax1, ax2, ax3, ax4, ax5, ax6, ax7]: ax.axis('off')
        plt.tight_layout()
        plt.savefig(os.path.join(OUTPUT_DIR, f"Eval_Natural_{filename.replace('.npy','.png')}"))
        plt.show()

    if results_log:
        pd.DataFrame(results_log).to_csv(os.path.join(OUTPUT_DIR, "Final_Report_Natural.csv"), index=False)
        print("\n‚úÖ DONE.")
        print(pd.DataFrame(results_log)[['Filename', 'PSNR', 'SSIM', 'RMSE']].to_string())

if __name__ == "__main__":
    run_evaluation()

# **2D Multi-Level Wavelet Convolutional Neural Network (MWCNN)**

## **Training & Validation**

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import numpy as np
import os

# ===============================
# 1. Dataset 2D FBP
# ===============================
class FBP2DDataset(Dataset):
    def __init__(self, fbp_folder, gt_folder):
        self.fbp_files = sorted([os.path.join(fbp_folder, f) for f in os.listdir(fbp_folder)])
        self.gt_files = sorted([os.path.join(gt_folder, f) for f in os.listdir(gt_folder)])

    def __len__(self):
        return len(self.fbp_files)

    def __getitem__(self, idx):
        x = np.load(self.fbp_files[idx])
        y = np.load(self.gt_files[idx])

        # Normalize
        x = (x - x.min()) / (x.max() - x.min() + 1e-8)
        y = (y - y.min()) / (y.max() - y.min() + 1e-8)

        # To tensor + channel dim
        x = torch.tensor(x, dtype=torch.float32).unsqueeze(0)  # [1,H,W]
        y = torch.tensor(y, dtype=torch.float32).unsqueeze(0)

        return x, y

# ===============================
# 2. Model MultiWave CNN 2D
# ===============================
class MultiWaveBlock2D(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.c3 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.c5 = nn.Conv2d(in_ch, out_ch, 5, padding=2)
        self.c7 = nn.Conv2d(in_ch, out_ch, 7, padding=3)
        self.bn = nn.BatchNorm2d(out_ch * 3)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.bn(torch.cat([
            self.c3(x), self.c5(x), self.c7(x)
        ], dim=1)))

class MultiWaveCNN2D(nn.Module):
    def __init__(self):
        super().__init__()
        self.b1 = MultiWaveBlock2D(1, 16)
        self.b2 = MultiWaveBlock2D(48, 32)
        self.b3 = MultiWaveBlock2D(96, 64)
        self.out = nn.Conv2d(192, 1, 1)

    def forward(self, x):
        x = self.b1(x)
        x = self.b2(x)
        x = self.b3(x)
        return self.out(x)

# ===============================
# 3. Load data & split
# ===============================
fbp_folder = "/content/drive/MyDrive/Tugas SMT 5/Pencit/Hasil FBP 2D"
gt_folder = "/content/drive/MyDrive/Tugas SMT 5/Pencit/Ground Truth FBP 2D"

dataset = FBP2DDataset(fbp_folder, gt_folder)

train_idx, val_idx = train_test_split(np.arange(len(dataset)), test_size=0.1, random_state=42)
train_dataset = torch.utils.data.Subset(dataset, train_idx)
val_dataset = torch.utils.data.Subset(dataset, val_idx)

BATCH_SIZE = 2
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# ===============================
# 4. Setup model, loss, optimizer, scheduler
# ===============================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MultiWaveCNN2D().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

# ===============================
# 5. Training loop dengan checkpoint
# ===============================
EPOCHS = 25
patience = 10
early_stop_counter = 0
best_val = float('inf')

CHECKPOINT_PATH = "/content/drive/MyDrive/Tugas SMT 5/Pencit/checkpoint_latest.pth"

start_epoch = 0

if os.path.exists(CHECKPOINT_PATH):
    print("üîÅ Resuming training from checkpoint...")

    checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)

    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    scheduler.load_state_dict(checkpoint["scheduler_state_dict"])

    start_epoch = checkpoint["epoch"]
    best_val = checkpoint["best_val"]

    print(f"‚úÖ Resumed from epoch {start_epoch}")
else:
    print("üÜï No checkpoint found, starting from epoch 1")

for e in range(start_epoch, EPOCHS):

    # =======================
    # TRAIN
    # =======================
    model.train()
    train_loss = 0.0

    for x, y in tqdm(train_loader, desc=f"Epoch {e+1} Training"):
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_loader)

    # =======================
    # VALIDATION
    # =======================
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            val_loss += criterion(model(x), y).item()

    val_loss /= len(val_loader)

    # =======================
    # SCHEDULER
    # =======================
    scheduler.step(val_loss)

    # =======================
    # BEST MODEL
    # =======================
    if val_loss < best_val:
        best_val = val_loss
        torch.save(model.state_dict(), "best_model.pth")
        early_stop_counter = 0
    else:
        early_stop_counter += 1

    # =======================
    # SAVE CHECKPOINT (SETIAP EPOCH)
    # =======================
    torch.save({
        "epoch": e + 1,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "best_val": best_val
    }, CHECKPOINT_PATH)

    # =======================
    # LOG
    # =======================
    print(
        f"Epoch {e+1:02d} | "
        f"Train: {train_loss:.6f} | "
        f"Val: {val_loss:.6f} | "
        f"Best: {best_val:.6f}"
    )

    # =======================
    # EARLY STOPPING
    # =======================
    if early_stop_counter >= patience:
        print(f"‚õî No improvement in {patience} epochs, stopping early.")
        break

## **Testing**

In [None]:
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import matplotlib.pyplot as plt
import numpy as np
import torch

def evaluate(model, loader, device, max_batches=3):
    model.eval()

    psnr_list, ssim_list, rmse_list, cnr_list = [], [], [], []

    with torch.no_grad():
        for i, (x, y) in enumerate(loader):
            if i >= max_batches:
                break

            x = x.to(device)
            y = y.to(device)

            pred = model(x)

            x = x.cpu().numpy()
            y = y.cpu().numpy()
            pred = pred.cpu().numpy()

            for b in range(x.shape[0]):
                inp = x[b, 0]
                gt  = y[b, 0]
                out = pred[b, 0]

                diff = np.abs(gt - out)

                # ======================
                # METRIK
                # ======================
                psnr_val = psnr(gt, out, data_range=1.0)
                ssim_val = ssim(gt, out, data_range=1.0)
                rmse_val = np.sqrt(np.mean((gt - out) ** 2))

                # ======================
                # CNR
                # ======================
                h, w = gt.shape
                roi = np.s_[h//3:h//2, w//3:w//2]
                bg  = np.s_[0:h//5, 0:w//5]

                cnr_val = compute_cnr(out, roi, bg)

                psnr_list.append(psnr_val)
                ssim_list.append(ssim_val)
                rmse_list.append(rmse_val)
                cnr_list.append(cnr_val)

                # ======================
                # PRINT
                # ======================
                print(
                    f"PSNR: {psnr_val:.2f} dB | "
                    f"SSIM: {ssim_val:.4f} | "
                    f"RMSE: {rmse_val:.5f} | "
                    f"CNR: {cnr_val:.3f}"
                )

                # ======================
                # VISUALISASI
                # ======================
                plt.figure(figsize=(16,4))

                plt.subplot(1,4,1)
                plt.imshow(inp, cmap='gray')
                plt.title("Input (FBP)")
                plt.axis('off')

                plt.subplot(1,4,2)
                plt.imshow(out, cmap='gray')
                plt.title("Output (MWCNN)")
                plt.axis('off')

                plt.subplot(1,4,3)
                plt.imshow(gt, cmap='gray')
                plt.title("Ground Truth")
                plt.axis('off')

                plt.subplot(1,4,4)
                plt.imshow(diff, cmap='hot')
                plt.title("Difference Map")
                plt.colorbar(fraction=0.046)
                plt.axis('off')

                plt.show()

    # ======================
    # RATA-RATA
    # ======================
    print("\n===== RATA-RATA EVALUASI =====")
    print(f"PSNR  : {np.mean(psnr_list):.2f} ¬± {np.std(psnr_list):.2f}")
    print(f"SSIM  : {np.mean(ssim_list):.4f}")
    print(f"RMSE  : {np.mean(rmse_list):.5f}")
    print(f"CNR   : {np.mean(cnr_list):.3f}")

In [None]:
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import numpy as np
import torch

def compare_fbp_vs_mwcnn(model, loader, device, max_batches=3):
    model.eval()

    metrics = {
        "FBP":  {"psnr": [], "ssim": [], "rmse": [], "cnr": []},
        "MWCNN":{"psnr": [], "ssim": [], "rmse": [], "cnr": []}
    }

    with torch.no_grad():
        for i, (x, y) in enumerate(loader):
            if i >= max_batches:
                break

            x = x.to(device)
            y = y.to(device)

            pred = model(x)

            x = x.cpu().numpy()
            y = y.cpu().numpy()
            pred = pred.cpu().numpy()

            for b in range(x.shape[0]):
                fbp = x[b,0]
                gt  = y[b,0]
                out = pred[b,0]

                # ======================
                # METRIK FBP
                # ======================
                metrics["FBP"]["psnr"].append(psnr(gt, fbp, data_range=1.0))
                metrics["FBP"]["ssim"].append(ssim(gt, fbp, data_range=1.0))
                metrics["FBP"]["rmse"].append(np.sqrt(np.mean((gt - fbp)**2)))

                # ======================
                # METRIK MWCNN
                # ======================
                metrics["MWCNN"]["psnr"].append(psnr(gt, out, data_range=1.0))
                metrics["MWCNN"]["ssim"].append(ssim(gt, out, data_range=1.0))
                metrics["MWCNN"]["rmse"].append(np.sqrt(np.mean((gt - out)**2)))

                # ======================
                # CNR
                # ======================
                h, w = gt.shape
                roi = np.s_[h//3:h//2, w//3:w//2]
                bg  = np.s_[0:h//5, 0:w//5]

                metrics["FBP"]["cnr"].append(compute_cnr(fbp, roi, bg))
                metrics["MWCNN"]["cnr"].append(compute_cnr(out, roi, bg))

    # ======================
    # RATA-RATA
    # ======================
    print("\n========== PERBANDINGAN KINERJA ==========")

    for m in ["psnr", "ssim", "rmse", "cnr"]:
        fbp_val = np.mean(metrics["FBP"][m])
        mw_val  = np.mean(metrics["MWCNN"][m])

        if m in ["psnr", "ssim", "cnr"]:
            label = "‚úÖ Better" if mw_val > fbp_val else "‚ùå Worse"
        else:  # RMSE
            label = "‚úÖ Better" if mw_val < fbp_val else "‚ùå Worse"

        print(f"{m.upper():5s} | FBP: {fbp_val:.4f} | MWCNN: {mw_val:.4f} ‚Üí {label}")

In [None]:
compare_fbp_vs_mwcnn(model, val_loader, device)