<a href="https://colab.research.google.com/github/TonyWilson96/ChatDoctor/blob/main/OpenGPT_The_making_of_Dum_E.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# OpenGPT - The making of Dum-E

We will fine-tune a GPT-2 model and make a very small version of ChatGPT for healthcare. But, since we are using a very small model (GPT-2) as the base, it will not be the best - in other words we are making Dum-E for healthcare (it will try its best, but...).


<p align='center'>
<img height=400px, src='https://i.pinimg.com/originals/0e/b5/db/0eb5db395528f3a8689c306795e9745a.jpg' />
</p>

----

OpenGPT GitHub: https://github.com/CogStack/OpenGPT

If you want to learn more about AI in healthcare, please subscribe: https://aiforhealthcare.substack.com/

In [16]:
# **!!! Switch the runtime to GPU !!!**

In [None]:
# Install OpenGPT -- this cell will autorestart the notebook (ie kill it, you will see notifications for the session crashing, but that is fine)
! pip install --upgrade opengpt
! pip install accelerate

import os
os.kill(os.getpid(), 9)



In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, pipeline
import pickle
import pandas as pd
import datasets


from opengpt.config import Config
from opengpt.model_utils import add_tokens_to_model_and_tokenizer
from opengpt.dataset_utils import create_labels, pack_examples
from opengpt.data_collator import DataCollatorWithPadding

In [2]:
# Download the configs and data from OpenGPT repo https://github.com/CogStack/OpenGPT
!wget https://raw.githubusercontent.com/CogStack/OpenGPT/main/data/nhs_uk_full/prepared_generated_data_for_nhs_uk_qa.csv
!wget https://raw.githubusercontent.com/CogStack/OpenGPT/main/data/nhs_uk_full/prepared_generated_data_for_nhs_uk_conversations.csv
!wget https://raw.githubusercontent.com/CogStack/OpenGPT/main/data/medical_tasks_gpt4/prepared_generated_data_for_medical_tasks.csv
!wget https://raw.githubusercontent.com/CogStack/OpenGPT/main/configs/example_train_config.yaml

--2023-10-24 15:27:08--  https://raw.githubusercontent.com/CogStack/OpenGPT/main/data/nhs_uk_full/prepared_generated_data_for_nhs_uk_qa.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.110.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 10370226 (9.9M) [text/plain]
Saving to: ‘prepared_generated_data_for_nhs_uk_qa.csv’


2023-10-24 15:27:09 (101 MB/s) - ‘prepared_generated_data_for_nhs_uk_qa.csv’ saved [10370226/10370226]

--2023-10-24 15:27:09--  https://raw.githubusercontent.com/CogStack/OpenGPT/main/data/nhs_uk_full/prepared_generated_data_for_nhs_uk_conversations.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request s

In [3]:
# Load the config
config = Config(yaml_path='./example_train_config.yaml')

# This config can be used as a template, feel free to change whatever you like
config.train.to_dict()

