# Empathetic Medical Q&A Chatbot

## Environment Setup

In [2]:
import torch
import torchvision
import warnings
import logging
import gradio as gr
import json
from typing import List, Tuple, Dict, Any, Optional
from datetime import datetime
import os
import re

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore")

# GPU Check
available = torch.cuda.is_available()
print(f"CUDA available: {available}, GPU: {torch.cuda.get_device_name(0) if available else 'none'}")
print(torch.__version__, torchvision.__version__)

CUDA available: True, GPU: Tesla P100-PCIE-16GB
2.6.0+cu124 0.21.0+cu124


# 1. Initial Setup (Model selection & conversation flow)

## 1.1 Configuration: Model & Retrieval Settings

In [3]:
# from dataclasses import dataclass
# @dataclass
# class ModelConfig:
#     """Enhanced configuration for the empathetic language model."""
#     # Original model for comparison or fallback
#     original_model_name: str = "ritvik77/Medical_Doctor_AI_LoRA-Mistral-7B-Instruct_FullModel"
#     # Fine-tuned empathetic model path
#     empathetic_model_path: str = "/kaggle/input/model1000/transformers/default/1/ModelFile1000"  # Path to your fine-tuned model
#     embedding_model: str = "sentence-transformers/all-mpnet-base-v2"
#     max_new_tokens: int = 256
#     temperature: float = 0.3
#     top_p: float = 0.9
#     repetition_penalty: float = 1.1
#     device_map: str = "auto"
#     load_in_4bit: bool = True
#     trust_remote_code: bool = True
#     do_sample: bool = True
#     # Empathy-specific settings
#     use_empathetic_model: bool = True
#     empathy_detection_enabled: bool = True

# @dataclass
# class RAGConfig:
#     """Configuration for RAG retrieval."""
#     chunk_size: int = 1000
#     chunk_overlap: int = 20
#     retriever_k: int = 5
#     vector_store_path: str = "./vector_store"
#     document_dir: str = "./documents"

## 1.2 Document Handling: Load and Split PDFs

In [4]:
# # Import statements for document processing and vector store
# from langchain.document_loaders import PyPDFLoader
# from langchain.text_splitter import RecursiveCharacterTextSplitter
# from langchain.schema import Document
# from langchain.vectorstores import FAISS
# from langchain.embeddings import HuggingFaceEmbeddings
# from pathlib import Path

# class DocumentProcessor:
#     """Loads and splits PDF documents into chunks."""
#     def __init__(self, cfg: RAGConfig):
#         self.splitter = RecursiveCharacterTextSplitter(
#             chunk_size=cfg.chunk_size, chunk_overlap=cfg.chunk_overlap
#         )
    
#     def load_and_split(self, pdf_paths: List[str]) -> List[Document]:
#         docs = []
#         for path in pdf_paths:
#             loader = PyPDFLoader(path)
#             docs.extend(loader.load())
#         chunks = self.splitter.split_documents(docs)
#         logger.info(f"Split into {len(chunks)} chunks")
#         return chunks

## 1.3 Vector Store: Build or Load FAISS Index

In [5]:
# class VectorStoreManager:
#     """Creates or loads a FAISS vector store."""
#     def __init__(self, rag_cfg: RAGConfig, model_cfg: ModelConfig):
#         device = "cuda" if torch.cuda.is_available() else "cpu"
#         self.embedding = HuggingFaceEmbeddings(
#             model_name=model_cfg.embedding_model, model_kwargs={"device": device}
#         )
#         self.path = rag_cfg.vector_store_path
    
#     def create_or_load(self, chunks: List[Document]) -> FAISS:
#         if Path(self.path).exists():
#             vs = FAISS.load_local(self.path, self.embedding, allow_dangerous_deserialization=True)
#             logger.info("Loaded existing FAISS store")
#         else:
#             vs = FAISS.from_documents(chunks, self.embedding)
#             vs.save_local(self.path)
#             logger.info("Created new FAISS store")
#         return vs

## 1.4 Model Loader: Quantized LLM Pipeline

In [6]:
# from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
# from langchain.llms import HuggingFacePipeline

# class EmpatheticModelManager:
#     """Enhanced model manager for empathetic medical responses."""
    
#     def __init__(self, cfg: ModelConfig):
#         self.cfg = cfg
#         self.llm = None
#         self.emotion_detector = EmotionDetector()
    
#     def get_llm(self) -> HuggingFacePipeline:
#         if self.llm is None:
#             # Quantization config
#             bnb = BitsAndBytesConfig(
#                 load_in_4bit=self.cfg.load_in_4bit,
#                 bnb_4bit_quant_type="nf4",
#                 bnb_4bit_compute_dtype=torch.float16
#             )
            
#             # Determine which model to load
#             model_path = (self.cfg.empathetic_model_path 
#                          if self.cfg.use_empathetic_model and os.path.exists(self.cfg.empathetic_model_path)
#                          else self.cfg.original_model_name)
            
#             logger.info(f"Loading model from: {model_path}")
#             print(model_path)
            
#             # Load model
#             model = AutoModelForCausalLM.from_pretrained(
#                 model_path,
#                 quantization_config=bnb,
#                 device_map=self.cfg.device_map,
#                 trust_remote_code=self.cfg.trust_remote_code,
#                 use_cache=True
#             )
            
#             tokenizer = AutoTokenizer.from_pretrained(
#                 model_path,
#                 trust_remote_code=self.cfg.trust_remote_code
#             )
#             tokenizer.pad_token = tokenizer.eos_token
            
#             pipe = pipeline(
#                 "text-generation",
#                 model=model,
#                 tokenizer=tokenizer,
#                 max_new_tokens=self.cfg.max_new_tokens,
#                 temperature=self.cfg.temperature,
#                 top_p=self.cfg.top_p,
#                 repetition_penalty=self.cfg.repetition_penalty,
#                 device_map=self.cfg.device_map,
#                 do_sample=self.cfg.do_sample,
#                 return_full_text=False,
#                 pad_token_id=tokenizer.eos_token_id,
#                 eos_token_id=tokenizer.eos_token_id
#             )
            
#             self.llm = HuggingFacePipeline(pipeline=pipe)
#             logger.info("Empathetic LLM pipeline is ready")
#         return self.llm

## 1.5 EmotionDetector

In [7]:
# class EmotionDetector:
#     """Detects emotion/empathy labels from patient input."""
    
