<a href="https://colab.research.google.com/github/Derinhelm/parser_stat/blob/llm_taiga/udeppllamarunning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Code downloading

In [1]:
!git clone https://github.com/Derinhelm/parser_stat.git

Cloning into 'parser_stat'...
remote: Enumerating objects: 312, done.[K
remote: Counting objects: 100% (81/81), done.[K
remote: Compressing objects: 100% (64/64), done.[K
remote: Total 312 (delta 36), reused 45 (delta 17), pack-reused 231 (from 1)[K
Receiving objects: 100% (312/312), 53.53 MiB | 17.47 MiB/s, done.
Resolving deltas: 100% (163/163), done.
Updating files: 100% (49/49), done.


In [2]:
import pickle

import sys
sys.path.append('/content/parser_stat')

from data_classes import ConllEntry, Sentence

# Preparing

In [3]:
from IPython.display import clear_output

In [4]:
import time
import traceback

In [5]:
from data_classes import ConllEntry, Sentence

def get_dataset_sentences(dataset_path):
    fh = open(dataset_path,'r',encoding='utf-8')
    sents_read = 0
    sents = []
    comments = set()

    sent = Sentence()
    for line in fh:
        tok = line.strip().split('\t')
        if not tok or line.strip() == '': # empty line, add sentence to list
            if sent.is_not_empty:
                sents_read += 1
                sents.append(sent)
            sent = Sentence()
        else:
            if line[0] == '#' or '-' in tok[0]: # a comment line
                line = line.strip()
                if line[:12] == "# sent_id = ":
                    sent.set_sent_id(line[12:])
                elif line[:9] == "# text = ":
                    sent.set_text(line[9:])
                else:
                    comments.add(line)

            else: # an actual ConllEntry, add to tokens
                if tok[2] == "_":
                    tok[2] = tok[1].lower()

                word = ConllEntry(*tok)
                sent.add_token(word)
    fh.close()
    return sents

In [8]:
treebank_name = 'taiga'
taiga_data = get_dataset_sentences(f"/content/parser_stat/treebank_test_sets/ru_{treebank_name}-ud-test.conllu")
print(treebank_name, len(taiga_data))

taiga 881


In [11]:
taiga_data[1].text

'Она решила попытаться остановить машину — хотя выйдя под дождь, сразу же промокла насквозь.'

# UDepPLLaMA running

In [None]:
!pip install peft transformers bitsandbytes

In [13]:
import transformers
import torch
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from peft import PeftModel

In [14]:
OP = '['
CP = ']'

class UDepPLLaMAParser:
    def __init__(self):
        quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )
        model_from = AutoModelForCausalLM.from_pretrained(
            "NousResearch/Llama-2-7b-hf",
            #load_in_4bit=True,
            quantization_config=quant_config,
            torch_dtype=torch.float16,
            trust_remote_code=True,
            device_map={"": 0},
        )

        model = PeftModel.from_pretrained(
            model_from,
            "sag-uniroma2/u-depp-llama-2-7b"
        )

        generation_config = GenerationConfig(
            num_beams=4,
            do_sample=False,
            early_stopping=True,
        )
        tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf", trust_remote_code=True)
        self.model = model
        self.generation_config = generation_config
        self.tokenizer = tokenizer


    def get_llm_output(self, input):
        prompt = f"""
        ### Input:
        {input}
        ### Answer:"""
        inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
        input_ids = inputs["input_ids"].to(self.model.device)
        with torch.no_grad():
            gen_outputs = self.model.generate(
                input_ids=input_ids,
                generation_config=self.generation_config,
                return_dict_in_generate=True,
                output_scores=True,
                max_new_tokens=1024,
                use_cache=True,
            )
        s = gen_outputs.sequences[0]
        output = self.tokenizer.decode(s, skip_special_tokens=True)

        response = output.split("### Answer:")[1].rstrip().lstrip()
        #print(response)
        return response

    def parseExpression(self, expression):
        nodeMap = dict()
        counter = 1
        node = ""
        retExp =""
        for char in expression:
            if char == OP or char == CP :
                if (len(node) > 0):
                    nodeMap[str(counter)] = node;
                    retExp += str(counter)
                    counter +=1
                retExp += char
                node =""
            elif char == ' ': continue
            else :
                node += char
        return retExp,nodeMap

    def toTree(self, expression):
        tree = dict()
        msg =""
        stack = list()
        for char in expression:
            if(char == OP):
                stack.append(msg)
                msg = ""
            elif char == CP:
                parent = stack.pop()
                if parent not in tree:
                    tree[parent] = list()
                tree[parent].append(msg)
                msg = parent
            else:
                msg += char
        return tree


    def _decode(self, tree, representation_type, node, nodeMap, parent, grand_parent, tid2treenodeMap, res):
        if node not in tree:
            tid = 1
            if res:
                tid = int(max(res.keys())) + 1

            grand_parent_label = "ROOT"
            if grand_parent in nodeMap:
                grand_parent_label = nodeMap[grand_parent]

            if representation_type == "lct":
                res[tid] = { "id": tid, "form": nodeMap[parent], "to": grand_parent_label, "toid" : grand_parent, "deprel": nodeMap[node] }
            elif representation_type == "grct":
                res[tid] = { "id": tid, "form": nodeMap[node], "to": grand_parent_label, "toid" : grand_parent, "deprel": nodeMap[parent] }
            else:
                raise Exception("The representation_type\t" + representation_type + "\t is not supported in decoding.")

            tid2treenodeMap[parent] = str(tid)

            return

        for child in tree[node]:
            self._decode(tree, representation_type, child, nodeMap, node, parent, tid2treenodeMap, res)

    def decode(self, tree, nodeMap, representation_type="lct"):
        res = dict()
        tid2treenodeMap = dict()
        #print(tree[''][0])
        self._decode(tree, representation_type, "1", nodeMap, None, None, tid2treenodeMap, res)

        for i in range(1, len(res)+1):
            if res[i]["toid"] is None:
                res[i]["toid"] = '0'
            else:
                try:
                    res[i]["toid"] = tid2treenodeMap[res[i]["toid"]]
                except:
                    res[i]["toid"] = '0'

        return res

    def _parse(self, s):
        llm_output = self.get_llm_output(s)
        retExp, nodeMap = self.parseExpression(llm_output)
        tree = self.toTree(retExp)
        res = self.decode(tree, nodeMap)
        return res

    def parse(self, sent):
        parsing_res = self._parse(sent)
        res = []
        for token in parsing_res.values():
          t =  { 'id': str(token['id']), 'form': token['deprel'],
                 'parent_id': token['toid'], 'relation': token['form'] }
          res.append(t)
        return res

In [None]:
parser = UDepPLLaMAParser()


In [16]:
ts = time.time()
parser.parse("Мама мыла раму.")
print(time.time() - ts)


17.610638856887817


# Experiments

In [17]:
#taiga_data = taiga_data[:5] # Uncomment for testing on small dataset version

In [18]:
import gc

start_i = 0
finish_i = len(taiga_data)

t_res = {}
print("\n", treebank_name)
t_time = []
for i in range(start_i, finish_i):
    if i % 20 == 0:
        with open(f"{treebank_name}_{start_i}_{i}.pickle", 'wb') as f:
            pickle.dump(t_res, f)
        gc.collect()
    sent = taiga_data[i]
    ts = time.time()
    try:
        token_list = parser.parse(sent.text)
    except Exception as err:
        t_res[i] = (err, )
        print(i, err)
    else:
        te = time.time()
        t_time.append(te - ts)
        cur_res = Sentence()
        cur_res.set_sent_id(sent.sent_id)
        cur_res.set_text(sent.text)
        for t in token_list:
            cur_res.add_token(t)
        t_res[i] = (cur_res, t_time[-1])
        print(i, t_time[-1])

all_time = sum(t_time)

print(f"Time: {all_time:5.3f} (s)")

with open(treebank_name + ".pickle", 'wb') as f:
    pickle.dump(t_res, f)


 taiga
0 69.96101880073547
1 69.03670978546143
2 34.33293128013611
3 58.96249794960022
4 5.033560752868652
Time: 237.327 (s)
