In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import sys
import itertools
from pathlib import Path
from scipy.io import loadmat
from matplotlib.colors import ListedColormap
import os

# --- SKLEARN for Metrics ---
try:
    from sklearn.metrics import roc_auc_score
except ImportError:
    print("WARNING: scikit-learn not found. AUROC calculation will be skipped.")
    roc_auc_score = None

# --- Project Setup ---
PROJECT_ROOT = Path.cwd().parent
sys.path.append(str(PROJECT_ROOT))

# Import your model
from func.Models import MultiTaskNet_ag as MultiTaskNet 
from func.Models import VAE
# ---------------------

# --- Configuration ---
NUM_CLASSES = 4
LATENT_DIM = 512
FULL_VOLUME_SIZE = 256
PATCH_SIZE = 128
SLICES_PER_AXIS = FULL_VOLUME_SIZE // PATCH_SIZE 

# Weights Path (Updated to your AG model)
MODEL_PATH = PROJECT_ROOT / "Trained_models" / "VAE_val_best.pth"

# Data Directory
BLACKHOLE_PATH = os.environ.get('BLACKHOLE', '.')
BASE_DATA_DIR = Path(os.path.join(BLACKHOLE_PATH, 'deep_learning_214776', 'extracted_datasets', 'datasets_processed_latest'))

# Test Column (e.g., 35)
TEST_COLUMN_ID = np.array([1, 2, 37,38])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"--- Using device: {device} ---")


Using device: cpu
--- Using device: cpu ---


--- Using device: cpu ---


In [2]:
def get_custom_colormap():
    """
    0: Blue (Background)
    1: Red
    2: Yellow
    3: Turquoise (Cyan)
    """
    colors = ['blue', 'red', 'yellow', 'cyan']
    return ListedColormap(colors)


def load_raw_volume(column_num, half_type='top'):
    """Loads the FULL 256^3 volume from disk."""
    column_dir = BASE_DATA_DIR / f'Column_{column_num}'
    x_filepath = column_dir / 'B' / f'{half_type}.mat'
    y_filepath = column_dir / f'gt_{half_type}.mat'
    
    print(f"Loading: {x_filepath}")
    try:
        X_full = np.squeeze(loadmat(str(x_filepath))[half_type]).astype(np.float32) / 255.0
        Y_full = np.squeeze(loadmat(str(y_filepath))[f'gt_{half_type}']).astype(np.int64)
        
        X_tensor = torch.from_numpy(X_full).unsqueeze(0).unsqueeze(0).to(device)
        Y_tensor = torch.from_numpy(Y_full).unsqueeze(0).to(device)
        return X_tensor, Y_tensor
    except Exception as e:
        print(f"ERROR loading raw volume: {e}")
        sys.exit(1)


def stitch_inference(model, X_full):
    """
    Cuts the 256^3 volume into 8 patches, runs inference, and stitches them back.
    """
    model.eval()
    
    stitched_seg = torch.zeros((1, NUM_CLASSES, FULL_VOLUME_SIZE, FULL_VOLUME_SIZE, FULL_VOLUME_SIZE), device=device)
    stitched_recon = torch.zeros((1, 1, FULL_VOLUME_SIZE, FULL_VOLUME_SIZE, FULL_VOLUME_SIZE), device=device)
    
    print("Starting Stitching Loop...")
    for pd, ph, pw in itertools.product(range(SLICES_PER_AXIS), repeat=3):
        d_start, d_end = pd * PATCH_SIZE, (pd + 1) * PATCH_SIZE
        h_start, h_end = ph * PATCH_SIZE, (ph + 1) * PATCH_SIZE
        w_start, w_end = pw * PATCH_SIZE, (pw + 1) * PATCH_SIZE
        
        X_patch = X_full[:, :, d_start:d_end, h_start:h_end, w_start:w_end]
        
        with torch.no_grad():
            seg_patch, recon_patch,_,_ = model(X_patch)
            
        stitched_seg[:, :, d_start:d_end, h_start:h_end, w_start:w_end] = seg_patch
        stitched_recon[:, :, d_start:d_end, h_start:h_end, w_start:w_end] = recon_patch
        
        print(f"  Processed Patch: [{d_start}:{d_end}, {h_start}:{h_end}, {w_start}:{w_end}]")

    return stitched_seg, stitched_recon


