## Model setup
Quantization and Accelerate allow inference to run even on a gaming laptop (Nvidia 3080 8GB VRAM)

In [3]:
# Helper packages
import pandas as pd
from textwrap import fill
from IPython.display import Markdown, display # for formating Python display folowing markdown language

# Model loading
import torch
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, pipeline

# Langchain
from langchain import PromptTemplate, HuggingFacePipeline


Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd
  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Model version of Mistral
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"

# Quantization is a technique used to reduce the memory and computation requirements 
# of deep learning models, typically by using fewer bits, 4 bits
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

# Initialization of a tokenizer for the Mistral-7b model, 
# necessary to preprocess text data for input
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

# Initialization of the pre-trained language Mistral-7b
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, torch_dtype=torch.float16,
    trust_remote_code=True,
    device_map="auto",
    quantization_config=quantization_config
)

# Print the device_map to make sure the whole model fits in GPU
print(model.hf_device_map)

Loading checkpoint shards: 100%|██████████| 3/3 [00:20<00:00,  6.79s/it]


OrderedDict([('', 0)])


In [5]:
# Configuration of some generation-related settings
generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
generation_config.max_new_tokens = 1024 # maximum number of new tokens that can be generated by the model
generation_config.temperature = 0.2 # low temperature for more deterministic output
generation_config.top_p = 0.1 # same for top_p
generation_config.do_sample = True # sampling during the generation process
generation_config.repetition_penalty = 1.15 # the degree to which the model should avoid repeating tokens in the generated text

# A pipeline is an object that works as an API for calling the model
# The pipeline is made of (1) the tokenizer instance, the model instance, and
# some post-procesing settings. Here, it's configured to return full-text outputs
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    return_full_text=True,
    generation_config=generation_config,
)

# HuggingFace pipeline
llm = HuggingFacePipeline(pipeline=pipe)

### Test the model
Instruction format: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2#instruction-format

In [6]:
def generate(model, text, template=None, format_instructions=None):
    if template == None:
        template = "[INST]{text}[/INST]"
    
    prompt = PromptTemplate.from_template(template)
    
    response = model(prompt.format(text = text, format_instructions = format_instructions))
    return response.strip()

def generate_and_display(model, text, template=None, format_instructions=None):
    result = generate(model, text, template, format_instructions)

    # No point displaying a templated prompt, this is just a convenience for simple prompts
    if (template == None):
        display(Markdown(f"<b>{text}</b>"))

    display(Markdown(f"{result}"))

# Test with a simple prompt
generate_and_display(llm, "Explain the fundamentals of ChatGPT in a couple of lines.")

  warn_deprecated(
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


<b>Explain the fundamentals of ChatGPT in a couple of lines.</b>

ChatGPT is a model from OpenAI that interacts with users in natural language text, providing responses based on context and information provided. It uses deep learning techniques to understand input, generate appropriate responses, and learn from interactions to improve performance over time. The goal is to create conversational experiences that mimic human-like interaction.

## Medical transcripts dataset
Source: https://www.kaggle.com/datasets/tboyle10/medicaltranscriptions

In [7]:
df=pd.read_csv('../data/medical-transcripts/mtsamples.csv',index_col=0)
df.head(5)

Unnamed: 0,description,medical_specialty,sample_name,transcription,keywords
0,A 23-year-old white female presents with comp...,Allergy / Immunology,Allergic Rhinitis,"SUBJECTIVE:, This 23-year-old white female pr...","allergy / immunology, allergic rhinitis, aller..."
1,Consult for laparoscopic gastric bypass.,Bariatrics,Laparoscopic Gastric Bypass Consult - 2,"PAST MEDICAL HISTORY:, He has difficulty climb...","bariatrics, laparoscopic gastric bypass, weigh..."
2,Consult for laparoscopic gastric bypass.,Bariatrics,Laparoscopic Gastric Bypass Consult - 1,"HISTORY OF PRESENT ILLNESS: , I have seen ABC ...","bariatrics, laparoscopic gastric bypass, heart..."
3,2-D M-Mode. Doppler.,Cardiovascular / Pulmonary,2-D Echocardiogram - 1,"2-D M-MODE: , ,1. Left atrial enlargement wit...","cardiovascular / pulmonary, 2-d m-mode, dopple..."
4,2-D Echocardiogram,Cardiovascular / Pulmonary,2-D Echocardiogram - 2,1. The left ventricular cavity size and wall ...,"cardiovascular / pulmonary, 2-d, doppler, echo..."


In [8]:
df.describe()

Unnamed: 0,description,medical_specialty,sample_name,transcription,keywords
count,4999,4999,4999,4966,3931.0
unique,2348,40,2377,2357,3849.0
top,An example/template for a routine normal male...,Surgery,Lumbar Discogram,"PREOPERATIVE DIAGNOSIS: , Low back pain.,POSTO...",
freq,12,1103,5,5,81.0


In [9]:
# Sample the transcripts to get an idea of the content
display(Markdown(f"<p>{df['transcription'].iloc[0]}</p>"))

<p>SUBJECTIVE:,  This 23-year-old white female presents with complaint of allergies.  She used to have allergies when she lived in Seattle but she thinks they are worse here.  In the past, she has tried Claritin, and Zyrtec.  Both worked for short time but then seemed to lose effectiveness.  She has used Allegra also.  She used that last summer and she began using it again two weeks ago.  It does not appear to be working very well.  She has used over-the-counter sprays but no prescription nasal sprays.  She does have asthma but doest not require daily medication for this and does not think it is flaring up.,MEDICATIONS: , Her only medication currently is Ortho Tri-Cyclen and the Allegra.,ALLERGIES: , She has no known medicine allergies.,OBJECTIVE:,Vitals:  Weight was 130 pounds and blood pressure 124/78.,HEENT:  Her throat was mildly erythematous without exudate.  Nasal mucosa was erythematous and swollen.  Only clear drainage was seen.  TMs were clear.,Neck:  Supple without adenopathy.,Lungs:  Clear.,ASSESSMENT:,  Allergic rhinitis.,PLAN:,1.  She will try Zyrtec instead of Allegra again.  Another option will be to use loratadine.  She does not think she has prescription coverage so that might be cheaper.,2.  Samples of Nasonex two sprays in each nostril given for three weeks.  A prescription was written as well.</p>

### Extract patient information

In [10]:
format_instructions = """
We want to extract ONLY the following information, if it is mentioned in the transcript:
- The age of the patient
- The gender of the patient 

The format should a valid JSON with this structure:

```json
{
    "age": 99,
    "gender": "Male"
}
```

Include ONLY fields that are in the example above, or the schema will be invalid and the code will fail.

Use set the field to null, if you can't find the information.

Respond only with valid JSON within a markdown code block, add no further comments to the response.
"""

In [11]:
template = """[INST]
You are a medical expert reading through transcripts. Your expertise spans the whole medical domain.

Your job is to extract structured information from the transcript.

This is the transcript: ```{text}```

{format_instructions}
[/INST]
"""

for text in df['transcription'].head(5):
    generate_and_display(llm, text, template, format_instructions)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


```json
{
    "age": 23,
    "gender": "Female"
}
```

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


```json
{
    "age": null,
    "gender": null
}
```

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


```json
{
    "age": 42,
    "gender": "Male"
}
```

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


```json
{
    "age": null,
    "gender": null
}
```

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


```json
{
    "age": null,
    "gender": null
}
```