In [None]:
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from vrnn import utils
from vrnn.models import VanillaModule, HillThermModule
from vrnn.normalization import NormalizedDataset, NormalizationModule
from vrnn.data_thermal import DatasetThermal, VoigtReussThermNormalization
import numpy as np


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
dtype = torch.float32

data_dir = utils.get_data_dir()

import shutil
import matplotlib
plt.rcParams["text.usetex"] = True if shutil.which('latex') else False
matplotlib.rcParams["text.latex.preamble"] = r"\usepackage{amsmath}"

In [None]:
# VRNN model
model_norm_file = data_dir / 'Thermal2D_models/vrnn_therm2D_norm_20250121_173200.pt'
# Vanilla model
model_vanilla_file = data_dir / 'Thermal2D_models/vann_therm2D_20250121_202555.pt'

# Load hdf5 files
ms_file = data_dir / 'feature_engineering_thermal_2D.h5'
print(ms_file)

# Load data
feature_idx = None
R_range_train = [1/100., 1/50., 1/20., 1/10., 1/5., 1/2., 2, 5, 10, 20, 50, 100]
train_data = DatasetThermal(file_name=ms_file, R_range=R_range_train, group='train_set',
                            input_mode='descriptors', feature_idx=feature_idx, feature_key='feature_vector', ndim=2)


R_range_val = np.concatenate([np.arange(2, 101, dtype=int), 1. / np.arange(1, 101, dtype=int)])
val_data = DatasetThermal(file_name=ms_file, R_range=R_range_val, group='val_set',
                          input_mode='descriptors', feature_idx=feature_idx, feature_key='feature_vector', ndim=2)

# Create dataloaders
batch_size = 30000
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)

# Fetch data
train_x, train_y = utils.get_data(train_loader, device=device, dtype=dtype)
val_x, val_y = utils.get_data(val_loader, device=device, dtype=dtype)


# Define normalization
features_max = torch.cat([train_x, val_x], dim=0).max(dim=0)[0]
features_min = torch.cat([train_x, val_x], dim=0).min(dim=0)[0]
features_min[0],features_max[0]  = 0, 1 # Dont normalize the first feature (volume fraction)
# features_min, features_max = None, None
normalization = VoigtReussThermNormalization(dim=2, features_min=features_min, features_max=features_max)

# Normalize data
train_data_norm = NormalizedDataset(train_data, normalization)
val_data_norm = NormalizedDataset(val_data, normalization)
train_loader_norm = DataLoader(train_data_norm, batch_size=batch_size, shuffle=False)
val_loader_norm = DataLoader(val_data_norm, batch_size=batch_size, shuffle=False)
train_x_norm, train_y_norm = utils.get_data(train_loader_norm, device=device, dtype=dtype)
val_x_norm, val_y_norm = utils.get_data(val_loader_norm, device=device, dtype=dtype)

In [None]:
from vrnn.tensortools import unpack_sym
from vrnn.losses import VoigtReussNormalizedLoss

in_dim, out_dim = train_x.shape[-1], train_y.shape[-1]
loss_fn = VoigtReussNormalizedLoss(dim=2)

def compute_extended_errors(x, y, y_pred, f1_range=None):
    """
    Computes mean, standard deviation, median, min, max of relative and absolute errors
    for each phase contrast value and returns the full set of errors as well.

    Args:
    - x: Input tensor containing data features.
    - y: Ground truth tensor.
    - y_pred: Predicted output tensor.
    - f1_range: Optional tuple (f1_min, f1_max) to filter data by first feature range (f1).

    Returns:
    - stats: Dictionary with extended statistics (mean, std, median, min, max) 
             for relative and absolute errors for each phase contrast.
    - all_rel_errs: Dictionary with full relative errors for each phase contrast.
    - all_abs_errs: Dictionary with full absolute errors for each phase contrast.
    """
    # Extract unique phase contrasts
    phase_contrasts = torch.unique(x[..., -1]).cpu().numpy()

    # Dictionary to hold statistics
    stats = {}
    all_rel_errs = {}
    all_abs_errs = {}

    for phase_contrast in phase_contrasts:
        # Filter by phase contrast and optional f1 range
        mask = x[:, -1] == phase_contrast
        if f1_range is not None:
            f1_min, f1_max = f1_range
            mask = mask * (x[:, 0] > f1_min) * (x[:, 0] < f1_max)

        # Filtered data
        y_filter = y[mask]
        y_pred_filter = y_pred[mask]

        # Compute absolute and relative errors
        # abs_err = (y_filter - y_pred_filter).norm(dim=-1)
        # rel_err = abs_err / y_filter.norm(dim=-1)
        
        abs_err = torch.norm(unpack_sym(y_filter, dim=2) - unpack_sym(y_pred_filter, dim=2), 'fro', dim=(1,2))
        rel_err = abs_err / torch.norm(unpack_sym(y_filter, dim=2), 'fro', dim=(1,2))

        # Store the full set of errors
        all_rel_errs[phase_contrast.item()] = rel_err.cpu()
        all_abs_errs[phase_contrast.item()] = abs_err.cpu()

        # Store statistics
        stats[phase_contrast.item()] = {
            'mean_rel_err': rel_err.mean().cpu(),
            'std_rel_err': rel_err.std().cpu(),
            'median_rel_err': rel_err.median().cpu(),
            'min_rel_err': rel_err.min().cpu(),
            'max_rel_err': rel_err.max().cpu(),
            'mean_abs_err': abs_err.mean().cpu(),
            'std_abs_err': abs_err.std().cpu(),
            'median_abs_err': abs_err.median().cpu(),
            'min_abs_err': abs_err.min().cpu(),
            'max_abs_err': abs_err.max().cpu(),
        }

    return stats, all_rel_errs, all_abs_errs

