In [None]:
# VERDICT MLP Model Evaluation
# This notebook evaluates the trained MLP model on patient brain data

import nibabel as nib
import numpy as np
import torch
import pickle
import os
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import MinMaxScaler

# Import MLP model
from models.mlp import MLPRegressor

# Load the NIfTI file
img = nib.load(r'D:\AiProjects\UCLmaster\OneDrive_1_2025-5-2\data\Patient08\Patient08_mc_normb0.nii.gz')
data = img.get_fdata()

print("=== VERDICT MLP Model Evaluation ===")
print(f"Loaded brain data shape: {data.shape}")
print(f"Data type: {data.dtype}")

In [None]:
# Data preprocessing
I = img.get_fdata().astype(np.float64)

# Get image dimensions
sx, sy, sz, vol = I.shape
print(f"Image dimensions: {sx} x {sy} x {sz}, Volumes: {vol}")

# Reshape image to 2D (voxels x volumes)
ROI = I.reshape((sx * sy * sz, vol))
print(f"ROI shape after reshape: {ROI.shape}")

# Clean the signal: remove nan, inf, and negative values
signal = np.nan_to_num(ROI, nan=0.0, posinf=0.0, neginf=0.0)
signal[signal < 0] = 0

# Keep all voxels to preserve spatial structure
signal_filtered = signal
print(f"Total voxels: {signal_filtered.shape[0]:,}")
print(f"Non-zero voxels: {np.count_nonzero(np.any(signal_filtered > 0, axis=1)):,}")
print(f"Signal range: [{signal_filtered.min():.3f}, {signal_filtered.max():.3f}]")

Shape of data: (112, 112, 60, 153)
Data type of data: float64


In [None]:
# Load MLP model and scaler
def load_mlp_model():
    """Load the trained MLP model and its scaler"""
    checkpoint_dir = r'd:\AiProjects\verdict_benchmark\checkpoints'
    
    # Load model
    model = MLPRegressor(input_dim=153, output_dim=8, 
                        hidden_dims=[150, 150, 150], activation='relu')
    
    model_path = os.path.join(checkpoint_dir, 'mlp_best.pt')
    if os.path.exists(model_path):
        state_dict = torch.load(model_path, map_location='cpu', weights_only=True)
        model.load_state_dict(state_dict)
        model.eval()
        print("✓ Loaded MLP model successfully")
    else:
        raise FileNotFoundError(f"Model file not found: {model_path}")
    
    # Load scaler
    scaler_path = os.path.join(checkpoint_dir, 'mlp_scaler.pkl')
    if os.path.exists(scaler_path):
        with open(scaler_path, 'rb') as f:
            scaler = pickle.load(f)
        print("✓ Loaded MLP scaler successfully")
    else:
        raise FileNotFoundError(f"Scaler file not found: {scaler_path}")
    
    return model, scaler

# Load the MLP model and scaler
mlp_model, mlp_scaler = load_mlp_model()
print(f"Model parameters: {sum(p.numel() for p in mlp_model.parameters()):,}")

Image shape: (112, 112, 60, 153)


In [None]:
# Define VERDICT parameter names
param_names = ['fic', 'fee', 'Dic', 'R', 'Dpar', 'Dtra', 'theta', 'phi']
print(f"VERDICT parameters: {param_names}")

# Run MLP inference on all voxels
print("\n=== Running MLP Inference ===")

# Get non-zero voxels for processing
non_zero_mask = np.any(signal_filtered > 0, axis=1)
signal_nonzero = signal_filtered[non_zero_mask]
print(f"Processing {signal_nonzero.shape[0]:,} non-zero voxels...")

# Process in batches to avoid memory issues
batch_size = 1000
all_predictions = []