def calculate_metrics(stitched_logits, Y_gt):
    """Calculates IoU, ME, and AUROC."""
    print("\n--- Calculating Full Volume Metrics ---")
    probs = torch.softmax(stitched_logits, dim=1)
    preds = torch.argmax(stitched_logits, dim=1)
    
    y_true_flat = Y_gt.cpu().numpy().flatten()
    y_pred_flat = preds.cpu().numpy().flatten()
    
    # Mean Error
    me = 1.0 - np.mean(y_true_flat == y_pred_flat)
    
    # IoU
    class_ious = []
    class_dice = []
    print(f"\n[Intersection over Union]")
    for c in range(NUM_CLASSES):
        intersection = np.sum((y_true_flat == c) & (y_pred_flat == c))
        union = np.sum((y_true_flat == c) | (y_pred_flat == c))
        
        if union == 0:
            iou = 1.0 if intersection == 0 else 0.0
        else:
            iou = intersection / union
            dice = (2*iou)/(1+iou)
        print(f"  Class {c} IoU: {iou:.4f} | Dice {dice:.4f} ")
        class_ious.append(iou)
        class_dice.append(dice)

        
    mIoU_all = np.mean(class_ious)
    Dice_all = np.mean(class_dice)
    
    # AUROC
    auroc = 0.0
    if roc_auc_score is not None:
        print("\n[AUROC Calculation...]")
        try:
            y_probs = probs.cpu().numpy()[0] # (C, D, H, W)
            n_classes = y_probs.shape[0]
            # Flatten to (N_voxels, C)
            y_probs_flat = np.transpose(y_probs, (1, 2, 3, 0)).reshape(-1, n_classes)
            auroc = roc_auc_score(y_true_flat, y_probs_flat, multi_class='ovr', average='macro')
        except Exception as e:
            print(f"  AUROC Error: {e}")
            auroc = -1.0
    
    print("\n" + "="*30)
    print(f"FINAL METRICS for Column {TEST_COLUMN_ID}")
    print(f"  Mean Error (ME):       {me:.4f}")
    print(f"  Mean IoU :             {mIoU_all:.4f}")
    print(f"  Mean Dice :            {Dice_all:.4f}")
    print(f"  AUROC (Macro):         {auroc:.4f}")
    print("="*30 + "\n")