# Lambda function to create the mask for a given phase contrast
tmask = lambda contrast: train_x[:, -1] == contrast
vmask = lambda contrast: val_x[:, -1] == contrast

tnmask = lambda contrast: train_x_norm[:, -1] == contrast
vnmask = lambda contrast: val_x_norm[:, -1] == contrast


In [None]:
# VRNN normalized model

ann_model = torch.load(model_norm_file, map_location=device, weights_only=False).to(device=device, dtype=dtype)
model_norm = VanillaModule(ann_model).to(device=device, dtype=dtype)
model_norm.eval()

with torch.inference_mode():
    train_pred_norm = model_norm(train_x_norm)
    val_pred_norm = model_norm(val_x_norm)

vrnn_norm_val_stats, vrnn_norm_val_all_rel_errs, vrnn_norm_val_all_abs_errs = compute_extended_errors(val_x_norm, val_y_norm, val_pred_norm)
vrnn_norm_train_stats, vrnn_norm_train_all_rel_errs, vrnn_norm_train_all_abs_errs = compute_extended_errors(train_x_norm, train_y_norm, train_pred_norm)



# VRNN model

model = NormalizationModule(normalized_module=model_norm, normalization=normalization).to(device=device, dtype=dtype)
model.eval()

with torch.inference_mode():
    train_pred = model(train_x)
    val_pred = model(val_x)

rel_err_train = torch.norm(unpack_sym(train_y, dim=2) - unpack_sym(train_pred, dim=2), 'fro', dim=(1,2)) / torch.norm(unpack_sym(train_y, dim=2), 'fro', dim=(1,2))
rel_err_val = torch.norm(unpack_sym(val_y, dim=2) - unpack_sym(val_pred, dim=2), 'fro', dim=(1,2)) / torch.norm(unpack_sym(val_y, dim=2), 'fro', dim=(1,2))

print(f'median rel. error (training) {rel_err_train.median():.4f}, '
      f'median rel. error (validation) {rel_err_val.median():.4f}')
print(f'mean rel. error (training) {rel_err_train.mean():.4f}, '
      f'mean rel. error (validation) {rel_err_val.mean():.4f}')
print(f'max rel. error (training) {rel_err_train.max():.4f}, '
      f'max rel. error (validation) {rel_err_val.max():.4f}')

vrnn_val_stats, vrnn_val_all_rel_errs, vrnn_val_all_abs_errs = compute_extended_errors(val_x, val_y, val_pred)
vrnn_train_stats, vrnn_train_all_rel_errs, vrnn_train_all_abs_errs = compute_extended_errors(train_x, train_y, train_pred)

In [None]:
# Vanilla ANN model

ann_model = torch.load(model_vanilla_file, map_location=device, weights_only=False).to(device=device, dtype=dtype)
model_vanilla = VanillaModule(ann_model).to(device=device, dtype=dtype)
model_vanilla.eval()

with torch.inference_mode():
    train_pred_vanilla = model_vanilla(train_x)
    val_pred_vanilla = model_vanilla(val_x)

rel_err_train_vanilla = torch.norm(unpack_sym(train_y, dim=2) - unpack_sym(train_pred_vanilla, dim=2), 'fro', dim=(1,2)) / torch.norm(unpack_sym(train_y, dim=2), 'fro', dim=(1,2))
rel_err_val_vanilla = torch.norm(unpack_sym(val_y, dim=2) - unpack_sym(val_pred_vanilla, dim=2), 'fro', dim=(1,2)) / torch.norm(unpack_sym(val_y, dim=2), 'fro', dim=(1,2))

print(f'median rel. error (training, Vanilla) {rel_err_train_vanilla.median():.4f}, '
      f'median rel. error (validation, Vanilla) {rel_err_val_vanilla.median():.4f}')
print(f'mean rel. error (training, Vanilla) {rel_err_train_vanilla.mean():.4f}, '
      f'mean rel. error (validation, Vanilla) {rel_err_val_vanilla.mean():.4f}')
print(f'max rel. error (training, Vanilla) {rel_err_train_vanilla.max():.4f}, '
      f'max rel. error (validation, Vanilla) {rel_err_val_vanilla.max():.4f}')

vanilla_val_stats, vanilla_val_all_rel_errs, vanilla_val_all_abs_errs = compute_extended_errors(val_x, val_y, val_pred_vanilla)
vanilla_train_stats, vanilla_train_all_rel_errs, vanilla_train_all_abs_errs = compute_extended_errors(train_x, train_y, train_pred_vanilla)

In [None]:
# Hill model

model_hill = HillThermModule(dim=2)
model_hill.eval()

with torch.inference_mode():
    train_hill = model_hill(train_x)
    val_hill = model_hill(val_x)

