In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorWithPadding
from datasets import load_dataset
import torch
import numpy as np

In [22]:
model_name = 'TinyLlama/TinyLlama-1.1B-Chat-v1.0'
inference_model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')

In [23]:
from peft import PeftModel

In [24]:
inference_model.resize_token_embeddings(len(tokenizer))

# lora_path = '../llama_medqa_clm_lora/checkpoint-25440'
lora_path = '../models'

inference_model = PeftModel.from_pretrained(inference_model, lora_path)
device = 'cuda:2'
inference_model.to(device)
inference_model.eval()

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(32000, 2048)
        (layers): ModuleList(
          (0-21): 22 x LlamaDecoderLayer(
            (self_attn): LlamaSdpaAttention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=2048, out_features=2048, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2048, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=2048, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): Linear(in_features=2048, out_features=256, bias=False)
              (v_proj): lora.

In [151]:
# data_files = '../Data/MIMICIII/test.final.json'
# data_files = '../Data/MedQA/data_clean/questions/US/test.jsonl'
data_files = '../Data/MedQA/data_clean/questions/US/test_type1_removed.jsonl'
# data_files = '../Data/MedQA/data_clean/questions/US/type1_test.jsonl'
# data_files = '../Data/MedQA/data_clean/questions/US/subset_output.jsonl'

dataset = load_dataset('json', data_files=data_files)
dataset

DatasetDict({
    train: Dataset({
        features: ['question', 'answer', 'options', 'meta_info', 'answer_idx'],
        num_rows: 1249
    })
})

In [152]:
# c_q = f"Context: {dataset['train']['data'][0][0]['paragraphs'][0]['context']} Question: {dataset['train']['data'][0][0]['paragraphs'][0]['qas'][0]['question']}? Answer:"
# c_q = f"Context: {dataset['train'][0]['question']} Options: {dataset['train'][0]['options']}"
# print(c_q)
# dataset['train'][0]['question']

In [153]:
start_index = 0 # WE NEED A SLICE
end_index = 200

questions = dataset['train'][start_index:end_index]['question']
answers = dataset['train'][start_index:end_index]['options']


# context = dataset['train'][q]['question']
# context = f"Context: {dataset['train'][q]['question']} Options: {dataset['train'][q]['options']} Answer:"
context = [f"Context: {q} Options: {a} Answer:" for q,a in zip(questions, answers)]
# context = dataset['train']['question']
# target = dataset['train']['answer']

true_index = dataset['train'][start_index:end_index]['answer_idx']
true_answers = dataset['train'][start_index:end_index]['answer']
target = [f"{i}: {a}" for i,a in zip(true_index, true_answers)]

# MIMIC
# context = f"Context: {dataset['train']['data'][0][0]['paragraphs'][0]['context']} Question: {dataset['train']['data'][0][0]['paragraphs'][0]['qas'][0]['question']}? Answer:"
# target = dataset['train']['data'][0][0]['paragraphs'][0]['qas'][0]['answers'][0]['text']

batch = tokenizer(context, truncation=True, padding=True, return_tensors="pt")

In [154]:
batch = {k: v.to(device) for k, v in batch.items()}
batch['input_ids'].shape

torch.Size([200, 601])

In [155]:
output_tokens = inference_model.generate(
    **batch,
    max_new_tokens=64,
    do_sample=True,
    temperature=0.01,
    top_p=0.95,
    top_k=50,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id,
)

In [156]:
target_predicted = [tokenizer.decode(tok, skip_special_tokens=True).split("<|endcontext|>")[0].split('Answer: ')[1] for tok in output_tokens]
target_predicted

['A: Disclose the error to the patient but leave it out of the operative report',
 'A: Inhibition of thymidine synthesis',
 'B: Allergic interstitial nephritis',
 'A: Coagulase-positive, gram-positive cocci forming mauve-colored colonies on methicillin-containing agar',
 'D: Fluorometholone eye drops',
 'D: Propranolol',
 'A: Renal artery stenosis',
 'D: Spironolactone',
 'C: Active or recurrent pelvic inflammatory disease (PID)',
 'A: Silvery plaques on extensor surfaces',
 'A: It determines the genotype of the virus',
 'D: Ruxolitinib',
 'A: Renal cell carcinoma',
 'B: An increase in left ventricular end-diastolic pressure',
 'B: Epstein-Barr virus',
 'B: Gallbladder cancer',
 'B: IL-2',
 'C: Restriction Options: A: Stratified analysis B: Blinding C: Restriction D: Randomization E: Matching',
 'D: Pericardiocentesis',
 'E: Benzodiazepine intoxication\n"',
 'C: Previous radiation therapy',
 'D: Maternal alcohol consumption',
 'B: Aspergillus fumigatus infection',
 'B: Streptococcus pn

In [157]:
matching_answers = [pred == acc for pred,acc in zip(target_predicted, target)]

sum(matching_answers)

35

In [148]:
correct_indices = [i + start_index for i,a in enumerate(matching_answers) if a]
correct_indices

[0,
 1,
 2,
 4,
 5,
 6,
 7,
 8,
 10,
 16,
 20,
 22,
 24,
 26,
 27,
 29,
 30,
 31,
 32,
 34,
 35,
 38,
 48,
 52,
 64,
 75,
 76,
 77,
 79,
 81]

In [149]:
print(f"{context=} \n\n {target_predicted=} \n\n {target=}")
type(output_tokens)

context=["Context: A 37-year-old-woman presents to her primary care physician requesting a new form of birth control. She has been utilizing oral contraceptive pills (OCPs) for the past 8 years, but asks to switch to an intrauterine device (IUD). Her vital signs are: blood pressure 118/78 mm Hg, pulse 73/min and respiratory rate 16/min. She is afebrile. Physical examination is within normal limits. Which of the following past medical history statements would make copper IUD placement contraindicated in this patient? Options: {'A': 'A history of stroke or venous thromboembolism', 'B': 'Current tobacco use', 'C': 'Active or recurrent pelvic inflammatory disease (PID)', 'D': 'Past medical history of breast cancer', 'E': 'Known liver neoplasm'} Answer:", "Context: A 23-year-old woman comes to the physician because she is embarrassed about the appearance of her nails. She has no history of serious illness and takes no medications. She appears well. A photograph of the nails is shown. Which 

torch.Tensor

In [None]:
print("accuracy: " + str(len(correct_indices)/len(target)))