{'model': 'olm/olm-gpt2-oct-2022',
 'datasets': ['../data/example_project_data/prepared_generated_data_for_example_project.csv',
  '../data/nhs_uk_full/prepared_generated_data_for_nhs_uk_qa.csv',
  '../data/nhs_uk_full/prepared_generated_data_for_nhs_uk_conversations.csv',
  '../data/medical_tasks_gpt4/prepared_generated_data_for_medical_tasks.csv'],
 'ignore_index': -100,
 'max_seq_len': 512,
 'packing_type': 'partial',
 'shuffle_dataset': True,
 'hf_training_arguments': {'output_dir': '../data/results/',
  'gradient_accumulation_steps': 16,
  'per_device_eval_batch_size': 1,
  'per_device_train_batch_size': 1,
  'load_best_model_at_end': False,
  'learning_rate': 2e-05,
  'weight_decay': 0.1,
  'adam_beta1': 0.9,
  'adam_beta2': 0.95,
  'adam_epsilon': 1e-07,
  'max_grad_norm': 1,
  'num_train_epochs': 1,
  'lr_scheduler_type': 'cosine',
  'warmup_ratio': 0.03,
  'logging_strategy': 'steps',
  'logging_steps': 100,
  'save_strategy': 'steps',
  'save_steps': 30000,
  'seed': 11,
  'o

In [4]:
# Load the GPT-2 model and tokenizer
model = AutoModelForCausalLM.from_pretrained(config.train.model)
tokenizer = AutoTokenizer.from_pretrained(config.train.model)
tokenizer.model_max_length = config.train.max_seq_len

Downloading (…)lve/main/config.json:   0%|          | 0.00/907 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/510M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/528 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/805k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/463k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.12M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/280 [00:00<?, ?B/s]

In [5]:
# Expand tokenizer/model with the necessary special tokens for conversational LLMs
add_tokens_to_model_and_tokenizer(config, tokenizer, model)



In [6]:
# Define the datasets to be used (a list of CSV paths, we've downloaded them above). The datasets we will use here a high quality medical datasets,
#you can use the same datasets with e.g. LLaMA or StableLM to train a nice LLM for healthcare (e.g. NHS-LLM, see OpenGPT repo: https://github.com/CogStack/OpenGPT).
config.train.datasets = ['./prepared_generated_data_for_nhs_uk_qa.csv',
                         './prepared_generated_data_for_nhs_uk_conversations.csv',
                         './prepared_generated_data_for_medical_tasks.csv']

In [7]:
# Load datasets and shuffle if needed
train_dataset = datasets.Dataset.from_csv(config.train.datasets)
if config.train.shuffle_dataset:
    train_dataset = train_dataset.shuffle()
    print("Shuffling dataset!")

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Shuffling dataset!


In [8]:
# Check the size of the dataset and columns
train_dataset

Dataset({
    features: ['text', 'raw_data_id'],
    num_rows: 31708
})

In [9]:
# the text part will be used to train our model, note the special tokens that are required for training
# models with OpenGPT (check the config file in the OpenGPT repo for explanations)
train_dataset[0]['text']

'<|user|> What is the maximum amount of salt that one should consume to reduce the risk of developing CHD? <|eos|> <|ai|> One should limit the amount of salt they eat to less than 6g (0.2oz) a day if they want to reduce the risk of developing coronary heart disease (CHD). This is because too much salt can increase blood pressure.\nReferences:\n- https://www.nhs.uk/conditions/coronary-heart-disease/prevention/ <|eos|> <|eod|>'

In [10]:
# Remove everything but text
to_remove = list(train_dataset.column_names)
to_remove.remove('text')
train_dataset = train_dataset.remove_columns(to_remove)

In [11]:
# Ignore max_seq_len warning, it is handled by the packer or data_collator
train_dataset = train_dataset.map(
    lambda examples: tokenizer(examples['text'], add_special_tokens=False),
    batched=True,
    num_proc=1,
    remove_columns=["text"])
# Create labels for supervised training (meaning we do not train on questions, but only on answers)
train_dataset = train_dataset.map(
    lambda examples: create_labels(examples, config, tokenizer),
    batched=True,
    batch_size=1000,
    num_proc=1,
)
# We only do packing for the train set
train_dataset = train_dataset.map(
    lambda examples: pack_examples(examples, config.train.max_seq_len, packing_type=config.train.packing_type),
    batched=True,
    batch_size=1000,
    num_proc=1,
)

Map:   0%|          | 0/31708 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (886 > 512). Running this sequence through the model will result in indexing errors


Map:   0%|          | 0/31708 [00:00<?, ? examples/s]

Map:   0%|          | 0/31708 [00:00<?, ? examples/s]

In [12]:
# Check the new train_dataset (take note of how the labels look). The USER (Question) part of the input should have a label of -100,
# and the AI part (Answer) should have labels equal to input_ids
for i in range(100):
    print(train_dataset[0]['input_ids'][i], train_dataset[0]['labels'][i], train_dataset[0]['attention_mask'][i])

50265 -100 1
1838 -100 1
318 -100 1
225 -100 1
88 -100 1
76 -100 1
73 -100 1
5530 -100 1
2487 -100 1
292 -100 1
6515 -100 1
345 -100 1
554 -100 1
905 -100 1
13230 -100 1
286 -100 1
3881 -100 1
225 -100 1
88 -100 1
76 -100 1
73 -100 1
2972 -100 1
292 -100 1
4450 -100 1
8440 -100 1
40 -100 1
35 -100 1
225 -100 1
50267 -100 1
225 -100 1
50266 50266 1
2060 2060 1
905 905 1
4813 4813 1
225 225 1
88 88 1
76 76 1
73 73 1
2487 2487 1
292 292 1
6515 6515 1
544 544 1
3835 3835 1
286 286 1
1490 1490 1
720 720 1
1025 1025 1
75 75 1
396 396 1
20 20 1
18 18 1
22 22 1
9747 9747 1
13 13 1
225 225 1
69 69 1
927 927 1
691 691 1
544 544 1
813 813 1
286 286 1
3881 3881 1
225 225 1
88 88 1
76 76 1
73 73 1
2972 2972 1
292 292 1
4450 4450 1
36455 36455 1
2328 2328 1
4447 4447 1
396 396 1
5429 5429 1
40 40 1
859 859 1
738 738 1
318 318 1
952 952 1
1248 1248 1
940 940 1
6515 6515 1
424 424 1
2688 2688 1
3695 3695 1
3946 3946 1
18 18 1
203 203 1
1987 1987 1
30 30 1
203 203 1
17 17 1
38324 38324 1
15892 15892 1


In [14]:
# Load the HF training arguments and make the data collator, you can try increasing the LR or the number of epochs - it could improve the performance of GPT-2.
# If you have access to bigger GPUs on Google Colab, you can also try bigger models from HF Hub, and get better results.
training_args = TrainingArguments(**config.train.hf_training_arguments.to_dict())
dc = DataCollatorWithPadding(tokenizer.pad_token_id, config.train.ignore_index, max_seq_len=config.train.max_seq_len)

# We are using HF trainer to train our models
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=None,
    data_collator=dc,
)

In [15]:
# Run training, should take around 30min. Ignore AdamW warnings if they appear.
trainer.train()



Step,Training Loss
100,1.612
200,1.4338
300,1.4032
400,1.3714
500,1.348
600,1.3426
700,1.3311


TrainOutput(global_step=736, training_loss=1.4030505781588347, metrics={'train_runtime': 1582.9199, 'train_samples_per_second': 7.441, 'train_steps_per_second': 0.465, 'total_flos': 2506650503040000.0, 'train_loss': 1.4030505781588347, 'epoch': 1.0})

# Testing

In [17]:
gen = pipeline(model=model, tokenizer=tokenizer, task='text-generation', device=model.device)

In [18]:
t = "<|user|> What is diabetes? <|eos|> <|ai|>" # The format with special tokens is required, because of training
print(gen(t, do_sample=True, max_length=128, temperature=0.2)[0]['generated_text'])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


<|user|> What is diabetes? <|eos|> <|ai|> Diabetes is a condition where the body's cells become unable to produce enough insulin, which can lead to high blood sugar levels.
References:
- https://www.nhs.uk/conditions/type-1-diabetes/living-with-type-1-diabetes/  <|eod|><|eod|><|eod|><|eod|><|ai|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|>


In [19]:
# Let's try some examples we used to test the NHS-LLM.  Obviously the result is not the best, what is good is the reference is correct - but the explanation makes no sense
t = "<|user|> What is vitamin d3 and should I take it? <|eos|> <|ai|>"
print(gen(t, do_sample=True, max_length=128, temperature=0.2)[0]['generated_text'])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


<|user|> What is vitamin d3 and should I take it? <|eos|> <|ai|> Vitamin D is a vitamin that helps your body absorb calcium, which is a vital nutrient for your bones, muscles, and skin.
References:
- https://www.nhs.uk/conditions/vitamins-and-minerals/vitamin-d/  <|eod|><|eod|><|ai|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|>


In [20]:
# A bit better, it even correctly expanded the abbreviation for hypertension
t = "<|user|> What is HTN? <|eos|> <|ai|>"
print(gen(t, do_sample=True, max_length=128, temperature=0.2)[0]['generated_text'])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


<|user|> What is HTN? <|eos|> <|ai|> HTN stands for Hypertension, which is a condition where the blood vessels in the legs become narrowed, causing blood to flow to the heart.
References:
- https://www.nhs.uk/conditions/high-tension-nhs/  <|eod|><|ai|> Hi, I'm sorry to hear that. Can you tell me more about your symptoms?
References:
- https://www.nhs.uk/conditions/high-tension-nhs/  <|eod|> Hi, I'm


In [21]:
# Let us try a question that has nothing to do with healthcare - this proves that the model is generalising a bit (but the reference makes no sense)
t = "<|user|> What is the capital of France? <|eos|> <|ai|>"
print(gen(t, do_sample=True, max_length=128, temperature=0.2)[0]['generated_text'])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


<|user|> What is the capital of France? <|eos|> <|ai|> The capital of France is Paris.
References:
- https://www.nhs.uk/conditions/social-care-and-support-guide/care-services-equipment-and-care-homes/living-with-social-care-and-support-guide/  <|eod|><|eod|><|eod|><|eod|><|eod|><|ai|><|ai|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|><|eod|>


In [22]:
# And a couple more healthcare quesetions, e.g. an MCQ (The correct answer is A)
t = "<|user|> Choose the correct statement about laparoscopic liver resection efficacy from the following options: a) there is no difference in the overall patient survival rate or disease-free survival rate between laparoscopic liver resection and open resection, b) laparoscopic liver resection has a higher patient survival rate than open resection, c) laparoscopic liver resection has a lower disease-free survival rate than open resection. <|eos|> <|ai|>"
print(gen(t, do_sample=True, max_length=128, temperature=0.2)[0]['generated_text'])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


<|user|> Choose the correct statement about laparoscopic liver resection efficacy from the following options: a) there is no difference in the overall patient survival rate or disease-free survival rate between laparoscopic liver resection and open resection, b) laparoscopic liver resection has a higher patient survival rate than open resection, c) laparoscopic liver resection has a lower disease-free survival rate than open resection. <|eos|> <|ai|> a) There is no difference in the overall patient survival rate or disease-free survival rate between lap


