In [49]:
import os
import torch
import logging
import gc
import pandas as pd
import numpy as np
from datasets import Dataset
import time
from tqdm import tqdm
from typing import Dict, List, Union
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from transformers import BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel


**FINAL V**

In [50]:
# Set environment variables for memory optimization
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Ensure using only one GPU

# Initialize logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)  # Fixed from __main__

# Check for GPU availability and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")


INFO:__main__:Using device: cpu


In [51]:
def load_medical_model(model_name="malhajar/meditron-7b-chat"):
    """Load the medical model and tokenizer with 4-bit quantization and prepare for fine-tuning."""
    logger.info(f"Loading model: {model_name}")
    
    # First load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        model_name, 
        trust_remote_code=True,
        use_fast=True
    )
    
    # Define prompt template for this specific model
    if not hasattr(tokenizer, 'chat_template') or tokenizer.chat_template is None:
        logger.info("Setting chat template for Meditron")
        tokenizer.chat_template = """{% for message in messages %}
{% if message['role'] == 'system' %}### Instruction:
{{ message['content'] }}
{% elif message['role'] == 'user' %}### Instruction:
{{ message['content'] }}
{% elif message['role'] == 'assistant' %}### Response:
{{ message['content'] }}
{% endif %}
{% if loop.last and add_generation_prompt %}### Response:
{% endif %}
{% endfor %}"""
    
    # Clear memory before loading model
    torch.cuda.empty_cache()
    gc.collect()
    
    # Configure quantization with CPU offloading enabled
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4"
    )
    
    # Load model with quantization config and auto device map
    logger.info("Loading 4-bit quantized model...")
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",  # Simplified device mapping
        quantization_config=quantization_config,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True
    )
    
    # Return the base model and tokenizer without PEFT modifications
    # We'll apply PEFT in a separate step
    return model, tokenizer


In [52]:
def apply_peft_to_model(model):
    """Apply Parameter-Efficient Fine-Tuning (PEFT) to the model."""
    logger.info("Applying PEFT to the model...")
    
    # CRITICAL STEP 1: Properly prepare model for kbit training
    model = prepare_model_for_kbit_training(model)
    
    # CRITICAL STEP 2: Ensure input gradients are enabled
    if hasattr(model, "enable_input_require_grads"):
        model.enable_input_require_grads()
    else:
        def make_inputs_require_grad(module, input, output):
            output.requires_grad_(True)
        model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
    
    # CRITICAL STEP 3: Define LoRA configuration with more comprehensive targets
    lora_config = LoraConfig(
        r=8,                    # Rank for LoRA
        lora_alpha=32,          # Alpha parameter for LoRA
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],  # Target attention modules
        lora_dropout=0.05,     
        bias="none",           
        task_type="CAUSAL_LM"  
    )
    
    # CRITICAL STEP 4: Apply LoRA to create a trainable model using get_peft_model
    peft_model = get_peft_model(model, lora_config)
    
    # Verify trainable parameters
    trainable_params = sum(p.numel() for p in peft_model.parameters() if p.requires_grad)
    all_params = sum(p.numel() for p in peft_model.parameters())
    logger.info(f"Trainable parameters: {trainable_params}")
    logger.info(f"All parameters: {all_params}")
    logger.info(f"Trainable%: {100 * trainable_params / all_params:.4f}%")
    
    return peft_model


In [53]:
def process_dataset_for_fine_tuning(csv_file):
    """Process the dataset for fine-tuning using original_question and ideal_answer columns."""
    logger.info(f"Processing dataset from {csv_file}")
    
    try:
        df = pd.read_csv(csv_file)
    except Exception as e:
        logger.error(f"Error loading dataset: {e}")
        # Create a small test dataset for debugging
        logger.info("Creating a small sample dataset for testing")
        test_data = {
            'original_question': [
                "What are the symptoms of diabetes?",
                "How is hypertension diagnosed?",
                "What are common treatments for migraine?"
            ],
            'ideal_answer': [
                "Common symptoms of diabetes include frequent urination, increased thirst, unexplained weight loss, extreme hunger, blurred vision, tingling in the extremities, and frequent infections.",
                "Hypertension is diagnosed when blood pressure readings consistently show systolic pressure above 130 mmHg or diastolic pressure above 80 mmHg. Diagnosis typically requires multiple readings over time.",
                "Common treatments for migraine include pain relievers, triptans, anti-nausea medications, preventive medications like beta blockers, and lifestyle changes such as stress management and regular sleep."
            ]
        }
        df = pd.DataFrame(test_data)
    
    # Check if required columns exist
    if 'original_question' not in df.columns or 'ideal_answer' not in df.columns:
        logger.error("Dataset missing required columns (original_question and ideal_answer)")
        return None
    
    # Filter rows that have both question and ideal answer
    df = df.dropna(subset=['original_question', 'ideal_answer'])
    
    # For low memory, limit dataset size
    if len(df) > 50:  # Reduced from 100 to 50
        logger.info(f"Limiting dataset to 50 examples for memory efficiency (from {len(df)})")
        df = df.sample(50, random_state=42)
    
    logger.info(f"Dataset has {len(df)} valid training examples")
    
    # Create a system prompt for all examples
    system_prompt = "You are an AI Medical Assistant. Give accurate and helpful answers to medical questions."
    
    # Create formatted examples for fine-tuning
    train_data = []
    
    for _, row in df.iterrows():
        # Format each example as a conversation with proper tokens
        conversation = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": row['original_question']},
            {"role": "assistant", "content": row['ideal_answer']}
        ]
        
        # Format the conversation according to the model's expected format
        example = {"conversation": conversation}
        train_data.append(example)
    
    # Convert to HuggingFace Dataset
    dataset = Dataset.from_pandas(pd.DataFrame(train_data))
    
    return dataset


In [54]:
def process_dataset_for_fine_tuning(csv_file):
    """Process the dataset for fine-tuning using original_question and ideal_answer columns."""
    logger.info(f"Processing dataset from {csv_file}")
    
    try:
        df = pd.read_csv(csv_file)
    except Exception as e:
        logger.error(f"Error loading dataset: {e}")
        # Create a small test dataset for debugging
        logger.info("Creating a small sample dataset for testing")
        test_data = {
            'original_question': [
                "What are the symptoms of diabetes?",
                "How is hypertension diagnosed?",
                "What are common treatments for migraine?"
            ],
            'ideal_answer': [
                "Common symptoms of diabetes include frequent urination, increased thirst, unexplained weight loss, extreme hunger, blurred vision, tingling in the extremities, and frequent infections.",
                "Hypertension is diagnosed when blood pressure readings consistently show systolic pressure above 130 mmHg or diastolic pressure above 80 mmHg. Diagnosis typically requires multiple readings over time.",
                "Common treatments for migraine include pain relievers, triptans, anti-nausea medications, preventive medications like beta blockers, and lifestyle changes such as stress management and regular sleep."
            ]
        }
        df = pd.DataFrame(test_data)
    
    # Check if required columns exist
    if 'original_question' not in df.columns or 'ideal_answer' not in df.columns:
        logger.error("Dataset missing required columns (original_question and ideal_answer)")
        return None
    
    # Filter rows that have both question and ideal answer
    df = df.dropna(subset=['original_question', 'ideal_answer'])
    
    # For low memory, limit dataset size
    if len(df) > 50:  # Reduced from 100 to 50
        logger.info(f"Limiting dataset to 50 examples for memory efficiency (from {len(df)})")
        df = df.sample(50, random_state=42)
    
    logger.info(f"Dataset has {len(df)} valid training examples")
    
    # Create a system prompt for all examples
    system_prompt = "You are an AI Medical Assistant. Give accurate and helpful answers to medical questions."
    
    # Create formatted examples for fine-tuning
    train_data = []
    
    for _, row in df.iterrows():
        # Format each example as a conversation with proper tokens
        conversation = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": row['original_question']},
            {"role": "assistant", "content": row['ideal_answer']}
        ]
        
        # Format the conversation according to the model's expected format
        example = {"conversation": conversation}
        train_data.append(example)
    
    # Convert to HuggingFace Dataset
    dataset = Dataset.from_pandas(pd.DataFrame(train_data))
    
    return dataset


In [55]:
class MedicalDataCollator:
    def __init__(self, tokenizer, max_length=192):  # Reduced max length for memory savings
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __call__(self, examples):
        # Extract conversations
        conversations = [ex["conversation"] for ex in examples]
        
        # Format according to the tokenizer's chat template if available
        if hasattr(self.tokenizer, "apply_chat_template"):
            input_texts = [
                self.tokenizer.apply_chat_template(
                    conv[:-1],  # Exclude the assistant's response
                    tokenize=False,
                    add_generation_prompt=True
                )
                for conv in conversations
            ]
            
            target_texts = [
                self.tokenizer.apply_chat_template(
                    conv,  # Full conversation
                    tokenize=False,
                    add_generation_prompt=False
                )
                for conv in conversations
            ]
        else:
            # Fallback formatting
            input_texts = []
            target_texts = []
            
            for conv in conversations:
                system = next((msg["content"] for msg in conv if msg["role"] == "system"), "")
                user = next((msg["content"] for msg in conv if msg["role"] == "user"), "")
                assistant = next((msg["content"] for msg in conv if msg["role"] == "assistant"), "")
                
                input_text = f"### Instruction:\n{system}\n### Instruction:\n{user}\n### Response:"
                target_text = f"### Instruction:\n{system}\n### Instruction:\n{user}\n### Response:\n{assistant}"
                
                input_texts.append(input_text)
                target_texts.append(target_text)
        
        # Tokenize inputs
        model_inputs = self.tokenizer(
            input_texts,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        
        # Tokenize targets
        labels = self.tokenizer(
            target_texts,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )["input_ids"]
        
        # Create labels, replacing padding tokens with -100
        labels_with_ignore_index = labels.clone()
        labels_with_ignore_index[labels == self.tokenizer.pad_token_id] = -100
        
        # Replace padding in input portion with -100 for loss calculation
        for i, (inp, full) in enumerate(zip(model_inputs["input_ids"], labels)):
            # Find end of input by comparing with full target
            input_len = len(inp)
            for j in range(min(input_len, len(full))):
                if j < input_len:
                    labels_with_ignore_index[i, j] = -100
        
        model_inputs["labels"] = labels_with_ignore_index
        return model_inputs



In [56]:
def fine_tune_model(model, tokenizer, dataset, output_dir="fine_tuned_model"):
    """Fine-tune the model using LoRA."""
    logger.info("Starting fine-tuning process")
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Define training arguments optimized for 4GB VRAM
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=1,              # Minimal epochs
        per_device_train_batch_size=1,   # Minimal batch size
        gradient_accumulation_steps=16,  # Increased for stability with small batches
        learning_rate=2e-4,              
        weight_decay=0.01,
        warmup_ratio=0.03,               # Shorter warmup to save time
        logging_steps=1,
        save_strategy="epoch",
        save_total_limit=1,              # Keep only the best model
        fp16=True,                       # Use mixed precision
        report_to="none",                # Disable reporting to save memory
        push_to_hub=False,
        gradient_checkpointing=True,     # Enable gradient checkpointing
        optim="paged_adamw_8bit",        # Use 8-bit optimizer for memory savings
        max_grad_norm=0.3,               # Reduce gradient norm for stability
        dataloader_num_workers=0,        # No parallel loading
        dataloader_pin_memory=False,     # Disable pinned memory
        max_steps=50,                    # Limit training steps
    )
    
    # Define trainer with the custom data collator
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        data_collator=MedicalDataCollator(tokenizer, max_length=400),  # Reduced max length
    )
    
    # Start training
    logger.info("Starting training...")
    trainer.train()
    
    # Save the fine-tuned model
    logger.info(f"Saving fine-tuned model to {output_dir}")
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    
    return model, tokenizer


