# Importing Necessary Libraries

In [None]:
!pip install -q datasets
!pip install -q transformers peft accelerate bitsandbytes datasets


import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
     # Set up training parameters and manage the training loop.
    TrainingArguments, Trainer,
 # Configure quantization (4‑bit precision) to reduce memory usage.
    BitsAndBytesConfig
)
# Set up and apply LoRA to the base model for efficient fine‑tuning.
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import load_dataset, concatenate_datasets
import os
import re
import pandas as pd

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from huggingface_hub import notebook_login
notebook_login()

# Device & Quantization Configuration

In [None]:
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"

# 1. Configuration: Set directories in Google Drive
BASE_DIR = "/content/drive/MyDrive/ClinicConnect"
DATA_DIR = os.path.join(BASE_DIR, "data")
OUTPUT_DIR = os.path.join(BASE_DIR, "trained_models/mistral-clinicconnect")
os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

# 2. Device Configuration for Colab (using CUDA)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 3. Model Setup with CUDA and 4-bit Quantization
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)

# Model and Tokenizer Loading
We are using Mistral 7B v0.2 model since it has a 32k context window (vs 8k context in v0.1). In general mistral 7B is highly efficient, open-source language model that achieves competitive performance on NLP benchmarks even with small number of parameters.

Have to use mps optmization since model is running on mac and does not have a dedicated GPU.


In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=quant_config,
    device_map="auto",  # Automatically maps to available devices
    use_safetensors=True
)

# LoRA (Low-Rank Adaptation) Configuration and Application
LoRA method has been used to fine tune our mistral 7B model that only changes a small number of trainable parameters, significantly reducing memory and computational requirements without compromising performance.
It decomposes a large matrix into two smaller low-rank matrices in the attention layers.

In [None]:
# 4. Configure LORA for PEFT
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)

# Dataset Preparation and Tokenization

In [None]:
def format_care_plan(plan):
    """Safely format care plan data with null checks."""
    # Handle None input
    if not plan:
        return "**No care plan available**"

    # Safely get nested values
    monitoring = "\n- ".join(plan.get('monitoring', [])) or "None"
    medications = plan.get('medications', {})
    oral_meds = "\n- ".join(medications.get('oral', [])) or "None"
    injectable_meds = "\n- ".join(medications.get('injectable', [])) or "None"
    lifestyle = "\n- ".join(plan.get('lifestyle', [])) or "None"

    return f"""
**Monitoring:**
- {monitoring}

**Medications:**
- Oral: {oral_meds}
- Injectable: {injectable_meds}

**Lifestyle Recommendations:**
- {lifestyle}
"""

def format_instruction(example):
    """Format data into text prompts with validation."""
    try:
        if "care_plan" in example:
            # Validate care plan structure
            care_plan_data = example.get('care_plan') or {}
            primary_condition = care_plan_data.get('primary_condition', {})

            return {
                "text": f"""<s>[INST] Generate care plan for {primary_condition.get('name', 'unknown')}:
                Patient: {care_plan_data.get('demographics', 'No demographics')}
                Comorbidities: {', '.join(care_plan_data.get('comorbidities', [])) or 'None'}
                [/INST]
                {format_care_plan(care_plan_data.get('care_plan', {}))}</s>"""
            }

        elif "input" in example and "output" in example:
            return {
                "text": f"""<s>[INST] {example.get('instruction', 'Medical question:')}
                {example['input']}
                [/INST]
                {example.get('output', 'No response available')}</s>"""
            }

        elif "text" in example:
            return {
                "text": f"""<s>[INST] Analyze clinical documentation:
                {example['text'][:2000]}
                [/INST]
                Key considerations: [Extracted from OASIS data]</s>"""
            }

    except Exception as e:
        print(f"Error formatting example: {e}")
        return {"text": "<s>[INST] Invalid data format [/INST] Error in example</s>"}



## Minor Data Formatting required for the datasets

In [None]:
# 5. Data Processing Functions
def clean_oasis(text):
    """Clean OASIS documentation by removing disclaimers and extracting relevant sections."""
    patterns = [
        r"PRA Disclosure Statement.*?Baltimore, Maryland 21244-1850\.",
        r"Adapted from:.*?NACHC\.\n",
        r"\f"  # Remove form feed characters
    ]
    for pattern in patterns:
        text = re.sub(pattern, "", text, flags=re.DOTALL)

    sections = re.split(r"(Section [A-Z]+:|Enter Code|↓)", text)
    cleaned = "\n".join([s for s in sections if len(s.strip()) > 10])

    return cleaned[:3000]  # Limit to most relevant parts

def normalize_qa(example):
    """Normalize Q&A data by standardizing medical terms and adding clinical context."""
    medical_mapping = {
        "panadol": "acetaminophen",
        "Z&D": "zinc supplementation"
    }

    for term, replacement in medical_mapping.items():
        example['input'] = re.sub(rf"\b{term}\b", replacement, example['input'], flags=re.I)
        example['output'] = re.sub(rf"\b{term}\b", replacement, example['output'], flags=re.I)

    example['instruction'] = "As a board-certified physician, provide evidence-based recommendations for:"
    return example

