In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:


ds = load_dataset("medalpaca/medical_meadow_medical_flashcards")

In [3]:
print(ds)

DatasetDict({
    train: Dataset({
        features: ['input', 'output', 'instruction'],
        num_rows: 33955
    })
})


In [4]:
print(ds["train"][0])

{'input': 'What is the relationship between very low Mg2+ levels, PTH levels, and Ca2+ levels?', 'output': 'Very low Mg2+ levels correspond to low PTH levels which in turn results in low Ca2+ levels.', 'instruction': 'Answer this question truthfully'}


In [5]:
def format_row(row):
    return {
        "text": "Instruction: " + row['instruction'] + 
                " Question: " + row['input'] + 
                " Response: " + row['output']
    }

formatted_ds = ds["train"].map(format_row)


In [6]:
# test new format
print(formatted_ds[0]["text"])

Instruction: Answer this question truthfully Question: What is the relationship between very low Mg2+ levels, PTH levels, and Ca2+ levels? Response: Very low Mg2+ levels correspond to low PTH levels which in turn results in low Ca2+ levels.


In [7]:
# load the model and tokenizer
model = "Featherless-Chat-Models/Mistral-7B-Instruct-v0.2"
tokenizer = AutoTokenizer.from_pretrained(model)
tokenizer.pad_token = tokenizer.eos_token



In [8]:
def tokenize_row(row):
    return tokenizer(
        row["text"],
        truncation=True,
        padding="max_length",
        max_length=512
    )

tokenized_ds = formatted_ds.map(tokenize_row, batched=True)


In [9]:
#test tokenizer
print(tokenized_ds[0])

{'input': 'What is the relationship between very low Mg2+ levels, PTH levels, and Ca2+ levels?', 'output': 'Very low Mg2+ levels correspond to low PTH levels which in turn results in low Ca2+ levels.', 'instruction': 'Answer this question truthfully', 'text': 'Instruction: Answer this question truthfully Question: What is the relationship between very low Mg2+ levels, PTH levels, and Ca2+ levels? Response: Very low Mg2+ levels correspond to low PTH levels which in turn results in low Ca2+ levels.', 'input_ids': [1, 3133, 3112, 28747, 26307, 456, 2996, 5307, 3071, 22478, 28747, 1824, 349, 272, 3758, 1444, 1215, 2859, 351, 28721, 28750, 28806, 6157, 28725, 367, 3151, 6157, 28725, 304, 11013, 28750, 28806, 6157, 28804, 12107, 28747, 13649, 2859, 351, 28721, 28750, 28806, 6157, 10384, 298, 2859, 367, 3151, 6157, 690, 297, 1527, 2903, 297, 2859, 11013, 28750, 28806, 6157, 28723, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,

In [10]:
#remoce any strings
tokenized_ds = tokenized_ds.remove_columns(["text", "input", "output", "instruction"])


In [11]:
#test tokenizer
print(tokenized_ds[0])

{'input_ids': [1, 3133, 3112, 28747, 26307, 456, 2996, 5307, 3071, 22478, 28747, 1824, 349, 272, 3758, 1444, 1215, 2859, 351, 28721, 28750, 28806, 6157, 28725, 367, 3151, 6157, 28725, 304, 11013, 28750, 28806, 6157, 28804, 12107, 28747, 13649, 2859, 351, 28721, 28750, 28806, 6157, 10384, 298, 2859, 367, 3151, 6157, 690, 297, 1527, 2903, 297, 2859, 11013, 28750, 28806, 6157, 28723, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2

In [12]:
def add_labels(row):
    row["labels"] = row["input_ids"].copy()  # labels = input_ids for causal LM
    return row

tokenized_ds = tokenized_ds.map(add_labels)


In [13]:
#test labels
print(tokenized_ds[0])


{'input_ids': [1, 3133, 3112, 28747, 26307, 456, 2996, 5307, 3071, 22478, 28747, 1824, 349, 272, 3758, 1444, 1215, 2859, 351, 28721, 28750, 28806, 6157, 28725, 367, 3151, 6157, 28725, 304, 11013, 28750, 28806, 6157, 28804, 12107, 28747, 13649, 2859, 351, 28721, 28750, 28806, 6157, 10384, 298, 2859, 367, 3151, 6157, 690, 297, 1527, 2903, 297, 2859, 11013, 28750, 28806, 6157, 28723, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2