### Imports

In [67]:
import torch
import torch.nn as nn
import torch_geometric.nn as gnn
import networks
import datasetLoader
import evaluation
import importlib
import matplotlib.pyplot as plt

importlib.reload(datasetLoader)
importlib.reload(networks)
importlib.reload(evaluation)


<module 'evaluation' from 'c:\\Users\\trist\\Git_repos\\BT-ML-PGESAT\\code\\PGExplainer\\evaluation.py'>

### Testing PyG

In [2]:
from torch_geometric.data import Data

edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index)
print(data)

Data(x=[3, 1], edge_index=[2, 4])


### Parameters

In [65]:
batch_size = 64                 # 64 used for Graphs in PGE (PGExplainer/codes/forgraph/config.py)      1 takes forever in current model



learning_rate_gnn = 0.001        #0.001 on ADAM
epochs_gnn = 1000
early_stopping = 500


learning_rate_mlp = 0.003        #0.003 on ADAM
coefficientSizeReg = 0.05
entropyReg = 1
epochs_mlp = 30

temperature =  0        #???


# TODO: Xavier initialization (torch.nn.init.xavier_uniform_(tensor, gain=1.0, generator=None) or torch.nn.init.xavier_normal_(tensor, gain=1.0, generator=None))
# lr scheduler? => not used in orig
# softmax after linear layer?       -> check code

# DONE: dropout     NOT USED IN OG
# DONE: early stopping with validation set

# cross validation?!


In [15]:
def weights_init(module):
    # TODO: GraphConv has no attribute weight!!
    #if isinstance(module, gnn.GraphConv):
    #    nn.init.xavier_normal_(module.weight.data)
    #    if module.bias is not None:
     #       nn.init.xavier_normal_(module.bias.data)
    if isinstance(module, nn.Linear):
        nn.init.xavier_normal_(module.weight.data)
        #if module.bias is not None:            TODO: This does not work either
        #    nn.init.xavier_normal_(module.bias.data)


## Training Loop GraphGNN

In [None]:
train_loader, val_loader, test_loader = datasetLoader.loadDataset('MUTAG', batch_size)               # TODO: adjs matrix is not compatible with GraphConv. Needs to be converted to edge_index (see RE_PGE datasets/utils)

temp = next(iter(train_loader))
gnn = networks.GraphGNN(features = temp.x.shape[1], labels=2)       # temp.y.shape[0] is wrong!!! TODO: how do I get #labels from loader?? take from loader.dataset?

gnn.apply(weights_init)

gnn_optimizer = torch.optim.Adam(params = gnn.parameters(), lr = learning_rate_gnn)         # TODO: understand params

print(f"Training on {len(train_loader.dataset)} graphs with batch size {batch_size}")

loss = nn.CrossEntropyLoss()           # cross entropy loss?!

early_stop_counter = 0
min_val_loss = 1000.0
best_val_acc = 0
best_epoch = 0

for epoch in range(0, epochs_gnn) :
    print(f'\n------------------ EPOCH {epoch + 1} -------------------')

    gnn.train()

    train_acc_sum = 0
    num_batches = 0.0
    train_loss = 0.0
    
    for batch_index, data in enumerate(train_loader):
        batch_size_ratio = len(data)/batch_size
        num_batches += batch_size_ratio
        
        gnn_optimizer.zero_grad()       # Reset parameters

        # real label
        label = data.y

        # get model embeddings (node representations)?
        # predicted label
        out = gnn.forward(data.x, data.edge_index, data.batch)

        # calc cross entropy(???)loss between real label and predicted label
        # needs to be calculated across batch
        currLoss = loss(out, label)

        #print(currLoss)

        # loss backward
        currLoss.backward()
        
        torch.nn.utils.clip_grad_norm_(gnn.parameters(), max_norm=2)    # clip gradient above 2(for ba2motfis according to reimplementation) to stop "overlearning"?

        # optimizer step
        gnn_optimizer.step()

        preds = out.argmax(dim=1)
        train_acc_sum += torch.sum(preds == data.y)                     # TODO: works with batches?
        
        train_loss += batch_size_ratio * currLoss.item()                       # use currLoss instead of currLoss.item() for batches

    final_train_acc = train_acc_sum/(num_batches*batch_size)                # num_batches*batch_size = len(train_loader.dataset) = #graphs

    gnn.eval()

    # avg loss
    print(f"average training loss: {train_loss/num_batches}, training acc: {final_train_acc}")

    val_acc, valLoss, test_acc, testLoss = evaluation.evaluateGNN(gnn, val_loader, test_loader)
    print(f"validation loss: {valLoss}, validation acc: {val_acc}, test loss: {testLoss}, test acc: {test_acc}")

    if(val_acc > best_val_acc):
        best_val_acc = val_acc
        best_epoch = epoch
    if(valLoss < min_val_loss):
        min_val_loss = valLoss
        early_stop_counter = 0
    elif(valLoss > min_val_loss):
        early_stop_counter += 1
        if early_stop_counter >= early_stopping:
            print("Stopping training due to early stopping threshold")
            print(f"highest validation accuracy: {best_val_acc} in epoch {best_epoch}")
            break

print(f"highest validation accuracy: {best_val_acc} in epoch {best_epoch}")


# move training loop Explainer
#mlp_optimizer = torch.optim.Adam(lr = learning_rate_mlp)