with torch.no_grad():
    for i in range(0, signal_nonzero.shape[0], batch_size):
        batch = signal_nonzero[i:i+batch_size]
        batch_tensor = torch.FloatTensor(batch)
        
        # Get scaled predictions from model
        batch_pred_scaled = mlp_model(batch_tensor)
        
        # Inverse transform to get real parameter values
        batch_pred = mlp_scaler.inverse_transform(batch_pred_scaled.numpy())
        all_predictions.append(batch_pred)
        
        if (i // batch_size + 1) % 20 == 0:
            print(f"  Processed {i + batch.shape[0]:,} voxels...")

# Concatenate all predictions
predictions_p1_p8 = np.concatenate(all_predictions, axis=0)
print(f"✓ Completed MLP inference!")
print(f"Predictions shape: {predictions_p1_p8.shape}")
print(f"Parameter range: [{predictions_p1_p8.min():.4f}, {predictions_p1_p8.max():.4f}]")

Image dimensions: 112 x 112 x 60, Volumes: 153
ROI shape after reshape: (752640, 153)
ROI shape after reshape: (752640, 153)


In [None]:
# Transform model outputs (p1-p8) to actual VERDICT parameters
print("\n=== Parameter Transformation ===")

# Apply transformations to get actual VERDICT parameters
predictions_verdict = np.zeros_like(predictions_p1_p8)

# fic = cos(p1)^2
predictions_verdict[:, 0] = np.cos(predictions_p1_p8[:, 0])**2

# fee = (1-cos(p1)^2)*cos(p2)^2 = sin(p1)^2 * cos(p2)^2  
predictions_verdict[:, 1] = (1 - np.cos(predictions_p1_p8[:, 0])**2) * np.cos(predictions_p1_p8[:, 1])**2

# Dic = p3 (Intracellular diffusivity)
predictions_verdict[:, 2] = predictions_p1_p8[:, 2]

# R = p4 (Cell radius)
predictions_verdict[:, 3] = predictions_p1_p8[:, 3]

# Dpar = p5 (Parallel diffusivity)
predictions_verdict[:, 4] = predictions_p1_p8[:, 4]

# Dtra = p6*p5 (Transverse diffusivity)
predictions_verdict[:, 5] = predictions_p1_p8[:, 5] * predictions_p1_p8[:, 4]

# theta = p7 (Polar angle)
predictions_verdict[:, 6] = predictions_p1_p8[:, 6]

# phi = p8 (Azimuthal angle)
predictions_verdict[:, 7] = predictions_p1_p8[:, 7]

print("✓ Parameter transformation completed!")
print(f"Transformed VERDICT range: [{predictions_verdict.min():.6f}, {predictions_verdict.max():.6f}]")

# Use transformed values for all analysis
predictions_mlp = predictions_verdict

Signal shape: (752640, 153)
Signal data type: float64


In [None]:
# Create 3D prediction volume for visualization
print("\n=== Creating 3D Visualization Volume ===")

# Initialize full volume with zeros
prediction_volume = np.zeros((sx * sy * sz, 8))

# Fill in predictions for non-zero voxels
prediction_volume[non_zero_mask] = predictions_mlp

# Reshape to 3D volume
prediction_3d = prediction_volume.reshape((sx, sy, sz, 8))

# Define middle slices for visualization
mid_slice_axial = sz // 2
mid_slice_coronal = sy // 2
mid_slice_sagittal = sx // 2

print(f"3D volume shape: {prediction_3d.shape}")
print(f"Middle slices - Axial: {mid_slice_axial}, Coronal: {mid_slice_coronal}, Sagittal: {mid_slice_sagittal}")
print("✓ 3D volume ready for visualization")

Total voxels kept: 752640
Signal shape: (752640, 153)
Non-zero voxels: 150682
Zero voxels: 601958


In [None]:
# Visualization 1: Brain Parameter Maps
print("\n=== Brain Parameter Maps ===")
plt.style.use('default')
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
fig.suptitle('VERDICT MRI Parameter Maps - Axial View (MLP Model)', fontsize=16, y=0.95)

for i in range(8):
    row = i // 4
    col = i % 4
    ax = axes[row, col]
    
    param_slice = prediction_3d[:, :, mid_slice_axial, i]
    im = ax.imshow(param_slice, cmap='viridis', aspect='equal')
    ax.set_title(f'{param_names[i]}\n(Axial slice {mid_slice_axial})', fontsize=12)
    ax.axis('off')
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=ax, shrink=0.8)
    cbar.ax.tick_params(labelsize=10)

plt.tight_layout()
plt.show()

Processing signal with shape: (752640, 153)
Non-zero voxels: 150682
Data range: [0.000, 1.999]


In [None]:
# Visualization 2: Parameter Statistical Analysis
print("\n=== Parameter Statistical Analysis ===")
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# 2a. Parameter distribution histograms (first 4)
ax = axes[0, 0]
for i in range(4):
    ax.hist(predictions_mlp[:, i], alpha=0.6, bins=50, label=param_names[i])
