# WEEK 1: DATA PREPROCESSING

## Step 1: Install Required Packages for Week 1 Data Preprocessing

This cell installs the necessary Python libraries for Week 1 data preprocessing tasks for the STEMRESEARCH project. The packages include tools for data manipulation (`pandas`), natural language processing (`spacy`, `medspacy`, `scispacy`), machine learning (`scikit-learn`), progress tracking (`tqdm`), sentence embeddings (`sentence-transformers`), and deep learning (`torch`). Refer to the GitHub repository [https://github.com/I-VAGAT/STEMRESEARCH](https://github.com/I-VAGAT/STEMRESEARCH) for project details.

### Command
```bash
!pip install pandas spacy medspacy datasets scikit-learn tqdm sentence-transformers torch medspacy scispacy


In [None]:
!pip install pandas spacy medspacy datasets scikit-learn tqdm sentence-transformers torch medspacy scispacy

Collecting scispacy
  Downloading scispacy-0.5.5-py3-none-any.whl.metadata (18 kB)
Collecting conllu (from scispacy)
  Downloading conllu-6.0.0-py3-none-any.whl.metadata (21 kB)
Collecting nmslib-metabrainz==2.1.3 (from scispacy)
  Downloading nmslib_metabrainz-2.1.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (956 bytes)
Collecting pybind11>=2.2.3 (from nmslib-metabrainz==2.1.3->scispacy)
  Downloading pybind11-2.13.6-py3-none-any.whl.metadata (9.5 kB)
Downloading scispacy-0.5.5-py3-none-any.whl (46 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.2/46.2 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading nmslib_metabrainz-2.1.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.1/14.1 MB[0m [31m112.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading conllu-6.0.0-py3-none-any.whl (16 kB)
Downloading pybind11-2.13.6-py3-none-any.whl (243 kB)
[2K   

In [None]:
pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_ner_bc5cdr_md-0.5.4.tar.gz

Collecting https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_ner_bc5cdr_md-0.5.4.tar.gz
  Downloading https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_ner_bc5cdr_md-0.5.4.tar.gz (119.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.8/119.8 MB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: en_ner_bc5cdr_md
  Building wheel for en_ner_bc5cdr_md (setup.py) ... [?25l[?25hdone
  Created wheel for en_ner_bc5cdr_md: filename=en_ner_bc5cdr_md-0.5.4-py3-none-any.whl size=119787677 sha256=7232c794e89d6d47888efd6330c0d77904e46d8536ee34c5950adeaae6453e8b
  Stored in directory: /root/.cache/pip/wheels/6e/a6/d6/bd15a41e2ff02a62f0a0a48dddbc07d048307db7199a1538f7
Successfully built en_ner_bc5cdr_md
Installing collected packages: en_ner_bc5cdr_md
Successfully installed en_ner_bc5cdr_md-0.5.4


## Step 2: Enhanced Medical Data Preparation Pipeline

This cell contains the main script for the **Enhanced Medical Data Preparation Pipeline** as part of the Week 1 data preprocessing tasks for the STEMRESEARCH project. The pipeline processes two datasets (`MentalChat16K` and `MedQuAD`) with a focus on biomedical named entity recognition (NER), protected health information (PHI) de-identification, and context-aware anonymization. Refer to the GitHub repository [https://github.com/I-VAGAT/STEMRESEARCH](https://github.com/I-VAGAT/STEMRESEARCH) for project details.

### Execution Context

**Note**: I did not run this script on Google Colab Pro+ during the first week because I didn't have access at that time. Instead, I executed it on my local machine, where the anonymization process took approximately **3 hours** to complete due to hardware limitations. So, I'm not re-running here; I will import that preprocessed data from my local machine.


### Key Features

- **Biomedical NER**: Uses `SciSpaCy` (`en_ner_bc5cdr_md`) to identify medical entities like diseases and chemicals.
- **PHI De-identification**: Uses `medSpaCy` to anonymize sensitive information (e.g., names, dates, IDs).
- **Context-Aware Anonymization**: Employs `SentenceTransformer` (`all-MiniLM-L6-v2`) to classify text context (medical vs. personal) for smarter anonymization decisions.
- **Regex Patterns**: Applies strict regex patterns to identify and anonymize sensitive data like emails, phone numbers, SSNs, etc.
- **Mental Health Focus**: Preserves relevant terms, organizations, and locations (e.g., "depression", "NIH", "Boston") while anonymizing personal data.
- **Error Handling**: Includes robust logging with suppression of specific regex errors (`"\." is not a eligible syntax`).
- **Data Splitting**: Splits the processed data into training, validation, and test sets (80/10/10 ratio).
- **Output**: Saves processed data as JSONL for fine-tuning, CSVs for analysis, statistics in JSON, and logs.

### Prerequisites

- Ensure the required packages are installed (Step 1 cell above).
- Enable GPU in Colab Pro+: `Runtime > Change runtime type > GPU` (optional, if you decide to re-run later).
- Download the `SciSpaCy` model `en_ner_bc5cdr_md` (run in a separate cell if not already installed, though not needed for importing):

```bash
!pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_ner_bc5cdr_md-0.5.4.tar.gz
```

## Important: Dataset Loading Issue in Colab

**Warning**: There is a known issue when loading datasets from Hugging Face in Google Colab due to the `**` (globstar) pattern used in the script. This causes the following error:

```
Error: Invalid pattern: '**' can only be an entire path component
```

This error is related to how Google Colab handles file patterns when downloading datasets from Hugging Face. The `**` pattern works fine locally but causes issues in Colab's environment.

### Recommendation:

**Option 1: Run on Local Machine**: It's better to run this script on your local machine, as I did initially, to avoid this issue. The script executed successfully on my local machine, albeit taking 3 hours for anonymization.

**Option 2: Download Data First in Colab**: Alternatively, you can manually download the MentalChat16K and MedQuAD datasets from Hugging Face, upload them to Colab, and modify the script to load the local files instead of directly fetching from Hugging Face. For example:

```python
# Example to load local dataset in Colab
mentalchat_df = pd.read_json('path_to_mentalchat16k.json')
medquad_df = pd.read_json('path_to_medquad.json')
```

## Code Overview

The script performs the following steps:

1. **Model Loading**: Loads SciSpaCy, medSpaCy, and SentenceTransformer models.
2. **Regex Validation**: Validates and tests regex patterns for identifying sensitive data.
3. **Dataset Loading**: Loads MentalChat16K and MedQuAD datasets from Hugging Face (requires modification if running in Colab due to the `**` issue).
4. **Preprocessing**:
   - Combines datasets and cleans them (removes duplicates, NaN values).
   - Applies context-aware anonymization using SciSpaCy and medSpaCy.
   - Validates question-answer pairs for quality.
5. **Statistics**: Computes and saves dataset statistics (e.g., record counts, text lengths).
6. **Data Splitting**: Splits data into training, validation, and test sets.
7. **Output Formatting**: Saves data in JSONL format for fine-tuning, CSVs for analysis, and logs processing details.

## Expected Output

The script generates the following files in the `combined_medical_data/` directory (already available from local execution):

- `train.jsonl`, `validation.jsonl`, `test.jsonl`: Formatted data for fine-tuning.
- `train.csv`, `validation.csv`, `test.csv`: Raw processed data for analysis.
- `dataset_statistics.json`: Detailed statistics of the dataset.
- `data_preparation.log`: Logs of the processing steps.

## Notes

- **Runtime**: On my local machine, the anonymization took 3 hours. Running on Colab Pro+ with GPU support would reduce this time if re-run, but the `**` issue needs to be resolved first.
- **Storage**: Ensure sufficient storage in Colab for importing and working with output files (use `!df -h` to check disk space).
- **Error Handling**: The script includes robust logging to `data_preparation.log`. Check this file (from local run) if issues arise.
- **Model Limitations**: SentenceTransformer required significant memory locally; Colab Pro+ would mitigate this if re-run.

In [None]:
import pandas as pd
import spacy
import medspacy
import re
from datasets import load_dataset
from sklearn.model_selection import train_test_split
import os
import json
import logging
from datetime import datetime
import numpy as np
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import warnings
import torch

warnings.filterwarnings("ignore")

# Custom logging filter to suppress specific message
class SuppressDotSyntaxFilter(logging.Filter):
    """Custom filter to suppress '\.' is not a eligible syntax messages."""
    def filter(self, record):
        return not ('"\." is not a eligible syntax.' in record.getMessage())

# Set up logging with reduced verbosity and custom filter
logging.basicConfig(
    filename='data_preparation.log',
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    filemode='w'  # Overwrite log file each run
)
logger = logging.getLogger()
logger.addFilter(SuppressDotSyntaxFilter())

# Global counter for regex errors to prevent log spam
regex_error_counts = {}
max_regex_errors_per_pattern = 5

# Load models
try:
    nlp_bio = spacy.load("en_ner_bc5cdr_md")  # SciSpaCy for biomedical NER
    nlp_phi = medspacy.load()  # medSpaCy for PHI de-identification
    context_model = SentenceTransformer('all-MiniLM-L6-v2', device='cuda' if torch.cuda.is_available() else 'cpu')
    print("Successfully loaded SciSpaCy, medSpaCy, and SentenceTransformer models.")
except Exception as e:
    logging.error(f"Failed to load models: {e}")
    print(f"Error: Failed to load models - {e}")
    raise

# Medical terms that should NEVER be anonymized
MEDICAL_TERMS = {
    'melatonin', 'ambien', 'zolpidem', 'prozac', 'lexapro', 'xanax', 'klonopin',
    'sertraline', 'escitalopram', 'fluoxetine', 'paroxetine', 'venlafaxine',
    'duloxetine', 'bupropion', 'mirtazepine', 'trazodone', 'buspirone',
    'lorazepam', 'clonazepam', 'diazepam', 'alprazolam', 'temazepam',
    'depression', 'anxiety', 'ptsd', 'ocd', 'schizophrenia', 'adhd', 'autism',
    'bipolar', 'mania', 'hypomania', 'panic', 'phobia', 'agoraphobia',
    'insomnia', 'narcolepsy', 'diabetes', 'hypothyroidism', 'fibromyalgia',
    'arthritis', 'hypertension', 'hypotension', 'migraine', 'epilepsy',
    'therapy', 'cbt', 'dbt', 'emdr', 'psychotherapy', 'counseling',
    'medication', 'antidepressant', 'antipsychotic', 'anxiolytic', 'sedative',
    'stimulant', 'mood stabilizer', 'anticonvulsant', 'beta blocker',
    'ssri', 'snri', 'tricyclic', 'maoi', 'benzodiazepine', 'barbiturate',
    # Mental health specific terms
    'reddit', 'subreddit', 'throwaway', 'anonymous', 'support group',
    'online community', 'forum', 'chat room', 'helpline', 'crisis line',
    'peer support', 'self help', 'coping mechanism', 'trigger warning',
    'mental health professional', 'licensed therapist', 'counselor'
}

# Medical organizations that should NEVER be anonymized
MEDICAL_ORGANIZATIONS = {
    'nih', 'national institutes of health', 'nimh', 'cdc', 'fda', 'who',
    'mayo clinic', 'cleveland clinic', 'johns hopkins', 'ama', 'apa', 'nami',
    'samhsa', 'american psychiatric association', 'american medical association',
    'world health organization', 'centers for disease control', 'nhs',
    'kaiser permanente', 'veterans affairs', 'va hospital', 'psychiatry today',
    'psychology today'
}

# Medical locations/institutions that should be preserved
MEDICAL_LOCATIONS = {
    'boston', 'baltimore', 'houston', 'atlanta', 'new york', 'chicago',
    'massachusetts', 'california', 'texas', 'pennsylvania', 'maryland',
    'florida', 'washington', 'oregon', 'michigan', 'minnesota'
}

# Context embeddings for classification
MEDICAL_CONTEXTS = [
    "medical research clinical trial study disease treatment diagnosis symptoms",
    "hospital clinic doctor physician medical institution healthcare",
    "medication drug prescription therapy treatment pharmacology",
    "neurological disorder brain nervous system condition syndrome",
    "genetic hereditary syndrome disease mutation biomarkers",
    "gastrointestinal digestive tract inflammation chronic disease",
    "mental health psychiatric psychological therapy disorder"
]

PERSONAL_CONTEXTS = [
    "my friend family member personal relationship spouse partner",
    "someone I know personal story experience individual",
    "private individual person name identity patient"
]

# FIXED REGEX PATTERNS with specified patterns removed
REGEX_PATTERNS = {
    'EMAIL': r'\b[\w.-]+@[\w.-]+\.\w+\b',
    'PHONE': r'\b(?:\(\d{3}\)\s*|\d{3}[.-]?)\d{3}[.-]?\d{4}\b',
    'SSN': r'\b\d{3}[.-]?\d{2}[.-]?\d{4}\b',
    'PATIENT_ID': r'\b(?:patient|case|record|id)\s*(?:id|#|num|number)?\s*:?\s*\d+\b',
    'USERNAME': r'\b@[A-Za-z0-9_]{3,20}\b',
    'IP_ADDRESS': r'\b(?:\d{1,3}\.){3}\d{1,3}\b',
    'CREDIT_CARD': r'\b(?:\d{4}[.\s-]?){3}\d{4}\b',
    'DATE': r'\b\d{1,2}[/-]\d{1,2}[/-]\d{4}\b',
    'MEDICAL_RECORD': r'\b(?:mrn|medical record|chart)\s*#?\s*:?\s*\d+\b'
}

def validate_regex_patterns():
    """Validate all regex patterns before use and return only valid ones."""
    valid_patterns = {}
    invalid_count = 0

    print("Validating regex patterns...")
    for label, pattern in REGEX_PATTERNS.items():
        try:
            re.compile(pattern)
            valid_patterns[label] = pattern
            print(f"✓ {label}: Pattern valid")
        except re.error as e:
            print(f"✗ {label}: Pattern invalid - {e}")
            logging.error(f"Invalid regex pattern for {label}: {pattern} - {e}")
            invalid_count += 1
        except Exception as e:
            print(f"✗ {label}: Unexpected error - {e}")
            logging.error(f"Unexpected error validating pattern for {label}: {pattern} - {e}")
            invalid_count += 1

    print(f"Pattern validation complete: {len(valid_patterns)} valid, {invalid_count} invalid")
    return valid_patterns

def test_regex_patterns(patterns):
    """Test regex patterns with sample text."""
    test_text = """
    Test email@example.com phone (555) 123-4567 SSN 123-45-6789
    Patient ID: 12345 @username
    IP: 192.168.1.1
    Credit card: 1234-5678-9012-3456 DOB: 12/25/1980
    MRN: 987654
    Phone: 555-123-4567 555.123.4567
    """

    print("\nTesting regex patterns with sample text...")
    for label, pattern in patterns.items():
        try:
            matches = re.findall(pattern, test_text, re.IGNORECASE)
            print(f"{label}: {len(matches)} matches found - {matches}")
        except Exception as e:
            print(f"{label}: ERROR during testing - {e}")

def get_context_embeddings():
    """Pre-compute embeddings for context classification."""
    try:
        medical_embeddings = context_model.encode(MEDICAL_CONTEXTS, show_progress_bar=False)
        personal_embeddings = context_model.encode(PERSONAL_CONTEXTS, show_progress_bar=False)
        return medical_embeddings, personal_embeddings
    except Exception as e:
        logging.error(f"Error computing context embeddings: {e}")
        return None, None

def classify_context(text_window, medical_embeddings, personal_embeddings, threshold=0.4):
    """Classify context as medical, personal, or neutral."""
    try:
        if medical_embeddings is None or personal_embeddings is None:
            return 'neutral'

        text_embedding = context_model.encode([text_window], show_progress_bar=False)
        medical_similarities = cosine_similarity(text_embedding, medical_embeddings)[0]
        personal_similarities = cosine_similarity(text_embedding, personal_embeddings)[0]

        max_medical_sim = np.max(medical_similarities)
        max_personal_sim = np.max(personal_similarities)

        if max_medical_sim > threshold and max_medical_sim >= max_personal_sim:
            return 'medical'
        elif max_personal_sim > threshold and max_personal_sim > max_medical_sim:
            return 'personal'
        else:
            return 'neutral'
    except Exception as e:
        logging.warning(f"Context classification error: {e}")
        return 'neutral'

def is_medical_term(entity_text):
    """Check if the entity is a medical term or condition."""
    entity_lower = entity_text.lower().strip()
    if entity_lower in MEDICAL_TERMS:
        return True

    medical_patterns = [
        r'\b\w+\s*(?:disease|disorder|syndrome|condition)\b',
        r'\b\w+\s*(?:colitis|hepatitis|itis)\b',
        r'\b(?:mg|ml|mcg|units?|dosage|tablet|capsule|injection)\b'
    ]
    return any(re.search(pattern, entity_lower, re.IGNORECASE) for pattern in medical_patterns)

def is_medical_organization(entity_text, context_window, context_type):
    """Check for medical organizations using context."""
    entity_lower = entity_text.lower().strip()
    if entity_lower in MEDICAL_ORGANIZATIONS:
        return True

    if context_type == 'medical':
        medical_keywords = [
            'institute', 'hospital', 'clinic', 'medical', 'health', 'research',
            'association', 'foundation', 'center', 'university', 'college',
            'department', 'division', 'school of medicine'
        ]
        if any(keyword in entity_lower for keyword in medical_keywords):
            return True
    return False

def is_medical_location(entity_text, context_window, context_type):
    """Check if a location should be preserved in medical context."""
    entity_lower = entity_text.lower().strip()
    if entity_lower in MEDICAL_LOCATIONS and context_type == 'medical':
        return True

    context_lower = context_window.lower()
    medical_location_indicators = [
        'medical center', 'hospital', 'clinic', 'university', 'institute',
        'research facility', 'healthcare system', 'medical school'
    ]
    if any(indicator in context_lower for indicator in medical_location_indicators):
        return True
    return False

def is_real_person_name(entity_text, context_window, context_type):
    """Enhanced person name detection using context."""
    entity_lower = entity_text.lower().strip()

    if is_medical_term(entity_lower):
        return False

    common_non_names = {
        'i', 'me', 'my', 'dr', 'mr', 'ms', 'mrs', 'prof', 'professor',
        'patient', 'doctor', 'nurse', 'therapist', 'psychiatrist', 'user',
        'reddit', 'anonymous', 'throwaway'
    }
    if entity_lower in common_non_names:
        return False

    if any(pattern in entity_lower for pattern in ['user', 'anon', 'throwaway', '_', '123', '456', '789']):
        return False

    if context_type == 'medical':
        personal_indicators = [
            'my friend', 'my family', 'my spouse', 'my partner', 'my mother',
            'my father', 'my sister', 'my brother', 'someone i know',
            'a person', 'individual named', 'patient named'
        ]
        context_lower = context_window.lower()
        if any(indicator in context_lower for indicator in personal_indicators):
            return True
        return False
    elif context_type == 'personal':
        return True

    if len(entity_text) <= 2:
        return False

    return True

def safe_regex_search(pattern, text, flags=re.IGNORECASE):
    """Safely apply regex patterns with error handling and limited logging."""
    global regex_error_counts, max_regex_errors_per_pattern

    try:
        return list(re.finditer(pattern, text, flags))
    except re.error as e:
        error_msg = str(e)
        if '"\." is not a eligible syntax.' in error_msg:
            return []  # Silently skip this specific error
        if pattern not in regex_error_counts:
            regex_error_counts[pattern] = 0

        regex_error_counts[pattern] += 1
        if regex_error_counts[pattern] <= max_regex_errors_per_pattern:
            logging.warning(f"Regex error for pattern '{pattern}': {e}")
        elif regex_error_counts[pattern] == max_regex_errors_per_pattern + 1:
            logging.warning(f"Regex error limit reached for pattern '{pattern}', suppressing further errors")
        return []
    except Exception as e:
        error_key = f"unexpected_{pattern}"
        if error_key not in regex_error_counts:
            regex_error_counts[error_key] = 0

        regex_error_counts[error_key] += 1
        if regex_error_counts[error_key] <= max_regex_errors_per_pattern:
            logging.warning(f"Unexpected error for pattern '{pattern}': {e}")
        return []

def combined_ner_anonymize(text, nlp_bio, nlp_phi, medical_embeddings, personal_embeddings, valid_patterns):
    """Combined SciSpaCy and medSpaCy anonymization with context awareness."""
    if not isinstance(text, str):
        return text

    try:
        context_type = classify_context(text, medical_embeddings, personal_embeddings)
        entities = []

        try:
            doc_bio = nlp_bio(text)
            for ent in doc_bio.ents:
                entities.append((ent.start_char, ent.end_char, ent.label_, ent.text))
        except Exception as e:
            logging.warning(f"SciSpaCy processing error: {e}")

        try:
            doc_phi = nlp_phi(text)
            for ent in doc_phi.ents:
                entities.append((ent.start_char, ent.end_char, ent.label_, ent.text))
        except Exception as e:
            logging.warning(f"medSpaCy processing error: {e}")

        for label, pattern in valid_patterns.items():
            matches = safe_regex_search(pattern, text)
            for match in matches:
                entities.append((match.start(), match.end(), label, match.group()))

        entities.sort(key=lambda x: (x[0], -x[1]))
        filtered_entities = []
        prev_end = -1

        for start, end, label, entity_text in entities:
            if start >= prev_end:
                filtered_entities.append((start, end, label, entity_text))
                prev_end = end
            elif label in ['DISEASE', 'CHEMICAL']:
                if filtered_entities:
                    filtered_entities[-1] = (start, end, label, entity_text)
                else:
                    filtered_entities.append((start, end, label, entity_text))
                prev_end = end

        entities_to_replace = []
        for start, end, label, entity_text in filtered_entities:
            if is_medical_term(entity_text) or label in ['DISEASE', 'CHEMICAL']:
                continue

            context_start = max(0, start - 200)
            context_end = min(len(text), end + 200)
            context_window = text[context_start:context_end]
            entity_context_type = classify_context(context_window, medical_embeddings, personal_embeddings)

            should_anonymize = False

            if label == 'PERSON' or label in ['PATIENT', 'DOCTOR', 'NAME']:
                should_anonymize = is_real_person_name(entity_text, context_window, entity_context_type)
            elif label == 'ORG' or label == 'ORGANIZATION':
                should_anonymize = not is_medical_organization(entity_text, context_window, entity_context_type)
            elif label == 'GPE' or label in ['LOCATION', 'CITY', 'STATE', 'COUNTRY']:
                should_anonymize = not is_medical_location(entity_text, context_window, entity_context_type)
            elif label in valid_patterns:
                should_anonymize = True
            elif label in ['WORK_OF_ART', 'PRODUCT', 'EVENT']:
                should_anonymize = not (entity_context_type == 'medical' or is_medical_term(entity_text))

            if should_anonymize:
                entities_to_replace.append((start, end, label, entity_text))

        entities_to_replace.sort(key=lambda x: x[0], reverse=True)

        anonymized_text = text
        for start, end, label, orig_text in entities_to_replace:
            placeholder = f"[{label}]" if label in valid_patterns else "[ANONYMIZED]"
            logging.info(f"Anonymizing: {orig_text} (Label: {label}) -> {placeholder}")
            anonymized_text = anonymized_text[:start] + placeholder + anonymized_text[end:]

        return anonymized_text

    except Exception as e:
        logging.error(f"Error in combined anonymization: {e}")
        return text

def anonymize_row(row, nlp_bio, nlp_phi, medical_embeddings, personal_embeddings, valid_patterns):
    """Apply combined anonymization to a DataFrame row."""
    try:
        row['question'] = combined_ner_anonymize(row['question'], nlp_bio, nlp_phi, medical_embeddings, personal_embeddings, valid_patterns)
        row['answer'] = combined_ner_anonymize(row['answer'], nlp_bio, nlp_phi, medical_embeddings, personal_embeddings, valid_patterns)
        return row
    except Exception as e:
        logging.error(f"Error anonymizing row: {e}")
        return row

def validate_pair(row, min_question_len=10, max_question_len=12000, min_answer_len=20, max_answer_len=30000):
    """Validate question-answer pair for quality."""
    try:
        question = row['question'] if hasattr(row, '__getitem__') else row.question
        answer = row['answer'] if hasattr(row, '__getitem__') else row.answer

        q_len = len(str(question))
        a_len = len(str(answer))

        if not (min_question_len <= q_len <= max_question_len):
            return False, f"Question length out of range ({q_len})"
        if not (min_answer_len <= a_len <= max_answer_len):
            return False, f"Answer length out of range ({a_len})"
        if not isinstance(question, str) or not isinstance(answer, str):
            return False, "Non-string question or answer"
        if pd.isna(question) or pd.isna(answer):
            return False, "Missing question or answer"
        return True, ""
    except Exception as e:
        logging.error(f"Error validating pair: {e}")
        return False, str(e)

def load_datasets():
    """Load datasets from Hugging Face."""
    try:
        print("Loading MentalChat16K dataset...")
        mentalchat_ds = load_dataset("ShenLab/MentalChat16K", split="train")
        print("Loading MedQuAD dataset...")
        medquad_ds = load_dataset("lavita/MedQuAD", split="train")

        print(f"Successfully loaded MentalChat16K: {len(mentalchat_ds)} records.")
        print(f"Successfully loaded MedQuAD: {len(medquad_ds)} records.")
        return pd.DataFrame(mentalchat_ds), pd.DataFrame(medquad_ds)
    except Exception as e:
        logging.error(f"Error loading datasets: {e}")
        print(f"Error: Failed to load datasets - {e}")
        return None, None

def preprocess_data(mentalchat_df, medquad_df, nlp_bio, nlp_phi, valid_patterns):
    """Preprocess data with combined anonymization."""
    try:
        mentalchat_df = mentalchat_df.rename(columns={'input': 'question', 'output': 'answer'})
        medquad_df = medquad_df.rename(columns=lambda x: x.lower().strip())

        mentalchat_df['source'] = 'mentalchat'
        medquad_df['source'] = 'medquad'

        combined_df = pd.concat([mentalchat_df[['question', 'answer', 'source']],
                                medquad_df[['question', 'answer', 'source']]],
                                ignore_index=True)

        initial_count = len(combined_df)
        combined_df = combined_df.dropna()
        combined_df = combined_df.drop_duplicates(subset=['question', 'answer'], keep='first')
        print(f"Dataset cleaned: {len(combined_df)} records (removed {initial_count - len(combined_df)} invalid/duplicate entries)")

        print("Preparing context classification models...")
        medical_embeddings, personal_embeddings = get_context_embeddings()

        if medical_embeddings is None or personal_embeddings is None:
            print("Warning: Context classification unavailable, proceeding with basic anonymization")

        print("Starting combined SciSpaCy and medSpaCy anonymization...")
        tqdm.pandas(desc="Anonymizing with context awareness")
        combined_df = combined_df.progress_apply(
            lambda row: anonymize_row(row, nlp_bio, nlp_phi, medical_embeddings, personal_embeddings, valid_patterns),
            axis=1
        )
        print("Successfully completed combined anonymization.")

        print("Validating processed data...")
        valid_rows = []
        filtered_count = 0

        for _, row in tqdm(combined_df.iterrows(), total=len(combined_df), desc="Validating"):
            is_valid, reason = validate_pair(row)
            if is_valid:
                valid_rows.append(row)
            else:
                filtered_count += 1
                if filtered_count <= 10:
                    logging.info(f"Filtered record: {reason}")

        combined_df = pd.DataFrame(valid_rows)
        print(f"Final dataset: {len(combined_df)} records (filtered {filtered_count} invalid entries)")

        return combined_df

    except Exception as e:
        logging.error(f"Error in preprocessing: {e}")
        print(f"Error: Failed to preprocess data - {e}")
        return None

def compute_statistics(df, output_dir):
    """Compute and save dataset statistics."""
    try:
        os.makedirs(output_dir, exist_ok=True)

        stats = {
            "total_records": len(df),
            "mentalchat_records": len(df[df['source'] == 'mentalchat']),
            "medquad_records": len(df[df['source'] == 'medquad']),
            "question_length_mean": float(df['question'].str.len().mean()),
            "question_length_std": float(df['question'].str.len().std()),
            "question_length_min": int(df['question'].str.len().min()),
            "question_length_max": int(df['question'].str.len().max()),
            "answer_length_mean": float(df['answer'].str.len().mean()),
            "answer_length_std": float(df['answer'].str.len().std()),
            "answer_length_min": int(df['answer'].str.len().min()),
            "answer_length_max": int(df['answer'].str.len().max()),
            "unique_questions": df['question'].nunique(),
            "unique_answers": df['answer'].nunique(),
            "processing_timestamp": datetime.now().isoformat(),
            "regex_error_summary": dict(regex_error_counts)
        }

        stats_path = os.path.join(output_dir, 'dataset_statistics.json')
        with open(stats_path, 'w') as f:
            json.dump(stats, f, indent=4)

        print(f"Dataset statistics saved to {stats_path}")
        print("Dataset Statistics:")
        for key, value in stats.items():
            if key == "regex_error_summary":
                continue
            if isinstance(value, float):
                print(f"  {key}: {value:.2f}")
            else:
                print(f"  {key}: {value}")

        if regex_error_counts:
            print("\nRegex Error Summary:")
            for pattern, count in regex_error_counts.items():
                print(f"  {pattern}: {count} errors")

        return stats
    except Exception as e:
        logging.error(f"Error computing statistics: {e}")
        return None

def split_data(combined_df, output_dir, train_ratio=0.8, val_ratio=0.1):
    """Split data into training, validation, and test sets."""
    try:
        os.makedirs(output_dir, exist_ok=True)

        train_df, temp_df = train_test_split(
            combined_df,
            train_size=train_ratio,
            random_state=42,
            stratify=combined_df['source']
        )

        val_size = val_ratio / (1 - train_ratio)
        val_df, test_df = train_test_split(
            temp_df,
            train_size=val_size,
            random_state=42,
            stratify=temp_df['source']
        )

        print(f"Data split - Train: {len(train_df)}, Validation: {len(val_df)}, Test: {len(test_df)}")
        print(f"Train split by source: MentalChat={len(train_df[train_df['source']=='mentalchat'])}, MedQuAD={len(train_df[train_df['source']=='medquad'])}")
        print(f"Validation split by source: MentalChat={len(val_df[val_df['source']=='mentalchat'])}, MedQuAD={len(val_df[val_df['source']=='medquad'])}")
        print(f"Test split by source: MentalChat={len(test_df[test_df['source']=='mentalchat'])}, MedQuAD={len(test_df[test_df['source']=='medquad'])}")

        return train_df, val_df, test_df
    except Exception as e:
        logging.error(f"Error splitting data: {e}")
        return None, None, None

def format_for_finetuning(df, output_path, source_label):
    """Format data as JSONL for fine-tuning."""
    try:
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        with open(output_path, 'w', encoding='utf-8') as f:
            for _, row in df.iterrows():
                entry = {
                    "text": f"Question: {row['question']}\nAnswer: {row['answer']}",
                    "source": row['source'],
                    "timestamp": datetime.now().isoformat(),
                    "metadata": {
                        "question_length": len(str(row['question'])),
                        "answer_length": len(str(row['answer']))
                    }
                }
                json.dump(entry, f, ensure_ascii=False)
                f.write('\n')

        print(f"Successfully saved {source_label} JSONL to {output_path} ({len(df)} records)")
    except Exception as e:
        logging.error(f"Error formatting JSONL for {source_label}: {e}")

def main(output_dir="combined_medical_data"):
    """Main function with combined SciSpaCy and medSpaCy anonymization."""
    print("Starting ENHANCED Medical Data Preparation Pipeline...")
    print("Key features:")
    print("- SciSpaCy (en_ner_bc5cdr_md) for biomedical NER")
    print("- medSpaCy for clinical PHI de-identification")
    print("- Context-aware anonymization with SentenceTransformers")
    print("- Reduced regex patterns for stricter anonymization")
    print("- Mental health community specific patterns")
    print("- FIXED regex patterns with validation")
    print("- Limited error logging to prevent log spam")
    print("- Preserves medical terms, organizations, and locations")
    print("- Suppressing '\.' is not a eligible syntax errors")
    print("- Removed CHAT_HANDLE, SOCIAL_HANDLE, THROWAWAY_ACCOUNT, URL, INSURANCE_ID, REDDIT_USER patterns")
    print("-" * 60)

    # Validate and test regex patterns
    valid_patterns = validate_regex_patterns()
    if not valid_patterns:
        print("ERROR: No valid regex patterns. Exiting.")
        return

    test_regex_patterns(valid_patterns)

    # Load datasets
    mentalchat_df, medquad_df = load_datasets()
    if mentalchat_df is None or medquad_df is None:
        print("ERROR: Failed to load datasets. Exiting.")
        return

    # Preprocess data
    combined_df = preprocess_data(mentalchat_df, medquad_df, nlp_bio, nlp_phi, valid_patterns)
    if combined_df is None:
        print("ERROR: Preprocessing failed. Exiting.")
        return

    # Compute statistics
    stats = compute_statistics(combined_df, output_dir)
    if stats is None:
        print("ERROR: Failed to compute statistics. Exiting.")
        return

    # Split data
    train_df, val_df, test_df = split_data(combined_df, output_dir)
    if train_df is None:
        print("ERROR: Data splitting failed. Exiting.")
        return

    # Save formatted data
    format_for_finetuning(train_df, os.path.join(output_dir, 'train.jsonl'), 'train')
    format_for_finetuning(val_df, os.path.join(output_dir, 'validation.jsonl'), 'validation')
    format_for_finetuning(test_df, os.path.join(output_dir, 'test.jsonl'), 'test')

    # Save raw dataframes as well
    train_df.to_csv(os.path.join(output_dir, 'train.csv'), index=False)
    val_df.to_csv(os.path.join(output_dir, 'validation.csv'), index=False)
    test_df.to_csv(os.path.join(output_dir, 'test.csv'), index=False)

    print("\n" + "="*60)
    print("ENHANCED PIPELINE COMPLETED SUCCESSFULLY!")
    print("="*60)
    print("Benefits of this approach:")
    print("✓ Accurate biomedical NER with SciSpaCy")
    print("✓ Robust PHI de-identification with medSpaCy")
    print("✓ Context-aware decisions using SentenceTransformers")
    print("✓ Reduced regex patterns for stricter anonymization")
    print("✓ Mental health community specific anonymization")
    print("✓ Fixed regex patterns prevent syntax errors")
    print("✓ Preserves medical research value while protecting privacy")
    print("✓ Improved error handling and logging")
    print("✓ Suppressed '\.' is not a eligible syntax errors")
    print("✓ Removed CHAT_HANDLE, SOCIAL_HANDLE, THROWAWAY_ACCOUNT, URL, INSURANCE_ID, REDDIT_USER patterns")

    print(f"\nOutput files saved to {output_dir}/:")
    print("  - train.jsonl, validation.jsonl, test.jsonl (for fine-tuning)")
    print("  - train.csv, validation.csv, test.csv (for analysis)")
    print("  - dataset_statistics.json (detailed statistics)")
    print("  - data_preparation.log (processing logs)")

if __name__ == "__main__":
    OUTPUT_DIR = "combined_medical_data"
    main(OUTPUT_DIR)


## Step 3: Upload Preprocessed Data to Colab

This cell uploads the preprocessed data files (previously generated on my local machine) into the Google Colab Pro+ environment for further analysis or fine-tuning. The script creates a directory `combined_medical_data/` (if it doesn't already exist) and moves the uploaded files into it. This step follows the preprocessing pipeline executed locally, as described in Step 2, where I processed the `MentalChat16K` and `MedQuAD` datasets (anonymization took 3 hours on my local machine due to hardware limitations). Refer to the GitHub repository [https://github.com/I-VAGAT/STEMRESEARCH](https://github.com/I-VAGAT/STEMRESEARCH) for project details.

### Code Overview
The script performs the following steps:
1. **Directory Creation**: Creates the `combined_medical_data/` directory to store the uploaded files.
2. **File Upload**: Prompts to upload files from the local machine using Colab's `files.upload()` interface.
3. **File Organization**: Moves the uploaded files into the `combined_medical_data/` directory.
4. **Verification**: Lists the files in the `combined_medical_data/` directory to confirm successful upload.

### Instructions
1. Run the cell below.
2. When prompted, upload the preprocessed files from your local `combined_medical_data/` folder (e.g., `train.csv`, `validation.csv`, `test.csv`, `train.jsonl`, `validation.jsonl`, `test.jsonl`, `dataset_statistics.json`, `data_preparation.log`).
3. The script will move the files to `combined_medical_data/` and display the list of uploaded files.

### Expected Output
After running the cell and uploading the files, you should see:
- A confirmation message for each file moved to `combined_medical_data/`.
- A list of all files in the `combined_medical_data/` directory, such as:

In [None]:
# Upload files directly in Colab
from google.colab import files
import os

# Create the directory structure
os.makedirs('combined_medical_data', exist_ok=True)

# Upload files one by one
print("Upload your files from combined_medical_data folder:")
uploaded = files.upload()

# Move uploaded files to the correct directory
for filename in uploaded.keys():
    os.rename(filename, f'combined_medical_data/{filename}')
    print(f'Moved {filename} to combined_medical_data/')

# Verify the upload
print("\nFiles in combined_medical_data:")
for file in os.listdir('combined_medical_data'):
    print(f"  - {file}")

Upload your files from combined_medical_data folder:


Saving dataset_statistics.json to dataset_statistics.json
Saving test.csv to test.csv
Saving test.jsonl to test.jsonl
Saving train.csv to train.csv
Saving train.jsonl to train.jsonl
Saving validation.csv to validation.csv
Saving validation.jsonl to validation.jsonl
Moved dataset_statistics.json to combined_medical_data/
Moved test.csv to combined_medical_data/
Moved test.jsonl to combined_medical_data/
Moved train.csv to combined_medical_data/
Moved train.jsonl to combined_medical_data/
Moved validation.csv to combined_medical_data/
Moved validation.jsonl to combined_medical_data/

Files in combined_medical_data:
  - validation.csv
  - train.jsonl
  - validation.jsonl
  - test.csv
  - test.jsonl
  - dataset_statistics.json
  - train.csv


# WEEK 2: FINE-TUNING

## Step 1: Install Requirements for Week 2 Fine-Tuning

This cell installs the necessary Python libraries for Week 2 tasks, focusing on fine-tuning a language model for the STEMRESEARCH project. The packages include tools for transformer models (`transformers`), datasets (`datasets`), parameter-efficient fine-tuning (`peft`), deep learning (`torch`), data manipulation (`pandas`, `numpy`), visualization (`matplotlib`, `seaborn`), progress tracking (`tqdm`), experiment tracking (`wandb`), optimization (`accelerate`), and quantization (`bitsandbytes`). Refer to the GitHub repository [https://github.com/I-VAGAT/STEMRESEARCH](https://github.com/I-VAGAT/STEMRESEARCH) for project details.

### Code Overview
The command installs the following packages:
- `transformers`: For loading and fine-tuning pre-trained transformer models.
- `datasets`: For managing and loading datasets (e.g., preprocessed data from Week 1).
- `peft`: For parameter-efficient fine-tuning techniques.
- `torch`: The PyTorch framework for deep learning.
- `pandas` and `numpy`: For data manipulation and numerical operations.
- `matplotlib` and `seaborn`: For data visualization.
- `tqdm`: For progress bars during training.
- `wandb`: For experiment tracking and visualization.
- `accelerate`: For optimized distributed training.
- `bitsandbytes`: For 8-bit optimization and quantization to reduce memory usage.

### Command
```bash
!pip install transformers datasets peft torch pandas numpy matplotlib seaborn tqdm wandb accelerate bitsandbytes


In [None]:
!pip install transformers datasets peft torch pandas numpy matplotlib seaborn tqdm wandb accelerate bitsandbytes

Collecting bitsandbytes
  Downloading bitsandbytes-0.46.0-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Downloading bitsandbytes-0.46.0-py3-none-manylinux_2_24_x86_64.whl (67.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.0/67.0 MB[0m [31m137.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.46.0


## Step 2: Install PyTorch with CUDA Support for Week 2 Fine-Tuning

This cell installs PyTorch, torchvision, and torchaudio with CUDA 12.1 support, which is essential for leveraging GPU acceleration during the fine-tuning tasks in Week 2 of the STEMRESEARCH project. This step ensures that the deep learning framework (`torch`) installed in the previous step is compatible with CUDA for faster computation. Refer to the GitHub repository [https://github.com/I-VAGAT/STEMRESEARCH](https://github.com/I-VAGAT/STEMRESEARCH) for project details.

### Code Overview
The command installs the following:
- `torch`: The PyTorch library for deep learning.
- `torchvision`: For computer vision tasks and datasets (included for completeness).
- `torchaudio`: For audio processing (included for completeness).
- The `--index-url https://download.pytorch.org/whl/cu121` flag specifies the PyTorch wheel with CUDA 12.1 support.

### Command
```bash
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121


In [None]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

Looking in indexes: https://download.pytorch.org/whl/cu121


## Step 3: Verify PyTorch Installation and CUDA Setup for Week 2 Fine-Tuning

This cell verifies the installation of PyTorch with CUDA support, ensuring that the environment is correctly set up for GPU-accelerated fine-tuning in Week 2 of the STEMRESEARCH project. The script checks the PyTorch version, confirms CUDA availability, and retrieves the name of the GPU device. This step follows the installation of PyTorch with CUDA 12.1 support (Step 2) and the required packages for fine-tuning (Step 1). Refer to the GitHub repository [https://github.com/I-VAGAT/STEMRESEARCH](https://github.com/I-VAGAT/STEMRESEARCH) for project details.

### Code Overview
The script performs the following checks:
1. **PyTorch Version**: Prints the installed version of PyTorch (`torch.__version__`).
2. **CUDA Availability**: Confirms whether CUDA is available for GPU acceleration (`torch.cuda.is_available()`).
3. **GPU Device Name**: Retrieves the name of the GPU device (`torch.cuda.get_device_name(0)`).

### Command
```python
import torch
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))


In [None]:
import torch
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))


2.6.0+cu124
True
NVIDIA A100-SXM4-40GB


## Step 4: Fine-Tune Mistral-7B with QLoRA and Safety Guardrails for Week 2

This cell contains the main script for fine-tuning the **Mistral-7B-Instruct-v0.1** model using QLoRA (Quantized Low-Rank Adaptation) for the Week 2 tasks of the STEMRESEARCH project. The pipeline integrates a DistilBERT-based query routing classifier, advanced safety guardrails for mental health and medical responses, and processes preprocessed data from Week 1 (stored in `combined_medical_data/`). The model is optimized for dual-purpose use (mental health support and medical information) with BF16 precision and domain-specific weighting. Refer to the GitHub repository [https://github.com/I-VAGAT/STEMRESEARCH](https://github.com/I-VAGAT/STEMRESEARCH) for project details.

### Execution Context
This script builds on the preprocessed data generated locally in Week 1 (anonymization took ~3 hours) and uploaded to Colab (Step 3). It leverages Colab Pro+'s GPU support for efficient training.

### Key Features
- **Model Fine-Tuning**: Uses Mistral-7B with QLoRA for parameter-efficient fine-tuning.
- **Query Routing**: Trains a DistilBERT classifier to route queries to mental health or medical domains.
- **Safety Guardrails**: Implements crisis detection, professional guidance prompts, and medical disclaimers.
- **Data Processing**: Converts JSONL data into an instruction-following format with empathy/accuracy markers.
- **Training Optimization**: Employs BF16 precision, gradient checkpointing, and domain-weighted loss.
- **Output**: Saves the fine-tuned model, routing classifier, configuration, and logs in `mistral_mental_medical_chatbot/`.

### Prerequisites
- Required packages installed (Steps 1 and 2).
- GPU enabled (`Runtime > Change runtime type > GPU`) for CUDA support.
- Preprocessed data uploaded to `combined_medical_data/`.
- Hugging Face token for model access (prompted during execution; e.g., `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`).
- Optional: Set `WANDB_API_KEY` as an environment variable for experiment tracking:
  ```python
  import os
  os.environ["WANDB_API_KEY"] = "your_wandb_api_key"


In [None]:
from huggingface_hub import login, whoami
login(token="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")
import os
os.system('chcp 65001')
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification,
    Trainer, TrainingArguments, DataCollatorForLanguageModeling,
    EarlyStoppingCallback, get_linear_schedule_with_warmup,
    pipeline
)
from datasets import Dataset, DatasetDict
import pandas as pd
from peft import LoraConfig, get_peft_model, TaskType, PeftModel, prepare_model_for_kbit_training
from transformers.trainer_utils import EvalPrediction
import numpy as np
import logging
import os
import wandb
from datetime import datetime
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple, Optional, Union
import warnings
import re
from tqdm import tqdm
import pickle
warnings.filterwarnings("ignore")

# Check if BitsAndBytes GPU support is available
try:
    from transformers import BitsAndBytesConfig
    BITSANDBYTES_AVAILABLE = True
except ImportError:
    BITSANDBYTES_AVAILABLE = False
    print("BitsAndBytesConfig not available, will use regular model loading")

# Configure advanced logging with proper encoding
def setup_logging():
    """Setup comprehensive logging system with UTF-8 encoding, no console output."""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_dir = "logs"
    os.makedirs(log_dir, exist_ok=True)

    # Create file handler with UTF-8 encoding for INFO level and above
    file_handler = logging.FileHandler(f'{log_dir}/mistral_training_{timestamp}.log', encoding='utf-8')
    file_handler.setLevel(logging.INFO)

    # Create formatter
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(formatter)

    # Get logger for __main__ and configure it
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    logger.handlers = []  # Clear any existing handlers
    logger.addHandler(file_handler)
    logger.propagate = False  # Prevent propagation to parent loggers

    # Clear all handlers from the root logger and add file handler only
    root_logger = logging.getLogger('')
    root_logger.handlers = []  # Remove all existing handlers (including console)
    root_logger.addHandler(file_handler)
    root_logger.setLevel(logging.INFO)
    root_logger.propagate = False

    # Suppress console output for specific libraries
    for logger_name in ['transformers', 'datasets', 'wandb', 'torch', 'huggingface_hub', 'peft', 'trl']:
        lib_logger = logging.getLogger(logger_name)
        lib_logger.handlers = []  # Clear any existing handlers
        lib_logger.addHandler(file_handler)  # Direct to file only
        lib_logger.setLevel(logging.INFO)
        lib_logger.propagate = False

    return logger

logger = setup_logging()

class SafetyGuardrails:
    """Advanced safety and moderation system for mental health and medical responses."""

    def __init__(self):
        self.mental_health_red_flags = [
            r'\b(kill|suicide|die|death|hurt myself|end it all|not worth living)\b',
            r'\b(harm|violence|dangerous|weapon)\b',
            r'\b(overdose|pills|medication abuse)\b'
        ]

        self.medical_disclaimers = [
            "This information is for educational purposes only and should not replace professional medical advice.",
            "Please consult with a healthcare professional for personalized medical guidance.",
            "Always seek immediate medical attention for serious health concerns."
        ]

        self.crisis_resources = {
            "suicide_prevention": "If you're having thoughts of suicide, please contact the 988 Suicide & Crisis Lifeline: 988",
            "crisis_text": "Text HOME to 741741 for the Crisis Text Line",
            "emergency": "For immediate emergencies, call 911"
        }

    def detect_crisis_language(self, text: str) -> Tuple[bool, str]:
        """Detect crisis language in user input."""
        text_lower = text.lower()

        for pattern in self.mental_health_red_flags:
            if re.search(pattern, text_lower):
                return True, "crisis_mental_health"

        return False, "safe"

    def apply_safety_filter(self, response: str, query_type: str) -> str:
        """Apply safety filters and add appropriate disclaimers."""
        if query_type == "mental_health":
            if not any(resource in response for resource in self.crisis_resources.values()):
                response += f"\n\nRemember: {self.crisis_resources['suicide_prevention']}"

        elif query_type == "medical":
            if not any(disclaimer in response for disclaimer in self.medical_disclaimers):
                response += f"\n\n{np.random.choice(self.medical_disclaimers)}"

        return response

    def moderate_response(self, response: str, user_input: str) -> str:
        """Comprehensive response moderation."""
        is_crisis, crisis_type = self.detect_crisis_language(user_input)

        if is_crisis:
            crisis_response = (
                "I'm concerned about what you've shared. Your safety and wellbeing are important. "
                f"{self.crisis_resources['suicide_prevention']} "
                f"{self.crisis_resources['emergency']} "
                "Please reach out to a mental health professional who can provide the support you need."
            )
            return crisis_response

        return response

class QueryRoutingClassifier:
    """DistilBERT-based classifier to route queries between mental health and medical domains."""

    def __init__(self, model_name: str = "distilbert-base-uncased"):
        self.model_name = model_name
        self.tokenizer = None
        self.model = None
        self.is_trained = False

    def prepare_routing_data(self, train_data: List[Dict], val_data: List[Dict]) -> Tuple[List[str], List[int], List[str], List[int]]:
        """Prepare data for training the classifier."""
        train_texts, train_labels = [], []
        val_texts, val_labels = [], []

        for item in train_data:
            question = item['text'].split('Answer:')[0].replace('Question:', '').strip()
            source = item['source']
            train_texts.append(question)
            train_labels.append(0 if source == 'mentalchat' else 1)

        for item in val_data:
            question = item['text'].split('Answer:')[0].replace('Question:', '').strip()
            source = item['source']
            val_texts.append(question)
            val_labels.append(0 if source == 'mentalchat' else 1)

        return train_texts, train_labels, val_texts, val_labels

    def train_classifier(self, train_texts: List[str], train_labels: List[int], val_texts: List[str], val_labels: List[int], output_dir: str, test_data: List[Dict] = None) -> dict:
        """Train the routing classifier."""
        logger.info("Training DistilBERT query routing classifier...")

        # Load tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            self.model_name,
            num_labels=2,
            dropout=0.3  # Added dropout for regularization
        )

        # Tokenize data
        train_encodings = self.tokenizer(train_texts, truncation=True, padding=True, max_length=128)
        val_encodings = self.tokenizer(val_texts, truncation=True, padding=True, max_length=128)

        # Create datasets
        train_dataset = Dataset.from_dict({
            'input_ids': train_encodings['input_ids'],
            'attention_mask': train_encodings['attention_mask'],
            'labels': train_labels
        })

        val_dataset = Dataset.from_dict({
            'input_ids': val_encodings['input_ids'],
            'attention_mask': val_encodings['attention_mask'],
            'labels': val_labels
        })

        # Training arguments
        training_args = TrainingArguments(
            output_dir=f"{output_dir}/routing_classifier",
            num_train_epochs=3,
            per_device_train_batch_size=8,  # Adjusted for DistilBERT
            per_device_eval_batch_size=8,
            warmup_steps=200,
            weight_decay=0.5,
            learning_rate=2e-5,
            logging_dir=f"{output_dir}/routing_classifier/logs",
            eval_strategy="epoch",  # Changed from evaluation_strategy
            save_strategy="epoch",
            load_best_model_at_end=True,
            metric_for_best_model="accuracy",
            report_to="none"  # Set to none for simplicity in Colab
        )

        def compute_metrics(eval_pred):
            predictions, labels = eval_pred
            predictions = np.argmax(predictions, axis=1)
            accuracy = accuracy_score(labels, predictions)
            precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='weighted')
            return {
                'accuracy': accuracy,
                'f1': f1,
                'precision': precision,
                'recall': recall
            }

        # Create trainer
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            compute_metrics=compute_metrics,
        )

        # Train
        trainer.train()

        # Evaluate on validation set
        eval_results = trainer.evaluate()
        logger.info(f"Routing classifier validation results: {eval_results}")

        # Evaluate on test set if test_data is provided
        if test_data is not None:
            test_texts, test_labels, _, _ = self.prepare_routing_data(test_data, test_data)
            test_encodings = self.tokenizer(test_texts, truncation=True, padding=True, max_length=128)
            test_dataset = Dataset.from_dict({
                'input_ids': test_encodings['input_ids'],
                'attention_mask': test_encodings['attention_mask'],
                'labels': test_labels
            })
            test_results = trainer.evaluate(test_dataset)
            logger.info(f"Routing classifier test results: {test_results}")
        else:
            test_results = None
            logger.info("No test data provided, skipping test set evaluation")

        # Save classifier
        self.model.save_pretrained(f"{output_dir}/routing_classifier")
        self.tokenizer.save_pretrained(f"{output_dir}/routing_classifier")

        self.is_trained = True
        return {'validation': eval_results, 'test': test_results}

    def predict_domain(self, query: str) -> Tuple[str, float]:
        """Predict whether query is mental health or medical."""
        if not self.is_trained or self.model is None:
            return self._keyword_based_routing(query)

        inputs = self.tokenizer(query, return_tensors='pt', truncation=True, padding=True, max_length=128)

        with torch.no_grad():
            outputs = self.model(**inputs)
            probabilities = F.softmax(outputs.logits, dim=-1)
            predicted_class = torch.argmax(probabilities, dim=-1).item()
            confidence = probabilities[0][predicted_class].item()

        domain = "mental_health" if predicted_class == 0 else "medical"
        return domain, confidence

    def _keyword_based_routing(self, query: str) -> Tuple[str, float]:
        """Fallback keyword-based routing."""
        mental_health_keywords = [
            'anxiety', 'depression', 'stress', 'mental', 'therapy', 'counseling',
            'mood', 'emotional', 'panic', 'ptsd', 'trauma', 'suicide', 'self-harm'
        ]

        medical_keywords = [
            'symptoms', 'disease', 'treatment', 'medication', 'diagnosis', 'doctor',
            'hospital', 'surgery', 'pain', 'infection', 'vaccine', 'prescription'
        ]

        query_lower = query.lower()
        mental_score = sum(1 for keyword in mental_health_keywords if keyword in query_lower)
        medical_score = sum(1 for keyword in medical_keywords if keyword in query_lower)

        if mental_score > medical_score:
            return "mental_health", 0.7
        elif medical_score > mental_score:
            return "medical", 0.7
        else:
            return "medical", 0.5

class AdvancedDataProcessor:
    """Advanced data processing for preprocessed JSONL datasets."""

    def __init__(self):
        self.label_mapping = {"mental_health": 0, "medical": 1}

    def load_jsonl(self, file_path: str) -> List[Dict]:
        """Load JSONL file with error handling."""
        try:
            data = []
            with open(file_path, 'r', encoding='utf-8') as f:
                for line_num, line in enumerate(f, 1):
                    try:
                        data.append(json.loads(line.strip()))
                    except json.JSONDecodeError as e:
                        logger.warning(f"Skipping malformed JSON at line {line_num}: {e}")
            logger.info(f"Loaded {len(data)} samples from {file_path}")
            return data
        except FileNotFoundError:
            logger.error(f"File not found: {file_path}")
            return []

    def preprocess_data(self, data: List[Dict], domain: str) -> List[Dict]:
        """Preprocess JSONL data with empathy or accuracy markers."""
        processed = []
        empathy_prefixes = [
            "I understand this can be really challenging, and I want you to know that your feelings are completely valid. ",
            "Thank you for sharing this with me. It takes courage to reach out. ",
            "I hear you, and I want you to know that you're not alone in feeling this way. ",
            "Your experience matters, and I'm here to support you through this. ",
            "It's completely normal to feel this way, and seeking support shows real strength. "
        ]
        medical_prefixes = [
            "Based on current medical knowledge: ",
            "According to established medical literature: ",
            "From a clinical perspective: ",
            "Medical research indicates that: ",
            ""  # Sometimes no prefix for natural flow
        ]

        for item in data:
            try:
                # Extract question and answer from text field
                text = item['text']
                question = text.split('Answer:')[0].replace('Question:', '').strip()
                answer = text.split('Answer:')[1].strip() if 'Answer:' in text else ''
                source = item['source']

                if not question or not answer:
                    continue

                # Determine domain
                item_domain = 'mental_health' if source == 'mentalchat' else 'medical'

                # Apply domain-specific processing
                if item_domain == 'mental_health':
                    prefix = np.random.choice(empathy_prefixes) if np.random.random() < 0.4 else ""
                    crisis_indicators = ['suicide', 'kill myself', 'end it all', 'hurt myself', 'die']
                    if any(indicator in question.lower() for indicator in crisis_indicators):
                        professional_guidance = " Please remember that professional mental health support is crucial, and I encourage you to reach out to a counselor, therapist, or crisis helpline."
                        answer += professional_guidance
                    instruction = (
                        "You are an empathetic mental health support assistant. Provide compassionate, "
                        "understanding, and helpful responses while always encouraging professional help "
                        "when appropriate. Be supportive but never provide medical diagnoses."
                    )
                    output = f"{prefix}{answer}"
                else:
                    prefix = np.random.choice(medical_prefixes)
                    consultation_reminder = " Please consult with a healthcare professional for personalized medical advice and treatment."
                    if "consult" not in answer.lower() and "doctor" not in answer.lower():
                        answer += consultation_reminder
                    instruction = (
                        "You are a knowledgeable medical information assistant. Provide accurate, "
                        "evidence-based medical information while always emphasizing the importance "
                        "of consulting healthcare professionals. Never provide specific medical diagnoses."
                    )
                    output = f"{prefix}{answer}"

                processed_item = {
                    'instruction': instruction,
                    'input': question,
                    'output': output,
                    'domain': item_domain,
                    'label': self.label_mapping[item_domain]
                }
                processed.append(processed_item)
            except Exception as e:
                logger.warning(f"Error processing item: {e}")
                continue

        return processed

    def create_mistral_format(self, data: List[Dict]) -> List[Dict]:
        """Convert data to Mistral instruction format."""
        formatted_data = []

        for item in data:
            conversation = f"<s>[INST] {item['instruction']}\n\nQuery: {item['input']} [/INST] {item['output']}</s>"
            formatted_item = {
                'text': conversation,
                'domain': item['domain'],
                'label': item['label'],
                'instruction': item['instruction'],
                'input': item['input'],
                'output': item['output']
            }
            formatted_data.append(formatted_item)

        return formatted_data

def setup_mistral_model_and_tokenizer(model_name: str = "mistralai/Mistral-7B-Instruct-v0.1", token: str = "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"):
    logger.info(f"Loading Mistral model: {model_name} in BF16 precision")
    if not token:
        raise ValueError("No Hugging Face token provided.")

    from huggingface_hub import login, whoami
    login(token=token)
    user_info = whoami(token=token)
    logger.info(f"Token verified! Logged in as: {user_info['name']}")

    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        token=token,
        padding_side="right"
    )
    tokenizer.pad_token = tokenizer.eos_token

    # Load model with BF16 precision
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,  # Use BF16 precision
        device_map="auto",
        trust_remote_code=True,
        token=token
    )
    logger.info(f"Loaded {model_name} with BF16 precision")

    return model, tokenizer
def setup_mistral_qlora_config():
    """Setup optimized QLoRA configuration for Mistral-7B."""
    return LoraConfig(
        r=64,
        lora_alpha=16,
        lora_dropout=0.1,
        target_modules=[
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj"
        ],
        bias="none",
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False
    )

class MistralTrainer(Trainer):
    def __init__(self, *args, domain_weights=None, safety_guardrails=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.domain_weights = domain_weights or {"mental_health": 1.2, "medical": 1.0}
        self.safety_guardrails = safety_guardrails or SafetyGuardrails()

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=1):
        """Custom loss computation with domain weighting."""
        labels = inputs.get("labels")
        domain_labels = inputs.get("domain_labels", None)
        outputs = model(**{k: v for k, v in inputs.items() if k != "domain_labels"})

        if labels is not None:
            shift_logits = outputs.logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction='none')
            losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            if domain_labels is not None:
                weights = torch.tensor([self.domain_weights["mental_health" if l == 0 else "medical"] for l in domain_labels], device=losses.device)
                losses = losses.view(shift_labels.size()) * weights.unsqueeze(-1)
                loss = losses.mean()
            else:
                loss = losses.mean()

            # Log loss details
            logger.info(f"Loss: {loss.item()}, requires_grad: {loss.requires_grad}")
            return (loss, outputs) if return_outputs else loss
        return outputs

def create_mistral_training_args(output_dir: str):
    return TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=3,
        per_device_train_batch_size=2,  # Increased for better GPU utilization
        per_device_eval_batch_size=2,   # Increased for evaluation
        gradient_accumulation_steps=4,  # Reduced to maintain similar effective batch size
        warmup_ratio=0.03,
        weight_decay=0.01,
        learning_rate=2e-4,
        lr_scheduler_type="cosine",
        logging_dir=os.path.join(output_dir, 'logs'),
        logging_steps=100,  # Increased to reduce logging overhead
        eval_strategy="steps",
        eval_steps=500,
        save_strategy="steps",
        save_steps=1000,
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        fp16=False,  # Disable FP16 for BF16 training
        bf16=True if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else False,  # Ensure BF16 is supported
        dataloader_pin_memory=True,
        remove_unused_columns=False,
        report_to="none",
        run_name=f"mistral_mental_medical_{datetime.now().strftime('%Y%m%d_%H%M')}",
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={"use_reentrant": False},
        optim="adamw_torch",  # Use standard AdamW for BF16 training
        max_grad_norm=1.0,
    )

def tokenize_mistral_data(examples, tokenizer, max_length=1024):
    """Tokenize data in Mistral format with padding and truncation."""
    tokenized = tokenizer(
        examples['text'],
        truncation=True,  # Enable truncation
        padding='max_length',  # Pad to max_length
        max_length=max_length,
        return_tensors=None
    )

    # Ensure labels are a copy of input_ids and properly formatted
    tokenized["labels"] = tokenized["input_ids"].copy()
    tokenized["domain_labels"] = examples['label']

    return tokenized
class MistralChatbot:
    def __init__(self, model_path: str, routing_classifier_path: str):
        self.model_path = model_path
        self.routing_classifier_path = routing_classifier_path
        self.safety_guardrails = SafetyGuardrails()
        self.routing_classifier = QueryRoutingClassifier()
        self.model = None
        self.tokenizer = None
        self.pipeline = None

    def load_models(self):
        logger.info("Loading Mistral chatbot models in BF16 precision...")

        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)

        # Load model with BF16 precision
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_path,
            torch_dtype=torch.bfloat16,  # Use BF16 precision
            device_map="auto",
            trust_remote_code=True
        )

        # Load routing classifier
        try:
            self.routing_classifier.tokenizer = AutoTokenizer.from_pretrained(self.routing_classifier_path)
            self.routing_classifier.model = AutoModelForSequenceClassification.from_pretrained(self.routing_classifier_path)
            self.routing_classifier.is_trained = True
        except Exception as e:
            logger.warning(f"Could not load routing classifier: {e}. Using keyword-based routing")

        # Setup pipeline with BF16
        self.pipeline = pipeline(
            "text-generation",
            model=self.model,
            tokenizer=self.tokenizer,
            torch_dtype=torch.bfloat16,  # Use BF16 for pipeline
            device_map="auto"
        )

        logger.info("All models loaded successfully")

    def generate_response(self, user_query: str) -> Dict[str, Union[str, float]]:
        """Generate response with routing and safety checks."""
        domain, confidence = self.routing_classifier.predict_domain(user_query)

        if domain == "mental_health":
            instruction = (
                "You are an empathetic mental health support assistant. Provide compassionate, "
                "understanding, and helpful responses while always encouraging professional help "
                "when appropriate. Be supportive but never provide medical diagnoses."
            )
        else:
            instruction = (
                "You are a knowledgeable medical information assistant. Provide accurate, "
                "evidence-based medical information while always emphasizing the importance "
                "of consulting healthcare professionals. Never provide specific medical diagnoses."
            )

        prompt = f"<s>[INST] {instruction}\n\nQuery: {user_query} [/INST]"

        outputs = self.pipeline(
            prompt,
            max_new_tokens=256,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=self.tokenizer.eos_token_id
        )

        generated_text = outputs[0]['generated_text']
        response = generated_text.split('[/INST]')[-1].strip()

        safe_response = self.safety_guardrails.moderate_response(response, user_query)
        final_response = self.safety_guardrails.apply_safety_filter(safe_response, domain)

        return {
            'response': final_response,
            'domain': domain,
            'confidence': confidence,
            'safety_applied': safe_response != response
        }

def test_chatbot(output_dir: str):
    """Test the chatbot with sample queries."""
    logger.info("Testing chatbot with sample queries...")
    chatbot = MistralChatbot(
        model_path=os.path.join(output_dir, 'mistral_model'),
        routing_classifier_path=os.path.join(output_dir, 'routing_classifier')
    )
    chatbot.load_models()

    sample_queries = [
        "I'm feeling really anxious about work and can't sleep. What can I do?",
        "What are the symptoms of type 2 diabetes?",
        "I feel like life isn't worth living anymore.",
        "Can you explain what causes migraines?"
    ]

    for query in sample_queries:
        response_dict = chatbot.generate_response(query)
        logger.info(f"Query: {query}")
        logger.info(f"Response: {response_dict['response']}")
        logger.info(f"Domain: {response_dict['domain']} (Confidence: {response_dict['confidence']:.2f})")
        logger.info(f"Safety Applied: {response_dict['safety_applied']}")
        logger.info("-" * 50)

def main():
    """Main training pipeline for Mistral-7B dual-purpose chatbot."""
    logger.info("Starting Advanced Mistral-7B Fine-Tuning Pipeline")
    print("Starting Advanced Mistral-7B Fine-Tuning Pipeline")

    if os.getenv("WANDB_API_KEY"):
        wandb.init(
            project="mistral-mental-health-medical-qa",
            name="mistral-7b-qlora-finetune",
            config={
                "model": "mistralai/Mistral-7B-Instruct-v0.1",
                "classifier": "distilbert-base-uncased",
                "technique": "QLoRA",
                "datasets": ["MentalChat16K", "MedQuAD"],
                "features": ["routing_classifier", "safety_guardrails", "domain_weighting"]
            }
        )

    output_dir = "mistral_mental_medical_chatbot"
    os.makedirs(output_dir, exist_ok=True)

    try:
        # Load datasets
        print("Loading preprocessed JSONL datasets...")
        logger.info("Loading preprocessed JSONL datasets...")
        data_processor = AdvancedDataProcessor()
        train_data = data_processor.load_jsonl('combined_medical_data/train.jsonl')
        val_data = data_processor.load_jsonl('combined_medical_data/validation.jsonl')
        test_data = data_processor.load_jsonl('combined_medical_data/test.jsonl')
        logger.info(f"Loaded {len(train_data)} training samples, {len(val_data)} validation samples, and {len(test_data)} test samples.")
        print(f"Loaded {len(train_data)} training samples, {len(val_data)} validation samples, and {len(test_data)} test samples.")

        if not train_data or not val_data or not test_data:
            logger.error("Failed to load datasets. Please check file paths.")
            print("Failed to load datasets. Please check file paths.")
            return

        # Process data
        print("Processing datasets...")
        logger.info("Processing datasets...")
        processed_train = data_processor.preprocess_data(train_data, None)
        processed_val = data_processor.preprocess_data(val_data, None)

        # Balance datasets
        mental_train = [item for item in processed_train if item['domain'] == 'mental_health']
        medical_train = [item for item in processed_train if item['domain'] == 'medical']
        min_size = min(len(mental_train), len(medical_train), 7500)
        mental_balanced = np.random.choice(mental_train, min_size, replace=False).tolist()
        medical_balanced = np.random.choice(medical_train, min_size, replace=False).tolist()
        logger.info(f"Balanced datasets: {min_size} samples each")
        print(f"Balanced datasets: {min_size} samples each")

        # Train routing classifier
        logger.info("Training DistilBERT routing classifier...")
        print("Training DistilBERT routing classifier...")
        routing_classifier = QueryRoutingClassifier()
        train_texts, train_labels, val_texts, val_labels = routing_classifier.prepare_routing_data(train_data, val_data)
        routing_results = routing_classifier.train_classifier(train_texts, train_labels, val_texts, val_labels, output_dir, test_data=test_data)

        # Combine and format data for Mistral
        all_train_data = mental_balanced + medical_balanced
        np.random.shuffle(all_train_data)
        formatted_train = data_processor.create_mistral_format(all_train_data)
        formatted_val = data_processor.create_mistral_format(processed_val)

        # Setup Mistral model
        logger.info("Loading Mistral-7B model...")
        print("Loading Mistral-7B model...")
        model, tokenizer = setup_mistral_model_and_tokenizer()

        # Apply QLoRA
        logger.info("Applying QLoRA configuration...")
        print("Applying QLoRA configuration...")
        lora_config = setup_mistral_qlora_config()
        model = get_peft_model(model, lora_config)
        model.print_trainable_parameters()

        # Verify trainable parameters
        trainable_params = False
        for name, param in model.named_parameters():
            if param.requires_grad:
                trainable_params = True
        if not trainable_params:
            logger.error("No parameters set to require gradients. Check LoRA configuration.")
            print("No parameters set to require gradients. Check LoRA configuration.")
            raise ValueError("No trainable parameters found.")

        # Create datasets
        train_dataset = Dataset.from_list(formatted_train)
        val_dataset = Dataset.from_list(formatted_val)

        # Tokenize datasets
        def tokenize_batch(examples):
            return tokenize_mistral_data(examples, tokenizer, max_length=1024)

        train_dataset = train_dataset.map(
            tokenize_batch,
            batched=True,
            remove_columns=[col for col in train_dataset.column_names if col not in ['input_ids', 'attention_mask', 'labels', 'domain_labels']]
        )
        val_dataset = val_dataset.map(
            tokenize_batch,
            batched=True,
            remove_columns=[col for col in val_dataset.column_names if col not in ['input_ids', 'attention_mask', 'labels', 'domain_labels']]
        )

        # Data collator
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=tokenizer,
            mlm=False,
            pad_to_multiple_of=8
        )

        # Setup training
        logger.info("🏋️ Setting up training configuration...")
        print("🏋️ Setting up training configuration...")
        training_args = create_mistral_training_args(output_dir)

        model.train()  # Ensure model is in training mode
        trainer = MistralTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            data_collator=data_collator,
            callbacks=[EarlyStoppingCallback(early_stopping_patience=1, early_stopping_threshold=0.0001)],
            domain_weights={"mental_health": 1.2, "medical": 1.0},
            safety_guardrails=SafetyGuardrails()
        )

        # Train model
        logger.info("🚀 Starting Mistral-7B fine-tuning...")
        print("🚀 Starting Mistral-7B fine-tuning...")
        trainer.train()

        # Evaluate
        logger.info("📊 Evaluating model...")
        print("📊 Evaluating model...")
        eval_results = trainer.evaluate()
        logger.info(f"Final evaluation results: {eval_results}")
        print(f"Final evaluation results: {eval_results}")

        # Save model
        logger.info("💾 Saving fine-tuned model...")
        print("💾 Saving fine-tuned model...")
        model.save_pretrained(os.path.join(output_dir, 'mistral_model'))
        tokenizer.save_pretrained(os.path.join(output_dir, 'mistral_model'))

        # Save configuration
        config = {
            'model_name': "mistralai/Mistral-7B-Instruct-v0.1",
            'classifier_name': "distilbert-base-uncased",
            'lora_config': {
                'r': lora_config.r,
                'lora_alpha': lora_config.lora_alpha,
                'lora_dropout': lora_config.lora_dropout,
                'target_modules': list(lora_config.target_modules)  # Ensure list
            },
            'training_args': training_args.to_dict(),
            'final_metrics': eval_results,
            'routing_classifier_metrics': routing_results,
            'dataset_sizes': {
                'mental_health': len(mental_balanced),
                'medical': len(medical_balanced),
                'total_train': len(all_train_data),
                'total_val': len(processed_val)
            },
            'dataset_paths': {
                'train': 'combined_medical_data/train.jsonl',
                'validation': 'combined_medical_data/validation.jsonl',
                'test': 'combined_medical_data/test.jsonl'
            },
            'timestamp': datetime.now().isoformat(),
            'features': [
                'QLoRA fine-tuning',
                'DistilBERT routing classifier',
                'Safety guardrails',
                'Crisis detection',
                'Professional guidance prompts',
                'Medical disclaimers',
                'bf16 quantization training',
                'Instruction-following conversational format',
                'Comprehensive evaluation and monitoring'
            ]
        }

        # Save config to JSON
        with open(os.path.join(output_dir, 'mistral_config.json'), 'w') as f:
            json.dump(config, f, indent=2)

        # Test the complete system
        logger.info("🧪 Testing complete chatbot system...")
        print("🧪 Testing complete chatbot system...")
        test_chatbot(output_dir)

        # Success message
        logger.info("✅ Training completed successfully!")

        print("✅ Training completed successfully!")

        print("\n" + "="*80)
        print("🎉 MISTRAL-7B FINE-TUNING COMPLETED SUCCESSFULLY! 🎉")
        print(f"📁 Model saved to: {output_dir}/mistral_model")
        accuracy = routing_results['validation'].get('eval_accuracy', 'N/A')
        if isinstance(accuracy, (int, float)):
            print(f"🎯 Routing Classifier Accuracy: {accuracy:.4f}")
        else:
            print(f"🎯 Routing Classifier Accuracy: {accuracy}")
        loss = eval_results.get('eval_loss', 'N/A')
        if isinstance(loss, (int, float)):
            print(f"📊 Final Training Loss: {loss:.4f}")
        else:
            print(f"📊 Final Training Loss: {loss}")
        print("\n🔬 Advanced Features Implemented:")
        print("✅ Mistral-7B base model with QLoRA fine-tuning")
        print("✅ DistilBERT routing classifier for query domain detection")
        print("✅ MentalChat16K integration for empathetic responses")
        print("✅ MedQuAD integration for accurate medical information")
        print("✅ Advanced safety guardrails and crisis detection")
        print("✅ Professional guidance and medical disclaimers")
        print("✅ Domain-specific loss weighting")
        print("✅ FP16 full precision training")
        print("✅ Instruction-following conversational format")
        print("✅ Comprehensive evaluation and monitoring")
        print("\n📋 Usage Instructions:")
        print(f"1. Load the chatbot: MistralChatbot('{output_dir}/mistral_model', '{output_dir}/routing_classifier')")
        print("2. Call chatbot.load_models() to initialize")
        print("3. Use chatbot.generate_response('your query') for responses")
        print("="*80)

    except Exception as e:
        logger.error(f"Training failed: {e}")
        import traceback
        traceback.print_exc()
        raise

    finally:
        if os.getenv("WANDB_API_KEY"):
            wandb.finish()

if __name__ == "__main__":
    main()


Starting Advanced Mistral-7B Fine-Tuning Pipeline
Loading preprocessed JSONL datasets...
Loaded 25854 training samples, 3232 validation samples, and 3232 test samples.
Processing datasets...
Balanced datasets: 7500 samples each
Training DistilBERT routing classifier...


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,0.0022,0.001152,0.999381,0.999381,0.999382,0.999381
2,0.0,2e-06,1.0,1.0,1.0,1.0
3,0.0,1e-06,1.0,1.0,1.0,1.0


Loading Mistral-7B model...


tokenizer_config.json:   0%|          | 0.00/2.10k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

Applying QLoRA configuration...
trainable params: 167,772,160 || all params: 7,409,504,256 || trainable%: 2.2643


Map:   0%|          | 0/15000 [00:00<?, ? examples/s]

Map:   0%|          | 0/3232 [00:00<?, ? examples/s]

🏋️ Setting up training configuration...
🚀 Starting Mistral-7B fine-tuning...


Step,Training Loss,Validation Loss
500,1.5804,0.402093
1000,1.6208,0.386127
1500,1.5836,0.376473
2000,1.2074,0.374338
2500,1.2593,0.373366
3000,1.1787,0.367831
3500,1.2357,0.362774
4000,0.8567,0.383185


📊 Evaluating model...


Final evaluation results: {'eval_loss': 0.3678305745124817, 'eval_runtime': 393.0881, 'eval_samples_per_second': 8.222, 'eval_steps_per_second': 4.111, 'epoch': 2.1333333333333333}
💾 Saving fine-tuned model...
🧪 Testing complete chatbot system...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

✅ Training completed successfully!

🎉 MISTRAL-7B FINE-TUNING COMPLETED SUCCESSFULLY! 🎉
📁 Model saved to: mistral_mental_medical_chatbot/mistral_model
🎯 Routing Classifier Accuracy: 1.0000
📊 Final Training Loss: 0.3678

🔬 Advanced Features Implemented:
✅ Mistral-7B base model with QLoRA fine-tuning
✅ DistilBERT routing classifier for query domain detection
✅ MentalChat16K integration for empathetic responses
✅ MedQuAD integration for accurate medical information
✅ Advanced safety guardrails and crisis detection
✅ Professional guidance and medical disclaimers
✅ Domain-specific loss weighting
✅ FP16 full precision training
✅ Instruction-following conversational format
✅ Comprehensive evaluation and monitoring

📋 Usage Instructions:
1. Load the chatbot: MistralChatbot('mistral_mental_medical_chatbot/mistral_model', 'mistral_mental_medical_chatbot/routing_classifier')
2. Call chatbot.load_models() to initialize
3. Use chatbot.generate_response('your query') for responses



## Step 5: Copy Files from Temporary Storage to Google Drive

This step provides a guide to transfer the fine-tuned Mistral-7B model, routing classifier, configuration files, and logs (generated in Step 4) from the temporary Colab storage to your Google Drive for long-term storage and future use. The files are located in the `mistral_mental_medical_chatbot/` directory, and this process ensures they are safely backed up. This step is part of the Week 2 workflow for the STEMRESEARCH project. Refer to the GitHub repository [https://github.com/I-VAGAT/STEMRESEARCH](https://github.com/I-VAGAT/STEMRESEARCH) for project details.



### Step-by-Step Guide

1. **Mount Google Drive**  
   To access your Google Drive from Colab, you need to mount it. Add and run the following code in a new cell in your Colab notebook:
   ```python
   from google.colab import drive
   drive.mount('/content/drive')
   ```
   Follow the prompt to sign in to your Google account, copy the authorization code, and paste it into the Colab text box.  
   This mounts your Google Drive at `/content/drive/My Drive/`. Verify the mount by listing the contents:
   ```python
   !ls "/content/drive/My Drive/"
   ```

2. **Verify Temporary Files**  
   Check that the files exist in the temporary Colab storage. Your code saved files to `mistral_mental_medical_chatbot/`. Run the following to confirm:
   ```python
   !ls mistral_mental_medical_chatbot
   !ls mistral_mental_medical_chatbot/mistral_model
   !ls mistral_mental_medical_chatbot/routing_classifier
   !ls mistral_mental_medical_chatbot/logs
   ```
   **Expected Output**:
   - `mistral_model/`: Files like `pytorch_model.bin`, `config.json`, `tokenizer.json`, etc.
   - `routing_classifier/`: Files like `pytorch_model.bin`, `vocab.txt`, `tokenizer_config.json`, etc.
   - `mistral_config.json`: The configuration file.
   - `logs/`: Log files like `mistral_training_*.log`.

3. **Create a Folder in Google Drive**  
   To keep things organized, create a folder in your Google Drive to store these files:
   ```python
   import os
   # Create a folder in Google Drive
   !mkdir -p "/content/drive/My Drive/mistral_mental_medical_chatbot"
   ```
   This creates a folder named `mistral_mental_medical_chatbot` in your Google Drive’s root directory. You can change the path if needed.

4. **Copy Files to Google Drive**  
   Use the `!cp` command to copy the entire directory from the temporary storage to Google Drive:
   ```python
   !cp -r mistral_mental_medical_chatbot "/content/drive/My Drive/"
   ```
   The `-r` flag ensures recursive copying, including all subdirectories and their contents.

5. **Verify Files in Google Drive**  
   After copying, confirm the files are in Google Drive:
   ```python
   !ls "/content/drive/My Drive/mistral_mental_medical_chatbot"
   !ls "/content/drive/My Drive/mistral_mental_medical_chatbot/mistral_model"
   !ls "/content/drive/My Drive/mistral_mental_medical_chatbot/routing_classifier"
   !ls "/content/drive/My Drive/mistral_mental_medical_chatbot/logs"
   ```

### Notes
- **Storage**: Ensure sufficient space in Google Drive for the files (use `!df -h` to check Colab storage and Google Drive quota).


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!ls "/content/drive/My Drive/"

In [None]:
!ls mistral_mental_medical_chatbot
!ls mistral_mental_medical_chatbot/mistral_model
!ls mistral_mental_medical_chatbot/routing_classifier
!ls mistral_mental_medical_chatbot/logs

In [None]:
import os
# Create a folder in Google Drive
!mkdir -p "/content/drive/My Drive/mistral_mental_medical_chatbot"

In [None]:
!ls "/content/drive/My Drive/mistral_mental_medical_chatbot"
!ls "/content/drive/My Drive/mistral_mental_medical_chatbot/mistral_model"
!ls "/content/drive/My Drive/mistral_mental_medical_chatbot/routing_classifier"
!ls "/content/drive/My Drive/mistral_mental_medical_chatbot/logs"

## Step 6: Test MistralChatbot with Sample Queries for Week 2

This cell contains the script to test the fine-tuned **MistralChatbot** model, loaded from the Google Drive backup created in Step 5, with sample queries for the Week 2 tasks of the STEMRESEARCH project. The chatbot integrates a DistilBERT-based query routing classifier, enhanced safety guardrails (including link removal and content moderation), and generates domain-specific responses (mental health or medical). The script uses the preprocessed data and models saved from Step 4, now stored in Google Drive. Refer to the GitHub repository [https://github.com/I-VAGAT/STEMRESEARCH](https://github.com/I-VAGAT/STEMRESEARCH) for project details.

### Execution Context

- **Data Source**: The script loads the fine-tuned model and routing classifier from `/content/drive/My Drive/mistral_mental_medical_chatbot/`, copied in Step 5.

### Future Use
- **For Future Use**: For future use, I created a `mistral_chatbot.py` file containing the chatbot logic and code up to the class MistralChatbot: definition and I uploaded it to Google Drive (`/content/drive/My Drive/`) for future use and reusability across sessions.

### Key Features
- **Model Loading**: Loads the Mistral-7B model and DistilBERT classifier with BF16 precision from Google Drive.
- **Query Routing**: Uses a keyword-based or pre-trained classifier to determine the domain (mental health or medical).
- **Safety Guardrails**: Detects crisis language, applies disclaimers, removes unwanted links/promotional content, and cleans repetitive or incomplete responses.
- **Response Generation**: Generates responses with a maximum of 512 new tokens, optimized for empathy (mental health) or accuracy (medical).
- **Logging**: Records operations and errors to the console with detailed logging.

### Prerequisites
- Required packages installed (Steps 1 and 2).
- GPU enabled (`Runtime > Change runtime type > GPU`) for CUDA support.
- Google Drive mounted and files copied (Steps 5).
- Hugging Face token provided (`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`) for model access.

### Code Overview
The script performs the following steps:
1. **Imports and Logging**: Sets up PyTorch, Transformers, and logging configurations.
2. **SafetyGuardrails**: Enhances with link removal and content moderation.
3. **QueryRoutingClassifier**: Provides domain prediction (keyword-based if classifier fails to load).
4. **MistralChatbot**: Loads models from Google Drive, generates responses, and applies safety filters.
5. **Main Function**: Tests the chatbot with sample queries and logs results.

### Sample Queries
- "I'm feeling really anxious about work and can't sleep. What can I do?"
- "What are the symptoms of type high blood glucose?"
- "I feel like life isn't worth living anymore."
- "Can you explain what causes migraines?"

### Expected Output
For each query, the output will include:
- The query text.
- The generated response (cleaned and moderated).
- The predicted domain (mental_health or medical) with confidence score.
- Whether safety measures (crisis detection or disclaimers) were applied.
Example:

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, pipeline
import logging
import os
import re
import numpy as np
from typing import Dict, List, Tuple, Union
from huggingface_hub import login
login(token="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")

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# SafetyGuardrails class (enhanced with link removal)
class SafetyGuardrails:
    def __init__(self):
        self.mental_health_red_flags = [
            r'\b(kill|suicide|die|death|hurt myself|end it all|not worth living)\b',
            r'\b(harm|violence|dangerous|weapon)\b',
            r'\b(overdose|pills|medication abuse)\b'
        ]
        self.medical_disclaimers = [
            "This information is for educational purposes only and should not replace professional medical advice.",
            "Please consult with a healthcare professional for personalized medical guidance.",
            "Always seek immediate medical attention for serious health concerns."
        ]
        self.crisis_resources = {
            "suicide_prevention": "If you're having thoughts of suicide, please contact the 988 Suicide & Crisis Lifeline: 988",
            "crisis_text": "Text HOME to 741741 for the Crisis Text Line",
            "emergency": "For immediate emergencies, call 911"
        }

        # Patterns to remove unwanted content
        self.unwanted_patterns = [
            r'Click here to visit.*?for more.*?information.*?\.',
            r'Click here to.*?Web site.*?\.',
            r'Visit.*?website.*?for more information.*?\.',
            r'Please contact us today to schedule a consultation.*?\.',
            r'Contact us.*?to discuss.*?\.',
            r'For more information.*?visit.*?website.*?\.',
            r'More details can be found at.*?\.',
            r'Additional resources.*?available at.*?\.'
        ]

    def detect_crisis_language(self, text: str) -> Tuple[bool, str]:
        text_lower = text.lower()
        for pattern in self.mental_health_red_flags:
            if re.search(pattern, text_lower):
                return True, "crisis_mental_health"
        return False, "safe"

    def apply_safety_filter(self, response: str, query_type: str) -> Tuple[str, bool]:
        original_response = response

        if query_type == "mental_health":
            # Check if crisis resource is already present
            if not any(resource in response for resource in self.crisis_resources.values()):
                response += f"\n\nRemember: {self.crisis_resources['suicide_prevention']}"
        elif query_type == "medical":
            # Check if any disclaimer is already present
            if not any(disclaimer in response for disclaimer in self.medical_disclaimers):
                response += f"\n\n{np.random.choice(self.medical_disclaimers)}"

        return response, response != original_response

    def moderate_response(self, response: str, user_input: str) -> Tuple[str, bool]:
        is_crisis, crisis_type = self.detect_crisis_language(user_input)
        if is_crisis:
            crisis_response = (
                "I'm concerned about what you've shared. Your safety and wellbeing are important. "
                f"{self.crisis_resources['suicide_prevention']} "
                f"{self.crisis_resources['emergency']} "
                "Please reach out to a mental health professional who can provide the support you need."
            )
            return crisis_response, True
        return response, False

    def remove_unwanted_content(self, response: str) -> str:
        """Remove unwanted links and promotional content"""
        cleaned_response = response

        # Remove unwanted patterns
        for pattern in self.unwanted_patterns:
            cleaned_response = re.sub(pattern, '', cleaned_response, flags=re.IGNORECASE | re.DOTALL)

        # Remove any remaining "Click here" references
        cleaned_response = re.sub(r'Click here.*?\.\s*', '', cleaned_response, flags=re.IGNORECASE)

        # Remove multiple consecutive spaces and clean up
        cleaned_response = re.sub(r'\s+', ' ', cleaned_response)
        cleaned_response = cleaned_response.strip()

        return cleaned_response

    def clean_response(self, response: str) -> str:
        """Remove repetitive sentences, unwanted content, and clean up the response"""

        # First remove unwanted content
        response = self.remove_unwanted_content(response)

        sentences = response.split('. ')
        unique_sentences = []
        seen = set()

        for sentence in sentences:
            # Clean the sentence
            clean_sentence = sentence.strip()
            if clean_sentence and clean_sentence not in seen:
                unique_sentences.append(clean_sentence)
                seen.add(clean_sentence)

        # Rejoin sentences
        cleaned_response = '. '.join(unique_sentences)

        # Ensure proper ending punctuation
        if cleaned_response and not cleaned_response.endswith(('.', '!', '?')):
            # Find the last complete sentence
            last_period = cleaned_response.rfind('.')
            last_exclamation = cleaned_response.rfind('!')
            last_question = cleaned_response.rfind('?')

            last_punct = max(last_period, last_exclamation, last_question)

            if last_punct > 0:
                cleaned_response = cleaned_response[:last_punct + 1]
            else:
                cleaned_response += '.'

        # Format numbered lists and bullet points with proper line breaks
        cleaned_response = self.format_lists(cleaned_response)

        return cleaned_response

    def format_lists(self, text: str) -> str:
        """Format numbered lists and bullet points with proper line breaks"""

        # Format numbered lists (1., 2., 3., etc.)
        text = re.sub(r'(\d+\.\s)', r'\n\n\1', text)

        # Format bullet points with dashes
        text = re.sub(r'(\s-\s)', r'\n\n- ', text)

        # Format colons followed by lists
        text = re.sub(r':\s*([A-Z])', r':\n\n\1', text)

        # Clean up multiple consecutive newlines
        text = re.sub(r'\n{3,}', '\n\n', text)

        # Remove leading/trailing whitespace
        text = text.strip()

        return text

# QueryRoutingClassifier class (unchanged)
class QueryRoutingClassifier:
    def __init__(self, model_name: str = "distilbert-base-uncased"):
        self.model_name = model_name
        self.tokenizer = None
        self.model = None
        self.is_trained = False

    def predict_domain(self, query: str) -> Tuple[str, float]:
        if not self.is_trained or self.model is None:
            return self._keyword_based_routing(query)
        inputs = self.tokenizer(query, return_tensors='pt', truncation=True, padding=True, max_length=128)
        with torch.no_grad():
            outputs = self.model(**inputs)
            probabilities = F.softmax(outputs.logits, dim=-1)
            predicted_class = torch.argmax(probabilities, dim=-1).item()
            confidence = probabilities[0][predicted_class].item()
        domain = "mental_health" if predicted_class == 0 else "medical"
        return domain, confidence

    def _keyword_based_routing(self, query: str) -> Tuple[str, float]:
        mental_health_keywords = [
            'anxiety', 'depression', 'stress', 'mental', 'therapy', 'counseling',
            'mood', 'emotional', 'panic', 'ptsd', 'trauma', 'suicide', 'self-harm'
        ]
        medical_keywords = [
            'symptoms', 'disease', 'treatment', 'medication', 'diagnosis', 'doctor',
            'hospital', 'surgery', 'pain', 'infection', 'vaccine', 'prescription'
        ]
        query_lower = query.lower()
        mental_score = sum(1 for keyword in mental_health_keywords if keyword in query_lower)
        medical_score = sum(1 for keyword in medical_keywords if keyword in query_lower)
        if mental_score > medical_score:
            return "mental_health", 0.7
        elif medical_score > mental_score:
            return "medical", 0.7
        else:
            return "medical", 0.5

# Updated MistralChatbot class
class MistralChatbot:
    def __init__(self, model_path: str, routing_classifier_path: str):
        self.model_path = model_path
        self.routing_classifier_path = routing_classifier_path
        self.safety_guardrails = SafetyGuardrails()
        self.routing_classifier = QueryRoutingClassifier()
        self.model = None
        self.tokenizer = None
        self.pipeline = None

    def load_models(self):
        logger.info("Loading Mistral chatbot models in BF16 precision...")
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_path,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True
        )
        try:
            self.routing_classifier.tokenizer = AutoTokenizer.from_pretrained(self.routing_classifier_path)
            self.routing_classifier.model = AutoModelForSequenceClassification.from_pretrained(self.routing_classifier_path)
            self.routing_classifier.is_trained = True
        except Exception as e:
            logger.warning(f"Could not load routing classifier: {e}. Using keyword-based routing")
        self.pipeline = pipeline(
            "text-generation",
            model=self.model,
            tokenizer=self.tokenizer,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
        logger.info("All models loaded successfully")

    def generate_response(self, user_query: str) -> Dict[str, Union[str, float]]:
        domain, confidence = self.routing_classifier.predict_domain(user_query)

        if domain == "mental_health":
            instruction = (
                "You are an empathetic mental health support assistant. Provide compassionate, "
                "understanding, and helpful responses while always encouraging professional help "
                "when appropriate. Be supportive but never provide medical diagnoses. "
                "Do not include website links or promotional content."
            )
        else:
            instruction = (
                "You are a knowledgeable medical information assistant. Provide accurate, "
                "evidence-based medical information while always emphasizing the importance "
                "of consulting healthcare professionals. Never provide specific medical diagnoses. "
                "Do not include website links or promotional content."
            )

        prompt = f"<s>[INST] {instruction}\n\nQuery: {user_query} [/INST]"

        outputs = self.pipeline(
            prompt,
            max_new_tokens=512,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.2,
            pad_token_id=self.tokenizer.eos_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            return_full_text=False  # Only return the new generated text
        )

        generated_text = outputs[0]['generated_text']
        response = generated_text.split('[/INST]')[-1].strip()

        # Clean the response to remove repetitions, unwanted content, and fix incomplete sentences
        response = self.safety_guardrails.clean_response(response)

        # Apply safety moderation
        moderated_response, crisis_applied = self.safety_guardrails.moderate_response(response, user_query)

        # Apply safety filter (only if not already applied)
        final_response, disclaimer_applied = self.safety_guardrails.apply_safety_filter(moderated_response, domain)

        return {
            'response': final_response,
            'domain': domain,
            'confidence': confidence,
            'safety_applied': crisis_applied or disclaimer_applied
        }
# Main function to test the chatbot
def main():
    output_dir = "/content/drive/My Drive/mistral_mental_medical_chatbot"
    model_path = os.path.join(output_dir, "mistral_model")
    routing_classifier_path = os.path.join(output_dir, "routing_classifier")

    chatbot = MistralChatbot(model_path=model_path, routing_classifier_path=routing_classifier_path)

    try:
        chatbot.load_models()
    except Exception as e:
        logger.error(f"Failed to load models: {e}")
        print(f"Error: Failed to load models. Ensure the model and classifier paths exist: {model_path}, {routing_classifier_path}")
        return

    sample_queries = [
        "I'm feeling really anxious about work and can't sleep. What can I do?",
        "What are the symptoms of type high blood glucose?",
        "I feel like life isn't worth living anymore.",
        "Can you explain what causes migraines?"
    ]

    for query in sample_queries:
        try:
            response_dict = chatbot.generate_response(query)
            print(f"\nQuery: {query}")
            print(f"Response: {response_dict['response']}")
            print(f"Domain: {response_dict['domain']} (Confidence: {response_dict['confidence']:.2f})")
            print(f"Safety Applied: {response_dict['safety_applied']}")
            print("-" * 50)
        except Exception as e:
            logger.error(f"Error processing query '{query}': {e}")
            print(f"Error processing query '{query}': {e}")

if __name__ == "__main__":
    main()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

Device set to use cuda:0



Query: I'm feeling really anxious about work and can't sleep. What can I do?
Response: It sounds like you're experiencing a lot of anxiety related to your work and it's impacting both your productivity at work and your ability to get quality rest. Here are some suggestions that might be helpful for managing these challenges: 

1. Prioritize self-care:

Taking care of yourself physically, mentally, and emotionally is crucial in managing stress and reducing anxiety. Make sure you are getting enough sleep by establishing a consistent bedtime routine and creating a comfortable sleeping environment. Engaging in activities such as exercise, meditation, deep breathing exercises, or hobbies that bring you joy can also help alleviate anxiety symptoms. 

2. Break tasks into smaller steps:

When faced with overwhelming workloads, breaking down larger projects into smaller, more manageable tasks can make them feel less daunting. Focus on completing one task at a time, which may give you a sense o

## Step 6.5: Mount Google Drive for Access
This cell mounts your Google Drive to the Colab environment, enabling access to the `mistral_chatbot.py` file and the fine-tuned model files (`mistral_model` and `routing_classifier`) stored in `/content/drive/My Drive/mistral_mental_medical_chatbot/`. This step must be executed before running the interactive conversational interface (Step 7) to ensure all necessary files are available. It is part of the Week 2 workflow for the STEMRESEARCH project. Refer to the GitHub repository [https://github.com/I-VAGAT/STEMRESEARCH](https://github.com/I-VAGAT/STEMRESEARCH) for project details.

### Execution Context
- **Purpose**: This step ensures the Google Drive is mounted if not already done, providing access to files uploaded or created in previous steps (e.g., Step 5 for file copying and Step 6 for `mistral_chatbot.py` upload).

### Key Features
- **Drive Mounting**: Uses `google.colab.drive` to mount Google Drive at `/content/drive/`.
- **Authorization**: Prompts for Google account sign-in and authorization code if not previously mounted in the session.

### Code Overview
The script performs the following:
1. **Import and Mount**: Imports the `drive` module and mounts Google Drive to `/content/drive/`.
2. **Verification**: Allows manual verification of the mount (e.g., using `!ls "/content/drive/My Drive/"` in a subsequent cell).

### Command
```python
from google.colab import drive
drive.mount('/content/drive')


In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Step 7: Interactive Conversational Interface with MistralChatbot for Week 2

This cell provides an interactive conversational interface to test and use the fine-tuned **MistralChatbot** loaded from the Google Drive backup created in Step 5. The chatbot, defined in the `mistral_chatbot.py` file uploaded to Google Drive, integrates a DistilBERT-based query routing classifier and enhanced safety guardrails for mental health and medical responses. This step concludes the Week 2 tasks of the STEMRESEARCH project, allowing real-time interaction with the model. Refer to the GitHub repository [https://github.com/I-VAGAT/STEMRESEARCH](https://github.com/I-VAGAT/STEMRESEARCH) for project details.

### Execution Context
- **Preparation**: The `mistral_chatbot.py` file was created and uploaded to `/content/drive/My Drive/` (mentioned in Step 6) for reusability across sessions.
- **Data Source**: The script loads the fine-tuned model and routing classifier from `/content/drive/My Drive/mistral_mental_medical_chatbot/`, copied in Step 5.

### Key Features
- **Model Loading**: Loads the Mistral-7B model and DistilBERT classifier with BF16 precision from Google Drive.
- **Query Routing**: Uses a pre-trained or keyword-based classifier to determine the domain (mental_health or medical).
- **Safety Guardrails**: Detects crisis language, applies disclaimers, removes unwanted links, and cleans responses.
- **Interactive Interface**: Allows users to input queries and receive real-time responses, with logging to Google Drive.
- **Logging**: Records conversations and errors to `/content/drive/My Drive/mistral_mental_medical_chatbot/logs/mistral_conversation_YYYYMMDD_HHMMSS.log`.

### Prerequisites
- Required packages installed (Steps 1 and 2).
- GPU enabled (`Runtime > Change runtime type > GPU`) for CUDA support.
- Google Drive mounted and files copied (Steps 5).
- `mistral_chatbot.py` uploaded to `/content/drive/My Drive/` (Step 6).
- Hugging Face token provided (`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`) for model access.

### Code Overview
The script performs the following steps:
1. **Imports and Path Setup**: Adds the Google Drive directory to the system path and imports `MistralChatbot` from `mistral_chatbot.py`.
2. **Logging Setup**: Configures UTF-8 logging to a timestamped file in Google Drive.
3. **Chatbot Initialization**: Initializes the chatbot with model and classifier paths from Google Drive.
4. **Model Loading**: Loads the pre-trained models and verifies functionality.
5. **Conversational Loop**: Runs an interactive loop where users input queries, receive responses, and can exit with 'quit'.

### Usage Instructions
- Run the cell below.
- Enter a query (e.g., "Symptoms of Cholera
") when prompted with `💬 You:`.
- Review the chatbot's response, domain prediction, and safety status.
- Type 'quit' to end the conversation.
- Example interaction:
🤖 Chatbot: According to established medical literature:

  What are the signs and symptoms of cholera? In acute cholera, the most common symptom is severe diarrhea that can be watery or contain blood and pieces of stool (rice water). This diarrhea often occurs suddenly and lasts for 24 hours. The diarrhea may then become bloody and the person becomes dehydrated. Other symptoms may include vomiting, leg cramps, and abdominal pain. Dehydration in infants appears as excessive crying with few tears; sunken eyes; dry skin; and poor elasticity when the skin is pinched. If untreated, dehydration may cause death due to shock or organ failure. It is estimated that without treatment, half of those affected by cholera will die from dehydration. Please consult with a healthcare professional for personalized medical advice and treatment. A licensed health care provider can evaluate your condition and recommend an appropriate course of action.

  Always seek immediate medical attention for serious health concerns.
  * 🧠 Domain: medical (Confidence: 1.00)
  * 🛡️ Safety Applied: True


### Expected Output
- A welcome message and prompt for input.
- For each query, a response with domain, confidence, and safety status.
- Logs saved to a timestamped file in `/content/drive/My Drive/mistral_mental_medical_chatbot/logs/`.

### Notes
- **Runtime**: Responses are generated in real-time (a few seconds per query) due to pre-trained models.
- **Model Loading Errors**: If models fail to load, verify the paths (`mistral_model`, `routing_classifier`) and ensure Google Drive is mounted.
- **Safety Features**: Crisis queries (e.g., "I feel like life isn't worth living") trigger safety responses; medical queries include disclaimers.
- **Persistence**: Logs are saved to Google Drive, ensuring conversation history is preserved.
- **Future Use**: Reuse this script by mounting Google Drive and running it in new sessions:
```python
from google.colab import drive
drive.mount('/content/drive')


In [None]:
import sys
sys.path.append('/content/drive/My Drive/mistral_mental_medical_chatbot')

# Import and initialize the chatbot
from mistral_chatbot import MistralChatbot
import logging
import os
from datetime import datetime
from huggingface_hub import login
login(token="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")
# Setup logging
def setup_logging():
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_dir = "/content/drive/My Drive/mistral_mental_medical_chatbot/logs"
    os.makedirs(log_dir, exist_ok=True)
    file_handler = logging.FileHandler(f'{log_dir}/mistral_conversation_{timestamp}.log', encoding='utf-8')
    file_handler.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(formatter)
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    logger.handlers = []
    logger.addHandler(file_handler)
    logger.propagate = False
    return logger

logger = setup_logging()

# Initialize the chatbot
output_dir = "/content/drive/My Drive/mistral_mental_medical_chatbot"
try:
    chatbot = MistralChatbot(
        model_path=f"{output_dir}/mistral_model",
        routing_classifier_path=f"{output_dir}/routing_classifier"
    )
    logger.info("Chatbot initialized successfully")
except Exception as e:
    logger.error(f"Failed to initialize chatbot: {e}")
    print(f"Error: Failed to initialize chatbot. Check if mistral_chatbot.py and model paths are correct.")
    raise

# Load the models
try:
    chatbot.load_models()
    logger.info("Models loaded successfully")
    print("🤖 Chatbot loaded successfully! Ready to assist.")
except Exception as e:
    logger.error(f"Failed to load models: {e}")
    print(f"Error: Failed to load models. Check the model paths and try again.")
    raise

# Conversational interface
print("\n=== Welcome to the Mental Health & Medical Chatbot ===")
print("I'm here to provide empathetic and accurate responses.")
print("Type your question or concern (e.g., 'I'm feeling anxious' or 'What are symptoms of diabetes?').")
print("Type 'quit' to exit the conversation.\n")

while True:
    query = input("💬 You: ")
    if query.lower() == 'quit':
        print("🤖 Goodbye! Stay safe and take care.")
        logger.info("Conversation ended by user")
        break
    if not query.strip():
        print("🤖 Please enter a valid question or concern.")
        continue

    try:
        response = chatbot.generate_response(query)
        print(f"\n🤖 Chatbot: {response['response']}")
        print(f"🧠 Domain: {response['domain']} (Confidence: {response['confidence']:.2f})")
        print(f"🛡️ Safety Applied: {response['safety_applied']}\n")
        logger.info(f"Query: {query}")
        logger.info(f"Response: {response['response']}")
        logger.info(f"Domain: {response['domain']} (Confidence: {response['confidence']:.2f})")
        logger.info(f"Safety Applied: {response['safety_applied']}")
    except Exception as e:
        logger.error(f"Error generating response: {e}")
        print(f"🤖 Sorry, an error occurred: {e}")
        print("Please try another question or check the logs for details.")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda:0


🤖 Chatbot loaded successfully! Ready to assist.

=== Welcome to the Mental Health & Medical Chatbot ===
I'm here to provide empathetic and accurate responses.
Type your question or concern (e.g., 'I'm feeling anxious' or 'What are symptoms of diabetes?').
Type 'quit' to exit the conversation.

💬 You: i am hungry i am feeling tired i can't cook food right now what can i do

🤖 Chatbot: Your experience matters, and I'm here to support you through this. It sounds like you're experiencing a difficult time right now, with the added stress of preparing meals on top of everything else. Here are some suggestions that might help make things easier for you: 

1. Reach out to friends or family members who could potentially bring you a meal or offer their assistance in any way they can. They may be happy to help and it could alleviate some of the pressure from you. 

2. Consider ordering takeout or delivery from nearby restaurants if you have access to those services. This option allows you to enjo