import ast
import json

def structure_careplan(example):
    # Handle care_plan as above
    care_plan = example.get('care_plan')
    if isinstance(care_plan, str):
        if care_plan.strip() == "None":
            plan = {}
        else:
            try:
                plan = ast.literal_eval(care_plan)
            except (ValueError, SyntaxError):
                plan = {}
    elif isinstance(care_plan, dict):
        plan = care_plan
    else:
        plan = {}

    # Clean comorbidities
    comorbidities = example.get('comorbidities', [])
    if isinstance(comorbidities, str):
        try:
            comorbidities = ast.literal_eval(comorbidities)
        except (ValueError, SyntaxError):
            comorbidities = []
    # Remove 'none' from the list, treating it as no comorbidities
    if isinstance(comorbidities, list) and 'none' in comorbidities:
        comorbidities = [c for c in comorbidities if c != 'none']
        if not comorbidities:  # If only 'none' was present, make it empty
            comorbidities = []

    return {
        "care_plan": {
            "monitoring": plan.get('monitoring', []),
            "medications": plan.get('medications', {}),
            "lifestyle": plan.get('lifestyle', [])
        },
        "comorbidities": comorbidities
    }

## Load and Concatenate Data

In [None]:
# # Load OASIS text data, cleaning it first
# oasis = load_dataset("text", data_files=os.path.join(DATA_DIR, "clean_oasis.txt")).map(lambda x: {"text": clean_oasis(x["text"])})

# # Load Doctor Q&A data and normalize it
# doctor_qa = load_dataset("csv", data_files=os.path.join(DATA_DIR, "Doctor-HealthCare-100k.csv")).map(normalize_qa)

# # Load Clinical Care Plans data and restructure it
# careplans = load_dataset("csv", data_files=os.path.join(DATA_DIR, "clinical_care_plans.csv")).map(structure_careplan)

# # Combine datasets (with shuffling and selection for weighting)
# oasis_size = min(5000, len(oasis["train"]))
# careplans_size = min(10000, len(careplans["train"]))

# combined = concatenate_datasets([
#     oasis["train"].shuffle(seed=42).select(range(oasis_size)),
#     doctor_qa["train"].shuffle(seed=42),
#     careplans["train"].shuffle(seed=42).select(range(careplans_size))
# ])
oasis = load_dataset("text", data_files=os.path.join(DATA_DIR, "clean_oasis.txt")).map(lambda x: {"text": clean_oasis(x["text"])})
doctor_qa = load_dataset("csv", data_files=os.path.join(DATA_DIR, "Doctor-HealthCare-100k.csv")).map(normalize_qa)
careplans = load_dataset("csv", data_files=os.path.join(DATA_DIR, "clinical_care_plans.csv")).map(structure_careplan)

combined = concatenate_datasets([
    oasis["train"].shuffle(seed=42).select(range(min(5000, len(oasis["train"])))),
    doctor_qa["train"].shuffle(seed=42),
    careplans["train"].shuffle(seed=42).select(range(min(10000, len(careplans["train"]))))
])

# Format, remove unused columns, and filter out short examples
formatted = combined.map(
    lambda x: format_instruction(x) or {"text": ""},
    remove_columns=combined.column_names
).filter(lambda x: len(x["text"]) > 50)

# Tokenize the formatted text prompts
tokenized_data = formatted.map(
    lambda x: tokenizer(
        x["text"],
        max_length=2048,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    ),
    batched=True
)




# Splitting Data into Training and Testing Sets

In [None]:
# Split tokenized_data into training and evaluation datasets.
split_data = tokenized_data.train_test_split(test_size=0.1, seed=42)
train_dataset = split_data["train"]
eval_dataset = split_data["test"]

# Training Setup & Training the Model

In [None]:
torch.cuda.empty_cache()
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
model.config.use_cache = False

# 6. Training Setup with Evaluation Strategy
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    optim="paged_adamw_8bit",
    learning_rate=2e-5,
    fp16=True,
    logging_steps=50,
    save_strategy="epoch",
    eval_strategy="steps",  # Updated to avoid FutureWarning
    eval_steps=500,
    report_to="none",
    push_to_hub=False
)

# 7. Initialize Trainer with corrected data_collator
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=lambda data: {
        "input_ids": torch.tensor([d["input_ids"] for d in data]),
        "attention_mask": torch.tensor([d["attention_mask"] for d in data]),
        "labels": torch.tensor([d["input_ids"] for d in data])  # Causal LM: labels = input_ids
    }
)

# 8. Start Training
print("Starting training...")
trainer.train()

# Test Inference

In [None]:
# 12. Inference Function
def generate_care_plan(patient_profile):
    """Generate a care plan based on patient profile."""
    prompt = f"""<s>[INST] <<care_plan_generation>>
    Generate comprehensive care plan for:
    Patient Profile: {patient_profile}
    [/INST]"""

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    outputs = model.generate(
        **inputs,
        max_new_tokens=512,
        temperature=0.7,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Test with sample input
test_profile = "65yo Male, T2DM, HbA1c 8.5%, CKD Stage 3, Hypertension"
print("Sample Care Plan:")
print(generate_care_plan(test_profile))