In [None]:
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from vrnn import utils
from vrnn.models import VanillaModule
from vrnn.normalization import NormalizedDataset, NormalizationModule
from vrnn.data_thermal import DatasetThermal, VoigtReussThermNormalization
import numpy as np
from vrnn.tensortools import unpack_sym


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
dtype = torch.float32

data_dir = utils.get_data_dir()

import shutil
import matplotlib
plt.rcParams["text.usetex"] = True if shutil.which('latex') else False
matplotlib.rcParams["text.latex.preamble"] = r"\usepackage{amsmath}"

In [None]:
# VRNN model
model_norm_file = data_dir / 'Thermal2D_models/alpha/vrnn_therm2D_norm_20250121_173200.pt'
# Vanilla model
model_vanilla_file = data_dir / 'Thermal2D_models/alpha/vann_therm2D_20250121_202555.pt'

# Load hdf5 files
ms_file = '/media/ssd/keshav/feature_engineering_thermal_2D.h5'
print(ms_file)

# Load data
feature_idx = None
R_range_train = [1/100., 1/50., 1/20., 1/10., 1/5., 1/2., 2, 5, 10, 20, 50, 100]
train_data = DatasetThermal(file_name=ms_file, R_range=R_range_train, group='train_set',
                            input_mode='descriptors', feature_idx=feature_idx, feature_key='feature_vector', ndim=2)


R_range_val = np.concatenate([np.arange(2, 101, dtype=int), 1. / np.arange(1, 101, dtype=int)])
val_data = DatasetThermal(file_name=ms_file, R_range=R_range_val, group='val_set',
                          input_mode='descriptors', feature_idx=feature_idx, feature_key='feature_vector', ndim=2)

# Create dataloaders
batch_size = 30000
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)

# Fetch data
train_x, train_y = utils.get_data(train_loader, device=device, dtype=dtype)
val_x, val_y = utils.get_data(val_loader, device=device, dtype=dtype)


# Define normalization
features_max = torch.cat([train_x, val_x], dim=0).max(dim=0)[0]
features_min = torch.cat([train_x, val_x], dim=0).min(dim=0)[0]
features_min[0],features_max[0]  = 0, 1 # Dont normalize the first feature (volume fraction)
# features_min, features_max = None, None
normalization = VoigtReussThermNormalization(dim=2, features_min=features_min, features_max=features_max)

# Normalize data
train_data_norm = NormalizedDataset(train_data, normalization)
val_data_norm = NormalizedDataset(val_data, normalization)
train_loader_norm = DataLoader(train_data_norm, batch_size=batch_size, shuffle=False)
val_loader_norm = DataLoader(val_data_norm, batch_size=batch_size, shuffle=False)
train_x_norm, train_y_norm = utils.get_data(train_loader_norm, device=device, dtype=dtype)
val_x_norm, val_y_norm = utils.get_data(val_loader_norm, device=device, dtype=dtype)

# VRNN normalized model

ann_model = torch.load(model_norm_file, map_location=device, weights_only=False).to(device=device, dtype=dtype)
model_norm = VanillaModule(ann_model).to(device=device, dtype=dtype)
model_norm.eval()

with torch.inference_mode():
    train_pred_norm = model_norm(train_x_norm)
    val_pred_norm = model_norm(val_x_norm)

# VRNN model

model = NormalizationModule(normalized_module=model_norm, normalization=normalization).to(device=device, dtype=dtype)
model.eval()

with torch.inference_mode():
    train_pred = model(train_x)
    val_pred = model(val_x)

# Vanilla model
ann_model = torch.load(model_vanilla_file, map_location=device, weights_only=False).to(device=device, dtype=dtype)
model_vanilla = VanillaModule(ann_model).to(device=device, dtype=dtype)
model_vanilla.eval()

with torch.inference_mode():
    train_pred_vanilla = model_vanilla(train_x)
    val_pred_vanilla = model_vanilla(val_x)


In [None]:
import torch, torch.nn.functional as F

