In [None]:
import networkx as nx
from networkx.readwrite import json_graph
from collections import Counter, defaultdict
from sklearn.model_selection import train_test_split
from nltk.tokenize import wordpunct_tokenize
from tqdm import tqdm
import os
import json
import itertools
import torch
from gensim.models import FastText
import numpy as np
from copy import copy

In [None]:
#!wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ru.300.bin.gz

--2021-05-03 08:27:16--  https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ru.300.bin.gz
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 172.67.9.4, 104.22.74.142, 104.22.75.142, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|172.67.9.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4496459151 (4.2G) [application/octet-stream]
Saving to: ‘cc.ru.300.bin.gz.1’

cc.ru.300.bin.gz.1    0%[                    ]   3.63M  2.77MB/s               ^C


In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
os.chdir('gdrive/MyDrive')

In [None]:
fasttext_model = FastText.load_fasttext_format('cc.ru.300.bin')

In [None]:
def choose_tag(fragment_noun_tags):
    if 'subj' in fragment_noun_tags:
        return subj
    elif 'obj' in fragment_noun_tags:
        return 'obj'
    else:
        return 'other_dep'

def quest_path2entity_graph(path, morphograph, morphology_key='node_morphodata', text_key='fragment_text'):
    G = nx.Graph()
    
    noun_lemmas_counter = Counter()
    nodes = morphograph.nodes()
    
    node2chosen_syntactic_tags = {}
    nouns2nodes = defaultdict(set)
    
    for node in path:
        current_node2all_syntactic_tags = defaultdict(set)
        for noun_data in nodes[node][morphology_key]['nouns']:
            noun_lemma = noun_data['lemma']
            noun_lemmas_counter.update([noun_lemma])
            nouns2nodes['noun_'+noun_lemma].add(node)
            dep_type = 'other_dep'
            dependency = noun_data['deprel']
            if 'subj' in dependency:
                dep_type = 'subject'
            elif 'obj' in dependency:
                dep_type = 'object'
            current_node2all_syntactic_tags[noun_lemma].add(dep_type)
        
        node2chosen_syntactic_tags[node] = {noun:choose_tag(current_node2all_syntactic_tags[noun]) for noun in current_node2all_syntactic_tags}
        
    relevant_nouns = [noun for noun in noun_lemmas_counter if noun_lemmas_counter[noun]>1]
    G.add_node('global', fragment_text='UNK')
    for relevant_noun in relevant_nouns:
        G.add_node('noun_'+relevant_noun, fragment_text=relevant_noun)
    for node in path:
        G.add_node(node, fragment_text=nodes[node][text_key])
        G.add_edge(node, 'global', label='global')
        for noun, syntactic_role in node2chosen_syntactic_tags[node].items():
            if noun in relevant_nouns:
                G.add_edge(node, 'noun_'+noun, label=syntactic_role)
    for noun in nouns2nodes:
        entity_neighbours = itertools.combinations(nouns2nodes[noun], 2)
        for pair in entity_neighbours:
            G.add_edge(*pair, label='fragments_pair')
    return G, [['noun_'+relevant_noun for noun in relevant_nouns]]

In [None]:
BOOK_GRAPHS_DIRECTORY = 'D:\Диплом_текстовые_квесты\Data\quest_books_graphs_morphology'
ONLINE_GRAPHS_DIRECTORY = 'D:\Диплом_текстовые_квесты\Data\Questbook_online_grammar'

dev_online_graphs, test_online_graphs = train_test_split(os.listdir(ONLINE_GRAPHS_DIRECTORY), random_state=42, test_size=0.4)

all_dev_graphs = [os.path.join(BOOK_GRAPHS_DIRECTORY, book_graph) for book_graph in os.listdir(BOOK_GRAPHS_DIRECTORY)] + [os.path.join(ONLINE_GRAPHS_DIRECTORY, online_graph) for online_graph in dev_online_graphs]
test_graphs_paths = [os.path.join(ONLINE_GRAPHS_DIRECTORY, test_graph) for test_graph in test_online_graphs]

train_paths, valid_paths = train_test_split(all_dev_graphs, random_state=42, test_size=0.1)