In [23]:
t = "<|user|> I had a high fever for the past 3 days, what should i do? <|eos|> <|ai|>"
print(gen(t, do_sample=True, max_length=128, temperature=0.2)[0]['generated_text'])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


<|user|> I had a high fever for the past 3 days, what should i do? <|eos|> <|ai|> I have a high fever for the past 3 days. I would recommend you to stay at home and avoid contact with other people. You can also try to avoid contact with other people, such as other animals, and avoid contact with other people who have a high fever.
References:
- https://www.nhs.uk/conditions/flu/  <|eod|><|ai|> If you have a high fever, you should avoid contact with other people, and avoid contact with


In [24]:
# Note that we can also continue the conversation (I'm doing it manually here, but can be easily made into a simple chatbot)
t = """<|user|> I had a high fever for the past 3 days, what should i do? <|eos|> <|ai|> I'm sorry to hear that. I understand your concern. Have you been feeling any other symptoms?
References:
- https://www.nhs.uk/conditions/coronavirus-covid-19/coronavirus-vaccination/how-to-get-your-first-dose-for-coronavirus/ <|eos|> <|user|> No, only fever. <|eos|> <|ai|>"""
print(gen(t, do_sample=True, max_length=256, temperature=0.2)[0]['generated_text'])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


<|user|> I had a high fever for the past 3 days, what should i do? <|eos|> <|ai|> I'm sorry to hear that. I understand your concern. Have you been feeling any other symptoms?
References:
- https://www.nhs.uk/conditions/coronavirus-covid-19/coronavirus-vaccination/how-to-get-your-first-dose-for-coronavirus/ <|eos|> <|user|> No, only fever. <|eos|> <|ai|> I see. Have you been feeling any other symptoms?
References:
- https://www.nhs.uk/conditions/coronavirus-covid-19/coronavirus-vaccination/how-to-get-your-first-dose-for-coronavirus/  <|eod|><|ai|><|ai|><|ai|><|ai|><|eod|>,<|eod|>,<|eod|>,<|eod|>,<|eod|>,<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>:<|eod|>