# ─────────────────────────────────────────────────────────────
# 1.  normalisation wrappers 
# ─────────────────────────────────────────────────────────────
def encode_x(x_raw):
    """
    raw (53,)  ➜  x̄ (53,)
    """
    return normalization.normalize_x(x_raw.unsqueeze(0)).squeeze(0)

def decode_x(x_bar):
    """
    x̄ (53,)  ➜  raw (53,)  (physical units)
    """
    return normalization.reconstruct_x(x_bar.unsqueeze(0)).squeeze(0)

def decode_y(x_raw, y_norm):
    """
    (x_raw, ȳ) ➜ y_raw  (physical units)
    Still fully differentiable.
    """
    return normalization.reconstruct(x_raw.unsqueeze(0), y_norm.unsqueeze(0)).squeeze(0)

# ─────────────────────────────────────────────────────────────
# 2.  forward pass
# ─────────────────────────────────────────────────────────────
def forward_raw(x_raw, use_norm=True):
    """
    x_raw (53,) ─► κ_raw  (κ11, κ22, κ12)
    """
    x_raw = x_raw.to(device)
    if use_norm:
        y_raw = model(x_raw.unsqueeze(0)).squeeze(0) 
    else:
        y_raw = model_vanilla(x_raw.unsqueeze(0)).squeeze(0)
    return y_raw

def raw_forward_from_xbar(x_bar):
    """
    x_bar  (53,) in [0,1]  --->  κ_raw (κ11, κ22, κ12)  physical units
    • keeps gradients (no torch.no_grad)
    • uses model_norm (x̄ ➜ ȳ)
    """
    y_bar = model_norm(x_bar.unsqueeze(0)).squeeze(0)           # ȳ
    x_raw = normalization.reconstruct_x(x_bar.unsqueeze(0)).squeeze(0)  # raw x
    y_raw = decode_y(x_raw, y_bar)  # κ_raw
    return y_raw, y_bar

# ─────────────────────────────────────────────────────────────
# 3.  utility metrics
# ─────────────────────────────────────────────────────────────
def anisotropy(k):
    """
    Compute the anisotropy defined as λ_max / λ_min, where λ_min is clamped to eps
    for numerical stability. Works for both a single sample (1D tensor) and batches (2D tensor).
    """
    eps = 1e-8
    unsqueezed = False
    # If a single sample is provided, add a batch dimension.
    if k.dim() == 1:
        k = k.unsqueeze(0)
        unsqueezed = True
    # Convert from Voigt notation to symmetric matrices.
    K = unpack_sym(k, dim=2)  # Expected shape: (B, 2, 2)
    # Compute eigenvalues for each symmetric matrix.
    eigvals = torch.linalg.eigvalsh(K)  # Shape: (B, 2), sorted in ascending order.
    # Compute anisotropy: λ_max / max(λ_min, eps) ensuring no div-by-zero.
    ratio = eigvals[:, 1] / torch.clamp(eigvals[:, 0], min=eps)
    return ratio[0] if unsqueezed else ratio

def kbar(k):                          # mean of κ11, κ22
    return 0.5 * (k[0] + k[1])

def random_raw_sample():
    x_bar = torch.rand(53, device=device, dtype=dtype)
    x = normalization.reconstruct_x(x_bar.unsqueeze(0)).squeeze(0)  # raw
    return x, x_bar



