# Figures comparing deformation

In [None]:
import pickle
import gc
import os
import sys
import importlib

import numpy as np
import funcs_helpers as fh
import mlflow
import matplotlib.pyplot as plt
import torch

## Load data

In [None]:
# load data
data_path0 = r"data\coarseMesh_noBifurcation_5\graphs_coarseMesh_noBifurcation_diameter_0.9_5_noBulkNodes_2.pkl"
data_path1 = r"data\coarseMesh_noBifurcation_5\graphs_coarseMesh_noBifurcation_diameter_0.9_5_noBulkNodes_largerRVE.pkl"

data = {}
with open(data_path0, 'rb') as f:
    data['reference'] = pickle.load(f)
print(data['reference'][0])
print(len(data['reference']))

with open(data_path1, 'rb') as f:
    data['extended RVE'] = pickle.load(f)
print(data['extended RVE'][0])
print(len(data['extended RVE']))


In [None]:

# %%
# # only keep validation data
# trajs = [graph.traj.numpy() for graph in data['reference']]
# train_inds, val_inds = fh.split_trajs(trajs, 2, split_sizes=(0.7, 0.3))

# for key in ['reference', 'extended RVE']:
#     data[key] = [data[key][ind] for ind in val_inds]

# %%
# rotate/reflect/scale
scale_factor = 1.5
R_arr = [torch.tensor([[0.707107, -0.707107],[0.707107, 0.707107]]),  # rot
         torch.tensor([[-1.0, -0],[0, 1.0]]),  # refl
         torch.tensor([[scale_factor, 0],[0, scale_factor]])]   # scale
keys = ['rotated', 'reflected', 'scaled']
for R, key in zip(R_arr, keys):
    data[key] = []
    err = []
    for graph in data['reference']:
        graph = graph.clone()

        R = R.to(graph.y.device)
        # print(R.device)
        # print(graph.y.device)
        graph.y = torch.matmul(R, graph.y.T).T
        graph.pos = torch.matmul(R, graph.pos.T).T
        graph.r = torch.matmul(R, graph.r.T).T
        # graph.mean_pos = torch.matmul(R, graph.mean_pos.T).T
        if key == 'scaled':
            graph.d = scale_factor*graph.d
        if key != 'scaled':
            graph.P = torch.einsum('lj,ijk,km->ilm', R, graph.P, R.T)
            graph.D = torch.einsum('nj,ok,pl,qm,ijklm->inopq', R, R, R, R, graph.D)
            graph.F = torch.einsum('ij,lk,mjk->mil', R, R, graph.F)

        graph.mean_pos = torch.mean(graph.y, axis=0, keepdim=True)
        # mean_pos2 = torch.mean(graph.y, axis=0, keepdim=True)
        # err.append((mean_pos2 - graph.mean_pos).cpu().numpy())
        data[key].append(graph)
    # print(key, 'MAE in mean_pos:', np.mean(np.abs(err)))

# %%
# shifted RVE
# vector describing how the RVE is shifting
shift_vec = torch.tensor([0.8, 0.8])

# basis vectors spanning the RVE
basis_vecs = torch.tensor([[3.2, 0], [0, 3.2]])  #.to(data['reference'][0].y.device)
data['shifted RVE'] = []
for graph in data['reference']:
    graph = graph.clone()
    basis_vecs = basis_vecs.to(graph.y.device)
    shift_vec = shift_vec.to(graph.y.device)
    graph.pos += shift_vec
    graph.y += torch.matmul(graph.F, shift_vec)

    bools1 = graph.pos[:, 0] > 1.6
    graph.pos[bools1] -= basis_vecs[0]
    graph.y[bools1] -= torch.matmul(graph.F, basis_vecs[0])

    bools2 = graph.pos[:, 1] > 1.6
    graph.pos[bools2] -= basis_vecs[1]
    graph.y[bools2] -= torch.matmul(graph.F, basis_vecs[1])

    graph.mean_pos = torch.mean(graph.y, dim=0, keepdim=True)
    data['shifted RVE'].append(graph)

# %% noisy distances (to check how errors grow)
data['noisy distances'] = []
for graph in data['reference']:
    graph = graph.clone()
    graph.d += torch.randn(*graph.d.shape)*1e-7
    data['noisy distances'].append(graph)

data['noisy d and r'] = []
for graph in data['reference']:
    graph = graph.clone()
    graph.d += torch.randn(*graph.d.shape)*1e-7
    graph.r += torch.randn(*graph.r.shape)*1e-7
    data['noisy d and r'].append(graph)


## Find specific F

In [None]:

# %%
# Find specific F

F = np.concatenate([graph.F.numpy() for graph in data['reference']])

# find specific configurations
inds_to_plot = []

# biaxial compression (rotational bifurcation only)
inds = np.where(((F[:, 0, 1] == 0)*(F[:, 0, 0] < 0.9)*(F[:, 1, 1] < 0.9)))[0]
inds_to_plot.append(inds[np.argmin(F[inds, 0, 1])])

# biaxial tension
inds = np.where((F[:, 0,0] == F[:, 1,1]) * (F[:, 0,1] == 0) * (F[:, 1,0] == 0) * (F[:, 0,0] != 1.0))[0]
inds_to_plot.append(inds[np.argmax(F[inds, 0, 0])])

# only shear
inds = np.where((F[:, 0, 0] == 1)*(F[:,1,1] == 1))[0]
inds_to_plot.append(inds[np.argmax(F[inds, 0, 1])])

# tension + compression (left/right bifurcation only)
inds = np.where(((F[:, 0, 1] == 0)*(F[:, 0, 0] > 1.23)*(F[:, 0, 0] < 1.27)*(F[:, 1, 1] < 0.77)*(F[:, 1, 1] > 0.73)))[0]
inds_to_plot.append(inds[np.argmax(F[inds, 0, 0] - F[inds, 1, 1])])

# double bifurcation
goal_F = np.array([[1.05, 0], [0, 0.8]])
MSE = np.mean((F - goal_F)**2, axis=(1,2))
inds_to_plot.append(np.argmin(MSE))

# # + 4 random cases
# rng = np.random.default_rng(seed=42)
# inds_to_plot.extend(rng.integers(0, len(F), size=4))

print(inds_to_plot)
print(F[inds_to_plot])


## All models, rows: models, columns: test cases

In [None]:

# %%
# %matplotlib qt
# create comparison plots deformation

client = mlflow.tracking.MlflowClient()

# with open(data_path3, 'rb') as f:
#     data['diameter 0.8'] = pickle.load(f)
# del data['diameter 0.8 (buckled)']
# del data['diameter 0.8 (unbuckled)']

plt.rcParams.update({
    "text.usetex": True,
    'text.latex.preamble': r'\usepackage{{amsmath}}',
    'font.size': 20
})
# run_ids = ['0cb58709c97f428b9bdff611b554c2df',
#            '855011c4a49340f6a6482a24bae0f874', 'f09812771fec478eb1195216f3ae018d', 'f41d9fe09c2b469f8e70416ff0cc2b8e']
# names = ['GNN', 'EGNNmod1', 'EGNNmod2', 'EGNNmod2_bigger']
# run_ids = ['0cb58709c97f428b9bdff611b554c2df',
#            '2810e07598f748819b99f6b41b1b2423',
#            '39c43c406253405fb16182ea6316652a',
#            '855011c4a49340f6a6482a24bae0f874',
#            'f09812771fec478eb1195216f3ae018d'
#            ]
# names = ['GNN',
#          'GNN, DA ×1',
#          'GNN, DA ×2',
#          'EGNNmod1',
#          'EGNNmod2',
#         #  'EGNNmod2_bigger'
#         ]

run_ids = ['9453aa73b9ee4b42ac9560ff37693d6f',
           'f25933db5e5545388265e2e9261edca3',
           '64e45fceb1eb46b6a977851c7123bca8',
           '0f54a8a568094a3085cc387fe21c3d38',
           'bdca87c393db4da3b387a805426d25d2',
           'a49f45e0cc3743ab914283c4ea0e6d60',
           'f09812771fec478eb1195216f3ae018d',
           ]
model_names = ['GNN',
         'GNN, DA ×1',
         'GNN, DA ×2',
         'EGNN',
         'EGNN, DA ×1',
         'EGNN, DA ×2',
         'SimEGNN',
        #  'EGNNmod2_bigger'
        ]