#     def __init__(self):
#         # Emotion keywords mapping based on your fine-tuning data
#         self.emotion_keywords = {
#             'afraid': ['afraid', 'scared', 'frightened', 'fear', 'terrified', 'anxious', 'worried', 'nervous'],
#             'terrified': ['terrified', 'horrified', 'petrified', 'panic', 'terror', 'nightmare'],
#             'angry': ['angry', 'mad', 'furious', 'frustrated', 'irritated', 'annoyed', 'upset', 'rage'],
#             'sad': ['sad', 'depressed', 'down', 'miserable', 'unhappy', 'crying', 'tears', 'grief'],
#             'happy': ['happy', 'glad', 'cheerful', 'pleased', 'delighted', 'excited', 'thrilled'],
#             'joyful': ['joyful', 'ecstatic', 'elated', 'overjoyed', 'blissful', 'euphoric'],
#             'proud': ['proud', 'accomplished', 'achieved', 'successful', 'triumph', 'victory'],
#             'sentimental': ['nostalgic', 'memories', 'remember', 'reminisce', 'touching', 'meaningful'],
#             'jealous': ['jealous', 'envious', 'envy', 'resentful', 'bitter'],
#             'faithful': ['loyal', 'devoted', 'committed', 'dedicated', 'faithful', 'trust']
#         }
        
#         self.emotion_contexts = {
#             'afraid': 'The patient is expressing fear and anxiety. Respond with reassurance and understanding.',
#             'terrified': 'The patient is experiencing terror and fear. Provide calm, reassuring support.',
#             'proud': 'The patient is sharing something positive. Acknowledge their achievement warmly.',
#             'joyful': 'The patient is expressing joy and happiness. Share in their positive emotions appropriately.',
#             'sentimental': 'The patient is sharing a meaningful memory. Show empathy and emotional support.',
#             'angry': 'The patient is expressing frustration or anger. Validate their feelings and offer support.',
#             'sad': 'The patient is experiencing sadness. Provide comfort and understanding.',
#             'happy': 'The patient is expressing joy. Share in their positive emotions appropriately.',
#             'jealous': 'The patient is feeling envious or jealous. Acknowledge their feelings with understanding.',
#             'faithful': 'The patient is discussing loyalty or faithfulness. Provide supportive guidance.',
#             'neutral': 'Respond with empathy and understanding.'
#         }
    
#     def detect_emotion(self, text: str) -> Tuple[str, str]:
#         """
#         Detect emotion from patient input.
#         Returns: (emotion_label, emotion_context)
#         """
#         text_lower = text.lower()
#         detected_emotions = []
        
#         # Score each emotion based on keyword matches
#         emotion_scores = {}
#         for emotion, keywords in self.emotion_keywords.items():
#             score = sum(1 for keyword in keywords if keyword in text_lower)
#             if score > 0:
#                 emotion_scores[emotion] = score
        
#         # Additional pattern-based detection
#         if any(word in text_lower for word in ['pain', 'hurt', 'ache', 'suffering']):
#             emotion_scores['sad'] = emotion_scores.get('sad', 0) + 1
        
#         if any(word in text_lower for word in ['can\'t', 'unable', 'difficult', 'hard', 'struggle']):
#             emotion_scores['afraid'] = emotion_scores.get('afraid', 0) + 1
        
#         if any(word in text_lower for word in ['better', 'improved', 'recovery', 'healing']):
#             emotion_scores['happy'] = emotion_scores.get('happy', 0) + 1
        
#         # Determine primary emotion
#         if emotion_scores:
#             primary_emotion = max(emotion_scores.keys(), key=lambda k: emotion_scores[k])
#         else:
#             primary_emotion = 'neutral'
        
#         emotion_context = self.emotion_contexts.get(primary_emotion, self.emotion_contexts['neutral'])
        
#         logger.info(f"Detected emotion: {primary_emotion} for input: '{text[:50]}...'")
#         return primary_emotion, emotion_context

## 1.5 Conversation History Manager

In [8]:
# class ConversationManager:
#     """Enhanced conversation manager with persistent history and topic generation."""

#     def __init__(self, history_dir: str = "./history"):
#         self.history: List[Tuple[str, str]] = []
#         self.session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
#         self.history_dir = history_dir
#         self.current_topic = ""
#         os.makedirs(self.history_dir, exist_ok=True)

#     def generate_topic(self, question: str) -> str:
#         """Generate a topic title from the first question of the conversation."""
#         medical_keywords = [
#             'diabetes', 'hypertension', 'cancer', 'heart', 'blood', 'pressure',
#             'symptoms', 'treatment', 'medication', 'diagnosis', 'disease',
#             'infection', 'pain', 'fever', 'cough', 'headache', 'stomach'
#         ]

#         question_lower = question.lower()
#         found_keywords = [kw for kw in medical_keywords if kw in question_lower]

#         if found_keywords:
#             main_keyword = found_keywords[0].title()
#             if len(question) > 50:
#                 return f"{main_keyword} - {question[:47]}..."
#             return f"{main_keyword} - {question}"
#         else:
#             if len(question) > 50:
#                 return f"Medical Query - {question[:47]}..."
#             return f"Medical Query - {question}"

#     def add_turn(self, question: str, answer: str):
#         """Add a question-answer turn to history."""
#         if not self.history and not self.current_topic:
#             self.current_topic = self.generate_topic(question)

#         self.history.append((question, answer))

#     def get_for_chain(self) -> List[Tuple[str, str]]:
#         """Get history in format suitable for the RAG chain."""
#         return self.history

#     def save(self, custom_filename: str = None):
#         """Save conversation with metadata including topic and timestamp."""
#         if not self.history:
#             return None

#         filename = custom_filename or f"{self.session_id}.json"
#         filepath = os.path.join(self.history_dir, filename)

#         conversation_data = {
#             "topic": self.current_topic,
#             "session_id": self.session_id,
#             "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
#             "total_turns": len(self.history),
#             "conversation": self.history
#         }

#         with open(filepath, "w", encoding='utf-8') as f:
#             json.dump(conversation_data, f, indent=2, ensure_ascii=False)

#         logger.info(f"Saved conversation '{self.current_topic}' to {filepath}")
#         return filepath

#     def load_conversation(self, filename: str):
#         """Load a previous conversation from file."""
#         filepath = os.path.join(self.history_dir, filename)

#         if not os.path.exists(filepath):
#             logger.error(f"Conversation file not found: {filepath}")
#             return False

#         try:
#             with open(filepath, "r", encoding='utf-8') as f:
#                 data = json.load(f)

#             self.history = data.get("conversation", [])
#             self.current_topic = data.get("topic", "")
#             self.session_id = data.get("session_id", self.session_id)

#             logger.info(f"Loaded conversation: {self.current_topic}")
#             return True
#         except Exception as e:
#             logger.error(f"Error loading conversation: {e}")
#             return False

#     def get_available_conversations(self) -> List[Dict[str, str]]:
#         """Get list of available conversation histories with metadata."""
#         conversations = []

#         if not os.path.exists(self.history_dir):
#             return conversations

#         for filename in os.listdir(self.history_dir):
#             if filename.endswith('.json'):
#                 filepath = os.path.join(self.history_dir, filename)
#                 try:
#                     with open(filepath, "r", encoding='utf-8') as f:
#                         data = json.load(f)

#                     conversations.append({
#                         "filename": filename,
#                         "topic": data.get("topic", "Unknown Topic"),
#                         "created_at": data.get("created_at", "Unknown Date"),
#                         "total_turns": data.get("total_turns", 0)
#                     })
#                 except Exception as e:
#                     logger.warning(f"Error reading {filename}: {e}")