In [None]:
# ─────────────────────────────────────────────────────────────
# A.  maximise anisotropy at fixed R 
# ─────────────────────────────────────────────────────────────
def design_most_anisotropic(R,                        
                            restarts=2,
                            lr=1e-1,
                            steps=200,
                            use_norm=True):
    best_val, best_x, best_k = -float('inf'), None, None
    log_data_A = []

    for restart in range(restarts):
        # Initialize logging for this restart.
        log_data_A.append({"restart": restart + 1, "steps": []})

        x, _ = random_raw_sample()  # random raw sample (start somehwere inside the training data)
        x[-2:] = torch.tensor([1.0 / R, R], device=device, dtype=dtype)  # enforce (1/R, R)
        # x[0] = 0.001
        x_raw = x
        x_raw.requires_grad_(True)
        opt = torch.optim.Adam([x_raw], lr=lr)
        scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=steps // 10, gamma=0.5)

        for step in range(steps):
            opt.zero_grad()
            k_raw = forward_raw(x_raw, use_norm=use_norm) 
            
            loss = -anisotropy(k_raw)
            current_lr = opt.param_groups[0]['lr']
            
            # ----- Log the current step's data -----
            log_entry = {
                "restart": restart + 1,
                "step": step + 1,
                "loss": loss.item(),
                "anisotropy": -loss.item(),  # note: loss was defined as -anisotropy
                "lr": current_lr,
                "x_raw": x_raw.detach().cpu().numpy(),
                "k_raw": k_raw.detach().cpu().numpy()
            }
            log_data_A[restart]["steps"].append(log_entry)
            print(f"Restart {restart}/{restarts} - Step {step}/{steps}: anisotropy = {-loss.item():.4f}, lr = {current_lr:.6f}")
            # ---------------------------------------
                                
            loss.backward()
            x_raw.grad[-2:] = 0  # do not backprop through the last two features (1/R, R)
            opt.step()
            scheduler.step()

            # Clamp to [0,1] (volume fraction) *without* breaking the graph:
            with torch.no_grad():
                x_raw[0].clamp_(0.0, 1.0)

        val = -loss.item()
        if val > best_val:
            best_val, best_x, best_k = val, x_raw.detach().clone(), k_raw.detach()

    return (best_x,  # 53-D (raw)
            best_k,  # κ_raw
            best_val,  # anisotropy
            log_data_A)  # log data for analysis

# contrastR_A = 1/100
# bx, bk, a, log_data_A  = design_most_anisotropic(R=contrastR_A, restarts=2, lr=1e-1, steps=200, use_norm=True)
# print("anisotropy =", a, "\nκ =", bk.cpu().numpy())
# print("x =", bx.cpu().numpy())

# best_dataset, best_sample, best_counterpart, best_idx = find_closest_sample(bk, [train_data, val_data], compare='targets')
# print(f"Closest sample (idx {best_idx}) in dataset {best_dataset}:")
# print(f"x: {best_sample.cpu().numpy()}")
# print(f"y: {best_counterpart.cpu().numpy()}")

In [None]:
import matplotlib.pyplot as plt

# Global settings
fontsize = 8
plt.rcParams.update({'font.size': fontsize})
plt.style.use('seaborn-v0_8-paper')  # Clean, publication-style theme