def visualize_full_slice(X_tensor, Y_tensor, stitched_seg, stitched_recon, save_path, slice_idx=FULL_VOLUME_SIZE // 2):
    """Saves a central slice using the CUSTOM COLORMAP."""
    pred_seg_tensor = torch.argmax(stitched_seg, dim=1) 
    
    x_np = X_tensor[0, 0, slice_idx, :, :].cpu().numpy()
    y_np = Y_tensor[0, slice_idx, :, :].cpu().numpy()
    recon_np = stitched_recon[0, 0, slice_idx, :, :].cpu().numpy()
    pred_np = pred_seg_tensor[0, slice_idx, :, :].cpu().numpy()
    
    fig, axes = plt.subplots(1, 4, figsize=(24, 6))
    titles = ['Input', 'Ground Truth', 'Reconstruction', 'Prediction']
    data = [x_np, y_np, recon_np, pred_np]
    
    # Use custom map for GT and Pred
    custom_cmap = get_custom_colormap()
    cmaps = ['gray', custom_cmap, 'gray', custom_cmap]
    
    for i, ax in enumerate(axes):
        # Vmax=3 for the 4 classes (0,1,2,3)
        vmax = NUM_CLASSES - 1 if i == 1 or i == 3 else None
        vmin = 0 if i == 1 or i == 3 else None
        interp = 'nearest' if i == 1 or i == 3 else None
        
        im = ax.imshow(data[i], cmap=cmaps[i], interpolation=interp, vmin=vmin, vmax=vmax)
        ax.set_title(titles[i], fontsize=18)
        ax.axis('off')
        
        if i == 1 or i == 3:
             cbar = plt.colorbar(im, ax=ax, ticks=range(NUM_CLASSES), fraction=0.046, pad=0.04)
             cbar.ax.set_yticklabels(['Water', 'Oil', 'Solids', 'Gas'])

    plt.suptitle(f"Full Resolution Inference - Test Column {TEST_COLUMN_ID} - Slice {slice_idx}", fontsize=24)
    plt.savefig(save_path, bbox_inches='tight', dpi=150)
    plt.close(fig)
    print(f"Visualization saved to: {save_path}")



In [3]:
if __name__ == "__main__":
    
    # 1. Load Model
    #model = MultiTaskNet(in_channels=1, num_classes=NUM_CLASSES, latent_dim=LATENT_DIM).to(device)
    model = VAE(in_channels=1, latent_dim=LATENT_DIM, NUM_CLASSES=NUM_CLASSES)

    if MODEL_PATH.exists():
        print(f"Loading weights from: {MODEL_PATH}")
        model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
    else:
        print(f"❌ Model weights not found at {MODEL_PATH}")
        sys.exit(1)
    for i in TEST_COLUMN_ID: 
        TEST_COLUMN_ID = i
        # 2. Load Data (Column 35, Top Half)
        X_full_tensor, Y_full_tensor = load_raw_volume(column_num=TEST_COLUMN_ID, half_type='top')
    
        # 3. Stitching Inference
        stitched_seg, stitched_recon = stitch_inference(model, X_full_tensor)

        # 4. Metrics Calculation
        calculate_metrics(stitched_seg, Y_full_tensor)

        # 5. Visualization
        save_full_path = Path.cwd() / f"full_vae_best_col{i}.png"
        visualize_full_slice(X_full_tensor, Y_full_tensor, stitched_seg, stitched_recon, save_full_path)

❌ Model weights not found at /zhome/d2/4/167803/Desktop/Deep_project/02456-final-project/Trained_models/VAE_val_finalVAE1.pth


SystemExit: 1

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [None]:
if __name__ == "__main__":
    
    # 1. Load Model
    #model = MultiTaskNet(in_channels=1, num_classes=NUM_CLASSES, latent_dim=LATENT_DIM).to(device)
    model = VAE(in_channels=1, latent_dim=LATENT_DIM, NUM_CLASSES=NUM_CLASSES)

    if MODEL_PATH.exists():
        print(f"Loading weights from: {MODEL_PATH}")
        model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
    else:
        print(f"❌ Model weights not found at {MODEL_PATH}")
        sys.exit(1)
    for i in TEST_COLUMN_ID: 
        TEST_COLUMN_ID = i
        # 2. Load Data (Column 35, Top Half)
        X_full_tensor, Y_full_tensor = load_raw_volume(column_num=TEST_COLUMN_ID, half_type='top')
    
        # 3. Stitching Inference
        stitched_seg, stitched_recon = stitch_inference(model, X_full_tensor)

        # 4. Metrics Calculation
        calculate_metrics(stitched_seg, Y_full_tensor)

        # 5. Visualization
        save_full_path = Path.cwd() / f"full_vae_final_col{i}.png"
        visualize_full_slice(X_full_tensor, Y_full_tensor, stitched_seg, stitched_recon, save_full_path)

Loading weights from: /zhome/d2/4/167803/Desktop/Deep_project/02456-final-project/Trained_models/VAE_val_finalVAE1.pth


  model.load_state_dict(torch.load(MODEL_PATH, map_location=device))


Loading: /dtu/blackhole/1b/167803/deep_learning_214776/extracted_datasets/datasets_processed_latest/Column_1/B/top.mat
Starting Stitching Loop...
  Processed Patch: [0:128, 0:128, 0:128]
  Processed Patch: [0:128, 0:128, 128:256]
  Processed Patch: [0:128, 128:256, 0:128]
  Processed Patch: [0:128, 128:256, 128:256]
  Processed Patch: [128:256, 0:128, 0:128]
  Processed Patch: [128:256, 0:128, 128:256]
  Processed Patch: [128:256, 128:256, 0:128]
  Processed Patch: [128:256, 128:256, 128:256]

--- Calculating Full Volume Metrics ---

[Intersection over Union]
  Class 0 IoU: 0.8816 | Dice 0.9371 
  Class 1 IoU: 0.9200 | Dice 0.9583 
  Class 2 IoU: 0.9403 | Dice 0.9692 
  Class 3 IoU: 0.0018 | Dice 0.0036 

[AUROC Calculation...]

FINAL METRICS for Column 1
  Mean Error (ME):       0.0458
  Mean IoU :             0.6859
  Mean Dice :            0.7171
  AUROC (Macro):         0.9583

✅ Visualization saved to: /zhome/d2/4/167803/Desktop/Deep_project/02456-final-project/scripts/full_vae_be

In [None]:
for i in range(50):
    img_save = f"vae_{i}.png"
    visualize_full_slice(X_full_tensor, Y_full_tensor, stitched_seg, stitched_recon, img_save, slice_idx=i)

✅ Visualization saved to: vae_0.png
✅ Visualization saved to: vae_1.png
✅ Visualization saved to: vae_2.png
✅ Visualization saved to: vae_3.png
✅ Visualization saved to: vae_4.png
✅ Visualization saved to: vae_5.png
✅ Visualization saved to: vae_6.png
✅ Visualization saved to: vae_7.png
✅ Visualization saved to: vae_8.png
✅ Visualization saved to: vae_9.png
✅ Visualization saved to: vae_10.png
✅ Visualization saved to: vae_11.png
✅ Visualization saved to: vae_12.png
✅ Visualization saved to: vae_13.png
✅ Visualization saved to: vae_14.png
✅ Visualization saved to: vae_15.png
✅ Visualization saved to: vae_16.png
✅ Visualization saved to: vae_17.png
✅ Visualization saved to: vae_18.png
✅ Visualization saved to: vae_19.png
✅ Visualization saved to: vae_20.png
✅ Visualization saved to: vae_21.png
✅ Visualization saved to: vae_22.png
✅ Visualization saved to: vae_23.png
✅ Visualization saved to: vae_24.png
✅ Visualization saved to: vae_25.png
✅ Visualization saved to: vae_26.png
✅ Visualiza

: 