<a href="https://colab.research.google.com/github/Derinhelm/graph_syntax_parsing/blob/main/Parser.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Installation

In [None]:
!pip install -q transformers

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m29.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m16.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m48.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m47.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
!pip install torch_geometric

from IPython.display import clear_output

clear_output()

# Logging

In [None]:

import logging
logger = logging.getLogger('my_logger')

# Remove all handlers associated with the root logger object.
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

logging.basicConfig(
    filename='app.log', # write to this file
    filemode='a', # open in append mode
    format='%(name)s - %(levelname)s - %(message)s'
    )

logging.warning('This will get logged to a file')

In [None]:
logging.warning('New warning')

In [None]:
from pathlib import Path
print(Path('/content/app.log').read_text())

#uuparser/utils.py

In [None]:
from collections import defaultdict, Counter
import re
import os,time
from operator import itemgetter
import random
import json
import pathlib
import subprocess
import sys

import tqdm


UTILS_PATH = pathlib.Path(__file__).parent/"utils"


class ConllEntry:
    def __init__(self, id, form, lemma, pos, cpos, feats=None, parent_id=None, relation=None,
        deps=None, misc=None, treebank_id=None, proxy_tbank=None, char_rep=None):

        self.id = id
        self.form = form
        self.char_rep = char_rep if char_rep else form
        self.norm = normalize(self.char_rep)
        self.cpos = cpos
        self.pos = pos
        self.parent_id = parent_id
        self.relation = relation

        self.lemma = lemma
        self.feats = feats
        self.deps = deps
        self.misc = misc

        self.pred_parent_id = None
        self.pred_relation = None
        self.treebank_id = treebank_id
        self.proxy_tbank = proxy_tbank

        self.pred_pos = None
        self.pred_cpos = None


    def __str__(self):
        values = [str(self.id), self.form, self.lemma, \
                  self.pred_cpos if self.pred_cpos else self.cpos,\
                  self.pred_pos if self.pred_pos else self.pos,\
                  self.feats, str(self.pred_parent_id) if self.pred_parent_id \
                  is not None else str(self.parent_id), self.pred_relation if\
                  self.pred_relation is not None else self.relation, \
                  self.deps, self.misc]
        return '\t'.join(['_' if v is None else v for v in values])

class ParseForest:
    def __init__(self, sentence):
        self.roots = list(sentence)

        for root in self.roots:
            root.children = []
            root.scores = None
            root.parent = None
            root.pred_parent_id = None
            root.pred_relation = None
            root.vecs = None

    def __len__(self):
        return len(self.roots)


    def Attach(self, parent_index, child_index):
        parent = self.roots[parent_index]
        child = self.roots[child_index]

        child.pred_parent_id = parent.id
        del self.roots[child_index]


def isProj(sentence):
    forest = ParseForest(sentence)
    unassigned = {entry.id: sum([1 for pentry in sentence if pentry.parent_id == entry.id]) for entry in sentence}

    for _ in xrange(len(sentence)):
        for i in xrange(len(forest.roots) - 1):
            if forest.roots[i].parent_id == forest.roots[i+1].id and unassigned[forest.roots[i].id] == 0:
                unassigned[forest.roots[i+1].id]-=1
                forest.Attach(i+1, i)
                break
            if forest.roots[i+1].parent_id == forest.roots[i].id and unassigned[forest.roots[i+1].id] == 0:
                unassigned[forest.roots[i].id]-=1
                forest.Attach(i, i+1)
                break

    return len(forest.roots) == 1


def get_irels(data):
    """
    Collect frequencies of words, cpos, pos and deprels + languages.
    """

    # could use sets directly rather than counters for most of these,
    # but having the counts might be useful in the future or possibly for debugging etc
    relCount = Counter()

    for sentence in data:
        for node in sentence:
            if isinstance(node, ConllEntry):
                relCount.update([node.relation])

    return list(relCount.keys())


def generate_root_token(treebank_id, proxy_tbank):
    return ConllEntry(0, '*root*', '*root*', 'ROOT-POS', 'ROOT-CPOS', '_', -1,
        'rroot', '_', '_',treebank_id=treebank_id, proxy_tbank=proxy_tbank)