#         conversations.sort(key=lambda x: x["created_at"], reverse=True)
#         return conversations

#     def clear_current_session(self):
#         """Clear current conversation and start a new session."""
#         self.history = []
#         self.session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
#         self.current_topic = ""

## 1.6 RAG Chain: Conversational Retrieval Setup

In [9]:
# from langchain.chains import ConversationalRetrievalChain
# from langchain.chains.question_answering import load_qa_chain
# from langchain.chains.llm import LLMChain
# from langchain.prompts import PromptTemplate

# class EmpatheticRAGChain:
#     """Enhanced RAG chain with empathy-aware prompts."""
    
#     def __init__(self, llm: HuggingFacePipeline, vs: FAISS, emotion_detector: EmotionDetector):
#         self.emotion_detector = emotion_detector
        
#         # Enhanced condense question prompt
#         condense_question_prompt = PromptTemplate(
#             input_variables=["question", "chat_history"],
#             template="""Given the conversation history and a follow-up question, rephrase the question to be standalone while preserving emotional context.

# Chat History:
# {chat_history}

# Follow-up Question: {question}

# Standalone Question:"""
#         )
        
#         # Empathy-aware QA prompt
#         qa_prompt = PromptTemplate(
#             input_variables=["context", "question", "emotion_context"],
#             template="""<|system|>
# You are an empathetic medical AI assistant trained to provide compassionate and understanding responses to patients.

# <|user|>
# You are responding to a patient who needs medical guidance. Use ONLY the provided medical context to answer their question.

# Patient's Emotional State: {emotion_context}

# Medical Context: {context}

# Patient Question: {question}

# Important guidelines:
# - Show empathy and understanding for the patient's emotional state
# - Provide direct, concise medical information based on the context
# - Be supportive and reassuring while maintaining medical accuracy
# - End your response naturally without additional questions
# - If the context doesn't contain relevant information, acknowledge this empathetically

# <|assistant|>"""
#         )
        
#         question_generator = LLMChain(
#             llm=llm,
#             prompt=condense_question_prompt
#         )
        
#         combine_docs_chain = load_qa_chain(
#             llm=llm,
#             chain_type="stuff",
#             prompt=qa_prompt
#         )
        
#         self.chain = ConversationalRetrievalChain(
#             retriever=vs.as_retriever(search_kwargs={"k": 5}),
#             combine_docs_chain=combine_docs_chain,
#             question_generator=question_generator,
#             return_source_documents=True
#         )
    
#     def query(self, question: str, history: List[Tuple[str, str]]):
#         """Enhanced query with emotion detection."""
#         # Detect emotion from the question
#         emotion_label, emotion_context = self.emotion_detector.detect_emotion(question)
        
#         # Modify the chain's prompt to include emotion context
#         result = self.chain({
#             "question": question, 
#             "chat_history": history,
#             "emotion_context": emotion_context
#         })
        
#         # Add emotion info to result
#         result["detected_emotion"] = emotion_label
#         result["emotion_context"] = emotion_context
        
#         return result

## 1.7 MedicalChatbot Core: Integrate All

In [10]:
# class EmpatheticMedicalChatbot:
#     """Enhanced medical chatbot with empathetic responses and persistent conversation history."""

#     def __init__(self):
#         self.model_cfg = ModelConfig()
#         self.rag_cfg = RAGConfig()
#         self.doc_proc = DocumentProcessor(self.rag_cfg)
#         self.vec_mgr = VectorStoreManager(self.rag_cfg, self.model_cfg)
#         self.mod_mgr = EmpatheticModelManager(self.model_cfg)
#         self.conv_mgr = ConversationManager()
#         self.emotion_detector = EmotionDetector()
#         self.rag_chain: EmpatheticRAGChain = None

#     def setup_documents(self, pdf_paths: List[str]):
#         """Setup documents and initialize RAG chain."""
#         chunks = self.doc_proc.load_and_split(pdf_paths)
#         vs = self.vec_mgr.create_or_load(chunks)
#         llm = self.mod_mgr.get_llm()
#         self.rag_chain = EmpatheticRAGChain(llm, vs, self.emotion_detector)

#     def chat(self, question: str) -> Dict[str, Any]:
#         """Process a question and return empathetic answer with sources."""
#         try:
#             if not question.strip():
#                 return {
#                     "answer": "I'm here to help. Please feel free to share your medical concerns with me.", 
#                     "sources": [],
#                     "detected_emotion": "neutral"
#                 }

#             hist = self.conv_mgr.get_for_chain()
#             result = self.rag_chain.query(question, hist)
#             ans = result["answer"]
#             ans = self._clean_answer(ans)
            
#             # Add emotion information to the response
#             detected_emotion = result.get("detected_emotion", "neutral")
            
#             self.conv_mgr.add_turn(question, ans)

#             logger.info(f"Successfully processed question with emotion '{detected_emotion}': {question[:50]}...")
#             return {
#                 "answer": ans, 
#                 "sources": result.get("source_documents", []),
#                 "detected_emotion": detected_emotion,
#                 "emotion_context": result.get("emotion_context", "")
#             }
#         except Exception as e:
#             logger.error(f"Error processing question: {e}")
#             return {
#                 "answer": "I apologize, but I encountered an issue while processing your question. Please try again, and I'll do my best to help you.", 
#                 "sources": [],
#                 "detected_emotion": "neutral"
#             }

#     def _clean_answer(self, answer: str) -> str:
#         """Enhanced answer cleaning with better post-processing."""
#         # Remove system tokens and unwanted phrases
#         system_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|endoftext|>"]
#         for token in system_tokens:
#             answer = answer.replace(token, "")
        
#         stop_phrases = [
#             "Stop generating once you have answered the question completely.",
#             "Stop generating.",
#             "Important instructions:",
#             "Context:",
#             "Question:",
#             "Answer:",
#             "Patient's Emotional State:",
#             "Medical Context:",
#             "Patient Question:"
#         ]

#         for phrase in stop_phrases:
#             answer = answer.replace(phrase, "")

#         # Handle truncated sentences
#         lines = answer.split('\n')
#         cleaned_lines = []

#         for line in lines:
#             line = line.strip()
#             if not line:
#                 continue

#             # Skip obvious metadata
#             if any(indicator in line for indicator in ['Source:', 'Sources:', 'Reference:', 'References:']):
#                 break

#             if any(indicator in line.lower() for indicator in ['http', 'www.', '.org', '.com']):
#                 break

#             cleaned_lines.append(line)

#         final_answer = '\n'.join(cleaned_lines).strip()

#         # Fix truncated sentences
#         if final_answer and not final_answer.endswith(('.', '!', '?', ':')):
#             sentences = final_answer.split('.')
#             if len(sentences) > 1:
#                 complete_sentences = sentences[:-1]
#                 final_answer = '.'.join(complete_sentences) + '.'
#             else:
#                 final_answer += '.'

