In [10]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import sys
import os
import pandas as pd # Added for nice tables
from pathlib import Path
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score, mean_squared_error

# --- 1. Project Setup ---
current_path = Path.cwd()
PROJECT_ROOT = None
for parent in [current_path] + list(current_path.parents):
    if (parent / "func").exists():
        PROJECT_ROOT = parent
        break
if PROJECT_ROOT is None:
    raise FileNotFoundError("Could not find project root containing 'func' folder.")

sys.path.append(str(PROJECT_ROOT))

from func.Models import VAE, MultiTaskNet_ag, MultiTaskNet_big
from func.dataloaders import VolumetricPatchDataset 

# --- 2. Configuration ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
INPUT_SHAPE = (128, 128, 128)
NUM_CLASSES = 4
LATENT_DIM = 256

# Define your Model Paths
PATHS = {
    "VAE": PROJECT_ROOT / "Trained_models" / "VAE_val_best.pth",
    "AG_Net": PROJECT_ROOT / "Trained_models" / "AG_val_best.pth",
    "Multi_Big": PROJECT_ROOT / "Trained_models" / "multi_big_best.pth"
}

TEST_COLS = [35, 36, 37, 38] 

# --- 3. Metric Functions ---

def calculate_iou_per_class(pred, target, num_classes):
    """Calculates IoU for each class in the 3D volume."""
    ious = []
    # Flatten tensors for easier processing
    pred = pred.view(-1)
    target = target.view(-1)
    
    for cls in range(num_classes):
        pred_inds = (pred == cls)
        target_inds = (target == cls)
        
        intersection = (pred_inds & target_inds).sum().item()
        union = (pred_inds | target_inds).sum().item()
        
        if union == 0:
            ious.append(float('nan')) # Ignore if class not present
        else:
            ious.append(intersection / union)
            
    return np.array(ious)

def calculate_auroc(logits, target, num_classes):
    """
    Calculates One-vs-Rest AUROC for multi-class segmentation.
    NOTE: This flattens the 3D volume, which can be heavy.
    """
    # logits: (1, C, D, H, W) -> Softmax -> Permute -> Flatten
    probs = F.softmax(logits, dim=1).squeeze(0) # (C, D, H, W)
    probs = probs.permute(1, 2, 3, 0).reshape(-1, num_classes).cpu().numpy()
    
    # target: (1, D, H, W) -> Flatten
    target = target.squeeze(0).reshape(-1).cpu().numpy()
    
    # Calculate weighted AUROC (One-vs-Rest)
    try:
        score = roc_auc_score(target, probs, multi_class='ovr', average='weighted')
        return score
    except ValueError:
        return 0.0 # Handle cases where not all classes are present in the patch

def calculate_reconstruction_mse(recon, input_img):
    """Calculates Mean Squared Error for reconstruction."""
    recon_flat = recon.view(-1).cpu().numpy()
    input_flat = input_img.view(-1).cpu().numpy()
    return mean_squared_error(input_flat, recon_flat)

# --- 4. Model Loading Helper ---
def load_model(model_class, path, model_name):
    print(f"Loading {model_name}...")
    try:
        # Check init signature based on class type
        if model_name == "VAE":
            model = model_class(in_channels=1, latent_dim=LATENT_DIM, NUM_CLASSES=NUM_CLASSES)
        else:
            model = model_class(in_channels=1, latent_dim=LATENT_DIM, num_classes=NUM_CLASSES)
        
        model.to(DEVICE)
        if not path.exists():
            print(f"  ❌ Warning: File not found at {path}")
            return None
        
        # Safe loading
        state_dict = torch.load(path, map_location=DEVICE)
        model.load_state_dict(state_dict)
        model.eval()
        return model
    except Exception as e:
        print(f"  ❌ Error loading {model_name}: {e}")
        return None

# --- 5. Main Execution ---



