In [1]:
!pip install transformers==4.41.2 peft==0.10.0 accelerate==0.30.1 datasets==2.19.1 bitsandbytes==0.43.0


[0m

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
from datasets import Dataset

print("CUDA Available:", torch.cuda.is_available())


  from .autonotebook import tqdm as notebook_tqdm


CUDA Available: True


In [3]:
model_ckpt = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModelForCausalLM.from_pretrained(model_ckpt, device_map="auto")


In [4]:
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)


In [5]:
import pandas as pd

train_df = pd.read_csv('train.csv').dropna(subset=['Conversation']).reset_index(drop=True)

def format_data(row):
    return {
        "text": f"[INST] Summarize the following clinical conversation:\n{row['Conversation']} [/INST]"
    }

formatted_data = train_df.apply(format_data, axis=1).tolist()
dataset = Dataset.from_list(formatted_data)

def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)

tokenized_dataset = dataset.map(tokenize_function, batched=True)


Map: 100%|██████████| 106556/106556 [00:11<00:00, 9211.07 examples/s]


In [6]:
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling

training_args = TrainingArguments(
    output_dir="./tinyllama_lora_results",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    learning_rate=2e-4,
    fp16=True,
    logging_steps=10,
    save_steps=100,
    save_total_limit=2,
    logging_dir="./logs",
)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)


In [7]:
trainer.train()


Step,Training Loss
10,2.9121
20,2.6071
30,2.3606
40,2.3482
50,2.3311
60,2.3142
70,2.293
80,2.2785
90,2.2739
100,2.2786




TrainOutput(global_step=39957, training_loss=1.9815270925369528, metrics={'train_runtime': 14228.9462, 'train_samples_per_second': 22.466, 'train_steps_per_second': 2.808, 'total_flos': 1.0180860262093947e+18, 'train_loss': 1.9815270925369528, 'epoch': 2.9998873831600283})

In [8]:
model.save_pretrained("./tinyllama_lora_finetuned")
tokenizer.save_pretrained("./tinyllama_lora_finetuned")




('./tinyllama_lora_finetuned/tokenizer_config.json',
 './tinyllama_lora_finetuned/special_tokens_map.json',
 './tinyllama_lora_finetuned/tokenizer.model',
 './tinyllama_lora_finetuned/added_tokens.json',
 './tinyllama_lora_finetuned/tokenizer.json')

In [9]:
def generate_summary(text):
    prompt = f"[INST] Summarize the following clinical conversation:\n{text} [/INST]"
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    output = model.generate(**inputs, max_new_tokens=150, temperature=0.7)
    return tokenizer.decode(output[0], skip_special_tokens=True)

print(generate_summary("Patient reports chest pain and shortness of breath."))




[INST] Summarize the following clinical conversation:
Patient reports chest pain and shortness of breath. [/INST] Patient reports chest pain and shortness of breath. He is 65 years old, 5 10, 180 lbs. He has had a history of heart disease, hypertension, and diabetes. He has been taking medication for these conditions for many years. He has been to the ER twice in the past 3 months with chest pain and shortness of breath. The first time he was given a chest x-ray and a chest CT scan. The second time he was given a chest x-ray and a chest CT scan. The CT scan showed no abnormalities. He has been told that he has a narrowed artery in his heart.


### Batch Summarization with Fine-Tuned TinyLlama

In [2]:
import pandas as pd
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load model
tokenizer = AutoTokenizer.from_pretrained("./tinyllama_lora_finetuned")
model = AutoModelForCausalLM.from_pretrained("./tinyllama_lora_finetuned", device_map="auto")

# Load test data
test_df = pd.read_csv('test.csv').dropna(subset=['Conversation']).reset_index(drop=True)

# Batch Summarization
batch_size = 8
summaries = []

for i in tqdm(range(0, len(test_df), batch_size)):
    batch_texts = test_df['Conversation'][i:i+batch_size].tolist()
    batch_prompts = [f"[INST] Summarize the following clinical conversation:\n{text} [/INST]" for text in batch_texts]
    
    batch_inputs = tokenizer(batch_prompts, return_tensors="pt", padding=True, truncation=True, max_length=512).to("cuda")
    
    with torch.no_grad():
        output = model.generate(**batch_inputs, max_new_tokens=150, temperature=0.7, do_sample=True)
    
    batch_summaries = tokenizer.batch_decode(output, skip_special_tokens=True)
    summaries.extend(batch_summaries)

test_df['Generated_Summary'] = summaries


  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 702/702 [41:46<00:00,  3.57s/it]


In [4]:
from tqdm import tqdm
tqdm.pandas()


### Clinical QA with PubMedBERT