#         # Remove duplicate sentences
#         sentences = [s.strip() for s in final_answer.split('.') if s.strip()]
#         unique_sentences = []
#         seen = set()

#         for sentence in sentences:
#             sentence_key = sentence.lower().replace(' ', '')[:50]
#             if sentence_key not in seen:
#                 seen.add(sentence_key)
#                 unique_sentences.append(sentence)

#         if unique_sentences:
#             final_answer = '. '.join(unique_sentences) + '.'

#         return final_answer

#     # Preserve all original methods
#     def save_current_conversation(self) -> Optional[str]:
#         """Save current conversation to file."""
#         return self.conv_mgr.save()

#     def load_conversation(self, filename: str) -> bool:
#         """Load a previous conversation."""
#         return self.conv_mgr.load_conversation(filename)

#     def get_conversation_list(self) -> List[Dict[str, str]]:
#         """Get list of available conversations."""
#         return self.conv_mgr.get_available_conversations()

#     def start_new_conversation(self):
#         """Start a new conversation session."""
#         if self.conv_mgr.history:
#             self.save_current_conversation()
#         self.conv_mgr.clear_current_session()

#     def get_current_conversation_info(self) -> Dict[str, Any]:
#         """Get information about current conversation."""
#         return {
#             "topic": self.conv_mgr.current_topic,
#             "session_id": self.conv_mgr.session_id,
#             "total_turns": len(self.conv_mgr.history),
#             "history": self.conv_mgr.history
#         }

#     def get_emotion_detection_status(self) -> bool:
#         """Check if emotion detection is enabled."""
#         return self.model_cfg.empathy_detection_enabled

#     def toggle_emotion_detection(self, enabled: bool):
#         """Toggle emotion detection on/off."""
#         self.model_cfg.empathy_detection_enabled = enabled
#         logger.info(f"Emotion detection {'enabled' if enabled else 'disabled'}")

## 1.8 Example


In [11]:
# import gradio as gr
# from typing import List, Tuple
# import logging

# # Import your EmpatheticMedicalChatbot class from the first file
# # from your_main_file import EmpatheticMedicalChatbot

# USE_GUI = False  # Set to False for command-line mode

# # Configure logging
# logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__)

# def initialize_chatbot():
#     """Initialize the empathetic medical chatbot."""
#     print("Initializing Empathetic Medical Chatbot...")
#     bot = EmpatheticMedicalChatbot()
#     print("Setting up documents...")
#     # Update this path to your medical documents
#     bot.setup_documents(["/kaggle/input/medical-dataset/Dataset/Medical_Book.pdf"])
#     return bot

# # Initialize the bot
# bot = initialize_chatbot()

# if not USE_GUI:
#     print("Running in command-line mode...\n")
#     questions = [
#         "I'm really scared about my diabetes diagnosis",
#         "What are the symptoms of diabetes?", 
#         "How can I prevent complications? I'm worried about my future.",
#         "I feel happy that my blood sugar is improving!"
#     ]
    
#     for q in questions:
#         print(f"Q: {q}")
#         res = bot.chat(q)
#         print(f"A: {res['answer']}")
#         print(f"Detected Emotion: {res['detected_emotion']}")
#         print(f"Emotion Context: {res['emotion_context']}")
#         print("-" * 80)

# else:
#     def respond(user_input: str, chat_history: List[Tuple[str, str]]):
#         """Enhanced response function with emotion display."""
#         if not user_input.strip():
#             return chat_history, "", update_conversation_list(), ""
        
#         result = bot.chat(user_input)
#         answer = result["answer"]
#         detected_emotion = result.get("detected_emotion", "neutral")
        
#         # Format the response with emotion indicator
#         emotion_emoji = {
#             'afraid': '😰', 'terrified': '😱', 'angry': '😠', 'sad': '😢',
#             'happy': '😊', 'joyful': '😄', 'proud': '😌', 'sentimental': '🥺',
#             'jealous': '😒', 'faithful': '🤗', 'neutral': '😐'
#         }
        
#         emoji = emotion_emoji.get(detected_emotion, '😐')
#         emotion_text = f"[Emotion detected: {detected_emotion} {emoji}]"
        
#         # Add emotion info to the answer display
#         formatted_answer = f"{answer}\n\n*{emotion_text}*"
        
#         chat_history.append((user_input, formatted_answer))
#         bot.save_current_conversation()
        
#         return chat_history, "", update_conversation_list(), get_emotion_status()

#     def reset_chat():
#         """Reset chat and start new conversation."""
#         bot.start_new_conversation()
#         return [], update_conversation_list(), get_emotion_status()

#     def load_conversation(evt: gr.SelectData):
#         """Load a previous conversation."""
#         if evt.index is None:
#             return [], get_current_topic()
        
#         try:
#             conversations = bot.get_conversation_list()
#             if evt.index < len(conversations):
#                 filename = conversations[evt.index]["filename"]
#                 if bot.load_conversation(filename):
#                     conv_info = bot.get_current_conversation_info()
#                     # Format history with emotion indicators
#                     formatted_history = []
#                     for q, a in conv_info["history"]:
#                         formatted_history.append((q, a))
#                     return formatted_history, get_current_topic()
#         except Exception as e:
#             gr.Warning(f"Error loading conversation: {e}")
        
#         return [], get_current_topic()

#     def update_conversation_list():
#         """Update the conversation list display."""
#         try:
#             conversations = bot.get_conversation_list()
#             rows = []
#             for c in conversations:
#                 topic_display = f"📋 {c['topic']}\n📅 {c['created_at']} | 💬 {c['total_turns']} turns"
#                 rows.append([topic_display])
#             return rows
#         except Exception as e:
#             logger.error(f"Error updating conversation list: {e}")
#             return []

#     def get_current_topic():
#         """Get current conversation topic with enhanced formatting."""
#         info = bot.get_current_conversation_info()
#         if info["topic"]:
#             return f"**Current Topic:** {info['topic']} | **Turns:** {info['total_turns']}"
#         return "**Current Topic:** New Conversation"

#     def get_emotion_status():
#         """Get emotion detection status."""
#         status = bot.get_emotion_detection_status()
#         return f"🧠 Emotion Detection: {'✅ Enabled' if status else '❌ Disabled'}"

#     def toggle_emotion_detection(enabled: bool):
#         """Toggle emotion detection on/off."""
#         bot.toggle_emotion_detection(enabled)
#         return get_emotion_status()

#     def refresh_conversation_list():
#         """Refresh the conversation list."""
#         return update_conversation_list()

#     # Create the Gradio interface
#     with gr.Blocks(
#         title="Empathetic Medical Chatbot", 
#         theme=gr.themes.Soft(),
#         css="""
#         .emotion-indicator {
#             background: linear-gradient(45deg, #ff9a9e, #fecfef);
#             border-radius: 10px;
#             padding: 5px;
#             margin: 5px 0;
#         }
#         .chatbot {
#             border-radius: 15px;
#         }
#         """
#     ) as demo:
        