def read_conll(filename, treebank_id=None, proxy_tbank=None, maxSize=-1, hard_lim=False, vocab_prep=False, drop_nproj=False, train=True):
    # hard lim means capping the corpus size across the whole training procedure
    # soft lim means using a sample of the whole corpus at each epoch
    fh = open(filename,'r',encoding='utf-8')
    logger.info(f"Reading {filename}")
    if vocab_prep and not hard_lim:
        maxSize = -1 # when preparing the vocab with a soft limit we need to use the whole corpus
    ts = time.time()
    dropped = 0
    sents_read = 0
    tokens = [generate_root_token(treebank_id, proxy_tbank)]
    yield_count = 0
    if maxSize > 0 and not hard_lim:
        sents = []
    for line in fh:
        tok = line.strip().split('\t')
        if not tok or line.strip() == '': # empty line, add sentence to list or yield
            if len(tokens)>1:
                sents_read += 1
                conll_tokens = [t for t in tokens if isinstance(t,ConllEntry)]
                if not drop_nproj or isProj(conll_tokens): # keep going if it's projective or we're not dropping non-projective sents
                    if train:
                        inorder_tokens = inorder(conll_tokens)
                        for i,t in enumerate(inorder_tokens):
                            t.projective_order = i
                        for tok in conll_tokens:
                            tok.rdeps = [i.id for i in conll_tokens if i.parent_id == tok.id]
                            if tok.id != 0:
                                tok.parent_entry = [i for i in conll_tokens if i.id == tok.parent_id][0]
                    if maxSize > 0:
                        if not hard_lim:
                            sents.append(tokens)
                        else:
                            yield tokens
                            yield_count += 1
                            if yield_count == maxSize:
                                logger.info(f"Capping size of corpus at {yield_count} sentences")
                                break
                    else:
                        yield tokens
                else:
                    logger.debug('Non-projective sentence dropped')
                    dropped += 1
            tokens = [generate_root_token(treebank_id, proxy_tbank)]
        else:
            if line[0] == '#' or '-' in tok[0] or '.' in tok[0]: # a comment line, add to tokens as is
                tokens.append(line.strip())
            else: # an actual ConllEntry, add to tokens
                char_rep = tok[1] # representation to use in character model
                if tok[2] == "_":
                    tok[2] = tok[1].lower()
                token = ConllEntry(int(tok[0]), tok[1], tok[2], tok[4], tok[3], tok[5], int(tok[6]) if tok[6] != '_' else -1, tok[7], tok[8], tok[9],treebank_id=treebank_id,proxy_tbank=proxy_tbank,char_rep=char_rep)

                tokens.append(token)

    if hard_lim and yield_count < maxSize:
        logger.warning(f'Unable to yield {maxSize} sentences, only {yield_count} found')

# TODO: deal with case where there are still unyielded tokens
# e.g. when there is no newline at end of file
#    if len(tokens) > 1:
#        yield tokens

    logger.debug(f'{sents_read} sentences read')

    if maxSize > 0 and not hard_lim:
        if len(sents) > maxSize:
            sents = random.sample(sents,maxSize)
            logger.debug(f"Yielding {len(sents)} random sentences")
        for toks in sents:
            yield toks

    te = time.time()
    logger.info(f'Time: {te-ts:.2g}s')


def write_conll(fn, conll_gen):
    logger.info(f"Writing to {fn}")
    sents = 0
    with open(fn, 'w', encoding='utf-8') as fh:
        for sentence in conll_gen:
            sents += 1
            for entry in sentence[1:]:
                fh.write(str(entry) + '\n')
            fh.write('\n')
        logger.debug(f"Wrote {sents} sentences")


numberRegex = re.compile("[0-9]+|[0-9]+\\.[0-9]+|[0-9]+[0-9,]+");
def normalize(word):
    return 'NUM' if numberRegex.match(word) else word.lower()