In [None]:
def correct_path(path, graph, morphodata_field, text_field='fragment_text'):
    nodes = graph.nodes()
    correct_path = [fragment_id for fragment_id in path if  morphodata_field in nodes[fragment_id] and text_field in nodes[fragment_id] and isinstance(nodes[fragment_id][text_field], str)]
    return correct_path

In [None]:
COLAB_TRAIN_DATA_DIRECTORY = 'tokenized_ordering_train'

In [None]:
COLAB_TEST_DATA_DIRECTORY = 'tokenized_ordering_test_joined'

In [None]:
edge_labels = [None, 'global', 'fragments_pair','subject', 'object', 'other_dep']

In [None]:
class GraphFastTextPreprocessor(torch.nn.Module):
    def __init__(self, fasttext_model, edge_labels, fasttext_embedding_dim=300, lstm_hidden_size=50):
        super().__init__()
        self.edges_vocab = {label:idx for idx, label in enumerate(edge_labels)}
        self.fasttext_model = fasttext_model
        self.fasttext_embedding_dim = fasttext_embedding_dim
        self.lstm_hidden_size = lstm_hidden_size
        self.output_size = lstm_hidden_size*4
        self.encoding_lstm = torch.nn.LSTM(input_size=fasttext_embedding_dim, hidden_size=lstm_hidden_size, batch_first=True, bidirectional=True)
        self.non_linear = torch.sigmoid
        
    
    def forward(self, graph, quest_correct_path, quest_nouns):
        nodes_order = ['global'] + quest_correct_path + quest_nouns
        node2id = {node:idx for idx, node in enumerate(nodes_order)}
        nodes = graph.nodes()
        edges = graph.edges()
        graph_transition_table = torch.zeros(len(nodes_order), len(nodes_order)).to(device)
        for node1, node2 in edges:
            node1_idx = node2id[node1]
            node2_idx = node2id[node2]
            label_idx = self.edges_vocab[edges[(node1, node2)]['label']]
            graph_transition_table[node1_idx][node2_idx] = label_idx
            graph_transition_table[node2_idx][node1_idx] = label_idx
        texts = []
        for node in nodes_order:
            if 'tokenizen_text' in nodes[node]:
                texts.append([nodes[node]['tokenized_text']])
            elif 'fragment_text' in nodes[node]:
                texts.append(nodes[node]['fragment_text'])
        text_embedded_nodes = self.embed_text_batch(texts).to(device)
        text_encoded_nodes = self.lstm_encode(text_embedded_nodes)
        #global nodes is initialized with zeros
        text_encoded_nodes = torch.cat((torch.zeros(1, text_encoded_nodes.shape[1]).to(device), text_encoded_nodes[1:].clone()), 0)
        #text_encoded_nodes[0] = text_encoded_nodes[0].clone() * 0
        return text_encoded_nodes, graph_transition_table.int()
                   
    def embed_text_batch(self, tokenized_texts):
        embedded_batch = []
        for tokenized_text in tokenized_texts:
            embeddings = [self.fasttext_model[token] for token in tokenized_text if token in self.fasttext_model]
            embedded_batch.append(embeddings)
        max_len = max([len(embeddings) for embeddings in embedded_batch])
        padded_batch = []
        for embeddings in embedded_batch:
            padded_embeddings = [np.zeros(self.fasttext_embedding_dim) for i in range(max_len-len(embeddings))] + embeddings
            padded_batch.append(padded_embeddings)
        return torch.tensor(padded_batch).float()
        
    def lstm_encode(self, embeddings):
        _, representation = self.encoding_lstm(embeddings)
        representation_vector = torch.cat((representation[0].clone()[0].clone(), representation[0].clone()[1].clone(), representation[1].clone()[0].clone(), representation[1].clone()[1].clone()),1)
        return self.non_linear(representation_vector)
        #return representation
        #return torch.cat((representation[0][0], representation[0][1], representation[1][0], representation[1][1]),1)