def plot_predictions_v2(ax, x, y, pred, phase_contrast, is_last_row=False):
    # Create mask for the specified phase contrast
    mask = x[:, -1] == phase_contrast
    x_filtered = x[mask]
    y_filtered = y[mask]
    pred_filtered = pred[mask]
    
    # Process data for plotting
    volume_fraction = x_filtered[:, 0].cpu().numpy()
    y_matrix = torch.stack([torch.stack([y_filtered[:, 0], y_filtered[:, 2]], dim=-1),
                          torch.stack([y_filtered[:, 2], y_filtered[:, 1]], dim=-1)], dim=-2)
    pred_matrix = torch.stack([torch.stack([pred_filtered[:, 0], pred_filtered[:, 2]], dim=-1),
                           torch.stack([pred_filtered[:, 2], pred_filtered[:, 1]], dim=-1)], dim=-2)
    
    # Compute eigenvalues
    y_eigenvalues = torch.linalg.eigvals(y_matrix).real.cpu().numpy()
    pred_eigenvalues = torch.linalg.eigvals(pred_matrix).real.cpu().numpy()
    
    # Generate bounds
    vf = np.linspace(0, 1, 1500)
    voigt_bound = vf - (vf - 1) / phase_contrast
    reuss_bound = 1 / (phase_contrast + vf - phase_contrast * vf)
    
    # Plot data
    vf_repeated = np.repeat(volume_fraction, 2)
    stacked_y = y_eigenvalues.reshape(-1)
    stacked_pred = pred_eigenvalues.reshape(-1)
    
    ax.fill_between(vf, np.flip(voigt_bound), y2=np.flip(reuss_bound), alpha=0.2, color='tab:blue')
    ax.plot(vf, np.flip(voigt_bound), 'k--', linewidth=0.5, alpha=0.3, label='Voigt-Reuss bounds' if is_last_row else None)
    ax.plot( vf, np.flip(reuss_bound), 'k--', linewidth=0.5, alpha=0.3)
    
    # compute the voigt and reuss bound for each entry of vf_repeated
    voigt_bound = (1-vf_repeated) - ((1-vf_repeated) - 1) / phase_contrast
    reuss_bound = 1 / (phase_contrast + (1-vf_repeated) - phase_contrast * (1-vf_repeated))
    
    
    ax.scatter(vf_repeated, stacked_y, s=1.5, color='limegreen',
                edgecolor='black', linewidth=0.05, alpha=0.2)
    ax.scatter([], [], s=3.0, color='limegreen', edgecolor='black', linewidth=0.05, alpha=1.0,
                label=r'training data $\lambda({\overline{\underline{\underline{\kappa}}}})$')
    
    
    
    # violation_mask = (stacked_pred > voigt_bound) | (stacked_pred < reuss_bound)
        
    # ax.scatter(vf_repeated[~violation_mask], stacked_pred[~violation_mask], s=1.5, color='limegreen',
    #             edgecolor='black', linewidth=0.05, alpha=0.2)
    # ax.scatter([], [], s=3.0, color='limegreen', edgecolor='black', linewidth=0.05, alpha=1.0,
    #            label=r'predicted $\lambda({\overline{\underline{\underline{\kappa}}}})$' if is_last_row else None)
    
    
    # # Violations!
    # ax.scatter(vf_repeated[violation_mask], stacked_pred[violation_mask], s=3, color='red',
    #            edgecolor='black', linewidth=0.05, alpha=1.0,
    #           label=r'predicted $\lambda({\overline{\underline{\underline{\kappa}}}})$ violation' if is_last_row else None)

def plot_optimization_trajectories(ax, log_data, bx, bk):
    for restart in log_data:
        steps = [entry["step"] for entry in restart["steps"]]
        x_raw = [entry["x_raw"] for entry in restart["steps"]]
        vf = np.array([entry["x_raw"][0] for entry in restart["steps"]])
        
        k_raw = torch.tensor([entry["k_raw"] for entry in restart["steps"]])
        k_raw = unpack_sym(k_raw, dim=2)
        k_raw_eig = torch.linalg.eigvalsh(k_raw)  
        k_raw_eig = k_raw_eig.detach().cpu().numpy()
        
        anisotropies = [entry["loss"] for entry in restart["steps"]]
        
        # Plot eigenvalues vs volume fraction (vf) with a different marker (e.g. square 's')
        ax.plot(vf, k_raw_eig[:, 0], marker='o', linestyle='--', color='red', linewidth=0.5, markersize=3,
                label=r"optimization trajectory $\widehat{\lambda}_1$" if restart["restart"]==1 else "")
        ax.plot(vf, k_raw_eig[:, 1], marker='o', linestyle='--', color='blue', linewidth=0.5, markersize=3,
                label=r"optimization trajectory $\widehat{\lambda}_2$" if restart["restart"]==1 else "")
        
        # scatter the best point
        bk_mat = unpack_sym(torch.tensor(bk, device=device, dtype=dtype), dim=2)
        bk_eig = torch.linalg.eigvalsh(bk_mat)
        ax.scatter(bx[0].item(), bk_eig[0].item(), s=30, marker='D', color='lightgrey', edgecolor='black', linewidth=1.5, alpha=1.0,
                label=r"optimized candidate $\lambda(\widehat{\overline{\underline{\underline{\kappa}}}})$" if restart["restart"]==1 else None, zorder=10)
        ax.scatter(bx[0].item(), bk_eig[1].item(), s=30, marker='D', color='lightgrey', edgecolor='black', linewidth=1.5, alpha=1.0,
                label=r"optimized candidate $\widehat{\lambda}_2$" if restart["restart"]==100 else None, zorder=10)
        
    ax.set_xlabel(fr'volume fraction $[-]$', fontsize=fontsize)
    ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
    ax.set_axisbelow(True)
    ax.set_box_aspect(1.0)
    ax.set_xlim(-0.1, 1.1)
    ax.minorticks_on() # Add minor ticks
    ax.set_yscale('log')
    
    

