# Model training

This notebook can be used to train the models used for the replication experiment. The notebook makes heavy use of predefined configuration files that describe the parameter setting of each model. Pretrained models using these specific parameters are also already available. Hence, retraining the models is not needed if you only wish to replicate the replication experiment. 

If you do wish to perform the replication experiments with your own retrained models, it is not sufficient to only retrain the model with this script. To prevent the training script from accidentally overriding the pretrained models, the models are saved in a different location then where the pretrained models are loaded from. 

**To replace the pretrained models in the replication study** you therefore need to copy the trained model from `checkpoints` to `Explanation/models/pretrained/<_model>/<_dataset>`. Where \_model and \_dataset are defined as in the code below. 

In [1]:
from ExplanationEvaluation.configs.selector import Selector
from ExplanationEvaluation.tasks.training import train_node, train_graph

import torch
import numpy as np

In [2]:
_dataset = 'bashapes' # One of: bashapes, bacommunity, treecycles, treegrids, ba2motifs, mutag

# Parameters below should only be changed if you want to run any of the experiments in the supplementary
_folder = 'replication' # One of: replication, batchnorm
_model = 'gnn' if _folder == 'replication' else 'ori'

# PGExplainer
config_path = f"./ExplanationEvaluation/configs/{_folder}/models/model_{_model}_{_dataset}.json"

config = Selector(config_path)
extension = (_folder == 'extension')

In [3]:
config = Selector(config_path).args

torch.manual_seed(config.model.seed)
torch.cuda.manual_seed(config.model.seed)
np.random.seed(config.model.seed)

In [4]:
_dataset = config.model.dataset
_explainer = config.model.paper

if _dataset[:3] == "syn":
    train_node(_dataset, _explainer, config.model)
elif _dataset == "ba2" or _dataset == "mutag":
    train_graph(_dataset, _explainer, config.model)

Loading syn1 dataset
NodeGCN(
  (conv1): GCNConv(10, 20)
  (relu1): ReLU()
  (conv2): GCNConv(20, 20)
  (relu2): ReLU()
  (conv3): GCNConv(20, 20)
  (relu3): ReLU()
  (lin): Linear(in_features=60, out_features=4, bias=True)
)
Epoch: 0, train_acc: 0.2161, val_acc: 0.3000, train_loss: 1.4170
Val improved
Epoch: 1, train_acc: 0.2161, val_acc: 0.3000, train_loss: 1.4121
Epoch: 2, train_acc: 0.2161, val_acc: 0.3000, train_loss: 1.4074
Epoch: 3, train_acc: 0.2161, val_acc: 0.3000, train_loss: 1.4039
Epoch: 4, train_acc: 0.2161, val_acc: 0.3000, train_loss: 1.4007
Epoch: 5, train_acc: 0.2161, val_acc: 0.3000, train_loss: 1.3975
Epoch: 6, train_acc: 0.2161, val_acc: 0.3000, train_loss: 1.3945
Epoch: 7, train_acc: 0.2161, val_acc: 0.3000, train_loss: 1.3916
Epoch: 8, train_acc: 0.2161, val_acc: 0.3000, train_loss: 1.3887
Epoch: 9, train_acc: 0.2161, val_acc: 0.3000, train_loss: 1.3859
Epoch: 10, train_acc: 0.2161, val_acc: 0.3000, train_loss: 1.3832
Epoch: 11, train_acc: 0.2161, val_acc: 0.3000

Epoch: 135, train_acc: 0.6768, val_acc: 0.6286, train_loss: 0.9186
Epoch: 136, train_acc: 0.6821, val_acc: 0.6143, train_loss: 0.9110
Epoch: 137, train_acc: 0.6714, val_acc: 0.6000, train_loss: 0.9054
Epoch: 138, train_acc: 0.6732, val_acc: 0.6286, train_loss: 0.8984
Epoch: 139, train_acc: 0.6821, val_acc: 0.6143, train_loss: 0.8919
Epoch: 140, train_acc: 0.6804, val_acc: 0.6429, train_loss: 0.8870
Val improved
Epoch: 141, train_acc: 0.6679, val_acc: 0.6286, train_loss: 0.8806
Epoch: 142, train_acc: 0.6661, val_acc: 0.6286, train_loss: 0.8761
Epoch: 143, train_acc: 0.6893, val_acc: 0.6286, train_loss: 0.8707
Epoch: 144, train_acc: 0.7107, val_acc: 0.6571, train_loss: 0.8647
Val improved
Epoch: 145, train_acc: 0.7000, val_acc: 0.6286, train_loss: 0.8608
Epoch: 146, train_acc: 0.6875, val_acc: 0.6286, train_loss: 0.8553
Epoch: 147, train_acc: 0.6893, val_acc: 0.6286, train_loss: 0.8506
Epoch: 148, train_acc: 0.6982, val_acc: 0.6286, train_loss: 0.8460
Epoch: 149, train_acc: 0.7071, val_a