In [None]:
class GRN(torch.nn.Module):
    def __init__(self, num_edge_embeddings, edge_embedding_dim, edge_pad_id=0, num_iterations=2):
        super().__init__()
        self.num_iterations=num_iterations
        self.edge_embedding_layer = torch.nn.Embedding(num_embeddings=num_edge_embeddings, embedding_dim=1, padding_idx=edge_pad_id)
        
        self.forgetting_linear = torch.nn.Linear(edge_embedding_dim, edge_embedding_dim)
        self.forgetting_nonlinear = torch.sigmoid

        self.incoming_linear = torch.nn.Linear(edge_embedding_dim, edge_embedding_dim)
        self.incoming_nonlinear = torch.tanh

        self.aggregation_nonlinear = torch.tanh

    def forward(self, nodes_embeddings, transition_tables):
        embedded_transitions = torch.squeeze(self.edge_embedding_layer(transition_tables)).float()
        for i in range(self.num_iterations):
            incoming_sygnal = torch.matmul(embedded_transitions, nodes_embeddings)
            incoming_sygnal = self.incoming_nonlinear(self.incoming_linear(incoming_sygnal))
            sygnal2forget = self.forgetting_linear(self.forgetting_nonlinear(nodes_embeddings))
            nodes_embeddings = self.aggregation_nonlinear(nodes_embeddings - sygnal2forget + incoming_sygnal)
        return nodes_embeddings

In [None]:
class OrderingNetwork(torch.nn.Module):
    def __init__(self, fasttext_model, edge_labels, fasttext_embedding_dim=300, encoder_lstm_hidden_size=50, pointer_network_lstm_hidden_size=50, pointer_network_hidden_dim=25):
        super().__init__()
        self.graph_preprocessor = GraphFastTextPreprocessor(fasttext_model, edge_labels, fasttext_embedding_dim, encoder_lstm_hidden_size)
        self.text_embedding_dim = self.graph_preprocessor.output_size
        self.GRN = GRN(num_edge_embeddings=len(edge_labels), edge_embedding_dim=self.text_embedding_dim)
        self.pointer_network_lstm_hidden_size = pointer_network_lstm_hidden_size
        self.pointer_network_hidden_dim = pointer_network_hidden_dim

        self.history_lstm = torch.nn.LSTM(input_size=self.text_embedding_dim, hidden_size=self.text_embedding_dim, batch_first=True)
        self.history_linear = torch.nn.Linear(self.text_embedding_dim, self.pointer_network_hidden_dim)
        self.candidate_linear = torch.nn.Linear(self.text_embedding_dim, self.pointer_network_hidden_dim)
        self.scoring_linear = torch.nn.Linear(self.pointer_network_hidden_dim, 1)
  

    def forward(self, graph, quest_correct_path, quest_nouns):
        text_encoded_nodes, graph_transition_table = self.graph_preprocessor(graph, quest_correct_path, quest_nouns)
        text_encoded_nodes = self.GRN(text_encoded_nodes, graph_transition_table)

        #first vector is the global state representation; the next vectors represent fragment nodes
        #during training I will treat previous nodes as possible candidates, repetitions can be prohibited apostpriori
        correct_history = text_encoded_nodes[:len(quest_correct_path)].clone()
        candidates = text_encoded_nodes[1:len(quest_correct_path)+1].clone()

        candidates_representation = torch.tanh(self.candidate_linear(candidates))
        history_representation = torch.tanh(self.history_linear(correct_history))
        scores = torch.zeros(len(candidates), len(candidates)).to(device)
        for step_idx in range(history_representation.shape[0]):
            history_step = history_representation[step_idx, :].clone()
            scores[step_idx] =  scores[step_idx].clone() + torch.squeeze(torch.sigmoid(self.scoring_linear(candidates_representation+history_step)))
       # return torch.softmax(scores, 1)
        return scores

In [None]:
ordering_model = OrderingNetwork(fasttext_model, edge_labels)
ordering_model.train()

In [None]:
optimizer =  torch.optim.Adam([{'params': ordering_model.parameters()}], 
                               lr = 1e-3)
criterion = torch.nn.BCELoss()

