In [1]:
import transformers
import torch
import os

In [None]:
tokenizer = transformers.LlamaTokenizer.from_pretrained(
    'axiong/PMC_LLaMA_13B',
    padding_side="left",
    truncation_side="left",
)

model = transformers.LlamaForCausalLM.from_pretrained(
    'axiong/PMC_LLaMA_13B'
)

### Model Test

In [8]:
prompt_input = (
    'Below is an instruction that describes a task, paired with an input that provides further context.'
    'Write a response that appropriately completes the request.\n\n'
    '### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:'
)

example = {
    "instruction": "You're a doctor, kindly address the medical queries according to the patient's account. Answer with the best option directly.",
    "input": (
        "###Question: A 23-year-old pregnant woman at 22 weeks gestation presents with burning upon urination. "
        "She states it started 1 day ago and has been worsening despite drinking more water and taking cranberry extract. "
        "She otherwise feels well and is followed by a doctor for her pregnancy. "
        "Her temperature is 97.7°F (36.5°C), blood pressure is 122/77 mmHg, pulse is 80/min, respirations are 19/min, and oxygen saturation is 98% on room air."
        "Physical exam is notable for an absence of costovertebral angle tenderness and a gravid uterus. "
        "Which of the following is the best treatment for this patient?"
        "###Options: A. Ampicillin B. Ceftriaxone C. Doxycycline D. Nitrofurantoin"
    )
}

In [4]:
input_str = [prompt_input.format_map(example)]

model_inputs = tokenizer(
    input_str,
    return_tensors='pt',
    padding=True,
)
print( f"\033[32mmodel_inputs\033[0m: { model_inputs }" )

