In [10]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import time

import torch
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          pipeline)


def get_pipeline(path, tokenizer):
    model = AutoModelForCausalLM.from_pretrained(
        path,
        torch_dtype=torch.float16,
        device_map='cuda',
        trust_remote_code=True)

    print('Model loaded')

    generator = pipeline("text-generation",
                         model=model,
                         tokenizer=tokenizer,
                         pad_token_id=tokenizer.pad_token_id,
                         eos_token_id=[tokenizer.eos_token_id, 128001]
                         )
    return generator


In [11]:
SYS_PROMPT = "You are a medical expert AI. Provide accurate and professional information to patients' questions or requests. AI is not a real doctor, so it cannot directly treat. In such cases, recommend a related specialized department so that the patient can receive appropriate treatment. Use a natural and friendly tone in your responses, and answer simply and clearly."
model_path = "aaditya/Llama3-OpenBioLLM-8B"
max_new_tokens = 512

tokenizer = AutoTokenizer.from_pretrained(model_path)

tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
print('bos_token:', tokenizer.bos_token)
print('bos_token_id:', tokenizer.bos_token_id)
print('eos_token:', tokenizer.eos_token)
print('eos_token_id:', tokenizer.eos_token_id)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


bos_token: <|begin_of_text|>
bos_token_id: 128000
eos_token: <|end_of_text|>
eos_token_id: 128001


In [12]:
pipe = get_pipeline(model_path, tokenizer)


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Model loaded


In [20]:
messages = []
if SYS_PROMPT is not None:
    messages.append({"role": 'system', 'content': SYS_PROMPT})

In [21]:
input_ = """
[Patient's Message]
My eyes are blurry.

[Request]
Ask brief questions one at a time to infer the suspected disease based on the patient's symptoms. After asking less than five questions, infer the possible disease based on the information collected.
"""

messages.append({'role': 'user', 'content': input_})

In [22]:
for m in messages[:-1]:
    print('\033[92m' + m['role'] + ': ', m['content'])
print('\033[93m' + 'User: ' + input_)

[92msystem:  You are a medical expert AI. Provide accurate and professional information to patients' questions or requests. AI is not a real doctor, so it cannot directly treat. In such cases, recommend a related specialized department so that the patient can receive appropriate treatment. Use a natural and friendly tone in your responses, and answer simply and clearly.
[93mUser: 
[Patient's Message]
My eyes are blurry.

[Request]
Ask brief questions one at a time to infer the suspected disease based on the patient's symptoms. After asking less than five questions, infer the possible disease based on the information collected.



In [23]:
text = pipe.tokenizer.apply_chat_template(
            messages, add_generation_prompt=True, tokenize=False)
result = pipe(text,
                return_full_text=False,
                clean_up_tokenization_spaces=True,
                max_new_tokens=512,
                do_sample=False)[0]['generated_text']

messages.append({'role': 'assistant', 'content': result})
print('\033[95m' + 'Assistant: ' + result)

[95mAssistant: I'm sorry to hear that your eyes are blurry. To better understand your situation, I have a few questions for you. Have you experienced any pain or discomfort in your eyes? Do you have a history of eye conditions or have you been exposed to any irritants recently?


In [24]:
input_ = """
[Patient's Message]
There is no pain, just light glares. I have no history of eye disease or recent exposure to irritants.

[Request]
Ask brief questions one at a time to infer the suspected disease based on the patient's symptoms. After asking less than five questions, infer the possible disease based on the information collected.
"""

messages.append({'role': 'user', 'content': input_})

text = pipe.tokenizer.apply_chat_template(
            messages, add_generation_prompt=True, tokenize=False)
result = pipe(text,
                return_full_text=False,
                clean_up_tokenization_spaces=True,
                max_new_tokens=512,
                do_sample=False)[0]['generated_text']

messages.append({'role': 'assistant', 'content': result})
print('\033[95m' + 'Assistant: ' + result)

[95mAssistant: Based on your symptoms of blurry vision without pain or discomfort, and no history of eye conditions or recent exposure to irritants, it is possible that you are experiencing a refractive error such as nearsightedness, farsightedness, or astigmatism. These conditions can cause blurry vision when the shape of the eye does not properly focus light onto the retina. I would recommend scheduling an eye examination with an optometrist or ophthalmologist to get a proper diagnosis and discuss treatment options.
