<a href="https://colab.research.google.com/github/Derinhelm/graph_syntax_parsing/blob/main/Parser.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Installation

In [1]:
!pip install -q transformers

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m14.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m22.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m37.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m42.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
!pip install torch_geometric

from IPython.display import clear_output

clear_output()

# Logging

In [3]:

import logging
logger = logging.getLogger('my_logger')

# Remove all handlers associated with the root logger object.
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

logging.basicConfig(
    filename='app.log', # write to this file
    filemode='a', # open in append mode
    format='%(name)s - %(levelname)s - %(message)s'
    )

logging.getLogger().setLevel(logging.INFO)
logging.getLogger().setLevel(logging.DEBUG)

logging.getLogger("urllib3.connectionpool").disabled = True
logging.getLogger("filelock").disabled = True

logging.warning('This will get logged to a file')

In [4]:
logging.warning('New warning')
logging.debug('New debug')
logging.info('New info')


In [5]:
from pathlib import Path
print(Path('/content/app.log').read_text())

root - DEBUG - New debug
root - INFO - New info



#uuparser/utils.py

In [6]:
from collections import defaultdict, Counter
import re
import os,time
from operator import itemgetter
import random
import json
import pathlib
import subprocess
import sys

import tqdm



class ConllEntry:
    def __init__(self, id, form, lemma, pos, cpos, feats=None, parent_id=None, relation=None,
        deps=None, misc=None):

        self.id = id
        self.form = form
        self.cpos = cpos
        self.pos = pos
        self.parent_id = parent_id
        self.relation = relation

        self.lemma = lemma
        self.feats = feats
        self.deps = deps
        self.misc = misc

        self.pred_parent_id = None
        self.pred_relation = None

        self.pred_pos = None
        self.pred_cpos = None


    def __str__(self):
        '''values = [str(self.id), self.form, self.lemma, \
                  self.pred_cpos if self.pred_cpos else self.cpos,\
                  self.pred_pos if self.pred_pos else self.pos,\
                  self.feats, str(self.pred_parent_id) if self.pred_parent_id \
                  is not None else str(self.parent_id), self.pred_relation if\
                  self.pred_relation is not None else self.relation, \
                  self.deps, self.misc]
        return '\t'.join(['_' if v is None else v for v in values])'''
        return self.form + " " + str(self.id)

class ParseForest:
    def __init__(self, sentence):
        self.roots = list(sentence)

        for root in self.roots:
            root.children = []
            root.scores = None # TODO: зачем?
            root.parent = None
            root.pred_parent_id = None
            root.pred_relation = None
            root.vecs = None

    def __len__(self):
        return len(self.roots)


    def Attach(self, parent_index, child_index):
        parent = self.roots[parent_index]
        child = self.roots[child_index]

        child.pred_parent_id = parent.id
        del self.roots[child_index]

    def __str__(self):
        return " ".join(map(str, self.roots))


def isProj(sentence):
    forest = ParseForest(sentence)
    unassigned = {entry.id: sum([1 for pentry in sentence if pentry.parent_id == entry.id]) for entry in sentence}

    for _ in xrange(len(sentence)):
        for i in xrange(len(forest.roots) - 1):
            if forest.roots[i].parent_id == forest.roots[i+1].id and unassigned[forest.roots[i].id] == 0:
                unassigned[forest.roots[i+1].id]-=1
                forest.Attach(i+1, i)
                break
            if forest.roots[i+1].parent_id == forest.roots[i].id and unassigned[forest.roots[i+1].id] == 0:
                unassigned[forest.roots[i].id]-=1
                forest.Attach(i, i+1)
                break

    return len(forest.roots) == 1


def get_irels(data):
    """
    Collect frequencies of words, cpos, pos and deprels + languages.
    """

    # could use sets directly rather than counters for most of these,
    # but having the counts might be useful in the future or possibly for debugging etc
    relCount = Counter()

    for sentence in data:
        for node in sentence:
            if isinstance(node, ConllEntry):
                relCount.update([node.relation])

    return list(relCount.keys())


def generate_root_token():
    return ConllEntry(0, '*root*', '*root*', 'ROOT-POS', 'ROOT-CPOS', '_', -1,
        'rroot', '_', '_')