In [57]:
def process_questions_file(input_csv, output_csv, model, tokenizer, batch_size=8):
    """Process the questions file with the fine-tuned model."""
    import time
    from concurrent.futures import ThreadPoolExecutor
    
    start_time = time.time()
    
    try:
        df = pd.read_csv(input_csv)
        logger.info(f"Loaded dataset with {len(df)} questions")
    except Exception as e:
        logger.error(f"Error loading dataset: {e}")
        return
    
    answerer = MedicalQuestionAnswerer(model, tokenizer)

    if 'original_answer' not in df.columns:
        df['original_answer'] = ""
    if 'faq_answer' not in df.columns:
        df['faq_answer'] = ""

    last_processed = 0
    for i, row in df.iterrows():
        if pd.notna(row['original_answer']) and row['original_answer'] != "":
            last_processed = i

    if last_processed > 0:
        logger.info(f"Resuming from question {last_processed+1}")

    save_interval = 50

    # Using ThreadPoolExecutor for concurrent processing of questions in batches
    with ThreadPoolExecutor(max_workers=4) as executor:
        futures = []
        
        for i in range(last_processed, len(df), batch_size):
            batch_end = min(i + batch_size, len(df))
            batch_df = df.iloc[i:batch_end].copy()
            
            for idx, row in tqdm(batch_df.iterrows(), total=len(batch_df), desc=f"Batch {i//batch_size + 1}/{(len(df)-last_processed)//batch_size + 1}"):
                if pd.notna(row['original_answer']) and row['original_answer'] != "":
                    continue
                
                original_question = row['original_question']
                generated_question = row['generated_question'] if 'generated_question' in row else None
                
                futures.append(executor.submit(answerer.answer_question, original_question))
                
                if generated_question:
                    futures.append(executor.submit(answerer.answer_question, generated_question))
                
                if idx % save_interval == 0:
                    df.to_csv(output_csv, index=False)

            # Collect results from futures
            result_idx = 0
            for future in tqdm(futures):
                try:
                    result = future.result()
                    # Update answers in dataframe - need to handle this better with specific indices
                    if result_idx % 2 == 0:  # Even indices are original questions
                        df.at[last_processed + result_idx//2, 'original_answer'] = result
                    else:  # Odd indices are generated questions
                        df.at[last_processed + result_idx//2, 'faq_answer'] = result
                    result_idx += 1
                except Exception as e:
                    logger.error(f"Error processing question: {e}")

            df.to_csv(output_csv, index=False)

            elapsed = time.time() - start_time
            questions_processed = batch_end - last_processed
            avg_time_per_q = elapsed / max(1, questions_processed)
            remaining_qs = len(df) - batch_end
            
            est_time_remaining = avg_time_per_q * remaining_qs
            
            logger.info(f"Processed {batch_end}/{len(df)} questions. "
                        f"Avg: {avg_time_per_q:.2f}s per question. "
                        f"Est. remaining: {est_time_remaining/60:.1f} minutes")

            # Clear GPU memory
            torch.cuda.empty_cache()
            gc.collect()
            futures = []  # Reset futures for the next batch

    df.to_csv(output_csv, index=False)
    total_time = time.time() - start_time
    logger.info(f"Completed in {total_time/60:.1f} minutes. Generated answers saved to {output_csv}")


In [58]:
class MedicalQuestionAnswerer:
    """Class to answer medical questions using a medical model."""
    
    def __init__(self, model, tokenizer):
        """Initialize with pre-loaded model and tokenizer."""
        self.model = model
        self.tokenizer = tokenizer
        
        # Get the device where the model is located
        self.device = next(model.parameters()).device
        
        # Set up model for inference
        self.model.eval()
        
        # Pre-compile prompt template for speed
        self.sys_message = "You are an AI Medical Assistant. Give brief answers as a medical professional."

    def answer_question(self, question):
        """Generate an answer for a given question."""
        try:
            # Clear cache before inference
            torch.cuda.empty_cache()
            
            # Check if apply_chat_template is available
            if hasattr(self.tokenizer, "apply_chat_template"):
                messages = [
                    {"role": "system", "content": self.sys_message},
                    {"role": "user", "content": question}
                ]
                prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            else:
                # Fallback to manual formatting
                prompt = f"<|system|>\n{self.sys_message}\n</s>\n<|user|>\n{question}\n</s>\n<|assistant|>"
            
            # Tokenize directly to the correct device
            inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device)
           
            with torch.no_grad():
                # Use memory-efficient generation settings
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=100,   # Reduced token count
                    temperature=0.7,
                    top_p=0.9,
                    do_sample=True,      
                    use_cache=True,
                    pad_token_id=self.tokenizer.eos_token_id
                )
            
            # Extract the generated text
            response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Extract only the assistant's response
            if "<|assistant|>" in response_text:
                answer = response_text.split("<|assistant|>")[-1].strip()
            elif "### Response:" in response_text:
                answer = response_text.split("### Response:")[-1].strip()
            else:
                answer = response_text.split(prompt)[-1].strip()
            
            return answer
        except Exception as e:
            logger.error(f"Error generating answer: {str(e)}")
            return f"Error generating answer: {str(e)}"
        finally:
            # Clean up memory after generation
            torch.cuda.empty_cache()
            gc.collect()


In [59]:
def process_questions_file(input_csv, output_csv, model, tokenizer, batch_size=1):  # Single question at a time
    """Process the questions file with the fine-tuned model."""
    start_time = time.time()
    
    try:
        df = pd.read_csv(input_csv)
        logger.info(f"Loaded dataset with {len(df)} questions")
    except Exception as e:
        logger.error(f"Error loading dataset: {e}")
        return
    
    answerer = MedicalQuestionAnswerer(model, tokenizer)

    if 'original_answer' not in df.columns:
        df['original_answer'] = ""
    if 'faq_answer' not in df.columns:
        df['faq_answer'] = ""

    last_processed = 0
    for i, row in df.iterrows():
        if pd.notna(row['original_answer']) and row['original_answer'] != "":
            last_processed = i

    if last_processed > 0:
        logger.info(f"Resuming from question {last_processed+1}")

    save_interval = 2  # Save more frequently
    
    # Process questions in smaller batches (single-threaded for stability)
    for i in range(last_processed, len(df), batch_size):
        batch_end = min(i + batch_size, len(df))
        batch_df = df.iloc[i:batch_end].copy()
        
        for idx, row in batch_df.iterrows():
            try:
                # Process original question if not already answered
                if not pd.notna(row['original_answer']) or row['original_answer'] == "":
                    original_question = row['original_question']
                    answer = answerer.answer_question(original_question)
                    df.at[idx, 'original_answer'] = answer
                    logger.info(f"Processed question {idx}")
                
                # Process generated question if available and not already answered
                if 'generated_question' in row and pd.notna(row['generated_question']):
                    if not pd.notna(row['faq_answer']) or row['faq_answer'] == "":
                        generated_question = row['generated_question']
                        faq_answer = answerer.answer_question(generated_question)
                        df.at[idx, 'faq_answer'] = faq_answer
                
                # Save progress frequently
                if idx % save_interval == 0:
                    df.to_csv(output_csv, index=False)
                    logger.info(f"Saved progress at index {idx}")
                    torch.cuda.empty_cache()
            except Exception as e:
                logger.error(f"Error processing question {idx}: {str(e)}")
                continue
        
        # Save progress after each batch
        df.to_csv(output_csv, index=False)
        
        elapsed = time.time() - start_time
        questions_processed = batch_end - last_processed
        avg_time_per_q = elapsed / max(1, questions_processed)
        remaining_qs = len(df) - batch_end
        
        est_time_remaining = avg_time_per_q * remaining_qs
        
        logger.info(f"Processed {batch_end}/{len(df)} questions. "
                    f"Avg: {avg_time_per_q:.2f}s per question. "
                    f"Est. remaining: {est_time_remaining/60:.1f} minutes")

        # Clear GPU memory
        torch.cuda.empty_cache()
        gc.collect()

    df.to_csv(output_csv, index=False)
    total_time = time.time() - start_time
    logger.info(f"Completed in {total_time/60:.1f} minutes. Generated answers saved to {output_csv}")


In [60]:
def main():
    # Import necessary libraries at the beginning
    from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
    from transformers import BitsAndBytesConfig
    from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel
    
    input_file = "T5_FAQS1.csv"  
    output_file = "medical_answers_finetuned.csv"
    fine_tuned_model_dir = "fine_tuned_medical_model"
    
    # Create offload directory
    os.makedirs("offload_folder", exist_ok=True)
    
    try:
        # Clear any existing cached memory
        torch.cuda.empty_cache()
        gc.collect()
        
        # Load the base model and tokenizer
        base_model_name = "malhajar/meditron-7b-chat"
        try:
            model, tokenizer = load_medical_model(base_model_name)
        except Exception as e:
            logger.warning(f"Failed to load 7B model: {e}")
            logger.info("Falling back to smaller model...")
            model, tokenizer = load_medical_model("malhajar/meditron-3b-chat")
        
        # Process dataset for fine-tuning
        dataset = process_dataset_for_fine_tuning(input_file)
        
        if dataset:
            # Apply PEFT to the model - THIS IS THE KEY CHANGE
            model = apply_peft_to_model(model)
            
            # Fine-tune the model with PEFT applied
            model, tokenizer = fine_tune_model(model, tokenizer, dataset, fine_tuned_model_dir)
            
            # Free up memory before inference
            torch.cuda.empty_cache()
            gc.collect()
            
            # Process questions using the fine-tuned model  
            process_questions_file(input_file, output_file, model, tokenizer)
        else:
            logger.error("Could not prepare dataset for fine-tuning. Check if the required columns exist.")
    except Exception as e:
        logger.error(f"An error occurred in the main process: {str(e)}")
        import traceback
        logger.error(traceback.format_exc())

if __name__ == "__main__":
    main()




INFO:__main__:Loading model: malhajar/meditron-7b-chat
INFO:__main__:Setting chat template for Meditron
INFO:__main__:Falling back to smaller model...
INFO:__main__:Loading model: malhajar/meditron-3b-chat
ERROR:__main__:An error occurred in the main process: malhajar/meditron-3b-chat is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `huggingface-cli login` or by passing `token=<your_token>`
ERROR:__main__:Traceback (most recent call last):
  File "/tmp/ipykernel_177421/1019782692.py", line 22, in main
    model, tokenizer = load_medical_model(base_model_name)
  File "/tmp/ipykernel_177421/4154498051.py", line 32, in load_medical_model
    quantization_config = BitsAndBytesConfig(
  File "/home/vjti/.local/lib/python3.10/site-packages/transformers/utils/quantization_config.py", line 433, in __init__
    self.post_init()
  

In [61]:
import os
import torch
import logging
import gc
import pandas as pd
import numpy as np
from datasets import Dataset
import time
from tqdm import tqdm
from typing import Dict, List, Union
import psutil # Import psutil to check system RAM
import traceback # For detailed error logging

# Try importing necessary libraries early
try:
    from transformers import (
        AutoModelForCausalLM,
        AutoTokenizer,
        TrainingArguments,
        Trainer,
        BitsAndBytesConfig,
        DataCollatorForSeq2Seq # Keep import, though using custom collator
    )
    # Import PeftModel explicitly for type checking
    from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel
    import accelerate # Ensure accelerate is available
    import bitsandbytes # Ensure bitsandbytes is available
except ImportError as e:
    print(f"Error importing libraries: {e}")
    print("Please ensure transformers, peft, datasets, accelerate, bitsandbytes, and psutil are installed.")
    exit()

# --- Configuration & Setup ---

# Set environment variables for memory optimization
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Control GPU visibility if needed

# Initialize logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Check for GPU availability and set device
if torch.cuda.is_available():
    device = torch.device("cuda")
    logger.info(f"CUDA available. Using device: {device}")
    try:
        gpu_index = torch.cuda.current_device()
        gpu_name = torch.cuda.get_device_name(gpu_index)
        logger.info(f"CUDA device name: {gpu_name}")
        t = torch.cuda.get_device_properties(gpu_index).total_memory
        r = torch.cuda.memory_reserved(gpu_index)
        a = torch.cuda.memory_allocated(gpu_index)
        f = r - a
        logger.info(f"Initial GPU Memory (Bytes): Total={t}, Reserved={r}, Allocated={a}")
        logger.info(f"Initial GPU Memory (GB): Total={t/1e9:.2f}GB, Reserved={r/1e9:.2f}GB, Allocated={a/1e9:.2f}GB, FreeReserved={f/1e9:.2f}GB")
    except Exception as e:
        logger.error(f"Could not get GPU details: {e}")
else:
    device = torch.device("cpu")
    logger.info("CUDA not available. Using device: CPU")

# Global configuration
INPUT_CSV = "T5_FAQS1.csv"
OUTPUT_CSV = "medical_answers_finetuned_v5.csv" # Incremented version
FINE_TUNED_MODEL_DIR = "fine_tuned_medical_model_v5" # Incremented version
BASE_MODEL_NAME = "malhajar/meditron-7b-chat"
MAX_DATASET_EXAMPLES = 50 # Limit examples for faster testing/demo
MAX_TRAINING_STEPS = 50   # Limit training steps for faster testing/demo
MAX_SEQ_LENGTH_COLLATOR = 128 # Max length for sequences during training (affects memory)
MAX_SEQ_LENGTH_INFERENCE = 500 # Max length for input sequence during inference
INFERENCE_BATCH_SIZE = 1 # Process one question at a time for inference
SAVE_INTERVAL_INFERENCE = 5 # Save progress every N questions during inference

# --- Function Definitions ---

def clear_gpu_memory():
    """Clears GPU memory."""
    logger.debug("Clearing GPU Cache...")
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()
        logger.debug("GPU Cache Cleared.")
    else:
        logger.debug("No GPU available, skipping cache clearing.")

def load_base_model_and_tokenizer(model_name):
    """Loads the base quantized model and tokenizer with explicit memory limits."""
    logger.info(f"Attempting to load base model: {model_name}")
    clear_gpu_memory()

    # 1. Load Tokenizer
    try:
        tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            trust_remote_code=True,
            use_fast=True
        )
        logger.info("Tokenizer loaded successfully.")
    except Exception as e:
        logger.error(f"Failed to load tokenizer for {model_name}: {e}")
        raise

    # 2. Set Chat Template & Padding Token
    if not hasattr(tokenizer, 'chat_template') or tokenizer.chat_template is None:
        logger.info("Setting chat template for Meditron model.")
        # Basic template structure - adjust if needed based on model card
        tokenizer.chat_template = """{% for message in messages %}{% if message['role'] == 'system' %}### Instruction:
{{ message['content'] }}
{% elif message['role'] == 'user' %}### Instruction:
{{ message['content'] }}
{% elif message['role'] == 'assistant' %}### Response:
{{ message['content'] }}
{% endif %}{% if loop.last and add_generation_prompt %}### Response:
{% endif %}{% endfor %}"""
    if tokenizer.pad_token is None:
        logger.warning("Tokenizer does not have a pad token. Setting to eos_token.")
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id # Explicitly set ID

    # 3. Configure Quantization
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4"
    )
    logger.info(f"Using quantization config: {quantization_config}")

    # 4. Define Memory Limits (Using INTEGER key for GPU)
    max_memory = {}
    if torch.cuda.is_available():
        gpu_index = 0 # Assuming device 0
        total_vram_bytes = torch.cuda.get_device_properties(gpu_index).total_memory
        # Leave ~2GB buffer for safety, adjust if needed
        gpu_mem_limit_bytes = total_vram_bytes - int(2 * 1024**3)
        max_memory[gpu_index] = f"{gpu_mem_limit_bytes // (1024**2)}MiB" # Use integer index
        logger.info(f"Calculated GPU memory limit for device {gpu_index}: {max_memory[gpu_index]}")

    total_ram_bytes = psutil.virtual_memory().total
    # Limit CPU RAM usage to 80% to avoid system freeze
    cpu_mem_limit_bytes = int(total_ram_bytes * 0.80)
    max_memory['cpu'] = f"{cpu_mem_limit_bytes // (1024**2)}MiB"

    logger.info(f"Setting max_memory for accelerate: {max_memory}")

    # 5. Load Model
    try:
        logger.info("Loading 4-bit quantized model with device_map='auto' and max_memory...")
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto", # Automatically distribute model layers across devices
            quantization_config=quantization_config,
            torch_dtype=torch.float16, # Use float16 for memory efficiency
            low_cpu_mem_usage=True, # Try to load shards sequentially to save CPU RAM
            max_memory=max_memory, # Apply memory limits per device
            trust_remote_code=True
        )
        logger.info("Base model loaded successfully.")
        logger.info(f"Model device map: {model.hf_device_map}")
    except Exception as e:
        logger.error(f"Failed to load base model {model_name} even with max_memory: {e}")
        logger.error(traceback.format_exc()) # Log full traceback
        raise

    return model, tokenizer