contrastR_A = 1/100
bx_vrnn, bk_vrnn, _, log_data_A_vrnn = design_most_anisotropic(R=contrastR_A, restarts=4, lr=1e-1, steps=200, use_norm=True)
bx_vann, bk_vann, _, log_data_A_vann = design_most_anisotropic(R=contrastR_A, restarts=4, lr=1e-1, steps=200, use_norm=False)

fig, ax = plt.subplots(1, 2, figsize=(6.3, 3.5), dpi=300)
plot_predictions_v2(ax[0], train_x, train_y, train_pred, contrastR_A, is_last_row=False)
plot_optimization_trajectories(ax[0], log_data_A_vrnn, bx_vrnn, bk_vrnn)

plot_predictions_v2(ax[1], train_x, train_y, train_pred_vanilla, contrastR_A, is_last_row=False)
plot_optimization_trajectories(ax[1], log_data_A_vann, bx_vann, bk_vann)

if contrastR_A < 1:
    ax[0].set_title(f"$\\text{{Voigt-Reuss net -}}\\, R = 1/{int(1/contrastR_A)}$")
    ax[1].set_title(f"$\\text{{Vanilla NN -}}\\, R = 1/{int(1/contrastR_A)}$")
else:
    ax[0].set_title(r'$R = ' + str(contrastR_A) + r'$')
        
handles, labels = ax[1].get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, -0.05), 
           ncol=4, fontsize=fontsize)
plt.tight_layout()
plt.subplots_adjust(bottom=0.15)  # Make room for legend

# Save figure
fig.savefig("../../overleaf/gfx/therm2d_inverse_design_ex_A.png", 
            bbox_inches='tight', 
            dpi=600,
            metadata={'Creator': '', 'Producer': ''})
plt.show()


In [None]:
def find_closest_sample(x, datasets, compare='features'):
    """
    Find the closest sample among the provided datasets to the given point x.
    Args:
        x (torch.Tensor): reference sample.
        datasets (list): list of dataset objects. Each dataset is expected to have both
                         'features' and 'targets' attributes.
        compare (str): attribute to compare - either 'features' or 'targets'.
    Returns:
        best_dataset: the dataset where the closest sample was found,
        best_sample: the sample value from the chosen attribute,
        best_counterpart: the sample’s counterpart (if compare is 'features', this is 'targets',
                          and vice versa),
        best_idx (int): index of the best sample within its dataset.
    """
    x = x.to(device)
    best_distance = float('inf')
    best_dataset = None
    best_sample = None
    best_counterpart = None
    best_idx = None

    for dataset in datasets:
        # Select the attribute for comparison.
        comp_attr = getattr(dataset, compare).to(device)
        if compare=='targets':
            # comp_attr_sym = unpack_sym(comp_attr, dim=2)
            # x_reshaped = unpack_sym(x.unsqueeze(0), dim=2)
            # distances = torch.norm(comp_attr_sym - x_reshaped, p='fro', dim=(1,2))
            
            comp_attr_anisotropy = anisotropy(comp_attr)
            x_anisotropy = anisotropy(x.unsqueeze(0))
            distances = torch.abs(comp_attr_anisotropy - x_anisotropy)
            
        else:
            if compare == 'normalized_x' or compare == 'features':
                # penalize the first feature (volume fraction) to compute the distance
                distances = torch.norm(comp_attr - x, dim=1) + 10*(comp_attr[:, 0] - x[0])**2
            else:                
                distances = torch.norm(comp_attr - x, dim=1)
            
        idx = torch.argmin(distances)
        dist = distances[idx].item()

        if dist < best_distance:
            best_distance = dist
            best_dataset = dataset
            best_sample = comp_attr[idx]
            # Get the other attribute as a counterpart.
            if compare == 'features' or compare == 'targets':
                other_attr = 'targets' if compare == 'features' else 'features'
            elif compare == 'normalized_x' or compare == 'normalized_y':
                other_attr = 'normalized_y' if compare == 'normalized_x' else 'normalized_x'
            best_counterpart = getattr(dataset, other_attr)[idx]
            best_idx = idx.item()

    return best_dataset, best_sample, best_counterpart, best_idx


