In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,HfArgumentParser,TrainingArguments,pipeline, logging, TextStreamer
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model 
import transformers
import torch
from datasets import load_dataset
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
base_model = "mistralai/Mistral-7B-v0.1" 


In [5]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=False,
)

In [6]:
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",

)


Loading checkpoint shards: 100%|██████████| 2/2 [00:12<00:00,  6.29s/it]


In [7]:
ds = load_dataset(path="facebook/belebele",split="spa_Latn")

In [8]:
# Select the first five rows of the dataset for example prompts
ds_examples=ds.select(range(0,5))
ds_prompts=ds.select(range(5,len(ds)))

prompt_template="""{flores_passage}
Question: {question}
Answer A: {mc_answer1}
Answer B: {mc_answer2}
Answer C: {mc_answer3}
Answer D: {mc_answer4}
Correct answer: {correct_answer}"""

# Prepare example prompts for 5-shot prompting
choices=["A","B","C","D"]
prompt_examples = "\n\n".join([ prompt_template.format(**d,correct_answer=choices[int(d["correct_answer_num"])-1]) for d in ds_examples])

In [9]:
model

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear4bit(
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.2, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=4096, out_features=256, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=256, out_features=4096, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
          )
          (k_proj): Linear4bit(
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.2, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): Linear(in_featur

In [10]:
tokenizer_path = "mistralai/Mistral-7B-v0.1" 

In [11]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)


In [12]:
pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

In [13]:
# parse model response and extract the model'schoice
def parse_choice(response):
    choices=["A","B","C","D"]
    
    if len(response)==1:
        return choices.index(response[0]) + 1 if response[0] in choices else None
    elif response[0] in choices and not response[1].isalpha():
        return choices.index(response[0]) + 1
    else:
        return None

# sampling parameters: llama-precise
gen_config = {
    "temperature": 0.7,
    "top_p": 0.1,
    "repetition_penalty": 1.18,
    "top_k": 40,
    "do_sample": True,
    "max_new_tokens": 5,
    "pad_token_id": pipeline.tokenizer.eos_token_id,
}

# Loop through prompts and evaluate model responses
q_correct = q_total = 0
for rowNo, row in enumerate(tqdm(ds_prompts)):        
    # Construct the prompt by combining the example prompts and the current row's question
    prompt=(prompt_examples + "\n\n" + prompt_template.format(**row, correct_answer="")).strip()

    # Generate a response from the model
    response=pipeline(prompt, **gen_config)[0]["generated_text"][len(prompt):]
    if "\n" in response:
        response=response.split("\n")[0]

    # Parse the model's choice and compare it to the correct answer
    choice=parse_choice(response.strip())
    if choice==int(row["correct_answer_num"]):
        q_correct+=1 
    q_total+=1

print(f"{q_total} questions, {q_correct} correct ({round(q_correct/q_total*100,1)}%)")  

  0%|          | 0/895 [00:00<?, ?it/s]

100%|██████████| 895/895 [22:25<00:00,  1.50s/it]

895 questions, 362 correct (40.4%)



