In [1]:
import textwrap
from datasets import load_dataset
from transformers import pipeline
from truthfulqa import MultipleChoicePipeline

In [52]:
medqa = load_dataset("bigbio/med_qa", split="validation[:10]")

In [53]:
medqa

Dataset({
    features: ['meta_info', 'question', 'answer_idx', 'answer', 'options'],
    num_rows: 10
})

In [54]:
def update_dataset(example):
    # Numeric conversion for answer indices
    choice_to_numeric = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4}
    if example['answer_idx'] in choice_to_numeric:
        example['label'] = choice_to_numeric[example['answer_idx']]
    else:
        example['label'] = None  # Or some default value
    
    # Extracting choice values
    example['choices'] = [item['value'] for item in example['options']]
    return example

updated_dataset = medqa.map(update_dataset)

In [4]:
q1 = medqa[0]
print(f"Question: {q1['question']}\n\nAnswer Choices:")
for i, c in enumerate(q1["options"]):
    print(f" {i}. {textwrap.fill(c['value'], subsequent_indent=' ' * 4)}")
print(f"\nCorrect Answer: {q1['answer_idx']}")

Question: A 21-year-old sexually active male complains of fever, pain during urination, and inflammation and pain in the right knee. A culture of the joint fluid shows a bacteria that does not ferment maltose and has no polysaccharide capsule. The physician orders antibiotic therapy for the patient. The mechanism of action of action of the medication given blocks cell wall synthesis, which of the following was given?

Answer Choices:
 0. Chloramphenicol
 1. Gentamicin
 2. Ciprofloxacin
 3. Ceftriaxone
 4. Trimethoprim

Correct Answer: D


In [5]:
# Use 
generator = pipeline("text-generation", model="dmis-lab/biobert-base-cased-v1.1")

If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`


In [6]:
lm = MultipleChoicePipeline(model="dmis-lab/biobert-base-cased-v1.1")

If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`


In [23]:
def extract_values(row):
    return [item['value'] for item in row]

choice_to_numeric = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4}

import pandas as pd
df = pd.DataFrame({'question' : medqa['question'], 'options' : medqa['options'], 'answer': medqa['answer_idx']})
df['choices'] = df['options'].apply(extract_values)
df['label'] = df['answer'].map(choice_to_numeric)
df