ax.set_xlabel('Parameter Value')
ax.set_ylabel('Frequency')
ax.set_title('Parameter Distributions (Volume Fractions & Diffusivity)')
ax.legend()
ax.grid(True, alpha=0.3)

# 2b. Parameter distribution histograms (last 4)
ax = axes[0, 1]
for i in range(4, 8):
    ax.hist(predictions_mlp[:, i], alpha=0.6, bins=50, label=param_names[i])
ax.set_xlabel('Parameter Value')
ax.set_ylabel('Frequency')
ax.set_title('Parameter Distributions (Geometry & Orientation)')
ax.legend()
ax.grid(True, alpha=0.3)

# 2c. Parameter correlation matrix
ax = axes[1, 0]
correlation_matrix = np.corrcoef(predictions_mlp.T)
im = ax.imshow(correlation_matrix, cmap='RdBu_r', vmin=-1, vmax=1)
ax.set_title('Parameter Correlation Matrix')
ax.set_xticks(range(8))
ax.set_xticklabels(param_names, rotation=45, ha='right')
ax.set_yticks(range(8))
ax.set_yticklabels(param_names)
cbar = plt.colorbar(im, ax=ax, shrink=0.8)

# Add correlation values
for i in range(8):
    for j in range(8):
        ax.text(j, i, f'{correlation_matrix[i, j]:.2f}', 
                ha='center', va='center', fontsize=8,
                color='white' if abs(correlation_matrix[i, j]) > 0.5 else 'black')

# 2d. Box plots for parameter ranges
ax = axes[1, 1]
box_data = [predictions_mlp[:, i] for i in range(8)]
bp = ax.boxplot(box_data, labels=param_names, patch_artist=True)
for patch in bp['boxes']:
    patch.set_facecolor('lightblue')
    patch.set_alpha(0.7)
ax.set_ylabel('Parameter Value')
ax.set_title('Parameter Value Ranges')
ax.tick_params(axis='x', rotation=45)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

VERDICT parameters: ['fic', 'fee', 'Dic', 'R', 'Dpar', 'Dtra', 'theta', 'phi']


In [None]:
# Visualization 3: Multi-orientation Brain Views (fIC parameter)
print("\n=== Multi-orientation Brain Views ===")
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
fig.suptitle('Intracellular Volume Fraction (fIC) - Multiple Brain Views', fontsize=16, y=0.95)

# 3a. Axial view
ax = axes[0, 0]
axial_slice = prediction_3d[:, :, mid_slice_axial, 0]
im = ax.imshow(axial_slice, cmap='viridis', aspect='equal')
ax.set_title(f'Axial View (slice {mid_slice_axial})')
ax.axis('off')
plt.colorbar(im, ax=ax, shrink=0.8)

# 3b. Coronal view
ax = axes[0, 1]
coronal_slice = prediction_3d[:, mid_slice_coronal, :, 0]
im = ax.imshow(coronal_slice.T, cmap='viridis', aspect='auto', origin='lower')
ax.set_title(f'Coronal View (slice {mid_slice_coronal})')
ax.set_xlabel('X direction')
ax.set_ylabel('Z direction')
plt.colorbar(im, ax=ax, shrink=0.8)

# 3c. Sagittal view
ax = axes[1, 0]
sagittal_slice = prediction_3d[mid_slice_sagittal, :, :, 0]
im = ax.imshow(sagittal_slice.T, cmap='viridis', aspect='auto', origin='lower')
ax.set_title(f'Sagittal View (slice {mid_slice_sagittal})')
ax.set_xlabel('Y direction')
ax.set_ylabel('Z direction')
plt.colorbar(im, ax=ax, shrink=0.8)

# 3d. Brain mask overlay
ax = axes[1, 1]
brain_mask = np.any(prediction_3d > 0, axis=3)
brain_slice = brain_mask[:, :, mid_slice_axial]
ax.imshow(brain_slice, cmap='gray', aspect='equal')
ax.set_title(f'Brain Tissue Mask (Axial slice {mid_slice_axial})')
ax.axis('off')

plt.tight_layout()
plt.show()

Available models: ['mlp', 'cnn', 'residual_mlp', 'rnn', 'transformer', 'moe_regressor', 'vae_regressor', 'tabnet_regressor']


In [None]:
# Visualization 4: Summary Statistics and Analysis
print("\n=== Summary Statistics ===")
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('VERDICT Parameter Analysis Summary', fontsize=16, y=0.95)