In [None]:
ordering_model.load_state_dict(torch.load('graph_based_ordering_model.pth'))
optimizer.load_state_dict(torch.load('graph_based_ordering_model_optimizer.pth'))

In [None]:
text_encoded_nodes_GRN = ordering_model.GRN(text_encoded_nodes, graph_transition_table)

In [None]:
device='cpu'

In [None]:
train_paths = os.listdir(COLAB_TRAIN_DATA_DIRECTORY)
real_i = 5001
for i in range(5001, len(train_paths)):
    prepared_path_file = train_paths[i]
    
    path_path = os.path.join(COLAB_TRAIN_DATA_DIRECTORY, prepared_path_file)
    path_data = json.load(open(path_path, encoding='utf-8'))
    path_data[2] = json_graph.node_link_graph(path_data[2])
    path_data[3] = [node for node in path_data[2].nodes() if isinstance(node, str) and 'noun_' in node]
    try:
        _, correct_path, graph, nouns = path_data
        if len(nouns) < 200 and len(correct_path) + len(nouns) < 220:
            print('num_fragments', len(correct_path), 'num_nouns', len(nouns))
            scores = ordering_model(graph, correct_path, nouns)
            correct_scores = torch.eye(scores.shape[0]).to(device)

            loss = criterion(scores, correct_scores)
            loss.backward()
            real_i += 1
    except Exception as e:
        print(e)

    if real_i % 10 == 0:
        optimizer.step()
        optimizer.zero_grad()
    if real_i % 50 == 0:
       torch.save(ordering_model.state_dict(), 'graph_based_ordering_model.pth')
       torch.save(optimizer.state_dict(), 'graph_based_ordering_model_optimizer.pth')
       print(real_i, i)

In [None]:
def entity_graph_beam_search(graph, path, nouns, num_candidates = 3):
    text_encoded_nodes, graph_transition_table = ordering_model.graph_preprocessor(graph, path, nouns)
    text_encoded_nodes = ordering_model.GRN(text_encoded_nodes, graph_transition_table)

    candidates = text_encoded_nodes[1:len(path)+1].clone()
    initial_history = torch.unsqueeze(text_encoded_nodes[0].clone(), 0)
    unordered_fragments_representation = torch.tanh(ordering_model.candidate_linear(candidates))
    order_candidates = [{'order': [], 'history': initial_history, 'probability': 1}]
    for i in range(len(candidates)):
        next_step_order_candidates = []
        for order_candidate in order_candidates:
            
            history_representation = ordering_model.history_lstm(torch.unsqueeze(order_candidate['history'], 0))[0] 
            current_history_step = torch.squeeze(torch.tanh(ordering_model.history_linear(history_representation[:,-1,:])))

  
            probabilities = torch.tensor([torch.sigmoid(ordering_model.scoring_linear(current_history_step + fragment_representation)) if i not in order_candidate['order'] else 0 for i, fragment_representation in enumerate(unordered_fragments_representation)])
            current_step_num_candidates = min(num_candidates, len(candidates)-len(order_candidate['order']))
            next_idxs = torch.argsort(probabilities,  descending=True)[:current_step_num_candidates]
            for next_idx in next_idxs:
                probability = float(probabilities[next_idx])
                next_step_order = copy(order_candidate)
                next_step_order['probability'] = next_step_order['probability'] * probability
                next_step_order['order'] = next_step_order['order'] + [int(next_idx)]
                next_step_order['history'] = torch.cat((next_step_order['history'], torch.unsqueeze(candidates[next_idx], 0)), 0)
                next_step_order_candidates.append(next_step_order)
        order_candidates = sorted(next_step_order_candidates, key = lambda x: ['probability'])[:num_candidates]
    return order_candidates[0]['order']

In [None]:
import math