cases = ['reference', 'shifted RVE',
                             'extended RVE', 'reflected',
                             'rotated', 'scaled',
                             # 'diameter 0.8', 'finer mesh'
                            ]
figs = []
axes = []
# one separate plot for each F
for ind in inds_to_plot:
    fig, ax = plt.subplots(len(run_ids), len(cases), figsize=(20, 25))
    figs.append(fig)
    axes.append(ax)

for i, [run_id, name] in enumerate(zip(run_ids, model_names)):
    print('================================================')
    print('testing model', name)
    torch.cuda.empty_cache()
    gc.collect()

    # import model

    art_path = client.get_run(run_id).info.artifact_uri[8:]
    print(art_path)

    # Fetch the logged artifacts
    artifacts = client.list_artifacts(run_id)

    # find the files with the weights of the model and the parameters needed to initialize it
    for artifact in artifacts:
        if 'weights.pt' in artifact.path:
            model_path = os.path.join(art_path, artifact.path)
            print(model_path)
        elif artifact.path.endswith('model_init_params.pkl'):
            init_params_path = os.path.join(art_path, artifact.path)
            print(init_params_path)

    # import initialization parameters
    with open(init_params_path, 'rb') as f:
        init_params = pickle.load(f)

    # make sure model definition is imported from the right directory
    sys.path.insert(0, art_path)

    # import model definition
    files = os.listdir(art_path)
    fleur_GNN_definition = [file for file in files if file.startswith('fleur_GNN')]
    if len(fleur_GNN_definition) > 1:
        raise Exception('Multiple fleur_GNN definitions found')
    elif len(fleur_GNN_definition) == 0:
        raise Exception('No fleur_GNN definition found')

    fG = importlib.import_module(fleur_GNN_definition[0].split('.')[0])
    fG = importlib.reload(fG)

    # create model, load params
    model = fG.MyGNN(**init_params)
    model.load_state_dict(torch.load(model_path))
    print(model)

    scaling_factors = eval(mlflow.get_run(run_id).data.params['scaling_factors'])
    print('scaling_factors:', scaling_factors)
    device = torch.device('cuda')
    model.to(device)

    # iterate over different test cases
    for j, key in enumerate(cases):
        print(key)

        for k, ind in enumerate(inds_to_plot):
            ax = axes[k][i][j]

            if key in ['finer mesh', 'diameter 0.8']:
                F2 = np.concatenate([graph.F.numpy() for graph in data[key]])

                # find same F in different dataset
                MSE = np.mean((F2 - F[ind])**2, axis=(1,2))
                ind = np.argmin(MSE)

            # get all relevant positions
            graph = data[key][ind].clone().to(device)
            graph.batch = torch.zeros(len(graph.x), dtype=torch.long).to(device)
            pos2 = model(graph)[0].clone().cpu().detach().numpy()  # predicted

            pos1_temp = graph.pos.clone().cpu().detach().reshape(-1, 2, 1)  # original
            pos1 = torch.matmul(graph.F.cpu(), pos1_temp).reshape(-1, 2).numpy()  # affine
            pos1_temp = pos1_temp.numpy()

            pos3 = graph.y.clone().cpu().detach().numpy()  # target

            graph_indices = graph.edge_index.cpu().detach().numpy()

            # plot nodes
            # ax.scatter(*(pos1.T), label='affine position', s=1)
            ax.scatter(*(pos3.T), label='final position', s=1)
            ax.scatter(*(pos2.T), label='GNN prediction', s=1)


            # plot edges
            # x, y = np.transpose(pos1[graph_indices], axes=[2,0,1])
            # edges1 = ax.plot(x[:, bools], y[:, bools], alpha=0.3, c='tab:blue')  #, label='affine position')
            x, y = np.transpose(pos1_temp[...,0][graph_indices], axes=[2,0,1])
            # exclude wraparound edges
            bools = ((np.abs(np.diff(x, axis=0)) < 1.6)
                        & (np.abs(np.diff(y, axis=0)) < 1.6)
                        & (graph.edge_attr.cpu().numpy() == -1).T
                    ).flatten()
            x, y = np.transpose(pos3[graph_indices], axes=[2,0,1])
            edges3 = ax.plot(x[:, bools], y[:, bools], alpha=0.3, c='tab:green')  #, label='final position')
            x, y = np.transpose(pos2[graph_indices], axes=[2,0,1])
            edges2 = ax.plot(x[:, bools], y[:, bools], alpha=0.3, c='tab:orange')  #, label='GNN prediction')

            # # plot fixed node
            # fixed_node = ax.scatter(*pos1[graph.fixed_corner_ind].T, s=20, c='magenta', label='fixed node')

            # plot original locations corners
            corner_coords = np.array([[-1.6, -1.6, 1.6, 1.6],[-1.6, 1.6, 1.6, -1.6]])
            if key == 'scaled':
                corner_coords *= scale_factor
            elif key == 'extended RVE':
                corner_coords = np.array([[-1.6, -1.6, 1.6 + 3.2, 1.6 + 3.2],[-1.6, 1.6 + 3.2, 1.6 + 3.2, -1.6]])
            elif key == 'rotated':
                R =  np.array([[0.707107, -0.707107],[0.707107, 0.707107]])
                corner_coords = np.matmul(R, corner_coords)
            orig_corners = ax.scatter(*corner_coords, marker='x',
                                      label='original corners', c='red', s=20, zorder=10)

            # plot new locations corners
            corner_coords = np.matmul(graph.F.cpu().numpy(), corner_coords)[0]
            # new_corners = ax.scatter(*corner_coords, marker='x', label='new corners', c='red', s=20)
            ax.fill(*corner_coords, facecolor='lightgray')

            if i == 0:  # first model
                ax.set_title(f'{key}', size=22)
            if j == 0:  # first test case
                ax.set_ylabel(f'{name}', size=22)

            # make plot square
            xlims = ax.get_xlim()
            ylims = ax.get_ylim()
            low = min(xlims[0], ylims[0])
            high = max(xlims[1], ylims[1])
            ax.set_xlim([low, high])
            ax.set_ylim([low, high])

            # ax.set_title(f'F={graph.F.detach().numpy()[0]}')
            # F_temp = graph.F.detach().cpu().numpy()[0]
            # temp_str = r'$\textbf{\textrm{F}}=\begin{bmatrix}' + str(F_temp[0,0]) + ' & ' + str(F_temp[0, 1]) + r'\\' + str(F_temp[1, 0]) + ' & ' + str(F_temp[1, 1]) + r'\end{bmatrix}$'
            # ax.set_title(temp_str)

            ax.set_aspect('equal')

for fig, ind in zip(figs, inds_to_plot):
    fig.subplots_adjust(wspace=0.4)
    # fig.suptitle(f'{F[ind]}')
    path = f'results/final_results/compare_deformations/compare_deformations_all_models_F=[{F[ind][0]},{F[ind][1]}]'
    path = path.replace(' ', '_')
    # fig.savefig(path + '.svg', ddpi=300)
    # fig.savefig(path + '.pdf', ddpi=300)
    # fig.savefig(path + '.png', ddpi=300)
    fig.savefig(path + '.png', ddpi=600, bbox_inches='tight')

plt.close('all')


## Only DA ×2 + SimEGNN, no RVE in-/equivariance, rows: test cases, columns: models 

In [None]:

# %%
# %matplotlib qt
# create comparison plots deformation

client = mlflow.tracking.MlflowClient()

# with open(data_path3, 'rb') as f:
#     data['diameter 0.8'] = pickle.load(f)
# del data['diameter 0.8 (buckled)']
# del data['diameter 0.8 (unbuckled)']

plt.rcParams.update({
    "text.usetex": True,
    'text.latex.preamble': r'\usepackage{{amsmath}}',
    'font.size': 22
})
# run_ids = ['0cb58709c97f428b9bdff611b554c2df',
#            '855011c4a49340f6a6482a24bae0f874', 'f09812771fec478eb1195216f3ae018d', 'f41d9fe09c2b469f8e70416ff0cc2b8e']
# names = ['GNN', 'EGNNmod1', 'EGNNmod2', 'EGNNmod2_bigger']
# run_ids = ['0cb58709c97f428b9bdff611b554c2df',
#            '2810e07598f748819b99f6b41b1b2423',
#            '39c43c406253405fb16182ea6316652a',
#            '855011c4a49340f6a6482a24bae0f874',
#            'f09812771fec478eb1195216f3ae018d'
#            ]
# names = ['GNN',
#          'GNN, DA ×1',
#          'GNN, DA ×2',
#          'EGNNmod1',
#          'EGNNmod2',
#         #  'EGNNmod2_bigger'
#         ]

