Load a previously trained model and test it.

# Imports

In [None]:
import os
import pickle
import sys
import importlib

import mlflow
import torch
import torch.nn as nn
from torch import nn, Tensor
import torch_geometric as tg
import matplotlib.pyplot as plt
import numpy as np

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print('cuda available')
else:
    device = torch.device('cpu')
    print('cuda not available')

# Import model

In [None]:
run_id = '167fac2522fd465f9207c582ada7c716'  # EGNN. Uses symmetric matching.
# run_id = '7d5e596c25b44b06bd90c38f51fd3142'  # EGNN, but without symmetric matching

In [None]:

client = mlflow.tracking.MlflowClient()

In [None]:

art_path = mlflow.get_run(run_id).info.artifact_uri[8:]
print(art_path)

# Fetch the logged artifacts
artifacts = client.list_artifacts(run_id)


In [None]:
# get logged parameter network_type
network_type = client.get_run(run_id).data.params['network_type']
print('network_type:', network_type)

In [None]:

# find the files with the weights of the model and the parameters needed to initialize it
for artifact in artifacts:
    if 'model_weights.pt' in artifact.path:
        model_path = os.path.join(art_path, artifact.path)
        print(model_path)


In [None]:

# make sure model definition is imported from the right directory
sys.path.insert(0, art_path)

# import model definition
files = os.listdir(art_path)
GNN_definition = [file for file in files if file.startswith(network_type)]
if len(GNN_definition) > 1:
    raise Exception('Multiple GNN definitions found')
elif len(GNN_definition) == 0:
    raise Exception('No GNN definition found')
print('GNN_definition:', GNN_definition[0])


In [None]:
model_def = importlib.import_module(GNN_definition[0].split('.')[0])
model_def = importlib.reload(model_def)


In [None]:
class Flow(nn.Module):
    def __init__(self):
        super(Flow, self).__init__()
        if network_type == 'GNNtimeConv':
            # layers =  [(4, 3, 1, 2, 0)]
            # node_in, edge_in, message_size, node_out, edge_out:
            layers =  [(4, 3, 32, 32, 32),
                       (32, 32, 32, 32, 32),
                       (32, 32, 1, 2, 0)
                       ]

            # reuse_layers = (1,)
            reuse_layers = (1,1,1)
            self.model = model_def.GNN(layers=layers, reuse_layers=reuse_layers).to(device)
        elif network_type == 'EGNNtimeConv':
            # layers =  [(4, 3, 1, 0, 0)]
            layers =  [(4, 3, 32, 32, 32),
                       (32, 32, 32, 32, 32),
                       (32, 32, 1, 0, 0)
                       ]
            # reuse_layers = (1,)
            reuse_layers = (1,1,1)
            self.model = model_def.EGNN(layers=layers, reuse_layers=reuse_layers).to(device)

    def forward(self, t, x_t, batch, verbose=False) -> Tensor:
        # t: torch.tensor, shape [batch size,], current flow-matching time
        # x_t: torch.tensor, shape [batch size, 2, T], current position

        x_shift = self.model(batch=batch, current_pos=x_t, tau=t, verbose=verbose)

        return x_shift

    def step(self, x_t: Tensor, t_start: Tensor, t_end: Tensor, batch) -> Tensor:
        # t_start: float, current time
        # t_end: float, end time
        # x_t: shape [batch size, 2, T], current position

        t_start = t_start.expand(x_t.shape[0]).view(-1, 1, 1)
        t_end = t_end.expand(x_t.shape[0]).view(-1, 1, 1)

        return (x_t + (t_end - t_start)
                * self(
                    t=(t_start + (t_end - t_start) / 2).view(-1),
                    x_t= x_t + self(x_t=x_t, t=t_start.view(-1), batch=batch) * (t_end - t_start) / 2,
                    batch=batch
                        )
                )


flow = Flow().to(device)
print(flow)

In [None]:

flow.load_state_dict(torch.load(model_path))
print(flow)
flow.to(device)

# Load data

In [None]:
data_folder = r'path/to/data/folder'  # TODO: set this to your local data folder

data_path = os.path.join(data_folder, 'BucklingBeams_data_fullyConnected.pkl')
mlflow.log_param('data_path', data_path)
with open(data_path, 'rb') as f:
    data = pickle.load(f)


In [None]:
for key in ['data_tr', 'data_te']:
    for graph in data[key]:
        N = graph.N[0].item()

        # divide K_i by L_i to get proper scaling
        Li = graph.edge_attr[:N, [0]]
        # graph.node_attr[0, 3:] /= Li[0]  # divide K_i node 0 by L_0
        # graph.node_attr[1:N, 3:] /= 0.5*(Li[:N-1] + Li[1:])   # divide K_i nodes 1-N by 1/2(L_i+L_i-1)

        # move L_i feature to its own tensor (should be treated separately because of scaling)
        graph.L_init = graph.edge_attr[:, 0]
        graph.edge_attr = graph.edge_attr[:, 1:]  # remove L_i from edge_attr

        graph.L_tot = graph.L_init[:N].sum().reshape(1,) # total length of beam


