# Import and compare results multiple GNNs


In [None]:
import mlflow
import matplotlib.pyplot as plt

## Compare nr of MP steps

In [None]:
# create list of ids (and other info) of runs to compare
ids = []
for run_info in mlflow.list_run_infos(experiment_id='5'):
    run_id = run_info.run_id
    run = mlflow.get_run(run_id=run_id)

    try:
        name = run.data.tags['mlflow.runName']

        if (name.startswith('(')
            or name.startswith('globalAttr')
            or name == 'no_bifurcation_5_globalAttr'
        ):
            ids.append(run_id)
            print(run_id, end='\t')
            print(name, end='\t')
            try:
                print(run.data.tags['mlflow.note.content'], end='')
            except KeyError:
                pass
            print('')
    except KeyError:
        pass
client = mlflow.tracking.MlflowClient()

data_gA = {}
data_no_gA = {}
for data in [data_gA, data_no_gA]:
    for key in ['MSE_pos', 'MSE_W', 'reuse_layers', 'nMPsteps']:
        data[key] = []

for run_id in ids:
    run = mlflow.get_run(run_id=run_id)

    run_name = run.data.tags['mlflow.runName']
    if 'globalAttr' in run_name:
        data = data_gA
    else:
        data = data_no_gA

    data['MSE_W'].append(run.data.metrics['val MSE W'])
    data['MSE_pos'].append(run.data.metrics['val MSE pos'])
    reuse_layers = run.data.params['reuse_layers']
    nMPsteps = sum(eval(reuse_layers))
    data['reuse_layers'].append(reuse_layers)
    data['nMPsteps'].append(nMPsteps)

# make sure matplotlib's latex can do \mathfrak correctly
plt.rcParams.update({
    "text.usetex": True,
    'text.latex.preamble': r'\usepackage{amssymb} \usepackage{{amsmath}}',
})
import numpy as np

%matplotlib qt
# plot MSE W
plt.plot(data_gA['nMPsteps'], data_gA['MSE_W'], label='with global attribute')

inds = np.argsort(data_no_gA['nMPsteps'])
plt.plot([data_no_gA['nMPsteps'][ind] for ind in inds], [data_no_gA['MSE_W'][ind] for ind in inds], label='without global attribute')

plt.grid()
plt.xlabel('MP steps')
plt.ylabel('MSE $\mathfrak{W}$')
plt.legend()
plt.xticks(ticks=data_gA['nMPsteps'])
plt.gca().set_xticklabels(data_gA['reuse_layers'])
plt.gcf().autofmt_xdate()
plt.subplots_adjust(bottom=0.2)
%matplotlib qt
# plot MSE pos
plt.plot(data_gA['nMPsteps'], data_gA['MSE_pos'], label='with global attribute')
plt.plot([data_no_gA['nMPsteps'][ind] for ind in inds], [data_no_gA['MSE_pos'][ind] for ind in inds], label='without global attribute')
plt.grid()
plt.xlabel('MP steps')
plt.ylabel('MSE position')
plt.legend()
plt.xticks(ticks=data_gA['nMPsteps'])
plt.gca().set_xticklabels(data_gA['reuse_layers'])
plt.gcf().autofmt_xdate()
plt.subplots_adjust(bottom=0.2)

# Import and compare results multiple GNNs

In [None]:
import mlflow
import matplotlib.pyplot as plt
import matplotlib

In [None]:
# list runs
for run in mlflow.list_run_infos(experiment_id='5'):
    print(run.run_id, end='\t')
    run = mlflow.get_run(run_id=run.run_id)
    print(run.data.params['layers'][:15], end='\t')
    print(run.data.params['reuse_layers'], end='\t')
    try:
        print(run.data.tags['mlflow.runName'], end='\t')
    except KeyError:
        pass
    try:
        print(run.data.tags['mlflow.note.content'], end='')
    except KeyError:
        pass
    print('')

## Compare batch sizes

In [None]:
# ids = ['c82a464bd54a4933ab85f6a71bdba2cd', 'e033798fb36c4bc59b149255ce43a7fe', '70b3d708f9534d1288112614331c0a3e', 'fec4d3292a524fd095488697b5f3f5ef', 'fe639eeb96d84839ac056b129148e547',	'308029a8ef724573b79bc8b5ba3be5a9', '969223051e974a168efa12b64cc38738']

ids = ['cd8e6f5c223f4178a27f316c59cd66b7', '5fe30eae970d402097abf91ac8eff04d', 'de28ea97dc754df9ba75f313c17a6385']