def apply_peft_to_model(model, tokenizer):
    """Applies PEFT/LoRA adapters to the loaded base model."""
    logger.info("Applying PEFT/LoRA adapters to the model...")
    clear_gpu_memory()

    # 1. Prepare model for k-bit training
    logger.info("Preparing model for k-bit training...")
    # Ensure gradient checkpointing is enabled here if desired for training
    try:
         # use_gradient_checkpointing=True can sometimes cause issues, let TrainingArguments handle it primarily.
         # Set it to False here, rely on TrainingArguments.gradient_checkpointing=True
         model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False)
         logger.info("prepare_model_for_kbit_training successful.")
    except Exception as e:
         logger.error(f"Error during prepare_model_for_kbit_training: {e}")
         raise

    # 2. Ensure input embeddings require gradients (can be crucial for some models/setups)
    if hasattr(model, "enable_input_require_grads"):
        logger.info("Enabling input require grads using model.enable_input_require_grads().")
        model.enable_input_require_grads()
    else:
        # Fallback method if the direct function isn't available
        logger.info("Attempting to enable input require grads using forward hook.")
        try:
            def make_inputs_require_grad(module, input, output):
                 if isinstance(output, torch.Tensor) and output.is_floating_point():
                     output.requires_grad_(True)
            embed_module = model.get_input_embeddings()
            if embed_module:
                 embed_module.register_forward_hook(make_inputs_require_grad)
                 logger.info("Gradient hook attached to input embeddings.")
            else:
                 logger.warning("Could not find input embeddings module.")
        except Exception as e:
             logger.warning(f"Failed to attach gradient hook: {e}. This might be okay.")

    # 3. Define LoRA configuration
    lora_config = LoraConfig(
        r=8, # Rank of the update matrices (lower value = fewer parameters)
        lora_alpha=16, # LoRA scaling factor
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], # Modules to apply LoRA to (check model architecture if needed)
        lora_dropout=0.05, # Dropout probability for LoRA layers
        bias="none", # Whether to train biases ('none', 'all', or 'lora_only')
        task_type="CAUSAL_LM" # Task type for PEFT
    )
    logger.info(f"Using LoRA config: {lora_config}")

    # 4. Apply LoRA using get_peft_model
    try:
        peft_model = get_peft_model(model, lora_config)
        logger.info("PEFT model created successfully using get_peft_model.")
    except Exception as e:
        logger.error(f"Failed to apply PEFT to the model using get_peft_model: {e}")
        raise

    # 5. Verify trainable parameters
    peft_model.print_trainable_parameters()

    # 6. <<< ADDED STEP: Explicitly mark as PEFT model >>>
    # This might help older/buggy Trainer checks, although usually not needed.
    if not isinstance(peft_model, PeftModel):
         logger.warning("get_peft_model did not return a PeftModel instance!")
    else:
         logger.info("Model is instance of PeftModel. Setting is_peft_model=True attribute just in case.")
         # Use setattr for safety in case attribute doesn't exist on all versions
         setattr(peft_model, 'is_peft_model', True)

    logger.info(f"Model type after get_peft_model: {type(peft_model)}")
    return peft_model

def process_dataset_for_fine_tuning(csv_file):
    """Loads and formats the dataset for fine-tuning."""
    logger.info(f"Processing dataset from {csv_file}")

    try:
        df = pd.read_csv(csv_file)
        logger.info(f"Loaded dataframe with {len(df)} rows.")
    except FileNotFoundError:
        logger.error(f"Dataset file not found: {csv_file}")
        return None
    except Exception as e:
        logger.error(f"Error loading dataset CSV: {e}")
        return None

    # Ensure required columns exist
    required_cols = ['original_question', 'ideal_answer']
    if not all(col in df.columns for col in required_cols):
        logger.error(f"Dataset missing required columns: {required_cols}. Found: {df.columns.tolist()}")
        return None

    # Drop rows where essential columns are missing
    df.dropna(subset=required_cols, inplace=True)
    logger.info(f"Rows after dropping NA in required columns: {len(df)}")

    if df.empty:
        logger.error("No valid data remaining after filtering.")
        return None

    # Limit dataset size if configured
    if len(df) > MAX_DATASET_EXAMPLES:
        logger.info(f"Limiting dataset to {MAX_DATASET_EXAMPLES} examples (from {len(df)})")
        df = df.sample(n=MAX_DATASET_EXAMPLES, random_state=42) # Use sampling with fixed state

    logger.info(f"Using {len(df)} examples for fine-tuning.")

    # Define the system prompt
    system_prompt = "You are an AI Medical Assistant. Provide accurate and concise answers to medical questions based on the context provided."

    # Format data into conversation structure
    formatted_data = []
    for _, row in df.iterrows():
        question = str(row['original_question']).strip()
        answer = str(row['ideal_answer']).strip()
        if not question or not answer: # Skip rows with empty question or answer
            continue
        # Structure expected by tokenizer.apply_chat_template
        conversation = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": question},
            {"role": "assistant", "content": answer} # The expected completion
        ]
        formatted_data.append({"conversation": conversation})

    if not formatted_data:
         logger.error("No valid conversation pairs found after formatting.")
         return None

    # Create Hugging Face Dataset object
    try:
        dataset = Dataset.from_list(formatted_data)
        logger.info("HuggingFace Dataset created successfully.")
    except Exception as e:
        logger.error(f"Failed to create HuggingFace Dataset: {e}")
        return None

    return dataset

class MedicalChatDataCollator:
    """Formats conversations and prepares inputs/labels for Causal LM fine-tuning."""
    def __init__(self, tokenizer, max_length=MAX_SEQ_LENGTH_COLLATOR):
        self.tokenizer = tokenizer
        self.max_length = max_length
        if not hasattr(self.tokenizer, "apply_chat_template"):
             raise ValueError("Tokenizer must have `apply_chat_template` method.")
        if self.tokenizer.pad_token_id is None:
            logger.warning("Collator: Tokenizer pad_token_id is None. Using eos_token_id.")
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

    def __call__(self, examples: List[Dict[str, List[Dict[str, str]]]]) -> Dict[str, torch.Tensor]:
        batch_conversations = [ex["conversation"] for ex in examples]

        try:
             # Prepare Inputs (prompt part only: system + user)
             # We add `add_generation_prompt=True` which typically adds the assistant prompt start (e.g., "### Response:")
             input_formatted = [
                  self.tokenizer.apply_chat_template(conv[:-1], tokenize=False, add_generation_prompt=True)
                  for conv in batch_conversations
             ]
             # Tokenize the inputs
             model_inputs = self.tokenizer(
                  input_formatted,
                  max_length=self.max_length,
                  padding="max_length", # Pad to max_length
                  truncation=True,
                  return_tensors="pt"
             )

             # Prepare Labels (full conversation: system + user + assistant)
             # We do NOT add generation prompt here, as we want the full sequence including the assistant's answer.
             labels_formatted = [
                  self.tokenizer.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)
                  for conv in batch_conversations
             ]
             # Tokenize the full conversation to get the target labels
             labels = self.tokenizer(
                  labels_formatted,
                  max_length=self.max_length,
                  padding="max_length", # Pad to max_length
                  truncation=True,
                  return_tensors="pt"
             )["input_ids"]

             # --- Label Masking ---
             # We only want to compute loss on the assistant's response tokens.
             # Mask tokens belonging to the system prompt, user query, and padding.
             masked_labels = labels.clone()

             for i in range(len(model_inputs["input_ids"])):
                  # Calculate the length of the input prompt (system + user + assistant prompt start)
                  # Use attention_mask sum, as input_ids might contain padding AFTER truncation
                  input_ids_len = model_inputs["attention_mask"][i].sum().item()

                  # Mask all tokens up to the end of the input prompt
                  masked_labels[i, :input_ids_len] = -100 # -100 is the standard ignore index for loss calculation

                  # Also mask padding tokens in the labels
                  # Find indices where labels are the padding token ID
                  label_pad_indices = (labels[i] == self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0]
                  if len(label_pad_indices) > 0:
                      # Mask all padding tokens (usually at the end)
                      # Get the index of the first padding token
                      first_pad_index = label_pad_indices[0].item()
                      masked_labels[i, first_pad_index:] = -100

             # Check if any example ended up with all labels masked (problematic)
             if torch.all(masked_labels == -100, dim=1).any():
                  logger.warning("Warning: An example has all labels masked. This might indicate issues with sequence lengths, truncation, or the collator logic.")
                  # Consider adding more detailed logging here if this happens frequently

             model_inputs["labels"] = masked_labels
             return model_inputs
        except Exception as e:
             logger.error(f"Error in Data Collator: {e}")
             logger.error(traceback.format_exc())
             # Return empty batch or raise error? Returning empty might cause issues later.
             raise # Let the error propagate

