In [58]:
import sys
sys.path.append('../python/')
from json2graph import jsonFile2graph
import networkx as nx
from vocabulary import Vocabulary
from graphUtils import plot_graph, graph2data, data2graph

# generator must belong to {VIATRA,RANDOMEMF,ALLOY,RAND}
generator = 'VIATRA'
# modelType must belong to {Ecore,RDS,Yakindu}
modelType = 'Ecore'

# Generate dataset

Vocabularies:

In [59]:
vocab_edges = Vocabulary()
vocab_nodes = Vocabulary()

In [60]:
def addOpposite(G):
    if (modelType == 'Ecore'):
        to_add = []
        for n1,n2,e in G.edges:
            typee = G[n1][n2][e]['type']
            if (typee == 'eSuperTypes' or
               typee == 'eType' or typee == 'eOpposite'):
                to_add.append((n2,n1,typee+'_inv'))
        for n1,n2,t in to_add:
            G.add_edge(n1,n2,type=t)
        return G
    elif (modelType == 'RDS'):
        to_add = []
        for n1,n2,e in G.edges:
            typee = G[n1][n2][e]['type']
            if (typee == 'elements' or
               typee == 'columns' or typee == 'indexes'
               or typee == 'column' or typee == 'indexColumns'):
                to_add.append((n2,n1,typee+'_inv'))
        for n1,n2,t in to_add:
            G.add_edge(n1,n2,type=t)
        return G
    elif (modelType == 'Yakindu'):
        to_add = []
        for n1,n2,e in G.edges:
            typee = G[n1][n2][e]['type']
            if (typee == 'vertices' or
               typee == 'regions'):
                to_add.append((n2,n1,typee+'_inv'))
        for n1,n2,t in to_add:
            G.add_edge(n1,n2,type=t)
        return G

In [61]:
def passFilter(G):
    return len(G) >= 4

In [62]:
import glob
#import random
#RandomEmfEcore
files = glob.glob("../syntheticGraphs/"+generator+"/"+modelType+"/*.json")
#files = glob.glob('/home/antolin/wakame/randomStuff/RealisticModels/YakinduRandomEmf/All/*.json')
#files = random.sample(files,221)
mine = []
for f in files:
    try:
        G = jsonFile2graph(f)
        G = addOpposite(G)
        if not passFilter(G):
            continue
        data = graph2data(G,0,vocab_nodes,vocab_edges)
        mine.append(data)
    except:
        continue
#random.shuffle(mine)
#mine = mine[0:268]

In [63]:
len(files)

368

In [64]:
files = glob.glob("../realGraphs/"+modelType+"/R1/*.json")
real = []
for f in files:
    G = jsonFile2graph(f)
    G = addOpposite(G)
    if not passFilter(G):
            continue
    data = graph2data(G,1,vocab_nodes,vocab_edges)
    real.append(data)

In [65]:
import random
random.seed(123)

if len(mine) > len(real):
    mine = random.sample(mine,len(real))
elif len(mine) < len(real):
    real = random.sample(real,len(mine))

In [66]:
import random
random.seed(3)
dataset = mine + real
random.shuffle(dataset)

In [67]:
print('Len train:', len(dataset))

Len train: 368


In [68]:
vocab_edges.word2id_names

{'regions': 0,
 'vertices': 1,
 'regions_inv': 2,
 'outgoingTransitions': 3,
 'vertices_inv': 4,
 'incomingTransitions': 5,
 'target': 6,
 'source': 7}

In [69]:
from torch.utils.data import random_split
import torch
train_len = int(0.6*len(dataset))
val_len = int(0.15*len(dataset))
test_len = len(dataset) - int(0.6*len(dataset)) - int(0.15*len(dataset))
train, val, test = random_split(dataset, [train_len, val_len ,test_len], 
                                generator=torch.Generator().manual_seed(42))

# Neural network

In [70]:
import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter.composite import scatter_softmax

class GNN_MoRec(nn.Module):
    
    def __init__(self,dim_input,hidden_dim,dropout):
        super(GNN_MoRec, self).__init__()
        
        
        self.emb_nodes = nn.Embedding(len(vocab_nodes.word2id_names), dim_input)
        
        
        self.conv_1 =  pyg_nn.RGCNConv(in_channels = dim_input, out_channels = hidden_dim, 
                                num_relations = len(vocab_edges.word2id_names))
        
        self.conv_2 =  pyg_nn.RGCNConv(in_channels = hidden_dim, out_channels = hidden_dim, 
                                num_relations = len(vocab_edges.word2id_names))
                
        
        
        self.d_1 = nn.Dropout(dropout)
        
        self.lin = nn.Linear(hidden_dim, 1)
        
        self.attention_vector = nn.Linear(hidden_dim,1,bias=False)
    
    def forward(self,nodeTypes,edge_index, edge_attr, bs):
        
        
        nodeTypes = self.emb_nodes(nodeTypes)
        
        
        
        nodes_mess_1 = self.conv_1(nodeTypes, edge_index, edge_attr)
        nodes_mess_1 = self.d_1(F.relu(nodes_mess_1))
        
        nodes_mess_1 = F.relu(self.conv_2(nodes_mess_1, edge_index, edge_attr))
        
        
        attentions = scatter_softmax(torch.squeeze(self.attention_vector(nodes_mess_1)), bs)
        
        nodes_mess_1 = torch.unsqueeze(attentions,dim=1) * nodes_mess_1
        
        graph_emb = pyg_nn.global_add_pool(nodes_mess_1, bs)
        
        rtu = self.lin(graph_emb)
        
        return F.sigmoid(rtu)
    
    def getAttentions(self,nodeTypes,edge_index, edge_attr, bs):
        
        nodeTypes = self.emb_nodes(nodeTypes)
        nodes_mess_1 = self.conv_1(nodeTypes, edge_index, edge_attr)
        nodes_mess_1 = self.d_1(F.relu(nodes_mess_1))
        
        nodes_mess_1 = F.relu(self.conv_2(nodes_mess_1, edge_index, edge_attr))
        
        
        attentions = scatter_softmax(torch.squeeze(self.attention_vector(nodes_mess_1)), bs)
        
        return attentions
        

# Training

In [71]:
from torch_geometric.data import DataLoader
train_loader = DataLoader(train, batch_size=32, num_workers = 5, shuffle=True)
val_loader = DataLoader(val, batch_size=1, num_workers = 5, shuffle=True)

In [15]:
def evaluation(model, loader):
    model.eval()
    count = 0.0
    with torch.no_grad():
        for data in loader:
            pred = model(data.x.cuda(), data.edge_index.cuda(),
          torch.squeeze(data.edge_attr.cuda(),dim=1),data.batch.cuda())
            if pred[0].item() > 0.5:
                pred = 1
            else:
                pred = 0
            if pred == data.y.long().item():
                count = count + 1
    return count/len(loader)

In [39]:
from nnUtils import EarlyStopping
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

model = GNN_MoRec(64,64,0.0).cuda()

epochs = 200
criterion = nn.BCELoss()

opt = torch.optim.Adam(model.parameters(), lr=0.001)
es = EarlyStopping(opt, model, modelType+'-'+generator+'-GNN',mode='max',patience=50)


for e in range(epochs):
    total_loss = 0.0
    b = 1
    model.train()
    for data in train_loader:
        
        opt.zero_grad()
        
        pred = model(data.x.cuda(), data.edge_index.cuda(),
          torch.squeeze(data.edge_attr.cuda(),dim=1),data.batch.cuda())
        loss = criterion(torch.squeeze(pred), data.y.float().cuda())
        total_loss += loss.item()
        
        loss.backward()
        opt.step()
        b = b + 1
        
    val_acc = evaluation(model, val_loader)
    print('Epoch',e,'Loss',total_loss/b)
    print('Eval',val_acc)
    
    if es.step(val_acc,e):
        break



Epoch 0 Loss 0.4895031899213791
Eval 0.9454545454545454
Epoch 1 Loss 0.3030615411698818
Eval 0.9818181818181818
Epoch 2 Loss 0.1776052014902234
Eval 0.9818181818181818
Epoch 3 Loss 0.14664837252348661
Eval 0.9818181818181818
Epoch 4 Loss 0.12768686283379793
Eval 1.0
Epoch 5 Loss 0.12687270483002067
Eval 1.0
Epoch 6 Loss 0.1126244105398655
Eval 1.0
Epoch 7 Loss 0.1062348005361855
Eval 0.9818181818181818
Epoch 8 Loss 0.10387372504919767
Eval 0.9818181818181818
Epoch 9 Loss 0.10398070886731148
Eval 1.0
Epoch 10 Loss 0.09625882259570062
Eval 1.0
Epoch 11 Loss 0.09696727618575096
Eval 1.0
Epoch 12 Loss 0.09078394481912255
Eval 1.0
Epoch 13 Loss 0.0904529350809753
Eval 1.0
Epoch 14 Loss 0.0807613474316895
Eval 0.9818181818181818
Epoch 15 Loss 0.08490354102104902
Eval 1.0
Epoch 16 Loss 0.08626751345582306
Eval 1.0
Epoch 17 Loss 0.08578318450599909
Eval 1.0
Epoch 18 Loss 0.07748268451541662
Eval 1.0
Epoch 19 Loss 0.07563338987529278
Eval 1.0
Epoch 20 Loss 0.06929029361344874
Eval 1.0
Epoch 21 

In [40]:
if e != 50:
    model2 = GNN_MoRec(64,64,0.0).cuda()



    checkpoint = torch.load(modelType+'-'+generator+'-GNN')
    model2.load_state_dict(checkpoint['model_state_dict'])

    epoch = checkpoint['epoch']
    loss = checkpoint['loss']

    model2.eval()
else:
    torch.save({
            'epoch': 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': 0,
            'loss': 0,
            }, modelType+'-'+generator+'-GNN')
    model2=model


Load model directly

In [72]:
model2 = GNN_MoRec(64,64,0.0).cuda()
checkpoint = torch.load(modelType+'-'+generator+'-GNN')
model2.load_state_dict(checkpoint['model_state_dict'])

epoch = checkpoint['epoch']
loss = checkpoint['loss']

model2.eval()

GNN_MoRec(
  (emb_nodes): Embedding(9, 64)
  (conv_1): RGCNConv(64, 64, num_relations=8)
  (conv_2): RGCNConv(64, 64, num_relations=8)
  (d_1): Dropout(p=0.0, inplace=False)
  (lin): Linear(in_features=64, out_features=1, bias=True)
  (attention_vector): Linear(in_features=64, out_features=1, bias=False)
)

# Testing

In [73]:
test_loader = DataLoader(test, batch_size=1, num_workers = 5, shuffle=True)

In [74]:
model2.eval()
count = 0
i0 = 0
i1 = 0
for data in test_loader:
    
    pred = model2(data.x.cuda(), data.edge_index.cuda(),
          torch.squeeze(data.edge_attr,dim=1).cuda(),data.batch.cuda())
    if pred[0].item() > 0.5:
        pred = 1
    else:
        pred = 0
    if pred == data.y.long().item():
        count = count + 1
    
print('Acc', count/len(test_loader))

Acc 1.0


Acc 0.956989247311828 <- VIATRA

Acc 0.956989247311828 <- Alloy

Acc 0.7634408602150538 <- RANDOMEMF

Acc 1.0 <- Random

In [75]:
from C2ST import C2ST_pvalue

acc =  count/len(test_loader)
n_test = len(test_loader)
print('p-value', C2ST_pvalue(acc,n_test))

p-value 2.6147170013952424e-22


In [76]:
n_test

93

# Interpreting

In [23]:
from interpretation import heatMap, plot_graph_attention, importantSubgraph, getMapAttention
i = 0
for data in test:
    G = data2graph(data,vocab_nodes,vocab_edges)
    batch = torch.zeros(len(G)).long()
    atts = model2.getAttentions(data.x.cuda(), data.edge_index.cuda(),
          torch.squeeze(data.edge_attr.cuda(),dim=1),batch.cuda())
    map_colors = getMapAttention(G,atts)
    
    pred = model2(data.x.cuda(), data.edge_index.cuda(),
          torch.squeeze(data.edge_attr.cuda(),dim=1),batch.cuda())
    if pred[0].item() < 0.1 and data.y.item() == 0:
        #plot_graph_attention(G,map_colors)
        #plot_graph_attention(importantSubgraph(G, atts.detach().cpu().numpy(), 0.2, 2),map_colors)
        heatMap(G,atts,str(i),'./interpretation/'+modelType+'/'+generator+'/')
        heatMap(importantSubgraph(G, atts.detach().cpu().numpy(), 0.2, 2),atts,str(i),'./interpretation/'+modelType+'/'+generator+'/subgraph/')
        i = i + 1
        print('--'*80)

----------------------------------------------------------------------------------------------------------------------------------------------------------------




----------------------------------------------------------------------------------------------------------------------------------------------------------------




----------------------------------------------------------------------------------------------------------------------------------------------------------------




----------------------------------------------------------------------------------------------------------------------------------------------------------------




----------------------------------------------------------------------------------------------------------------------------------------------------------------




----------------------------------------------------------------------------------------------------------------------------------------------------------------




----------------------------------------------------------------------------------------------------------------------------------------------------------------




----------------------------------------------------------------------------------------------------------------------------------------------------------------




----------------------------------------------------------------------------------------------------------------------------------------------------------------




----------------------------------------------------------------------------------------------------------------------------------------------------------------




----------------------------------------------------------------------------------------------------------------------------------------------------------------




----------------------------------------------------------------------------------------------------------------------------------------------------------------




----------------------------------------------------------------------------------------------------------------------------------------------------------------




----------------------------------------------------------------------------------------------------------------------------------------------------------------




----------------------------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------------------------------------------




----------------------------------------------------------------------------------------------------------------------------------------------------------------




----------------------------------------------------------------------------------------------------------------------------------------------------------------




----------------------------------------------------------------------------------------------------------------------------------------------------------------




----------------------------------------------------------------------------------------------------------------------------------------------------------------




----------------------------------------------------------------------------------------------------------------------------------------------------------------




----------------------------------------------------------------------------------------------------------------------------------------------------------------