In [6]:
from transformers import AutoTokenizer, AutoModelForQuestionAnswering

qa_tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
qa_model = AutoModelForQuestionAnswering.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract").to("cuda")

def answer_with_pubmedbert(context, question):
    inputs = qa_tokenizer(
        question,
        context,
        return_tensors="pt",
        truncation=True,
        max_length=512,
        padding="max_length"
    ).to("cuda")
    
    with torch.no_grad():
        outputs = qa_model(**inputs)
    
    start_idx = torch.argmax(outputs.start_logits)
    end_idx = torch.argmax(outputs.end_logits)
    
    return qa_tokenizer.decode(inputs["input_ids"][0][start_idx:end_idx+1])


# Example questions
questions = [
    "What symptoms does the patient report?",
    "What diagnosis is mentioned?",
    "Is there any treatment discussed?"
]

# Apply QA to generated summaries
for q in questions:
    test_df[q] = test_df['Generated_Summary'].progress_apply(lambda text: answer_with_pubmedbert(text, q))


Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 5609/5609 [00:33<00:00, 165.59it/s]
100%|██████████| 5609/5609 [00:34<00:00, 164.48it/s]
100%|██████████| 5609/5609 [00:34<00:00, 163.56it/s]


###  Save the Final Output CSV

In [7]:
test_df.to_csv('test_with_summaries_and_answers.csv', index=False)


In [8]:
print(test_df[['Conversation', 'Generated_Summary', questions[0], questions[1], questions[2]]].head())


                                        Conversation  \
0  The conversation between human and AI assistan...   
1  The conversation between human and AI assistan...   
2  The conversation between human and AI assistan...   
3  The conversation between human and AI assistan...   
4  The conversation between human and AI assistan...   

                                   Generated_Summary  \
0  [INST] Summarize the following clinical conver...   
1  [INST] Summarize the following clinical conver...   
2  [INST] Summarize the following clinical conver...   
3  [INST] Summarize the following clinical conver...   
4  [INST] Summarize the following clinical conver...   

              What symptoms does the patient report?  \
0  and legs to mid calf feeling like they are on ...   
1  we want to take treatment from you. for your k...   
2  , from what you have described, it appears tha...   
3  i have a real tight neck that wants to keep dr...   
4  , but is there anything else i can do? do y

### Save Fine-Tuned TinyLLaMA

In [9]:
model.save_pretrained("./deployed_tinyllama_lora")
tokenizer.save_pretrained("./deployed_tinyllama_lora")




('./deployed_tinyllama_lora/tokenizer_config.json',
 './deployed_tinyllama_lora/special_tokens_map.json',
 './deployed_tinyllama_lora/tokenizer.model',
 './deployed_tinyllama_lora/added_tokens.json',
 './deployed_tinyllama_lora/tokenizer.json')

### Save PubMedBERT (if you fine-tuned QA)

In [10]:
qa_model.save_pretrained("./deployed_pubmedbert_qa")
qa_tokenizer.save_pretrained("./deployed_pubmedbert_qa")


('./deployed_pubmedbert_qa/tokenizer_config.json',
 './deployed_pubmedbert_qa/special_tokens_map.json',
 './deployed_pubmedbert_qa/vocab.txt',
 './deployed_pubmedbert_qa/added_tokens.json',
 './deployed_pubmedbert_qa/tokenizer.json')

In [11]:
!pip install streamlit transformers torch


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting streamlit
  Downloading streamlit-1.47.0-py3-none-any.whl.metadata (9.0 kB)
Collecting altair<6,>=4.0 (from streamlit)
  Downloading altair-5.5.0-py3-none-any.whl.metadata (11 kB)
Collecting blinker<2,>=1.5.0 (from streamlit)
  Downloading blinker-1.9.0-py3-none-any.whl.metadata (1.6 kB)
Collecting cachetools<7,>=4.0 (from streamlit)
  Downloading cachetools-6.1.0-py3-none-any.whl.metadata (5.4 kB)
Collecting protobuf<7,>=3.20 (from streamlit)
  Downloading protobuf-6.31.1-cp39-abi3-manylinux2014_x86_64.whl.metadata (593 bytes)
Collecting tenacity<10,>=8.1.0 (from streamlit)
  Downloading tenacity-9.1.2-py3-none-any.whl.metadata (1.2 kB)
Collecting toml<2,>=0.10.1 (from streamlit)
  Downloading toml-0.10.2-py2.py3-none-any.whl.metadata (7.1 kB)
Collecting watchdog<7,>=2.1.5 (from streamlit)
  Downloading watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.3/44.3 kB[0m [31m5.4 MB/s[0m eta [36