run_ids = [
            '9453aa73b9ee4b42ac9560ff37693d6f', # GNN
        #    'f25933db5e5545388265e2e9261edca3',  # GNN
        #    '64e45fceb1eb46b6a977851c7123bca8',    # GNN
           '0f54a8a568094a3085cc387fe21c3d38',    # EGNN
        #    'bdca87c393db4da3b387a805426d25d2',    # EGNN
        #    'a49f45e0cc3743ab914283c4ea0e6d60',      # EGNN
           'f09812771fec478eb1195216f3ae018d',      # SimEGNN
           ]
model_names = [
        'GNN',
        #  'GNN, DA ×1',
        #  'GNN, DA ×2',
         'EGNN',
        #  'EGNN, DA ×1',
        #  'EGNN, DA ×2',
         'SimEGNN',
        #  'EGNNmod2_bigger'
        ]

cases = ['reference',
        #  'shifted RVE',
        #  'extended RVE',
         'reflected',
         'rotated',
         'scaled',
        # 'diameter 0.8', 'finer mesh'
        ]
figs = []
axes = []
# one separate plot for only the first F
for ind in inds_to_plot[:1]:
    fig, ax = plt.subplots(len(cases),len(run_ids), figsize=(12, 15))
    figs.append(fig)
    axes.append(ax)

for i, [run_id, name] in enumerate(zip(run_ids, model_names)):
    print('================================================')
    print('testing model', name)
    torch.cuda.empty_cache()
    gc.collect()

    # import model

    art_path = client.get_run(run_id).info.artifact_uri[8:]
    print(art_path)

    # Fetch the logged artifacts
    artifacts = client.list_artifacts(run_id)

    # find the files with the weights of the model and the parameters needed to initialize it
    for artifact in artifacts:
        if 'weights.pt' in artifact.path:
            model_path = os.path.join(art_path, artifact.path)
            print(model_path)
        elif artifact.path.endswith('model_init_params.pkl'):
            init_params_path = os.path.join(art_path, artifact.path)
            print(init_params_path)

    # import initialization parameters
    with open(init_params_path, 'rb') as f:
        init_params = pickle.load(f)

    # make sure model definition is imported from the right directory
    sys.path.insert(0, art_path)

    # import model definition
    files = os.listdir(art_path)
    fleur_GNN_definition = [file for file in files if file.startswith('fleur_GNN')]
    if len(fleur_GNN_definition) > 1:
        raise Exception('Multiple fleur_GNN definitions found')
    elif len(fleur_GNN_definition) == 0:
        raise Exception('No fleur_GNN definition found')

    fG = importlib.import_module(fleur_GNN_definition[0].split('.')[0])
    fG = importlib.reload(fG)

    # create model, load params
    model = fG.MyGNN(**init_params)
    model.load_state_dict(torch.load(model_path))
    print(model)

    scaling_factors = eval(mlflow.get_run(run_id).data.params['scaling_factors'])
    print('scaling_factors:', scaling_factors)
    device = torch.device('cuda')
    model.to(device)

    # iterate over different test cases
    for j, key in enumerate(cases):
        print(key)

        for k, ind in enumerate(inds_to_plot[:1]):
            ax = axes[k][j][i]

            if key in ['finer mesh', 'diameter 0.8']:
                F2 = np.concatenate([graph.F.numpy() for graph in data[key]])

                # find same F in different dataset
                MSE = np.mean((F2 - F[ind])**2, axis=(1,2))
                ind = np.argmin(MSE)

            # get all relevant positions
            graph = data[key][ind].clone().to(device)
            graph.batch = torch.zeros(len(graph.x), dtype=torch.long).to(device)
            pos2 = model(graph)[0].clone().cpu().detach().numpy()  # predicted

            pos1_temp = graph.pos.clone().cpu().detach().reshape(-1, 2, 1)  # original
            pos1 = torch.matmul(graph.F.cpu(), pos1_temp).reshape(-1, 2).numpy()  # affine
            pos1_temp = pos1_temp.numpy()

            pos3 = graph.y.clone().cpu().detach().numpy()  # target

            graph_indices = graph.edge_index.cpu().detach().numpy()

            # plot nodes
            # ax.scatter(*(pos1.T), label='affine position', s=1)
            ax.scatter(*(pos3.T), label='final position', s=1)
            ax.scatter(*(pos2.T), label='GNN prediction', s=1)


            # plot edges
            # x, y = np.transpose(pos1[graph_indices], axes=[2,0,1])
            # edges1 = ax.plot(x[:, bools], y[:, bools], alpha=0.3, c='tab:blue')  #, label='affine position')
            x, y = np.transpose(pos1_temp[...,0][graph_indices], axes=[2,0,1])
            # exclude wraparound edges
            bools = ((np.abs(np.diff(x, axis=0)) < 1.6)
                        & (np.abs(np.diff(y, axis=0)) < 1.6)
                        & (graph.edge_attr.cpu().numpy() == -1).T
                    ).flatten()
            x, y = np.transpose(pos3[graph_indices], axes=[2,0,1])
            edges3 = ax.plot(x[:, bools], y[:, bools], alpha=0.3, c='tab:green')  #, label='final position')
            x, y = np.transpose(pos2[graph_indices], axes=[2,0,1])
            edges2 = ax.plot(x[:, bools], y[:, bools], alpha=0.3, c='tab:orange')  #, label='GNN prediction')

            # # plot fixed node
            # fixed_node = ax.scatter(*pos1[graph.fixed_corner_ind].T, s=20, c='magenta', label='fixed node')

            # plot original locations corners
            corner_coords = np.array([[-1.6, -1.6, 1.6, 1.6],[-1.6, 1.6, 1.6, -1.6]])
            if key == 'scaled':
                corner_coords *= scale_factor
            elif key == 'extended RVE':
                corner_coords = np.array([[-1.6, -1.6, 1.6 + 3.2, 1.6 + 3.2],[-1.6, 1.6 + 3.2, 1.6 + 3.2, -1.6]])
            elif key == 'rotated':
                R =  np.array([[0.707107, -0.707107],[0.707107, 0.707107]])
                corner_coords = np.matmul(R, corner_coords)
            orig_corners = ax.scatter(*corner_coords, marker='x',
                                      label='original corners', c='red', s=20, zorder=10)

            # plot new locations corners
            corner_coords = np.matmul(graph.F.cpu().numpy(), corner_coords)[0]
            # new_corners = ax.scatter(*corner_coords, marker='x', label='new corners', c='red', s=20)
            ax.fill(*corner_coords, facecolor='lightgray')

            if i == 0:  # first model
                ax.set_ylabel(f'{key}', size=24)
            if j == 0:  # first test case
                ax.set_title(f'{name}', size=24)

            # make plot square
            xlims = ax.get_xlim()
            ylims = ax.get_ylim()
            low = min(xlims[0], ylims[0])
            high = max(xlims[1], ylims[1])
            ax.set_xlim([low, high])
            ax.set_ylim([low, high])

            # ax.set_title(f'F={graph.F.detach().numpy()[0]}')
            # F_temp = graph.F.detach().cpu().numpy()[0]
            # temp_str = r'$\textbf{\textrm{F}}=\begin{bmatrix}' + str(F_temp[0,0]) + ' & ' + str(F_temp[0, 1]) + r'\\' + str(F_temp[1, 0]) + ' & ' + str(F_temp[1, 1]) + r'\end{bmatrix}$'
            # ax.set_title(temp_str)

            ax.set_aspect('equal')

for fig, ind in zip(figs, inds_to_plot):
    fig.subplots_adjust(wspace=0.3, hspace=0.35)
    # fig.suptitle(f'{F[ind]}')
    path = f'results/final_results/compare_deformations/compare_deformations_3_models_F=[{F[ind][0]},{F[ind][1]}]'
    path = path.replace(' ', '_')
    # fig.savefig(path + '.svg', ddpi=300)
    # fig.savefig(path + '.pdf', ddpi=300)
    fig.savefig(path + '.png', ddpi=600, bbox_inches='tight')

