In [2]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.add_special_tokens({"pad_token": "<pad>", 
                                "bos_token": "<|user|>",
                                "eos_token": "<|eod|>"})
tokenizer.add_tokens(["<|ai|>"])
tokenizer.add_tokens(["<|eos|>"])

model = GPT2LMHeadModel.from_pretrained("gpt2")
model.resize_token_embeddings(len(tokenizer))

model.load_state_dict(torch.load('fine-tuned-med-model.pt'))
model = model.to(device)
model.eval()

def runModel(inp):
    inp = "<|user|> "+inp+" <|eos|> <|ai|>: "
    inp = tokenizer(inp, return_tensors="pt")
    X = inp["input_ids"].to(device)
    a = inp["attention_mask"].to(device)
    output = model.generate(X, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, attention_mask=a)
    output = tokenizer.decode(output[0], skip_special_tokens=True)

    loc = output.find("<|eos|>")
    locOut = output.find("<|eos|>", loc + 1)
    output = output[:locOut + 7]
    return output

In [3]:
print(runModel('What is bronchiolitis?'))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


What is bronchiolitis? <|eos|> <|ai|> :  bronchiectasis is a condition where the airways become inflamed due to a build-up of mucus.
References:
- https://www.nhs.uk/conditions/cornea-transplant/what-happens-at-an-cromium-resection/ <|eos|>


In [4]:
print(runModel('What are the complications of a heart attack?'))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


What are the complications of a heart attack? <|eos|> <|ai|> : Heart attacks can lead to serious complications such as heart failure, stroke, and death.
References:
- https://www.nhs.uk/conditions/Heart-attack/Pages/Introduction.aspx <|eos|>


In [6]:
print(runModel('What is dehydration?'))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


What is dehydration? <|eos|> <|ai|> :  dehydration is a condition where your body loses too much fluid.
References:
- https://www.nhs.uk/conditions/alcohol-related-liver-disease-arld/complications/ <|eos|>


In [5]:
print(runModel('What should I do if someone has inhaled poisonous fumes?'))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


What should I do if someone has inhaled poisonous fumes? <|eos|> <|ai|> : Call 999 or go to your nearest accident and emergency (A&E) department immediately if you or someone you look after is seriously ill or has breathing difficulties.
References:
- https://www.nhs.uk/conditions/poisoning/symptoms-of-illness/ <|eos|>


In [7]:
print(runModel('How long does it take for jet lag to improve?'))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


How long does it take for jet lag to improve? <|eos|> <|ai|> : Jet lag usually improves on its own within a few weeks, but it can take longer.
References:
- https://www.nhs.uk/conditions/jet-lag/ <|eos|>