from torch.utils.data import Dataset
import h5py

class MicrostructureImageDataset(Dataset):
    def __init__(self, file_path, dataset_name):
        self.file = h5py.File(file_path, 'r')
        self.data = self.file[dataset_name]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = self.data[idx]
        return torch.tensor(img).float()
    
# Load the dataset
MS = MicrostructureImageDataset(file_path=ms_file, dataset_name='train_set/image_data')
MS_bench = MicrostructureImageDataset(file_path=ms_file, dataset_name='benchmark_set/image_data')

import matplotlib.pyplot as plt
import math
import numpy as np

# Global settings
fontsize = 8
plt.rcParams.update({'font.size': fontsize})
plt.style.use('seaborn-v0_8-paper')  # Clean, publication-style theme
plt.rcParams["text.latex.preamble"] = r"\usepackage{amssymb}"

N_TILES = 6
FIGSIZE = (6.5, 3.5)

log = log_data_A_vrnn[1]["steps"]

images, titles = [], []
c1_values = []
unique_idxs = set()
i = 0
# Collect unique images until we have N_TILES (or run out of log entries)
# while len(images) < N_TILES and i < len(log):
for i in [0,1,2,3,4,49]:
    entry = log[i]
    
    # # Comparing targets
    # k_raw = torch.tensor(entry["k_raw"], device=device)
    # _, best_k_raw, _, best_idx = find_closest_sample(k_raw, [train_data], compare='targets')
    
    
    # Comparing normalized features
    x_raw = torch.tensor(entry["x_raw"], device=device, dtype=dtype)
    x_bar = encode_x(x_raw)
    _, best_x_bar, best_k_bar, best_idx = find_closest_sample(x_bar, [train_data_norm], compare='normalized_x')
    best_x = decode_x(best_x_bar)  # Get the raw counterpart of the best sample
    best_k_raw = decode_y(best_x.to(device), best_k_bar.to(device))  # Get the raw counterpart of the best sample
    
    print(best_idx, best_k_raw.detach(), anisotropy(best_k_raw), entry['anisotropy'])
    if best_idx is not None and best_idx not in unique_idxs:
    # if True:
        unique_idxs.add(best_idx)
        img = MS[best_idx % 30000]  # MS dataset has 30000 samples
        images.append(img.squeeze().cpu().numpy())
        titles.append(
            f"Opt step - ${entry['step']}$\n"
            f"$\\mathcal{{J}}_\\mathrm{{opt}} = {entry['anisotropy']:.2f}$\n"
            f"$\\mathcal{{J}}_\\mathrm{{closest}} = {anisotropy(best_k_raw):.2f} \\Lsh$"
        )
        # Store volume fraction to display under the image
        c1_values.append(f"$c_1 = {1-best_x[0].item():.2f}$")
    i += 1

# Adjust the grid: if fewer images are found, use a balanced grid.
rows = 1
cols = math.ceil(len(images) / rows)

fig, axes = plt.subplots(rows, cols, figsize=FIGSIZE, squeeze=False, dpi=300)

for idx, ax in enumerate(axes.ravel()):
    if idx < len(images):
        ax.imshow(images[idx], cmap='plasma')
        ax.set_title(titles[idx], fontsize=8)
        
        # Add volume fraction text below the image
        ax.text(0.5, -0.05, c1_values[idx], transform=ax.transAxes, 
                ha='center', va='top', fontsize=8)
    ax.axis('off')