def fine_tune_model(model, tokenizer, dataset, output_dir):
    """Fine-tunes the PEFT model using the Trainer API."""
    logger.info("Starting fine-tuning process...")
    clear_gpu_memory()

    os.makedirs(output_dir, exist_ok=True)

    # Training Arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=1, # Single epoch for quick demo/test
        per_device_train_batch_size=1, # Small batch size due to model size/memory
        gradient_accumulation_steps=16, # Accumulate gradients to simulate larger batch size (effective batch size = 1 * 16 = 16)
        learning_rate=2e-4, # Common learning rate for LoRA
        weight_decay=0.01, # Regularization
        warmup_ratio=0.03, # Warmup steps as a fraction of total steps
        logging_steps=5, # Log metrics every 5 steps
        save_strategy="steps", # Save checkpoints based on steps
        save_steps=max(1, MAX_TRAINING_STEPS // 2), # Save halfway through (adjust as needed)
        save_total_limit=1, # Keep only the latest checkpoint
        fp16=torch.cuda.is_available(), # Enable mixed-precision training if CUDA is available
        gradient_checkpointing=True, # Use gradient checkpointing to save memory (at cost of compute) - RELY ON THIS
        optim="paged_adamw_8bit", # Use paged AdamW optimizer for memory efficiency with QLoRA
        max_grad_norm=0.3, # Gradient clipping
        dataloader_num_workers=0, # Set to 0 or small number, can cause issues otherwise
        dataloader_pin_memory=False, # Often False is better with device_map='auto'
        max_steps=MAX_TRAINING_STEPS, # Limit total training steps
        report_to="none", # Disable external reporting (like wandb) for simplicity
        push_to_hub=False, # Don't push to Hugging Face Hub
        remove_unused_columns=True, # Let Trainer remove columns not used by the model
        # Added for potential stability
        # ddp_find_unused_parameters=False # Sometimes needed with PEFT/gradient checkpointing
    )
    logger.info(f"Using Training Arguments: {training_args}")

    # Data Collator Instance
    data_collator = MedicalChatDataCollator(tokenizer, max_length=MAX_SEQ_LENGTH_COLLATOR)

    # === Debugging Logs Before Trainer Init ===
    logger.info(f"--- Preparing to initialize Trainer ---")
    logger.info(f"Model object ID: {id(model)}")
    logger.info(f"Model type passed to fine_tune_model: {type(model)}")
    logger.info(f"Is model instance of PeftModel? {isinstance(model, PeftModel)}")
    logger.info(f"Does model have 'peft_config' attribute? {hasattr(model, 'peft_config')}")
    logger.info(f"Does model have 'is_peft_model' attribute set? {getattr(model, 'is_peft_model', 'Not Set')}")
    if hasattr(model, 'hf_device_map'):
         logger.info(f"Model device map: {model.hf_device_map}")
    else:
         logger.info("Model does not have 'hf_device_map' attribute.")
    if hasattr(model, 'is_quantized'):
         logger.info(f"Model is_quantized: {model.is_quantized}")
    else:
         logger.info("Model does not have 'is_quantized' attribute.")
    # Check base model properties if it's a PeftModel
    if isinstance(model, PeftModel) and hasattr(model, 'base_model'):
         base = model.base_model
         logger.info(f"Base model type: {type(base)}")
         logger.info(f"Base model is_quantized: {getattr(base, 'is_quantized', 'Not Set')}")
         logger.info(f"Base model has quantization_config: {hasattr(base.config, 'quantization_config')}")
         if hasattr(base.config, 'quantization_config'):
              logger.info(f"Base model quantization_config type: {type(base.config.quantization_config)}")
    logger.info(f"--- End Pre-Trainer Init Logs ---")
    # === End Debugging Logs ===

    # Trainer Initialization
    trainer = None # Initialize to None
    try:
        trainer = Trainer(
            model=model, # Should be the PeftModel instance from apply_peft_to_model
            args=training_args,
            train_dataset=dataset,
            tokenizer=tokenizer,
            data_collator=data_collator,
            # callbacks=[...] # Add custom callbacks if needed
        )
        logger.info("Trainer initialized successfully.")
    except ValueError as ve:
        logger.error(f"ValueError during Trainer initialization: {ve}")
        logger.error("This likely means the Trainer still doesn't recognize the model as PEFT-compatible, possibly due to issues with quantization or PEFT setup.")
        logger.error(traceback.format_exc())
        return None, tokenizer # Cannot proceed if Trainer fails
    except Exception as e:
        logger.error(f"Unexpected error during Trainer initialization: {e}")
        logger.error(traceback.format_exc())
        return None, tokenizer

    # Training
    logger.info("Starting training...")
    trained_model = None
    try:
        # Start training
        train_result = trainer.train()
        logger.info("Training completed.")
        # Log metrics
        metrics = train_result.metrics
        logger.info(f"Training metrics: {metrics}")
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trained_model = model # Assign model if training successful
    except Exception as e:
        logger.error(f"Error during training: {e}")
        logger.error(traceback.format_exc())
        # Attempt to save state even if training failed mid-way
        try:
            logger.warning("Attempting to save model state after training error...")
            save_path = f"{output_dir}/error_save"
            if trainer is not None:
                 trainer.save_model(output_dir=save_path)
                 tokenizer.save_pretrained(save_path)
                 logger.info(f"Model state saved to {save_path}")
            else:
                 logger.error("Trainer was not initialized, cannot save model.")
        except Exception as save_e:
            logger.error(f"Could not save model after error: {save_e}")
        # Return None for model if training failed critically
        return None, tokenizer

    # Save final model adapters (LoRA weights) and tokenizer
    if trained_model is not None:
        logger.info(f"Saving fine-tuned PEFT adapters and tokenizer to {output_dir}")
        try:
            # save_model() with PEFT model saves only the adapters by default
            trainer.save_model(output_dir)
            # Save the tokenizer configuration as well
            tokenizer.save_pretrained(output_dir)
            logger.info("Model adapters and tokenizer saved successfully.")
        except Exception as e:
            logger.error(f"Error saving final model/tokenizer: {e}")
            pass # Return the model object anyway, even if saving failed
    else:
         logger.warning("Training did not complete successfully, final model not saved via trainer.save_model.")

    # Cleanup
    del trainer
    clear_gpu_memory()

    return trained_model, tokenizer # Return the model (potentially with trained adapters)

class MedicalQuestionAnswerer:
    """Generates answers using the fine-tuned PEFT model."""
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        try:
             # If using device_map, the device might be complex, find a parameter's device
             self.device = next(model.parameters()).device
        except Exception:
             logger.warning("Could not automatically determine model device. Assuming CPU or first CUDA device if available.")
             self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        self.model.eval() # Set the model to evaluation mode
        self.sys_message = "You are an AI Medical Assistant. Give concise and accurate answers." # System prompt for inference
        logger.info(f"Question Answerer initialized on device: {self.device}")

        # Ensure pad token is set for generation
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

    @torch.inference_mode() # Disable gradient calculations for inference
    def answer_question(self, question):
        """Generates an answer for a single question."""
        if not question or not isinstance(question, str):
             logger.warning(f"Invalid question provided: '{question}'. Skipping.")
             return "Error: Invalid question provided."

        logger.debug(f"Answering question: '{question[:50]}...'")
        clear_gpu_memory() # Clear cache before generating

        try:
            # Prepare input using the chat template
            messages = [
                {"role": "system", "content": self.sys_message},
                {"role": "user", "content": question}
            ]
            inputs = self.tokenizer.apply_chat_template(
                messages,
                add_generation_prompt=True, # Add the prompt for the assistant's turn
                return_tensors="pt"
            ).to(self.device) # Move inputs to the model's device

            input_length = inputs.shape[1]

            # Optional: Check if input exceeds inference length limit
            if input_length >= MAX_SEQ_LENGTH_INFERENCE:
                 logger.warning(f"Input sequence length ({input_length}) is >= MAX_SEQ_LENGTH_INFERENCE ({MAX_SEQ_LENGTH_INFERENCE}). Input might be truncated by model implicitly or cause issues.")

            # --- Generate the answer ---
            outputs = self.model.generate(
                input_ids=inputs,
                # *** MODIFIED PARAMETER FOR LONGER ANSWERS ***
                max_new_tokens=400, # Increased from 50 to allow for ~7-8 sentences
                temperature=0.6,    # Controls randomness. Lower is more deterministic.
                top_p=0.9,          # Nucleus sampling parameter
                do_sample=True,     # Enable sampling strategies (temperature, top_p)
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id, # Ensure generation stops at EOS
                use_cache=True      # Speed up generation
                # Add other generation parameters if needed (e.g., repetition_penalty)
            )

            # Decode only the newly generated tokens
            output_tokens = outputs[0, input_length:]
            answer = self.tokenizer.decode(output_tokens, skip_special_tokens=True).strip()

            logger.debug(f"Generated answer: '{answer[:100]}...'")
            return answer

        except Exception as e:
            logger.error(f"Error generating answer for question '{question[:50]}...': {str(e)}")
            logger.debug(traceback.format_exc()) # Log traceback for debugging
            return f"Error: Processing failed during generation."
        finally:
            clear_gpu_memory() # Clear cache after generating

def process_questions_file(input_csv, output_csv, model, tokenizer):
    """Processes questions from CSV sequentially with resuming and saving."""
    logger.info(f"Starting sequential question processing from {input_csv}...")
    start_time = time.time()

    try:
        df = pd.read_csv(input_csv)
        logger.info(f"Loaded {len(df)} questions from {input_csv}.")
    except FileNotFoundError:
        logger.error(f"Input CSV not found: {input_csv}")
        return
    except Exception as e:
        logger.error(f"Error loading input CSV {input_csv}: {e}")
        return

    # Initialize the question answerer class
    answerer = MedicalQuestionAnswerer(model, tokenizer)

    # Ensure output columns exist, initialize with NA if not
    if 'original_answer' not in df.columns: df['original_answer'] = pd.NA
    # Check if 'faq_answer' should exist based on input columns
    if 'generated_question' in df.columns and 'faq_answer' not in df.columns:
        df['faq_answer'] = pd.NA

    # --- Resuming Logic ---
    last_processed_index = -1
    if os.path.exists(output_csv):
        logger.info(f"Output file {output_csv} found. Attempting to resume.")
        try:
            df_existing = pd.read_csv(output_csv)
            # Find the last row where 'original_answer' is NOT NA
            valid_indices = df_existing['original_answer'].dropna().index
            if not valid_indices.empty:
                 last_processed_index = valid_indices[-1]
                 logger.info(f"Resuming from index {last_processed_index + 1}.")
                 # Update the current dataframe with already processed answers from the existing file
                 # Only update up to the last processed index to avoid overwriting potentially newer data
                 df.update(df_existing.iloc[:last_processed_index+1])
            else:
                logger.info("No previously processed answers found in output file, starting fresh.")
        except pd.errors.EmptyDataError:
             logger.warning(f"Output file {output_csv} is empty. Starting fresh.")
             last_processed_index = -1
        except Exception as e:
            logger.warning(f"Could not read or parse existing output file {output_csv}: {e}. Starting fresh.")
            last_processed_index = -1

    questions_processed_since_resume = 0
    total_to_process = len(df) - (last_processed_index + 1)
    if total_to_process <= 0:
        logger.info("No new questions to process based on existing output file.")
        return # Nothing left to do

    # --- Processing Loop ---
    for idx in tqdm(range(last_processed_index + 1, len(df)), desc="Processing Questions", total=total_to_process, unit="q"):
        row = df.iloc[idx]
        row_changed = False # Flag to check if we need to save

        try:
            # Process original question if its answer is missing
            if pd.isna(df.at[idx, 'original_answer']):
                original_question = str(row['original_question']).strip() if pd.notna(row['original_question']) else None
                if original_question:
                    answer = answerer.answer_question(original_question)
                    df.at[idx, 'original_answer'] = answer
                    row_changed = True
                elif pd.notna(row['original_question']): # Handle case where question exists but was empty string
                    df.at[idx, 'original_answer'] = "Error: Missing/empty original question"
                    row_changed = True
                # If original_question was NaN, leave original_answer as NaN

            # Process generated question (if exists and its answer is missing)
            if 'generated_question' in df.columns and 'faq_answer' in df.columns and pd.isna(df.at[idx, 'faq_answer']):
                 generated_question = str(row['generated_question']).strip() if pd.notna(row['generated_question']) else None
                 if generated_question:
                      faq_answer = answerer.answer_question(generated_question)
                      df.at[idx, 'faq_answer'] = faq_answer
                      row_changed = True
                 elif pd.notna(row['generated_question']): # Handle case where generated_question exists but was empty string
                      df.at[idx, 'faq_answer'] = "Error: Missing/empty generated question"
                      row_changed = True
                 # If generated_question was NaN, leave faq_answer as NaN

            if row_changed:
                questions_processed_since_resume += 1

                # Save progress periodically or at the very end
                if (questions_processed_since_resume % SAVE_INTERVAL_INFERENCE == 0) or (idx == len(df) - 1):
                     logger.info(f"\nSaving progress at index {idx}...")
                     try:
                          df.to_csv(output_csv, index=False)
                     except Exception as save_e:
                          logger.error(f"Failed to save progress to {output_csv}: {save_e}")

        except KeyboardInterrupt:
             logger.warning("\nKeyboardInterrupt detected. Saving progress and exiting.")
             try:
                 df.to_csv(output_csv, index=False)
             except Exception as save_e:
                 logger.error(f"Failed to save progress during KeyboardInterrupt exit: {save_e}")
             raise # Re-raise interrupt
        except Exception as e:
            logger.error(f"Critical error processing index {idx}: {e}. Recording error and saving progress.")
            logger.error(traceback.format_exc())
            # Record error in the specific row that failed, if possible
            if pd.isna(df.at[idx, 'original_answer']): df.at[idx, 'original_answer'] = f"Error: Processing Failed - {e}"
            if 'faq_answer' in df.columns and pd.isna(df.at[idx, 'faq_answer']): df.at[idx, 'faq_answer'] = f"Error: Processing Failed - {e}"
            # Try to save the state including the error message
            try:
                df.to_csv(output_csv, index=False)
            except Exception as save_e:
                logger.error(f"Failed to save error state to {output_csv}: {save_e}")
            continue # Attempt to continue with the next row

    # --- Final Save ---
    logger.info("Saving final results...")
    try:
        df.to_csv(output_csv, index=False)
    except Exception as save_e:
        logger.error(f"Failed to save final results to {output_csv}: {save_e}")

    total_time = time.time() - start_time
    logger.info(f"Completed processing {total_to_process} questions. Total time: {total_time / 60:.1f} minutes.")
    logger.info(f"Generated answers saved to {output_csv}")


# --- Main Execution Logic ---

