In [1]:
import os
import re
import pandas as pd
from datasets import load_dataset, Dataset
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments



In [2]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

# Tokenizer and Model

In [3]:
model_name = "gpt2-medium"

In [4]:
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

model.resize_token_embeddings(len(tokenizer))
model.config.pad_token_id = model.config.eos_token_id

Downloading config.json:   0%|          | 0.00/718 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/1.52G [00:00<?, ?B/s]

Downloading generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [5]:
df = pd.read_csv('/kaggle/input/med-question-answer/medQA.csv', index_col='Unnamed: 0')
df.shape

(16407, 3)

In [6]:
df.head()

Unnamed: 0,qtype,Question,Answer
0,susceptibility,Who is at risk for Lymphocytic Choriomeningiti...,LCMV infections can occur after exposure to fr...
1,symptoms,What are the symptoms of Lymphocytic Choriomen...,LCMV is most commonly recognized as causing ne...
2,susceptibility,Who is at risk for Lymphocytic Choriomeningiti...,Individuals of all ages who come into contact ...
3,exams and tests,How to diagnose Lymphocytic Choriomeningitis (...,"During the first phase of the disease, the mos..."
4,treatment,What are the treatments for Lymphocytic Chorio...,"Aseptic meningitis, encephalitis, or meningoen..."


In [7]:
def create_text(row):
    return f""" Question: {row['Question']} \n
    Answer: {row['Answer']}
    """

In [8]:
df['text'] = df.apply(create_text, axis=1)

In [9]:
df['text'][0]

' Question: Who is at risk for Lymphocytic Choriomeningitis (LCM)? ? \n\n    Answer: LCMV infections can occur after exposure to fresh urine, droppings, saliva, or nesting materials from infected rodents.  Transmission may also occur when these materials are directly introduced into broken skin, the nose, the eyes, or the mouth, or presumably, via the bite of an infected rodent. Person-to-person transmission has not been reported, with the exception of vertical transmission from infected mother to fetus, and rarely, through organ transplantation.\n    '

In [10]:
data = Dataset.from_pandas(df)

  if _pandas_api.is_sparse(col):


In [11]:
data

Dataset({
    features: ['qtype', 'Question', 'Answer', 'text', '__index_level_0__'],
    num_rows: 16407
})

In [12]:
tokenized_dataset = data.map(lambda x: tokenizer(x['text'], padding=True, truncation=True, max_length=512), batched=True)

  0%|          | 0/17 [00:00<?, ?ba/s]

In [13]:
os.mkdir('/kaggle/working/logs')

In [15]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [17]:
training_args = TrainingArguments(
        output_dir='/kaggle/working/',
        overwrite_output_dir=True,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        num_train_epochs=4,
        save_steps=100,
        max_steps = 5000,
        save_total_limit=2,
        logging_dir='/kaggle/working/logs',
        report_to=[],
    )

In [18]:
trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=tokenized_dataset
    )

In [19]:
trainer.train()

Step,Training Loss
500,1.6194
1000,1.4481
1500,1.4032
2000,1.3478
2500,1.3109
3000,1.2819
3500,1.3075
4000,1.2675
4500,1.1691
5000,1.1679


TrainOutput(global_step=5000, training_loss=1.3323400390625, metrics={'train_runtime': 5157.8879, 'train_samples_per_second': 3.878, 'train_steps_per_second': 0.969, 'total_flos': 1.857308518986547e+16, 'train_loss': 1.3323400390625, 'epoch': 1.22})

In [20]:
prompt = """ Question: Who is at risk for Lymphocytic Choriomeningitis (LCM)? \n
    Answer: """

In [21]:
sample = tokenizer(prompt, return_tensors="pt", return_attention_mask=False)

In [22]:
sample

{'input_ids': tensor([[18233,    25,  5338,   318,   379,  2526,   329,   406, 20896, 13733,
         13370,   609, 10145,   296,  3101, 11815,   357,  5639,    44, 19427,
           220,   628,   220,   220,   220, 23998,    25,   220]])}

In [23]:
sample = {key: value.to("cuda:0") for key, value in sample.items()}

In [24]:
outputs = model.generate(**sample, max_length=512)
text = tokenizer.batch_decode(outputs)[0]
print(text)



 Question: Who is at risk for Lymphocytic Choriomeningitis (LCM)? 

    Answer:  People who have been exposed to the human immunodeficiency virus (HIV) or other sexually transmitted diseases (STDs) are at risk for LCM.  People who have sex with an infected partner are at risk for LCM.  People who have been exposed to the human immunodeficiency virus (HIV) or other sexually transmitted diseases (STDs) are at risk for LCM.  People who have been exposed to the human immunodeficiency virus (HIV) or other sexually transmitted diseases (STDs) are at risk for LCM.  People who have sex with an infected partner are at risk for LCM.  People who have been exposed to the human immunodeficiency virus (HIV) or other sexually transmitted diseases (STDs) are at risk for LCM.  People who have been exposed to the human immunodeficiency virus (HIV) or other sexually transmitted diseases (STDs) are at risk for LCM.  People who have been exposed to the human immunodeficiency virus (HIV) or other sexually t

In [25]:
model.push_to_hub('ErnestBeckham/gpt-2-medqa')

model.safetensors:   0%|          | 0.00/1.42G [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/ErnestBeckham/gpt-2-medqa/commit/6b33d0522a4e019c97ca09efb9bdeee9bc24e5e0', commit_message='Upload model', commit_description='', oid='6b33d0522a4e019c97ca09efb9bdeee9bc24e5e0', pr_url=None, pr_revision=None, pr_num=None)