rel_err_train_hill = torch.norm(unpack_sym(train_y, dim=2) - unpack_sym(train_hill, dim=2), 'fro', dim=(1,2)) / torch.norm(unpack_sym(train_y, dim=2), 'fro', dim=(1,2))
rel_err_val_hill = torch.norm(unpack_sym(val_y, dim=2) - unpack_sym(val_hill, dim=2), 'fro', dim=(1,2)) / torch.norm(unpack_sym(val_y, dim=2), 'fro', dim=(1,2))

print(f'median rel. error (training, Hill) {rel_err_train_hill.median():.4f}, '
      f'median rel. error (validation, Hill) {rel_err_val_hill.median():.4f}')
print(f'mean rel. error (training, Hill) {rel_err_train_hill.mean():.4f}, '
      f'mean rel. error (validation, Hill) {rel_err_val_hill.mean():.4f}')
print(f'max rel. error (training, Hill) {rel_err_train_hill.max():.4f}, '
      f'max rel. error (validation, Hill) {rel_err_val_hill.max():.4f}')

hill_val_stats, hill_val_all_rel_errs, hill_val_all_abs_errs = compute_extended_errors(val_x, val_y, val_hill)
hill_train_stats, hill_train_all_rel_errs, hill_train_all_abs_errs = compute_extended_errors(train_x, train_y, train_hill)

In [None]:
fontsize = 8
plt.rcParams.update({'font.size': fontsize})
plt.style.use('seaborn-v0_8-paper')
plt.rcParams.update({
    'font.size': fontsize,
})

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde

def plot_stacked_kde(all_errors, ax, min_variance_threshold=1e-6, fontsize=20):
    phase_contrasts = sorted(all_errors.keys())
    all_error_values = np.concatenate([all_errors[pc].numpy() * 100 for pc in phase_contrasts])
    global_min, global_max = np.min(all_error_values), np.max(all_error_values)
    error_grid = np.linspace(global_min, global_max, 1000)
    
    pc_vals_list, densities_list = [], []
    for pc in phase_contrasts:
        error_values = all_errors[pc].numpy() * 100
        if np.var(error_values) < min_variance_threshold:
            print(f"Skipping phase contrast {pc} due to low variance")
            continue
        density = gaussian_kde(error_values)(error_grid)
        density /= np.max(density)  # normalize
        pc_vals_list.append(np.full_like(error_grid, pc))
        densities_list.append(density)
    
    if len(pc_vals_list) == 0:
        print("No valid phase contrasts to plot")
        return ax

    valid_pcs = np.array([arr[0] for arr in pc_vals_list])
    Z = np.array(densities_list)
    X, Y = np.meshgrid(valid_pcs, error_grid)
    
    c = ax.pcolormesh(X, Y, Z.T, cmap='Blues', shading='auto')
    # cb = ax.figure.colorbar(c, ax=ax)
    # cb.ax.tick_params(labelsize=fontsize)
    
    ax.set_xscale('log')
    ax.set_xlabel(fr"Phase contrast $R \;[-]$", fontsize=fontsize)
    ax.set_ylabel(fr"Relative Frobenius error $[\%]$", fontsize=fontsize)
    ax.tick_params(axis='both', which='major', labelsize=fontsize)
    ax.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.3)
    
    plt.tight_layout()
    return ax

fig, ax = plt.subplots(1,3, figsize=(6.3, 2.5), dpi=600)
plt.style.use('seaborn-v0_8-paper')
plt.rcParams['axes.edgecolor'] = 'black'
plt.rcParams['axes.linewidth'] = 1


# ---------------------- VRNN --------------------------------

plot_stacked_kde(
    vrnn_val_all_rel_errs, 
    ax=ax[0], 
    fontsize=8
)

# Prepare x and y values for VRNN validation (mean and median relative error)
x_val = np.array(sorted(vrnn_val_stats.keys()))
mean_rel_val = np.array([100 * vrnn_val_stats[k]['mean_rel_err'].item() 
                         for k in sorted(vrnn_val_stats.keys())])
median_rel_val = np.array([100 * vrnn_val_stats[k]['median_rel_err'].item() 
                           for k in sorted(vrnn_val_stats.keys())])

# Plot VRNN validation lines and markers
ax[0].plot(x_val, mean_rel_val, color='red', linewidth=1, 
           label="validation - mean")
ax[0].plot(x_val, median_rel_val, color='red', linewidth=0.7, 
           label="validation - median", linestyle='--')

x_train = np.array(sorted(vrnn_train_stats.keys()))
mean_rel_train = np.array([100 * vrnn_train_stats[k]['mean_rel_err'].item() 
                           for k in sorted(vrnn_train_stats.keys())])
ax[0].scatter(x_train, mean_rel_train, color='red', edgecolor='black', 
              s=10, linewidth=0.3, label="train - mean", zorder=10)

ax[0].set_xscale('log')
ax[0].set_xlim(1e-2, 1e2)
ax[0].set_ylim(0, 30)
ax[0].legend(fontsize=8, loc='upper center')
ax[0].grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.3)
ax[0].set_box_aspect(1)
ax[0].minorticks_off()
ax[0].set_title('Voigt-Reuss net')


# ---------------------- Vanilla ANN --------------------------------

plot_stacked_kde(
    vanilla_val_all_rel_errs, 
    ax=ax[1], 
    fontsize=8
)

