###  SFT Llama for event extraction

In [None]:
import pandas as pd
from tqdm import tqdm
import datasets
from peft import LoraConfig
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
import torch
from transformers import LlamaTokenizer,LlamaForCausalLM,TrainingArguments
tqdm.pandas()

df_event_extraction = pd.read_csv('../Corpus/event_extraction_human_annotation.csv')

ds_train = datasets.Dataset.from_pandas(df_event_extraction[:100])
ds_test = datasets.Dataset.from_pandas(df_event_extraction[100:])

model_name = '../models/Llama-2-13b-chat-hf'

tokenizer = LlamaTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = LlamaForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    max_memory={0:"15GB", 1:"15GB", 2:"15GB", 3:"15GB"}, 
)

def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['sentence'])):
        text = f"### Question: An event is defined as somebody/something has some actions/states.\
List all events can be infered from the given sentence. \
If no event exists, simply respond 'no event.' \
The sentence: {example['sentence'][i]} \n \
### Answer: {example['human_label'][i]}"+tokenizer.eos_token
        output_texts.append(text)
    return output_texts

response_template_with_context = "\n ### Answer:" 
response_template_ids = tokenizer.encode(response_template_with_context, add_special_tokens=False)[2:]  
data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer)

training_args = TrainingArguments(
    output_dir="../models/tmp",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=12,
    learning_rate=0.00005,
    logging_steps=10,
    remove_unused_columns=True,
)
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    max_seq_length=512,
    train_dataset=ds_train,
    eval_dataset=ds_test,
    peft_config=peft_config,
    formatting_func=formatting_prompts_func,
    data_collator=data_collator,
    
)

trainer.train()
trainer.save_model("../models/llama-2-13b-hf-for-event-extraction")

### Event extraction on Wiki corpus

In [None]:
import os
from dataclasses import dataclass, field
from transformers import LlamaTokenizer,LlamaForCausalLM
import datasets
from tqdm import tqdm
import pandas as pd
import torch


corpus_path = '../Corpus/Wiki_Corpus'
model_name = "../models/llama-2-13b-hf-for-event-extraction"
prompt = "### Question: An event is defined as somebody/something has some actions/states.\
List all events can be infered from the given sentence. If no event exists, simply respond 'no event.' \
The sentence: {} \n ### Answer:"

if not os.path.exists("../Corpus/llama_labeled_events"):
    os.mkdir("../Corpus/llama_labeled_events")

corpus_file_list = os.listdir(corpus_path)

tokenizer = LlamaTokenizer.from_pretrained(model_name)
model = LlamaForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    max_memory={0:"15GB", 1:"15GB", 2:"15GB", 3:"15GB"}, 
)

def prompt_llama(model, tokenizer, prompt):
    inputs =  tokenizer(prompt, return_tensors="pt")
    with torch.no_grad():
        outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), max_new_tokens=100)
        outputs_string = tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0]
        outputs_pred = outputs_string.split("### Answer:")[1].strip().split('### Question')[0]
    return outputs_pred 
    
for i in tqdm(range(len(corpus_file_list))):
    answer_list = []
    with open(os.path.join(corpus_path,corpus_file_list[i]),'r') as f:
        line = f.readline()
        while line:
            if len(line.split(" "))>30:
                sentence_list = line.split('.')
                for sentence in sentence_list:
                    if len(sentence)<10:
                        continue
                    full_prompt = prompt.format(sentence)
                    outputs_pred = prompt_llama(model, tokenizer, full_prompt)
                    answer_list.append({'sentence':sentence,'llama_labeled_events':outputs_pred})
                    print(sentence)
                    print(outputs_pred)
                    print("-")
            line = f.readline()
    df = pd.DataFrame(answer_list)
    df.to_csv(os.path.join("../Corpus/llama_labeled_events",corpus_file_list[i].split('.')[0]+'.csv'))