plt.tight_layout()
plt.show()


# Save figure
fig.savefig("../../overleaf/gfx/therm2d_inverse_design_microstructures.png", 
            bbox_inches='tight', 
            dpi=600,
            metadata={'Creator': '', 'Producer': ''})
plt.show()


In [None]:
_, _, _, log_data_A_vrnn_all = design_most_anisotropic(R=contrastR_A, restarts=100, lr=1e-1, steps=50, use_norm=True)


# Plot all optimization trajectories x- steps y- anisotropy of all restarts and draw the best one boldly and all the others faintly
import matplotlib.pyplot as plt

# Global settings
fontsize = 8
plt.rcParams.update({'font.size': fontsize})
plt.style.use('seaborn-v0_8-paper')  # Clean, publication-style theme
plt.rcParams["text.latex.preamble"] = r"\usepackage{amssymb}"


fig, ax = plt.subplots(figsize=(3.0, 1.5), dpi=300)

# Identify best trajectory based on the final anisotropy value (largest)
best_final = -float('inf')
best_traj = None
for traj in log_data_A_vrnn_all:
    anisotropies = [step["anisotropy"] for step in traj["steps"]]
    if anisotropies[-1] > best_final:
        best_final = anisotropies[-1]
        best_traj = traj

# Plot each trajectory; best one in bold blue, others in faint gray
for traj in log_data_A_vrnn_all:
    steps = [step["step"]-1 for step in traj["steps"]]
    anisotropies = [step["anisotropy"] for step in traj["steps"]]
    if traj is best_traj:
        ax.plot(steps, anisotropies, '-X',color="tab:blue", linewidth=1, markersize=2.5, label="Best trajectory", zorder=10)
    else:
        ax.plot(steps, anisotropies, color="gray", linewidth=0.25, alpha=0.25)
    
    # Highlight theoretical anisotropy max which is (R+1)^2/(4*R) where R=1/100
    theoretical_max = (contrastR_A + 1)**2 / (4 * contrastR_A)
    ax.axhline(theoretical_max, color='red', linestyle='--', linewidth=0.5, alpha=0.5,
               label=r'$\mathcal{J}_\mathrm{max}$' if traj is best_traj else None)

ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
ax.set_axisbelow(True)
ax.set_xlim(0, 50)
ax.set_ylim(0, 1.05 * theoretical_max)
ax.minorticks_on() # Add minor ticks
ax.set_xlabel("optimization steps $[-]$", fontsize=fontsize)
ax.set_ylabel("Anisotropy $[-]$", fontsize=fontsize)
# ax.set_yscale('log')
ax.legend(loc='lower right', fontsize=fontsize)
plt.show()

In [None]:
# ─────────────────────────────────────────────────────────────
# B. design least phase contrast given target κ
# TODO: make it work for both normalized and vanilla models
# ─────────────────────────────────────────────────────────────