In [None]:
client = mlflow.tracking.MlflowClient()

mse_val1 = []
mse_val2 = []
bs_arr = []
for run_id in ids:
    run = mlflow.get_run(run_id=run_id)
    print(run.data.tags['mlflow.runName'])
    # print(run.data)
    hist = client.get_metric_history(run_id, 'val MSE pos')
    mse_val1.append([elem.value for elem in hist])
    hist = client.get_metric_history(run_id, 'val MSE W_scaled')
    scale_factor = run.data.params['scaling_factor_W']
    mse_val2.append([elem.value*scale_factor**2 for elem in hist])

    # mse_val1.append(run.data.metrics['val MSE pos'])
    # mse_val2.append(run.data.metrics['val MSE W_scaled'])

    bs = eval(run.data.params['lr_schedule'])[0][2]
    bs_arr.append(bs)

print(mse_val1)
print(mse_val2)
print(bs_arr)

In [None]:
%matplotlib qt
colors = plt.cm.tab20(np.arange(20))
plt.gca().set_prop_cycle('color', colors)
for mse1, mse2, bs in zip(mse_val1, mse_val2, bs_arr): #, ['tab:blue', 'tab:orange', 'tab:green']):
    x = np.linspace(0, 100, len(mse1))
    plt.plot(x, mse1, label=f'batch size {bs}, MSE position')
    plt.plot(x, mse2, label=f'batch size {bs}' + ' MSE $\mathfrak{W}$') #, linestyle='--')
# plt.plot(mp_steps, mse_val, label='validation', marker='o')
plt.grid()
plt.yscale('log')
plt.xlabel('training progress')
plt.ylabel('MSE')
plt.legend()
plt.gca().xaxis.set_major_formatter(matplotlib.ticker.PercentFormatter())
# plt.gca().xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True))

## Compare nr of MP steps

In [None]:
# create list of ids (and other info) of runs to compare
ids = []
for run_info in mlflow.list_run_infos(experiment_id='5'):
    run_id = run_info.run_id
    run = mlflow.get_run(run_id=run_id)

    try:
        name = run.data.tags['mlflow.runName']

        if (name.startswith('(')
            or name.startswith('globalAttr')
            or name == 'no_bifurcation_5_globalAttr'
        ):
            ids.append(run_id)
            print(run_id, end='\t')
            print(name, end='\t')
            try:
                print(run.data.tags['mlflow.note.content'], end='')
            except KeyError:
                pass
            print('')
    except KeyError:
        pass

In [None]:
client = mlflow.tracking.MlflowClient()

data_gA = {}
data_no_gA = {}
for data in [data_gA, data_no_gA]:
    for key in ['MSE_pos', 'MSE_W', 'reuse_layers', 'nMPsteps']:
        data[key] = []

for run_id in ids:
    run = mlflow.get_run(run_id=run_id)

    run_name = run.data.tags['mlflow.runName']
    if 'globalAttr' in run_name:
        data = data_gA
    else:
        data = data_no_gA

    data['MSE_W'].append(run.data.metrics['val MSE W'])
    data['MSE_pos'].append(run.data.metrics['val MSE pos'])
    reuse_layers = run.data.params['reuse_layers']
    nMPsteps = sum(eval(reuse_layers))
    data['reuse_layers'].append(reuse_layers)
    data['nMPsteps'].append(nMPsteps)


In [None]:
# make sure matplotlib's latex can do \mathfrak correctly
plt.rcParams.update({
    "text.usetex": True,
    'text.latex.preamble': r'\usepackage{amssymb} \usepackage{{amsmath}}',
})

In [None]:
import numpy as np

In [None]:
%matplotlib qt
# plot MSE W
plt.plot(data_gA['nMPsteps'], data_gA['MSE_W'], label='with global attribute')

inds = np.argsort(data_no_gA['nMPsteps'])
plt.plot([data_no_gA['nMPsteps'][ind] for ind in inds], [data_no_gA['MSE_W'][ind] for ind in inds], label='without global attribute')

plt.grid()
plt.xlabel('MP steps')
plt.ylabel('MSE $\mathfrak{W}$')
plt.legend()
plt.xticks(ticks=data_gA['nMPsteps'])
plt.gca().set_xticklabels(data_gA['reuse_layers'])
plt.gcf().autofmt_xdate()
plt.subplots_adjust(bottom=0.2)

