In [9]:
import torch

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoConfig,
    # BitsAndBytesConfig,
)

import json
import evaluate

In [2]:
# !pip install torch transformers peft bitsandbytes evaluate datasets

Collecting torch
  Using cached torch-2.4.0-cp312-none-macosx_11_0_arm64.whl.metadata (26 kB)
Collecting transformers
  Using cached transformers-4.44.2-py3-none-any.whl.metadata (43 kB)
Collecting peft
  Using cached peft-0.12.0-py3-none-any.whl.metadata (13 kB)
Collecting bitsandbytes
  Using cached bitsandbytes-0.42.0-py3-none-any.whl.metadata (9.9 kB)
Collecting evaluate
  Using cached evaluate-0.4.2-py3-none-any.whl.metadata (9.3 kB)
Collecting datasets
  Using cached datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting filelock (from torch)
  Using cached filelock-3.15.4-py3-none-any.whl.metadata (2.9 kB)
Collecting typing-extensions>=4.8.0 (from torch)
  Using cached typing_extensions-4.12.2-py3-none-any.whl.metadata (3.0 kB)
Collecting sympy (from torch)
  Using cached sympy-1.13.2-py3-none-any.whl.metadata (12 kB)
Collecting networkx (from torch)
  Using cached networkx-3.3-py3-none-any.whl.metadata (5.1 kB)
Collecting jinja2 (from torch)
  Using cached jinja2-3.1.4-py

In [4]:
from dataclasses import dataclass
import os

@dataclass
class DataClass:
    MODEL_PATH = ["weights/checkpoint-897", "Qwen/Qwen2-0.5B-Instruct"][0]
    MAX_LENGTH = 96
    EPOCH = 3
    LORA_RANK = 2
    LORA_ALPHA = 2 * LORA_RANK
    LORA_DROPOUT = 0.5
    LORA_MODULES = ["o_proj", "qjv_proj", "gate_up_proj"]
    LR = 5e-5
    MODEL_SAVE_FOLDER = '/content/drive/MyDrive/weights'
    DEVICE = 'cuda' if torch.cuda.is_available() else 'mps'

# Macbook MPS
if DataClass.DEVICE == 'mps':
    os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

In [5]:
model_config = AutoConfig.from_pretrained(
    DataClass.MODEL_PATH,
    trust_remote_code = True,
    attn_implementation = 'eager', #'flash_attention_2'
)

tokenizer = AutoTokenizer.from_pretrained(
    DataClass.MODEL_PATH,
    trust_remote_code = True
)

tokenizer.pad_token = tokenizer.eos_token

# quant_config = BitsAndBytesConfig(
#     load_in_4bit = True,
#     bnb_4bit_quant_type="n4f",
#     bnb4bit_compute_dtype=torch.bfloat16,
#     bnb_4bit_use_double_quant=True
# )

model = AutoModelForCausalLM.from_pretrained(
    DataClass.MODEL_PATH,
    device_map=DataClass.DEVICE,
    low_cpu_mem_usage=True,
    # load_in_8bit=True,
    # load_in_4bit=True,
    attn_implementation='eager', #'flash_attention_2',
    torch_dtype=torch.bfloat16, # NOTE: MPS does not support torch.bfloat16 finetuning
    trust_remote_code=True,
    # quantization_config=quant_config
)

In [6]:
def inference(input_text):
    input_ids = tokenizer(input_text, return_tensors="pt")
    outputs = model.generate(
        # **input_ids
        max_new_tokens=100,
        do_sample=False,
        num_beams=1,
        temperature=None,
        top_k=None,
        top_p=None,
        input_ids=input_ids['input_ids'].to(DataClass.DEVICE),
        attention_mask=input_ids['attention_mask'].to(DataClass.DEVICE)
    )
    # Only generate output
    input_token_len = input_ids['input_ids'].shape[-1]
    return tokenizer.decode(outputs[0][input_token_len:], skip_special_tokens=True)

In [7]:
def prompter(question):
    prompt = f'''<|im_start|>system
You are an advanced language model adept at interpreting and refining noisy or imperfect user inputs.
Given user data, your task is to accurately extract the intended question and provide precise answers or predictions, even if the input contains errors or discontinuities.<|im_end|>
<|im_start|>user
{question}<|im_end|>
<|im_start|>assistant
'''
    # print(prompt)
    return inference(prompt)

In [6]:
import json
with open("valid.json", "r") as f:
    valid_data = json.load(f)

In [7]:
for i in range(10):
    q = valid_data[i]['input_disfluent']
    a = valid_data[i]['output_original']

    print("LLM:", prompter(q))
    print("ANS:", a)
    print('---')

LLM: What did the government want Thoreau to do?
ANS: What did the government want Thoreau to do?
---
LLM: What makes the Wells Fargo Center stand out?
ANS: What makes the Wells Fargo Center stand out?
---
LLM: What was the Colonia Agrippina's original name?
ANS: What was the Colonia Agrippina's original name?
---
LLM: Extended authorization networking benefits helped those that could not connect to what platform?
ANS: Extended networking benefits helped those that could not connect to what platform? 
---
LLM: Who is the emphasis on when there is a private finance initiative?
ANS: Who is the emphasis on when there is a private finance initiative?
---
LLM:  What dynasties inspired the Chinese-like elements of Kublai's government?
ANS: What dynasties inspired the Chinese-like elements of Kublai's government?
---
LLM: What is the density of all primes compatible with a modulo 9?
ANS: What is the density of all primes compatible with a modulo 9?
---
LLM: What did European empires rely on t

In [8]:
bleu_metric = evaluate.load("bleu")
tot_bleu = 0.
model_io = []

for data in valid_data:
    ques = data['input_disfluent']
    ref = data['output_original']
    pred = prompter(ques)
    tot_bleu += bleu_metric.compute(predictions=[pred], references=[ref])['bleu']
    model_io.append({"input_ques": ques, "output_pred": pred, "output_ground": ref})

with open("model_io.json", "w") as f:
    json.dump(model_io, f, indent=2)

print(tot_bleu/len(valid_data))

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


0.8716455833892353


In [11]:
bleu_metric = evaluate.load("bleu")
tot_bleu = 0.

with open("train.json", "r") as f:
    train_data = json.load(f)

for data in train_data:
    ques = data['input_disfluent']
    ref = data['output_original']
    pred = prompter(ques)
    tot_bleu += bleu_metric.compute(predictions=[pred], references=[ref])['bleu']

print(tot_bleu/len(train_data))