def read_conll(filename, drop_nproj=False, train=True):
    fh = open(filename,'r',encoding='utf-8')
    logging.info(f"Reading {filename}")
    ts = time.time()
    dropped = 0
    sents_read = 0
    sentences = []
    tokens = [generate_root_token()]
    words = [] # all words from the dataset
    for line in fh:
        tok = line.strip().split('\t')
        if not tok or line.strip() == '': # empty line, add sentence to list or yield
            if len(tokens) > 1:
                sents_read += 1
                conll_tokens = [t for t in tokens if isinstance(t,ConllEntry)]
                if not drop_nproj or isProj(conll_tokens):
                    # keep going if it's projective or we're not dropping non-projective sents
                    if train:
                        inorder_tokens = inorder(conll_tokens)
                        for i,t in enumerate(inorder_tokens):
                            t.projective_order = i
                        for tok in conll_tokens:
                            tok.rdeps = [i.id for i in conll_tokens if i.parent_id == tok.id]
                            if tok.id != 0:
                                tok.parent_entry = [i for i in conll_tokens if i.id == tok.parent_id][0]
                    sentences.append(tokens)
                else:
                    logging.debug('Non-projective sentence dropped')
                    dropped += 1
            tokens = [generate_root_token()]
        else:
            if line[0] == '#' or '-' in tok[0] or '.' in tok[0]: # a comment line, add to tokens as is
                tokens.append(line.strip())
            else: # an actual ConllEntry, add to tokens
                if tok[2] == "_":
                    tok[2] = tok[1].lower()
                lemma = tok[2]
                words.append(lemma)
                token = ConllEntry(int(tok[0]), tok[1], lemma, tok[4], tok[3], tok[5], \
                    int(tok[6]) if tok[6] != '_' else -1, tok[7], tok[8], tok[9])

                tokens.append(token)

# deal with case where there are still tokens, that aren`t in sentences list
# e.g. when there is no newline at end of file
    if len(tokens) > 1:
        sentences.append(tokens)

    logging.debug(f'{sents_read} sentences read')

    te = time.time()
    logging.info(f'Time: {te-ts:.2g}s')
    return sentences, words


def write_conll(fn, conll_gen):
    logging.info(f"Writing to {fn}")
    sents = 0
    with open(fn, 'w', encoding='utf-8') as fh:
        for sentence in conll_gen:
            sents += 1
            for entry in sentence[1:]:
                fh.write(str(entry) + '\n')
            fh.write('\n')
        logging.debug(f"Wrote {sents} sentences")


numberRegex = re.compile("[0-9]+|[0-9]+\\.[0-9]+|[0-9]+[0-9,]+");
def normalize(word):
    return 'NUM' if numberRegex.match(word) else word.lower()


def inorder(sentence):
    queue = [sentence[0]]
    def inorder_helper(sentence,i):
        results = []
        left_children = [entry for entry in sentence[:i] if entry.parent_id == i]
        for child in left_children:
            results += inorder_helper(sentence,child.id)
        results.append(sentence[i])

        right_children = [entry for entry in sentence[i:] if entry.parent_id == i ]
        for child in right_children:
            results += inorder_helper(sentence,child.id)
        return results
    return inorder_helper(sentence,queue[0].id)


def set_seeds():
    python_seed = 1
    logging.debug("Using default Python seed")
    random.seed(python_seed)


def generate_seed():
    return random.randint(0,10**9) # this range seems to work for Dynet and Python's random function


# uuparser/multilayer_perceptron.py


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.nn import SAGEConv, to_hetero
import torch

