<a href="https://colab.research.google.com/github/Derinhelm/parser_stat/blob/main/parser_running.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Code downloading

In [None]:
!git clone -b llm-parser https://github.com/Derinhelm/parser_stat.git

Cloning into 'parser_stat'...
remote: Enumerating objects: 63, done.[K
remote: Counting objects: 100% (63/63), done.[K
remote: Compressing objects: 100% (49/49), done.[K
remote: Total 63 (delta 30), reused 36 (delta 13), pack-reused 0 (from 0)[K
Receiving objects: 100% (63/63), 14.17 MiB | 13.71 MiB/s, done.
Resolving deltas: 100% (30/30), done.


In [None]:
import sys
sys.path.append('/content/parser_stat')

# Preparing

In [None]:
from IPython.display import clear_output

In [None]:
from google.colab import files

In [None]:
import time
import traceback


# UDepPLLaMA running

In [None]:
!pip install peft transformers bitsandbytes
clear_output()

In [None]:
import pickle

from data_classes import ConllEntry, Sentence

In [None]:
import transformers
import torch
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from peft import PeftModel

In [None]:
OP = '['
CP = ']'

class UDepPLLaMAParser:
    def __init__(self):
        quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )
        model_from = AutoModelForCausalLM.from_pretrained(
            "NousResearch/Llama-2-13b-hf",
            #load_in_4bit=True,
            quantization_config=quant_config,
            torch_dtype=torch.float16,
            trust_remote_code=True,
            device_map={"": 0},
        )

        model = PeftModel.from_pretrained(
            model_from,
            "sag-uniroma2/u-depp-llama-2-13b"
        )

        generation_config = GenerationConfig(
            num_beams=4,
            do_sample=False,
            early_stopping=True,
        )
        tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-13b-hf", trust_remote_code=True)
        self.model = model
        self.generation_config = generation_config
        self.tokenizer = tokenizer


    def get_llm_output(self, input):
        prompt = f"""
        ### Input:
        {input}
        ### Answer:"""
        inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
        input_ids = inputs["input_ids"].to(self.model.device)
        with torch.no_grad():
            gen_outputs = self.model.generate(
                input_ids=input_ids,
                generation_config=self.generation_config,
                return_dict_in_generate=True,
                output_scores=True,
                max_new_tokens=1024,
                use_cache=True,
            )
        s = gen_outputs.sequences[0]
        output = self.tokenizer.decode(s, skip_special_tokens=True)

        response = output.split("### Answer:")[1].rstrip().lstrip()
        #print(response)

        del input_ids
        torch.cuda.empty_cache()
        return response

    def parseExpression(self, expression):
        nodeMap = dict()
        counter = 1
        node = ""
        retExp =""
        for char in expression:
            if char == OP or char == CP :
                if (len(node) > 0):
                    nodeMap[str(counter)] = node;
                    retExp += str(counter)
                    counter +=1
                retExp += char
                node =""
            elif char == ' ': continue
            else :
                node += char
        return retExp,nodeMap

    def toTree(self, expression):
        tree = dict()
        msg =""
        stack = list()
        for char in expression:
            if(char == OP):
                stack.append(msg)
                msg = ""
            elif char == CP:
                parent = stack.pop()
                if parent not in tree:
                    tree[parent] = list()
                tree[parent].append(msg)
                msg = parent
            else:
                msg += char
        return tree


    def _decode(self, tree, representation_type, node, nodeMap, parent, grand_parent, tid2treenodeMap, res):
        if node not in tree:
            tid = 1
            if res:
                tid = int(max(res.keys())) + 1

            grand_parent_label = "ROOT"
            if grand_parent in nodeMap:
                grand_parent_label = nodeMap[grand_parent]

            if representation_type == "lct":
                res[tid] = { "id": tid, "form": nodeMap[parent], "to": grand_parent_label, "toid" : grand_parent, "deprel": nodeMap[node] }
            elif representation_type == "grct":
                res[tid] = { "id": tid, "form": nodeMap[node], "to": grand_parent_label, "toid" : grand_parent, "deprel": nodeMap[parent] }
            else:
                raise Exception("The representation_type\t" + representation_type + "\t is not supported in decoding.")

            tid2treenodeMap[parent] = str(tid)

            return

        for child in tree[node]:
            self._decode(tree, representation_type, child, nodeMap, node, parent, tid2treenodeMap, res)

    def decode(self, tree, nodeMap, representation_type="lct"):
        res = dict()
        tid2treenodeMap = dict()
        self._decode(tree, representation_type, "1", nodeMap, None, None, tid2treenodeMap, res)

        for i in range(1, len(res)+1):
            if res[i]["toid"] is None:
                res[i]["toid"] = '0'
            else:
                try:
                    res[i]["toid"] = tid2treenodeMap[res[i]["toid"]]
                except:
                    res[i]["toid"] = '0'

        return res

    def _parse(self, s):
        llm_output = self.get_llm_output(s)
        retExp, nodeMap = self.parseExpression(llm_output)
        tree = self.toTree(retExp)
        res = self.decode(tree, nodeMap)
        return res

    def parse(self, sent):
        parsing_res = self._parse(sent)
        res = []
        for token in parsing_res.values():
          t =  ConllEntry(str(token['id']), form=token['deprel'], parent_id=token['toid'], relation=token['form'])
          res.append(t)
        return res

In [None]:

pickle_data_path = "/content/parser_stat/treebank_test_sets/treebank_data.pickle"

with open(pickle_data_path, 'rb') as f:
    data = pickle.load(f)


res = {}
time_dict = {}
parser = UDepPLLaMAParser()

for treebank_name, treebank_sents in data.items():
    t_res = []
    print("\n", treebank_name)
    t_time = []
    for i, sent in enumerate(treebank_sents):
        if i % 100 == 0:
            print(f"{i:4}/{len(treebank_sents)}")
        try:
            ts = time.time()
            token_list = parser.parse(sent.text)
            te = time.time()
            t_time.append(te - ts)
            
            cur_res = Sentence()
            cur_res.set_sent_id(sent.sent_id)
            cur_res.set_text(sent.text)
            for t in token_list:
                cur_res.add_token(t)
            t_res.append(cur_res)
        except Exception as e:
            t_res.append((e, traceback.format_exc()))
    res[treebank_name] = t_res
    time_dict[treebank_name] = sum(t_time)

print("\ntime results (s):")
for p, t in time_dict.items():
    print(f"{p:10}: {t:5.3f} (s)")

with open(f'udeppllama.pickle', 'wb') as f:
    pickle.dump(res, f)

In [None]:
files.download("/content/udeppllama.pickle")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>