<a href="https://colab.research.google.com/github/ElizavetaNosova/Ordering-text-quest-fragments/blob/main/Attention_%2B_pointer_network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn
import json
from networkx.readwrite import json_graph
import math
import numpy as np
from tqdm import tqdm
import random
from gensim.models import FastText
from nltk import wordpunct_tokenize

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
#@title Текст заголовка по умолчанию
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
import os

In [None]:
os.chdir('gdrive/MyDrive')

In [None]:
fasttext_model = FastText.load_fasttext_format('cc.ru.300.bin')

In [None]:
fasttext_model['кошка'].shape

  """Entry point for launching an IPython kernel.


(300,)

In [None]:
np.zeros(300).shape

(300,)

In [None]:
class BatchEmbedder:
    def __init__(self, fasttext_model):
        self.fasttext_model = fasttext_model
    
    def embed_tokens(self, tokens:list):
        return [self.fasttext_model[token] for token in tokens if token in self.fasttext_model]
    
    def __call__(self, tokenized_texts):
        max_len = max([len(tokenized_text) for tokenized_text in tokenized_texts])
        embedded_texts = [self.embed_tokens(tokenized_text) for tokenized_text in tokenized_texts]
        embedded_texts = [text for text in embedded_texts if text]
        padded_embedded_tokens = [self.pad(embedded_text, max_len) for embedded_text in embedded_texts]
        return torch.tensor(padded_embedded_tokens)
        
    def pad(self, embeddings:list, max_len:int):
        sequence_beginning = embeddings[:max_len]
        embedding_dim = len(embeddings[0])
        pads = [np.zeros(embedding_dim) for i in range(max_len-len(sequence_beginning))]
        return pads + sequence_beginning
  


In [None]:
ORDERING_DATA_DIRECTORY = 'tokenized_ordering_train'

In [None]:
ORDERING_DATA_DIRECTORY_TEST = 'tokenized_ordering_test_joined'

In [None]:
class QuestOrderDataset(torch.utils.data.Dataset):
    def __init__(self, paths_directory, embedder):
        super().__init__()
        self.directory = paths_directory
        self.files = os.listdir(paths_directory)
        self.embedder=embedder
      
    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file = os.path.join(self.directory, self.files[idx])
        _, order, graph_json_data, __ = json.load(open(file, encoding='utf-8'))
        graph = json_graph.node_link_graph(graph_json_data)
        nodes = graph.nodes()
        tokenized_fragments = [self.get_tokenized_text(nodes[node]) for node in order]
        return torch.squeeze(self.embedder(tokenized_fragments))



    def get_tokenized_text(self, node):
        if 'tokenized_text' in node:
            return node['tokenized_text']
        elif 'fragment_text' in node and isinstance(node['fragment_text'], str):
            return wordpunct_tokenize(node['fragment_text'])
        else:
            return []

In [None]:
class SourceTextQuestOrderDataset(torch.utils.data.Dataset):
    def __init__(self, paths_directory):
        super().__init__()
        self.directory = paths_directory
        self.files = os.listdir(paths_directory)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file = os.path.join(self.directory, self.files[idx])
        _, order, graph_json_data, __ = json.load(open(file, encoding='utf-8'))
        graph = json_graph.node_link_graph(graph_json_data)
        nodes = graph.nodes()
        fragments = [self.get_text(nodes[node]) for node in order]
        return fragments



    def get_text(self, node):
        if 'joined_fragment_text' in node and isinstance(node['joined_fragment_text'], str):
            return node['joined_fragment_text']
        elif  'fragment_text' in node and  isinstance(node['fragment_text'], str):
            return node['fragment_text']
        else:
            return ''

In [None]:
embedder = BatchEmbedder(fasttext_model)

In [None]:
train_dataset = QuestOrderDataset('tokenized_ordering_train', embedder)

In [None]:
test_dataset =  QuestOrderDataset('tokenized_ordering_test', embedder)

In [None]:
test_dataset_texts = SourceTextQuestOrderDataset(ORDERING_DATA_DIRECTORY_TEST)

In [None]:
class DeepAttentiveOrderingNetwork(torch.nn.Module):
    def __init__(self, embedding_dim=300, encoder_lstm_hidden_dim=50, num_heads=5, attention_iterations=5, dropout=0.1, max_path_len=40, dim_feed_forward=500, pointer_embedding_dim=15):
        super().__init__()
        self.iterations = attention_iterations
  
        self.fragment_encoder_lstm = torch.nn.LSTM(embedding_dim, encoder_lstm_hidden_dim, bidirectional=True, batch_first=True)
        lstm_encoder_output_dim = encoder_lstm_hidden_dim*4
        self.paragraph_encoder_layer1 = torch.nn.TransformerEncoderLayer(d_model=lstm_encoder_output_dim, nhead=num_heads, dim_feedforward=dim_feed_forward)
        self.paragraph_encoder_layer2 = torch.nn.TransformerEncoderLayer(d_model=lstm_encoder_output_dim, nhead=num_heads, dim_feedforward=dim_feed_forward)

       #pointer module
        self.history_lstm = torch.nn.LSTM(lstm_encoder_output_dim, lstm_encoder_output_dim, batch_first=True)
        self.history_linear = torch.nn.Linear(lstm_encoder_output_dim, pointer_embedding_dim)
        self.candidates_linear = torch.nn.Linear(lstm_encoder_output_dim, pointer_embedding_dim)
        self.pointing_linear = torch.nn.Linear(pointer_embedding_dim, 1)


    def forward(self, fragments_in_correct_order):
        fragments_representation = self.encode(fragments_in_correct_order)
        global_representation = self.get_global_representation(fragments_representation)
        history = torch.cat((torch.unsqueeze(global_representation, 0), fragments_representation.clone()))
        history = torch.unsqueeze(history, 0)
        scores = self.point(history, fragments_representation)
        return scores
  
    def encode(self, embedded_fragments):
        representation = self.lstm_encode(embedded_fragments)
        representation = torch.unsqueeze(representation, 0)
        for i in range(self.iterations-1):
            representation = self.paragraph_encoder_layer1(representation)
        representation = self.paragraph_encoder_layer2(representation)
        return torch.squeeze(representation)

    def lstm_encode(self, embeddings):
        _, representation = self.fragment_encoder_lstm(embeddings.double())
        representation_vector = torch.cat((representation[0].clone()[0].clone(), representation[0].clone()[1].clone(), representation[1].clone()[0].clone(), representation[1].clone()[1].clone()),1)
        return torch.tanh(representation_vector)

    def get_global_representation(self, fragments_representations):
        return torch.mean(fragments_representations, 0)

    def point(self, history, candidates, shuffle=True):
        history, _ = self.history_lstm(history)
        history = torch.tanh(torch.squeeze(self.history_linear(history)))
        if shuffle:
             candidates, new_order = self.controlled_shuffle(candidates)
        else:
            new_order = [i for i in range(candidates.shape[0])]
        candidates = torch.tanh(self.candidates_linear(candidates))
        scores = torch.zeros(history.shape[0]-1, candidates.shape[0])
        for step_idx in range(history.shape[0]-1):
            #history_step = history[step_idx, :].clone()
            history_step = history[step_idx].clone()
            for candidate_idx in range(candidates.shape[0]):
                candidate = candidates[candidate_idx, :].clone()   
                scores[step_idx, candidate_idx] = torch.squeeze(torch.tanh(self.pointing_linear(candidate+history_step)))
        return torch.softmax(scores, 0), new_order

    def controlled_shuffle(self, candidates_tensor):
        ordering_list = [(i, random.random(), row) for i, row in enumerate(candidates_tensor)]
        ordering_list = sorted(ordering_list, key=lambda x: x[1])
        new_order = [row_data[0] for row_data in ordering_list]
        new_candidates_tensor = torch.cat(tuple([torch.unsqueeze(row_data[2], 0) for row_data in ordering_list]), 0)
        return new_candidates_tensor, new_order

In [None]:
model = DeepAttentiveOrderingNetwork()
model.double()


model.train()
optimizer =  torch.optim.Adam([{'params': model.parameters()}], 
                               lr = 1e-3)
criterion = torch.nn.BCELoss()

In [None]:
losses = []
for i in tqdm(range(len(train_dataset))):
    path = train_dataset[i]
    prediction, correct_order = model(path)
    correct = torch.zeros(prediction.shape)
    for i, position in enumerate(correct_order):
        correct[position][i] += 1
    loss = criterion(prediction, correct)
    losses.append(float(loss))
    loss.backward()
    if len(losses)%10 == 0:
        optimizer.step()
        optimizer.zero_grad()
    if len(losses)%150==0:
        torch.save(model.state_dict(), 'ordering_attention_fix_activation.pth')

  
  
100%|██████████| 7283/7283 [2:16:56<00:00,  1.13s/it]


In [None]:
model.load_state_dict(torch.load('ordering_attention_fix_activation.pth'))

<All keys matched successfully>

In [None]:
def predict_order_beam_search(test_sample, model, num_best_candidates=3):
    model.eval()
    encoded_sample = model.encode(test_sample)
    history = torch.unsqueeze(torch.unsqueeze(model.get_global_representation(encoded_sample), 0), 0)
    #history = torch.unsqueeze(model.get_global_representation(encoded_sample), 0)
    predicted_orders = [{'history':history, 'predicted_order':[], 'probability':1}]
    for i in range(test_sample.shape[0]-2):
        for order_data in predicted_orders:
            current_history = order_data['history']
            order_candidate = order_data['predicted_order']
            padded_current_history = torch.cat((current_history, torch.zeros(1, len(test_sample)-current_history.shape[0], current_history.shape[-1])),1)
            scores, _ = model.point(padded_current_history, encoded_sample, shuffle=False)
            #print(scores)
            current_step_scores = scores[i]
            for already_used_sent in order_candidate:
                current_step_scores[already_used_sent] *= -math.inf
            for i in range(min(num_best_candidates, len(test_sample)-len(order_candidate))):
                best_candidate_idx = int(current_step_scores.argmax(-1))
                order_data['predicted_order'].append(best_candidate_idx)

                order_data['history'] = torch.cat((order_data['history'], torch.unsqueeze(torch.unsqueeze(encoded_sample[best_candidate_idx], 0), 0)),1)

                order_data['probability'] *= float(current_step_scores[best_candidate_idx])
                current_step_scores[best_candidate_idx] *= -math.inf

        predicted_orders = sorted(predicted_orders, key=lambda x: x['probability'], reverse=True)[:num_best_candidates]
    return predicted_orders[0]['predicted_order']   

In [None]:
def scoring_predict_order_beam_search(texts, model, num_best_candidates=3, epsilon=10**(-9)):
    tokenized_texts = [wordpunct_tokenize(text) for text in texts]
    test_sample = embedder(tokenized_texts)
    model.eval()
    encoded_sample = model.encode(test_sample)
    history = torch.unsqueeze(torch.unsqueeze(model.get_global_representation(encoded_sample), 0), 0)
    #history = torch.unsqueeze(model.get_global_representation(encoded_sample), 0)
    predicted_orders = [{'history':history, 'predicted_order':[], 'probability':0}]
    for i in range(test_sample.shape[0]-2):
        for order_data in predicted_orders:
            current_history = order_data['history']
            order_candidate = order_data['predicted_order']
            padded_current_history = torch.cat((current_history, torch.zeros(1, len(test_sample)-current_history.shape[0], current_history.shape[-1])),1)
            scores, _ = model.point(padded_current_history, encoded_sample, shuffle=False)
            #print(scores)
            current_step_scores = scores[i]
            for already_used_sent in order_candidate:
                current_step_scores[already_used_sent] *= -math.inf
            for i in range(min(num_best_candidates, len(test_sample)-len(order_candidate))):
                best_candidate_idx = int(current_step_scores.argmax(-1))
                order_data['predicted_order'].append(best_candidate_idx)

                order_data['history'] = torch.cat((order_data['history'], torch.unsqueeze(torch.unsqueeze(encoded_sample[best_candidate_idx], 0), 0)),1)

                score = float(current_step_scores[best_candidate_idx]) if float(current_step_scores[best_candidate_idx]) > 0 else epsilon
                order_data['probability'] += float(current_step_scores[best_candidate_idx])
                current_step_scores[best_candidate_idx] *= -math.inf

        predicted_orders = sorted(predicted_orders, key=lambda x: x['probability'], reverse=True)[:num_best_candidates]
    return {'order':predicted_orders[0]['predicted_order'], 'score':predicted_orders[0]['probability']}

In [None]:
with open('sanity_check_data.json') as f:
    sanity_check_data = json.load(f)

In [None]:
real_quests_predictions = [scoring_predict_order_beam_search(path, model) for path in sanity_check_data['real']]
random_quests_predictions = [scoring_predict_order_beam_search(path, model) for path in sanity_check_data['random']]

  
  


In [None]:
real_quests_predictions

[{'order': [1, 3, 5, 4, 2, 0, 7, 6, 8], 'score': 0.8619930446147919},
 {'order': [0, 1, 3, 2], 'score': 1.166579708456993},
 {'order': [14, 10, 0, 4, 5, 6, 8, 7, 12, 9, 15, 16, 1, 2, 3, 11, 17, 13],
  'score': 0.769480399787426},
 {'order': [6, 8, 9, 2, 3, 4, 7, 5, 0, 1], 'score': 0.8388739749789238},
 {'order': [6, 7, 10, 0, 2, 3, 13, 11, 9, 14, 15, 1, 12, 4, 16, 17, 5, 8, 18],
  'score': 0.7643137294799089}]

In [None]:
from statistics import mean
mean([sample['score'] for sample in real_quests_predictions])

0.8802481714636088

In [None]:
mean([sample['score'] for sample in random_quests_predictions])

0.8802481662482023

In [None]:
result = []
for sample in tqdm(test_dataset):
    shuffled_sample, correct_order = model.controlled_shuffle(sample)
    prediction = predict_order_beam_search(sample, model)
    result.append({'correct': correct_order, 'predicted': prediction})

  
  
100%|██████████| 3430/3430 [4:57:58<00:00,  5.21s/it]


In [None]:
with open('basic_transformer_ordering_new.json', 'w', encoding='utf-8') as f:
    json.dump(result, f)

In [None]:
with open('basic_transformer_ordering_new.json') as f:
    result = json.load(f)

In [None]:
def longest_correct_subsequence(predicted, correct):
    correct_transitions = set([(item_from, item_to) for item_from, item_to in zip(correct[:-1], correct[1:])])
    predicted_transitions = [(item_from, item_to) for item_from, item_to in zip(predicted[:-1], predicted[1:])]
    predicted_transitions_are_correct = [transition in correct_transitions for transition in predicted_transitions]
    
    longest_correct_transitions_subsequence = 0
    current_correct_transitions_subsequence = 0
    #Add False as last item to include last real item checking into the loop
    for predicted_transition_is_correct in predicted_transitions_are_correct + [False]:
        if predicted_transition_is_correct:
            current_correct_transitions_subsequence += 1
        else:
            if current_correct_transitions_subsequence > longest_correct_transitions_subsequence:
                longest_correct_transitions_subsequence = current_correct_transitions_subsequence
            current_correct_transitions_subsequence = 0
    #return number of items in longest correct sequence (not number of transitions)
    return longest_correct_transitions_subsequence + 1 if longest_correct_transitions_subsequence else 0  

In [None]:
from scipy.stats import kendalltau
import pandas as pd

In [None]:
simple_transformer_pointer_df = pd.DataFrame(columns = ["sequence length", "Kendall's  tau", "Longest correct subsequence"])
for sequence_data in result:
    predicted = sequence_data['predicted']
    correct = sequence_data['correct']
    tau = kendalltau(predicted, correct).correlation
    lcs = longest_correct_subsequence(predicted, correct)
    simple_transformer_pointer_df.loc[len(simple_transformer_pointer_df)] = [len(correct), tau, lcs]
    

In [None]:
simple_transformer_pointer_df_aggr = simple_transformer_pointer_df[simple_transformer_pointer_df["sequence length"]<=30].groupby("sequence length").describe()[[("Kendall's  tau", 'count'),  ("Kendall's  tau",  'mean'), ('Longest correct subsequence',  'mean')]]
simple_transformer_pointer_df_aggr 

Unnamed: 0_level_0,Kendall's tau,Kendall's tau,Longest correct subsequence
Unnamed: 0_level_1,count,mean,mean
sequence length,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2
4.0,98.0,0.027211,1.336735
5.0,97.0,-0.051546,1.14433
6.0,149.0,-0.022819,1.375839
7.0,151.0,-0.002208,1.317881
8.0,130.0,-0.001648,1.307692
9.0,192.0,0.039352,1.307292
10.0,175.0,-0.005206,1.194286
11.0,225.0,0.015919,1.342222
12.0,226.0,-0.014079,1.283186
13.0,184.0,-0.004041,1.336957


In [None]:
with open('latex_draft.txt', 'w') as f:
    f.write(simple_transformer_pointer_df_aggr.to_latex())