def main():
    logger.info("--- Starting Medical FAQ Fine-Tuning and Processing Script ---")

    # --- 1. Load Base Model and Tokenizer ---
    model = None
    tokenizer = None
    model_for_inference = None # This will hold the model to be used for answering

    try:
        model, tokenizer = load_base_model_and_tokenizer(BASE_MODEL_NAME)
        # Initially, the model for inference is the base model
        model_for_inference = model
    except Exception as e:
        logger.critical(f"Failed to load base model '{BASE_MODEL_NAME}'. Cannot proceed with fine-tuning or inference.")
        # load_base_model_and_tokenizer already logs traceback
        return # Exit script if base model fails

    # --- 2. Prepare Dataset for Fine-Tuning ---
    dataset = process_dataset_for_fine_tuning(INPUT_CSV)

    model_for_training = None # Initialize to None
    if dataset is None:
        logger.warning("Dataset preparation failed or resulted in no data. Skipping fine-tuning.")
    else:
        logger.info("Dataset prepared successfully.")

        # --- 3. Apply PEFT Adapters (only if dataset is valid) ---
        try:
            # We apply PEFT adapters to the 'model' object loaded earlier
            peft_model = apply_peft_to_model(model, tokenizer)
            model_for_training = peft_model # This PEFT model will be trained
            # Update the inference model to use the PEFT version IF PEFT application succeeds
            model_for_inference = peft_model
            logger.info("PEFT adapters applied successfully. Will use PEFT model for training and potentially inference.")
        except Exception as e:
            logger.error(f"Failed to apply PEFT adapters: {e}. Skipping fine-tuning.")
            logger.error(traceback.format_exc())
            model_for_training = None # Ensure this is None if PEFT fails
            # If PEFT fails, model_for_inference remains the original base 'model'

    # --- 4. Fine-tune Model (only if dataset and PEFT model are ready) ---
    # Check both dataset and model_for_training validity
    if dataset is not None and model_for_training is not None:
        logger.info("Proceeding with fine-tuning...")
        trained_model, tokenizer = fine_tune_model(model_for_training, tokenizer, dataset, FINE_TUNED_MODEL_DIR)

        if trained_model is not None:
             logger.info("Fine-tuning process completed (or attempted). Using the resulting model for inference.")
             # Update the inference model to the one returned by fine_tune_model
             # This could be the model with trained adapters or the state before a training crash
             model_for_inference = trained_model
             # Optional: cleanup the reference used just for training if different
             if model_for_training is not trained_model:
                 del model_for_training
             clear_gpu_memory()
        else:
             logger.warning("Fine-tuning function returned None (likely due to critical error). Inference will use the model state from *before* the fine_tune_model call.")
             # model_for_inference is already set to peft_model (if PEFT succeeded) or base model (if PEFT failed)
    elif dataset is None:
         logger.warning("Skipping fine-tuning because dataset preparation failed.")
         # model_for_inference remains the base model 'model' or peft_model if PEFT applied but dataset failed later
    else: # model_for_training must be None because PEFT failed
         logger.warning("Skipping fine-tuning because PEFT adapter application failed.")
         # model_for_inference remains the base model 'model'


    # --- 5. Inference Phase ---
    logger.info("--- Starting Inference Phase ---")
    if model_for_inference is None:
        # This case should ideally not be reached if base model loading succeeded
        logger.error("No valid model available for inference (should have at least the base model). Exiting.")
        return

    # Ensure the final model for inference is in evaluation mode
    model_for_inference.eval()

    # Log which model configuration is being used for inference
    logger.info(f"Preparing for inference using model type: {type(model_for_inference)}")
    if isinstance(model_for_inference, PeftModel):
         logger.info("Inference will use the PEFT model (either freshly adapted or fine-tuned).")
    else:
         logger.info("Inference will use the original BASE model (fine-tuning was skipped or failed).")

    # Run the inference process
    clear_gpu_memory()
    process_questions_file(INPUT_CSV, OUTPUT_CSV, model_for_inference, tokenizer)

    logger.info("--- Script Finished ---")

if __name__ == "__main__":
    # Basic check for psutil, often needed for memory monitoring/limits
    try:
        import psutil
    except ImportError:
        print("Error: psutil library not found. Please install it: pip install psutil")
        # Optionally exit, or let the script fail later if psutil is strictly needed
        # exit()

    main()


INFO:__main__:CUDA not available. Using device: CPU
INFO:__main__:--- Starting Medical FAQ Fine-Tuning and Processing Script ---
INFO:__main__:Attempting to load base model: malhajar/meditron-7b-chat
INFO:__main__:Tokenizer loaded successfully.
INFO:__main__:Setting chat template for Meditron model.
CRITICAL:__main__:Failed to load base model 'malhajar/meditron-7b-chat'. Cannot proceed with fine-tuning or inference.


In [62]:
import os
import torch
import logging
import gc
import pandas as pd
import numpy as np
from datasets import Dataset
import time
from tqdm import tqdm
from typing import Dict, List, Union
import psutil
import traceback

# Try importing necessary libraries early
try:
    from transformers import (
        AutoModelForCausalLM,
        AutoTokenizer,
        TrainingArguments,
        Trainer,
        BitsAndBytesConfig,
        DataCollatorForSeq2Seq
    )
    from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel
    import accelerate
    import bitsandbytes
except ImportError as e:
    print(f"Error importing libraries: {e}")
    print("Please ensure transformers, peft, datasets, accelerate, bitsandbytes, and psutil are installed.")
    exit()

# --- Configuration & Setup ---

# Set environment variables for memory optimization
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Control GPU visibility if needed

# Initialize logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Check for GPU availability and set device
if torch.cuda.is_available():
    device = torch.device("cuda")
    logger.info(f"CUDA available. Using device: {device}")
    try:
        gpu_index = torch.cuda.current_device()
        gpu_name = torch.cuda.get_device_name(gpu_index)
        logger.info(f"CUDA device name: {gpu_name}")
        t = torch.cuda.get_device_properties(gpu_index).total_memory
        r = torch.cuda.memory_reserved(gpu_index)
        a = torch.cuda.memory_allocated(gpu_index)
        f = r - a
        logger.info(f"Initial GPU Memory (Bytes): Total={t}, Reserved={r}, Allocated={a}")
        logger.info(f"Initial GPU Memory (GB): Total={t/1e9:.2f}GB, Reserved={r/1e9:.2f}GB, Allocated={a/1e9:.2f}GB, FreeReserved={f/1e9:.2f}GB")
    except Exception as e:
        logger.error(f"Could not get GPU details: {e}")
else:
    device = torch.device("cpu")
    logger.info("CUDA not available. Using device: CPU")

# Global configuration
INPUT_CSV = "T5_FAQS1.csv"
OUTPUT_CSV = "medical_answers_finetuned_v7.csv" # Incremented version
FINE_TUNED_MODEL_DIR = "fine_tuned_medical_model_v6" # Incremented version
BASE_MODEL_NAME = "malhajar/meditron-7b-chat"
MAX_DATASET_EXAMPLES = 50 # Limit examples for faster testing/demo
MAX_TRAINING_STEPS = 50   # Limit training steps for faster testing/demo
MAX_SEQ_LENGTH_COLLATOR = 128 # Max length for sequences during training (affects memory)
MAX_SEQ_LENGTH_INFERENCE = 500 # Max length for input sequence during inference
INFERENCE_BATCH_SIZE = 1 # Process one question at a time for inference
SAVE_INTERVAL_INFERENCE = 5 # Save progress every N questions during inference

# --- Function Definitions ---

def clear_gpu_memory():
    """Clears GPU memory."""
    logger.debug("Clearing GPU Cache...")
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()
        logger.debug("GPU Cache Cleared.")
    else:
        logger.debug("No GPU available, skipping cache clearing.")

def load_base_model_and_tokenizer(model_name):
    """Loads the base quantized model and tokenizer with explicit memory limits."""
    logger.info(f"Attempting to load base model: {model_name}")
    clear_gpu_memory()

    # 1. Load Tokenizer
    try:
        tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            trust_remote_code=True,
            use_fast=True
        )
        logger.info("Tokenizer loaded successfully.")
    except Exception as e:
        logger.error(f"Failed to load tokenizer for {model_name}: {e}")
        raise

    # 2. Set Chat Template & Padding Token
    if not hasattr(tokenizer, 'chat_template') or tokenizer.chat_template is None:
        logger.info("Setting chat template for Meditron model.")
        # Basic template structure - adjust if needed based on model card
        tokenizer.chat_template = """{% for message in messages %}{% if message['role'] == 'system' %}### Instruction:
{{ message['content'] }}
{% elif message['role'] == 'user' %}### Instruction:
{{ message['content'] }}
{% elif message['role'] == 'assistant' %}### Response:
{{ message['content'] }}
{% endif %}{% if loop.last and add_generation_prompt %}### Response:
{% endif %}{% endfor %}"""
    if tokenizer.pad_token is None:
        logger.warning("Tokenizer does not have a pad token. Setting to eos_token.")
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id # Explicitly set ID

    # 3. Configure Quantization
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4"
    )
    logger.info(f"Using quantization config: {quantization_config}")

    # 4. Define Memory Limits (Using INTEGER key for GPU)
    max_memory = {}
    if torch.cuda.is_available():
        gpu_index = 0 # Assuming device 0
        total_vram_bytes = torch.cuda.get_device_properties(gpu_index).total_memory
        # Leave ~2GB buffer for safety, adjust if needed
        gpu_mem_limit_bytes = total_vram_bytes - int(2 * 1024**3)
        max_memory[gpu_index] = f"{gpu_mem_limit_bytes // (1024**2)}MiB" # Use integer index
        logger.info(f"Calculated GPU memory limit for device {gpu_index}: {max_memory[gpu_index]}")
    total_ram_bytes = psutil.virtual_memory().total
    # Limit CPU RAM usage to 80% to avoid system freeze
    cpu_mem_limit_bytes = int(total_ram_bytes * 0.80)
    max_memory['cpu'] = f"{cpu_mem_limit_bytes // (1024**2)}MiB"

    logger.info(f"Setting max_memory for accelerate: {max_memory}")

    # 5. Load Model
    try:
        logger.info("Loading 4-bit quantized model with device_map='auto' and max_memory...")
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto", # Automatically distribute model layers across devices
            quantization_config=quantization_config,
            torch_dtype=torch.float16, # Use float16 for memory efficiency
            low_cpu_mem_usage=True, # Try to load shards sequentially to save CPU RAM
            max_memory=max_memory, # Apply memory limits per device
            trust_remote_code=True
        )
        logger.info("Base model loaded successfully.")
        logger.info(f"Model device map: {model.hf_device_map}")
    except Exception as e:
        logger.error(f"Failed to load base model {model_name} even with max_memory: {e}")
        logger.error(traceback.format_exc()) # Log full traceback
        raise

    return model, tokenizer

def apply_peft_to_model(model, tokenizer):
    """Applies PEFT/LoRA adapters to the loaded base model."""
    logger.info("Applying PEFT/LoRA adapters to the model...")
    clear_gpu_memory()

    # 1. Prepare model for k-bit training
    logger.info("Preparing model for k-bit training...")
    # Ensure gradient checkpointing is enabled here if desired for training
    try:
         # use_gradient_checkpointing=True can sometimes cause issues, let TrainingArguments handle it primarily.
         # Set it to False here, rely on TrainingArguments.gradient_checkpointing=True
         model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False)
         logger.info("prepare_model_for_kbit_training successful.")
    except Exception as e:
         logger.error(f"Error during prepare_model_for_kbit_training: {e}")
         raise

    # 2. Ensure input embeddings require gradients (can be crucial for some models/setups)
    if hasattr(model, "enable_input_require_grads"):
        logger.info("Enabling input require grads using model.enable_input_require_grads().")
        model.enable_input_require_grads()
    else:
        # Fallback method if the direct function isn't available
        logger.info("Attempting to enable input require grads using forward hook.")
        try:
            def make_inputs_require_grad(module, input, output):
                 if isinstance(output, torch.Tensor) and output.is_floating_point():
                     output.requires_grad_(True)
            embed_module = model.get_input_embeddings()
            if embed_module:
                 embed_module.register_forward_hook(make_inputs_require_grad)
                 logger.info("Gradient hook attached to input embeddings.")
            else:
                 logger.warning("Could not find input embeddings module.")
        except Exception as e:
             logger.warning(f"Failed to attach gradient hook: {e}. This might be okay.")

    # 3. Define LoRA configuration
    lora_config = LoraConfig(
        r=8, # Rank of the update matrices (lower value = fewer parameters)
        lora_alpha=16, # LoRA scaling factor
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], # Modules to apply LoRA to (check model architecture if needed)
        lora_dropout=0.05, # Dropout probability for LoRA layers
        bias="none", # Whether to train biases ('none', 'all', or 'lora_only')
        task_type="CAUSAL_LM" # Task type for PEFT
    )
    logger.info(f"Using LoRA config: {lora_config}")

    # 4. Apply LoRA using get_peft_model
    try:
        peft_model = get_peft_model(model, lora_config)
        logger.info("PEFT model created successfully using get_peft_model.")
    except Exception as e:
        logger.error(f"Failed to apply PEFT to the model using get_peft_model: {e}")
        raise

    # 5. Verify trainable parameters
    peft_model.print_trainable_parameters()

    # 6. <<< ADDED STEP: Explicitly mark as PEFT model >>>
    # This might help older/buggy Trainer checks, although usually not needed.
    if not isinstance(peft_model, PeftModel):
         logger.warning("get_peft_model did not return a PeftModel instance!")
    else:
         logger.info("Model is instance of PeftModel. Setting is_peft_model=True attribute just in case.")
         # Use setattr for safety in case attribute doesn't exist on all versions
         setattr(peft_model, 'is_peft_model', True)

    logger.info(f"Model type after get_peft_model: {type(peft_model)}")
    return peft_model

def process_dataset_for_fine_tuning(csv_file):
    """Loads and formats the dataset for fine-tuning, focusing on ideal_answer format."""
    logger.info(f"Processing dataset from {csv_file}")

    try:
        df = pd.read_csv(csv_file)
        logger.info(f"Loaded dataframe with {len(df)} rows.")
    except FileNotFoundError:
        logger.error(f"Dataset file not found: {csv_file}")
        return None
    except Exception as e:
        logger.error(f"Error loading dataset CSV: {e}")
        return None

    # Ensure required columns exist
    required_cols = ['original_question', 'ideal_answer']
    if not all(col in df.columns for col in required_cols):
        logger.error(f"Dataset missing required columns: {required_cols}. Found: {df.columns.tolist()}")
        return None

    # Drop rows where essential columns are missing
    df.dropna(subset=required_cols, inplace=True)
    logger.info(f"Rows after dropping NA in required columns: {len(df)}")

    if df.empty:
        logger.error("No valid data remaining after filtering.")
        return None

    # Limit dataset size if configured
    if len(df) > MAX_DATASET_EXAMPLES:
        logger.info(f"Limiting dataset to {MAX_DATASET_EXAMPLES} examples (from {len(df)})")
        df = df.sample(n=MAX_DATASET_EXAMPLES, random_state=42) # Use sampling with fixed state

    logger.info(f"Using {len(df)} examples for fine-tuning.")

    # Define the system prompt - updated to match BioASQ ideal answer format
    system_prompt = "You are a biomedical expert. Provide accurate, concise, and comprehensive answers to medical questions. Your answers should be one-paragraph summaries that address the question completely, based on current medical knowledge."

    # Format data into conversation structure
    formatted_data = []
    for _, row in df.iterrows():
        question = str(row['original_question']).strip()
        answer = str(row['ideal_answer']).strip()
        if not question or not answer: # Skip rows with empty question or answer
            continue
        # Structure expected by tokenizer.apply_chat_template
        conversation = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": question},
            {"role": "assistant", "content": answer} # The expected completion
        ]
        formatted_data.append({"conversation": conversation})

    if not formatted_data:
         logger.error("No valid conversation pairs found after formatting.")
         return None

    # Create Hugging Face Dataset object
    try:
        dataset = Dataset.from_list(formatted_data)
        logger.info("HuggingFace Dataset created successfully.")
    except Exception as e:
        logger.error(f"Failed to create HuggingFace Dataset: {e}")
        return None

    return dataset

