In [1]:
import textwrap
from datasets import load_dataset
from transformers import pipeline
from truthfulqa import MultipleChoicePipeline

In [24]:
medqa = load_dataset("bigbio/med_qa", split="validation[:100]")

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [25]:
medqa

Dataset({
    features: ['meta_info', 'question', 'answer_idx', 'answer', 'options'],
    num_rows: 100
})

In [26]:
q1 = medqa[0]
print(f"Question: {q1['question']}\n\nAnswer Choices:")
for i, c in enumerate(q1["options"]):
    print(f" {i}. {textwrap.fill(c['value'], subsequent_indent=' ' * 4)}")
print(f"\nCorrect Answer: {q1['answer_idx']}")

Question: A 21-year-old sexually active male complains of fever, pain during urination, and inflammation and pain in the right knee. A culture of the joint fluid shows a bacteria that does not ferment maltose and has no polysaccharide capsule. The physician orders antibiotic therapy for the patient. The mechanism of action of action of the medication given blocks cell wall synthesis, which of the following was given?

Answer Choices:
 0. Chloramphenicol
 1. Gentamicin
 2. Ciprofloxacin
 3. Ceftriaxone
 4. Trimethoprim

Correct Answer: D


In [27]:
# Use 
generator = pipeline("text-generation", model="dmis-lab/biobert-base-cased-v1.1")

If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`


In [28]:
lm = MultipleChoicePipeline(model="dmis-lab/biobert-base-cased-v1.1")

If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`


In [29]:
def extract_values(row):
    return [item['value'] for item in row]

choice_to_numeric = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4}

import pandas as pd
df = pd.DataFrame({'question' : medqa['question'], 'options' : medqa['options'], 'answer': medqa['answer_idx']})
df['choices'] = df['options'].apply(extract_values)
df['label'] = df['answer'].map(choice_to_numeric)
df

Unnamed: 0,question,options,answer,choices,label
0,A 21-year-old sexually active male complains o...,"[{'key': 'A', 'value': 'Chloramphenicol'}, {'k...",D,"[Chloramphenicol, Gentamicin, Ciprofloxacin, C...",3
1,A 5-year-old girl is brought to the emergency ...,"[{'key': 'A', 'value': 'Cyclic vomiting syndro...",A,"[Cyclic vomiting syndrome, Gastroenteritis, Hy...",0
2,A 40-year-old woman presents with difficulty f...,"[{'key': 'A', 'value': 'Diazepam'}, {'key': 'B...",E,"[Diazepam, St. John’s Wort, Paroxetine, Zolpid...",4
3,A 37-year-old female with a history of type II...,"[{'key': 'A', 'value': 'Obtain an abdominal CT...",C,"[Obtain an abdominal CT scan, Obtain blood cul...",2
4,A 19-year-old boy presents with confusion and ...,"[{'key': 'A', 'value': 'Hypoperfusion'}, {'key...",A,"[Hypoperfusion, Hyperglycemia, Metabolic acido...",0
...,...,...,...,...,...
95,A 28-year-old man presents with episodic abdom...,"[{'key': 'A', 'value': 'Goblet cell hyperplasi...",C,"[Goblet cell hyperplasia, Paneth cells metapla...",2
96,A 14-year-old boy is brought to the emergency ...,"[{'key': 'A', 'value': 'Parvovirus B19'}, {'ke...",A,"[Parvovirus B19, Medication-induced hemolysis,...",0
97,A 44-year-old woman comes to the office becaus...,"[{'key': 'A', 'value': 'Increased pulmonary ca...",A,"[Increased pulmonary capillary wedge pressure,...",0
98,A 43-year-old woman is brought to the emergenc...,"[{'key': 'A', 'value': 'CT pulmonary angiograp...",E,"[CT pulmonary angiography, D-dimer levels, Cat...",4


In [30]:
lm._get_input_texts(df[0:2])

['Q: A 21-year-old sexually active male complains of fever, pain during urination, and inflammation and pain in the right knee. A culture of the joint fluid shows a bacteria that does not ferment maltose and has no polysaccharide capsule. The physician orders antibiotic therapy for the patient. The mechanism of action of action of the medication given blocks cell wall synthesis, which of the following was given?\nA: Chloramphenicol',
 'Q: A 21-year-old sexually active male complains of fever, pain during urination, and inflammation and pain in the right knee. A culture of the joint fluid shows a bacteria that does not ferment maltose and has no polysaccharide capsule. The physician orders antibiotic therapy for the patient. The mechanism of action of action of the medication given blocks cell wall synthesis, which of the following was given?\nA: Gentamicin',
 'Q: A 21-year-old sexually active male complains of fever, pain during urination, and inflammation and pain in the right knee. A

In [31]:
lm.set_system_prompt("In fact,")
lm._get_input_texts(df[0:1])

['Q: A 21-year-old sexually active male complains of fever, pain during urination, and inflammation and pain in the right knee. A culture of the joint fluid shows a bacteria that does not ferment maltose and has no polysaccharide capsule. The physician orders antibiotic therapy for the patient. The mechanism of action of action of the medication given blocks cell wall synthesis, which of the following was given?\nA: In fact, Chloramphenicol',
 'Q: A 21-year-old sexually active male complains of fever, pain during urination, and inflammation and pain in the right knee. A culture of the joint fluid shows a bacteria that does not ferment maltose and has no polysaccharide capsule. The physician orders antibiotic therapy for the patient. The mechanism of action of action of the medication given blocks cell wall synthesis, which of the following was given?\nA: In fact, Gentamicin',
 'Q: A 21-year-old sexually active male complains of fever, pain during urination, and inflammation and pain in

In [32]:
lm.preprocess(df[0:1])

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


{'input_ids': tensor([[  101,   186,   131,   170,  1626,   118,  1214,   118,  1385, 13014,
           2327,  2581, 19073,  1116,  1104, 10880,   117,  2489,  1219,   190,
           9324,  2116,   117,  1105, 24970,  1105,  2489,  1107,  1103,  1268,
           5656,   119,   170,  2754,  1104,  1103,  4091,  8240,  2196,   170,
          10548,  1115,  1674,  1136,   175,  1200,  1880, 12477,  6066,  6787,
           1105,  1144,  1185,   185, 23415,  3202, 19515, 16234,  2007, 21749,
            119,  1103,  7454,  3791,  2848, 22400,  7606,  1111,  1103,  5351,
            119,  1103,  6978,  1104,  2168,  1104,  2168,  1104,  1103, 15683,
           1549,  5511,  2765,  2095, 11362,   117,  1134,  1104,  1103,  1378,
           1108,  1549,   136,   170,   131,  1107,  1864,   117, 22572, 24171,
           8223, 10436, 10658,  1233,   102,     0],
         [  101,   186,   131,   170,  1626,   118,  1214,   118,  1385, 13014,
           2327,  2581, 19073,  1116,  1104, 10880,   

In [33]:
input_ = lm.preprocess(df)
outputs = lm._forward(input_)

for k, v in outputs.items():
    print(f"Shape of {k}:\t{v.shape}")

RuntimeError: [enforce fail at alloc_cpu.cpp:114] data. DefaultCPUAllocator: not enough memory: you tried to allocate 5976024000 bytes.

In [15]:
prediction = lm.postprocess(outputs)
prediction

array([4, 1], dtype=int64)

In [16]:
from sklearn.metrics import accuracy_score
accuracy = accuracy_score(df['label'][:2], prediction)

print("Accuracy:", accuracy)

Accuracy: 0.0