Epoch: 257, train_acc: 0.8536, val_acc: 0.9000, train_loss: 0.6279
Epoch: 258, train_acc: 0.8536, val_acc: 0.9000, train_loss: 0.6319
Epoch: 259, train_acc: 0.8589, val_acc: 0.9000, train_loss: 0.6302
Epoch: 260, train_acc: 0.8589, val_acc: 0.9000, train_loss: 0.6233
Epoch: 261, train_acc: 0.8571, val_acc: 0.9000, train_loss: 0.6224
Epoch: 262, train_acc: 0.8571, val_acc: 0.9000, train_loss: 0.6207
Epoch: 263, train_acc: 0.8571, val_acc: 0.9000, train_loss: 0.6219
Epoch: 264, train_acc: 0.8607, val_acc: 0.9000, train_loss: 0.6180
Epoch: 265, train_acc: 0.8625, val_acc: 0.9000, train_loss: 0.6220
Epoch: 266, train_acc: 0.8625, val_acc: 0.9000, train_loss: 0.6216
Epoch: 267, train_acc: 0.8589, val_acc: 0.9000, train_loss: 0.6186
Epoch: 268, train_acc: 0.8500, val_acc: 0.9000, train_loss: 0.6146
Epoch: 269, train_acc: 0.8536, val_acc: 0.9000, train_loss: 0.6152
Epoch: 270, train_acc: 0.8554, val_acc: 0.9000, train_loss: 0.6146
Epoch: 271, train_acc: 0.8607, val_acc: 0.9000, train_loss: 0.

Epoch: 389, train_acc: 0.9643, val_acc: 0.9714, train_loss: 0.4987
Epoch: 390, train_acc: 0.9643, val_acc: 0.9714, train_loss: 0.5088
Epoch: 391, train_acc: 0.9643, val_acc: 0.9714, train_loss: 0.5109
Epoch: 392, train_acc: 0.9643, val_acc: 0.9714, train_loss: 0.5023
Epoch: 393, train_acc: 0.9643, val_acc: 0.9714, train_loss: 0.5022
Epoch: 394, train_acc: 0.9625, val_acc: 0.9714, train_loss: 0.5007
Epoch: 395, train_acc: 0.9607, val_acc: 0.9714, train_loss: 0.4966
Epoch: 396, train_acc: 0.9625, val_acc: 0.9714, train_loss: 0.4972
Epoch: 397, train_acc: 0.9643, val_acc: 0.9714, train_loss: 0.4947
Epoch: 398, train_acc: 0.9643, val_acc: 0.9714, train_loss: 0.4950
Epoch: 399, train_acc: 0.9625, val_acc: 0.9714, train_loss: 0.4939
Epoch: 400, train_acc: 0.9625, val_acc: 0.9714, train_loss: 0.4925
Epoch: 401, train_acc: 0.9643, val_acc: 0.9714, train_loss: 0.4923
Epoch: 402, train_acc: 0.9643, val_acc: 0.9714, train_loss: 0.4897
Epoch: 403, train_acc: 0.9643, val_acc: 0.9714, train_loss: 0.

Epoch: 515, train_acc: 0.9679, val_acc: 0.9714, train_loss: 0.4058
Epoch: 516, train_acc: 0.9679, val_acc: 0.9714, train_loss: 0.4031
Epoch: 517, train_acc: 0.9679, val_acc: 0.9714, train_loss: 0.4041
Epoch: 518, train_acc: 0.9696, val_acc: 0.9714, train_loss: 0.4042
Epoch: 519, train_acc: 0.9696, val_acc: 0.9714, train_loss: 0.4011
Epoch: 520, train_acc: 0.9679, val_acc: 0.9714, train_loss: 0.4017
Epoch: 521, train_acc: 0.9679, val_acc: 0.9714, train_loss: 0.3997
Epoch: 522, train_acc: 0.9679, val_acc: 0.9714, train_loss: 0.3996
Epoch: 523, train_acc: 0.9696, val_acc: 0.9714, train_loss: 0.3965
Epoch: 524, train_acc: 0.9696, val_acc: 0.9714, train_loss: 0.3951
Epoch: 525, train_acc: 0.9696, val_acc: 0.9714, train_loss: 0.3998
Epoch: 526, train_acc: 0.9696, val_acc: 0.9714, train_loss: 0.3957
Epoch: 527, train_acc: 0.9661, val_acc: 0.9571, train_loss: 0.3947
Epoch: 528, train_acc: 0.9679, val_acc: 0.9714, train_loss: 0.3967
Epoch: 529, train_acc: 0.9679, val_acc: 0.9714, train_loss: 0.