class MedicalChatDataCollator:
    """Formats conversations and prepares inputs/labels for Causal LM fine-tuning."""
    def __init__(self, tokenizer, max_length=MAX_SEQ_LENGTH_COLLATOR):
        self.tokenizer = tokenizer
        self.max_length = max_length
        if not hasattr(self.tokenizer, "apply_chat_template"):
             raise ValueError("Tokenizer must have `apply_chat_template` method.")
        if self.tokenizer.pad_token_id is None:
            logger.warning("Collator: Tokenizer pad_token_id is None. Using eos_token_id.")
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

    def __call__(self, examples: List[Dict[str, List[Dict[str, str]]]]) -> Dict[str, torch.Tensor]:
        batch_conversations = [ex["conversation"] for ex in examples]

        try:
             # Prepare Inputs (prompt part only: system + user)
             # We add `add_generation_prompt=True` which typically adds the assistant prompt start (e.g., "### Response:")
             input_formatted = [
                  self.tokenizer.apply_chat_template(conv[:-1], tokenize=False, add_generation_prompt=True)
                  for conv in batch_conversations
             ]
             # Tokenize the inputs
             model_inputs = self.tokenizer(
                  input_formatted,
                  max_length=self.max_length,
                  padding="max_length", # Pad to max_length
                  truncation=True,
                  return_tensors="pt"
             )

             # Prepare Labels (full conversation: system + user + assistant)
             # We do NOT add generation prompt here, as we want the full sequence including the assistant's answer.
             labels_formatted = [
                  self.tokenizer.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)
                  for conv in batch_conversations
             ]
             # Tokenize the full conversation to get the target labels
             labels = self.tokenizer(
                  labels_formatted,
                  max_length=self.max_length,
                  padding="max_length", # Pad to max_length
                  truncation=True,
                  return_tensors="pt"
             )["input_ids"]

             # --- Label Masking ---
             # We only want to compute loss on the assistant's response tokens.
             # Mask tokens belonging to the system prompt, user query, and padding.
             masked_labels = labels.clone()

             for i in range(len(model_inputs["input_ids"])):
                  # Calculate the length of the input prompt (system + user + assistant prompt start)
                  # Use attention_mask sum, as input_ids might contain padding AFTER truncation
                  input_ids_len = model_inputs["attention_mask"][i].sum().item()

                  # Mask all tokens up to the end of the input prompt
                  masked_labels[i, :input_ids_len] = -100 # -100 is the standard ignore index for loss calculation

                  # Also mask padding tokens in the labels
                  # Find indices where labels are the padding token ID
                  label_pad_indices = (labels[i] == self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0]
                  if len(label_pad_indices) > 0:
                      # Mask all padding tokens (usually at the end)
                      # Get the index of the first padding token
                      first_pad_index = label_pad_indices[0].item()
                      masked_labels[i, first_pad_index:] = -100

             # Check if any example ended up with all labels masked (problematic)
             if torch.all(masked_labels == -100, dim=1).any():
                  logger.warning("Warning: An example has all labels masked. This might indicate issues with sequence lengths, truncation, or the collator logic.")
                  # Consider adding more detailed logging here if this happens frequently

             model_inputs["labels"] = masked_labels
             return model_inputs
        except Exception as e:
             logger.error(f"Error in Data Collator: {e}")
             logger.error(traceback.format_exc())
             # Return empty batch or raise error? Returning empty might cause issues later.
             raise # Let the error propagate

def fine_tune_model(model, tokenizer, dataset, output_dir):
    """Fine-tunes the PEFT model using the Trainer API."""
    logger.info("Starting fine-tuning process...")
    clear_gpu_memory()

    os.makedirs(output_dir, exist_ok=True)

    # Training Arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=1, # Single epoch for quick demo/test
        per_device_train_batch_size=1, # Small batch size due to model size/memory
        gradient_accumulation_steps=16, # Accumulate gradients to simulate larger batch size (effective batch size = 1 * 16 = 16)
        learning_rate=2e-4, # Common learning rate for LoRA
        weight_decay=0.01, # Regularization
        warmup_ratio=0.03, # Warmup steps as a fraction of total steps
        logging_steps=5, # Log metrics every 5 steps
        save_strategy="steps", # Save checkpoints based on steps
        save_steps=max(1, MAX_TRAINING_STEPS // 2), # Save halfway through (adjust as needed)
        save_total_limit=1, # Keep only the latest checkpoint
        fp16=torch.cuda.is_available(), # Enable mixed-precision training if CUDA is available
        gradient_checkpointing=True, # Use gradient checkpointing to save memory (at cost of compute) - RELY ON THIS
        optim="paged_adamw_8bit", # Use paged AdamW optimizer for memory efficiency with QLoRA
        max_grad_norm=0.3, # Gradient clipping
        dataloader_num_workers=0, # Set to 0 or small number, can cause issues otherwise
        dataloader_pin_memory=False, # Often False is better with device_map='auto'
        max_steps=MAX_TRAINING_STEPS, # Limit total training steps
        report_to="none", # Disable external reporting (like wandb) for simplicity
        push_to_hub=False, # Don't push to Hugging Face Hub
        remove_unused_columns=True, # Let Trainer remove columns not used by the model
        # Added for potential stability
        # ddp_find_unused_parameters=False # Sometimes needed with PEFT/gradient checkpointing
    )
    logger.info(f"Using Training Arguments: {training_args}")

    # Data Collator Instance
    data_collator = MedicalChatDataCollator(tokenizer, max_length=MAX_SEQ_LENGTH_COLLATOR)

    # === Debugging Logs Before Trainer Init ===
    logger.info(f"--- Preparing to initialize Trainer ---")
    logger.info(f"Model object ID: {id(model)}")
    logger.info(f"Model type passed to fine_tune_model: {type(model)}")
    logger.info(f"Is model instance of PeftModel? {isinstance(model, PeftModel)}")
    logger.info(f"Does model have 'peft_config' attribute? {hasattr(model, 'peft_config')}")
    logger.info(f"Does model have 'is_peft_model' attribute set? {getattr(model, 'is_peft_model', 'Not Set')}")
    if hasattr(model, 'hf_device_map'):
         logger.info(f"Model device map: {model.hf_device_map}")
    else:
         logger.info("Model does not have 'hf_device_map' attribute.")
    if hasattr(model, 'is_quantized'):
         logger.info(f"Model is_quantized: {model.is_quantized}")
    else:
         logger.info("Model does not have 'is_quantized' attribute.")
    # Check base model properties if it's a PeftModel
    if isinstance(model, PeftModel) and hasattr(model, 'base_model'):
         base = model.base_model
         logger.info(f"Base model type: {type(base)}")
         logger.info(f"Base model is_quantized: {getattr(base, 'is_quantized', 'Not Set')}")
         logger.info(f"Base model has quantization_config: {hasattr(base.config, 'quantization_config')}")
         if hasattr(base.config, 'quantization_config'):
              logger.info(f"Base model quantization_config type: {type(base.config.quantization_config)}")
    logger.info(f"--- End Pre-Trainer Init Logs ---")
    # === End Debugging Logs ===

    # Trainer Initialization
    trainer = None # Initialize to None
    try:
        trainer = Trainer(
            model=model, # Should be the PeftModel instance from apply_peft_to_model
            args=training_args,
            train_dataset=dataset,
            tokenizer=tokenizer,
            data_collator=data_collator,
            # callbacks=[...] # Add custom callbacks if needed
        )
        logger.info("Trainer initialized successfully.")
    except ValueError as ve:
        logger.error(f"ValueError during Trainer initialization: {ve}")
        logger.error("This likely means the Trainer still doesn't recognize the model as PEFT-compatible, possibly due to issues with quantization or PEFT setup.")
        logger.error(traceback.format_exc())
        return None, tokenizer # Cannot proceed if Trainer fails
    except Exception as e:
        logger.error(f"Unexpected error during Trainer initialization: {e}")
        logger.error(traceback.format_exc())
        return None, tokenizer

    # Training
    logger.info("Starting training...")
    trained_model = None
    try:
        # Start training
        train_result = trainer.train()
        logger.info("Training completed.")
        # Log metrics
        metrics = train_result.metrics
        logger.info(f"Training metrics: {metrics}")
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trained_model = model # Assign model if training successful
    except Exception as e:
        logger.error(f"Error during training: {e}")
        logger.error(traceback.format_exc())
        # Attempt to save state even if training failed mid-way
        try:
            logger.warning("Attempting to save model state after training error...")
            save_path = f"{output_dir}/error_save"
            if trainer is not None:
                 trainer.save_model(output_dir=save_path)
                 tokenizer.save_pretrained(save_path)
                 logger.info(f"Model state saved to {save_path}")
            else:
                 logger.error("Trainer was not initialized, cannot save model.")
        except Exception as save_e:
            logger.error(f"Could not save model after error: {save_e}")
        # Return None for model if training failed critically
        return None, tokenizer

    # Save final model adapters (LoRA weights) and tokenizer
    if trained_model is not None:
        logger.info(f"Saving fine-tuned PEFT adapters and tokenizer to {output_dir}")
        try:
            # save_model() with PEFT model saves only the adapters by default
            trainer.save_model(output_dir)
            # Save the tokenizer configuration as well
            tokenizer.save_pretrained(output_dir)
            logger.info("Model adapters and tokenizer saved successfully.")
        except Exception as e:
            logger.error(f"Error saving final model/tokenizer: {e}")
            pass # Return the model object anyway, even if saving failed
    else:
         logger.warning("Training did not complete successfully, final model not saved via trainer.save_model.")

    # Cleanup
    del trainer
    clear_gpu_memory()

    return trained_model, tokenizer # Return the model (potentially with trained adapters)

class MedicalQuestionAnswerer:
    """Generates answers using the fine-tuned PEFT model, formatted as ideal answers."""
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        try:
             # If using device_map, the device might be complex, find a parameter's device
             self.device = next(model.parameters()).device
        except Exception:
             logger.warning("Could not automatically determine model device. Assuming CPU or first CUDA device if available.")
             self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model.eval() # Set the model to evaluation mode
        
        # Updated system prompt to match BioASQ ideal answer format
        self.sys_message = "You are a biomedical expert. Provide a concise, one-paragraph answer that fully addresses the medical question. Your answer should be factual, accurate, and based on current medical knowledge, written for other experts in the field."
        
        logger.info(f"Question Answerer initialized on device: {self.device}")

        # Ensure pad token is set for generation
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

    @torch.inference_mode() # Disable gradient calculations for inference
    def answer_question(self, question):
        """Generates an ideal answer for a medical question."""
        if not question or not isinstance(question, str):
             logger.warning(f"Invalid question provided: '{question}'. Skipping.")
             return "Error: Invalid question provided."

        logger.debug(f"Answering question: '{question[:50]}...'")
        clear_gpu_memory() # Clear cache before generating

        try:
            # Prepare input using the chat template
            messages = [
                {"role": "system", "content": self.sys_message},
                {"role": "user", "content": question}
            ]
            inputs = self.tokenizer.apply_chat_template(
                messages,
                add_generation_prompt=True, # Add the prompt for the assistant's turn
                return_tensors="pt"
            ).to(self.device) # Move inputs to the model's device

            input_length = inputs.shape[1]

            # Optional: Check if input exceeds inference length limit
            if input_length >= MAX_SEQ_LENGTH_INFERENCE:
                 logger.warning(f"Input sequence length ({input_length}) is >= MAX_SEQ_LENGTH_INFERENCE ({MAX_SEQ_LENGTH_INFERENCE}). Input might be truncated by model implicitly or cause issues.")

            # --- Generate the answer ---
            outputs = self.model.generate(
                input_ids=inputs,
                # Parameters adjusted for BioASQ ideal answer format
                max_new_tokens=450, # Ideal answers are typically one paragraph
                temperature=0.5,    # Lower temperature for more factual responses
                top_p=0.9,          # Nucleus sampling parameter
                do_sample=True,     # Enable sampling strategies
                num_beams=3,        # Use beam search for better coherence
                no_repeat_ngram_size=3, # Avoid repetition of phrases
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
                use_cache=True      # Speed up generation
            )

            # Decode only the newly generated tokens
            output_tokens = outputs[0, input_length:]
            answer = self.tokenizer.decode(output_tokens, skip_special_tokens=True).strip()
            
            # Format the answer as a single paragraph (BioASQ ideal answer format)
            answer = ' '.join(answer.split())

            logger.debug(f"Generated ideal answer: '{answer[:100]}...'")
            return answer

        except Exception as e:
            logger.error(f"Error generating answer for question '{question[:50]}...': {str(e)}")
            logger.debug(traceback.format_exc()) # Log traceback for debugging
            return f"Error: Processing failed during generation."
        finally:
            clear_gpu_memory() # Clear cache after generating