def design_least_phase_contrast_given_k(k_target,
                                        restarts=3,
                                        lr=1e-2,
                                        steps=500,
                                        beta=0.1):
    """
    Given a target κ tensor k_target, optimizes for a raw sample x such that:
      - The predicted κ = model(x) is close to k_target (MSE loss)
      - The phase contrast R (last feature) is minimized (weighted by beta)
    In the constructed x:
      - x[0] (volume fraction) and R are clamped between 0 and 1,
      - The last but one feature is forced to 1/R.
    Returns:
      best_x: optimized raw sample (53-dimensional)
      best_k: model(x) for that sample,
      best_R: optimized R value,
      best_loss: final loss value
    """
    best_loss = float('inf')
    best_x = None
    best_k = None
    best_R_val = None

    for r in range(restarts):
        # start from a random raw sample and get its normalized version
        x_rand, _ = random_raw_sample()  # random raw sample
        x_bar = encode_x(x_rand).detach().clone()
        x_bar.requires_grad_(True)
        # Initialize R as a parameter; we start from 0.5 (in [0,1])
        R_param = torch.tensor(0.5, device=device, dtype=dtype, requires_grad=True)

        opt = torch.optim.Adam([x_bar, R_param], lr=lr)

        for i in range(steps):
            opt.zero_grad()
            # Decode x from x_bar
            x_dec = decode_x(x_bar)
            # Override the last two features: set feature[-2] = 1/R and feature[-1] = R.
            # (Safe division: R_param is clamped away from 0.)
            x_mod = torch.cat([x_dec[:-2], torch.stack([1.0/R_param, R_param])])
            # Get the predicted κ from the raw model input
            k_pred = model(x_mod.unsqueeze(0)).squeeze(0)
            # Loss: matching the given κ plus a term to encourage low R
            loss_val = F.mse_loss(k_pred, k_target) + beta * R_param
            loss_val.backward()
            opt.step()

            # Clamp the constraints (without breaking the gradient tracking in future steps,
            # we do this in no_grad):
            with torch.no_grad():
                x_bar[0].clamp_(0.0, 1.0)
                R_param.clamp_(1e-6, 1.0)

            if (i+1) % 100 == 0:
                print(f"Restart {r+1}, Step {i+1}/{steps}: Loss={loss_val.item():.4f}, R={R_param.item():.4f}")

        if loss_val.item() < best_loss:
            best_loss = loss_val.item()
            best_x = x_mod.detach().clone()
            best_k = k_pred.detach().clone()
            best_R_val = R_param.detach().clone()
    
    return best_x, best_k, best_R_val, best_loss

# Example: using an existing κ tensor (here bk) as k_target.
xB_opt, kB_opt, R_opt, final_loss = design_least_phase_contrast_given_k(bk_vrnn,
                                                                        restarts=3,
                                                                        lr=1e-2,
                                                                        steps=500,
                                                                        beta=0.1)
print("Optimized sample x (raw):", xB_opt.cpu().numpy())
print("Predicted κ:", kB_opt.cpu().numpy())
print("Optimized phase contrast R:", R_opt.item())
print("Final loss:", final_loss)



In [None]:
# ─────────────────────────────────────────────────────────────
# C.  minimise κ̄ at fixed R  (thermal insulator)
# TODO: make it work for both normalized and vanilla models
# ─────────────────────────────────────────────────────────────
def design_low_kbar_fixed_R(R,
                            restarts=1,
                            lr     = 1e-2,   
                            steps  = 800):

    best_val, best_xbar, best_k = -float('inf'), None, None
    
    for _ in range(restarts):
        x, _ = random_raw_sample()  # random raw sample
        x[-2:] = torch.tensor([1.0/R, R], device=device, dtype=dtype)  # enforce (1/R, R)
        x_bar = encode_x(x)                                            # x_bar in [0,1]
        x_bar.requires_grad_(True) 
        opt = torch.optim.Adam([x_bar], lr=lr)
        
        for _ in range(steps):
            opt.zero_grad()
            k_raw, k_bar = raw_forward_from_xbar(x_bar)
            loss  = kbar(k_raw) # - anisotropy(k_raw)
            loss.backward()
            x_bar.grad[-2:] = 0  # do not backprop through the last two features (1/R, R)
            opt.step()  
            
            print(f"Step {_+1}/{steps}: "
                  f"kbar = {loss.item():.4f}")
            
            # clamp *without* breaking the graph:
            with torch.no_grad():
                x_bar[0].clamp_(0.25, 0.75)    # keep volume fraction in [0.25, 0.75]
        
        val = loss.item()
        if val > best_val:
            best_val, best_xbar, best_k = val, x_bar.detach().clone(), k_raw.detach()
            
    best_x = decode_x(best_xbar) 

    return (best_x,                            # 53-D (raw)
            best_k,                            # κ_raw
            best_val)                          # anisotropy


# C  – minimise mean κ at R = 5
xC, kC, kbarC = design_low_kbar_fixed_R(R=1./100)
print("k̄ =", kbarC)
print("reached κ =", kC.detach().cpu().numpy())
print("reached x =", xC.detach().cpu().numpy())