def inorder(sentence):
    queue = [sentence[0]]
    def inorder_helper(sentence,i):
        results = []
        left_children = [entry for entry in sentence[:i] if entry.parent_id == i]
        for child in left_children:
            results += inorder_helper(sentence,child.id)
        results.append(sentence[i])

        right_children = [entry for entry in sentence[i:] if entry.parent_id == i ]
        for child in right_children:
            results += inorder_helper(sentence,child.id)
        return results
    return inorder_helper(sentence,queue[0].id)


def set_seeds():
    python_seed = 1
    logger.debug("Using default Python seed")
    random.seed(python_seed)


def generate_seed():
    return random.randint(0,10**9) # this range seems to work for Dynet and Python's random function


# uuparser/multilayer_perceptron.py


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.nn import SAGEConv, to_hetero
import torch

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x




# uuparser/arc_hybrid.py

In [None]:
from operator import itemgetter
from itertools import chain
import time, random
import numpy as np
from copy import deepcopy
from collections import defaultdict
import json

import torch
import torch.nn as nn
import torch.optim as optim

import tqdm


from jax.numpy import int32
from torch_geometric.data import HeteroData

from transformers import AutoTokenizer, BertModel
import torch

In [None]:
def get_embed(tokenizer, model, word):

    inputs = tokenizer(word, return_tensors="pt")
    outputs = model(**inputs)

    last_hidden_states = outputs.last_hidden_state[0][0]
    return last_hidden_states.detach().cpu()


def get_embed_for_sentence(sentence):
    word_embeds = torch.empty((len(sentence), 768))
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    model = BertModel.from_pretrained("bert-base-uncased")
    for i in range(len(sentence)):
        word_embeds[i] = get_embed(tokenizer, model, sentence[i].form)
    return word_embeds

def create_stack_edges(stack):
    if len(stack) == 0:
        return torch.stack((torch.tensor([], dtype=torch.int32), torch.tensor([], dtype=torch.int32)), dim=0)
    stack_edges = []
    if len(stack) == 1:
        stack_edges.append((stack[0].id - 1, stack[0].id - 1)) # temporary solution
    else:
        for i in range(len(stack) - 1): # Represents every two consecutive stack nodes as an edge
            stack_edges.append((stack[i].id - 1, stack[i + 1].id - 1))
    stack_edges = tuple(zip(*stack_edges))
    stack_edges = [torch.tensor(stack_edges[0]), torch.tensor(stack_edges[1])]
    return torch.stack(stack_edges, dim=0)

def create_buffer_edges(buffer):
    if len(buffer) == 0:
        return torch.stack((torch.tensor([], dtype=torch.int32), torch.tensor([], dtype=torch.int32)), dim=0)
    buffer_edges = []
    if len(buffer) == 1:
        buffer_edges.append((buffer[0].id - 1, buffer[0].id - 1)) # temporary solution
    else:
        for i in range(len(buffer) - 1): # Represents every two consecutive buffer nodes as an edge
            buffer_edges.append((buffer[i].id - 1, buffer[i + 1].id - 1))
    buffer_edges = tuple(zip(*buffer_edges))
    buffer_edges = [torch.tensor(buffer_edges[0]), torch.tensor(buffer_edges[1])]
    return torch.stack(buffer_edges, dim=0)

def create_graph_edges(sentence):
    if len(sentence) == 0:
        return torch.stack((torch.tensor([], dtype=torch.int32), torch.tensor([], dtype=torch.int32)), dim=0)
    graph_edges = []
    for node in sentence:
        if node.parent_id is not None and node.parent_id != 0:
            graph_edges.append((node.parent_id - 1, node.id - 1))
    graph_edges = tuple(zip(*graph_edges))
    graph_edges = [torch.tensor(graph_edges[0]), torch.tensor(graph_edges[1])]
    return torch.stack(graph_edges, dim=0)