In [38]:
# A. Load Data
print("Loading Test Data...")
try:
    ds = VolumetricPatchDataset(selected_columns=TEST_COLS, augment=False, is_labeled=True)
    dl = DataLoader(ds, batch_size=1, shuffle=True)
    x_batch, y_batch = next(iter(dl))
    x_batch = x_batch.to(DEVICE)
    y_batch = y_batch.squeeze(1).to(DEVICE) # Ensure shape (1, D, H, W)
except Exception as e:
    print(f"Error loading data: {e}")
    sys.exit(1)

# B. Load Models
models = {
    "VAE": load_model(VAE, PATHS["VAE"], "VAE"),
    "AG_Net": load_model(MultiTaskNet_ag, PATHS["AG_Net"], "AG_Net"),
    "Multi_Big": load_model(MultiTaskNet_big, PATHS["Multi_Big"], "Multi_Big")
}

# C. Run Inference & Calculate Metrics
metrics_data = []
results = {}

print("\nRunning Inference & Metrics Calculation...")
with torch.no_grad():
    for name, model in models.items():
        if model is None: continue
        
        # Forward Pass
        output = model(x_batch)
        
        # Handle output formats
        if len(output) == 4: seg_logits, recon, _, _ = output
        elif len(output) == 3: seg_logits, recon, _ = output
        else: seg_logits, recon = output[0], output[1]
        
        # Save for plotting
        results[name] = {"seg": seg_logits, "recon": recon}
        
        # --- METRICS ---
        # 1. Segmentation Prediction (Argmax)
        pred_seg = torch.argmax(seg_logits, dim=1)
        
        # 2. IoU (Per Class & Mean Foreground)
        ious = calculate_iou_per_class(pred_seg, y_batch, NUM_CLASSES)
        mIoU = np.nanmean(ious[1:]) # Mean of Classes 1, 2, 3 (exclude background)
        
        # 3. AUROC (Segmentation Performance)
        # Note: Using only a subset of pixels if memory is tight, but here we use full patch
        auroc = calculate_auroc(seg_logits, y_batch, NUM_CLASSES)
        
        # 4. ME (Reconstruction MSE)
        mse_recon = calculate_reconstruction_mse(recon, x_batch)
        
        # Store Data
        metrics_data.append({
            "Model": name,
            "mIoU (Foreground)": mIoU,
            "AUROC": auroc,
            "Recon MSE (ME)": mse_recon,
            "IoU Class 1": ious[1],
            "IoU Class 2": ious[2],
            "IoU Class 3": ious[3]
        })

# D. Display Metrics
df = pd.DataFrame(metrics_data)
print("\n--- Model Performance Comparison ---")
display(df) # Uses Jupyter's nice table display

# E. Visualization (Central Slice)
slc = 64
img_in = x_batch[0, 0, slc].cpu().numpy()
img_gt = y_batch[0, slc].cpu().numpy()

fig, axes = plt.subplots(3, 1 + len(results), figsize=(4 * (1 + len(results)), 12))

# Column 0: Inputs
axes[0, 0].imshow(img_in, cmap='gray')
axes[0, 0].set_title("Input Image")
axes[1, 0].imshow(img_gt, cmap='viridis', vmin=0, vmax=NUM_CLASSES-1)
axes[1, 0].set_title("Ground Truth")
axes[2, 0].axis('off')

# Columns 1..N: Models
for i, (name, res) in enumerate(results.items(), start=1):
    # Recon
    recon_slice = res["recon"][0, 0, slc].cpu().numpy()
    axes[0, i].imshow(recon_slice, cmap='gray')
    axes[0, i].set_title(f"{name}\nRecon")
    
    # Segmentation
    seg_slice = torch.argmax(res["seg"], dim=1)[0, slc].cpu().numpy()
    axes[1, i].imshow(seg_slice, cmap='viridis', vmin=0, vmax=NUM_CLASSES-1)
    axes[1, i].set_title(f"{name}\nSeg Prediction")
    
    # Error Map
    diff = (seg_slice != img_gt)
    axes[2, i].imshow(diff, cmap='Reds', vmin=0, vmax=1)
    axes[2, i].set_title(f"{name}\nError Map")

