In [1]:
import torch
import json
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, PeftModel
from trl import SFTTrainer, SFTConfig
from datasets import Dataset
from tqdm import tqdm
import numpy as np

from mention import Mention, decode_mentions, encode_bio, decode_bio, split_text
from score import ScoringCounts, score_mentions

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "meta-llama/Llama-3.2-1B"
hf_auth = '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'

In [3]:
train_data_path = 'dataset/train_annotation.jsonl'
dev_data_path = 'dataset/dev_annotation.jsonl'
test_data_path = 'dataset/test_annotation.jsonl'

In [4]:
# Model parameters
batch_size = 32
learning_rate = 2e-5
epochs = 3
lora_rank = 16
lora_alpha = 32

In [5]:
# quatization parameters
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype='bfloat16'
)

In [6]:
# lora parameters
lora_config = LoraConfig(
    r=lora_rank,
    lora_alpha=lora_alpha,
    lora_dropout=0.05,
    target_modules=["q_proj", "v_proj"],
    bias="none",
    task_type="CAUSAL_LM"
)

In [8]:
args = SFTConfig(
        output_dir='Llama_3.2_1B_Zelda_NER_3',
        max_seq_length=256,
        packing=True,
        bf16=True,
        save_strategy="steps",
        save_steps=25,
        learning_rate=learning_rate,
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs = {'use_reentrant': True},
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
)

In [9]:
def prompt_format(instance):
    prompt = (
        f'Task: You are an expert Named Entity Recognition system specialized in The Legend of Zelda: breath of the wild. Your task is to identify and tag entities in the provided text.\n'
        f'Input: {instance["text"]}\n'
        f'Output: {instance["candidates"][0]}\n'
    )
    return prompt

In [10]:
def prompt_format_eval(instance):
    prompt = (
        f'Task: You are an expert Named Entity Recognition system specialized in The Legend of Zelda: breath of the wild. Your task is to identify and tag entities in the provided text.\n'
        f'Input: {instance["text"]}\n'
        f'Output: '
    )
    return prompt

***Start Processing***

In [11]:
with open(train_data_path, 'r', encoding='utf-8') as f:
    train_dataset = Dataset.from_list(list(map(json.loads, f)))

In [12]:
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    use_fast=False,
    max_length=256,
    padding='max_length',
    truncation=True,
    return_tensors='pt',
    token=hf_auth
)
tokenizer.pad_token = tokenizer.eos_token

In [13]:
# load model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map='auto',
    quantization_config=quantization_config,
    token=hf_auth
)

In [14]:
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    peft_config=lora_config,
    formatting_func=prompt_format,
    args=args
)

  trainer = SFTTrainer(
Applying formatting function to train dataset: 100%|██████████| 1471/1471 [00:00<00:00, 37711.00 examples/s]
Converting train dataset to ChatML: 100%|██████████| 1471/1471 [00:00<00:00, 54465.23 examples/s]
Applying chat template to train dataset: 100%|██████████| 1471/1471 [00:00<00:00, 52519.40 examples/s]
Tokenizing train dataset: 100%|██████████| 1471/1471 [00:00<00:00, 5945.93 examples/s]
Packing train dataset: 100%|██████████| 1471/1471 [00:00<00:00, 13017.76 examples/s]
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [15]:
trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss



Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.2-1B/resolve/main/config.json.
Access to model meta-llama/Llama-3.2-1B is restricted. You must have access to it and be authenticated to access it. Please log in. - silently ignoring the lookup for the file config.json in meta-llama/Llama-3.2-1B.

Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.2-1B/resolve/main/config.json.
Access to model meta-llama/Llama-3.2-1B is restricted. You must have access to it and be authenticated to access it. Please log in. - silently ignoring the lookup for the file config.json in meta-llama/Llama-3.2-1B.


TrainOutput(global_step=48, training_loss=2.1935609181722007, metrics={'train_runtime': 297.0188, 'train_samples_per_second': 5.141, 'train_steps_per_second': 0.162, 'total_flos': 2286483491782656.0, 'train_loss': 2.1935609181722007})

In [16]:
torch.cuda.empty_cache()

***Evaluation***

In [17]:
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map='auto',
    quantization_config=quantization_config,
    token=hf_auth
)

model_dev = PeftModel.from_pretrained(
    base_model,
    "Llama_3.2_1B_Zelda_NER_3/checkpoint-48",
    device_map='auto'
)

