In [11]:
import networkx as nx

import numpy as np
import torch
from tqdm import tqdm

from rostok.graph_grammar.node import GraphGrammar
from rostok.graph_grammar.node_vocabulary import NodeVocabulary
from rostok.graph_grammar.rule_vocabulary import RuleVocabulary
import rule_without_chrono as re


In [12]:

def get_input_layer(node, dict_id_label_nodes):
    input = torch.zeros(len(dict_id_label_nodes)).long()
    input[node] = 1
    return input

def vocabulary2batch_graph(rule_vocabulary: RuleVocabulary, max_rules: int):

    batch_graph = GraphGrammar()
    amount_rules = np.random.randint(1, max_rules)
    for _ in range(amount_rules):
        rules = rule_vocabulary.get_list_of_applicable_rules(batch_graph)
        if len(rules) > 0:
            rule = rule_vocabulary.get_rule(rules[np.random.choice(len(rules))])
            batch_graph.apply_rule(rule)
        else:
            break
    return batch_graph

def random_batch(skip_grams):
    random_inputs = []
    random_labels = []
    random_index = np.random.choice(range(len(skip_grams)), 2, replace=False)

    for i in random_index:
        random_inputs.append(skip_grams[i][0])  # target
        random_labels.append(skip_grams[i][1])  # context word

    return random_inputs, random_labels


In [13]:

class skipgramm_model(torch.nn.Module):

    def __init__(self, vocabulary_size: int, embedding_size: int):
        super().__init__()

        self.embedding = torch.nn.Embedding(vocabulary_size, embedding_size)
        self.W = torch.nn.Linear(embedding_size, embedding_size, bias=False)
        self.WT = torch.nn.Linear(embedding_size, vocabulary_size, bias=False)

    def forward(self, x):
        embdedings = self.embedding(x)
        hidden_layer = torch.nn.functional.relu(self.W(embdedings))
        output_layer = self.WT(hidden_layer)

        return output_layer

    def get_node_embedding(self, node, sorted_node_labels, dict_label_id_nodes):
        input = torch.zeros(len(sorted_node_labels)).float()
        input[dict_label_id_nodes[node]] = 1
        return self.embedding(input).view(1, -1)


def skipgram(paths, dict_label_id_nodes, window_size=1):
    idx_pairs = []
    for path in paths:
        indices = [dict_label_id_nodes[node_label] for node_label in path]
        for pos_center_node, node_index in enumerate(indices):
            for i in range(-window_size, window_size + 1):
                pos_context_node = pos_center_node + i

                if pos_context_node < 0 or pos_context_node >= len(
                        indices) or pos_center_node == pos_context_node:
                    continue
                context_id_node = indices[pos_context_node]
                idx_pairs.append((node_index, context_id_node))

    return np.array(idx_pairs)


def create_dict_node_labels(node_vocabulary: NodeVocabulary):

    sorted_node_labels = sorted(node_vocabulary.node_dict.keys())

    dict_id_label_nodes = dict(enumerate(sorted_node_labels))
    dict_label_id_nodes = {w: idx for (idx, w) in enumerate(sorted_node_labels)}

    return dict_id_label_nodes, dict_label_id_nodes


In [14]:
rule_vocab = re.init_extension_rules()
node_vocabulary = rule_vocab.node_vocab

id2label, label2id = create_dict_node_labels(node_vocabulary)

model = skipgramm_model(len(id2label), 2)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in tqdm(range(150000)):
    graph = vocabulary2batch_graph(rule_vocab, 15)
    pairs = skipgram(graph.get_uniq_representation(),label2id)
    input_batch, target_batch = random_batch(pairs)
    input_batch = get_input_layer(input_batch, id2label)
    target_batch = get_input_layer(target_batch, id2label)

    optimizer.zero_grad()
    output = model(input_batch)

    # output : [batch_size, voc_size], target_batch : [batch_size] (LongTensor, not one-hot)
    loss = criterion(output, target_batch)
    if (epoch + 1) % 10000 == 0:
        print('Epoch:', '%04d' % (epoch + 1), ' cost =', '{:.6f}'.format(loss))

    loss.backward(retain_graph=True)
    optimizer.step()

  7%|▋         | 10010/150000 [02:06<25:55, 89.99it/s]

Epoch: 10000  cost = 4.025352


 13%|█▎        | 20014/150000 [04:06<27:03, 80.06it/s] 

Epoch: 20000  cost = 4.025352


 20%|██        | 30010/150000 [06:12<26:46, 74.71it/s] 

Epoch: 30000  cost = 4.025352


 27%|██▋       | 40011/150000 [09:27<19:25, 94.39it/s]  

Epoch: 40000  cost = 4.025352


 33%|███▎      | 50011/150000 [11:52<25:07, 66.32it/s]

Epoch: 50000  cost = 4.025352


 40%|████      | 60012/150000 [14:17<24:20, 61.63it/s]

Epoch: 60000  cost = 4.025352


 47%|████▋     | 70008/150000 [16:41<18:35, 71.70it/s]

Epoch: 70000  cost = 4.025352


 53%|█████▎    | 80017/150000 [19:08<13:59, 83.34it/s]

Epoch: 80000  cost = 4.025352


 60%|██████    | 90014/150000 [21:29<12:33, 79.63it/s] 

Epoch: 90000  cost = 4.025352


 67%|██████▋   | 100013/150000 [23:33<09:38, 86.48it/s]

Epoch: 100000  cost = 4.025352


 73%|███████▎  | 110012/150000 [25:42<09:33, 69.75it/s] 

Epoch: 110000  cost = 4.025352


 80%|████████  | 120019/150000 [28:07<05:02, 99.08it/s] 

Epoch: 120000  cost = 4.025352


 87%|████████▋ | 130010/150000 [30:13<04:36, 72.26it/s] 

Epoch: 130000  cost = 4.025352


 93%|█████████▎| 140008/150000 [32:23<01:55, 86.14it/s] 

Epoch: 140000  cost = 4.025352


100%|██████████| 150000/150000 [34:22<00:00, 72.72it/s] 

Epoch: 150000  cost = 4.025352





In [17]:
def Skipgram_test(test_data, model):
    correct_ct = 0

    for i in range(len(test_data)):
        input_batch, target_batch = random_batch(test_data)
        input_batch = torch.LongTensor(input_batch)
        target_batch = torch.LongTensor(target_batch)

        model.zero_grad()
        _, predicted = torch.max(model(input_batch), 1)




        if predicted[0] == target_batch[0]:
                correct_ct += 1

    print('Accuracy: {:.1f}% ({:d}/{:d})'.format(correct_ct/len(test_data)*100, correct_ct, len(test_data)))

In [19]:
Skipgram_test(skipgram(graph.get_uniq_representation(), label2id), model)

Accuracy: 0.0% (0/16)
