# Importing Necessary Libraries

In [18]:
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
     # Set up training parameters and manage the training loop.
    TrainingArguments, Trainer,
 # Configure quantization (4‑bit precision) to reduce memory usage.
    BitsAndBytesConfig
)
# Set up and apply LoRA to the base model for efficient fine‑tuning.
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import load_dataset, concatenate_datasets
import os
import re
import pandas as pd

# Device & Quantization Configuration

In [19]:
# 1. Configuration
BASE_DIR = "/Users/arjunanand/Documents/SE_Project/ClinicConnect/AI-Model"
DATA_DIR = os.path.join(BASE_DIR, "data")
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"
OUTPUT_DIR = os.path.join(BASE_DIR, "trained_models/mistral-clinicconnect")
os.makedirs(OUTPUT_DIR, exist_ok=True)

# 2. Device Configuration for macOS
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# 3. Model Setup with MPS Optimization
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)


Using device: mps


# Model and Tokenizer Loading
We are using Mistral 7B v0.2 model since it has a 32k context window (vs 8k context in v0.1). In general mistral 7B is highly efficient, open-source language model that achieves competitive performance on NLP benchmarks even with small number of parameters.

Have to use mps optmization since model is running on mac and does not have a dedicated GPU.


In [20]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=quant_config,
    device_map={"": device},
    use_safetensors=True
)
model = prepare_model_for_kbit_training(model)

ImportError: Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`

# LoRA (Low-Rank Adaptation) Configuration and Application
LoRA method has been used to fine tune our mistral 7B model that only changes a small number of trainable parameters, significantly reducing memory and computational requirements without compromising performance.
It decomposes a large matrix into two smaller low-rank matrices in the attention layers.

In [None]:
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)

# Dataset Preparation and Tokenization

In [None]:
# This function formats the data into a text prompt and then tokenizes the model input format with a max length of 2048 tokens.
def format_instruction(example):
    # Clinical Care Plans
    if "care_plan" in example:
        care_plan = pd.json_normalize(example['care_plan']).to_markdown()
        return {
            "text": f"""<s>[INST] <<care_plan_generation>>
            Generate comprehensive care plan for {example['primary_condition']['name']} ({example['primary_condition']['subtype']}):
            Patient Profile: {example['demographics']}
            Comorbidities: {', '.join(example['comorbidities']) if example['comorbidities'] else 'None'}
            [/INST]
            {care_plan}</s>"""
        }
    
    # Doctor Q&A
    if "input" in example:
        return {
            "text": f"""<s>[INST] <<medical_qa>>
            {example['instruction']}
            Patient Description: {example['input']}
            [/INST]
            {example['output']}</s>"""
        }
    
    # OASIS Documentation
    return {
        "text": f"""<s>[INST] <<clinical_analysis>>
        Analyze OASIS documentation and extract key care considerations:
        {example['text'][:2000]}... [truncated]
        [/INST]
        Key Findings:
        - Requires fall risk assessment
        - Monitor for medication interactions
        - Schedule weekly BP checks</s>"""
    }



## Minor Data Formatting required for the datasets

In [None]:
def clean_oasis(text):
    # Remove legal/disclaimer content
    patterns = [
        r"PRA Disclosure Statement.*?Baltimore, Maryland 21244-1850\.",
        r"Adapted from:.*?NACHC\.\n",
        r""  # Remove form feed characters
    ]
    for pattern in patterns:
        text = re.sub(pattern, "", text, flags=re.DOTALL)
    
    # Extract structured components
    sections = re.split(r"(Section [A-Z]+:|Enter Code|↓)", text)
    cleaned = "\n".join([s for s in sections if len(s.strip()) > 10])
    
    return cleaned[:3000]  # Keep most relevant parts

def normalize_qa(example):
    # Standardize medical terms
    medical_mapping = {
        "panadol": "acetaminophen",
        "Z&D": "zinc supplementation"
    }
    
    for term, replacement in medical_mapping.items():
        example['input'] = re.sub(rf"\b{term}\b", replacement, example['input'], flags=re.I)
        example['output'] = re.sub(rf"\b{term}\b", replacement, example['output'], flags=re.I)
    
    # Add clinical context
    example['instruction'] = "As a board-certified physician, provide evidence-based recommendations for:"
    
    return example

def structure_careplan(example):
    plan = example['care_plan']
    return {
        "monitoring_schedule": "\n- ".join(plan['monitoring']),
        "medications": {
            "oral": "\n- ".join(plan['medications']['oral']),
            "injectable": "\n- ".join(plan['medications']['injectable'])
        },
        "lifestyle_recommendations": "\n- ".join(plan['lifestyle'])
    }

## Load and Concatenate Data

In [None]:
# Load and process datasets
oasis = load_dataset("text", data_files="clean_oasis.txt") \
    .map(lambda x: {"text": clean_oasis(x["text"])})

doctor_qa = load_dataset("csv", data_files="Doctor-Healthcare-100k.csv") \
    .map(normalize_qa)

careplans = load_dataset("csv", data_files="clinical_care_plans.csv") \
    .map(structure_careplan)

# Combine with weighting
combined = concatenate_datasets([
    oasis["train"].shuffle(seed=42).select(range(5000)),
    doctor_qa["train"].shuffle(seed=42),
    careplans["train"].shuffle(seed=42).select(range(10000))
])

# Final formatting
tokenized_data = combined.map(format_instruction).map(
    lambda x: tokenizer(
        x["text"],
        max_length=2048,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    ),
    batched=True
)

# Training Setup & Training the Model

In [None]:
# Training Arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    optim="paged_adamw_8bit",
    learning_rate=2e-5,
    fp16=True,
    logging_steps=50,
    save_strategy="epoch",
    evaluation_strategy="steps",
    eval_steps=500,
    report_to="none",
    push_to_hub=False
)

# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_data,
    data_collator=lambda data: {
        "input_ids": torch.stack([d["input_ids"] for d in data]),
        "attention_mask": torch.stack([d["attention_mask"] for d in data]),
        "labels": torch.stack([d["input_ids"] for d in data])
    }
)

# Start Training
print("Starting training...")
trainer.train()

# Save Model
model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"Model saved to {OUTPUT_DIR}")

# Test Inference

In [None]:
def generate_care_plan(patient_profile):
    prompt = f"""<s>[INST] <<care_plan_generation>>
    Generate comprehensive care plan for:
    Patient Profile: {patient_profile}
    [/INST]"""
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    outputs = model.generate(
        **inputs,
        max_new_tokens=512,
        temperature=0.7,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Test with sample input
test_profile = "65yo Male, T2DM, HbA1c 8.5%, CKD Stage 3, Hypertension"
print("Sample Care Plan:")
print(generate_care_plan(test_profile))