# Prepare x and y values for Vanilla validation (mean and median relative error)
x_val = np.array(sorted(vanilla_val_stats.keys()))
mean_rel_val = np.array([100 * vanilla_val_stats[k]['mean_rel_err'].item() 
                         for k in sorted(vanilla_val_stats.keys())])
median_rel_val = np.array([100 * vanilla_val_stats[k]['median_rel_err'].item() 
                           for k in sorted(vanilla_val_stats.keys())])

# Plot Vanilla validation lines and markers
ax[1].plot(x_val, mean_rel_val, color='lime', linewidth=1, 
           label="validation - mean", linestyle='-')
ax[1].plot(x_val, median_rel_val, color='lime', linewidth=0.7, 
           label="validation - median", linestyle='--')

x_train = np.array(sorted(vanilla_train_stats.keys()))
mean_rel_train = np.array([100 * vanilla_train_stats[k]['mean_rel_err'].item() 
                           for k in sorted(vanilla_train_stats.keys())])
ax[1].scatter(x_train, mean_rel_train, color='lime', edgecolor='black', 
              s=10, linewidth=0.3, label="train - mean", zorder=10)

ax[1].set_xscale('log')
ax[1].set_xlim(1e-2, 1e2)
ax[1].set_ylim(0, 30)
ax[1].legend(fontsize=8, loc='upper center')
ax[1].grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.3)
ax[1].set_box_aspect(1)
ax[1].set_yticklabels([])
ax[1].set_ylabel('')
ax[1].minorticks_off()
ax[1].set_title('Vanilla NN')


# ---------------------- Compare VRNN and Vanilla ANN --------------------------------

# Define the keys to use from the stats dictionaries
stat_quantity = 'mean_rel_err'
std_quantity = 'std_rel_err'

# --- Vanilla ANN ---
x_vanilla = np.array(sorted(vanilla_val_stats.keys()))
mean_vals_vanilla = np.array([100 * vanilla_val_stats[k][stat_quantity].item() for k in sorted(vanilla_val_stats.keys())])
std_vals_vanilla = np.array([100 * vanilla_val_stats[k][std_quantity].item() for k in sorted(vanilla_val_stats.keys())])

upper_bound_vanilla = mean_vals_vanilla + std_vals_vanilla
lower_bound_vanilla = mean_vals_vanilla - std_vals_vanilla

ax[2].fill_between(x_vanilla, lower_bound_vanilla, upper_bound_vanilla, color='lime', alpha=0.2,
                   label=fr"Vanilla - $\mu \pm 1\sigma$", edgecolor='black')
ax[2].plot(x_vanilla, mean_vals_vanilla, color='lime', linewidth=1,
           )#label="Vanilla ANN - mean", linestyle='-')

x_train_vanilla = np.array(sorted(vanilla_train_stats.keys()))
train_mean_vanilla = np.array([100 * vanilla_train_stats[k][stat_quantity].item() for k in sorted(vanilla_train_stats.keys())])
# ax[2].scatter(x_train_vanilla, train_mean_vanilla, facecolors='lime', edgecolors='black', s=10,
#               label="ANN train - mean")

# --- VRNN ---
x_vrnn = np.array(sorted(vrnn_val_stats.keys()))
mean_vals_vrnn = np.array([100 * vrnn_val_stats[k][stat_quantity].item() for k in sorted(vrnn_val_stats.keys())])
std_vals_vrnn = np.array([100 * vrnn_val_stats[k][std_quantity].item() for k in sorted(vrnn_val_stats.keys())])

upper_bound_vrnn = mean_vals_vrnn + std_vals_vrnn
lower_bound_vrnn = mean_vals_vrnn - std_vals_vrnn

ax[2].fill_between(x_vrnn, lower_bound_vrnn, upper_bound_vrnn, color='red', alpha=0.2,
                   label=fr"Voigt-Reuss - $\mu \pm 1\sigma$", edgecolor='black')
ax[2].plot(x_vrnn, mean_vals_vrnn, color='red', linewidth=1,
            )#label="Voigt-Reuss NN - mean")

x_train_vrnn = np.array(sorted(vrnn_train_stats.keys()))
train_mean_vrnn = np.array([100 * vrnn_train_stats[k][stat_quantity].item() for k in sorted(vrnn_train_stats.keys())])
# ax[2].scatter(x_train_vrnn, train_mean_vrnn, facecolors='red', edgecolors='black', s=10,
#               label="VRNN training - mean")

ax[2].set_xscale('log')
ax[2].legend(loc = 'upper center')
ax[2].grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.3)
ax[2].set_box_aspect(1)
ax[2].set_xlim(1e-2, 1e2)
ax[2].set_ylim(0, 30)
ax[2].minorticks_off()
ax[2].set_title('Comparison on validation set')
ax[2].set_yticklabels([])
ax[2].set_ylabel('')
ax[2].set_xlabel(fr'Phase contrast $R \;[-]$', fontsize=fontsize)


plt.subplots_adjust(wspace=0.15)
plt.show()

# Save figure in high quality
fig.savefig('../../overleaf/gfx/2D_thermal_error_comparison.png', dpi=600, bbox_inches='tight')


