In [None]:
import sys
sys.path.append('../python')

from plotMethods import *
from constants import *
from distanceLossMethods *
import plotParameters

# General Settings

In [2]:
dataDir = '/Users/noah-everett/Documents_non-iCloud/dsps/mldata/trainData_6'
predsDir = 'model_12_12hrs_TverskyLoss_lr=0.01,a=0.7,b=0.3,fmaps=32,64_last'
outputDir = f'/Users/noah-everett/Documents_non-iCloud/dsps/preds/{predsDir}/'
figuresDir = '../figures/UNetPreds/'

# Single File Analysis

In [3]:
fileNumber = 217 #56 #45 #30
reshapeSize = 40

In [4]:
import h5py

file = f'{outputDir}/{fileNumber}_predictions.h5'
f = h5py.File(file, 'r')
pred = f['predictions'][:]
f.close()
pred = pred[0,:,:,:]

file = f'{dataDir}/{fileNumber}.h5'
f = h5py.File(file, 'r')
true = f['y'][:]
f.close()
true = true[:,:,:]

step = (pred.shape[0]//reshapeSize, pred.shape[1]//reshapeSize, pred.shape[2]//reshapeSize)
assert (pred.shape[0] % reshapeSize == 0) and (pred.shape[1] % reshapeSize == 0) and (pred.shape[2] % reshapeSize == 0)
pred = pred.reshape(reshapeSize, step[0], reshapeSize, step[1], reshapeSize, step[2]).mean(axis=(1,3,5))
true = true.reshape(reshapeSize, step[0], reshapeSize, step[1], reshapeSize, step[2]).mean(axis=(1,3,5))

In [5]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np

fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(111, projection='3d')

gridSize = pred.shape

xEdges = np.linspace(-DETECTOR_SIZE_MM[0]/2, DETECTOR_SIZE_MM[0]/2, gridSize[0] + 1)
yEdges = np.linspace(-DETECTOR_SIZE_MM[1]/2, DETECTOR_SIZE_MM[1]/2, gridSize[1] + 1)
zEdges = np.linspace(-DETECTOR_SIZE_MM[2]/2, DETECTOR_SIZE_MM[2]/2, gridSize[2] + 1)
yEdges, xEdges, zEdges = np.meshgrid(xEdges, yEdges, zEdges)

print('np.min(pred):', np.min(pred))
print('np.max(pred):', np.max(pred))

### Use percentile to find shown voxels
# minPercentile = 99.5
# maxPercentile = 100
# minVal = np.percentile(pred, minPercentile)
# maxVal = np.percentile(pred, maxPercentile)
# print(f'minVal ({minPercentile}th percentile):', minVal)
# print(f'maxVal ({maxPercentile}th percentile):', maxVal)

### Show top N voxels
N = 30
topNIndices = np.unravel_index(np.argsort(pred.ravel())[-N:], pred.shape)
minVal = pred[topNIndices].min()
maxVal = pred[topNIndices].max()
print('minVal (top N):', minVal)
print('maxVal (top N):', maxVal)

alpha_filled = 0.5
globalColorNorm = cm.colors.Normalize(vmin=minVal, vmax=maxVal)
pred = np.where(pred < minVal, 0, pred)
colors = cm.viridis(globalColorNorm(pred))

ax = plot_grid(
    ax,
    xEdges,
    yEdges,
    zEdges,
    recoGrid=pred,
    recoGridFaceColors=colors,
    recoGridEdgeColors=np.clip(colors*2-0.5, 0, 1),
    recoGridAlpha=alpha_filled,
    trueGrid=true,
    trueGridEdgeColors='red',
    trueGridAlpha=0,
    nullGridAlpha=0.3,
    linewidth=0.5,
    cbar=True,
    colorNorm=globalColorNorm,
    cmap=cm.viridis,
)

ax.set_axis_off()
ax.set_aspect('equal')

plt.savefig(f'{figuresDir}/{fileNumber}_[{predsDir}].pdf', bbox_inches='tight')
plt.show()

np.min(pred): 0.0
np.max(pred): 1.0
minVal (top N): 0.25
maxVal (top N): 1.0


  plt.show()


In [6]:
del pred

# Multi-File Analysis

In [7]:
import numpy as np

fileNumbers = np.arange(0, 1138)
# fileNumbers = np.arange(200, 220)
# fileNumbers = np.arange(0, 20)
reshapeSize = 20

In [8]:
import h5py
import numpy as np
from tqdm import tqdm

fileNames_recos = [f'{outputDir}/{i}_predictions.h5' for i in fileNumbers]
fileNames_trues = [f'{dataDir}/{i}.h5' for i in fileNumbers]

# Load data
data_recos = [h5py.File(i, 'r')['predictions'][0,:,:,:] for i in tqdm(fileNames_recos)]
data_trues = [h5py.File(i, 'r')['y'][:,:,:] for i in tqdm(fileNames_trues)]
assert len(data_recos) == len(data_trues)
assert all([i.shape == j.shape for i, j in zip(data_recos, data_trues)])
assert all(data_recos[0].shape == i.shape for i in data_recos)

100%|██████████| 1138/1138 [00:22<00:00, 51.65it/s]
100%|██████████| 1138/1138 [00:02<00:00, 494.26it/s]


In [9]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# Flatten voxel arrays
recos_flat = [arr.ravel() for arr in data_recos]

# Fixed x bounds
main_lo, main_hi = 0.0, 0.001
tail_lo, tail_hi = 0.10, 1.0

# Binning
bins_main = 100
bins_tail = 100
bin_edges_main = np.linspace(main_lo, main_hi, bins_main + 1)
bin_edges_tail = np.linspace(tail_lo, tail_hi, bins_tail + 1)

# Figure
fig, axes = plt.subplots(1, 2, figsize=(12, 6), constrained_layout=True)
ax_main, ax_tail = axes
colors = plt.cm.viridis(np.linspace(0, 1, len(recos_flat)))

# Left plot
for idx, vals in tqdm(enumerate(recos_flat), total=len(recos_flat), desc="Plotting left histogram"):
    ax_main.hist(vals, bins=bin_edges_main, histtype='step',
                 linewidth=1.2, alpha=0.1, color=colors[idx],
                 label=f'File {fileNumbers[idx]}')
ax_main.set_title(f"Voxel Distributions ({main_lo} to {main_hi})", fontsize=16, pad=10)
ax_main.set_xlabel("Voxel Value", fontsize=14)
ax_main.set_ylabel("Count", fontsize=14)
ax_main.set_xlim(main_lo, main_hi)
ax_main.tick_params(axis='both', which='major', labelsize=12, length=6)
ax_main.tick_params(axis='both', which='minor', labelsize=10, length=4)
ax_main.grid(True, linestyle='--', alpha=0.4)

# Right plot
for idx, vals in tqdm(enumerate(recos_flat), total=len(recos_flat), desc="Plotting right histogram"):
    ax_tail.hist(vals, bins=bin_edges_tail, histtype='step',
                 linewidth=1.2, alpha=0.1, color=colors[idx])
ax_tail.set_title(f"Voxel Distributions ({tail_lo} to {tail_hi})", fontsize=16, pad=10)
ax_tail.set_xlabel("Voxel Value", fontsize=14)
ax_tail.set_ylabel("Count", fontsize=14)
ax_tail.set_xlim(tail_lo, tail_hi)
ax_tail.tick_params(axis='both', which='major', labelsize=12, length=6)
ax_tail.tick_params(axis='both', which='minor', labelsize=10, length=4)
ax_tail.grid(True, linestyle='--', alpha=0.4)

# Shared legend
handles, labels = ax_main.get_legend_handles_labels()
# fig.legend(handles, labels, fontsize=8, ncol=3, frameon=False, loc='upper center')

plt.savefig(f'{figuresDir}/voxelHist_[{predsDir}].pdf', bbox_inches='tight')
plt.show()

# Free memory
del recos_flat

Plotting left histogram: 100%|██████████| 1138/1138 [00:32<00:00, 35.28it/s]
Plotting right histogram: 100%|██████████| 1138/1138 [00:31<00:00, 35.83it/s]
  plt.show()


# Train, Test, and Validation

In [10]:
n_bins      = 50
xmax        = 200          # set to None to auto-scale from data
line_width  = 2
density     = True
show_means  = True         # plots[0]
show_pervox = True         # plots[1]
eps_init    = 15
optimize    = True
min_entries = 15
save_path   = f'{figuresDir}/meanDistances_[{predsDir}].pdf'

seed = 42
nTest = 20
nVal = 10
np.random.seed(seed)
indices = np.arange(len(data_recos))
test_indices = np.random.choice(indices, nTest, replace=False)
remaining = np.setdiff1d(indices, test_indices)
val_indices = np.random.choice(remaining, nVal, replace=False)
train_indices = np.setdiff1d(remaining, val_indices)

In [11]:
import numpy as np

# Optional: set voxel spacing in mm (or your units). Defaults to 1,1,1.
VOXEL_SIZE = (1.0, 1.0, 1.0)  # (dz, dy, dx)

# Use the first volume (you asserted all shapes match)
_Z, _Y, _X = data_recos[0].shape
z = np.arange(_Z, dtype=np.float32) * VOXEL_SIZE[0]
y = np.arange(_Y, dtype=np.float32) * VOXEL_SIZE[1]
x = np.arange(_X, dtype=np.float32) * VOXEL_SIZE[2]

# Mesh in (Z,Y,X) with ij indexing to match array layout; then flatten in C-order.
grid_pos = np.stack(np.meshgrid(z, y, x, indexing='ij'), axis=-1).reshape(-1, 3).astype(np.float32)

In [13]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from tqdm import tqdm

# Colors from the "twilight" colormap
tw = get_cmap('twilight', 40)   # 40-sample discrete twilight
colors = [tw(5), tw(40-5), tw(12), tw(40-12)]  # [test TR, test RT, train TR, train RT]

# ======================================================
# Utilities: prediction source + NaN cleanup conveniences
# ======================================================

def ensure_numpy(x):
    """Gracefully convert TF/PT tensors to NumPy arrays when needed."""
    try:
        return x.numpy()
    except AttributeError:
        return np.asarray(x)

def clean_nans_1d(arr):
    """Return 1D array with NaNs removed."""
    arr = np.asarray(arr).ravel()
    return arr[~np.isnan(arr)]

def clean_list_of_1d(list_of_arrays):
    """NaN-clean each 1D array in a list, returning a list of 1D arrays."""
    return [clean_nans_1d(a) for a in list_of_arrays]

# ==================================================================
# Core evaluation: compute mean and per-voxel distances (TR and RT)
# ==================================================================

def compute_means_and_distances(Y, R,
                                eps=eps_init, optimize=optimize,
                                min_entries=min_entries, 
                                desc=""):
    """
    Given arrays/lists of truth (Y) and reco (R) volumes (aligned by index),
    compute:
      - mean distances (True→Reco, Reco→True)
      - per-voxel distance arrays for each event (TR and RT)
    Returns dict of cleaned numpy arrays/lists ready to plot.
    """
    mean_TR, mean_RT = [], []
    dists_TR, dists_RT = [], []

    it = range(len(Y))
    if desc:
        it = tqdm(it, desc=desc)

    for i in it:
        y_i = Y[i]
        r_i = R[i]

        # True → Reco
        m_tr, d_tr = meanAndDistancesFromTrueToReco(
            y_i, r_i, eps=eps, optimize=optimize, minEntries=min_entries
        )
        mean_TR.append(m_tr)
        dists_TR.append(ensure_numpy(d_tr))

        # Reco → True
        m_rt, d_rt = meanAndDistancesFromRecoToTrue(
            y_i, r_i, eps=eps, optimize=optimize, minEntries=min_entries
        )
        mean_RT.append(m_rt)
        dists_RT.append(ensure_numpy(d_rt))

    # Convert to arrays and drop NaNs
    mean_TR = clean_nans_1d(np.asarray(mean_TR, dtype=float))
    mean_RT = clean_nans_1d(np.asarray(mean_RT, dtype=float))
    dists_TR = clean_list_of_1d(dists_TR)
    dists_RT = clean_list_of_1d(dists_RT)

    return {
        "mean_TR": mean_TR,
        "mean_RT": mean_RT,
        "dists_TR": dists_TR,  # list of 1D arrays (per event)
        "dists_RT": dists_RT,  # list of 1D arrays (per event)
    }



# ========================================
# Split the data into train and test sets
# ========================================
Ytr = [data_trues[i] for i in train_indices]
R_train = [data_recos[i] for i in train_indices]
Yte = [data_trues[i] for i in test_indices]
R_test = [data_recos[i] for i in test_indices]

# ==========================
# Compute metrics per split
# ==========================
results_test  = compute_means_and_distances(Yte,  R_test,  desc="Evaluating TEST")
results_train = compute_means_and_distances(Ytr,  R_train, desc="Evaluating TRAIN") if len(Ytr) else {
    "mean_TR": np.array([]), "mean_RT": np.array([]),
    "dists_TR": [], "dists_RT": []
}

# Quick peek (first test event, if any)
if len(results_test["dists_RT"]):
    print("Example TEST RT shape:", results_test["dists_RT"][0].shape)
    print("Example TEST RT values (first 10):", results_test["dists_RT"][0][:10])
if len(results_test["dists_TR"]):
    print("Example TEST TR shape:", results_test["dists_TR"][0].shape)
    print("Example TEST TR values (first 10):", results_test["dists_TR"][0][:10])

del Ytr
del Yte
del R_train
del R_test

# ======================
# Build the figure axes
# ======================
if show_means and show_pervox:
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
elif show_means or show_pervox:
    fig, ax = plt.subplots(1, 1, figsize=(6, 5))
    axes = [ax, ax]
else:
    raise ValueError("At least one of show_means/show_pervox must be True.")

# ==========================
# Panel 1: mean distances
# ==========================
if show_means:
    ax0 = axes[0]
    # Auto xmax if requested
    _all_means = []
    if results_test["mean_TR"].size:  _all_means.append(np.max(results_test["mean_TR"]))
    if results_test["mean_RT"].size:  _all_means.append(np.max(results_test["mean_RT"]))
    if results_train["mean_TR"].size: _all_means.append(np.max(results_train["mean_TR"]))
    if results_train["mean_RT"].size: _all_means.append(np.max(results_train["mean_RT"]))
    _xmax_means = (np.max(_all_means) if (len(_all_means) and xmax is None) else xmax) or xmax

    # Train (if available)
    if results_train["mean_TR"].size:
        add_step_hist(
            results_train["mean_TR"], bins=n_bins, hist_range=(0, _xmax_means),
            density=density, ax=ax0, linewidth=line_width, edgecolor=colors[2],
            facecolor='None', hatch='\\\\', label=r'Train True $\to$ Reco', alpha=1.0
        )
    if results_train["mean_RT"].size:
        add_step_hist(
            results_train["mean_RT"], bins=n_bins, hist_range=(0, _xmax_means),
            density=density, ax=ax0, linewidth=line_width, edgecolor=colors[3],
            facecolor='None', hatch='\\\\', label=r'Train Reco $\to$ True', alpha=1.0
        )

    # Test
    if results_test["mean_TR"].size:
        add_step_hist(
            results_test["mean_TR"], bins=n_bins, hist_range=(0, _xmax_means),
            density=density, ax=ax0, linewidth=line_width, edgecolor=colors[0],
            facecolor='None', hatch='//', label=r'Test True $\to$ Reco', alpha=1.0
        )
    if results_test["mean_RT"].size:
        add_step_hist(
            results_test["mean_RT"], bins=n_bins, hist_range=(0, _xmax_means),
            density=density, ax=ax0, linewidth=line_width, edgecolor=colors[1],
            facecolor='None', hatch='//', label=r'Test Reco $\to$ True', alpha=1.0
        )

    ax0.set_yscale('log')
    ax0.set_xlabel('Mean Distance [mm]')
    ax0.set_ylabel('Density' if density else 'Counts')
    ax0.set_xlim(0, _xmax_means if _xmax_means is not None else None)
    ax0.grid(True, alpha=0.3)
    ax0.legend(frameon=False)

# =======================================
# Panel 2: per-voxel distance distributions
# =======================================
if show_pervox:
    ax1 = axes[1]

    # Concatenate per-event vectors (skip empty)
    trn_TR = np.concatenate(results_train["dists_TR"]) if len(results_train["dists_TR"]) else np.array([])
    trn_RT = np.concatenate(results_train["dists_RT"]) if len(results_train["dists_RT"]) else np.array([])
    tst_TR = np.concatenate(results_test["dists_TR"])  if len(results_test["dists_TR"])  else np.array([])
    tst_RT = np.concatenate(results_test["dists_RT"])  if len(results_test["dists_RT"])  else np.array([])

    # Auto xmax if requested
    _cands = [c.max() for c in [trn_TR, trn_RT, tst_TR, tst_RT] if c.size]
    _xmax_vox = (np.max(_cands) if (len(_cands) and xmax is None) else xmax) or xmax

    # Train (if available)
    if trn_TR.size:
        add_step_hist(
            trn_TR, bins=n_bins, hist_range=(0, _xmax_vox),
            density=density, ax=ax1, linewidth=line_width, edgecolor=colors[2],
            facecolor='None', hatch='\\\\', label=r'Train: True $\to$ Reco', alpha=1.0
        )
    if trn_RT.size:
        add_step_hist(
            trn_RT, bins=n_bins, hist_range=(0, _xmax_vox),
            density=density, ax=ax1, linewidth=line_width, edgecolor=colors[3],
            facecolor='None', hatch='\\\\', label=r'Train: Reco $\to$ True', alpha=1.0
        )

    # Test
    if tst_TR.size:
        add_step_hist(
            tst_TR, bins=n_bins, hist_range=(0, _xmax_vox),
            density=density, ax=ax1, linewidth=line_width, edgecolor=colors[0],
            facecolor='None', hatch='//', label=r'Test: True $\to$ Reco', alpha=1.0
        )
    if tst_RT.size:
        add_step_hist(
            tst_RT, bins=n_bins, hist_range=(0, _xmax_vox),
            density=density, ax=ax1, linewidth=line_width, edgecolor=colors[1],
            facecolor='None', hatch='//', label=r'Test: Reco $\to$ True', alpha=1.0
        )

    ax1.set_yscale('log')
    ax1.set_xlabel('Distance [mm]')
    ax1.set_ylabel('Density' if density else 'Counts')
    ax1.set_xlim(0, _xmax_vox if _xmax_vox is not None else None)
    ax1.grid(True, alpha=0.3)
    ax1.legend(frameon=False)

plt.tight_layout()
if save_path:
    plt.savefig(save_path, bbox_inches='tight')
plt.show()



  tw = get_cmap('twilight', 40)   # 40-sample discrete twilight
Evaluating TEST: 100%|██████████| 20/20 [00:00<00:00, 55.77it/s]
Evaluating TRAIN: 100%|██████████| 1108/1108 [00:07<00:00, 138.74it/s]


Example TEST RT shape: (0,)
Example TEST RT values (first 10): []
Example TEST TR shape: (13,)
Example TEST TR values (first 10): [1.4142135 0.        0.        0.        0.        0.        0.
 0.        0.        0.       ]


  plt.show()


In [14]:
del z
del y
del x
del grid_pos

# Multi-File Plot

In [15]:
from tqdm import tqdm
import numpy as np

# Resize the data
step = (data_recos[0].shape[0]//reshapeSize, data_recos[0].shape[1]//reshapeSize, data_recos[0].shape[2]//reshapeSize)
assert (data_recos[0].shape[0] % reshapeSize == 0) and (data_recos[0].shape[1] % reshapeSize == 0) and (data_recos[0].shape[2] % reshapeSize == 0)
data_recos = [i.reshape(reshapeSize, step[0], reshapeSize, step[1], reshapeSize, step[2]).mean(axis=(1,3,5)) for i in tqdm(data_recos)]
data_trues = [i.reshape(reshapeSize, step[0], reshapeSize, step[1], reshapeSize, step[2]).mean(axis=(1,3,5)) for i in tqdm(data_trues)]

# Data for plotting
scale = 1000 # mm -> m
xEdges = np.linspace(-DETECTOR_SIZE_MM[0]/2/scale, DETECTOR_SIZE_MM[0]/2/scale, data_recos[0].shape[0] + 1)
yEdges = np.linspace(-DETECTOR_SIZE_MM[1]/2/scale, DETECTOR_SIZE_MM[1]/2/scale, data_recos[0].shape[1] + 1)
zEdges = np.linspace(-DETECTOR_SIZE_MM[2]/2/scale, DETECTOR_SIZE_MM[2]/2/scale, data_recos[0].shape[2] + 1)
yEdges, xEdges, zEdges = np.meshgrid(xEdges, yEdges, zEdges)

  0%|          | 0/1138 [00:00<?, ?it/s]

100%|██████████| 1138/1138 [00:01<00:00, 624.04it/s]
100%|██████████| 1138/1138 [00:02<00:00, 448.82it/s]


In [16]:
import numpy as np

fileNumbersToPlot = train_indices
if len(fileNumbers) != 20:
    fileNumbersToPlot = np.arange(200, 220)
else:
    fileNumbersToPlot = fileNumbers

In [17]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
from tqdm import tqdm

assert np.all(np.isin(fileNumbersToPlot, fileNumbers)), 'Not all fileNumbersToPlot are in fileNumbers'
fileIndicesToPlot = np.where(np.isin(fileNumbers, fileNumbersToPlot))[0]
dataRecosToPlot = [data_recos[i] for i in fileIndicesToPlot]
dataTruesToPlot = [data_trues[i] for i in fileIndicesToPlot]

nCols = 4
nRows = 5
assert len(fileNumbersToPlot) == nCols*nRows

def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=256):
    new_cmap = cm.colors.LinearSegmentedColormap.from_list(
        f'trunc({cmap.name},{minval},{maxval})',
        cmap(np.linspace(minval, maxval, n))
    )
    return new_cmap
cmap = truncate_colormap(cm.viridis, minval=0.3, maxval=1.0)

fig, axs = plt.subplots(nRows, nCols, figsize=(16, 18), subplot_kw={'projection': '3d'})

for ind, (predGrid, trueGrid) in tqdm(enumerate(zip(dataRecosToPlot, dataTruesToPlot)), total=len(dataRecosToPlot)):
    # Find top N voxels
    N = 30
    # use _eps_from_topk
    eps = _eps_from_topk(predGrid, 1.0, True, N)
    predGrid = np.where(predGrid < eps, 0, predGrid)
    minVal = predGrid[predGrid > 0].min()
    maxVal = predGrid[predGrid > 0].max()
    # print(f'{ind//nCols}, {ind%nCols}: {minVal}, {maxVal}')
    globalColorNorm = cm.colors.Normalize(vmin=minVal, vmax=maxVal)

    # colors = cm.viridis(globalColorNorm(predGrid))
    colors = cmap(globalColorNorm(predGrid))

    axs[ind//nCols, ind%nCols] = plot_grid(
        axs[ind//nCols, ind%nCols],
        xEdges,
        yEdges,
        zEdges,
        recoGrid=predGrid,
        recoGridFaceColors=colors,
        recoGridEdgeColors=np.clip(colors*2-0.5, 0, 1),
        recoGridAlpha=0.8,
        trueGrid=trueGrid,
        trueGridFaceColors='r',
        trueGridEdgeColors='k',
        trueGridAlpha=0.1,
        nullGridAlpha=0.1,
        linewidth=0.5,
        # truePoints=primaryTrue/scale,
        # truePointsColor='r',
        # truePointsSize=30,
        # truePointsAlpha=0.1,
        # cbar=False,
        cbar=True,
        colorNorm=globalColorNorm,
        cmap=cmap,
    )

    axs[ind//nCols, ind%nCols].set_xlim(-DETECTOR_SIZE_MM[0]/2/scale, DETECTOR_SIZE_MM[0]/2/scale)
    axs[ind//nCols, ind%nCols].set_ylim(-DETECTOR_SIZE_MM[1]/2/scale, DETECTOR_SIZE_MM[1]/2/scale)
    axs[ind//nCols, ind%nCols].set_zlim(-DETECTOR_SIZE_MM[2]/2/scale, DETECTOR_SIZE_MM[2]/2/scale)

    axs[ind//nCols, ind%nCols].set_xlabel(r'$x$ [m]', labelpad=7)
    axs[ind//nCols, ind%nCols].set_ylabel(r'$y$ [m]')
    axs[ind//nCols, ind%nCols].set_zlabel(r'$z$ [m]', labelpad=0)

axNew = fig.add_axes([0.1, 0.1, 0.8, 0.8])
axNew.set_visible(False)
cbar = plt.colorbar(cm.ScalarMappable(norm=globalColorNorm, cmap=cm.viridis), ax=axNew, orientation='horizontal', label='Predicted Weight', alpha=0.2, pad=0.1, aspect=40, shrink=0.8, extend='both')
cbar.ax.xaxis.set_label_position('top')
cbar.ax.xaxis.set_ticks_position('top')
cbar.ax.set_position([0.11, 0.1, 0.8, 0.9])

axExtend = fig.add_axes([1.1, 0.1, 0.1, 0.8])
axExtend.set_visible(True)

fig.tight_layout()
plt.savefig(f'{figuresDir}/allTestEvents_[{predsDir}].pdf', bbox_inches='tight')
plt.show()

100%|██████████| 20/20 [00:28<00:00,  1.45s/it]
  fig.tight_layout()
  plt.show()