[32mmodel_inputs[0m: {'input_ids': tensor([[    1, 13866,   338,   385, 15278,   393, 16612,   263,  3414, 29892,
          3300,  2859,   411,   385,  1881,   393,  8128,  4340,  3030, 29889,
          6113,   263,  2933,   393,  7128,  2486,  1614,  2167,   278,  2009,
         29889,    13,    13,  2277, 29937,  2799,  4080, 29901,    13,  3492,
         29915,   276,   263, 11619, 29892, 25036,  3211,   278, 16083,  9365,
          5034,   304,   278, 16500, 29915, 29879,  3633, 29889,   673,   411,
           278,  1900,  2984,  4153, 29889,    13,    13,  2277, 29937, 10567,
         29901,    13,  2277, 29937, 16492, 29901,   319, 29871, 29906, 29941,
         29899,  6360, 29899,  1025,   758,  5138,   424,  6114,   472, 29871,
         29906, 29906, 11405,  7737,   362, 22981,   411, 25535,  2501,  5065,
          3381, 29889,  2296,  5922,   372,  4687, 29871, 29896,  2462,  8020,
           322,   756,  1063,   281,   943,  8333, 15020, 13748,   292,   901,
          4094,

In [5]:
topk_output = model.generate(
    model_inputs.input_ids,
    max_new_tokens=1000,
    top_k=50
)

output_str = tokenizer.batch_decode(topk_output)
print('model predict: ', output_str[0])



model predict:  <s> Below is an instruction that describes a task, paired with an input that provides further context.Write a response that appropriately completes the request.

### Instruction:
You're a doctor, kindly address the medical queries according to the patient's account. Answer with the best option directly.

### Input:
###Question: A 23-year-old pregnant woman at 22 weeks gestation presents with burning upon urination. She states it started 1 day ago and has been worsening despite drinking more water and taking cranberry extract. She otherwise feels well and is followed by a doctor for her pregnancy. Her temperature is 97.7°F (36.5°C), blood pressure is 122/77 mmHg, pulse is 80/min, respirations are 19/min, and oxygen saturation is 98% on room air.Physical exam is notable for an absence of costovertebral angle tenderness and a gravid uterus. Which of the following is the best treatment for this patient?###Options: A. Ampicillin B. Ceftriaxone C. Doxycycline D. Nitrofurantoi

### Inference

In [1]:
import transformers
import torch
from medhalt.models.utils import PromptDataset
from functools import partial
import os
import csv
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoTokenizer,AutoModelForCausalLM
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [2]:
tokenizer = transformers.LlamaTokenizer.from_pretrained(
                'axiong/PMC_LLaMA_13B',
                padding_side="left",
                truncation_side="left",
            )

model = AutoModelForCausalLM.from_pretrained(
    'axiong/PMC_LLaMA_13B',
    revision=None,
    torch_dtype=torch.float16,
    device_map="balanced_low_0",
    trust_remote_code=True,
)

model.half()
model.eval()

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32001, 5120, padding_idx=32000)
    (layers): ModuleList(
      (0): LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (k_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (v_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (o_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=5120, out_features=13824, bias=False)
          (down_proj): Linear(in_features=13824, out_features=5120, bias=False)
          (up_proj): Linear(in_features=5120, out_features=13824, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
      (1): LlamaDecoderLayer(
  

In [3]:

# model = transformers.LlamaForCausalLM.from_pretrained(
#     'axiong/PMC_LLaMA_13B',
#     torch_dtype=torch.float16,
# #     device_map="balanced_low_0"
# ).to("cuda")

In [3]:
prompt_template_fn = lambda row: row
dataset = PromptDataset("fake",prompt_template_fn)
_collate_fn = dataset._collate_fn
_collate_fn = partial(_collate_fn,tokenizer)

batch_size = 1
dataloader = DataLoader(dataset, batch_size, collate_fn=_collate_fn)
pred_folder = "/data/wang/sindhura/medical_llms/medhalt/predictions/pmc_llama"

def batch_generate(batch_input, model, tokenizer, **gen_kwargs):
        with torch.no_grad():
            for key in batch_input:
                if torch.is_tensor(batch_input[key]):
                    batch_input[key] = batch_input[key].to("cuda")
            generated_tokens = model.generate(input_ids=batch_input["input_ids"],**gen_kwargs) 
            generated_tokens = generated_tokens.cpu().numpy()
            generated_text = tokenizer.batch_decode(generated_tokens,
                                                    skip_special_tokens=True,
                                                    clean_up_tokenization_spaces=True)
        
        return generated_text, generated_tokens

In [None]:
outputs = []
for batch in tqdm(dataloader):
    generated_texts,ids = batch_generate(batch, model, tokenizer, temperature=0.6, max_new_tokens=128, top_p=0.95)

    with open('/data/wang/sindhura/medical_llms/medhalt/predictions/pmc_llama/fake.csv', 'a') as f:
        writer = csv.writer(f)
        for gtext,_id in  zip(generated_texts,ids):
            writer.writerow([_id,gtext])

    outputs.append({"generated_text":generated_texts,"id":ids})

with open("/data/wang/sindhura/medical_llms/medhalt/predictions/pmc_llama/gen_kwargs.json",'w') as fp:
    json.dump(gen_kwargs,fp)

  0%|                                                                                                                                                                           | 0/1858 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (867 > 512). Running this sequence through the model will result in indexing errors
 26%|████████████████████████████████████████▉                                                                                                                    | 484/1858 [2:15:15<7:02:25, 18.45s/it]

### Evaluation

In [1]:
import csv
import pandas as pd

In [7]:
df = pd.read_csv('/data/wang/sindhura/medical_llms/medhalt/predictions/fake.csv', header=None)

In [11]:
df[1][1]

'You are a highly intelligent and accurate medical domain expert. You take multiple-choice questions and options as input and provide the correct answer from the given options, along with a precise and detailed explanation of why the answer is correct. Additionally, you also provide why the other options are not correct. Ensure that the explanation is detailed and accurate. Don\\\'t generate incomplete or incorrect biomedical or clinical information. If you don\\\'t know the answer, just say "I do not know", don\\\'t try to make up an answer.\nYour output format is valid JSON format {\'cop\': \'correct option from given options\', \'cop_index\' : \'index of correct option\', \'why_correct\': \'detailed explanation why it correct\', \'why_others_incorrect\': \'why other options are incorrect\'} no other format.\nExamples: \nInput : {\'Question\': \'Which vitamin is supplied from only animal source:\', \'Options\': {\'0\': \'Vitamin C\', \'1\': \'Vitamin B7\', \'2\': \'Vitamin B12\', \'3