#         gr.Markdown("""
#         # 🏥 Empathetic Medical Chatbot
#         ### An AI assistant that understands your emotions and provides compassionate medical guidance
#         """)

#         with gr.Row():
#             # Main chat interface
#             with gr.Column(scale=3):
#                 current_topic = gr.Markdown(get_current_topic())
                
#                 chat = gr.Chatbot(
#                     label="💬 Empathetic Medical Consultation",
#                     height=500,
#                     show_copy_button=True,
#                     show_share_button=False,
#                     avatar_images=["👤", "🩺"],
#                     bubble_full_width=False
#                 )
                
#                 with gr.Row():
#                     user_msg = gr.Textbox(
#                         placeholder="Share your medical concerns... I'm here to help with understanding and compassion.",
#                         label="Your Question",
#                         scale=4,
#                         lines=2,
#                         max_lines=5
#                     )
#                     send_btn = gr.Button("Send 📤", variant="primary", scale=1, size="lg")

#             # Sidebar with controls and history
#             with gr.Column(scale=1):
#                 # Emotion detection controls
#                 gr.Markdown("### 🧠 Emotion Detection")
#                 emotion_status = gr.Markdown(get_emotion_status())
#                 emotion_toggle = gr.Checkbox(
#                     label="Enable Emotion Detection",
#                     value=bot.get_emotion_detection_status(),
#                     info="Detects emotional context in your messages"
#                 )
                
#                 gr.Markdown("### 📚 Conversation History")
#                 with gr.Row():
#                     refresh_btn = gr.Button("🔄 Refresh", variant="secondary", size="sm", scale=1)
#                     clear_btn = gr.Button("🗑️ New Chat", variant="stop", size="sm", scale=1)
                
#                 conversation_list = gr.Dataframe(
#                     headers=["Previous Conversations"],
#                     datatype=["str"],
#                     label="📋 Click to continue a conversation:",
#                     interactive=True,
#                     wrap=True,
#                     height=300,
#                     column_widths=[400]
#                 )

#         # Event handlers
#         def respond_and_update_all(user_input, chat_history):
#             new_history, empty_input, updated_list, emotion_status = respond(user_input, chat_history)
#             return new_history, empty_input, get_current_topic(), updated_list, emotion_status

#         def reset_and_update_all():
#             empty_chat, updated_list, emotion_status = reset_chat()
#             return empty_chat, get_current_topic(), updated_list, "", emotion_status

#         def load_and_update_topic(evt: gr.SelectData):
#             hist, topic = load_conversation(evt)
#             return hist, topic

#         def update_emotion_status(enabled):
#             status = toggle_emotion_detection(enabled)
#             return status

#         # Connect events
#         send_btn.click(
#             respond_and_update_all,
#             inputs=[user_msg, chat],
#             outputs=[chat, user_msg, current_topic, conversation_list, emotion_status]
#         )
        
#         user_msg.submit(
#             respond_and_update_all,
#             inputs=[user_msg, chat],
#             outputs=[chat, user_msg, current_topic, conversation_list, emotion_status]
#         )
        
#         conversation_list.select(
#             load_and_update_topic,
#             outputs=[chat, current_topic]
#         )
        
#         clear_btn.click(
#             reset_and_update_all,
#             outputs=[chat, current_topic, conversation_list, user_msg, emotion_status]
#         )
        
#         refresh_btn.click(
#             refresh_conversation_list,
#             outputs=[conversation_list]
#         )
        
#         emotion_toggle.change(
#             update_emotion_status,
#             inputs=[emotion_toggle],
#             outputs=[emotion_status]
#         )

#         # Initialize interface
#         def initialize_interface():
#             return update_conversation_list(), get_current_topic(), get_emotion_status()

#         demo.load(
#             initialize_interface,
#             outputs=[conversation_list, current_topic, emotion_status]
#         )

#     # Launch the interface
#     demo.launch(
#         share=True,
#         server_name="0.0.0.0",
#         server_port=7860,
#         show_error=True
#     )

# 2. Empathy Dataset Finetuning Code by LORA


In [12]:
import os
import json
import torch
import pandas as pd
from typing import List, Dict, Tuple
from dataclasses import dataclass, field
from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    BitsAndBytesConfig,
    __version__ as transformers_version
)
from peft import (
    LoraConfig, 
    get_peft_model, 
    prepare_model_for_kbit_training,
    TaskType,
    __version__ as peft_version
)
import logging
import gc
import torch.cuda as cuda
from collections import Counter

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Log library versions
logger.info(f"Transformers version: {transformers_version}")
logger.info(f"PEFT version: {peft_version}")
logger.info(f"PyTorch version: {torch.__version__}")

@dataclass
class FinetuningConfig:
    """Configuration for fine-tuning the medical chatbot."""
    base_model: str = "ritvik77/Medical_Doctor_AI_LoRA-Mistral-7B-Instruct_FullModel"
    output_dir: str = "./kaggle/working/empathetic_medical_model"
    train_file: str = "/kaggle/input/medical-dataset/Dataset/empatheticdialogues/train.csv"
    val_file: str = "/kaggle/input/medical-dataset/Dataset/empatheticdialogues/valid.csv"
    test_file: str = "/kaggle/input/medical-dataset/Dataset/empatheticdialogues/test.csv"
    
    # LoRA parameters
    lora_r: int = 8
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"])
    
    # Training parameters
    num_train_epochs: int = 3
    per_device_train_batch_size: int = 1
    per_device_eval_batch_size: int = 1
    gradient_accumulation_steps: int = 2
    learning_rate: float = 2e-4
    weight_decay: float = 0.01
    warmup_steps: int = 50
    logging_steps: int = 50
    save_steps: int = 500
    eval_steps: int = 500
    max_seq_length: int = 256
    
    # Generation parameters
    max_new_tokens: int = 75  # Reduced to encourage concise responses
    temperature: float = 0.5  # Lowered to reduce randomness
    top_p: float = 0.8        # Adjusted to improve coherence