class GNNBlock(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x




# uuparser/arc_hybrid.py

In [8]:
from operator import itemgetter
from itertools import chain
import time, random
import numpy as np
from copy import deepcopy
from collections import defaultdict
import json

import torch
import torch.nn as nn
import torch.optim as optim

import tqdm


from jax.numpy import int32
from torch_geometric.data import HeteroData

import torch

## Config

In [9]:
def create_stack_edges(stack):
    if len(stack) == 0:
        return torch.stack((torch.tensor([], dtype=torch.int32), torch.tensor([], dtype=torch.int32)), dim=0)
    stack_edges = []
    if len(stack) == 1:
        stack_edges.append((stack[0].id - 1, stack[0].id - 1)) # temporary solution
    else:
        for i in range(len(stack) - 1): # Represents every two consecutive stack nodes as an edge
            stack_edges.append((stack[i].id - 1, stack[i + 1].id - 1))
    stack_edges = tuple(zip(*stack_edges))
    stack_edges = [torch.tensor(stack_edges[0]), torch.tensor(stack_edges[1])]
    return torch.stack(stack_edges, dim=0)

def create_buffer_edges(buffer):
    if len(buffer) == 0 or len(buffer) == 1: # Last element is a technical root element.
        return torch.stack((torch.tensor([], dtype=torch.int32), torch.tensor([], dtype=torch.int32)), dim=0)
    buffer_edges = []
    if len(buffer) == 2: # Last element is a technical root element.
        buffer_edges.append((buffer[0].id - 1, buffer[0].id - 1)) # temporary solution
    else:
        for i in range(len(buffer) - 2): # Last element is a technical root element.
        # Represents every two consecutive buffer nodes as an edge
            buffer_edges.append((buffer[i].id - 1, buffer[i + 1].id - 1))
    buffer_edges = tuple(zip(*buffer_edges))
    buffer_edges = [torch.tensor(buffer_edges[0]), torch.tensor(buffer_edges[1])]
    return torch.stack(buffer_edges, dim=0)

def create_graph_edges(sentence):
    graph_edges = []
    for node in sentence:
        if node.pred_parent_id is not None and node.pred_parent_id != 0 and node.pred_parent_id != -1:
            graph_edges.append((node.pred_parent_id - 1, node.id - 1))
    if len(graph_edges) == 0:
        return torch.stack((torch.tensor([], dtype=torch.int32), torch.tensor([], dtype=torch.int32)), dim=0)
    graph_edges = tuple(zip(*graph_edges))
    graph_edges = [torch.tensor(graph_edges[0]), torch.tensor(graph_edges[1])]
    return torch.stack(graph_edges, dim=0)

class Configuration:
    def __init__(self, sentence, irels):
        self.sentence = deepcopy(sentence)
        # ensures we are working with a clean copy of sentence and allows memory to be recycled each time round the loop
        self.sentence = [entry for entry in self.sentence if isinstance(entry, ConllEntry)]
        self.sentence = self.sentence[1:] + [self.sentence[0]]
        self.stack = ParseForest([])
        self.buffer = ParseForest(self.sentence)
        for root in self.sentence:
            root.relation = root.relation if root.relation in irels else 'runk'


    def config_to_graph(self, embeds):
        word_embeds = torch.empty((len(self.sentence), 768))
        for i in range(len(self.sentence) - 1): # Last element is a technical root element.
            word_embeds[i] = embeds[self.sentence[i].lemma]

        data = HeteroData()
        data['node']['x'] = word_embeds

        data[('node', 'graph', 'node')].edge_index = create_graph_edges(self.sentence)
        data[('node', 'stack', 'node')].edge_index = create_stack_edges(self.stack.roots)
        data[('node', 'buffer', 'node')].edge_index = create_buffer_edges(self.buffer.roots)
        return data

    def apply_transition(self, best):
        if best[1] == SHIFT:
            self.stack.roots.append(self.buffer.roots[0])
            del self.buffer.roots[0]

        elif best[1] == SWAP:
            child = self.stack.roots.pop()
            self.buffer.roots.insert(1,child)

        elif best[1] == LEFT_ARC:
            child = self.stack.roots.pop()
            parent = self.buffer.roots[0]

        elif best[1] == RIGHT_ARC:
            child = self.stack.roots.pop()
            parent = self.stack.roots[-1]

        if best[1] == LEFT_ARC or best[1] == RIGHT_ARC:
            #attach
            child.pred_parent_id = parent.id
            child.pred_relation = best[0]

    def get_stack_ids(self):
        return [sitem.id for sitem in self.stack.roots]

    def is_stack_not_empty(self):
        return len(self.stack) > 0

    def get_stack_last_element(self):
        return self.stack.roots[-1] # Last stack element

    def get_stack_penultimate_element(self):
        return self.stack.roots[-2] # Penultimate stack element

    def get_buffer_head(self):
        return self.buffer.roots[0] # Head buffer element

    def get_buffer_tail(self):
        return self.buffer.roots[1:] if len(self.buffer) > 1 else [] # Tail of buffer

    def get_sentence(self):
        return self.sentence

    def is_end(self):
        return len(self.buffer) == 1 and len(self.stack) == 0

    def check_left_arc_conditions(self):
        return len(self.stack) > 0

    def check_not_train_left_arc_conditions(self):
            #(avoiding the multiple roots problem: disallow left-arc from root
            #if stack has more than one element
        return self.check_left_arc_conditions() and not (self.buffer.roots[0].id == 0 and len(self.stack) > 1)

    def check_right_arc_conditions(self):
        return len(self.stack) > 1

    def check_shift_conditions(self):
        return self.buffer.roots[0].id != 0

    def check_swap_conditions(self):
        return len(self.stack) > 0 and self.stack.roots[-1].id < self.buffer.roots[0].id

    def __str__(self):
        return "stack:" + str(self.stack) + "\n" + "buffer:" + str(self.buffer) + "\n"
    
    def calculate_left_cost(self):
        if not self.check_left_arc_conditions():
            return 1
        
        s0 = self.get_stack_last_element() # Last stack element
        b = self.get_buffer_head() # Head buffer element
        left_cost = len(s0.rdeps) + int(s0.parent_id != b.id and s0.id in s0.parent_entry.rdeps)

        if self.check_swap_conditions() and s0.projective_order > b.projective_order:
            left_cost = 1
        return left_cost
    
    def calculate_right_cost(self):
        if not self.check_right_arc_conditions():
            return 1

        s1 = self.get_stack_penultimate_element() # Penultimate stack element
        s0 = self.get_stack_last_element() # Last stack element
        b = self.get_buffer_head() # Head buffer element 
        
        right_cost = len(s0.rdeps) + int(s0.parent_id != s1.id and s0.id in s0.parent_entry.rdeps)

        if self.check_swap_conditions() and s0.projective_order > b.projective_order:
            right_cost = 1

        return right_cost

    def calculate_shift_cost(self):
        if not self.check_shift_conditions():
            shift_cost = 1
            shift_case = 0
            return shift_cost, shift_case
        
        b = self.get_buffer_head() # Head buffer element
        beta = self.get_buffer_tail() # Tail (list) of buffer
        
        if len([item for item in beta if item.projective_order < b.projective_order and item.id > b.id ])> 0:
            shift_cost = 0
            shift_case = 1
        else:
            stack_ids = self.get_stack_ids()
            shift_cost = len([d for d in b.rdeps if d in stack_ids]) + \
                int(self.is_stack_not_empty() and b.parent_id in stack_ids[:-1] and b.id in b.parent_entry.rdeps)
            shift_case = 2

        if self.check_swap_conditions():
            s0 = self.get_stack_last_element() # Last stack element
            if s0.projective_order > b.projective_order:
                shift_cost = 1

        return shift_cost, shift_case

    def calculate_swap_cost(self):
        if not self.check_swap_conditions():
            return 1

        s0 = self.get_stack_last_element() # Last stack element
        b = self.get_buffer_head() # Head buffer element

        if s0.projective_order > b.projective_order:
            swap_cost = 0
        else:
            swap_cost = 1

        return swap_cost
    
    def dynamic_oracle_updates(self, best, shift_case):
        stack_ids = self.get_stack_ids()
        if best[1] == SHIFT:
            if shift_case == 2:
                b = self.get_buffer_head() # Head buffer element
                if b.parent_entry.id in stack_ids[:-1] and b.id in b.parent_entry.rdeps:
                    b.parent_entry.rdeps.remove(b.id)
                blocked_deps = [d for d in b.rdeps if d in stack_ids]
                for d in blocked_deps:
                    b.rdeps.remove(d)

        elif best[1] == LEFT_ARC or best[1] == RIGHT_ARC:
            s0 = self.get_stack_last_element() # Last stack element
            s0.rdeps = []
            if s0.id in s0.parent_entry.rdeps:
                s0.parent_entry.rdeps.remove(s0.id)

In [None]:
class GNNNet:
    def __init__(self, options, out_irels_dims):
        self.hidden_dims = options["hidden_dims"]
        self.out_irels_dims = out_irels_dims

        self.metadata = (['node'], [('node', 'graph', 'node'), ('node', 'stack', 'node'), ('node', 'buffer', 'node')])
        self.unlabeled_GNN = GNNBlock(hidden_channels=self.hidden_dims, out_channels=4)
        self.unlabeled_GNN = to_hetero(self.unlabeled_GNN, self.metadata, aggr='sum')

        self.labeled_GNN = GNNBlock(hidden_channels=self.hidden_dims, out_channels=2*self.out_irels_dims+2)
        self.labeled_GNN = to_hetero(self.labeled_GNN, self.metadata, aggr='sum')

        self.unlabeled_optimizer = optim.Adam(self.unlabeled_GNN.parameters(), lr=options["learning_rate"])
        self.labeled_optimizer = optim.Adam(self.labeled_GNN.parameters(), lr=options["learning_rate"])

    def Load(self, epoch):
        unlab_path = 'model_unlab' + '_' + str(epoch)
        lab_path = 'model_lab' + '_' + str(epoch)

        self.unlabeled_GNN = GNNBlock(hidden_channels=self.hidden_dims, out_channels=4)
        self.labeled_GNN = GNNBlock(hidden_channels=self.hidden_dims, out_channels=2*self.out_irels_dims+2)

        unlab_checkpoint = torch.load(unlab_path)
        self.unlabeled_GNN.load_state_dict(unlab_checkpoint['model_state_dict'], strict=False)

        self.unlabeled_GNN = to_hetero(self.unlabeled_GNN, self.metadata, aggr='sum')
        self.labeled_GNN = to_hetero(self.labeled_GNN, self.metadata, aggr='sum')

        lab_checkpoint = torch.load(lab_path)
        self.labeled_GNN.load_state_dict(lab_checkpoint['model_state_dict'], strict=False)

    def Save(self, epoch):
        unlab_path = 'model_unlab' + '_' + str(epoch)
        lab_path = 'model_lab' + '_' + str(epoch)
        logging.info(f'Saving unlabeled model to {unlab_path}')
        torch.save({'epoch': epoch, 'model_state_dict': self.unlabeled_GNN.state_dict()}, unlab_path)
        logging.info(f'Saving labeled model to {lab_path}')
        torch.save({'epoch': epoch, 'model_state_dict': self.labeled_GNN.state_dict()}, lab_path)

    def evaluate(self, config, embeds):
        graph = config.config_to_graph(embeds)
        uscrs = self.unlabeled_GNN(graph.x_dict, graph.edge_index_dict)
        uscrs = torch.sum(uscrs['node'], dim=0)
        scrs = self.labeled_GNN(graph.x_dict, graph.edge_index_dict)
        scrs = torch.sum(scrs['node'], dim=0)
        return scrs, uscrs
    
    def error_processing(self, errs):    
        self.labeled_optimizer.zero_grad()
        self.unlabeled_optimizer.zero_grad()
        eerrs = torch.sum(torch.tensor(errs, requires_grad=True))
        eerrs.backward()
        self.labeled_optimizer.step() # TODO Какой из оптимизаторов ???
        self.unlabeled_optimizer.step()

## Parser

In [10]:
class Parser:
    def __init__(self, options, irels):

        global LEFT_ARC, RIGHT_ARC, SHIFT, SWAP
        LEFT_ARC, RIGHT_ARC, SHIFT, SWAP = 0, 1, 2, 3 # TODO: сделать глобальными переменными, используются в Configuration.

        self.oracle = options["oracle"]
        self.net = GNNNet(options, len(irels))
        self.irels = irels

    def Load(self, epoch):
        self.net.Load(epoch)

    def Save(self, epoch):
        self.net.Save(epoch)

    def test_evaluate(self, config, embeds):
        """
        ret = [left arc,
               right arc
               shift]

        RET[i] = (rel, transition, score1) for shift, l_arc and r_arc
         shift = 2 (==> rel=None) ; l_arc = 0; r_acr = 1
        """
        scrs, uscrs = self.net.evaluate(config, embeds)

        #transition conditions
        right_arc_conditions = config.check_right_arc_conditions()
        shift_conditions = config.check_shift_conditions()
        swap_conditions = config.check_swap_conditions()

        #(avoiding the multiple roots problem: disallow left-arc from root
        #if stack has more than one element
        left_arc_conditions = config.check_not_train_left_arc_conditions()

        s1,r1 = max(zip(scrs[2::2], self.irels))
        s2,r2 = max(zip(scrs[3::2], self.irels))
        s1 = s1 + uscrs[2]
        s2 = s2 + uscrs[3]
        ret = [ [ (r1, LEFT_ARC, s1) ] if left_arc_conditions else [],
                [ (r2, RIGHT_ARC, s2) ] if right_arc_conditions else [],
                [ (None, SHIFT, scrs[0] + uscrs[0]) ] if shift_conditions else [] ,
                [ (None, SWAP, scrs[1] + uscrs[1]) ] if swap_conditions else [] ]
        return ret

    def Predict(self, data, embeds):
        reached_max_swap = 0

        pbar = tqdm.tqdm(
            data,
            desc="Parsing",
            unit="sentences",
            mininterval=1.0,
            leave=False,
            disable=False,
        )

        for iSentence, osentence in enumerate(pbar,1):
            config = Configuration(osentence, self.irels)
            max_swap = 2*len(osentence)
            reached_swap_for_i_sentence = False
            iSwap = 0

            while not config.is_end():
                scores = self.test_evaluate(config, embeds)
                best = max(chain(*(scores if iSwap < max_swap else scores[:3] )), key = itemgetter(2) )
                if iSwap == max_swap and not reached_swap_for_i_sentence:
                    reached_max_swap += 1
                    reached_swap_for_i_sentence = True
                    logging.debug(f"reached max swap in {reached_max_swap:d} out of {iSentence:d} sentences")
                config.apply_transition(best)
                if best[1] == SWAP:
                    iSwap += 1

            #keep in memory the information we need, not all the vectors
            oconll_sentence = [entry for entry in osentence if isinstance(entry, ConllEntry)]
            oconll_sentence = oconll_sentence[1:] + [oconll_sentence[0]]
            conll_sentence = config.get_sentence()
            for tok_o, tok in zip(oconll_sentence, conll_sentence):
                tok_o.pred_relation = tok.pred_relation
                tok_o.pred_parent_id = tok.pred_parent_id
            yield osentence

    def calculate_left_scores(self, config, scrs, uscrs):
        left_arc_conditions = config.check_left_arc_conditions()
        if not left_arc_conditions:
            return [], []
        left_cost = config.calculate_left_cost()
        left_scores = [(rel, LEFT_ARC, scrs[2 + j * 2] + uscrs[2]) \
                    for j, rel in enumerate(self.irels)]
        if left_cost == 0:
            left_valid_scores = [(rel, trans, sc) for (rel, trans, sc) in left_scores \
                if rel == config.get_stack_last_element().relation]
            left_wrong_scores = [(rel, trans, sc) for (rel, trans, sc) in left_scores \
                if rel != config.get_stack_last_element().relation]

        else:
            left_valid_scores = []
            left_wrong_scores = left_scores
        return left_valid_scores, left_wrong_scores

    def calculate_right_scores(self, config, scrs, uscrs):
        right_arc_conditions = config.check_right_arc_conditions()
        if not right_arc_conditions:
            return [], []
        right_cost = config.calculate_right_cost()
        right_scores = [ (rel, RIGHT_ARC, scrs[3 + j * 2] + uscrs[3]) \
                    for j, rel in enumerate(self.irels) ]

        if right_cost == 0:
            right_valid_scores = [(rel, trans, sc) for (rel, trans, sc) in right_scores \
                if rel == config.get_stack_last_element().relation]
            right_wrong_scores = [(rel, trans, sc) for (rel, trans, sc) in right_scores \
                if rel != config.get_stack_last_element().relation]
        else:
            right_valid_scores = []
            right_wrong_scores = right_scores

        return right_valid_scores, right_wrong_scores

    def calculate_shift_scores(self, config, scrs, uscrs):
        shift_cost, shift_case = config.calculate_shift_cost()
        shift_conditions = config.check_shift_conditions()
        if not shift_conditions:
             return [], [], shift_case

        shift_scores = [ (None, SHIFT, scrs[0] + uscrs[0]) ]
        
        if shift_cost == 0:
            shift_valid_scores = shift_scores
            shift_wrong_scores = []
        else:
            shift_valid_scores = []
            shift_wrong_scores = shift_scores

        return shift_valid_scores, shift_wrong_scores, shift_case

    def calculate_swap_scores(self, config, scrs, uscrs):
        swap_conditions = config.check_swap_conditions()
        swap_cost = config.calculate_swap_cost()
        if not swap_conditions:
            return [], [], swap_cost
        
        swap_scores = [(None, SWAP, scrs[1] + uscrs[1])]

        if swap_cost == 0:
            swap_valid_scores = swap_scores
            swap_wrong_scores = []
        else:
            swap_valid_scores = []
            swap_wrong_scores = swap_scores

        return swap_valid_scores, swap_wrong_scores, swap_cost

    def error_append(self, info, best, bestValid, bestWrong, config):
        #labeled errors
        if best[1] == LEFT_ARC or best[1] == RIGHT_ARC:
            child = config.get_stack_last_element() # Last stack element
            if (child.pred_parent_id != child.parent_id or child.pred_relation != child.relation):
                info["lerrors"] += 1
                #attachment error
                if child.pred_parent_id != child.parent_id:
                    info["eerrors"] += 1

        if bestValid[2] < bestWrong[2] + 1.0:
            loss = bestWrong[2] - bestValid[2]
            info["mloss"] += 1.0 + bestWrong[2] - bestValid[2]
            info["eloss"] += 1.0 + bestWrong[2] - bestValid[2]
            info["errs"].append(loss)

        #??? when did this happen and why?
        if best[1] == 0 or best[1] == 2:
            info["etotal"] += 1

    def create_valid_wrong(self, config, scrs, uscrs):
        left_valid, left_wrong = self.calculate_left_scores(config, scrs, uscrs)
        right_valid, right_wrong = self.calculate_right_scores(config, scrs, uscrs)
        shift_valid, shift_wrong, shift_case = self.calculate_shift_scores(config, scrs, uscrs)
        swap_valid, swap_wrong, swap_cost = self.calculate_swap_scores(config, scrs, uscrs)

        valid = chain(left_valid, right_valid, shift_valid, swap_valid)
        wrong = chain(left_wrong, right_wrong, shift_wrong, swap_wrong, [(None, 4, -float('inf'))])
        # (None, 4, -float('inf')) is used to ensure that at least one element will be.
        return valid, wrong, shift_case, swap_cost

    def create_best(self, bestValid, bestWrong, swap_cost):
        #force swap
        if swap_cost == 0:
            best = bestValid
        else:
        #select a transition to follow
        # + aggresive exploration
        #1: might want to experiment with that parameter
            if bestWrong[1] == SWAP:
                best = bestValid
            else:
                best = bestValid if ( (not self.oracle) or (bestValid[2] - bestWrong[2] > 1.0) \
                    or (bestValid[2] > bestWrong[2] and random.random() > 0.1) ) else bestWrong
        return best

    def train_sentence(self, sentence, info, embeds):
            config = Configuration(sentence, self.irels)

            while not config.is_end():
                scrs, uscrs = self.net.evaluate(config, embeds)
                valid, wrong, shift_case, swap_cost = self.create_valid_wrong(config, scrs, uscrs)

                best_valid = max(valid, key=itemgetter(2))
                best_wrong = max(wrong, key=itemgetter(2))
                best = self.create_best(best_valid, best_wrong, swap_cost)

                #updates for the dynamic oracle
                if self.oracle: # TODO: проверить, что значит True/False (где dynamic/static)
                    config.dynamic_oracle_updates(best, shift_case)

                self.error_append(info, best, best_valid, best_wrong, config)
                
                config.apply_transition(best)

            return info
      
    def create_info(self):
        info = {}
        info["mloss"], info["eloss"], info["eerrors"], info["lerrors"], info["etotal"]  = 0.0, 0.0, 0, 0, 0
        info["errs"] = []
        info["iSentence"] = -1
        info["start"] = time.time()
        return info

    def train_logging(self, info):
        loss_message = (
            f'Processing sentence number: {info["iSentence"]}'
            f' Loss: {info["eloss"] / info["etotal"]:.3f}'
            f' Errors: {info["eerrors"] / info["etotal"]:.3f}'
            f' Labeled Errors: {info["lerrors"] / info["etotal"]:.3f}'
            f' Time: {time.time()-info["start"]:.3f}s'
        )
        logging.debug(loss_message)
        info["start"] = time.time() # TODO: зачем этот параметр ?
        info["eerrors"], info["eloss"], info["etotal"], info["lerrors"] = 0, 0.0, 0, 0 # TODO: Почему здесь зануляем?

    def Train(self, trainData, embeds):
        random.shuffle(trainData)
        # in certain cases the data will already have been shuffled after being read from file or while creating dev data
        logging.info(f"Length of training data: {len(trainData)}")

        beg = time.time()
        info = self.create_info()

        pbar = tqdm.tqdm(
            trainData, desc="Training", unit="sentences",
            mininterval=1.0, leave=False, disable=False,
        )

        for iSentence, sentence in enumerate(pbar,1):
            #print("-----------------------------------------------")
            #print("Sentence №", iSentence, sentence[2])
            info["iSentence"] = iSentence
            if iSentence % 100 == 0:
                self.train_logging(info)

            info = self.train_sentence(sentence, info, embeds)

            #footnote 8 in Eli's original paper
            if len(info["errs"]) > 50: # or True:
                self.net.error_processing(info["errs"])
                info["errs"] = []

        if len(info["errs"]) > 0:
            self.net.error_processing(info["errs"])
            info["errs"] = []

        logging.info(f"Loss: {info['mloss']/info['iSentence']}")
        logging.info(f"Total Training Time: {time.time()-beg:.2g}s")


# uuparser/parser.py

In [11]:
import pickle, os, time, sys, copy, itertools, re, random

from shutil import copyfile

def evaluate_uas(sentence_descr):
    #sentence_descr is a list, in which elements 0, 1, 2 are auxiliary
    right_parent_tokens = 0
    for token in sentence_descr[3:]:
        if isinstance(token, ConllEntry): # TODO: изучить случаи, когда не ConllEntry - ошибка считывания?
          if token.pred_parent_id == token.parent_id:
              right_parent_tokens += 1
        #print("pred_parent:", token.pred_parent_id, "real_parent:", token.parent_id)
    uas = right_parent_tokens / (len(sentence_descr) - 3)
    return uas

def evaluate_uas_epoche(sentence_list):
    summ_uas = 0
    for sent in sentence_list:
        summ_uas += evaluate_uas(sent)
    return summ_uas / len(sentence_list)

def run(traindata, valdata, testdata, embeds, options):

    irels = get_irels(traindata)
    logging.debug('Initializing the model')
    parser = Parser(options, irels)

    dev_best = [options["epochs"],-1.0] # best epoch, best score

    for epoch in range(options["first_epoch"], options["epochs"] + 1):
        # Training
        logging.info(f'Starting epoch {epoch} (training)')
        parser.Train(traindata, embeds)
        logging.info(f'Finished epoch {epoch} (training)')

        parser.Save(epoch)

        logging.info(f"Predicting on dev data")
        dev_pred = list(parser.Predict(valdata, embeds))
        mean_dev_score = evaluate_uas_epoche(dev_pred)
        logging.info(f"Dev score {mean_dev_score:.2f} at epoch {epoch:d}")
        print(f"Dev score {mean_dev_score:.2f} at epoch {epoch:d}")

        if mean_dev_score > dev_best[1]:
            dev_best = [epoch,mean_dev_score] # update best dev score

    logging.info(f"Loading best model from epoche{dev_best[0]:d}")
    # Loading best_models to parser.labeled_GNN and parser.unlabeled_GNN
    parser.Load(epoch)

    logging.info(f"Predicting on test data")

    test_pred = list(parser.Predict(testdata, embeds))
    mean_test_score = evaluate_uas_epoche(test_pred)

    logging.info(f"On test obtained UAS score of {mean_test_score:.2f}")
    print(f"On test obtained UAS score of {mean_test_score:.2f}")


    logging.debug('Finished predicting')


In [12]:
options = {}
options["hidden_dims"] = 100 # MLP hidden layer dimensions
options["learning_rate"] = 0.001 # Learning rate for neural network optimizer

options["oracle"] = True # Use the static oracle instead of the dynamic oracle

options["epochs"] = 10 # Number of epochs
options["first_epoch"] = 1

# really important to do this before anything else to make experiments reproducible
set_seeds()

train_dir = 'sample_data/UD_Russian-SynTagRus/ru_syntagrus-ud-train.conllu'
val_dir = 'sample_data/UD_Russian-SynTagRus/ru_syntagrus-ud-dev.conllu'
test_dir = 'sample_data/UD_Russian-SynTagRus/ru_syntagrus-ud-test.conllu'

train, train_words = read_conll(train_dir)
val, val_words = read_conll(val_dir)
test, test_words = read_conll(test_dir)
all_words = train_words + val_words + test_words
all_words = set(all_words)

In [13]:
from transformers import AutoTokenizer, BertModel
def get_embed(tokenizer, model, word):
    inputs = tokenizer(word, return_tensors="pt")
    outputs = model(**inputs)

    last_hidden_states = outputs.last_hidden_state[0][0]
    return last_hidden_states.detach().cpu()

embeds = {}
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")

logging.debug('Creating embeddings')
ts = time.time()
for word in all_words:
    embeds[word] = get_embed(tokenizer, model, word)
logging.debug(f'{len(embeds)} embeddings were created')
te = time.time()
logging.info(f'Time of embedding creation: {te-ts:.2g}s')

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [14]:
len(train)

5

In [15]:
print(Path('/content/app.log').read_text())

root - DEBUG - New debug
root - INFO - New info
jaxlib.mlir._mlir_libs - DEBUG - Initializing MLIR with module: _site_initialize_0
jaxlib.mlir._mlir_libs - DEBUG - Registering dialects from initializer <module 'jaxlib.mlir._mlir_libs._site_initialize_0' from '/usr/local/lib/python3.10/dist-packages/jaxlib/mlir/_mlir_libs/_site_initialize_0.so'>
jax._src.xla_bridge - DEBUG - No jax_plugins namespace packages available
jax._src.path - DEBUG - etils.epath found. Using etils.epath for file I/O.
root - DEBUG - Using default Python seed
root - INFO - Reading sample_data/UD_Russian-SynTagRus/ru_syntagrus-ud-train.conllu
root - DEBUG - 5 sentences read
root - INFO - Time: 0.0007s
root - INFO - Reading sample_data/UD_Russian-SynTagRus/ru_syntagrus-ud-dev.conllu
root - DEBUG - 25 sentences read
root - INFO - Time: 0.019s
root - INFO - Reading sample_data/UD_Russian-SynTagRus/ru_syntagrus-ud-test.conllu
root - DEBUG - 27 sentences read
root - INFO - Time: 0.012s
root - DEBUG - Creating embeddings

In [16]:
run(train, val, test, embeds, options)



Dev score 0.26 at epoch 1




Dev score 0.26 at epoch 2




Dev score 0.26 at epoch 3




Dev score 0.26 at epoch 4




Dev score 0.26 at epoch 5




Dev score 0.26 at epoch 6




Dev score 0.26 at epoch 7




Dev score 0.26 at epoch 8




Dev score 0.26 at epoch 9




Dev score 0.26 at epoch 10


                                                               

On test obtained UAS score of 0.30




In [None]:
r

{'stack_ids': [],
 's1': [],
 's0': [],
 'b': [<__main__.ConllEntry at 0x7cf5c478f280>],
 'beta': [<__main__.ConllEntry at 0x7cf5c478f220>,
  <__main__.ConllEntry at 0x7cf5c478ef20>,
  <__main__.ConllEntry at 0x7cf5c478ce20>,
  <__main__.ConllEntry at 0x7cf5c478f3a0>,
  <__main__.ConllEntry at 0x7cf5c478f190>,
  <__main__.ConllEntry at 0x7cf5c478f730>,
  <__main__.ConllEntry at 0x7cf5c478f340>,
  <__main__.ConllEntry at 0x7cf5c478f760>,
  <__main__.ConllEntry at 0x7cf5c478e3e0>,
  <__main__.ConllEntry at 0x7cf5c478d480>,
  <__main__.ConllEntry at 0x7cf5c478f3d0>,
  <__main__.ConllEntry at 0x7cf5c478f7c0>,
  <__main__.ConllEntry at 0x7cf5c478ee60>,
  <__main__.ConllEntry at 0x7cf5c478dde0>,
  <__main__.ConllEntry at 0x7cf5c478dc00>,
  <__main__.ConllEntry at 0x7cf5c478e800>,
  <__main__.ConllEntry at 0x7cf5c478cf70>,
  <__main__.ConllEntry at 0x7cf5c478efe0>,
  <__main__.ConllEntry at 0x7cf5c478f0d0>,
  <__main__.ConllEntry at 0x7cf5c478f400>,
  <__main__.ConllEntry at 0x7cf5c478d840>,


# TODO

TODO:
В sentence последний элемент -

{'id': 0,
 'form': '*root*',
 'char_rep': '*root*',
 'norm': '*root*',
 'cpos': 'ROOT-CPOS',
 'pos': 'ROOT-POS',
 'parent_id': -1,
 'relation': 'rroot',
 'lemma': '*root*',
 'feats': '_',
 'deps': '_',
 'misc': '_',
 'pred_parent_id': None,
 'pred_relation': None,
 'treebank_id': None,
 'proxy_tbank': None,
 'pred_pos': None,
 'pred_cpos': None,
 'projective_order': 0,
 'rdeps': [8],
 'children': [],
 'scores': None,
 'parent': None,
 'vecs': None}


В какую сторону стек в коде сейчас ?
Используют stack[-1], stack[-2].
Стек или очередь ?

Разобраться, какие метрики считают при обучении (на train)