for ax in axes.flatten(): ax.axis('off')
plt.tight_layout()
plt.show()


Loading Test Data...
Loading VAE...
Loading AG_Net...
Loading Multi_Big...

Running Inference & Metrics Calculation...


  state_dict = torch.load(path, map_location=DEVICE)


KeyboardInterrupt: 

In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import sys
import itertools
from pathlib import Path
from scipy.io import loadmat
import os

# --- Project Path Setup ---
# Assuming this notebook is located in 'Model/Notebooks/' or similar inside your project
# We go up two levels to find the root. Adjust .parent count if needed.
PROJECT_ROOT = Path.cwd().parent
sys.path.append(str(PROJECT_ROOT))

print(f"Project Root set to: {PROJECT_ROOT}")

# --- Import Model ---
# Change this import if you want to use MultiTaskNet_big instead
from func.Models import MultiTaskNet_ag as MultiTaskNet 
from func.Models import VAE
# from func.Models import MultiTaskNet_big as MultiTaskNet 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Project Root set to: /zhome/d2/4/167803/Desktop/Deep_project/02456-final-project
Using device: cpu
Using device: cpu


In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import sys
import itertools
from pathlib import Path
from scipy.io import loadmat
from matplotlib.colors import ListedColormap
import os

# --- SKLEARN for Metrics ---
try:
    from sklearn.metrics import roc_auc_score
except ImportError:
    print("WARNING: scikit-learn not found. AUROC calculation will be skipped.")
    roc_auc_score = None

# --- Project Setup ---
PROJECT_ROOT = Path.cwd().parent
sys.path.append(str(PROJECT_ROOT))

# Import your model
from func.Models import MultiTaskNet_ag as MultiTaskNet 
from func.Models import VAE
# ---------------------

# --- Configuration ---
NUM_CLASSES = 4
LATENT_DIM = 512
FULL_VOLUME_SIZE = 256
PATCH_SIZE = 128
SLICES_PER_AXIS = FULL_VOLUME_SIZE // PATCH_SIZE 

# Weights Path (Updated to your AG model)
MODEL_PATH = PROJECT_ROOT / "Trained_models" / "AG_val_best.pth"
SAVE_FILENAME = "full_ag_best_col36.png"

# Data Directory
BLACKHOLE_PATH = os.environ.get('BLACKHOLE', '.')
BASE_DATA_DIR = Path(os.path.join(BLACKHOLE_PATH, 'deep_learning_214776', 'extracted_datasets', 'datasets_processed_latest'))

# Test Column (e.g., 35)
TEST_COLUMN_ID = 36

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"--- Using device: {device} ---")


Using device: cpu
--- Using device: cpu ---


In [None]:
def get_custom_colormap():
    """
    0: Blue (Background)
    1: Red
    2: Yellow
    3: Turquoise (Cyan)
    """
    colors = ['blue', 'red', 'yellow', 'cyan']
    return ListedColormap(colors)


def load_raw_volume(column_num, half_type='top'):
    """Loads the FULL 256^3 volume from disk."""
    column_dir = BASE_DATA_DIR / f'Column_{column_num}'
    x_filepath = column_dir / 'B' / f'{half_type}.mat'
    y_filepath = column_dir / f'gt_{half_type}.mat'
    
    print(f"Loading: {x_filepath}")
    try:
        X_full = np.squeeze(loadmat(str(x_filepath))[half_type]).astype(np.float32) / 255.0
        Y_full = np.squeeze(loadmat(str(y_filepath))[f'gt_{half_type}']).astype(np.int64)
        
        X_tensor = torch.from_numpy(X_full).unsqueeze(0).unsqueeze(0).to(device)
        Y_tensor = torch.from_numpy(Y_full).unsqueeze(0).to(device)
        return X_tensor, Y_tensor
    except Exception as e:
        print(f"ERROR loading raw volume: {e}")
        sys.exit(1)