class EmpatheticDataProcessor:
    """Process conversational data to create empathetic medical training examples."""
    
    def __init__(self, config: FinetuningConfig):
        self.config = config
        self.emotion_contexts = {
            'afraid': 'The patient is expressing fear and anxiety. Respond with reassurance and understanding.',
            'terrified': 'The patient is experiencing terror and fear. Provide calm, reassuring support.',
            'proud': 'The patient is sharing something positive. Acknowledge their achievement warmly.',
            'joyful': 'The patient is expressing joy and happiness. Share in their positive emotions appropriately.',
            'sentimental': 'The patient is sharing a meaningful memory. Show empathy and emotional support.',
            'angry': 'The patient is expressing frustration or anger. Validate their feelings and offer support.',
            'sad': 'The patient is experiencing sadness. Provide comfort and understanding.',
            'happy': 'The patient is expressing joy. Share in their positive emotions appropriately.',
            'jealous': 'The patient is feeling envious or jealous. Acknowledge their feelings with understanding.',
            'faithful': 'The patient is discussing loyalty or faithfulness. Provide supportive guidance.',
            'neutral': 'Respond with empathy and understanding.'
        }
    
    def load_conversations_from_csv(self, file_path: str) -> List[Dict]:
        """Load conversations from CSV file with proper column handling."""
        conversations = []
        try:
            df = pd.read_csv(file_path, encoding='utf-8', on_bad_lines='warn')
            logger.info(f"Loaded CSV with columns: {list(df.columns)}")
            logger.info(f"CSV shape: {df.shape}")
            
            required_cols = ['conv_id', 'utterance_idx', 'context', 'speaker_idx', 'utterance']
            missing_cols = [col for col in required_cols if col not in df.columns]
            
            if missing_cols:
                logger.error(f"Missing required columns: {missing_cols}")
                return []
            
            for idx, row in df.iterrows():
                utterance = str(row.get('utterance', ''))
                if (pd.isna(utterance) or 
                    utterance.lower() in ['nan', ''] or 
                    utterance.strip() == '' or
                    utterance.strip().lower() in ['<hi>', '<unigram>']):
                    logger.debug(f"Skipping row {idx} with invalid utterance: '{utterance}'")
                    continue
                    
                conv_data = {
                    'conv_id': str(row.get('conv_id', f'conv_{idx}')),
                    'utterance_idx': int(row.get('utterance_idx', 0)),
                    'context': str(row.get('context', 'neutral')).strip().lower(),
                    'prompt': str(row.get('prompt', '')),
                    'speaker_idx': int(row.get('speaker_idx', 0)),
                    'utterance': utterance.strip(),
                    'selfeval': str(row.get('selfeval', ''))
                }
                
                conversations.append(conv_data)
                
            logger.info(f"Successfully processed {len(conversations)} conversation turns")
            
            if conversations:
                logger.info("Sample conversation data:")
                for i, conv in enumerate(conversations[:3]):
                    logger.info(f"Sample {i+1}: conv_id={conv['conv_id']}, speaker={conv['speaker_idx']}, utterance='{conv['utterance'][:50]}...'")
                    
        except Exception as e:
            logger.error(f"Error loading CSV from {file_path}: {e}")
            import traceback
            logger.error(f"Traceback: {traceback.format_exc()}")
            return []
            
        return conversations
    
    def create_training_examples(self, conversations: List[Dict]) -> List[Dict]:
        """Convert conversations into training examples for empathetic medical responses."""
        training_examples = []
        
        conv_groups = {}
        for conv in conversations:
            conv_id = conv['conv_id']
            if conv_id not in conv_groups:
                conv_groups[conv_id] = []
            conv_groups[conv_id].append(conv)
        
        logger.info(f"Processing {len(conv_groups)} conversation groups")
        
        for conv_id, turns in conv_groups.items():
            turns.sort(key=lambda x: x['utterance_idx'])
            emotion = turns[0]['context'] if turns else 'neutral'
            emotion_context = self.emotion_contexts.get(emotion, 'Respond with empathy and understanding.')
            
            for i in range(len(turns) - 1):
                current_turn = turns[i]
                next_turn = turns[i + 1]
                
                if current_turn['speaker_idx'] == next_turn['speaker_idx']:
                    continue
                
                if not current_turn['utterance'].strip() or not next_turn['utterance'].strip():
                    continue
                
                medical_context = f"Patient emotion: {emotion}. {emotion_context}"
                # Revised instruction to enforce a single, concise response
                instruction = f"""You are an empathetic medical AI assistant. A patient is sharing their concerns with you.

Context: {medical_context}
Patient: {current_turn['utterance']}

Provide a single, concise response with empathy, understanding, and appropriate supportive guidance. Keep your response compassionate, professional, and focused on addressing the patient's input directly. Avoid generating multi-turn conversations or follow-up questions unless explicitly relevant."""
                
                # Use the next turn's utterance as the response, reinterpreted as a standalone answer
                response = next_turn['utterance']
                
                training_examples.append({
                    'instruction': instruction,
                    'response': response,
                    'emotion': emotion,
                    'conv_id': conv_id
                })
        
        logger.info(f"Created {len(training_examples)} training examples")
        return training_examples
    
    def format_for_training(self, examples: List[Dict]) -> List[str]:
        """Format examples for causal language modeling with emphasis on single responses."""
        formatted_examples = []
        
        for example in examples:
            formatted_text = f"""<|system|>
You are an empathetic medical AI assistant trained to provide compassionate, concise, and professional single responses to patients.

<|user|>
{example['instruction']}

<|assistant|>
{example['response']}<|endoftext|>"""
            
            formatted_examples.append(formatted_text)
        
        return formatted_examples