def process_questions_file(input_csv, output_csv, model, tokenizer):
    """Processes questions from CSV sequentially with resuming and saving."""
    logger.info(f"Starting sequential question processing from {input_csv}...")
    start_time = time.time()

    try:
        df = pd.read_csv(input_csv)
        logger.info(f"Loaded {len(df)} questions from {input_csv}.")
    except FileNotFoundError:
        logger.error(f"Input CSV not found: {input_csv}")
        return
    except Exception as e:
        logger.error(f"Error loading input CSV {input_csv}: {e}")
        return

    # Initialize the question answerer class
    answerer = MedicalQuestionAnswerer(model, tokenizer)

    # Ensure output columns exist, initialize with NA if not
    if 'original_answer' not in df.columns: df['original_answer'] = pd.NA
    # Check if 'ideal_answer' column exists, if not create it
    if 'ideal_answer' not in df.columns: df['ideal_answer'] = pd.NA
    # Check if 'faq_answer' should exist based on input columns
    if 'generated_question' in df.columns and 'faq_answer' not in df.columns:
        df['faq_answer'] = pd.NA

    # --- Resuming Logic ---
    last_processed_index = -1
    if os.path.exists(output_csv):
        logger.info(f"Output file {output_csv} found. Attempting to resume.")
        try:
            df_existing = pd.read_csv(output_csv)
            # Find the last row where 'ideal_answer' is NOT NA
            valid_indices = df_existing['ideal_answer'].dropna().index
            if not valid_indices.empty:
                 last_processed_index = valid_indices[-1]
                 logger.info(f"Resuming from index {last_processed_index + 1}.")
                 # Update the current dataframe with already processed answers from the existing file
                 # Only update up to the last processed index to avoid overwriting potentially newer data
                 df.update(df_existing.iloc[:last_processed_index+1])
            else:
                logger.info("No previously processed answers found in output file, starting fresh.")
        except pd.errors.EmptyDataError:
             logger.warning(f"Output file {output_csv} is empty. Starting fresh.")
             last_processed_index = -1
        except Exception as e:
            logger.warning(f"Could not read or parse existing output file {output_csv}: {e}. Starting fresh.")
            last_processed_index = -1

    questions_processed_since_resume = 0
    total_to_process = len(df) - (last_processed_index + 1)
    if total_to_process <= 0:
        logger.info("No new questions to process based on existing output file.")
        return # Nothing left to do

    # --- Processing Loop ---
    for idx in tqdm(range(last_processed_index + 1, len(df)), desc="Processing Questions", total=total_to_process, unit="q"):
        row = df.iloc[idx]
        row_changed = False # Flag to check if we need to save

        try:
            # Process original question if its ideal answer is missing
            if pd.isna(df.at[idx, 'ideal_answer']):
                original_question = str(row['original_question']).strip() if pd.notna(row['original_question']) else None
                if original_question:
                    # Generate ideal answer format response
                    ideal_answer = answerer.answer_question(original_question)
                    df.at[idx, 'ideal_answer'] = ideal_answer
                    row_changed = True
                elif pd.notna(row['original_question']): # Handle case where question exists but was empty string
                    df.at[idx, 'ideal_answer'] = "Error: Missing/empty original question"
                    row_changed = True
                # If original_question was NaN, leave ideal_answer as NaN

            # Process original question if its answer is missing (for backward compatibility)
            if pd.isna(df.at[idx, 'original_answer']):
                original_question = str(row['original_question']).strip() if pd.notna(row['original_question']) else None
                if original_question:
                    # If we already generated an ideal_answer, use it for original_answer too
                    if pd.notna(df.at[idx, 'ideal_answer']):
                        df.at[idx, 'original_answer'] = df.at[idx, 'ideal_answer']
                    else:
                        answer = answerer.answer_question(original_question)
                        df.at[idx, 'original_answer'] = answer
                    row_changed = True
                elif pd.notna(row['original_question']): # Handle case where question exists but was empty string
                    df.at[idx, 'original_answer'] = "Error: Missing/empty original question"
                    row_changed = True
                # If original_question was NaN, leave original_answer as NaN

            # Process generated question (if exists and its answer is missing)
            if 'generated_question' in df.columns and 'faq_answer' in df.columns and pd.isna(df.at[idx, 'faq_answer']):
                 generated_question = str(row['generated_question']).strip() if pd.notna(row['generated_question']) else None
                 if generated_question:
                      faq_answer = answerer.answer_question(generated_question)
                      df.at[idx, 'faq_answer'] = faq_answer
                      row_changed = True
                 elif pd.notna(row['generated_question']): # Handle case where generated_question exists but was empty string
                      df.at[idx, 'faq_answer'] = "Error: Missing/empty generated question"
                      row_changed = True
                 # If generated_question was NaN, leave faq_answer as NaN

            if row_changed:
                 questions_processed_since_resume += 1
                 if (questions_processed_since_resume % SAVE_INTERVAL_INFERENCE == 0) or (idx == len(df) - 1):
                     logger.info(f"\nSaving progress at index {idx}...")
                     try:
                          df.to_csv(output_csv, index=False)
                     except Exception as save_e:
                          logger.error(f"Failed to save progress to {output_csv}: {save_e}")

        except KeyboardInterrupt:
             logger.warning("\nKeyboardInterrupt detected. Saving progress and exiting.")
             try:
                 df.to_csv(output_csv, index=False)
             except Exception as save_e:
                 logger.error(f"Failed to save progress during KeyboardInterrupt exit: {save_e}")
             raise # Re-raise interrupt
        except Exception as e:
            logger.error(f"Critical error processing index {idx}: {e}. Recording error and saving progress.")
            logger.error(traceback.format_exc())
            # Record error in the specific row that failed, if possible
            if pd.isna(df.at[idx, 'ideal_answer']): df.at[idx, 'ideal_answer'] = f"Error: Processing Failed - {e}"
            if pd.isna(df.at[idx, 'original_answer']): df.at[idx, 'original_answer'] = f"Error: Processing Failed - {e}"
            if 'faq_answer' in df.columns and pd.isna(df.at[idx, 'faq_answer']): df.at[idx, 'faq_answer'] = f"Error: Processing Failed - {e}"
            # Try to save the state including the error message
            try:
                df.to_csv(output_csv, index=False)
            except Exception as save_e:
                logger.error(f"Failed to save error state to {output_csv}: {save_e}")
            continue # Attempt to continue with the next row

    # --- Final Save ---
    logger.info("Saving final results...")
    try:
        df.to_csv(output_csv, index=False)
    except Exception as save_e:
        logger.error(f"Failed to save final results to {output_csv}: {save_e}")

    total_time = time.time() - start_time
    logger.info(f"Completed processing {total_to_process} questions. Total time: {total_time / 60:.1f} minutes.")
    logger.info(f"Generated answers saved to {output_csv}")


# --- Main Execution Logic ---

def main():
    logger.info("--- Starting Medical FAQ Fine-Tuning and Processing Script ---")

    # --- 1. Load Base Model and Tokenizer ---
    model = None
    tokenizer = None
    model_for_inference = None # This will hold the model to be used for answering

    try:
        model, tokenizer = load_base_model_and_tokenizer(BASE_MODEL_NAME)
        # Initially, the model for inference is the base model
        model_for_inference = model
    except Exception as e:
        logger.critical(f"Failed to load base model '{BASE_MODEL_NAME}'. Cannot proceed with fine-tuning or inference.")
        # load_base_model_and_tokenizer already logs traceback
        return # Exit script if base model fails

    # --- 2. Prepare Dataset for Fine-Tuning ---
    dataset = process_dataset_for_fine_tuning(INPUT_CSV)

    model_for_training = None # Initialize to None
    if dataset is None:
        logger.warning("Dataset preparation failed or resulted in no data. Skipping fine-tuning.")
    else:
        logger.info("Dataset prepared successfully.")

        # --- 3. Apply PEFT Adapters (only if dataset is valid) ---
        try:
            # We apply PEFT adapters to the 'model' object loaded earlier
            peft_model = apply_peft_to_model(model, tokenizer)
            model_for_training = peft_model # This PEFT model will be trained
            # Update the inference model to use the PEFT version IF PEFT application succeeds
            model_for_inference = peft_model
            logger.info("PEFT adapters applied successfully. Will use PEFT model for training and potentially inference.")
        except Exception as e:
            logger.error(f"Failed to apply PEFT adapters: {e}. Skipping fine-tuning.")
            logger.error(traceback.format_exc())
            model_for_training = None # Ensure this is None if PEFT fails
            # If PEFT fails, model_for_inference remains the original base 'model'

    # --- 4. Fine-tune Model (only if dataset and PEFT model are ready) ---
    # Check both dataset and model_for_training validity
    if dataset is not None and model_for_training is not None:
        logger.info("Proceeding with fine-tuning...")
        trained_model, tokenizer = fine_tune_model(model_for_training, tokenizer, dataset, FINE_TUNED_MODEL_DIR)

        if trained_model is not None:
             logger.info("Fine-tuning process completed (or attempted). Using the resulting model for inference.")
             # Update the inference model to the one returned by fine_tune_model
             # This could be the model with trained adapters or the state before a training crash
             model_for_inference = trained_model
             # Optional: cleanup the reference used just for training if different
             if model_for_training is not trained_model:
                 del model_for_training
             clear_gpu_memory()
        else:
             logger.warning("Fine-tuning function returned None (likely due to critical error). Inference will use the model state from *before* the fine_tune_model call.")
             # model_for_inference is already set to peft_model (if PEFT succeeded) or base model (if PEFT failed)
    elif dataset is None:
         logger.warning("Skipping fine-tuning because dataset preparation failed.")
         # model_for_inference remains the base model 'model' or peft_model if PEFT applied but dataset failed later
    else: # model_for_training must be None because PEFT failed
         logger.warning("Skipping fine-tuning because PEFT adapter application failed.")
         # model_for_inference remains the base model 'model'


    # --- 5. Inference Phase ---
    logger.info("--- Starting Inference Phase ---")
    if model_for_inference is None:
        # This case should ideally not be reached if base model loading succeeded
        logger.error("No valid model available for inference (should have at least the base model). Exiting.")
        return

    # Ensure the final model for inference is in evaluation mode
    model_for_inference.eval()

    # Log which model configuration is being used for inference
    logger.info(f"Preparing for inference using model type: {type(model_for_inference)}")
    if isinstance(model_for_inference, PeftModel):
         logger.info("Inference will use the PEFT model (either freshly adapted or fine-tuned).")
    else:
         logger.info("Inference will use the original BASE model (fine-tuning was skipped or failed).")

    # Run the inference process
    clear_gpu_memory()
    process_questions_file(INPUT_CSV, OUTPUT_CSV, model_for_inference, tokenizer)

    logger.info("--- Script Finished ---")

if __name__ == "__main__":
    # Basic check for psutil, often needed for memory monitoring/limits
    try:
        import psutil
    except ImportError:
        print("Error: psutil library not found. Please install it: pip install psutil")
        # Optionally exit, or let the script fail later if psutil is strictly needed
        # exit()

    main()


INFO:__main__:CUDA not available. Using device: CPU
INFO:__main__:--- Starting Medical FAQ Fine-Tuning and Processing Script ---
INFO:__main__:Attempting to load base model: malhajar/meditron-7b-chat
INFO:__main__:Tokenizer loaded successfully.
INFO:__main__:Setting chat template for Meditron model.
CRITICAL:__main__:Failed to load base model 'malhajar/meditron-7b-chat'. Cannot proceed with fine-tuning or inference.


In [63]:
 pip install bitsandbytes==0.42.0 --no-cache-dir


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Defaulting to user installation because normal site-packages is not writeable
Collecting bitsandbytes==0.42.0
  Downloading bitsandbytes-0.42.0-py3-none-any.whl.metadata (9.9 kB)