In [None]:
def scoring_entity_graph_beam_search(graph, path, nouns, num_candidates = 3, epsilon=10**(-9)):
    text_encoded_nodes, graph_transition_table = ordering_model.graph_preprocessor(graph, path, nouns)
    text_encoded_nodes = ordering_model.GRN(text_encoded_nodes, graph_transition_table)

    candidates = text_encoded_nodes[1:len(path)+1].clone()
    initial_history = torch.unsqueeze(text_encoded_nodes[0].clone(), 0)
    unordered_fragments_representation = torch.tanh(ordering_model.candidate_linear(candidates))
    order_candidates = [{'order': [], 'history': initial_history, 'probability': 0}]
    for i in range(len(candidates)):
        next_step_order_candidates = []
        for order_candidate in order_candidates:
            
            history_representation = ordering_model.history_lstm(torch.unsqueeze(order_candidate['history'], 0))[0] 
            current_history_step = torch.squeeze(torch.tanh(ordering_model.history_linear(history_representation[:,-1,:])))

  
            probabilities = torch.tensor([torch.sigmoid(ordering_model.scoring_linear(current_history_step + fragment_representation)) if i not in order_candidate['order'] else 0 for i, fragment_representation in enumerate(unordered_fragments_representation)])
            current_step_num_candidates = min(num_candidates, len(candidates)-len(order_candidate['order']))
            next_idxs = torch.argsort(probabilities,  descending=True)[:current_step_num_candidates]
            for next_idx in next_idxs:
                score = float(probabilities[next_idx]) if float(probabilities[next_idx]) > 0 else epsilon
                next_step_order = copy(order_candidate)
                next_step_order['probability'] = next_step_order['probability'] + math.log(score)
                next_step_order['order'] = next_step_order['order'] + [int(next_idx)]
                next_step_order['history'] = torch.cat((next_step_order['history'], torch.unsqueeze(candidates[next_idx], 0)), 0)
                next_step_order_candidates.append(next_step_order)
        order_candidates = sorted(next_step_order_candidates, key = lambda x: ['probability'])[:num_candidates]
    return {'order':order_candidates[0]['order'], 'score':order_candidates[0]['probability']}

In [None]:
ordering_model = OrderingNetwork(fasttext_model, edge_labels)
ordering_model.load_state_dict(torch.load('graph_based_ordering_model.pth'))
ordering_model.eval()

OrderingNetwork(
  (graph_preprocessor): GraphFastTextPreprocessor(
    (encoding_lstm): LSTM(300, 50, batch_first=True, bidirectional=True)
  )
  (GRN): GRN(
    (edge_embedding_layer): Embedding(6, 1, padding_idx=0)
    (forgetting_linear): Linear(in_features=200, out_features=200, bias=True)
    (incoming_linear): Linear(in_features=200, out_features=200, bias=True)
  )
  (history_lstm): LSTM(200, 200, batch_first=True)
  (history_linear): Linear(in_features=200, out_features=25, bias=True)
  (candidate_linear): Linear(in_features=200, out_features=25, bias=True)
  (scoring_linear): Linear(in_features=25, out_features=1, bias=True)
)

In [None]:
from tqdm import tqdm

In [None]:
predictions = []
for prepared_path_file in tqdm(os.listdir(COLAB_TEST_DATA_DIRECTORY)):
     
    try:
        with open(os.path.join(COLAB_TEST_DATA_DIRECTORY, file)):
            path_path = os.path.join(COLAB_TEST_DATA_DIRECTORY, prepared_path_file)
            path_data = json.load(open(path_path, encoding='utf-8'))
            path_data[2] = json_graph.node_link_graph(path_data[2])
            path_data[3] = [node for node in path_data[2].nodes() if isinstance(node, str) and 'noun_' in node]

            with torch.no_grad():

                _, path, graph, nouns = path_data
                if len(path) <= 30 and len(path)+len(nouns) <= 120:
                    prediction = entity_graph_beam_search(graph, path, nouns)
                    predictions.append(prediction)

                    if len(predictions) % 100 == 0:
                        with open('entity_graph_predictions.json', 'w') as f:
                            json.dump(predictions, f)
    except Exception as e:
        print(e)


with open('entity_graph_predictions.json', 'w') as f:
    json.dump(predictions, f)

In [None]:
with open('entity_graph_predictions.json') as f:
    predictions = json.load(f)