# 4a. Signal quality metrics
ax = axes[0, 0]
signal_std = np.std(signal_nonzero, axis=1)
signal_mean = np.mean(signal_nonzero, axis=1)
signal_cv = signal_std / (signal_mean + 1e-8)

scatter = ax.scatter(signal_mean, signal_cv, alpha=0.5, s=1, c='blue')
ax.set_xlabel('Signal Mean')
ax.set_ylabel('Coefficient of Variation')
ax.set_title('Signal Quality Assessment')
ax.grid(True, alpha=0.3)

# 4b. Parameter means comparison
ax = axes[0, 1]
param_means = [predictions_mlp[:, i].mean() for i in range(8)]
param_stds = [predictions_mlp[:, i].std() for i in range(8)]
bars = ax.bar(param_names, param_means, yerr=param_stds, capsize=5, alpha=0.7, color='green')
ax.set_xlabel('Parameters')
ax.set_ylabel('Mean Value ± Std')
ax.set_title('Parameter Summary Statistics')
ax.tick_params(axis='x', rotation=45)
ax.grid(True, alpha=0.3)

# Add value labels with appropriate precision
for i, (bar, mean, std) in enumerate(zip(bars, param_means, param_stds)):
    if param_names[i] in ['Dic', 'R', 'Dpar', 'Dtra']:
        label = f'{mean:.6f}'
    else:
        label = f'{mean:.3f}'
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + std + 0.02, 
             label, ha='center', va='bottom', fontsize=9)

# 4c. Parameter ranges visualization
ax = axes[1, 0]
param_mins = [predictions_mlp[:, i].min() for i in range(8)]
param_maxs = [predictions_mlp[:, i].max() for i in range(8)]

x_pos = np.arange(8)
for i in range(8):
    ax.errorbar(x_pos[i], param_means[i], 
                yerr=[[param_means[i] - param_mins[i]], [param_maxs[i] - param_means[i]]], 
                fmt='o', capsize=5, markersize=8)

ax.set_xlabel('Parameters')
ax.set_ylabel('Value Range')
ax.set_title('Parameter Value Ranges (Min-Max)')
ax.set_xticks(x_pos)
ax.set_xticklabels(param_names, rotation=45, ha='right')
ax.grid(True, alpha=0.3)

# 4d. Summary statistics table
ax = axes[1, 1]
ax.axis('off')

stats_text = "VERDICT Parameter Statistics Summary:\n\n"
stats_text += f"{'Parameter':<8} {'Mean':<12} {'Std':<12} {'Min':<12} {'Max':<12}\n"
stats_text += "-" * 65 + "\n"

for i, param in enumerate(param_names):
    values = predictions_mlp[:, i]
    if param in ['Dic', 'R', 'Dpar', 'Dtra']:
        if values.mean() < 0.001:
            stats_text += f"{param:<8} {values.mean():<12.2e} {values.std():<12.2e} {values.min():<12.2e} {values.max():<12.2e}\n"
        else:
            stats_text += f"{param:<8} {values.mean():<12.6f} {values.std():<12.6f} {values.min():<12.6f} {values.max():<12.6f}\n"
    else:
        stats_text += f"{param:<8} {values.mean():<12.3f} {values.std():<12.3f} {values.min():<12.3f} {values.max():<12.3f}\n"

stats_text += "\n" + "-" * 65 + "\n"
stats_text += f"Brain volume: {sx} × {sy} × {sz} = {sx*sy*sz:,} voxels\n"
stats_text += f"Brain tissue: {predictions_mlp.shape[0]:,} voxels ({100*predictions_mlp.shape[0]/(sx*sy*sz):.1f}%)\n"
stats_text += f"MLP Model: {sum(p.numel() for p in mlp_model.parameters()):,} parameters\n"

ax.text(0.05, 0.95, stats_text, transform=ax.transAxes, fontsize=9,
         verticalalignment='top', fontfamily='monospace',
         bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))

plt.tight_layout()
plt.show()

print("\n✓ MLP Evaluation Complete!")

MLP Model architecture from weights:
Input dim: 153
Hidden layers: torch.Size([150, 150])
Output dim: 8


  checkpoint = torch.load(r'd:\AiProjects\verdict_benchmark\checkpoints\mlp_best.pt', map_location='cpu')