plt.close('all')


## No data augmentation, rows: models, columns: test cases

In [None]:
import sys
import os
import importlib
import gc
import pickle

import numpy as np
import torch
import funcs_helpers as fh
import mlflow
import matplotlib.pyplot as plt


In [None]:

# %%
# load data
data_path0 = r"data\coarseMesh_noBifurcation_5\graphs_coarseMesh_noBifurcation_diameter_0.9_5_noBulkNodes_2.pkl"
data_path1 = r"data\coarseMesh_noBifurcation_5\graphs_coarseMesh_noBifurcation_diameter_0.9_5_noBulkNodes_largerRVE.pkl"

data = {}
with open(data_path0, 'rb') as f:
    data['untransformed'] = pickle.load(f)
print(data['untransformed'][0])
print(len(data['untransformed']))

with open(data_path1, 'rb') as f:
    data['extended RVE'] = pickle.load(f)
print(data['extended RVE'][0])
print(len(data['extended RVE']))

# %%
for key in data: print(key, len(data[key]))

# %%
# only keep validation data
trajs = [graph.traj.numpy() for graph in data['untransformed']]
train_inds, val_inds = fh.split_trajs(trajs, 2, split_sizes=(0.7, 0.3))

for key in ['untransformed', 'extended RVE']:
    data[key] = [data[key][ind] for ind in val_inds]

# %%
# rotate/reflect/scale
scale_factor = 1.5
R_arr = [torch.tensor([[0.707107, -0.707107],[0.707107, 0.707107]]),  # rot
         torch.tensor([[-1.0, -0],[0, 1.0]]),  # refl
         torch.tensor([[scale_factor, 0],[0, scale_factor]])]   # scale
keys = ['rotated', 'reflected', 'scaled']
for R, key in zip(R_arr, keys):
    data[key] = []
    err = []
    for graph in data['untransformed']:
        graph = graph.clone()

        R = R.to(graph.y.device)
        graph.y = torch.matmul(R, graph.y.T).T
        graph.pos = torch.matmul(R, graph.pos.T).T
        graph.r = torch.matmul(R, graph.r.T).T
        # graph.mean_pos = torch.matmul(R, graph.mean_pos.T).T
        if key == 'scaled':
            graph.d = scale_factor*graph.d
        if key != 'scaled':
            graph.P = torch.einsum('lj,ijk,km->ilm', R, graph.P, R.T)
            graph.D = torch.einsum('nj,ok,pl,qm,ijklm->inopq', R, R, R, R, graph.D)
            graph.F = torch.einsum('ij,lk,mjk->mil', R, R, graph.F)

        graph.mean_pos = torch.mean(graph.y, axis=0, keepdim=True)
        data[key].append(graph)

# %%
# shifted RVE
# vector describing how the RVE is shifting
shift_vec = torch.tensor([0.8, 0.8])

# basis vectors spanning the RVE
basis_vecs = torch.tensor([[3.2, 0], [0, 3.2]])  #.to(data['untransformed'][0].y.device)
data['shifted RVE'] = []
for graph in data['untransformed']:
    graph = graph.clone()
    basis_vecs = basis_vecs.to(graph.y.device)
    shift_vec = shift_vec.to(graph.y.device)
    graph.pos += shift_vec
    graph.y += torch.matmul(graph.F, shift_vec)

    bools1 = graph.pos[:, 0] > 1.6
    graph.pos[bools1] -= basis_vecs[0]
    graph.y[bools1] -= torch.matmul(graph.F, basis_vecs[0])

    bools2 = graph.pos[:, 1] > 1.6
    graph.pos[bools2] -= basis_vecs[1]
    graph.y[bools2] -= torch.matmul(graph.F, basis_vecs[1])

    graph.mean_pos = torch.mean(graph.y, dim=0, keepdim=True)
    data['shifted RVE'].append(graph)

In [None]:
run_ids = ['9453aa73b9ee4b42ac9560ff37693d6f',
           '0f54a8a568094a3085cc387fe21c3d38',
           'f09812771fec478eb1195216f3ae018d',
           ]
model_names = ['GNN',
         'EGNN',
         'SimEGNN',
        #  'EGNNmod2_bigger'
        ]

In [None]:

# %%
# Find specific F

F = np.concatenate([graph.F.numpy() for graph in data['untransformed']])

# find specific configurations
inds_to_plot = []

# biaxial tension
inds = np.where((F[:, 0,0] == F[:, 1,1]) * (F[:, 0,1] == 0) * (F[:, 1,0] == 0) * (F[:, 0,0] != 1.0))[0]
inds_to_plot.append(inds[np.argmax(F[inds, 0, 0])])

# only shear
inds = np.where((F[:, 0, 0] == 1)*(F[:,1,1] == 1))[0]
inds_to_plot.append(inds[np.argmax(F[inds, 0, 1])])

# biaxial compression
inds = np.where(((F[:, 0, 1] == 0)*(F[:, 0, 0] < 0.9)*(F[:, 1, 1] < 0.9)))[0]
inds_to_plot.append(inds[np.argmin(F[inds, 0, 1])])

# tension + compression
inds = np.where(((F[:, 0, 1] == 0)*(F[:, 0, 0] > 1.2)*(F[:, 1, 1] < 0.9)))[0]
inds_to_plot.append(inds[np.argmax(F[inds, 0, 0] - F[inds, 1, 1])])

# double bifurcation
goal_F = np.array([[1.05, 0], [0, 1-0.2]])
MSE = np.mean((F - goal_F)**2, axis=(1,2))
inds_to_plot.append(np.argmin(MSE))

# + 4 random cases
rng = np.random.default_rng(seed=42)
inds_to_plot.extend(rng.integers(0, len(F), size=4))

print(inds_to_plot)
print(F[inds_to_plot])

# %%
# %matplotlib qt
# create comparison plots deformation
figs = []
axes = []
# one separate plot for each F
for ind in inds_to_plot:
    fig, ax = plt.subplots(len(run_ids), len(data.keys()), figsize=(15, 9))
    figs.append(fig)
    axes.append(ax)

client = mlflow.tracking.MlflowClient()

# with open(data_path3, 'rb') as f:
#     data['diameter 0.8'] = pickle.load(f)
# del data['diameter 0.8 (buckled)']
# del data['diameter 0.8 (unbuckled)']