# Test model

In [None]:
prior_type = client.get_run(run_id).data.params['prior_type']
print('prior_type:', prior_type)

if prior_type == 'Prior_wide':
    mu_phi, std_phi = -0.00030389729903857826, 0.014231439025614998
    mu_eps, std_eps = 2.0113846990910993e-05, 0.008196196348061435
    a, b = 4.042343191342216e-05, 0.0025632507
    prior = model_def.Prior_wide(std_phi=std_phi, mu_eps=mu_eps, std_eps=std_eps, a=a, b=b)
elif prior_type == 'Prior':
    lamb = 94.09494942436929
    mu_eps = 2.0113846990910993e-05
    std_eps = 0.008196196348061435
    prior = model_def.Prior(lamb=lamb, mu_eps=mu_eps, std_eps=std_eps)

In [None]:
test_loader3 = tg.loader.DataLoader(data['data_te'], batch_size=2)

n_steps = 64
time_steps = torch.linspace(0, 1.0, n_steps+1).to(device)

n_samples = 50
preds = np.empty((n_samples, 3, 2, 200))  # (n_samples, n_nodes, dim, n_timesteps)

# find one example with N=2
with torch.no_grad():
    for j, batch in enumerate(test_loader3):
        if batch.N[0].item() != 2:
            print('Skipping batch with N != 2, N =', batch.N[0].item())
            continue

        print('Batch:', j)
        print(batch)
        batch = batch.to(device)
        for s in range(n_samples):


            e = batch.edge_index
            L_init_temp = batch.L_init[e[0] == e[1]-1]  # use only the beam element edges, not the virtual ones, not the reversed ones

            x = prior(batch.N, batch.d, batch.node_attr[..., :3], L_init_temp, batch.batch, batch.L_tot)

            real = batch.pos[batch.batch==0, ..., 0].cpu().detach().numpy()  # final position (target)

            for i in range(n_steps):
                # print(f'Step {i}/{n_steps}')

                x = flow.step(x_t=x, t_start=time_steps[i], t_end=time_steps[i + 1], batch=batch)

            pred = x[batch.batch==0].cpu().detach().numpy()
            preds[s] = pred

        break

## Bifurcation diagram

In [None]:
import matplotlib.ticker as mtick

In [None]:
fig, ax = plt.subplots(figsize=(4,3), dpi=200)

# colors = plt.get_cmap('viridis')(np.linspace(0, 1, real.shape[0]))
colors = plt.get_cmap('tab10')(np.arange(real.shape[0]))
d = (real[-1,1,0] - real[-1,1,:])/real[-1,1,0]*100
for s in range(n_samples):
    pred = preds[s]

    for j in range(pred.shape[0]):
        plt.plot(d, pred[j, 0, :].T, linestyle='--', color=colors[j], alpha=0.2)

for j in range(real.shape[0]):
    color = plt.get_cmap('viridis')(j / real.shape[0])
    plt.plot(d, real[j, 0, :].T, color=colors[j], linewidth=2)
    plt.plot(d, -real[j, 0, :].T, color=colors[j], linewidth=2, label=f'Node {j}')

fmt = '%.0f%%' # Format you want the ticks, e.g. '40%'
xticks = mtick.FormatStrFormatter(fmt)
ax.xaxis.set_major_formatter(xticks)


plt.xlabel('d')
plt.ylabel('x-coordinate')
plt.legend()

## Plot entire trajectory, all predictions

In [None]:
fig, ax = plt.subplots(figsize=(4,3), dpi=200)

for s in range(n_samples):
    pred = preds[s]

    for t in np.arange(pred.shape[-1], step=10):  # loop over time steps
        # print('t =', t)
        for j in range(len(pred)-1): # loop over segments
            ax.plot(pred[j:j+2, 0, t], pred[j:j+2, 1, t],
                    color=plt.cm.viridis(t / (pred.shape[-1] - 1)))
        ax.scatter(pred[:, 0, t], pred[:, 1, t], s=3, color='black')
    ax.set_title(r'predictions')
    ax.set_aspect('equal')
    plt.tight_layout()
    plt.show

## Plot real and predicted, one figure per time step

In [None]:
for t in [20, 100, 150, 199]:  # loop over time steps
    print('t =', t)

    fig, ax = plt.subplots(figsize=(4,3), dpi=200)

    # plot real
    ax.plot(real[:, 0, t], real[:, 1, t], color='tab:blue', linewidth=2, label='Real', marker='o', markersize=5, zorder=10)
    ax.plot(-real[:, 0, t], real[:, 1, t], color='tab:blue', linewidth=2, marker='o', markersize=5, zorder=10)

    # plot predictions
    for s in range(n_samples):
        pred = preds[s]

        ax.plot(pred[:, 0, t], pred[:, 1, t], color='tab:orange', linewidth=2, label=s*'_'+'Predicted', marker='o', markersize=5, alpha=0.3)

    # Plot start
    ax.plot(real[:, 0, 0], real[:, 1, 0], color='gray', marker='o', markersize=5, label='Initial')

    ax.set_title(f'd = {t/(199)*100:.0f}%')
    ax.set_aspect('equal')

    # get x limits
    lims = ax.get_xlim()
    if lims[0] > -1 or lims[1] < 1:
        plt.xlim([-1, 1])
    plt.tight_layout()

    ax.set_xlabel('x')
    ax.set_ylabel('y')
    if t == 199:
        plt.legend()
    plt.show()

