This notebook is for finetuning an autoregressive language model (like OPT or GPT2) for recognizing medical symptom mentions

In [None]:
# check GPU version
!nvidia-smi

In [None]:
! pip install datasets transformers

import random
import pandas as pd
from transformers import AutoTokenizer
from datasets import load_dataset
from torch.utils.data import Dataset
import torch

# Format data

In [7]:
df = pd.read_csv("/content/labeled_patient_records.txt")

# create a list of all symptoms mentioned
symptom_label_list = list(df['symptom_label'].value_counts().keys())
symptom_label_list = [x.lower() for x in symptom_label_list]

text_list = []
for i, row in df.iterrows():
  symptom = ''
  person_mention = ''
  if row['no_symptom']:
    # if the record does not mention any symptom, set the prompt to a random 
    # symptom and train the model to respond that the symptom was not found
    symptom = random.choice(symptom_label_list)
    person_mention = 'not mentioned'
  else:
    symptom = row['symptom_label']

    symptom = symptom.lower()

  text = """Patient medical file:
{}

Symptom: {}
Mention in relation to patient: {} 
""".format(row["transcription"], symptom, person_mention)

  text_list.append(text)


# Training

In [None]:
BASE_MODEL = "facebook/opt-350m"
MODEL_MAX_LEN = 2048
OUTPUT_PATH = "/models/final" # where final model will be saved

In [None]:
from datasets import DatasetDict
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    default_data_collator,
    Trainer,
    TrainingArguments,
)


tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, model_max_length=MODEL_MAX_LEN)
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)

training_args = TrainingArguments(
    output_dir="/tmp/model",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    learning_rate=1e-5
)

def encode(batch):
    encodings = tokenizer(batch["text"], padding="max_length", truncation=True)
    encodings["labels"] = encodings["input_ids"].copy()
    return encodings

In [None]:
dataset_load = Dataset.from_dict({"text": text_list})
tokenized_datasets = dataset_load.map(encode, remove_columns=["text"])

model.cuda() # train on gpu

trainer = Trainer(
    tokenizer=tokenizer,
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets,
    data_collator=default_data_collator
)
trainer.train()
trainer.save_model(OUTPUT_PATH)

# Inference

In [None]:
from transformers import GenerationConfig

In [None]:
# example:
text = """Patient medical file:
PREOPERATIVE DIAGNOSIS: , Missed abortion.,POSTOPERATIVE DIAGNOSIS:  ,Missed abortion.,PROCEDURE PERFORMED: , Suction, dilation, and curettage.,ANESTHESIA: , Spinal.,ESTIMATED BLOOD LOSS:,  50 mL.,COMPLICATIONS: , None.,FINDINGS: , Products of conception consistent with a 6-week intrauterine pregnancy.,INDICATIONS: , The patient is a 28-year-old gravida 4, para 3 female at 13 weeks by her last menstrual period and 6 weeks by an ultrasound today in the emergency room who presents with heavy bleeding starting today.  A workup done in the emergency room revealed a beta-quant level of 1931 and an ultrasound showing an intrauterine pregnancy with a crown-rump length consistent with a 6-week and 2-day pregnancy.  No heart tones were visible.  On examination in the emergency room, a moderate amount of bleeding was noted.,Additionally, the cervix was noted to be 1 cm dilated.  These findings were discussed with the patient and options including surgical management via dilation and curettage versus management with misoprostol versus expected management were discussed with the patient.  After discussion of these options, the patient opted for a suction, dilation, and curettage.  The patient was described to the patient in detail including risks of infection, bleeding, injury to surrounding organs including risk of perforation.  Informed consent was obtained prior to proceeding with the procedure.,PROCEDURE NOTE:  ,The patient was taken to the operating room where spinal anesthesia was administered without difficulty.  The patient was prepped and draped in usual sterile fashion in lithotomy position.  A weighted speculum was placed.  The anterior lip of the cervix was grasped with a single tooth tenaculum.  At this time, a 7-mm suction curettage was advanced into the uterine cavity without difficulty and was used to suction contents of the uterus.  Following removal of the products of conception, a sharp curette was advanced into the uterine cavity and was used to scrape the four walls of the uterus until a gritty texture was noted.  At this time, the suction curette was advanced one additional time to suction any remaining products.  All instruments were removed.  Hemostasis was visualized.  The patient was stable at the completion of the procedure.  Sponge, lap, and instrument counts were correct.

Symptom: joint pain
"""

In [None]:
input_ids = tokenizer.encode(text, return_tensors='pt').cuda()

sample_output = model.generate(
    input_ids, 
     generation_config=GenerationConfig(temperature=0, max_length=1024)
)



In [None]:
tokenizer.decode(sample_output[0], skip_special_tokens=True)