In [18]:
with open(dev_data_path, 'r', encoding='utf-8') as f:
    dev_dataset = Dataset.from_list(list(map(json.loads, f)))
dev_text = [instance['text'] for instance in dev_dataset]
dev_prompts = [prompt_format_eval(instance) for instance in dev_dataset]
dev_labels = [instance['candidates'][0] for instance in dev_dataset]

In [19]:
test_text = dev_prompts[:2]

In [32]:
print(test_text)

["Task: You are an expert Named Entity Recognition system specialized in The Legend of Zelda: breath of the wild. Your task is to identify and tag entities in the provided text.\nInput: B-b-because! Just beyond Goron City, they're rainin' down from the sky!\nOutput: ", "Task: You are an expert Named Entity Recognition system specialized in The Legend of Zelda: breath of the wild. Your task is to identify and tag entities in the provided text.\nInput: The Fang and Bone only opens up at nighttime. Apparently it's at Skull Lake, but I don't know where that is.\nOutput: "]


In [20]:
batch_size = 1

In [21]:
model_dev.eval()

responses = []
for i in tqdm(range(0, len(test_text), batch_size), desc="Processing", unit="batch"):
    batch = test_text[i:i + batch_size]
    inputs = tokenizer(
        batch,
        return_tensors="pt",  # Ensure output is in tensor format
        padding=True,
        truncation=True,
        max_length=256,
    ).to(device)

    with torch.no_grad():
        generated_ids = model_dev.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=256,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id, # Explicitly set pad token id
            eos_token_id=tokenizer.eos_token_id
        )
        
    '''
    input_lengths = inputs.input_ids.shape[1]
    response_ids = generated_ids[:, input_lengths:]
    generated_text = tokenizer.batch_decode(response_ids, skip_special_tokens=True)
    '''

    generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    for text in generated_text:
        responses.append(text)

Processing: 100%|██████████| 2/2 [00:11<00:00,  5.99s/batch]


In [22]:
print(responses)

["Task: You are an expert Named Entity Recognition system specialized in The Legend of Zelda: breath of the wild. Your task is to identify and tag entities in the provided text.\nInput: B-b-because! Just beyond Goron City, they're rainin' down from the sky!\nOutput: 1. B 2. b 3. b 4. b 5. b 6. b 7. b 8. b 9. b 10. b\nTask: You are an expert Named Entity Recognition system specialized in The Legend of Zelda: breath of the wild. Your task is to identify and tag entities in the provided text.\nInput: B-b-because! Just beyond Goron City, they're rainin' down from the sky!\nOutput: 1. B 2. b 3. b 4. b 5. b 6. b 7. b 8. b 9. b 10. b\nTask: You are an expert Named Entity Recognition system specialized in The Legend of Zelda: breath of the wild. Your task is to identify and tag entities in the provided text.\nInput: B-b-because! Just beyond Goron City, they're rainin' down from the sky!\nOutput: 1. B 2. b 3. b 4. b 5. b 6. b 7. b 8. b 9. b 10. b\nTask: You are an expert Named Entity Recognitio

In [None]:
dev_clean_text = [split_text(item) for item in dev_text]
target_mentions = []
for item, clean_text in zip(dev_labels, dev_clean_text):
    mention = decode_mentions(item, clean_text)
    target_mentions.append(mention)

In [None]:
predict_mentions = []
for item, clean_text in zip(responses, dev_clean_text):
    mention = decode_mentions(item, clean_text) # may need to be changed since the roduce text may be different
    predict_mentions.append(mention)

In [None]:
results = []
for i in range(len(dev_dataset)):
    result = score_mentions(target_mentions[i], predict_mentions[i])
    results.append(result)

In [None]:
total_true_positive = [score_counts[0] for score_counts in results]
total_false_positive = [score_counts[1] for score_counts in results]
total_false_negative = [score_counts[2] for score_counts in results]

In [None]:
precision = sum(total_true_positive) / (sum(total_true_positive) + sum(total_false_positive)) if (sum(total_true_positive) + sum(total_false_positive)) > 0 else 0
recall = sum(total_true_positive) / (sum(total_true_positive) + sum(total_false_negative)) if (sum(total_true_positive) + sum(total_false_negative)) > 0 else 0
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

In [None]:
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1_score:.4f}")