plt.rcParams.update({
    "text.usetex": True,
    'text.latex.preamble': r'\usepackage{{amsmath}}',
    'font.size': 16
})
# run_ids = ['0cb58709c97f428b9bdff611b554c2df',
#            '855011c4a49340f6a6482a24bae0f874', 'f09812771fec478eb1195216f3ae018d', 'f41d9fe09c2b469f8e70416ff0cc2b8e']
# names = ['GNN', 'EGNNmod1', 'EGNNmod2', 'EGNNmod2_bigger']
# run_ids = ['0cb58709c97f428b9bdff611b554c2df',
#            '2810e07598f748819b99f6b41b1b2423',
#            '39c43c406253405fb16182ea6316652a',
#            '855011c4a49340f6a6482a24bae0f874',
#            'f09812771fec478eb1195216f3ae018d'
#            ]
# names = ['GNN',
#          'GNN, DA ×1',
#          'GNN, DA ×2',
#          'EGNNmod1',
#          'EGNNmod2',
#         #  'EGNNmod2_bigger'
#         ]
for i, [run_id, name] in enumerate(zip(run_ids, model_names)):
    print('================================================')
    print('testing model', name)
    torch.cuda.empty_cache()
    gc.collect()

    # import model

    art_path = client.get_run(run_id).info.artifact_uri[8:]
    print(art_path)

    # Fetch the logged artifacts
    artifacts = client.list_artifacts(run_id)

    # find the files with the weights of the model and the parameters needed to initialize it
    for artifact in artifacts:
        if 'weights.pt' in artifact.path:
            model_path = os.path.join(art_path, artifact.path)
            print(model_path)
        elif artifact.path.endswith('model_init_params.pkl'):
            init_params_path = os.path.join(art_path, artifact.path)
            print(init_params_path)

    # import initialization parameters
    with open(init_params_path, 'rb') as f:
        init_params = pickle.load(f)

    # make sure model definition is imported from the right directory
    sys.path.insert(0, art_path)

    # import model definition
    files = os.listdir(art_path)
    fleur_GNN_definition = [file for file in files if file.startswith('fleur_GNN')]
    if len(fleur_GNN_definition) > 1:
        raise Exception('Multiple fleur_GNN definitions found')
    elif len(fleur_GNN_definition) == 0:
        raise Exception('No fleur_GNN definition found')

    fG = importlib.import_module(fleur_GNN_definition[0].split('.')[0])
    fG = importlib.reload(fG)

    # create model, load params
    model = fG.MyGNN(**init_params)
    model.load_state_dict(torch.load(model_path))
    print(model)

    scaling_factors = eval(mlflow.get_run(run_id).data.params['scaling_factors'])
    print('scaling_factors:', scaling_factors)
    device = torch.device('cuda')
    model.to(device)

    # iterate over different test cases
    for j, key in enumerate(['untransformed', 'shifted RVE',
                             'extended RVE', 'reflected',
                             'rotated', 'scaled',
                             # 'diameter 0.8', 'finer mesh'
                            ]):
        print(key)

        for k, ind in enumerate(inds_to_plot):
            ax = axes[k][i][j]

            if key in ['finer mesh', 'diameter 0.8']:
                F2 = np.concatenate([graph.F.numpy() for graph in data[key]])

                # find same F in different dataset
                MSE = np.mean((F2 - F[ind])**2, axis=(1,2))
                ind = np.argmin(MSE)

            # get all relevant positions
            graph = data[key][ind].clone().to(device)
            graph.batch = torch.zeros(len(graph.x), dtype=torch.long).to(device)
            pos2 = model(graph)[0].clone().cpu().detach().numpy()  # predicted

            pos1_temp = graph.pos.clone().cpu().detach().reshape(-1, 2, 1)  # original
            pos1 = torch.matmul(graph.F.cpu(), pos1_temp).reshape(-1, 2).numpy()  # affine
            pos1_temp = pos1_temp.numpy()

            pos3 = graph.y.clone().cpu().detach().numpy()  # target

            graph_indices = graph.edge_index.cpu().detach().numpy()

            # plot nodes
            # ax.scatter(*(pos1.T), label='affine position', s=1)
            ax.scatter(*(pos3.T), label='final position', s=1)
            ax.scatter(*(pos2.T), label='GNN prediction', s=1)


            # plot edges
            # x, y = np.transpose(pos1[graph_indices], axes=[2,0,1])
            # edges1 = ax.plot(x[:, bools], y[:, bools], alpha=0.3, c='tab:blue')  #, label='affine position')
            x, y = np.transpose(pos1_temp[...,0][graph_indices], axes=[2,0,1])
            # exclude wraparound edges
            bools = ((np.abs(np.diff(x, axis=0)) < 1.6)
                        & (np.abs(np.diff(y, axis=0)) < 1.6)
                        & (graph.edge_attr.cpu().numpy() == -1).T
                    ).flatten()
            x, y = np.transpose(pos3[graph_indices], axes=[2,0,1])
            edges3 = ax.plot(x[:, bools], y[:, bools], alpha=0.3, c='tab:green')  #, label='final position')
            x, y = np.transpose(pos2[graph_indices], axes=[2,0,1])
            edges2 = ax.plot(x[:, bools], y[:, bools], alpha=0.3, c='tab:orange')  #, label='GNN prediction')

            # # plot fixed node
            # fixed_node = ax.scatter(*pos1[graph.fixed_corner_ind].T, s=20, c='magenta', label='fixed node')

            # plot original locations corners
            corner_coords = np.array([[-1.6, -1.6, 1.6, 1.6],[-1.6, 1.6, 1.6, -1.6]])
            if key == 'scaled':
                corner_coords *= scale_factor
            elif key == 'extended RVE':
                corner_coords = np.array([[-1.6, -1.6, 1.6 + 3.2, 1.6 + 3.2],[-1.6, 1.6 + 3.2, 1.6 + 3.2, -1.6]])
            elif key == 'rotated':
                R =  np.array([[0.707107, -0.707107],[0.707107, 0.707107]])
                corner_coords = np.matmul(R, corner_coords)
            orig_corners = ax.scatter(*corner_coords, marker='x',
                                      label='original corners', c='red', s=20, zorder=10)

            # plot new locations corners
            corner_coords = np.matmul(graph.F.cpu().numpy(), corner_coords)[0]
            # new_corners = ax.scatter(*corner_coords, marker='x', label='new corners', c='red', s=20)
            ax.fill(*corner_coords, facecolor='lightgray')

            if i == 0:  # first model
                ax.set_title(f'{key}', size=16)
            if j == 0:  # first test case
                ax.set_ylabel(f'{name}', size=16)

            # make plot square
            xlims = ax.get_xlim()
            ylims = ax.get_ylim()
            low = min(xlims[0], ylims[0])
            high = max(xlims[1], ylims[1])
            ax.set_xlim([low, high])
            ax.set_ylim([low, high])

            # ax.set_title(f'F={graph.F.detach().numpy()[0]}')
            # F_temp = graph.F.detach().cpu().numpy()[0]
            # temp_str = r'$\textbf{\textrm{F}}=\begin{bmatrix}' + str(F_temp[0,0]) + ' & ' + str(F_temp[0, 1]) + r'\\' + str(F_temp[1, 0]) + ' & ' + str(F_temp[1, 1]) + r'\end{bmatrix}$'
            # ax.set_title(temp_str)

            ax.set_aspect('equal')

for fig, ind in zip(figs, inds_to_plot):
    fig.subplots_adjust(wspace=0.4)
    # fig.suptitle(f'{F[ind]}')
    path = f'results/final_results/compare_deformations2_F=[{F[ind][0]},{F[ind][1]}]'
    path.replace(' ', '_')
    fig.savefig(path + '.svg')
    fig.savefig(path + '.pdf')
    fig.savefig(path + '.png')

plt.close('all')


# Spider plots

In [None]:

# %%
# Define spider plots
import pickle
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
from matplotlib.patches import Circle, RegularPolygon
from matplotlib.path import Path
from matplotlib.projections.polar import PolarAxes
from matplotlib.projections import register_projection
from matplotlib.spines import Spine
from matplotlib.transforms import Affine2D


def radar_factory(num_vars, frame='circle'):
    """
    Create a radar chart with `num_vars` axes.

    This function creates a RadarAxes projection and registers it.

    Parameters
    ----------
    num_vars : int
        Number of variables for radar chart.
    frame : {'circle', 'polygon'}
        Shape of frame surrounding axes.

    """
    # calculate evenly-spaced axis angles
    theta = np.linspace(0, 2*np.pi, num_vars, endpoint=False)

    class RadarTransform(PolarAxes.PolarTransform):

        def transform_path_non_affine(self, path):
            # Paths with non-unit interpolation steps correspond to gridlines,
            # in which case we force interpolation (to defeat PolarTransform's
            # autoconversion to circular arcs).
            if path._interpolation_steps > 1:
                path = path.interpolated(num_vars)
            return Path(self.transform(path.vertices), path.codes)

    class RadarAxes(PolarAxes):

        name = 'radar'
        PolarTransform = RadarTransform

        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            # rotate plot such that the first axis is at the top
            self.set_theta_zero_location('N')

        def fill(self, *args, closed=True, **kwargs):
            """Override fill so that line is closed by default"""
            return super().fill(closed=closed, *args, **kwargs)

        def plot(self, *args, **kwargs):
            """Override plot so that line is closed by default"""
            lines = super().plot(*args, **kwargs)
            for line in lines:
                self._close_line(line)

        def _close_line(self, line):
            x, y = line.get_data()
            # FIXME: markers at x[0], y[0] get doubled-up
            if x[0] != x[-1]:
                x = np.append(x, x[0])
                y = np.append(y, y[0])
                line.set_data(x, y)

        def set_varlabels(self, labels):
            self.set_thetagrids(np.degrees(theta), labels)

        def _gen_axes_patch(self):
            # The Axes patch must be centered at (0.5, 0.5) and of radius 0.5
            # in axes coordinates.
            if frame == 'circle':
                return Circle((0.5, 0.5), 0.5)
            elif frame == 'polygon':
                return RegularPolygon((0.5, 0.5), num_vars,
                                      radius=.5, edgecolor="k")
            else:
                raise ValueError("Unknown value for 'frame': %s" % frame)

        def _gen_axes_spines(self):
            if frame == 'circle':
                return super()._gen_axes_spines()
            elif frame == 'polygon':
                # spine_type must be 'left'/'right'/'top'/'bottom'/'circle'.
                spine = Spine(axes=self,
                              spine_type='circle',
                              path=Path.unit_regular_polygon(num_vars))
                # unit_regular_polygon gives a polygon of radius 1 centered at
                # (0, 0) but we want a polygon of radius 0.5 centered at (0.5,
                # 0.5) in axes coordinates.
                spine.set_transform(Affine2D().scale(.5).translate(.5, .5)
                                    + self.transAxes)
                return {'polar': spine}
            else:
                raise ValueError("Unknown value for 'frame': %s" % frame)

    register_projection(RadarAxes)
    return theta