def config_to_graph(sentence, stack, buffer):
    word_embeds = get_embed_for_sentence(sentence)

    data = HeteroData()
    data['node']['x'] = word_embeds

    data[('node', 'graph', 'node')].edge_index = create_graph_edges(sentence)
    print(data[('node', 'graph', 'node')].edge_index)
    data[('node', 'stack', 'node')].edge_index = create_stack_edges(stack)
    data[('node', 'buffer', 'node')].edge_index = create_buffer_edges(buffer)
    return data

In [None]:

class ArcHybridLSTM:
    def __init__(self, irels, options):

        # import here so we don't load Dynet if just running parser.py --help for example
        from uuparser.multilayer_perceptron import MLP

        global LEFT_ARC, RIGHT_ARC, SHIFT, SWAP
        LEFT_ARC, RIGHT_ARC, SHIFT, SWAP = 0,1,2,3

        self.irels = irels

        self.activation = options["activation"]

        self.hidden_dims = options["mlp_hidden_dims"]

        self.mlp_in_dims = 30 # TODO: Create a logical value.

        self.metadata = (['node'], [('node', 'graph', 'node'), ('node', 'stack', 'node'), ('node', 'buffer', 'node')])
        self.unlabeled_GNN = GNN(hidden_channels=self.hidden_dims, out_channels=4)
        self.unlabeled_GNN = to_hetero(self.unlabeled_GNN, self.metadata, aggr='sum')

        self.labeled_GNN = GNN(hidden_channels=self.hidden_dims, out_channels=2*len(self.irels)+2)
        self.labeled_GNN = to_hetero(self.labeled_GNN, self.metadata, aggr='sum')


        self.unlabeled_optimizer = optim.Adam(self.unlabeled_GNN.parameters(), lr=options["learning_rate"]
        self.labeled_optimizer = optim.Adam(self.labeled_GNN.parameters(), lr=options["learning_rate"]

        self.oracle = options["oracle"]

        self.headFlag = options["headFlag"]
        self.rlMostFlag = options["rlMostFlag"]
        self.rlFlag = options["rlFlag"]
        self.k = options["k"]

    def __evaluate(self, stack, buf, sentence, train):
        """
        ret = [left arc,
               right arc
               shift]

        RET[i] = (rel, transition, score1, score2) for shift, l_arc and r_arc
         shift = 2 (==> rel=None) ; l_arc = 0; r_acr = 1

        ret[i][j][2] ~= ret[i][j][3] except the latter is a dynet
        expression used in the loss, the first is used in rest of training
        """

        graph = config_to_graph(sentence, stack.roots, buf.roots)
        output = self.unlabeled_GNN(graph.x_dict, graph.edge_index_dict)
        routput = self.labeled_GNN(graph.x_dict, graph.edge_index_dict)


        #scores, unlabeled scores
        scrs, uscrs = routput, output

        #transition conditions
        left_arc_conditions = len(stack) > 0
        right_arc_conditions = len(stack) > 1
        shift_conditions = buf.roots[0].id != 0
        swap_conditions = len(stack) > 0 and stack.roots[-1].id < buf.roots[0].id

        if not train:
            #(avoiding the multiple roots problem: disallow left-arc from root
            #if stack has more than one element
            left_arc_conditions = left_arc_conditions and not (buf.roots[0].id == 0 and len(stack) > 1)

        uscrs0, uscrs1, uscrs2, uscrs3 = uscrs[0], uscrs[1], uscrs[2], uscrs[3]

        if train:
            output0, output1, output2, output3 = output[0], output[1], output[2], output[3]


            ret = [ [ (rel, LEFT_ARC, scrs[2 + j * 2] + uscrs2, routput[2 + j * 2 ] + output2) for j, rel in enumerate(self.irels) ] if left_arc_conditions else [],
                   [ (rel, RIGHT_ARC, scrs[3 + j * 2] + uscrs3, routput[3 + j * 2 ] + output3) for j, rel in enumerate(self.irels) ] if right_arc_conditions else [],
                   [ (None, SHIFT, scrs[0] + uscrs0, routput[0] + output0) ] if shift_conditions else [] ,
                    [ (None, SWAP, scrs[1] + uscrs1, routput[1] + output1) ] if swap_conditions else [] ]
        else:
            s1,r1 = max(zip(scrs[2::2],self.irels))
            s2,r2 = max(zip(scrs[3::2],self.irels))
            s1 = s1 + uscrs2
            s2 = s2 + uscrs3
            ret = [ [ (r1, LEFT_ARC, s1) ] if left_arc_conditions else [],
                   [ (r2, RIGHT_ARC, s2) ] if right_arc_conditions else [],
                   [ (None, SHIFT, scrs[0] + uscrs0) ] if shift_conditions else [] ,
                    [ (None, SWAP, scrs[1] + uscrs1) ] if swap_conditions else [] ]
        return ret

    def Load(self, epoch):
        unlab_path = 'model_unlab' + '_' + str(epoch)
        lab_path = 'model_lab' + '_' + str(epoch)

        self.unlabeled_GNN = GNN(hidden_channels=self.hidden_dims, out_channels=4)
        self.labeled_GNN = GNN(hidden_channels=self.hidden_dims, out_channels=2*len(self.irels)+2)

        unlab_checkpoint = torch.load(unlab_path)
        self.unlabeled_GNN.load_state_dict(unlab_checkpoint['model_state_dict'])


        self.unlabeled_GNN = to_hetero(self.unlabeled_GNN, self.metadata, aggr='sum')
        self.labeled_GNN = to_hetero(self.labeled_GNN, self.metadata, aggr='sum')

        lab_checkpoint = torch.load(lab_path)
        self.labeled_GNN.load_state_dict(lab_checkpoint['model_state_dict'])


    def Save(self, epoch):
        unlab_path = 'model_unlab' + '_' + str(epoch)
        lab_path = 'model_lab' + '_' + str(epoch)
        logger.info(f'Saving unlabeled model to {unlab_path}')
        torch.save({'epoch': epoch, 'model_state_dict': self.unlabeled_GNN.state_dict()}, unlab_path)
        logger.info(f'Saving labeled model to {lab_path}')
        torch.save({'epoch': epoch, 'model_state_dict': self.labeled_GNN.state_dict()}, lab_path)

    def apply_transition(self,best,stack,buf,hoffset):
        if best[1] == SHIFT:
            stack.roots.append(buf.roots[0])
            del buf.roots[0]

        elif best[1] == SWAP:
            child = stack.roots.pop()
            buf.roots.insert(1,child)

        elif best[1] == LEFT_ARC:
            child = stack.roots.pop()
            parent = buf.roots[0]

        elif best[1] == RIGHT_ARC:
            child = stack.roots.pop()
            parent = stack.roots[-1]

        if best[1] == LEFT_ARC or best[1] == RIGHT_ARC:
            #attach
            child.pred_parent_id = parent.id
            child.pred_relation = best[0]

    def calculate_cost(self,scores,s0,s1,b,beta,stack_ids):
        if len(scores[LEFT_ARC]) == 0:
            left_cost = 1
        else:
            left_cost = len(s0[0].rdeps) + int(s0[0].parent_id != b[0].id and s0[0].id in s0[0].parent_entry.rdeps)


        if len(scores[RIGHT_ARC]) == 0:
            right_cost = 1
        else:
            right_cost = len(s0[0].rdeps) + int(s0[0].parent_id != s1[0].id and s0[0].id in s0[0].parent_entry.rdeps)


        if len(scores[SHIFT]) == 0:
            shift_cost = 1
            shift_case = 0
        elif len([item for item in beta if item.projective_order < b[0].projective_order and item.id > b[0].id ])> 0:
            shift_cost = 0
            shift_case = 1
        else:
            shift_cost = len([d for d in b[0].rdeps if d in stack_ids]) + int(len(s0)>0 and b[0].parent_id in stack_ids[:-1] and b[0].id in b[0].parent_entry.rdeps)
            shift_case = 2


        if len(scores[SWAP]) == 0 :
            swap_cost = 1
        elif s0[0].projective_order > b[0].projective_order:
            swap_cost = 0
            #disable all the others
            left_cost = right_cost = shift_cost = 1
        else:
            swap_cost = 1

        costs = (left_cost, right_cost, shift_cost, swap_cost,1)
        return costs,shift_case


    def oracle_updates(self,best,b,s0,stack_ids,shift_case):
        if best[1] == SHIFT:
            if shift_case == 2:
                if b[0].parent_entry.id in stack_ids[:-1] and b[0].id in b[0].parent_entry.rdeps:
                    b[0].parent_entry.rdeps.remove(b[0].id)
                blocked_deps = [d for d in b[0].rdeps if d in stack_ids]
                for d in blocked_deps:
                    b[0].rdeps.remove(d)

        elif best[1] == LEFT_ARC or best[1] == RIGHT_ARC:
            s0[0].rdeps = []
            if s0[0].id in s0[0].parent_entry.rdeps:
                s0[0].parent_entry.rdeps.remove(s0[0].id)

    def Predict(self, data, datasplit, options):
        reached_max_swap = 0

        pbar = tqdm.tqdm(
            data,
            desc="Parsing",
            unit="sentences",
            mininterval=1.0,
            leave=False,
            disable=False,
        )

        for iSentence, osentence in enumerate(pbar,1):
            sentence = deepcopy(osentence)
            reached_swap_for_i_sentence = False
            max_swap = 2*len(sentence)
            iSwap = 0
            #self.feature_extractor.Init(options)
            conll_sentence = [entry for entry in sentence if isinstance(entry, ConllEntry)]
            conll_sentence = conll_sentence[1:] + [conll_sentence[0]]
            stack = ParseForest([])
            buf = ParseForest(conll_sentence)

            hoffset = 1 if self.headFlag else 0

            for root in conll_sentence:
                root.relation = root.relation if root.relation in self.irels else 'runk'


            while not (len(buf) == 1 and len(stack) == 0):
                scores = self.__evaluate(stack, buf, conll_sentence, False)
                best = max(chain(*(scores if iSwap < max_swap else scores[:3] )), key = itemgetter(2) )
                if iSwap == max_swap and not reached_swap_for_i_sentence:
                    reached_max_swap += 1
                    reached_swap_for_i_sentence = True
                    logger.debug(f"reached max swap in {reached_max_swap:d} out of {iSentence:d} sentences")
                self.apply_transition(best,stack,buf,hoffset)
                if best[1] == SWAP:
                    iSwap += 1

            #keep in memory the information we need, not all the vectors
            oconll_sentence = [entry for entry in osentence if isinstance(entry, ConllEntry)]
            oconll_sentence = oconll_sentence[1:] + [oconll_sentence[0]]
            for tok_o, tok in zip(oconll_sentence, conll_sentence):
                tok_o.pred_relation = tok.pred_relation
                tok_o.pred_parent_id = tok.pred_parent_id
            yield osentence

    def cost_computing(stack, buf, scores, info):
        stack_ids = [sitem.id for sitem in stack.roots]
        
        s1 = [stack.roots[-2]] if len(stack) > 1 else []
        s0 = [stack.roots[-1]] if len(stack) > 0 else []
        b = [buf.roots[0]] if len(buf) > 0 else []
        beta = buf.roots[1:] if len(buf) > 1 else []

        costs, shift_case = self.calculate_cost(scores,s0,s1,b,beta,stack_ids)

        bestValid = list(( s for s in chain(*scores) if costs[s[1]] == 0 and ( s[1] == SHIFT or s[1] == SWAP or  s[0] == s0[0].relation ) ))

        bestValid = max(bestValid, key=itemgetter(2))
        bestWrong = max(( s for s in chain(*scores) if costs[s[1]] != 0 or ( s[1] != SHIFT and s[1] != SWAP and s[0] != s0[0].relation ) ), key=itemgetter(2))
        #force swap
        if costs[SWAP]== 0:
            best = bestValid
        else:
        #select a transition to follow
        # + aggresive exploration
        #1: might want to experiment with that parameter
            if bestWrong[1] == SWAP:
                best = bestValid
            else:
                best = bestValid if ( (not self.oracle) or (bestValid[2] - bestWrong[2] > 1.0) or (bestValid[2] > bestWrong[2] and random.random() > 0.1) ) else bestWrong

        #updates for the dynamic oracle
        if self.oracle:
            self.oracle_updates(best,b,s0,stack_ids,shift_case)

        #labeled errors
        if best[1] == LEFT_ARC or best[1] == RIGHT_ARC:
            child = s0[0]
            if (child.pred_parent_id != child.parent_id or child.pred_relation != child.relation):
                info["lerrors"] += 1
                #attachment error
                if child.pred_parent_id != child.parent_id:
                    info["eerrors"] += 1

        if bestValid[2] < bestWrong[2] + 1.0:
            loss = bestWrong[3] - bestValid[3]
            info["mloss"] += 1.0 + bestWrong[2] - bestValid[2]
            info["eloss"] += 1.0 + bestWrong[2] - bestValid[2]
            info["errs"].append(loss)

        #??? when did this happen and why?
        if best[1] == 0 or best[1] == 2:
            info["etotal"] += 1

        return best, info

    def train_sentence(sentence, info):
            sentence = deepcopy(sentence) # ensures we are working with a clean copy of sentence and allows memory to be recycled each time round the loop

            conll_sentence = [entry for entry in sentence if isinstance(entry, ConllEntry)]
            conll_sentence = conll_sentence[1:] + [conll_sentence[0]]
            stack = ParseForest([])
            buf = ParseForest(conll_sentence)
            hoffset = 1 if self.headFlag else 0

            for root in conll_sentence:
                root.relation = root.relation if root.relation in self.irels else 'runk'
            
            ninf = -float('inf')
            while not (len(buf) == 1 and len(stack) == 0):
                scores = self.__evaluate(stack, buf, conll_sentence, True)
                scores.append([(None, 4, ninf ,None)]) #to ensure that we have at least one wrong operation

                best, info = cost_computing(stack, buf, scores, info)

                self.apply_transition(best,stack,buf,hoffset)

            return info

    def error_processing(errs):
        self.labeled_optimizer.zero_grad()
        self.unlabeled_optimizer.zero_grad()
        eerrs = torch.sum(torch.tensor(errs, requires_grad=True))
        eerrs.backward()
        self.labeled_optimizer.step() # TODO Какой из оптимизаторов ???
        self.unlabeled_optimizer.step()
        

    def Train(self, trainData, options):
        info = {}
        info["mloss"], info["eloss"], info["eerrors"], info["lerrors"], info["etotal"]  = 0.0, 0.0, 0, 0, 0
        info["errs"] = []

        beg = time.time()
        start = time.time()

        random.shuffle(trainData) # in certain cases the data will already have been shuffled after being read from file or while creating dev data
        logger.info(f"Length of training data: {len(trainData)}")

        pbar = tqdm.tqdm(
            trainData,
            desc="Training",
            unit="sentences",
            mininterval=1.0,
            leave=False,
            disable=False,
        )

        for iSentence, sentence in enumerate(pbar,1):
            if iSentence % 100 == 0:
                loss_message = (
                    f'Processing sentence number: {iSentence}'
                    f' Loss: {info["eloss"] / info["etotal"]:.3f}'
                    f' Errors: {info["eerrors"] / info["etotal"]:.3f}'
                    f' Labeled Errors: {info["lerrors"] / info["etotal"]:.3f}'
                    f' Time: {time.time()-start:.3f}s'
                )
                logger.debug(loss_message)
                start = time.time() # TODO: зачем этот параметр ?
                info["eerrors"], info["eloss"], info["etotal"], info["lerrors"] = 0, 0.0, 0, 0

            info = self.train_sentence(sentence, info)

            #footnote 8 in Eli's original paper
            if len(info["errs"]) > 50: # or True:
                self.error_processing(info["errs"])
                info["errs"] = []

        if len(info["errs"]) > 0:
            self.error_processing(info["errs"])
            info["errs"] = []


        logger.info(f"Loss: {info["mloss"]/iSentence}")
        logger.info(f"Total Training Time: {time.time()-beg:.2g}s")


# uuparser/parser.py

In [None]:
import pickle, os, time, sys, copy, itertools, re, random

from shutil import copyfile

def evaluate_uas(sentence_descr):
    #sentence_descr is a list, in which elements 0, 1, 2 are auxiliary
    right_parent_tokens = 0
    for token in sentence_descr[3:]:
        if isinstance(token, ConllEntry): # TODO: изучить случаи, когда не ConllEntry - ошибка считывания?
          if token.pred_parent_id == token.parent_id:
              right_parent_tokens += 1
        #print("pred_parent:", token.pred_parent_id, "real_parent:", token.parent_id)
    uas = right_parent_tokens / (len(sentence_descr) - 3)
    return uas

def evaluate_uas_epoche(sentence_list):
    summ_uas = 0
    for sent in sentence_list:
        summ_uas += evaluate_uas(sent)
    return summ_uas / len(sentence_list)

def run(traindata, valdata, testdata, options):

    from uuparser.arc_hybrid import ArcHybridLSTM
    #logger.info('Working with a transition-based parser')

    irels = get_irels(traindata)
    logger.debug('Initializing the model')
    parser = ArcHybridLSTM(irels, options)

    dev_best = [options["epochs"],-1.0] # best epoch, best score

    for epoch in range(options["first_epoch"], options["epochs"] + 1):
        # Training
        logger.info(f'Starting epoch {epoch} (training)')
        parser.Train(traindata,options)
        logger.info(f'Finished epoch {epoch} (training)')

        parser.Save(epoch)

        logger.info(f"Predicting on dev data")
        dev_pred = list(parser.Predict(valdata,"dev",options))
        mean_dev_score = evaluate_uas_epoche(dev_pred)
        logger.info(f"Dev score {mean_dev_score:.2f} at epoch {epoch:d}")
        print(f"Dev score {mean_dev_score:.2f} at epoch {epoch:d}")

        if mean_dev_score > dev_best[1]:
            dev_best = [epoch,mean_dev_score] # update best dev score

    logger.info(f"Loading best model from epoche{dev_best[0]:d}")
    # Loading best_models to parser.labeled_GNN and parser.unlabeled_GNN
    parser.Load(epoch)

    logger.info(f"Predicting on test data")

    test_pred = list(parser.Predict(testdata,"test",options))
    mean_test_score = evaluate_uas_epoche(test_pred)

    logger.info(f"On test obtained UAS score of {mean_test_score:.2f}")
    print(f"On test obtained UAS score of {mean_test_score:.2f}")


    logger.debug('Finished predicting')


def main():
    options = {}
    options["activation"] = "tanh" # Activation function in the MLP
    options["mlp_hidden_dims"] = 100 # MLP hidden layer dimensions
    options["learning_rate"] = 0.001 # Learning rate for neural network optimizer
    options["oracle"] = True # Use the static oracle instead of the dynamic oracle
    options["headFlag"] = True # Disable using the head of word vectors fed to the MLP
    options["rlMostFlag"] = True # Disable using leftmost and rightmost dependents of words fed to the MLP
    options["rlFlag"] = False
    options["k"] = 3 # Number of stack elements to feed to MLP
    options["epochs"] = 30 # Number of epochs
    options["first_epoch"] = 1
    options["max_sentences"] = -1 # Only train using n sentences per epoch

    # really important to do this before anything else to make experiments reproducible
    set_seeds()

    train_dir = 'sample_data/UD_Russian-SynTagRus/ru_syntagrus-ud-train.conllu'
    val_dir = 'sample_data/UD_Russian-SynTagRus/ru_syntagrus-ud-val.conllu'
    test_dir = 'sample_data/UD_Russian-SynTagRus/ru_syntagrus-ud-test.conllu'

    train = list(read_conll(train_dir, maxSize=options["max_sentences"]))
    val = list(read_conll(val_dir, maxSize=options["max_sentences"]))
    test = list(read_conll(test_dir, maxSize=options["max_sentences"]))
    run(train, val, test, options)

if __name__ == '__main__':
    main()