In [None]:
from scipy.stats import kendalltau
import pandas as pd
def longest_correct_subsequence(predicted, correct):
    correct_transitions = set([(item_from, item_to) for item_from, item_to in zip(correct[:-1], correct[1:])])
    predicted_transitions = [(item_from, item_to) for item_from, item_to in zip(predicted[:-1], predicted[1:])]
    predicted_transitions_are_correct = [transition in correct_transitions for transition in predicted_transitions]
    
    longest_correct_transitions_subsequence = 0
    current_correct_transitions_subsequence = 0
    #Add False as last item to include last real item checking into the loop
    for predicted_transition_is_correct in predicted_transitions_are_correct + [False]:
        if predicted_transition_is_correct:
            current_correct_transitions_subsequence += 1
        else:
            if current_correct_transitions_subsequence > longest_correct_transitions_subsequence:
                longest_correct_transitions_subsequence = current_correct_transitions_subsequence
            current_correct_transitions_subsequence = 0
    #return number of items in longest correct sequence (not number of transitions)
    return longest_correct_transitions_subsequence + 1 if longest_correct_transitions_subsequence else 0  

In [None]:
entity_graph_df = pd.DataFrame(columns = ["sequence length", "Kendall's  tau", "Longest correct subsequence"])
for prediction in predictions:
    correct = list(range(len(prediction)))
    tau = kendalltau(prediction, correct).correlation
    lcs = longest_correct_subsequence(prediction, correct)
    entity_graph_df.loc[len(entity_graph_df)] = [len(correct), tau, lcs]

In [None]:
entity_graph_df_aggr = entity_graph_df[entity_graph_df["sequence length"]<=30].groupby("sequence length").describe()[[("Kendall's  tau", 'count'),  ("Kendall's  tau",  'mean'), ('Longest correct subsequence',  'mean')]]
entity_graph_df_aggr

Unnamed: 0_level_0,Kendall's tau,Kendall's tau,Longest correct subsequence
Unnamed: 0_level_1,count,mean,mean
sequence length,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2
4.0,100.0,-0.133333,1.3
5.0,97.0,-0.08866,1.329897
6.0,150.0,0.032889,1.406667
7.0,152.0,0.187343,1.421053
8.0,125.0,0.073714,1.576
9.0,189.0,0.124927,1.47619
10.0,163.0,-0.0394,1.515337
11.0,222.0,0.114824,1.477477
12.0,212.0,0.113922,1.584906
13.0,174.0,0.049366,1.5


In [None]:
with open('latex_table.txt', 'w') as f:
    f.write(entity_graph_df_aggr.to_latex())

In [None]:
from copy import copy

In [None]:
with open('sanity_check_data_entity_graph.json') as f:
    sanity_check_data_entity_graph = json.load(f)

In [None]:
sanity_check_data_entity_graph.keys()

dict_keys(['real', 'random'])

In [None]:
real_prediction =  [scoring_entity_graph_beam_search(json_graph.node_link_graph(sample[2]), sample[1], sample[3][0]) for sample in sanity_check_data_entity_graph['real']]
random_prediction = [scoring_entity_graph_beam_search(json_graph.node_link_graph(sample[2]), sample[1], sample[3][0]) for sample in sanity_check_data_entity_graph['random']]



In [None]:
real_prediction

[{'order': [4, 8, 5, 2, 6, 3, 7, 1, 0], 'score': -17.67351146782644},
 {'order': [0, 2, 3, 1], 'score': -7.718962635598061},
 {'order': [5, 0, 4, 1, 6, 3, 7, 9, 15, 11, 2, 14, 17, 13, 16, 10, 12, 8],
  'score': -35.530161927976174},
 {'order': [7, 0, 3, 2, 6, 8, 5, 1, 9, 4], 'score': -19.238074785558357},
 {'order': [2, 1, 10, 3, 16, 6, 18, 8, 0, 5, 13, 9, 17, 12, 11, 7, 15, 4, 14],
  'score': -36.96463414296116}]

In [None]:
from statistics import mean
mean(sample['score'] for sample in real_prediction)

-23.425068991984038

In [None]:
from statistics import mean
mean(sample['score'] for sample in random_prediction)

-21.887044893920315