In [25]:
# Finally, to show that our training works, we will also try to query an untrained model
model_not_trained = AutoModelForCausalLM.from_pretrained(config.train.model)
tokenizer_not_trained = AutoTokenizer.from_pretrained(config.train.model)
gen_not_trained = pipeline(model=model_not_trained, tokenizer=tokenizer_not_trained, task='text-generation', device=model.device)

In [26]:
t = "What is HTN?" # No special tokens, as this model is not trained.
print(gen_not_trained(t, do_sample=True, max_length=128, temperature=0.2)[0]['generated_text'])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


What is HTN?
Hat Trick Network is a network of over 1,000,000 registered users. It is a free service that allows users to create and share their own content. It is a great way to share your content with friends and family. It is also a great way to connect with other people who share similar interests.
What is HTN?
Hat Trick Network is a free service that allows users to create and share their own content. It is a great way to connect with other people who share similar interests. It is a great way to share your content with friends


In [27]:
t = "What is vitamin d3 and should I take it?" # No special tokens, as this model is not trained.
print(gen_not_trained(t, do_sample=True, max_length=128, temperature=0.2)[0]['generated_text'])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


What is vitamin d3 and should I take it?
Vitamin D3 is a vitamin that is found in the skin and is essential for the health of our body. It is also important for our immune system and our overall health. Vitamin D3 is a good source of vitamin D, which is essential for our body’s health.
What is vitamin D3 deficiency?
Vitamin D deficiency is a problem that can occur in many different ways. It can be caused by a number of different things, including a lack of vitamin D, a lack of vitamin D3,


In [28]:
t = "What is the capital of France?" # Let's try a general question
print(gen_not_trained(t, do_sample=True, max_length=128, temperature=0.2)[0]['generated_text'])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


What is the capital of France?
The capital of France is Paris. It is the capital of France. It is the capital of France.
What is the capital of the United States?
The capital of the United States is Washington, D.C. It is the capital of the United States.
What is the capital of the United Kingdom?
The capital of the United Kingdom is London. It is the capital of the United


In [29]:
# The untrained model deals with general questions to some extent, but things get very messy when we ask a medical question.
# Meaning that our training works, even with a small model.

## Conclusion

We managed to create a Dum-E version of ChatGPT in 30min on google colab. Honestly I'm surprised how good it works, given how small this model is.

----

**The code for OpenGPT is open source, please consider contributing. Or if you find any bugs/problems please open an Issue or submit a PR**

OpenGPT GitHub: https://github.com/CogStack/OpenGPT

If you want to learn more about AI in healthcare, please subscribe: https://aiforhealthcare.substack.com/