def stitch_inference(model, X_full):
    """
    Cuts the 256^3 volume into 8 patches, runs inference, and stitches them back.
    """
    model.eval()
    
    stitched_seg = torch.zeros((1, NUM_CLASSES, FULL_VOLUME_SIZE, FULL_VOLUME_SIZE, FULL_VOLUME_SIZE), device=device)
    stitched_recon = torch.zeros((1, 1, FULL_VOLUME_SIZE, FULL_VOLUME_SIZE, FULL_VOLUME_SIZE), device=device)
    
    print("Starting Stitching Loop...")
    for pd, ph, pw in itertools.product(range(SLICES_PER_AXIS), repeat=3):
        d_start, d_end = pd * PATCH_SIZE, (pd + 1) * PATCH_SIZE
        h_start, h_end = ph * PATCH_SIZE, (ph + 1) * PATCH_SIZE
        w_start, w_end = pw * PATCH_SIZE, (pw + 1) * PATCH_SIZE
        
        X_patch = X_full[:, :, d_start:d_end, h_start:h_end, w_start:w_end]
        
        with torch.no_grad():
            seg_patch, recon_patch = model(X_patch)
            
        stitched_seg[:, :, d_start:d_end, h_start:h_end, w_start:w_end] = seg_patch
        stitched_recon[:, :, d_start:d_end, h_start:h_end, w_start:w_end] = recon_patch
        
        print(f"  Processed Patch: [{d_start}:{d_end}, {h_start}:{h_end}, {w_start}:{w_end}]")

    return stitched_seg, stitched_recon


def calculate_metrics(stitched_logits, Y_gt):
    """Calculates IoU, ME, and AUROC."""
    print("\n--- Calculating Full Volume Metrics ---")
    probs = torch.softmax(stitched_logits, dim=1)
    preds = torch.argmax(stitched_logits, dim=1)
    
    y_true_flat = Y_gt.cpu().numpy().flatten()
    y_pred_flat = preds.cpu().numpy().flatten()
    
    # Mean Error
    me = 1.0 - np.mean(y_true_flat == y_pred_flat)
    
    # IoU
    class_ious = []
    print(f"\n[Intersection over Union]")
    for c in range(NUM_CLASSES):
        intersection = np.sum((y_true_flat == c) & (y_pred_flat == c))
        union = np.sum((y_true_flat == c) | (y_pred_flat == c))
        
        if union == 0:
            iou = 1.0 if intersection == 0 else 0.0
        else:
            iou = intersection / union
        print(f"  Class {c}: {iou:.4f}")
        class_ious.append(iou)
        
    mIoU_all = np.mean(class_ious)
    mIoU_fg = np.mean(class_ious[1:]) 
    
    # AUROC
    auroc = 0.0
    if roc_auc_score is not None:
        print("\n[AUROC Calculation...]")
        try:
            y_probs = probs.cpu().numpy()[0] # (C, D, H, W)
            n_classes = y_probs.shape[0]
            # Flatten to (N_voxels, C)
            y_probs_flat = np.transpose(y_probs, (1, 2, 3, 0)).reshape(-1, n_classes)
            auroc = roc_auc_score(y_true_flat, y_probs_flat, multi_class='ovr', average='macro')
        except Exception as e:
            print(f"  AUROC Error: {e}")
            auroc = -1.0
    
    print("\n" + "="*30)
    print(f"FINAL METRICS for Column {TEST_COLUMN_ID}")
    print(f"  Mean Error (ME):       {me:.4f}")
    print(f"  Mean IoU (All):        {mIoU_all:.4f}")
    print(f"  Mean IoU (Foreground): {mIoU_fg:.4f}")
    print(f"  AUROC (Macro):         {auroc:.4f}")
    print("="*30 + "\n")