In [None]:
# Define global variables
fontsize = 8 
nbins = 150  # Increased for smoother histograms
linewidth = 1.0  # Increased for better visibility
plt.style.use('seaborn-v0_8-paper')  # Use a clean, publication-style theme

# Create figure with 2x4 tiles with adjusted size ratio
fig, ax = plt.subplots(2, 4, figsize=[6.3, 3.5], dpi=300)

# R values to plot with LaTeX formatting
R_values = [1./100, 1./10, 10, 100]
titles = [r'$R = 1/100$', r'$R = 1/10$', r'$R = 10 $', r'$R = 100$']
labels_top = [r'$\overline{\kappa}_{11}$', r'$\overline{\kappa}_{22}$', r'$\overline{\kappa}_{12}$']
labels_bottom = [r'$\xi_{\lambda1}$', r'$\xi_{\lambda2}$', r'$\xi_{\mathrm{q}1}$']

# Define different colors for top and bottom rows
colors_top = ['#1f77b4','#d62728', '#2ca02c' ]      # Blue, Red, Green
colors_bottom = ['#9467bd', '#ff7f0e', '#17becf']    # Purple, Orange, Cyan

# Plot histograms
for i, (R, title) in enumerate(zip(R_values, titles)):
    # Original data (top row)
    for j, (label, color) in enumerate(zip(labels_top, colors_top)):
        ax[0,i].hist(train_y[tmask(R), j].ravel().cpu(), bins=nbins, 
                    histtype=u'step', label=label, linewidth=linewidth,
                    color=color, alpha=0.8)
    ax[0,i].set_title(title, fontsize=fontsize, pad=8)
    ax[0,i].grid(True, alpha=0.3, linewidth=0.5, linestyle='--')
    # ax[0,i].set_yticklabels([])
    ax[0,i].set_box_aspect(1.0)
    # Keep all spines visible for box appearance
    for spine in ax[0,i].spines.values():
        spine.set_visible(True)
    # ax[0,i].set_ylim(0, 10000)
    # if i != 0:
    #     ax[0,i].set_yticklabels([])

    # Normalized data (bottom row)
    for j, (label, color) in enumerate(zip(labels_bottom, colors_bottom)):
        ax[1,i].hist(train_y_norm[tmask(R), j].ravel().cpu(), bins=nbins, 
                    histtype=u'step', label=label, linewidth=linewidth,
                    color=color, alpha=0.8)
    ax[1,i].set_title(title, fontsize=fontsize, pad=8)
    ax[1,i].grid(True, alpha=0.4, linewidth=0.5, linestyle='--')
    # ax[1,i].set_yticklabels([])
    ax[1,i].set_box_aspect(1.0)
    ax[1,i].set_xlim(0,1)
    # Keep all spines visible for box appearance
    for spine in ax[1,i].spines.values():
        spine.set_visible(True)
    # ax[1,i].set_ylim(0, 5000)
    # if i != 0:
    #     ax[1,i].set_yticklabels([])

# Apply global fontsize and style
for row in ax:
    for a in row:
        a.tick_params(axis='both', which='major', labelsize=fontsize)
        a.tick_params(axis='both', which='minor', labelsize=fontsize)
        a.xaxis.label.set_size(fontsize)
        a.yaxis.label.set_size(fontsize)
        # Add minor ticks for better precision
        a.minorticks_on()

# Create legends with better positioning and style
legend1 = fig.legend(*ax[0,0].get_legend_handles_labels(), 
                    loc='center', bbox_to_anchor=(0.5, 0.50), 
                    ncol=3, fontsize=fontsize, frameon=True,
                    edgecolor='black', fancybox=False)
legend2 = fig.legend(*ax[1,0].get_legend_handles_labels(), 
                    loc='center', bbox_to_anchor=(0.5, -0.02), 
                    ncol=3, fontsize=fontsize, frameon=True,
                    edgecolor='black', fancybox=False)

# Adjust layout
fig.tight_layout()
plt.subplots_adjust(hspace=0.475, wspace=0.4, bottom=0.08)

plt.show()

# Save the figure with high quality
fig.savefig("../../overleaf/gfx/therm2d_histograms.pdf", 
            bbox_inches='tight', 
            dpi=300,
            metadata={'Creator': '', 'Producer': ''})



In [None]:
import matplotlib.cm as cm
from sklearn.metrics import r2_score

# Set global fontsize
fontsize = 8
plt.style.use('seaborn-v0_8-paper')  # Clean, publication-style theme

def plot_predictions_v3(ax, y, pred, idx=0, error_type='absolute_log', norm=None):
    """Create scatter plot with color-coded prediction errors."""
    y_sorted = y[:, idx]
    pred_sorted = pred[:, idx]

    # Calculate errors and determine color mapping
    if error_type == 'absolute':
        error_plt = np.abs((pred_sorted - y_sorted).cpu().numpy())
        norm = plt.Normalize(vmin=0, vmax=2) if norm is None else norm
    else:  # absolute_log
        error_plt = np.abs((torch.log(np.abs(pred_sorted)) - torch.log(y_sorted)).cpu().numpy())
        norm = plt.Normalize(vmin=0, vmax=2) if norm is None else norm

    # Create scatter plot with viridis colormap
    colors = cm.viridis(norm(error_plt))
    ax.scatter(y_sorted.cpu().numpy(), 
              pred_sorted.cpu().numpy(), 
              color=colors, 
              marker='.',
              s=5, 
              alpha=0.7)
    
    # Add diagonal reference line
    lims = [y_sorted.min().item(), y_sorted.max().item()]
    ax.plot(lims, lims, 'r--', alpha=0.8, linewidth=0.8)