"""for i in enumerate(adjs):
    #out = gnn.forward(feas[i], adjs[i].nonzero().t().contiguous())

for epoch in epochs_graphgnn:
    for graph in adjs:
        # calculate latent variables? MLP?
        for k in # k in monte carlo sampling?!
            # sammple graph
            # pred label on sampled graph

    # compute loss
    # update params with backprop"""

    

Downloading https://www.chrsmrrs.com/graphkerneldatasets/Mutagenicity.zip


Training on 3470 graphs with batch size 64

------------------ EPOCH 1 -------------------
average training loss: 0.6657513978501906, training acc: 0.5812680125236511
validation loss: 0.6298005177128699, validation acc: 0.6096997857093811, test loss: 0.6109024220347954, test acc: 0.6221198439598083

------------------ EPOCH 2 -------------------
average training loss: 0.6205855575693444, training acc: 0.6524495482444763
validation loss: 0.5869775555924885, validation acc: 0.6859122514724731, test loss: 0.562876262697756, test acc: 0.7096773982048035

------------------ EPOCH 3 -------------------
average training loss: 0.588515118357084, training acc: 0.6878962516784668
validation loss: 0.5572813193094895, validation acc: 0.6882216930389404, test loss: 0.5370394480393229, test acc: 0.7027649879455566

------------------ EPOCH 4 -------------------
average training loss: 0.5658357291812512, training acc: 0.7020173072814941
validation loss: 0.5395018528408718, validation acc: 0.734411060

KeyboardInterrupt: 

### eval values with batch size 64 and learning rate 0,001 on 1000 Epochs MUTAG
average training loss: 0.12170956794397975, training acc: 0.9521613717079163
validation loss: 0.7258547251949662, validation acc: 0.8198614120483398, test loss: 0.8935344749332024, test acc: 0.7995391488075256

### GCNConv instead of GraphConv
average training loss: 0.3523677970215635, training acc: 0.8458213210105896
validation loss: 0.47687009828431265, validation acc: 0.7921478152275085, test loss: 0.5093637616952993, test acc: 0.7695852518081665
highest validation accuracy: 0.8152424693107605 in epoch 452

##### Saving Models

In [None]:
torch.save(gnn.state_dict(), f"models/BA2Motif")

##### Loading Models

In [None]:
#model = TheModelClass(*args, **kwargs)
#model.load_state_dict(torch.load(PATH, weights_only=True))
#model.eval()

In [None]:
print(len(test_loader))

print(next(iter(train_loader)))         # edge_index = "map" for edges, x = features, y = labels
#print(next(iter(train_loader)).x)

print("test_loader contains 100 graphs with following labels:")
for i, curr in enumerate(test_loader):
    print(curr.y)
    print(curr.x.shape[1])
    print(curr.batch)

    out = gnn.forward(curr.x, curr.edge_index, curr.batch)
    print(out.)

100
DataBatch(x=[25, 10], edge_index=[2, 52], y=[1], batch=[25], ptr=[2])
test_loader contains 100 graphs with following labels:
tensor([0])
10
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0])
tensor([[0.2440]], grad_fn=<AddmmBackward0>)
tensor([0])
10
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0])
tensor([[0.2392]], grad_fn=<AddmmBackward0>)
tensor([1])
10
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0])
tensor([[0.2350]], grad_fn=<AddmmBackward0>)
tensor([0])
10
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0])
tensor([[0.2379]], grad_fn=<AddmmBackward0>)
tensor([1])
10
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0])
tensor([[0.2447]], grad_fn=<AddmmBackward0>)
tensor([0])
10
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0])
tens

## Testing

In [34]:
from torch_geometric.datasets import ExplainerDataset
from torch_geometric.datasets.graph_generator import BAGraph
from torch_geometric.datasets.motif_generator import HouseMotif
from torch_geometric.datasets.motif_generator import CycleMotif
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader

dataset1 = ExplainerDataset(
            graph_generator=BAGraph(20, 1),
            motif_generator=HouseMotif(),
            num_motifs=1,
            num_graphs=400,
            transform=T.Constant()      # appends value 1 node feature for every node
        )

dataset2 = ExplainerDataset(
            graph_generator=BAGraph(20, 1),
            motif_generator=CycleMotif(5),
            num_motifs=1,
            num_graphs=400,
            transform=T.Constant()
        )

dataset = torch.utils.data.ConcatDataset([dataset1, dataset2])

print(dataset)
dataset[0].y = torch.tensor([0])

print(dataset[0].y)

train_loader = DataLoader(dataset1, batch_size = 1, shuffle = True)

print(next(iter(train_loader)).y)

<torch.utils.data.dataset.ConcatDataset object at 0x000001ED56EB32C0>
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 2,
        3])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 2,
        3])


In [40]:
generator1 = torch.Generator().manual_seed(42)
generator2 = torch.Generator().manual_seed(42)
set1, set2 = torch.utils.data.random_split(range(10), [3, 7])
set3, set4, set5 = torch.utils.data.random_split(range(30), [0.3, 0.3, 0.4])

print("Set 1:")
print(set1[0])
print(set1[1])
print(set1[2])
print("Set 2:")
print(set2[0])
print(set2[1])
print(set2[2])
print(set2[3])
print(set2[4])
print(set2[5])
print(set2[6])

Set 1:
9
3
5
Set 2:
2
6
4
8
7
0
1


In [91]:
train_loader, val_loader, test_loader = datasetLoader.loadDataset('MUTAG', batch_size) 
temp = next(iter(train_loader))
print(temp.y.argmax(dim=-1))

Downloading https://www.chrsmrrs.com/graphkerneldatasets/Mutagenicity.zip


tensor(0)