def visualize_full_slice(X_tensor, Y_tensor, stitched_seg, stitched_recon, save_path, slice_idx=FULL_VOLUME_SIZE // 2):
    """Saves a central slice using the CUSTOM COLORMAP."""
    pred_seg_tensor = torch.argmax(stitched_seg, dim=1) 
    
    x_np = X_tensor[0, 0, slice_idx, :, :].cpu().numpy()
    y_np = Y_tensor[0, slice_idx, :, :].cpu().numpy()
    recon_np = stitched_recon[0, 0, slice_idx, :, :].cpu().numpy()
    pred_np = pred_seg_tensor[0, slice_idx, :, :].cpu().numpy()
    
    fig, axes = plt.subplots(1, 4, figsize=(24, 6))
    titles = ['Input', 'Ground Truth', 'Reconstruction', 'Prediction']
    data = [x_np, y_np, recon_np, pred_np]
    
    # Use custom map for GT and Pred
    custom_cmap = get_custom_colormap()
    cmaps = ['gray', custom_cmap, 'gray', custom_cmap]
    
    for i, ax in enumerate(axes):
        # Vmax=3 for the 4 classes (0,1,2,3)
        vmax = NUM_CLASSES - 1 if i == 1 or i == 3 else None
        vmin = 0 if i == 1 or i == 3 else None
        interp = 'nearest' if i == 1 or i == 3 else None
        
        im = ax.imshow(data[i], cmap=cmaps[i], interpolation=interp, vmin=vmin, vmax=vmax)
        ax.set_title(titles[i], fontsize=14)
        ax.axis('off')
        
        if i == 1 or i == 3:
             cbar = plt.colorbar(im, ax=ax, ticks=range(NUM_CLASSES), fraction=0.046, pad=0.04)
             cbar.ax.set_yticklabels(['Water', 'Oil', 'Solids', 'Gas'])

    plt.suptitle(f"Full Resolution Inference - Test Column {TEST_COLUMN_ID} - Slice {slice_idx}", fontsize=16)
    plt.savefig(save_path, bbox_inches='tight', dpi=150)
    plt.close(fig)
    print(f"✅ Visualization saved to: {save_path}")



In [3]:
if __name__ == "__main__":
    
    # 1. Load Model
    model = MultiTaskNet(in_channels=1, num_classes=NUM_CLASSES, latent_dim=LATENT_DIM).to(device)
    #model = VAE(in_channels=1, latent_dim=LATENT_DIM, NUM_CLASSES=NUM_CLASSES)

    if MODEL_PATH.exists():
        print(f"Loading weights from: {MODEL_PATH}")
        model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
    else:
        print(f"❌ Model weights not found at {MODEL_PATH}")
        sys.exit(1)
    
    # 2. Load Data (Column 35, Top Half)
    X_full_tensor, Y_full_tensor = load_raw_volume(column_num=TEST_COLUMN_ID, half_type='top')
    
    # 3. Stitching Inference
    stitched_seg, stitched_recon = stitch_inference(model, X_full_tensor)

    # 4. Metrics Calculation
    calculate_metrics(stitched_seg, Y_full_tensor)

    # 5. Visualization
    save_full_path = Path.cwd() / SAVE_FILENAME
    visualize_full_slice(X_full_tensor, Y_full_tensor, stitched_seg, stitched_recon, save_full_path)

Loading weights from: /zhome/d2/4/167803/Desktop/Deep_project/02456-final-project/Trained_models/AG_val_best.pth
Loading: /dtu/blackhole/1b/167803/deep_learning_214776/extracted_datasets/datasets_processed_latest/Column_36/B/top.mat


  model.load_state_dict(torch.load(MODEL_PATH, map_location=device))


Starting Stitching Loop...
  Processed Patch: [0:128, 0:128, 0:128]
  Processed Patch: [0:128, 0:128, 128:256]
  Processed Patch: [0:128, 128:256, 0:128]
  Processed Patch: [0:128, 128:256, 128:256]
  Processed Patch: [128:256, 0:128, 0:128]
  Processed Patch: [128:256, 0:128, 128:256]
  Processed Patch: [128:256, 128:256, 0:128]
  Processed Patch: [128:256, 128:256, 128:256]

--- Calculating Full Volume Metrics ---

[Intersection over Union]
  Class 0: 0.7987
  Class 1: 0.9055
  Class 2: 0.8821
  Class 3: 0.5043

[AUROC Calculation...]

FINAL METRICS for Column 36
  Mean Error (ME):       0.0888
  Mean IoU (All):        0.7726
  Mean IoU (Foreground): 0.7639
  AUROC (Macro):         0.9814

✅ Visualization saved to: /zhome/d2/4/167803/Desktop/Deep_project/02456-final-project/scripts/full_ag_best_col36.png


In [None]:
if __name__ == "__main__":
    
    # 1. Load Model
    #model = MultiTaskNet(in_channels=1, num_classes=NUM_CLASSES, latent_dim=LATENT_DIM).to(device)
    model = VAE(in_channels=1, latent_dim=LATENT_DIM, NUM_CLASSES=NUM_CLASSES)

    if MODEL_PATH.exists():
        print(f"Loading weights from: {MODEL_PATH}")
        model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
    else:
        print(f"❌ Model weights not found at {MODEL_PATH}")
        sys.exit(1)
    
    # 2. Load Data (Column 35, Top Half)
    X_full_tensor, Y_full_tensor = load_raw_volume(column_num=TEST_COLUMN_ID, half_type='top')
    
    # 3. Stitching Inference
    stitched_seg, stitched_recon = stitch_inference(model, X_full_tensor)

    # 4. Metrics Calculation
    calculate_metrics(stitched_seg, Y_full_tensor)

    # 5. Visualization
    save_full_path = Path.cwd() / SAVE_FILENAME
    visualize_full_slice(X_full_tensor, Y_full_tensor, stitched_seg, stitched_recon, save_full_path)

Loading weights from: /zhome/d2/4/167803/Desktop/Deep_project/02456-final-project/Trained_models/VAE_val_final.pth
Loading: /dtu/blackhole/1b/167803/deep_learning_214776/extracted_datasets/datasets_processed_latest/Column_36/B/top.mat


  model.load_state_dict(torch.load(MODEL_PATH, map_location=device))


Starting Stitching Loop...
  Processed Patch: [0:128, 0:128, 0:128]
  Processed Patch: [0:128, 0:128, 128:256]
  Processed Patch: [0:128, 128:256, 0:128]
  Processed Patch: [0:128, 128:256, 128:256]
  Processed Patch: [128:256, 0:128, 0:128]
  Processed Patch: [128:256, 0:128, 128:256]
  Processed Patch: [128:256, 128:256, 0:128]
  Processed Patch: [128:256, 128:256, 128:256]

--- Calculating Full Volume Metrics ---

[Intersection over Union]
  Class 0: 0.7878
  Class 1: 0.9043
  Class 2: 0.8575
  Class 3: 0.3947

[AUROC Calculation - This may take a moment for 16M voxels...]

FINAL METRICS for Column 36
  Mean Error (ME):       0.1118
  Mean IoU (All):        0.7361
  Mean IoU (Foreground): 0.7188
  AUROC (Macro):         0.9778

✅ Visualization saved to: /zhome/d2/4/167803/Desktop/Deep_project/02456-final-project/scripts/full_vae_best_col36.png


In [35]:
 visualize_full_slice(X_full_tensor, Y_full_tensor, stitched_seg, stitched_recon, save_full_path)

✅ Visualization saved to: /zhome/d2/4/167803/Desktop/Deep_project/02456-final-project/scripts/full_vae_best_col1.png