In [None]:
%matplotlib qt
# plot MSE pos
plt.plot(data_gA['nMPsteps'], data_gA['MSE_pos'], label='with global attribute')
plt.plot([data_no_gA['nMPsteps'][ind] for ind in inds], [data_no_gA['MSE_pos'][ind] for ind in inds], label='without global attribute')
plt.grid()
plt.xlabel('MP steps')
plt.ylabel('MSE position')
plt.legend()
plt.xticks(ticks=data_gA['nMPsteps'])
plt.gca().set_xticklabels(data_gA['reuse_layers'])
plt.gcf().autofmt_xdate()
plt.subplots_adjust(bottom=0.2)

# Loss plot

In [None]:
import mlflow
import numpy as np

In [None]:

run_ids = ['9453aa73b9ee4b42ac9560ff37693d6f',
           'f25933db5e5545388265e2e9261edca3',
           '64e45fceb1eb46b6a977851c7123bca8',
           '0f54a8a568094a3085cc387fe21c3d38',
           'bdca87c393db4da3b387a805426d25d2',
           'a49f45e0cc3743ab914283c4ea0e6d60',
           'f09812771fec478eb1195216f3ae018d',
           ]
model_names = ['GNN',
         'GNN, DA ×1',
         'GNN, DA ×2',
         'EGNN',
         'EGNN, DA ×1',
         'EGNN, DA ×2',
         'SimEGNN',
        #  'EGNNmod2_bigger'
        ]

In [None]:
client = mlflow.tracking.MlflowClient()

data = {}

for run_id, model_name in zip(run_ids, model_names):
    data[model_name] = {}
    run = mlflow.get_run(run_id=run_id)
    run_name = run.data.tags['mlflow.runName']
    print('run_name:', run_name)
    print('model_name:', model_name)
    weight_losses = run.data.params['weight_losses']
    weight_losses = [float(elem) for elem in weight_losses[1:-1].split()]

    for i, var in enumerate(['pos', 'W', 'P', 'D']):
        hist = client.get_metric_history(run_id, 'val MSE '+ var + '_scaled')
        hist = np.array([elem.value for elem in hist])
        data[model_name][var] = hist

        print('hist:', hist)
        print('type(hist):', type(hist))

        if 'total loss' not in data[model_name]:
            data[model_name]['total loss'] = hist  #*weight_losses[i]
        else:
            data[model_name]['total loss'] += hist  #*weight_losses[i]

data

In [None]:
for i, var in enumerate(['pos', 'W', 'P', 'D']):
    plt.hist(data['SimEGNN'][var]  #*weight_losses[i]
             , bins=np.linspace(0, 0.05, 100), label=var, density=True, histtype='step')
plt.legend()
plt.yscale('log')

In [None]:
epochs = np.arange(-1, len(data['GNN']['total loss'])-1)

In [None]:
# make sure matplotlib's latex can do \mathfrak correctly
plt.rcParams.update({
    "text.usetex": True,
    'text.latex.preamble': r'\usepackage{amssymb} \usepackage{{amsmath}}',
    'font.size' : 24
})

In [None]:
%matplotlib qt

plt.figure(figsize=(13,8))
# plot MSE W
models = ['GNN',
          'GNN, DA ×1', 'GNN, DA ×2',
         'EGNN',
         'EGNN, DA ×1', 'EGNN, DA ×2',
          'SimEGNN']

colors = [plt.get_cmap('tab20c').colors[0],
          plt.get_cmap('tab20c').colors[1],
          plt.get_cmap('tab20c').colors[3],
          plt.get_cmap('tab20c').colors[4],
          plt.get_cmap('tab20c').colors[5],
          plt.get_cmap('tab20c').colors[7],
          'tab:green']
# linestyles = ['solid',
#               'dashed', 'dotted',
#               'solid',
#               'dashed', 'dotted',
#               'solid']

for model_name, c in zip(model_names, colors):
    plt.plot(epochs[1:], data[model_name]['total loss'][1:], label=model_name, c=c)

plt.grid()
plt.xlabel('epoch')
plt.ylabel('validation loss')
plt.yscale('log')
plt.legend()
# plt.xticks(ticks=data_gA['nMPsteps'])
# plt.gca().set_xticklabels(data_gA['reuse_layers'])
# plt.gcf().autofmt_xdate()
# plt.subplots_adjust(bottom=0.2)


plt.gcf().savefig('results/final_results/loss_plot.png', ddpi=600, bbox_inches='tight')