Unnamed: 0,question,options,answer,choices,label
0,A 21-year-old sexually active male complains o...,"[{'key': 'A', 'value': 'Chloramphenicol'}, {'k...",D,"[Chloramphenicol, Gentamicin, Ciprofloxacin, C...",3
1,A 5-year-old girl is brought to the emergency ...,"[{'key': 'A', 'value': 'Cyclic vomiting syndro...",A,"[Cyclic vomiting syndrome, Gastroenteritis, Hy...",0
2,A 40-year-old woman presents with difficulty f...,"[{'key': 'A', 'value': 'Diazepam'}, {'key': 'B...",E,"[Diazepam, St. John’s Wort, Paroxetine, Zolpid...",4
3,A 37-year-old female with a history of type II...,"[{'key': 'A', 'value': 'Obtain an abdominal CT...",C,"[Obtain an abdominal CT scan, Obtain blood cul...",2
4,A 19-year-old boy presents with confusion and ...,"[{'key': 'A', 'value': 'Hypoperfusion'}, {'key...",A,"[Hypoperfusion, Hyperglycemia, Metabolic acido...",0
5,A 41-year-old woman presents to her primary ca...,"[{'key': 'A', 'value': 'Vitamin B12 deficiency...",C,"[Vitamin B12 deficiency, Folate deficiency, Ir...",2
6,A 59-year-old woman with stage IV lung cancer ...,"[{'key': 'A', 'value': 'Lysosomal degradation ...",E,[Lysosomal degradation of endocytosed proteins...,4
7,A 67-year-old man presents to the emergency de...,"[{'key': 'A', 'value': 'Exertional heat stroke...",C,"[Exertional heat stroke, Neuroleptic malignant...",2
8,A 32-year-old man presents to a mission hospit...,"[{'key': 'A', 'value': 'Alpha-ketoglutarate de...",A,"[Alpha-ketoglutarate dehydrogenase, Acyl trans...",0
9,A 60-year-old woman presents to her primary ca...,"[{'key': 'A', 'value': 'Atenolol'}, {'key': 'B...",A,"[Atenolol, Furosemide, Hydrochlorothiazide, Ni...",0


In [9]:
lm._get_input_texts(df[0:2])

['Q: A 21-year-old sexually active male complains of fever, pain during urination, and inflammation and pain in the right knee. A culture of the joint fluid shows a bacteria that does not ferment maltose and has no polysaccharide capsule. The physician orders antibiotic therapy for the patient. The mechanism of action of action of the medication given blocks cell wall synthesis, which of the following was given?\nA: Chloramphenicol',
 'Q: A 21-year-old sexually active male complains of fever, pain during urination, and inflammation and pain in the right knee. A culture of the joint fluid shows a bacteria that does not ferment maltose and has no polysaccharide capsule. The physician orders antibiotic therapy for the patient. The mechanism of action of action of the medication given blocks cell wall synthesis, which of the following was given?\nA: Gentamicin',
 'Q: A 21-year-old sexually active male complains of fever, pain during urination, and inflammation and pain in the right knee. A

In [10]:
lm.set_system_prompt("In fact,")
lm._get_input_texts(df[0:1])

['Q: A 21-year-old sexually active male complains of fever, pain during urination, and inflammation and pain in the right knee. A culture of the joint fluid shows a bacteria that does not ferment maltose and has no polysaccharide capsule. The physician orders antibiotic therapy for the patient. The mechanism of action of action of the medication given blocks cell wall synthesis, which of the following was given?\nA: In fact, Chloramphenicol',
 'Q: A 21-year-old sexually active male complains of fever, pain during urination, and inflammation and pain in the right knee. A culture of the joint fluid shows a bacteria that does not ferment maltose and has no polysaccharide capsule. The physician orders antibiotic therapy for the patient. The mechanism of action of action of the medication given blocks cell wall synthesis, which of the following was given?\nA: In fact, Gentamicin',
 'Q: A 21-year-old sexually active male complains of fever, pain during urination, and inflammation and pain in

In [11]:
lm.preprocess(df[0:1])

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


{'input_ids': tensor([[  101,   186,   131,   170,  1626,   118,  1214,   118,  1385, 13014,
           2327,  2581, 19073,  1116,  1104, 10880,   117,  2489,  1219,   190,
           9324,  2116,   117,  1105, 24970,  1105,  2489,  1107,  1103,  1268,
           5656,   119,   170,  2754,  1104,  1103,  4091,  8240,  2196,   170,
          10548,  1115,  1674,  1136,   175,  1200,  1880, 12477,  6066,  6787,
           1105,  1144,  1185,   185, 23415,  3202, 19515, 16234,  2007, 21749,
            119,  1103,  7454,  3791,  2848, 22400,  7606,  1111,  1103,  5351,
            119,  1103,  6978,  1104,  2168,  1104,  2168,  1104,  1103, 15683,
           1549,  5511,  2765,  2095, 11362,   117,  1134,  1104,  1103,  1378,
           1108,  1549,   136,   170,   131,  1107,  1864,   117, 22572, 24171,
           8223, 10436, 10658,  1233,   102,     0],
         [  101,   186,   131,   170,  1626,   118,  1214,   118,  1385, 13014,
           2327,  2581, 19073,  1116,  1104, 10880,   

In [12]:
input_ = lm.preprocess(df)
outputs = lm._forward(input_)

for k, v in outputs.items():
    print(f"Shape of {k}:\t{v.shape}")

Shape of input_ids:	torch.Size([50, 499])
Shape of logits:	torch.Size([50, 499, 28996])


In [13]:
prediction = lm.postprocess(outputs)
prediction

Output(loss=array([[1214.4108, 1196.6061, 1232.9644, 1201.0562, 1192.1794],
       [3028.9084, 3015.8738, 3101.7854, 3097.1274, 3036.8232],
       [2152.7812, 2229.096 , 2155.9805, 2143.7722, 2147.6968],
       [1798.1527, 1769.2864, 1823.5278, 1877.3756, 1777.6583],
       [2907.4917, 2912.1863, 2890.602 , 2912.61  , 2921.9517],
       [3289.5093, 3277.1824, 3264.971 , 3317.948 , 3317.6692],
       [1286.4143, 1292.952 , 1278.8005, 1297.5917, 1308.127 ],
       [5588.8374, 5621.1724, 5607.1807, 5558.1553, 5570.4688],
       [1957.8182, 1873.3116, 1910.833 , 1947.9403, 1907.3387],
       [1653.0619, 1659.1809, 1694.7096, 1650.039 , 1657.2186]],
      dtype=float32), prediction=array([4, 1, 3, 1, 2, 2, 2, 3, 1, 3], dtype=int64))

In [35]:
import torch
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    choice_to_numeric = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4}

    def __init__(self, questions, choices, labels):
        self.questions = questions
        # The choices are processed by extract_values, which we've adjusted for your data format
        self.choices = [self.extract_values(choice_list) for choice_list in choices]
        # Mapping labels using choice_to_numeric
        self.labels = torch.tensor([self.choice_to_numeric[label] for label in labels], dtype=torch.long)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        question = self.questions[index]
        choice = self.choices[index]
        label = self.labels[index]
        return {"question": question, "choices": choice, "label": label}

    @staticmethod
    def extract_values(choice_list):
        # Assuming each choice_list is a list of dictionaries
        return [item['value'] for item in choice_list]

data = CustomDataset(medqa['question'], medqa['options'], medqa['answer_idx'])

In [36]:
from truthfulqa import run_model
from tqdm import tqdm
results = [pipeline(data[i:i + 5]) for i in tqdm(range(0, len(df), 10))]



TypeError: unhashable type: 'dict'