# Create figure and axes with standardized size
fig, axs = plt.subplots(2, 3, figsize=[6.3, 3.05], dpi=300, sharex='col', squeeze=False)

# Define norms for consistent color scaling
norm_log = plt.Normalize(vmin=0, vmax=2)   # Columns 0 and 1
norm_abs = plt.Normalize(vmin=0, vmax=12)  # Column 2

# LaTeX labels for tensor components
component_labels = [r"$\kappa_{11}$", r"$\kappa_{22}$", r"$\kappa_{12}$"]

# Filter validation data
mask = val_x[:, -1] > 0
models = [
    (val_pred[mask].cpu(), "VRNN"),
    (val_pred_vanilla[mask].cpu(), "Vanilla")
]

# Calculate R² scores for all models and components
r2_scores = [[r2_score(val_y[mask][:, i].cpu(), pred[:, i].cpu()) 
              for i in range(3)] for pred, _ in models]

# Create subplots
for row, (pred, model_name) in enumerate(models):
    for col in range(3):
        ax = axs[row, col]
        
        # Plot with appropriate error metrics
        error_type = 'absolute_log' if col < 2 else 'absolute'
        norm = norm_log if col < 2 else norm_abs
        plot_predictions_v3(ax, val_y[mask].cpu(), pred, idx=col, 
                          error_type=error_type, norm=norm)
        
        # Add R² score annotation
        ax.text(0.05, 0.95, f"$R^2 = {r2_scores[row][col]:.4f}$", 
                transform=ax.transAxes, fontsize=fontsize, verticalalignment='top',
                bbox=dict(facecolor='white', edgecolor='none', alpha=0.8, pad=2))
        
        # Add component label
        ax.text(0.95, 0.05, component_labels[col], 
                transform=ax.transAxes, fontsize=fontsize, 
                horizontalalignment='right', verticalalignment='bottom',
                bbox=dict(facecolor='white', edgecolor='none', alpha=0.8, pad=2))
        
        # Set axis scales and formatting
        if col < 2:
            ax.set_xscale('log')
            ax.set_yscale('log')

        ax.tick_params(labelbottom=True)
        ax.minorticks_on()
        
        ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
        ax.set_axisbelow(True)
        ax.set_box_aspect(0.65)
        
        ax.set_xlabel('')
        ax.set_ylabel('')

# Add colorbars at the bottom with adjusted positioning
for col in range(3):
    norm = norm_log if col < 2 else norm_abs
    sm = plt.cm.ScalarMappable(cmap=cm.viridis, norm=norm)
    # Create a new axis for colorbar below all plots
    cax = fig.add_axes([0.1 + col*0.315, 0.02, 0.23, 0.02])
    cbar = fig.colorbar(sm, cax=cax, orientation='horizontal')
    cbar.ax.tick_params(labelsize=fontsize)
    error_type = 'absolute_log' if col < 2 else 'absolute'
    if col == 0:
        error_label = r'$|\log(\widehat{\overline{\kappa}}_{11}) - \log(\overline{\kappa}_{11})|$'
    elif col == 1:
        error_label = r'$|\log(\widehat{\overline{\kappa}}_{22} - \log(\overline{\kappa}_{22})|$'
    elif col == 2:
        error_label = r'$|\widehat{\overline{\kappa}}_{12} - \overline{\kappa}_{12}|$'
    cbar.set_label(error_label, fontsize=fontsize)

plt.tight_layout()
plt.subplots_adjust(hspace=0.0, wspace=0.25)
plt.show()

# Save the figure with high quality
fig.savefig("../../overleaf/gfx/therm2d_predictions_comparision_v2.png", 
            bbox_inches='tight', 
            dpi=600,
            metadata={'Creator': '', 'Producer': ''})


In [None]:
# Global settings
fontsize = 8
plt.rcParams.update({'font.size': fontsize})
plt.style.use('seaborn-v0_8-paper')  # Clean, publication-style theme