# Illustration: plot real, 2 time steps

In [None]:
for j, [t1, t2] in enumerate([[0, 20], [20, 40], [40, 60]]):  # loop over time steps

    fig, ax = plt.subplots(figsize=(4,3), dpi=200)

    for i, [t, c, l] in enumerate(zip([t1, t2],
                                      ['tab:blue', 'tab:orange'],
                                      ['t', 't+1'])):
        # plot real
        ax.plot(real[:, 0, t], real[:, 1, t], color=c, linewidth=2,
                marker='o', markersize=5, zorder=10, label=l)
        ax.plot(-real[:, 0, t], real[:, 1, t], color=c, linewidth=2,
                marker='o', markersize=5, zorder=10)

    # Plot start
    ax.plot(real[:, 0, 0], real[:, 1, 0], color='gray', marker='o', markersize=5, label='Initial')

    ax.set_title(f'$d$ from {t1/(199)*100:.0f}% to {t2/(199)*100:.0f}%')
    ax.set_aspect('equal')

    plt.xlabel('x')
    plt.ylabel('y')

    # add extra margin to axes limits

    # get x limits
    lims = ax.get_xlim()
    if lims[0] > -0.8 or lims[1] < 0.8:
        plt.xlim([-0.8, 0.8])
    # else:
    #     plt.xlim([lims[0]-0.4, lims[1]+0.4])
    plt.tight_layout()

    # if j == 0:
    #     plt.legend(loc='lower right')
    plt.show

# Wasserstein distance

In [None]:
prior_type = client.get_run(run_id).data.params['prior_type']
print('prior_type:', prior_type)

if prior_type == 'Prior_wide':
    mu_phi, std_phi = -0.00030389729903857826, 0.014231439025614998
    mu_eps, std_eps = 2.0113846990910993e-05, 0.008196196348061435
    a, b = 4.042343191342216e-05, 0.0025632507
    prior = model_def.Prior_wide(std_phi=std_phi, mu_eps=mu_eps, std_eps=std_eps, a=a, b=b)
elif prior_type == 'Prior':
    lamb = 94.09494942436929
    mu_eps = 2.0113846990910993e-05
    std_eps = 0.008196196348061435
    prior = model_def.Prior(lamb=lamb, mu_eps=mu_eps, std_eps=std_eps)

In [None]:
from scipy.stats import wasserstein_distance_nd

test_loader3 = tg.loader.DataLoader(data['data_te'], batch_size=64)

n_steps = 8
time_steps = torch.linspace(0, 1.0, n_steps+1).to(device)

T = data['data_te'][0].pos.shape[-2]
print('T =', T)

n_samples = 100

wd = []
with torch.no_grad():
    for j, batch in enumerate(test_loader3):
        print('Batch:', j)
        print(batch)
        batch = batch.to(device)

        # create 100 predictions per node
        preds = np.zeros((len(batch.pos), 2, T, n_samples)) # shape (n_nodes, dim, T, n_samples)

        for s in range(n_samples):


            e = batch.edge_index
            L_init_temp = batch.L_init[e[0] == e[1]-1]  # use only the beam element edges, not the virtual ones, not the reversed ones

            x = prior(batch.N, batch.d, batch.node_attr[..., :3], L_init_temp, batch.batch, batch.L_tot)

            real = batch.pos[batch.batch==0, ..., 0].cpu().detach().numpy()  # final position (target)

            for i in range(n_steps):
                # print(f'Step {i}/{n_steps}')

                x = flow.step(x_t=x, t_start=time_steps[i], t_end=time_steps[i + 1], batch=batch)

            pred = x.cpu().detach().numpy()
            preds[..., s] = pred

        reals = batch.pos.cpu().detach().numpy() # shape (n_nodes, dim, T, n_solutions)

        bbatch = batch.batch.cpu().detach().numpy()

        print('Calculate Wasserstein distance')
        for i in range(batch.batch.max()):  # iterate over all graphs in the batch
            pred = preds[bbatch==i].reshape(-1, n_samples)  # shape (n_nodes×dim×T, n_samples)
            real = reals[bbatch==i].reshape(-1, 2)  # shape (n_nodes×dim×T, n_solutions)
            wd.append(wasserstein_distance_nd(pred.T, real.T))

print('Mean Wasserstein distance:', np.mean(wd))