## Define quantities for spider plot

In [None]:
metric = 'FVU' # 'MSE' #

if (metric == 'FVU' or metric == 'R2'):
    with open('results/final_results/results_R2_4.pkl', 'rb') as f:
        results = pickle.load(f)
elif metric == 'MSE':
    with open('results/final_results/results_MSE_4.pkl', 'rb') as f:
        results = pickle.load(f)
else:
    raise ValueError(f'{metric} is an invalid metric')

plt.rcParams.update({
    "text.usetex": True,
    'text.latex.preamble': r'\usepackage{{amsmath}} \usepackage{{amssymb}}',
    # 'font.size': 14
})
spoke_quantities= ['reference', 'scaled', 'rotated',
                'reflected', 'larger RVE', 'shifted RVE']
spoke_labels = ['reference', 'scaled', 'rotated',
                'reflected', 'extended RVE', 'shifted RVE']
quantities = ['y', 'W', 'P', 'D']
models = ['GNN',
          'GNN, DA ×1', 'GNN, DA ×2',
         'EGNN',
         'EGNN, DA ×1', 'EGNN, DA ×2',
          'SimEGNN']
colors = ['tab:blue',
          'tab:blue', 'tab:blue',
          'tab:orange',
          'tab:orange', 'tab:orange',
          'tab:green']
linestyles = ['solid',
              'dashed', 'dotted',
              'solid',
              'dashed', 'dotted',
              'solid']

model_names = models  #['GNN', 'GNN, Data Augmentation', 'Improved GNN']

N = len(spoke_labels)
theta = radar_factory(N, frame='polygon')

## Create spider plots (2×2)

In [None]:
%matplotlib qt
# fig, axs = plt.subplots(figsize=(9, 9), nrows=2, ncols=2,
#                         subplot_kw=dict(projection='radar'))
# fig.subplots_adjust(wspace=0.35, hspace=0.30, top=0.85, bottom=0.05)

fig, axs = plt.subplots(figsize=(7, 7), nrows=2, ncols=2,
                        subplot_kw=dict(projection='radar'), frameon=False)
fig.subplots_adjust(wspace=0.05, hspace=0.4, top=0.85, bottom=0.05, left=0.03, right=0.95)
fig.patch.set_facecolor("None")

# colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple']
# Plot the four cases from the example data on separate axes
for ax, quantity, title in zip(axs.flat, quantities,
                               [r'microfluctuation $\vec{w}$',
                                r'energy $\mathfrak{W}$',
                                r'stress $\textbf{\textrm{P}}$',
                                r'stiffness $\textbf{\textrm{D}}$'
                                ]):
    # ax.set_rgrids([0.2, 0.4, 0.6, 0.8])
    ax.set_title(title, weight='bold', size='large', position=(0.5, 1.3),
                    horizontalalignment='center', verticalalignment='center')
    for model, color, linestyle in zip(models, colors, linestyles):
        if metric == 'FVU':
            r = [1-results[model][key][quantity] for key in spoke_quantities]
        else:
            r = [results[model][key][quantity] for key in spoke_quantities]
        ax.plot(theta, r, color=color, linestyle=linestyle)
        # ax.fill(theta, results2[quantity][key], facecolor=color, alpha=0.25, label='_nolegend_')
    ax.set_varlabels(spoke_labels)

    if metric == 'FVU':
        if quantity == 'D':
            ax.set_rscale('symlog', linthresh=1e-4)
            # ax.set_rlim([1e-5, 1e-1])
        elif quantity == 'W':
            ax.set_rscale('symlog', linthresh=1e-6)
            # ax.set_rlim([1e-7, 1e-2])
        elif quantity == 'y':
            ax.set_rscale('symlog', linthresh=1e-4)
            # ax.set_rlim([1e-2, 1])
        else:
            ax.set_rscale('symlog', linthresh=1e-6)
            # ax.set_rlim([1e-7, 1e-3])
    elif metric == 'MSE':
        if quantity == 'y':
            ax.set_rscale('symlog', linthresh=1e-6)
        elif quantity == 'P':
            ax.set_rscale('symlog', linthresh=1e-2)
        else:
            ax.set_rscale('symlog')
    else:
        raise NotImplementedError(f'metric {metric} not implemented')
    # ax.set_rasterization_zorder(1)

    labels = ax.get_yticklabels()
    for label in labels:
        label.set_x(-0.39)  # Adjust the x-position
        # print(*label.get_position(), label.get_text())


    # # Adjust the labelpad property for each tick label
    # for label in labels:
    #     label.set_bbox(dict(facecolor='white', edgecolor='none', pad=2.0))  # Optional: Add a white background
    #     label.set_horizontalalignment('center')  # Optional: Center-align the labels
    #     label.set_verticalalignment('center')    # Optional: Center-align the labels
    #     # label.set_rotation(180 * label.get_position()[0] / np.pi - 90)  # Optional: Rotate the labels
    #     # label.set_pad(5)  # Adjust the labelpad to move labels inwards

# add legend relative to top-left plot
legend = axs[0, 0].legend(model_names, loc=(0.9, .95),
                            labelspacing=0.1,
                            fontsize='small'
                        )

# fig.text(0.5, 0.965, '5-Factor Solution Profiles Across Four Scenarios',
#             horizontalalignment='center', color='black', weight='bold',
#             size='large')

plt.show()

fig.savefig(f'results/final_results/spiderplots_{metric}.pdf')
fig.savefig(f'results/final_results/spiderplots_{metric}.png', dpi=600)
fig.savefig(f'results/final_results/spiderplots_{metric}.svg' )





## Create spider plots (4×1)

In [None]:
%matplotlib qt
spoke_labels = 6*['']
# fig, axs = plt.subplots(figsize=(9, 9), nrows=2, ncols=2,
#                         subplot_kw=dict(projection='radar'))
# fig.subplots_adjust(wspace=0.35, hspace=0.30, top=0.85, bottom=0.05)

fig, axs = plt.subplots(figsize=(12, 4), nrows=1, ncols=4,
                        subplot_kw=dict(projection='radar'))
fig.subplots_adjust(wspace=0.4, hspace=0.40, top=0.85, bottom=0.05)
fig.patch.set_facecolor("None")

# colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple']
# Plot the four cases from the example data on separate axes
for ax, quantity, title in zip(axs.flat, quantities,
                               [r'position $\vec{x}$' if metric=='MSE' else
                                r'microfluctuation $\vec{w}$',
                                r'energy $\mathfrak{W}$',
                                r'stress $\textbf{\textrm{P}}$',
                                r'stiffness $\textbf{\textrm{D}}$'
                                ]):
    # ax.set_rgrids([0.2, 0.4, 0.6, 0.8])
    ax.set_title(title, weight='bold', size='large', y=1.35, pad=-14,
                 #position=(0.5, 1.3),
                    horizontalalignment='center', verticalalignment='center')
    for model, color, linestyle, alpha in zip(models, colors, linestyles, [1, 1, 0]):
        if metric == 'FVU':
            r = [1-results[model][key][quantity] for key in spoke_quantities]
        else:
            r = [results[model][key][quantity] for key in spoke_quantities]
        ax.plot(theta, r, color=color, linestyle=linestyle, alpha=alpha)
        # ax.fill(theta, results2[quantity][key], facecolor=color, alpha=0.25, label='_nolegend_')
    ax.set_varlabels(spoke_labels)

    if metric == 'FVU':
        if quantity == 'D':
            ax.set_rscale('symlog', linthresh=1e-4)
            # ax.set_rlim([1e-5, 1e-1])
        elif quantity == 'W':
            ax.set_rscale('symlog', linthresh=1e-6)
            # ax.set_rlim([1e-7, 1e-2])
        elif quantity == 'y':
            ax.set_rscale('symlog', linthresh=1e-4)
            # ax.set_rlim([1e-2, 1])
        else:
            ax.set_rscale('symlog', linthresh=1e-6)
            # ax.set_rlim([1e-7, 1e-3])
    elif metric == 'MSE':
        if quantity == 'y':
            ax.set_rscale('symlog', linthresh=1e-6)
        elif quantity == 'P':
            ax.set_rscale('symlog', linthresh=1e-2)
        else:
            ax.set_rscale('symlog')
    else:
        raise NotImplementedError(f'metric {metric} not implemented')
    # ax.set_rasterization_zorder(1)

    labels = ax.get_yticklabels()

    # # Adjust the labelpad property for each tick label
    # for label in labels:
    #     label.set_bbox(dict(facecolor='white', edgecolor='none', pad=2.0))  # Optional: Add a white background
    #     label.set_horizontalalignment('center')  # Optional: Center-align the labels
    #     label.set_verticalalignment('center')    # Optional: Center-align the labels
    #     # label.set_rotation(180 * label.get_position()[0] / np.pi - 90)  # Optional: Rotate the labels
    #     # label.set_pad(5)  # Adjust the labelpad to move labels inwards

    # Adjust the position of the tick labels
    labelpad = 5  # Set the desired padding value
    for label in labels:
        label.set_x(-0.39)  # Adjust the x-position
        # label.set_y(label.get_position()[1] - 500)  # Keep the y-position unchanged

# add legend relative to top-left plot
legend = axs[0].legend(model_names, loc=(0.9, .95),
                            labelspacing=0.1,
                            fontsize='small'
                        )

# fig.text(0.5, 0.965, '5-Factor Solution Profiles Across Four Scenarios',
#             horizontalalignment='center', color='black', weight='bold',
#             size='large')

plt.show()

fig.savefig(f'results/final_results/spiderplots1×4_{metric}_3.pdf')
fig.savefig(f'results/final_results/spiderplots1×4_{metric}_3.png', dpi=600)
fig.savefig(f'results/final_results/spiderplots1×4_{metric}_3.svg' )



## Create spider plots (2×1), models appear one by one

In [None]:
%matplotlib qt
spoke_labels = 6*['']
# fig, axs = plt.subplots(figsize=(9, 9), nrows=2, ncols=2,
#                         subplot_kw=dict(projection='radar'))
# fig.subplots_adjust(wspace=0.35, hspace=0.30, top=0.85, bottom=0.05)

for i in range(3):
    fig, axs = plt.subplots(figsize=(6, 4), nrows=1, ncols=2,
                            subplot_kw=dict(projection='radar'))
    fig.subplots_adjust(wspace=0.3, hspace=0.40, top=0.85, bottom=0.05)
    fig.patch.set_facecolor("None")

    # colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple']
    # Plot the four cases from the example data on separate axes
    for ax, quantity, title in zip(axs.flat, quantities[2:],
                                [
                                    #r'position $\vec{x}$' if metric=='MSE' else
                                    # r'microfluctuation $\vec{w}$',
                                    # r'energy $\mathfrak{W}$',
                                    r'stress $\textbf{\textrm{P}}$',
                                    r'stiffness $\textbf{\textrm{D}}$'
                                    ]):
        # ax.set_rgrids([0.2, 0.4, 0.6, 0.8])
        ax.set_title(title, weight='bold', size='large', y=1.35, pad=-14,
                    #position=(0.5, 1.3),
                        horizontalalignment='center', verticalalignment='center')
        alphas = [1]*(i+1) + [0]*(2-i)
        for model, color, linestyle, alpha in zip(models, colors, linestyles, alphas):
            if metric == 'FVU':
                r = [1-results[model][key][quantity] for key in spoke_quantities]
            else:
                r = [results[model][key][quantity] for key in spoke_quantities]
            ax.plot(theta, r, color=color, linestyle=linestyle, alpha=alpha)
            # ax.fill(theta, results2[quantity][key], facecolor=color, alpha=0.25, label='_nolegend_')
        ax.set_varlabels(spoke_labels)

        if metric == 'FVU':
            if quantity == 'D':
                ax.set_rscale('symlog', linthresh=1e-4)
                # ax.set_rlim([1e-5, 1e-1])
            elif quantity == 'W':
                ax.set_rscale('symlog', linthresh=1e-6)
                # ax.set_rlim([1e-7, 1e-2])
            elif quantity == 'y':
                ax.set_rscale('symlog', linthresh=1e-4)
                # ax.set_rlim([1e-2, 1])
            else:
                ax.set_rscale('symlog', linthresh=1e-6)
                # ax.set_rlim([1e-7, 1e-3])
        elif metric == 'MSE':
            if quantity == 'y':
                ax.set_rscale('symlog', linthresh=1e-6)
            elif quantity == 'P':
                ax.set_rscale('symlog', linthresh=1e-2)
            else:
                ax.set_rscale('symlog')
        else:
            raise NotImplementedError(f'metric {metric} not implemented')
        # ax.set_rasterization_zorder(1)

        labels = ax.get_yticklabels()

        # # Adjust the labelpad property for each tick label
        # for label in labels:
        #     label.set_bbox(dict(facecolor='white', edgecolor='none', pad=2.0))  # Optional: Add a white background
        #     label.set_horizontalalignment('center')  # Optional: Center-align the labels
        #     label.set_verticalalignment('center')    # Optional: Center-align the labels
        #     # label.set_rotation(180 * label.get_position()[0] / np.pi - 90)  # Optional: Rotate the labels
        #     # label.set_pad(5)  # Adjust the labelpad to move labels inwards

        # Adjust the position of the tick labels
        labelpad = 5  # Set the desired padding value
        for label in labels:
            label.set_x(-0.39)  # Adjust the x-position
            # label.set_y(label.get_position()[1] - 500)  # Keep the y-position unchanged

    # # add legend relative to top-left plot
    # legend = axs[0].legend(model_names, loc=(0.9, .95),
    #                             labelspacing=0.1,
    #                             fontsize='small'
    #                         )

    # fig.text(0.5, 0.965, '5-Factor Solution Profiles Across Four Scenarios',
    #             horizontalalignment='center', color='black', weight='bold',
    #             size='large')

    plt.show()

    fig.savefig(f'results/final_results/spiderplots1×2_{metric}_{i}.pdf')
    fig.savefig(f'results/final_results/spiderplots1×2_{metric}_{i}.png', dpi=600)
    fig.savefig(f'results/final_results/spiderplots1×2_{metric}_{i}.svg' )



## Create spider plots (1 figure per target quantity)

In [None]:
%matplotlib qt
# fig, axs = plt.subplots(figsize=(9, 9), nrows=2, ncols=2,
#                         subplot_kw=dict(projection='radar'))
# fig.subplots_adjust(wspace=0.35, hspace=0.30, top=0.85, bottom=0.05)



# colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple']
# Plot the four cases from the example data on separate axes
for i, [quantity, title, short_title] in enumerate(zip(quantities,
                               [r'microfluctuation $\vec{w}$',
                                r'energy $\mathfrak{W}$',
                                r'stress $\textbf{\textrm{P}}$',
                                r'stiffness $\textbf{\textrm{D}}$'
                                ],
                                ['w', 'W', 'P', 'D'])):

    fig, ax = plt.subplots(figsize=(4,4),
                        subplot_kw=dict(projection='radar'))
    fig.subplots_adjust(wspace=0.4, hspace=0.40, top=0.8, bottom=0.1, left=0.3, right=0.7)
    fig.patch.set_facecolor("None")

    # ax.set_rgrids([0.2, 0.4, 0.6, 0.8])
    ax.set_title(title, weight='bold', size='large', position=(0.5, 1.3),
                    horizontalalignment='center', verticalalignment='center')
    for model, color, linestyle in zip(models, colors, linestyles):
        if metric == 'FVU':
            r = [1-results[model][key][quantity] for key in spoke_quantities]
        else:
            r = [results[model][key][quantity] for key in spoke_quantities]
        ax.plot(theta, r, color=color, linestyle=linestyle)
        # ax.fill(theta, results2[quantity][key], facecolor=color, alpha=0.25, label='_nolegend_')
    ax.set_varlabels(spoke_labels)

    if metric == 'FVU':
        if quantity == 'D':
            ax.set_rscale('symlog', linthresh=1e-4)
            # ax.set_rlim([1e-5, 1e-1])
        elif quantity == 'W':
            ax.set_rscale('symlog', linthresh=1e-6)
            # ax.set_rlim([1e-7, 1e-2])
        elif quantity == 'y':
            ax.set_rscale('symlog', linthresh=1e-4)
            # ax.set_rlim([1e-2, 1])
        else:
            ax.set_rscale('symlog', linthresh=1e-6)
            # ax.set_rlim([1e-7, 1e-3])
    elif metric == 'MSE':
        if quantity == 'y':
            ax.set_rscale('symlog', linthresh=1e-6)
        elif quantity == 'P':
            ax.set_rscale('symlog', linthresh=1e-2)
        else:
            ax.set_rscale('symlog')
    else:
        raise NotImplementedError(f'metric {metric} not implemented')
    # ax.set_rasterization_zorder(1)

    labels = ax.get_yticklabels()

    # # Adjust the labelpad property for each tick label
    # for label in labels:
    #     label.set_bbox(dict(facecolor='white', edgecolor='none', pad=2.0))  # Optional: Add a white background
    #     label.set_horizontalalignment('center')  # Optional: Center-align the labels
    #     label.set_verticalalignment('center')    # Optional: Center-align the labels
    #     # label.set_rotation(180 * label.get_position()[0] / np.pi - 90)  # Optional: Rotate the labels
    #     # label.set_pad(5)  # Adjust the labelpad to move labels inwards

    # Adjust the position of the tick labels
    labelpad = 5  # Set the desired padding value
    for label in labels:
        label.set_x(-0.39)  # Adjust the x-position
        # label.set_y(label.get_position()[1] - 500)  # Keep the y-position unchanged

    if i == 0:
        # add legend relative to top-left plot
        legend = ax.legend(model_names, loc=(0.87, .95),
                                    labelspacing=0.1,
                                    fontsize='small'
                                )

# fig.text(0.5, 0.965, '5-Factor Solution Profiles Across Four Scenarios',
#             horizontalalignment='center', color='black', weight='bold',
#             size='large')

    plt.show()

    fig.savefig(f'results/final_results/spiderplots_{metric}_{short_title}.pdf')
    fig.savefig(f'results/final_results/spiderplots_{metric}_{short_title}.png', dpi=600)
    fig.savefig(f'results/final_results/spiderplots_{metric}_{short_title}.svg' )



# LaTeX table MSE/FVU/R2


In [None]:
import pandas as pd
metric = 'MSE'

if (metric == 'FVU' or metric == 'R2'):
    with open('results/final_results/results_R2_4.pkl', 'rb') as f:
        results = pickle.load(f)
elif metric == 'MSE':
    with open('results/final_results/results_MSE_4.pkl', 'rb') as f:
        results = pickle.load(f)
else:
    raise ValueError(f'{metric} is an invalid metric')

results2 = {}
for model in results:
    for key in results[model]:
        for var in results[model][key]:
            if var not in results2:
                results2[var] = {}
            if model not in results2[var]:
                results2[var][model] = []

            if metric == 'FVU':
                results2[var][model].append(1-results[model][key][var])
            else:
                results2[var][model].append(results[model][key][var])

# %%
def formatter(number):
    string = f'{number:.1e}'
    if string.startswith('-'):
        sign = '-'
        string = string[1:]
    else:
        sign = ''
    mantissa, exponent = string.split('e')
    if exponent.startswith('+'):
        exponent = exponent[1:]
    exponent = int(exponent)
    if -3 <= exponent <= 2:
        string2 = f'${np.format_float_positional(number, precision=2, fractional=False, unique=False)}$'
    else:
        string2 = f'${sign}{mantissa}\times 10^{{{exponent}}}$'
    return string2

testcase_names = ['reference', 'noisy distances', 'shifted RVE', 'extended RVE', 'reflected', 'rotated', 'scaled']  #, 'diameter 0.8 (unbuckled)', 'diameter 0.8 (buckled)', 'finer mesh']

model_names = ['GNN',
         'GNN, DA ×1',
         'GNN, DA ×2',
         'EGNN',
         'EGNN, DA ×1',
         'EGNN, DA ×2',
         'SimEGNN',
        #  'EGNNmod2_bigger'
        ]

print(r'\begin{table}')
for var, symbol, name in zip(['y', 'W', 'P', 'D'],
                               [r'$\vec{w}$',
                                r'$\mathfrak{W}$',
                                r'$\textbf{\textrm{P}}$',
                                r'$\textbf{\textrm{D}}$'
                                ],
                             ['microfluctuation', 'strain energy density', 'first Piola-Kirchhoff stress tensor',
                              'stiffness tensor']):

    df = pd.DataFrame(results2[var], index = results['GNN'].keys())

    # shuffle the rows to the order I want
    df = df.loc[testcase_names]

    # rename the columns
    df.columns = model_names

    # find best result per row
    inds = np.argmin(df.values, axis=1)

    # apply formatting
    df = df.applymap(formatter)

    # make best result per row boldface
    for i, ind in enumerate(inds):
        df.iloc[i, ind] = '$\mathbf{' + df.iloc[i, ind][1:-1] + '}$'

    print(r'\centering')
    print(r'\tiny')
    print(r'\caption{', metric, ' of the ', f'{name} {symbol}', r'\label{tab:results', var, '}}', sep='')
    print(df.to_latex(escape=False))
    print(r'\bigskip')
    print('')

print(r'\end{table}')

# LaTeX table frob


In [None]:
import pandas as pd
import pickle
import numpy as np

with open('results/final_results/results_frob.pkl', 'rb') as f:
    results = pickle.load(f)

results2 = {}
for model in results:
    for key in results[model]:
        for var in results[model][key]:
            if var not in results2:
                results2[var] = {}
            if model not in results2[var]:
                results2[var][model] = []

            results2[var][model].append(results[model][key][var])

# %%
def formatter(number):
    string = np.format_float_positional(number*100, precision=2, fractional=False, unique=False)
    return f'${string}\%$'

testcase_names = ['reference', #'noisy distances',
                  'shifted RVE', 'extended RVE', 'reflected', 'rotated', 'scaled']  #, 'diameter 0.8 (unbuckled)', 'diameter 0.8 (buckled)', 'finer mesh']

model_names = ['GNN',
         'GNN, DA ×1',
         'GNN, DA ×2',
         'EGNN',
         'EGNN, DA ×1',
         'EGNN, DA ×2',
         'SimEGNN',
        #  'EGNNmod2_bigger'
        ]

print(r'\begin{table}')
for var, symbol, name in zip(['y', 'W', 'P', 'D'],
                               [r'$\vec{w}$',
                                r'$\mathfrak{W}$',
                                r'$\textbf{\textrm{P}}$',
                                r'$\textbf{\textrm{D}}$'
                                ],
                             ['microfluctuation', 'strain energy density', 'first Piola-Kirchhoff stress tensor',
                              'stiffness tensor']):

    df = pd.DataFrame(results2[var], index = results['GNN'].keys())

    # shuffle the rows to the order I want
    df = df.loc[testcase_names]

    # rename the columns
    df.columns = model_names

    # find best result per row
    inds = np.argmin(df.values, axis=1)

    # apply formatting
    df = df.applymap(formatter)

    # make best result per row boldface
    for i, ind in enumerate(inds):
        df.iloc[i, ind] = '$\mathbf{' + df.iloc[i, ind][1:-3] + '}\%$'

    print(r'\centering')
    print(r'\tiny')
    print(r'\caption{Relative error of the ', f'{name} {symbol}', r'\label{tab:resultsfrob', var, '}}', sep='')
    print(df.to_latex(escape=False))
    print(r'\bigskip')
    print('')

print(r'\end{table}')