def plot_predictions_v2(ax, x, y, pred, phase_contrast, is_last_row=False):
    # Create mask for the specified phase contrast
    mask = x[:, -1] == phase_contrast
    x_filtered = x[mask]
    y_filtered = y[mask]
    pred_filtered = pred[mask]
    
    # Process data for plotting
    volume_fraction = x_filtered[:, 0].cpu().numpy()
    y_matrix = torch.stack([torch.stack([y_filtered[:, 0], y_filtered[:, 2]], dim=-1),
                          torch.stack([y_filtered[:, 2], y_filtered[:, 1]], dim=-1)], dim=-2)
    pred_matrix = torch.stack([torch.stack([pred_filtered[:, 0], pred_filtered[:, 2]], dim=-1),
                           torch.stack([pred_filtered[:, 2], pred_filtered[:, 1]], dim=-1)], dim=-2)
    
    # Compute eigenvalues
    y_eigenvalues = torch.linalg.eigvals(y_matrix).real.cpu().numpy()
    pred_eigenvalues = torch.linalg.eigvals(pred_matrix).real.cpu().numpy()
    
    # Generate bounds
    vf = np.linspace(0, 1, 1500)
    voigt_bound = vf - (vf - 1) / phase_contrast
    reuss_bound = 1 / (phase_contrast + vf - phase_contrast * vf)
    
    # Plot data
    vf_repeated = np.repeat(volume_fraction, 2)
    stacked_y = y_eigenvalues.reshape(-1)
    stacked_pred = pred_eigenvalues.reshape(-1)
    
    ax.fill_between(vf, np.flip(voigt_bound), y2=np.flip(reuss_bound), alpha=0.2, color='tab:blue')
    ax.plot(vf, np.flip(voigt_bound), 'k--', linewidth=0.5, alpha=0.3, label='Voigt-Reuss bounds' if is_last_row else None)
    ax.plot( vf, np.flip(reuss_bound), 'k--', linewidth=0.5, alpha=0.3)
    
    # compute the voigt and reuss bound for each entry of vf_repeated
    voigt_bound = (1-vf_repeated) - ((1-vf_repeated) - 1) / phase_contrast
    reuss_bound = 1 / (phase_contrast + (1-vf_repeated) - phase_contrast * (1-vf_repeated))
    
    if is_last_row:
        ax.scatter(vf_repeated, stacked_y, s=1.5, color='royalblue',
                edgecolor='black', linewidth=0.05, alpha=0.2)
        ax.scatter([], [], s=3.0, color='royalblue', edgecolor='black', linewidth=0.05, alpha=1.0,
                label=r'ground truth $\lambda({\overline{\underline{\underline{\kappa}}}})$')
    
    
    
    violation_mask = (stacked_pred > voigt_bound) | (stacked_pred < reuss_bound)
        
    ax.scatter(vf_repeated[~violation_mask], stacked_pred[~violation_mask], s=1.5, color='limegreen',
                edgecolor='black', linewidth=0.05, alpha=0.2)
    ax.scatter([], [], s=3.0, color='limegreen', edgecolor='black', linewidth=0.05, alpha=1.0,
               label=r'predicted $\lambda(\widehat{\overline{\underline{\underline{\kappa}}}})$' if is_last_row else None)
    
    
    # Violations!
    ax.scatter(vf_repeated[violation_mask], stacked_pred[violation_mask], s=3, color='red',
               edgecolor='black', linewidth=0.05, alpha=1.0,
              label=r'predicted $\lambda(\widehat{\overline{\underline{\underline{\kappa}}}})$ violation' if is_last_row else None)

    
    

# Create figure
fig, ax = plt.subplots(3, 4, figsize=(6.3, 5.1), dpi=300)

# Plot data
contrasts = [1./75, 1./30, 30, 75]
# contrasts = [1./75, 1./35, 15, 95]
datasets = [val_pred, val_pred_vanilla, val_hill]

for row, pred_data in enumerate(datasets):
    for col, contrast in enumerate(contrasts):
        plot_predictions_v2(ax[row, col], val_x, val_y, pred_data, contrast, is_last_row=(row==2))
        if row == 0:
            # Format first two columns as fractions, last two as regular numbers
            if col < 2:
                ax[row, col].set_title(fr'$R = 1/{int(1/contrast)}$', fontsize=fontsize)
            else:
                ax[row, col].set_title(fr'$R = {int(contrast)}$', fontsize=fontsize)
        # if col == 0:
        #     ax[row, col].set_ylabel('eigenvalues', fontsize=fontsize)
        if row == 2:
            ax[row, col].set_xlabel(fr'volume fraction $[-]$', fontsize=fontsize)
        ax[row, col].grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
        ax[row, col].set_axisbelow(True)
        ax[row, col].set_box_aspect(1.0)
        ax[row, col].set_xlim(-0.1, 1.1)
        ax[row, col].minorticks_on() # Add minor ticks
        

# Add single legend at the bottom
handles, labels = ax[2, 0].get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, -0.05), 
           ncol=4, fontsize=fontsize)

# Adjust layout
plt.tight_layout()
plt.subplots_adjust(hspace = -0.2, bottom=0.025)  # Make room for legend

# Save figure
fig.savefig("../../overleaf/gfx/therm2d_predictions_VRbounds.png", 
            bbox_inches='tight', 
            dpi=600,
            metadata={'Creator': '', 'Producer': ''})
plt.show()


In [None]:
training_history_file = data_dir / 'Thermal2D_models/vrnn_therm2D_training_history_20250121_173200.out'

import matplotlib.pyplot as plt
import re
from matplotlib.ticker import LogLocator, NullFormatter

# Initialize lists to store the values
epochs, training_loss, validation_loss, learning_rates = [], [], [], []

# Read the file
with open(training_history_file, 'r') as file:
    for line in file:
        match = re.match(r'Epoch (\d+): training loss ([\d.]+), validation loss ([\d.]+), learning rate ([\d.e+-]+)', line)
        if match:
            epochs.append(int(match.group(1)))
            training_loss.append(float(match.group(2)))
            validation_loss.append(float(match.group(3)))
            learning_rates.append(float(match.group(4)))

# Hacky... scale training loss and validation loss by sqrt(2)
training_loss = np.array(training_loss) / np.sqrt(2)
validation_loss = np.array(validation_loss) / np.sqrt(2)