Downloading bitsandbytes-0.42.0-py3-none-any.whl (105.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m105.0/105.0 MB[0m [31m58.2 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[?25hInstalling collected packages: bitsandbytes
[31mERROR: Could not install packages due to an OSError: [Errno 28] No space left on device
[0m[31m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


*Final*

In [16]:
import os
import torch
import logging
import gc
import pandas as pd
import numpy as np
from datasets import Dataset
import time
from tqdm import tqdm
from typing import Dict, List, Union
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel

# FORCE CPU USAGE - Set before any other imports
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["OMP_NUM_THREADS"] = "20"
os.environ["MKL_NUM_THREADS"] = "20"
os.environ["NUMEXPR_NUM_THREADS"] = "20"

# Configure PyTorch for CPU only
torch.set_num_threads(20)
torch.cuda.is_available = lambda: False  # Force PyTorch to think CUDA is not available

# Initialize logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Force CPU usage
device = torch.device("cpu")
logger.info(f"Using device: {device}")
logger.info(f"CPU threads configured: {torch.get_num_threads()}")
logger.info(f"CUDA available: {torch.cuda.is_available()}")

def load_medical_model(model_name="malhajar/meditron-7b-chat"):
    logger.info(f"Loading model: {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(
        model_name, 
        trust_remote_code=True,
        use_fast=True
    )
    if not hasattr(tokenizer, 'chat_template') or tokenizer.chat_template is None:
        logger.info("Setting chat template for Meditron")
        tokenizer.chat_template = """{% for message in messages %}
{% if message['role'] == 'system' %}### Instruction:
{{ message['content'] }}
{% elif message['role'] == 'user' %}### Instruction:
{{ message['content'] }}
{% elif message['role'] == 'assistant' %}### Response:
{{ message['content'] }}
{% endif %}
{% if loop.last and add_generation_prompt %}### Response:
{% endif %}
{% endfor %}"""
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        logger.info("Set pad_token to eos_token")
    gc.collect()
    logger.info("Loading model for CPU execution...")
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float32,
        device_map="cpu",
        low_cpu_mem_usage=True,
        trust_remote_code=True
    )
    # Ensure model is on CPU
    model = model.to(device)
    logger.info("Model loaded successfully")
    return model, tokenizer

def apply_peft_to_model(model):
    logger.info("Applying PEFT to the model...")
    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
    if hasattr(model, "enable_input_require_grads"):
        model.enable_input_require_grads()
    else:
        def make_inputs_require_grad(module, input, output):
            output.requires_grad_(True)
        model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
    lora_config = LoraConfig(
        r=8,
        lora_alpha=16,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
        lora_dropout=0.05, 
        bias="none", 
        task_type="CAUSAL_LM" 
    )
    peft_model = get_peft_model(model, lora_config)
    # Ensure PEFT model is on CPU
    peft_model = peft_model.to(device)
    trainable_params = sum(p.numel() for p in peft_model.parameters() if p.requires_grad)
    all_params = sum(p.numel() for p in peft_model.parameters())
    logger.info(f"Trainable parameters: {trainable_params}")
    logger.info(f"All parameters: {all_params}")
    logger.info(f"Trainable%: {100 * trainable_params / all_params:.4f}%")
    return peft_model

def process_dataset_for_fine_tuning(csv_file):
    logger.info(f"Processing dataset from {csv_file}")
    try:
        df = pd.read_csv(csv_file)
    except Exception as e:
        logger.error(f"Error loading dataset: {e}")
        logger.info("Creating a small sample dataset for testing")
        test_data = {
            'original_question': [
                "What are the symptoms of diabetes?",
                "How is hypertension diagnosed?",
                "What are common treatments for migraine?"
            ],
            'ideal_answer': [
                "Common symptoms of diabetes include frequent urination, increased thirst, unexplained weight loss, extreme hunger, blurred vision, tingling in the extremities, and frequent infections.",
                "Hypertension is diagnosed when blood pressure readings consistently show systolic pressure above 130 mmHg or diastolic pressure above 80 mmHg. Diagnosis typically requires multiple readings over time.",
                "Common treatments for migraine include pain relievers, triptans, anti-nausea medications, preventive medications like beta blockers, and lifestyle changes such as stress management and regular sleep."
            ]
        }
        df = pd.DataFrame(test_data)
    if 'original_question' not in df.columns or 'ideal_answer' not in df.columns:
        logger.error("Dataset missing required columns (original_question and ideal_answer)")
        return None
    df = df.dropna(subset=['original_question', 'ideal_answer'])
    # Reduce dataset size more aggressively for CPU training
    if len(df) > 50:
        logger.info(f"Limiting dataset to 50 examples for CPU efficiency (from {len(df)})")
        df = df.sample(50, random_state=42)
    logger.info(f"Dataset has {len(df)} valid training examples")
    system_prompt = "You are an AI Medical Assistant. Give accurate and helpful answers to medical questions."
    train_data = []
    for _, row in df.iterrows():
        conversation = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": row['original_question']},
            {"role": "assistant", "content": row['ideal_answer']}
        ]
        example = {"conversation": conversation}
        train_data.append(example)
    dataset = Dataset.from_pandas(pd.DataFrame(train_data))
    return dataset

class MedicalDataCollator:
    def __init__(self, tokenizer, max_length=128):  # Reduced max_length for CPU
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __call__(self, examples):
        conversations = [ex["conversation"] for ex in examples]
        if hasattr(self.tokenizer, "apply_chat_template"):
            input_texts = [
                self.tokenizer.apply_chat_template(
                    conv[:-1],
                    tokenize=False,
                    add_generation_prompt=True
                )
                for conv in conversations
            ]
            target_texts = [
                self.tokenizer.apply_chat_template(
                    conv,
                    tokenize=False,
                    add_generation_prompt=False
                )
                for conv in conversations
            ]
        else:
            input_texts, target_texts = [], []
            for conv in conversations:
                system = next((msg["content"] for msg in conv if msg["role"] == "system"), "")
                user = next((msg["content"] for msg in conv if msg["role"] == "user"), "")
                assistant = next((msg["content"] for msg in conv if msg["role"] == "assistant"), "")
                input_text = f"### Instruction:\n{system}\n### Instruction:\n{user}\n### Response:"
                target_text = f"### Instruction:\n{system}\n### Instruction:\n{user}\n### Response:\n{assistant}"
                input_texts.append(input_text)
                target_texts.append(target_text)
        
        model_inputs = self.tokenizer(
            input_texts,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        labels = self.tokenizer(
            target_texts,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )["input_ids"]
        
        labels_with_ignore_index = labels.clone()
        labels_with_ignore_index[labels == self.tokenizer.pad_token_id] = -100
        model_inputs["labels"] = labels_with_ignore_index
        
        # Ensure all tensors are on CPU
        for key in model_inputs:
            if isinstance(model_inputs[key], torch.Tensor):
                model_inputs[key] = model_inputs[key].to(device)
        
        return model_inputs

def fine_tune_model(model, tokenizer, dataset, output_dir):
    data_collator = MedicalDataCollator(tokenizer)
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=2,  # Reduced epochs for CPU
        per_device_train_batch_size=1,  # Reduced batch size
        gradient_accumulation_steps=8,  # Increased to maintain effective batch size
        logging_steps=5,
        save_steps=25,
        save_total_limit=2,
        eval_strategy="no",
        learning_rate=3e-5,  # Slightly reduced learning rate
        weight_decay=0.01,
        warmup_steps=5,
        fp16=False,  # Keep disabled for CPU
        bf16=False,  # Ensure disabled for CPU
        push_to_hub=False,
        report_to="none",
        remove_unused_columns=False,
        use_cpu=True,  # Explicitly force CPU usage
        no_cuda=True,  # Explicitly disable CUDA
        dataloader_pin_memory=False,  # Disable pin memory for CPU
        dataloader_num_workers=0,  # Single worker for CPU
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        data_collator=data_collator,
        processing_class=tokenizer
    )
    
    # Ensure model stays on CPU before training
    model = model.to(device)
    logger.info(f"Model device before training: {next(model.parameters()).device}")
    
    trainer.train()
    trainer.save_model(output_dir)
    return model, tokenizer

class MedicalQuestionAnswerer:
    def __init__(self, model, tokenizer, max_new_tokens=128):  # Reduced max tokens
        self.model = model
        self.tokenizer = tokenizer
        self.max_new_tokens = max_new_tokens
        # Ensure model is on CPU
        self.model = self.model.to(device)
    
    def answer_question(self, question, system_prompt="You are an AI Medical Assistant."):
        try:
            conversation = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": question}
            ]
            if hasattr(self.tokenizer, "apply_chat_template"):
                prompt = self.tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
            else:
                prompt = f"### Instruction:\n{system_prompt}\n### Instruction:\n{question}\n### Response:"
            
            input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device)
            
            with torch.no_grad():
                outputs = self.model.generate(
                    input_ids=input_ids,
                    max_new_tokens=self.max_new_tokens,
                    do_sample=False,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                    temperature=0.7,
                    top_p=0.9
                )
            
            response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            if prompt in response_text:
                answer = response_text.split(prompt)[-1].strip()
            else:
                answer = response_text.strip()
            return answer
        except Exception as e:
            logger.error(f"Error generating answer: {str(e)}")
            return f"Error generating answer: {str(e)}"
        finally:
            gc.collect()

def process_questions_file(input_csv, output_csv, model, tokenizer, batch_size=1):
    start_time = time.time()
    try:
        df = pd.read_csv(input_csv)
        logger.info(f"Loaded dataset with {len(df)} questions")
    except Exception as e:
        logger.error(f"Error loading dataset: {e}")
        return
    
    answerer = MedicalQuestionAnswerer(model, tokenizer)
    
    if 'original_answer' not in df.columns:
        df['original_answer'] = ""
    if 'faq_answer' not in df.columns:
        df['faq_answer'] = ""
    
    last_processed = 0
    for i, row in df.iterrows():
        if pd.notna(row['original_answer']) and row['original_answer'] != "":
            last_processed = i
    
    if last_processed > 0:
        logger.info(f"Resuming from question {last_processed+1}")
    
    save_interval = 5  # More frequent saves
    for i in range(last_processed, len(df), batch_size):
        batch_end = min(i + batch_size, len(df))
        batch_df = df.iloc[i:batch_end].copy()
        
        for idx, row in batch_df.iterrows():
            try:
                if not pd.notna(row['original_answer']) or row['original_answer'] == "":
                    original_question = row['original_question']
                    answer = answerer.answer_question(original_question)
                    df.at[idx, 'original_answer'] = answer
                    logger.info(f"Processed question {idx}")
                
                if 'generated_question' in row and pd.notna(row['generated_question']):
                    if not pd.notna(row['faq_answer']) or row['faq_answer'] == "":
                        generated_question = row['generated_question']
                        faq_answer = answerer.answer_question(generated_question)
                        df.at[idx, 'faq_answer'] = faq_answer
                
                if idx % save_interval == 0:
                    df.to_csv(output_csv, index=False)
                    logger.info(f"Saved progress at index {idx}")
                    gc.collect()
                    
            except Exception as e:
                logger.error(f"Error processing question {idx}: {str(e)}")
                continue
        
        df.to_csv(output_csv, index=False)
        elapsed = time.time() - start_time
        questions_processed = batch_end - last_processed
        avg_time_per_q = elapsed / max(1, questions_processed)
        remaining_qs = len(df) - batch_end
        est_time_remaining = avg_time_per_q * remaining_qs
        
        logger.info(f"Processed {batch_end}/{len(df)} questions. "
                    f"Avg: {avg_time_per_q:.2f}s per question. "
                    f"Est. remaining: {est_time_remaining/60:.1f} minutes")
        gc.collect()
    
    df.to_csv(output_csv, index=False)
    total_time = time.time() - start_time
    logger.info(f"Completed in {total_time/60:.1f} minutes. Generated answers saved to {output_csv}")

def clear_gpu_memory():
    """Clear any GPU memory that might be allocated"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

def main():
    input_file = "T5_FAQS1.csv" 
    output_file = "medical_answers_finetuned_v8.csv"
    fine_tuned_model_dir = "fine_tuned_medical_model_v6"
    
    # Clear memory at start
    clear_gpu_memory()
    
    os.makedirs("model_cache", exist_ok=True)
    
    try:
        gc.collect()
        base_model_name = "malhajar/meditron-7b-chat"
        
        try:
            model, tokenizer = load_medical_model(base_model_name)
        except Exception as e:
            logger.warning(f"Failed to load 7B model: {e}")
            logger.info("Falling back to smaller model...")
            try:
                model, tokenizer = load_medical_model("microsoft/DialoGPT-medium")
            except Exception as e2:
                logger.warning(f"Failed to load DialoGPT: {e2}")
                logger.info("Falling back to even smaller model...")
                model, tokenizer = load_medical_model("distilgpt2")
        
        dataset = process_dataset_for_fine_tuning(input_file)
        if dataset:
            model = apply_peft_to_model(model)
            model, tokenizer = fine_tune_model(model, tokenizer, dataset, fine_tuned_model_dir)
            
            # Clear memory after training
            clear_gpu_memory()
            
            process_questions_file(input_file, output_file, model, tokenizer)
        else:
            logger.error("Could not prepare dataset for fine-tuning. Check if the required columns exist.")
            
    except Exception as e:
        logger.error(f"An error occurred in the main process: {str(e)}")
        import traceback
        logger.error(traceback.format_exc())
    finally:
        clear_gpu_memory()

if __name__ == "__main__":
    main()

INFO:__main__:Using device: cpu
INFO:__main__:CPU threads configured: 20
INFO:__main__:CUDA available: False
INFO:__main__:Loading model: malhajar/meditron-7b-chat
INFO:__main__:Setting chat template for Meditron
INFO:__main__:Set pad_token to eos_token
INFO:__main__:Loading model for CPU execution...


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

INFO:__main__:Model loaded successfully
INFO:__main__:Processing dataset from T5_FAQS1.csv
INFO:__main__:Limiting dataset to 50 examples for CPU efficiency (from 16407)
INFO:__main__:Dataset has 50 valid training examples
INFO:__main__:Applying PEFT to the model...
INFO:__main__:Trainable parameters: 8388608
INFO:__main__:All parameters: 6746943488
INFO:__main__:Trainable%: 0.1243%
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
INFO:__main__:Model device before training: cpu


Step,Training Loss


KeyboardInterrupt: 

In [64]:
pip install bitsandbytes


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Defaulting to user installation because normal site-packages is not writeable
Collecting bitsandbytes
  Downloading bitsandbytes-0.46.0-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting torch<3,>=2.2 (from bitsandbytes)
  Downloading torch-2.7.1-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting sympy>=1.13.3 (from torch<3,>=2.2->bitsandbytes)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.6.77 (from torch<3,>=2.2->bitsandbytes)
  Downloading nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.6.77 (from torch<3,>=2.2->bitsandbytes)
  Downloading nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.6.80 (from torch<3,>=2.2->bitsandbytes)
  Downloading nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (