In [None]:
import sys
sys.path.append('../')
import dmg.model2graph.model2graph as m2g
import dmg.model2graph.metafilter as mf
from networkx.algorithms.isomorphism import is_isomorphic
import dmg.graphUtils as gu
import glob
from dmg.yakindu.yakinduPreprocess import removeLayout
import dmg.yakindu.yakinduPallete as yp 
import random
random.seed(123)

# Load dataset

In [None]:
metafilter_refs = ['Region.vertices', 
                           'CompositeElement.regions',
                           'Vertex.outgoingTransitions',
                           'Vertex.incomingTransitions',
                           'Transition.target',
                           'Transition.source']
metafilter_cla = list(yp.dic_nodes_yak.keys())     
metafilter_atts = None
metafilterobj = mf.MetaFilter(references = metafilter_refs, 
                 attributes = metafilter_atts,
                 classes = metafilter_cla)       
meta_models = glob.glob("../data/metamodels/yakinduComplete/*")

In [None]:
files = glob.glob("../data/yakinduDataset/train/*")
graphs = []
for f in files:
    graphs.append(m2g.getGraphFromModel(f, 
                              meta_models, metafilterobj,
                              consider_atts = False))

In [None]:
print('Number of graphs:', len(graphs))

In [None]:
files = glob.glob("../data/yakinduDataset/val/*")
graphs_val = []
for f in files:
    graphs_val.append(m2g.getGraphFromModel(f, 
                              meta_models, metafilterobj,
                              consider_atts = False))

In [None]:
print('Number of graphs:', len(graphs_val))

In [None]:
from torch_geometric.data import DataLoader
listDatas_val = []
batch_size = 64
print('Preparing seqs')
for g in graphs_val:
    sequence = yp.yakindu_pallete.graphToSequence(g)
    listDatas_val = listDatas_val + sequence2data(sequence, yp.yakindu_pallete, max_len)
loader_val = DataLoader(listDatas_val, batch_size=batch_size, 
                        num_workers = 0, 
                        shuffle=False)
print('Seqs finished')

# Training

In [None]:
from dmg.deeplearning.dataGeneration import sequence2data, data2graph
from dmg.deeplearning.generativeModel import GenerativeModel
import torch
import torch.nn as nn

epochs = 100
max_len = 2
hidden_dim = 128

criterion_node = nn.CrossEntropyLoss(reduction = 'mean',ignore_index=-1)
criterion_action = nn.CrossEntropyLoss(reduction = 'mean')
model = GenerativeModel(hidden_dim, yp.dic_nodes_yak, yp.dic_edges_yak, yp.dic_operations_yak)
opt = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(epochs):
    model.train()
    total_loss = 0
    listDatas = []
    #preparing training set
    print('Preparing seqs')
    for g in graphs:
        sequence = yp.yakindu_pallete.graphToSequence(g)
        listDatas = listDatas + sequence2data(sequence, yp.yakindu_pallete, max_len)
    loader = DataLoader(listDatas, batch_size=batch_size, 
                            num_workers = 0, 
                            shuffle=False)
    print('Seqs finished')
    #training
    for data in loader:
        opt.zero_grad()
        action, nodes = model(data.x, data.edge_index, 
                        torch.squeeze(data.edge_attr,dim=1), 
                data.batch, data.sequence, data.nodes, data.len_seq, data.action)
        
        nodes = torch.unsqueeze(nodes, dim = 2).repeat(1,1,2)
        nodes[:,:,0] = 1 - nodes[:,:,1]
            
        L = torch.max(data.len_seq).item()
        gTruth = data.sequence_masked[:,0:L]
        loss = (criterion_node(nodes.reshape(-1,2), gTruth.flatten()) +
                    criterion_action(action, data.action)) / 2
        total_loss += loss.item()
        loss.backward()
        opt.step()
    #validation
    val_loss = 0
    model.eval()
    with torch.no_grad():
        for data in loader_val:
            action, nodes = model(data.x, data.edge_index, 
                        torch.squeeze(data.edge_attr,dim=1), 
                data.batch, data.sequence, data.nodes, data.len_seq, data.action)
            nodes = torch.unsqueeze(nodes, dim = 2).repeat(1,1,2)
            nodes[:,:,0] = 1 - nodes[:,:,1]
            
            L = torch.max(data.len_seq).item()
            gTruth = data.sequence_masked[:,0:L]
            loss = (criterion_node(nodes.reshape(-1,2), gTruth.flatten()) +
                    criterion_action(action, data.action)) / 2
            val_loss+= loss.item()
        
    print('Epoch',epoch,'Loss Traning',total_loss/(len(loader)))
    print('Epoch',epoch,'Loss Val',val_loss/(len(loader_val)))
        