# Increase font size for readability
fontsize = 8
linewidth = 0.5

# Create a single figure with one axes
fig, ax1 = plt.subplots(figsize=(6.3/2, 2.5), dpi=600)

# Plot training and validation losses on the left y-axis with markers
l1, = ax1.plot(epochs, training_loss, '-', label='Training Loss', linewidth=linewidth)# , color='blue')
l2, = ax1.plot(epochs, validation_loss, '-', label='Validation Loss', linewidth=linewidth)# , color='red')
ax1.set_xlabel('Epoch $[-]$', fontsize=fontsize)
ax1.set_ylabel('Voigt-Reuss Normalized Loss $[-]$', fontsize=fontsize)
ax1.set_yscale('log')
ax1.set_ylim(7e-3, 1)
ax1.grid(which="both", ls="--", alpha=0.5, linewidth=0.3)
ax1.tick_params(axis='both', which='major', labelsize=fontsize)
ax1.legend(loc='lower left', fontsize=fontsize)

# Add minor ticks to the left axis for a smoother log scale
ax1.yaxis.set_minor_locator(LogLocator(subs='auto'))
ax1.yaxis.set_minor_formatter(NullFormatter())

# Create a twin axes to plot the learning rate decay on the right y-axis
ax2 = ax1.twinx()
l3, = ax2.plot(epochs, learning_rates, '-', color="grey", label='Learning Rate', linewidth=linewidth)
ax2.set_ylabel('Learning Rate $[-]$', fontsize=fontsize)
ax2.set_yscale('log')
ax2.set_ylim(1e-5, 2e-1)
ax2.tick_params(axis='both', which='major', labelsize=fontsize)
ax2.legend(loc='upper right', fontsize=fontsize)

ax2.set_xlim(-10, epochs[-1])
# Apply consistent style and layout
plt.style.use('seaborn-v0_8-paper')
plt.rcParams['axes.edgecolor'] = 'black'
plt.rcParams['axes.linewidth'] = 1
plt.tight_layout()

plt.show()

# Save the figure with high quality
fig.savefig("../../overleaf/gfx/therm2d_vrnn_training_history.pdf", 
            bbox_inches='tight', 
            dpi=600,
            metadata={'Creator': '', 'Producer': ''})

In [None]:
training_history_file = data_dir / 'Thermal2D_models/vann_therm2D_training_history_20250121_202555.out'

# Initialize lists to store the values
epochs, training_loss, validation_loss, learning_rates = [], [], [], []

# Read the file
with open(training_history_file, 'r') as file:
    for line in file:
        match = re.match(r'Epoch (\d+): training loss ([\d.]+), validation loss ([\d.]+), learning rate ([\d.e+-]+)', line)
        if match:
            epochs.append(int(match.group(1)))
            training_loss.append(float(match.group(2)))
            validation_loss.append(float(match.group(3)))
            learning_rates.append(float(match.group(4)))

# Hacky... scale training loss and validation loss by sqrt(2)
training_loss = np.array(training_loss) / np.sqrt(2)
validation_loss = np.array(validation_loss) / np.sqrt(2)

# Increase font size for readability
fontsize = 8
linewidth = 0.5

# Create a single figure with one axes
fig, ax1 = plt.subplots(figsize=(6.3/2, 2.5), dpi=600)

# Plot training and validation losses on the left y-axis with markers
l1, = ax1.plot(epochs, training_loss, '-', label='Training Loss', linewidth=linewidth)# , color='blue')
l2, = ax1.plot(epochs, validation_loss, '-', label='Validation Loss', linewidth=linewidth)# , color='red')
ax1.set_xlabel('Epoch $[-]$', fontsize=fontsize)
ax1.set_ylabel('MSE Loss', fontsize=fontsize)
ax1.set_yscale('log')
ax1.set_ylim(2.5e-2, 3e3)
ax1.grid(which="both", ls="--", alpha=0.5, linewidth=0.3)
ax1.tick_params(axis='both', which='major', labelsize=fontsize)
ax1.legend(loc='lower left', fontsize=fontsize)

# Add minor ticks to the left axis for a smoother log scale
ax1.yaxis.set_minor_locator(LogLocator(subs='auto'))
ax1.yaxis.set_minor_formatter(NullFormatter())

# Create a twin axes to plot the learning rate decay on the right y-axis
ax2 = ax1.twinx()
l3, = ax2.plot(epochs, learning_rates, '-', color="grey", label='Learning Rate', linewidth=linewidth)
ax2.set_ylabel('Learning Rate $[-]$', fontsize=fontsize)
ax2.set_yscale('log')
ax2.set_ylim(1e-5, 2e-1)
ax2.tick_params(axis='both', which='major', labelsize=fontsize)
ax2.legend(loc='upper right', fontsize=fontsize)

ax2.set_xlim(-10, epochs[-1])
# Apply consistent style and layout
plt.style.use('seaborn-v0_8-paper')
plt.rcParams['axes.edgecolor'] = 'black'
plt.rcParams['axes.linewidth'] = 1
plt.tight_layout()

plt.show()

# Save the figure with high quality
fig.savefig("../../overleaf/gfx/therm2d_vann_training_history.pdf", 
            bbox_inches='tight', 
            dpi=600,
            metadata={'Creator': '', 'Producer': ''})