Epoch: 649, train_acc: 0.9679, val_acc: 0.9714, train_loss: 0.3289
Epoch: 650, train_acc: 0.9661, val_acc: 0.9571, train_loss: 0.3295
Epoch: 651, train_acc: 0.9732, val_acc: 0.9714, train_loss: 0.3246
Epoch: 652, train_acc: 0.9679, val_acc: 0.9571, train_loss: 0.3233
Epoch: 653, train_acc: 0.9679, val_acc: 0.9571, train_loss: 0.3295
Epoch: 654, train_acc: 0.9643, val_acc: 0.9571, train_loss: 0.3260
Epoch: 655, train_acc: 0.9661, val_acc: 0.9571, train_loss: 0.3270
Epoch: 656, train_acc: 0.9714, val_acc: 0.9571, train_loss: 0.3262
Epoch: 657, train_acc: 0.9732, val_acc: 0.9571, train_loss: 0.3219
Epoch: 658, train_acc: 0.9661, val_acc: 0.9571, train_loss: 0.3199
Epoch: 659, train_acc: 0.9643, val_acc: 0.9571, train_loss: 0.3278
Epoch: 660, train_acc: 0.9679, val_acc: 0.9714, train_loss: 0.3266
Epoch: 661, train_acc: 0.9732, val_acc: 0.9571, train_loss: 0.3237
Epoch: 662, train_acc: 0.9732, val_acc: 0.9571, train_loss: 0.3222
Epoch: 663, train_acc: 0.9679, val_acc: 0.9714, train_loss: 0.

Epoch: 776, train_acc: 0.9732, val_acc: 0.9571, train_loss: 0.2705
Epoch: 777, train_acc: 0.9696, val_acc: 0.9571, train_loss: 0.2739
Epoch: 778, train_acc: 0.9679, val_acc: 0.9571, train_loss: 0.2716
Epoch: 779, train_acc: 0.9732, val_acc: 0.9571, train_loss: 0.2710
Epoch: 780, train_acc: 0.9732, val_acc: 0.9571, train_loss: 0.2701
Epoch: 781, train_acc: 0.9732, val_acc: 0.9571, train_loss: 0.2698
Epoch: 782, train_acc: 0.9714, val_acc: 0.9571, train_loss: 0.2691
Epoch: 783, train_acc: 0.9732, val_acc: 0.9571, train_loss: 0.2655
Epoch: 784, train_acc: 0.9732, val_acc: 0.9571, train_loss: 0.2746
Epoch: 785, train_acc: 0.9732, val_acc: 0.9571, train_loss: 0.2738
Epoch: 786, train_acc: 0.9714, val_acc: 0.9571, train_loss: 0.2660
Epoch: 787, train_acc: 0.9679, val_acc: 0.9714, train_loss: 0.2647
Epoch: 788, train_acc: 0.9679, val_acc: 0.9571, train_loss: 0.2684
Epoch: 789, train_acc: 0.9714, val_acc: 0.9429, train_loss: 0.2671
Epoch: 790, train_acc: 0.9696, val_acc: 0.9429, train_loss: 0.

Epoch: 906, train_acc: 0.9732, val_acc: 0.9714, train_loss: 0.2343
Epoch: 907, train_acc: 0.9732, val_acc: 0.9571, train_loss: 0.2335
Epoch: 908, train_acc: 0.9732, val_acc: 0.9571, train_loss: 0.2333
Epoch: 909, train_acc: 0.9714, val_acc: 0.9429, train_loss: 0.2329
Epoch: 910, train_acc: 0.9714, val_acc: 0.9429, train_loss: 0.2285
Epoch: 911, train_acc: 0.9732, val_acc: 0.9571, train_loss: 0.2286
Epoch: 912, train_acc: 0.9732, val_acc: 0.9571, train_loss: 0.2260
Epoch: 913, train_acc: 0.9732, val_acc: 0.9571, train_loss: 0.2262
Epoch: 914, train_acc: 0.9732, val_acc: 0.9571, train_loss: 0.2276
Epoch: 915, train_acc: 0.9732, val_acc: 0.9714, train_loss: 0.2231
Epoch: 916, train_acc: 0.9732, val_acc: 0.9571, train_loss: 0.2321
Epoch: 917, train_acc: 0.9714, val_acc: 0.9429, train_loss: 0.2309
Epoch: 918, train_acc: 0.9714, val_acc: 0.9429, train_loss: 0.2252
Epoch: 919, train_acc: 0.9679, val_acc: 0.9571, train_loss: 0.2234
Epoch: 920, train_acc: 0.9732, val_acc: 0.9571, train_loss: 0.