class EmpatheticMedicalTrainer:
    """Fine-tune the medical model for empathetic responses."""
    
    def __init__(self, config: FinetuningConfig):
        self.config = config
        self.tokenizer = None
        self.model = None
        self.data_processor = EmpatheticDataProcessor(config)
    
    def log_gpu_memory(self):
        """Log GPU memory usage for debugging."""
        if torch.cuda.is_available():
            for i in range(torch.cuda.device_count()):
                total_mem = torch.cuda.get_device_properties(i).total_memory / 1024**3
                allocated_mem = torch.cuda.memory_allocated(i) / 1024**3
                free_mem = total_mem - allocated_mem
                logger.info(f"GPU {i}: Total={total_mem:.2f}GB, Allocated={allocated_mem:.2f}GB, Free={free_mem:.2f}GB")
    
    def setup_model_and_tokenizer(self):
        """Initialize model and tokenizer with LoRA configuration."""
        logger.info(f"Loading model and tokenizer: {self.config.base_model}")
        
        # Clear GPU memory
        torch.cuda.empty_cache()
        gc.collect()
        self.log_gpu_memory()
        
        # Quantization config with CPU offloading
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_use_nested_quant=True
        )
        
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.config.base_model,
            trust_remote_code=True,
            padding_side="right"
        )
        
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Load model with balanced device mapping and CPU offloading
        self.model = AutoModelForCausalLM.from_pretrained(
            self.config.base_model,
            quantization_config=bnb_config,
            device_map="balanced",
            trust_remote_code=True,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            offload_state_dict=True
        )
        
        # Prepare model for k-bit training
        self.model = prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=True)
        
        # LoRA configuration
        lora_config = LoraConfig(
            r=self.config.lora_r,
            lora_alpha=self.config.lora_alpha,
            target_modules=self.config.target_modules,
            lora_dropout=self.config.lora_dropout,
            bias="none",
            task_type=TaskType.CAUSAL_LM
        )
        
        # Apply LoRA to model
        self.model = get_peft_model(self.model, lora_config)
        self.model.print_trainable_parameters()
        self.log_gpu_memory()
    
    def prepare_datasets(self) -> DatasetDict:
        """Prepare training, validation, and test datasets."""
        logger.info("Preparing datasets...")
        
        datasets = {}
        
        for split, filename in [('train', self.config.train_file), 
                              ('validation', self.config.val_file),
                              ('test', self.config.test_file)]:
            
            if os.path.exists(filename):
                logger.info(f"Processing {split} dataset from {filename}")
                
                conversations = self.data_processor.load_conversations_from_csv(filename)
                
                if not conversations:
                    logger.warning(f"No conversations loaded from {filename}")
                    continue
                
                logger.info(f"Loaded {len(conversations)} conversations from {filename}")
                
                examples = self.data_processor.create_training_examples(conversations)
                
                if not examples:
                    logger.warning(f"No training examples created for {split}")
                    continue
                
                logger.info(f"Created {len(examples)} training examples for {split}")
                
                # Subsample dataset if too large
                if len(examples) > 10000:
                    examples = examples[:10000]
                    logger.info(f"Subsampled to {len(examples)} examples to reduce memory usage")
                
                formatted_texts = self.data_processor.format_for_training(examples)
                datasets[split] = Dataset.from_dict({'text': formatted_texts})
                logger.info(f"Created {split} dataset with {len(formatted_texts)} examples")
            else:
                logger.warning(f"File not found: {filename}")
        
        if not datasets:
            logger.error("No datasets were created! Check your file paths and data format.")
            return DatasetDict()
        
        return DatasetDict(datasets)
    
    def tokenize_function(self, examples):
        """Tokenize examples for training."""
        tokenized = self.tokenizer(
            examples['text'],
            truncation=True,
            padding='max_length',
            max_length=self.config.max_seq_length,
            return_tensors="pt",
            return_special_tokens_mask=True
        )
        
        tokenized["labels"] = tokenized["input_ids"].clone()
        if "attention_mask" in tokenized:
            tokenized["labels"] = [
                [-100 if mask == 0 else token for token, mask in zip(label, mask)]
                for label, mask in zip(tokenized["labels"], tokenized["attention_mask"])
            ]
        
        return tokenized
    
    def train(self):
        """Execute the fine-tuning process."""
        logger.info("Starting fine-tuning process...")
        
        self.setup_model_and_tokenizer()
        
        logger.info("Preparing datasets...")
        datasets = self.prepare_datasets()
        
        if not datasets or 'train' not in datasets:
            logger.error("No training dataset available for training")
            return
        
        logger.info(f"Training dataset size: {len(datasets['train'])}")
        if 'validation' in datasets:
            logger.info(f"Validation dataset size: {len(datasets['validation'])}")
        
        logger.info("Sample training examples:")
        for i in range(min(2, len(datasets['train']))):
            example = datasets['train'][i]['text']
            logger.info(f"Example {i+1}: {example[:200]}...")
        
        logger.info("Tokenizing datasets...")
        tokenized_datasets = datasets.map(
            self.tokenize_function,
            batched=True,
            remove_columns=datasets['train'].column_names,
            desc="Tokenizing",
            num_proc=1
        )
        
        logger.info("Tokenization complete!")
        logger.info(f"Tokenized train dataset size: {len(tokenized_datasets['train'])}")
        
        for split in tokenized_datasets:
            lengths = [len(item['input_ids']) for item in tokenized_datasets[split]]
            logger.info(f"{split} dataset sequence lengths - min: {min(lengths)}, max: {max(lengths)}, mean: {sum(lengths)/len(lengths):.2f}")
        
        sample_tokens = tokenized_datasets['train'][0]
        logger.info(f"Sample tokenized length: {len(sample_tokens['input_ids'])}")
        
        logger.info("Setting up data collator...")
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=self.tokenizer,
            mlm=False,
            pad_to_multiple_of=8
        )
        
        logger.info("Setting up training arguments...")
        training_args = TrainingArguments(
            output_dir=self.config.output_dir,
            num_train_epochs=self.config.num_train_epochs,
            per_device_train_batch_size=self.config.per_device_train_batch_size,
            per_device_eval_batch_size=self.config.per_device_eval_batch_size,
            gradient_accumulation_steps=self.config.gradient_accumulation_steps,
            learning_rate=self.config.learning_rate,
            weight_decay=self.config.weight_decay,
            warmup_steps=self.config.warmup_steps,
            logging_steps=self.config.logging_steps,
            save_steps=self.config.save_steps,
            eval_steps=self.config.eval_steps,
            eval_strategy="steps" if 'validation' in tokenized_datasets else "no",
            save_strategy="steps",
            load_best_model_at_end=True if 'validation' in tokenized_datasets else False,
            metric_for_best_model="eval_loss" if 'validation' in tokenized_datasets else None,
            greater_is_better=False,
            remove_unused_columns=False,
            dataloader_pin_memory=True,
            gradient_checkpointing=True,
            gradient_checkpointing_kwargs={"use_reentrant": False},
            fp16=True,
            report_to="none",
            logging_first_step=True,
            disable_tqdm=False
        )
        
        logger.info("Initializing trainer...")
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=tokenized_datasets.get('train'),
            eval_dataset=tokenized_datasets.get('validation'),
            data_collator=data_collator,
            tokenizer=self.tokenizer
        )
        
        total_steps = len(tokenized_datasets['train']) // (self.config.per_device_train_batch_size * self.config.gradient_accumulation_steps) * self.config.num_train_epochs
        logger.info(f"Total training steps: {total_steps}")
        logger.info(f"Effective batch size: {self.config.per_device_train_batch_size * self.config.gradient_accumulation_steps}")
        
        logger.info("=" * 50)
        logger.info("STARTING TRAINING NOW!")
        logger.info("=" * 50)
        
        try:
            trainer.train()
        except Exception as e:
            logger.error(f"Training failed with error: {e}")
            raise
        
        logger.info("Saving final model...")
        trainer.save_model()
        self.tokenizer.save_pretrained(self.config.output_dir)
        
        del self.model
        del trainer
        torch.cuda.empty_cache()
        gc.collect()
        
        logger.info(f"Training completed! Model saved to {self.config.output_dir}")
    
    def evaluate_model(self):
        """Evaluate the fine-tuned model on test set with detailed performance logging."""
        if not os.path.exists(self.config.test_file):
            logger.warning("Test file not found, skipping evaluation")
            return
        
        logger.info("Evaluating model...")
        
        test_conversations = self.data_processor.load_conversations_from_csv(self.config.test_file)
        test_examples = self.data_processor.create_training_examples(test_conversations)
        
        if not test_examples:
            logger.warning("No test examples created, skipping evaluation")
            return
        
        # Initialize performance metrics
        total_examples = len(test_examples)
        single_response_count = 0
        multi_turn_count = 0
        empathy_score = 0
        medical_relevance_score = 0  # Placeholder, limited by dataset
        
        # Define empathy and medical keywords (expandable with domain knowledge)
        empathy_keywords = ['great', 'wonderful', 'happy', 'support', 'reassurance', 'understand']
        medical_keywords = ['blood', 'sugar', 'health', 'monitor', 'treatment']  # Limited by dataset scope
        
        for i, example in enumerate(test_examples[:5]):  # Limit to 5 for logging
            print(f"\n--- Example {i+1} ---")
            print(f"Emotion: {example['emotion']}")
            print(f"Instruction: {example['instruction'][:200]}...")
            print(f"Expected Response: {example['response'][:100]}...")
            
            # Generate prediction (simplified, using expected response as proxy)
            response = example['response']
            sentences = response.split('. ')
            is_single = len(sentences) <= 2 and not any('?' in s for s in sentences)  # Heuristic for single response
            if is_single:
                single_response_count += 1
                logger.info(f"Example {i+1}: Single response detected")
            else:
                multi_turn_count += 1
                logger.info(f"Example {i+1}: Multi-turn response detected")
            
            # Empathy score (simple keyword match)
            response_lower = response.lower()
            if any(keyword in response_lower for keyword in empathy_keywords):
                empathy_score += 1
                logger.info(f"Example {i+1}: Contains empathetic language")
            
            # Medical relevance score (simple keyword match, limited by dataset)
            if any(keyword in response_lower for keyword in medical_keywords):
                medical_relevance_score += 1
                logger.info(f"Example {i+1}: Contains medical relevance")
        
        # Calculate performance metrics
        single_response_rate = (single_response_count / total_examples) * 100 if total_examples > 0 else 0
        multi_turn_rate = (multi_turn_count / total_examples) * 100 if total_examples > 0 else 0
        empathy_rate = (empathy_score / total_examples) * 100 if total_examples > 0 else 0
        medical_relevance_rate = (medical_relevance_score / total_examples) * 100 if total_examples > 0 else 0
        
        # Log performance summary
        logger.info("=" * 50)
        logger.info("Evaluation Performance Summary")
        logger.info(f"Total Examples Evaluated: {total_examples}")
        logger.info(f"Single Response Rate: {single_response_rate:.2f}% ({single_response_count}/{total_examples})")
        logger.info(f"Multi-Turn Rate: {multi_turn_rate:.2f}% ({multi_turn_count}/{total_examples})")
        logger.info(f"Empathy Rate: {empathy_rate:.2f}% ({empathy_score}/{total_examples})")
        logger.info(f"Medical Relevance Rate: {medical_relevance_rate:.2f}% ({medical_relevance_score}/{total_examples})")
        logger.info(f"Evaluation completed at: {torch.cuda.current_device() if torch.cuda.is_available() else 'CPU'}")
        logger.info("=" * 50)

def test_data_loading():
    """Test function to verify data loading works correctly."""
    config = FinetuningConfig()
    processor = EmpatheticDataProcessor(config)
    
    if os.path.exists(config.train_file):
        print(f"Testing data loading from: {config.train_file}")
        conversations = processor.load_conversations_from_csv(config.train_file)
        print(f"Loaded {len(conversations)} conversations")
        
        if conversations:
            print(f"Sample conversation: {conversations[0]}")
            
            examples = processor.create_training_examples(conversations)
            print(f"Created {len(examples)} training examples")
            
            if examples:
                print(f"Sample training example:")
                print(f"Emotion: {examples[0]['emotion']}")
                print(f"Instruction: {examples[0]['instruction'][:200]}...")
                print(f"Response: {examples[0]['response'][:100]}...")
        else:
            print("No conversations loaded!")
    else:
        print(f"Train file not found: {config.train_file}")

def main():
    """Main function to run the fine-tuning process."""
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    
    config = FinetuningConfig()
    
    print("Testing data loading...")
    test_data_loading()
    print("\n" + "="*50 + "\n")
    
    trainer = EmpatheticMedicalTrainer(config)
    
    try:
        trainer.train()
        trainer.evaluate_model()
    except Exception as e:
        logger.error(f"Training failed: {e}")
        raise
    finally:
        torch.cuda.empty_cache()
        gc.collect()

if __name__ == "__main__":
    main()

2025-06-16 13:37:30.274860: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750081050.471836      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750081050.526934      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Testing data loading...
Testing data loading from: /kaggle/input/medical-dataset/Dataset/empatheticdialogues/train.csv


Loaded 76668 conversations
Sample conversation: {'conv_id': 'hit:0_conv:1', 'utterance_idx': 1, 'context': 'sentimental', 'prompt': 'I remember going to the fireworks with my best friend. There was a lot of people_comma_ but it only felt like us in the world.', 'speaker_idx': 1, 'utterance': 'I remember going to see the fireworks with my best friend. It was the first time we ever spent time alone together. Although there was a lot of people_comma_ we felt like the only people in the world.', 'selfeval': '5|5|5_2|2|5'}
Created 58829 training examples
Sample training example:
Emotion: sentimental
Instruction: You are an empathetic medical AI assistant. A patient is sharing their concerns with you.

Context: Patient emotion: sentimental. The patient is sharing a meaningful memory. Show empathy and emotional...
Response: Was this a friend you were in love with_comma_ or just a best friend?...




tokenizer_config.json:   0%|          | 0.00/143k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/587k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/3.67M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/252 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/787 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/4.14G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

trainable params: 3,407,872 || all params: 7,251,513,344 || trainable%: 0.0470


Tokenizing:   0%|          | 0/10000 [00:00<?, ? examples/s]

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
1,2.858
50,1.5704
100,0.5732
150,0.5291
200,0.5267
250,0.5174
300,0.4918
350,0.4814
400,0.518
450,0.4598


HTTP Error 429 thrown while requesting HEAD https://huggingface.co/ritvik77/Medical_Doctor_AI_LoRA-Mistral-7B-Instruct_FullModel/resolve/main/config.json


Retrying in 1s [Retry 1/5].


HTTP Error 429 thrown while requesting HEAD https://huggingface.co/ritvik77/Medical_Doctor_AI_LoRA-Mistral-7B-Instruct_FullModel/resolve/main/config.json


Retrying in 2s [Retry 2/5].


HTTP Error 429 thrown while requesting HEAD https://huggingface.co/ritvik77/Medical_Doctor_AI_LoRA-Mistral-7B-Instruct_FullModel/resolve/main/config.json


Retrying in 4s [Retry 3/5].


HTTP Error 429 thrown while requesting HEAD https://huggingface.co/ritvik77/Medical_Doctor_AI_LoRA-Mistral-7B-Instruct_FullModel/resolve/main/config.json


Retrying in 8s [Retry 4/5].


HTTP Error 429 thrown while requesting HEAD https://huggingface.co/ritvik77/Medical_Doctor_AI_LoRA-Mistral-7B-Instruct_FullModel/resolve/main/config.json


Retrying in 8s [Retry 5/5].


HTTP Error 429 thrown while requesting HEAD https://huggingface.co/ritvik77/Medical_Doctor_AI_LoRA-Mistral-7B-Instruct_FullModel/resolve/main/config.json
