In [12]:
# -*- coding: utf-8 -*-
"""
IMPROVED Multi-Label Arabic Mental Health Classification Model
- MODIFIED: Can now run in "evaluation-only" mode to skip training and
  predict directly on a test set using an existing checkpoint.
"""
import os
import pandas as pd
from tqdm.auto import tqdm
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import glob

# Import Hugging Face Transformers components
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from transformers.modeling_outputs import SequenceClassifierOutput

from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import f1_score, classification_report

# --- Configuration ---
MODEL_NAME = "CAMeL-Lab/bert-base-arabic-camelbert-mix-sentiment"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# =================================================================================
# --- START OF MODE CONFIGURATION ---
# Set to True to skip training and only run evaluation on the test set.
# Set to False to run the full training process first.
EVALUATE_ONLY = True
# =================================================================================
# --- END OF MODE CONFIGURATION ---
# =================================================================================

# --- File Paths ---
# Training data paths (still needed in EVALUATE_ONLY mode to understand the labels)
DATA_PATH = '/content/drive/MyDrive/AraHealthQA/MentalQA/Task1/dev_data.tsv'
LABELS_PATH = '/content/drive/MyDrive/AraHealthQA/MentalQA/Task1/train_label.tsv'
# Directory where your saved checkpoints are located
TRAINING_OUTPUT_DIR = '/content/drive/MyDrive/AraHealthQA/MentalQA/Task1/output/improved_camelbert_checkpoints'

# Test data paths
TEST_DATA_PATH = '/content/drive/MyDrive/AraHealthQA/MentalQA/Task1/subtask1_input_test.tsv'
TEST_PREDICTION_OUTPUT_PATH = '/content/drive/MyDrive/AraHealthQA/MentalQA/Task1/output/predictions_on_test_set.tsv'


# --- Custom Model with Focal Loss ---
class ImprovedMultiLabelModel(nn.Module):
    # (This class is unchanged)
    def __init__(self, model_name, num_labels, alpha=1.0, gamma=2.0):
        super().__init__()
        self.bert = AutoModelForSequenceClassification.from_pretrained(
            model_name, num_labels=num_labels, problem_type="multi_label_classification", ignore_mismatched_sizes=True
        )
        self.alpha, self.gamma, self.num_labels = alpha, gamma, num_labels
    def focal_loss(self, logits, labels):
        BCE_loss = nn.BCEWithLogitsLoss(reduction='none')(logits, labels)
        pt = torch.exp(-BCE_loss)
        return (self.alpha * (1-pt)**self.gamma * BCE_loss).mean()
    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        outputs = self.bert.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        pooled_output = sequence_output[:, 0]
        logits = self.bert.classifier(pooled_output)
        loss = None
        if labels is not None:
            loss = self.focal_loss(logits, labels)
        return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)

# --- Helper Functions ---
def robust_read_lines(file_path):
    # (This function is unchanged)
    with open(file_path, 'r', encoding='utf-8') as f:
        return [line.strip() for line in f.readlines()]

def load_and_prepare_data(data_path, labels_path):
    # (This function is unchanged)
    questions, labels = robust_read_lines(data_path), robust_read_lines(labels_path)
    if len(questions) != len(labels):
        raise ValueError(f"Mismatch in line count between data and labels.")
    return pd.DataFrame({'text': questions, 'labels_str': labels})

def process_label_strings(label_series):
    # (This function is unchanged)
    processed_labels = []
    for s in label_series:
        labels = [label.strip() for label in s.split(',') if label.strip()]
        processed_labels.append(labels)
    return processed_labels

def analyze_label_cooccurrence(labels_matrix, label_names):
    # (This function is unchanged)
    cooccurrence = np.dot(labels_matrix.T, labels_matrix)
    label_frequencies = np.sum(labels_matrix, axis=0)
    cooccurrence_prob = {}
    for i, label1 in enumerate(label_names):
        for j, label2 in enumerate(label_names):
            if i != j and label_frequencies[i] > 0:
                prob = cooccurrence[i, j] / label_frequencies[i]
                if prob > 0.3:
                    cooccurrence_prob[(label1, label2)] = prob
    return cooccurrence_prob

class ImprovedMentalQADataset(Dataset):
    # (This function is unchanged)
    def __init__(self, encodings, labels):
        self.encodings, self.labels = encodings, labels
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.float)
        return item
    def __len__(self):
        return len(self.labels)

def adaptive_threshold_prediction(logits, label_names, cooccurrence_prob, base_threshold=0.3):
    # (This function is unchanged)
    probs = 1 / (1 + np.exp(-logits))
    predictions = []
    for i in range(len(probs)):
        sample_probs = probs[i]
        predicted_labels = {label_names[idx] for idx in np.where(sample_probs >= base_threshold)[0]}
        for label in list(predicted_labels):
            for idx, other_label in enumerate(label_names):
                if other_label not in predicted_labels and (label, other_label) in cooccurrence_prob:
                    cooccur_prob = cooccurrence_prob[(label, other_label)]
                    adjusted_threshold = base_threshold * (1 - cooccur_prob * 0.5)
                    if sample_probs[idx] >= adjusted_threshold:
                        predicted_labels.add(other_label)
        if not predicted_labels:
            predicted_labels.add(label_names[np.argmax(sample_probs)])
        if len(predicted_labels) > 4:
            label_prob_pairs = sorted([(label, sample_probs[label_names.index(label)]) for label in predicted_labels], key=lambda x: x[1], reverse=True)
            predicted_labels = {pair[0] for pair in label_prob_pairs[:4]}
        predictions.append(sorted(list(predicted_labels)))
    return predictions



def adaptive_threshold_prediction(logits, label_names, cooccurrence_prob, base_threshold=0.3):
    # Convert the numpy array of label names to a list to enable the .index() method
    label_names_list = list(label_names)

    probs = 1 / (1 + np.exp(-logits))
    predictions = []

    for i in range(len(probs)):
        sample_probs = probs[i]
        # Use the list for indexing
        predicted_labels = {label_names_list[idx] for idx in np.where(sample_probs >= base_threshold)[0]}

        for label in list(predicted_labels):
            # Use the list for enumerating
            for idx, other_label in enumerate(label_names_list):
                if other_label not in predicted_labels and (label, other_label) in cooccurrence_prob:
                    cooccur_prob = cooccurrence_prob[(label, other_label)]
                    adjusted_threshold = base_threshold * (1 - cooccur_prob * 0.5)
                    if sample_probs[idx] >= adjusted_threshold:
                        predicted_labels.add(other_label)

        if not predicted_labels:
            # Use the list for indexing
            predicted_labels.add(label_names_list[np.argmax(sample_probs)])

        if len(predicted_labels) > 4:
            # This is the line that caused the error, now fixed by using the list.
            label_prob_pairs = sorted([(label, sample_probs[label_names_list.index(label)]) for label in predicted_labels], key=lambda x: x[1], reverse=True)
            predicted_labels = {pair[0] for pair in label_prob_pairs[:4]}

        predictions.append(sorted(list(predicted_labels)))

    return predictions



# --- Main Execution ---
def main():
    # This part is common to both modes: We must load the training data to learn
    # the label set (for the MultiLabelBinarizer) and co-occurrence probabilities.
    print("\n--- Loading Training Data for Label Information ---")
    full_df = load_and_prepare_data(DATA_PATH, LABELS_PATH)
    if full_df is None: return

    print("\n--- Preprocessing Labels ---")
    all_labels_flat = [label for sublist in process_label_strings(full_df['labels_str']) for label in sublist]
    all_labels = sorted(list(set(all_labels_flat)))
    print(f"Discovered {len(all_labels)} unique labels: {all_labels}")
    mlb = MultiLabelBinarizer(classes=all_labels)
    train_labels_for_fitting = mlb.fit_transform(process_label_strings(full_df['labels_str']))
    print("Label processing complete.")

    print("\n--- Analyzing Label Co-occurrence ---")
    cooccurrence_prob = analyze_label_cooccurrence(train_labels_for_fitting, all_labels)
    print(f"Found {len(cooccurrence_prob)} strong label co-occurrence patterns")

    if not EVALUATE_ONLY:
        # --- TRAINING MODE ---
        print("\n--- RUNNING IN TRAINING MODE ---")
        print("Tokenizing training text...")
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        train_encodings = tokenizer(full_df['text'].tolist(), truncation=True, padding=True, max_length=256)
        train_dataset = ImprovedMentalQADataset(train_encodings, train_labels_for_fitting)

        print("Initializing model for training...")
        model = ImprovedMultiLabelModel(MODEL_NAME, len(all_labels), alpha=1.0, gamma=2.0).to(DEVICE)
        training_args = TrainingArguments(
            output_dir=TRAINING_OUTPUT_DIR, num_train_epochs=15, per_device_train_batch_size=8,
            gradient_accumulation_steps=2, learning_rate=2e-5, warmup_steps=100,
            weight_decay=0.01, logging_dir='./logs', logging_steps=20, save_strategy="epoch",
            save_total_limit=3, dataloader_num_workers=2, fp16=True if torch.cuda.is_available() else False,
        )
        trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset)

        print("Starting fine-tuning...")
        trainer.train()
        print("Fine-tuning complete.")

    # --- EVALUATION ON TEST SET (runs in both modes) ---
    print("\n--- Preparing for Test Set Evaluation ---")
    # Find the latest checkpoint from the training output directory
    print(f"Searching for checkpoints in '{TRAINING_OUTPUT_DIR}'...")
    checkpoints = sorted(glob.glob(os.path.join(TRAINING_OUTPUT_DIR, 'checkpoint-*')), key=os.path.getmtime)

    if not checkpoints:
        print(f"FATAL: No checkpoints found in '{TRAINING_OUTPUT_DIR}'. Cannot run evaluation.")
        return

    latest_checkpoint_path = checkpoints[-1]
    print(f"Found latest checkpoint to use: {latest_checkpoint_path}")

    # Call the evaluation function
    evaluate_on_test_set(
        checkpoint_path=latest_checkpoint_path,
        test_data_path=TEST_DATA_PATH,
        output_path=TEST_PREDICTION_OUTPUT_PATH,
        mlb=mlb,
        cooccurrence_prob=cooccurrence_prob
    )

if __name__ == "__main__":
    main()

Using device: cpu

--- Loading Training Data for Label Information ---

--- Preprocessing Labels ---
Discovered 7 unique labels: ['A', 'B', 'C', 'D', 'E', 'F', 'Z']
Label processing complete.

--- Analyzing Label Co-occurrence ---
Found 14 strong label co-occurrence patterns

--- Preparing for Test Set Evaluation ---
Searching for checkpoints in '/content/drive/MyDrive/AraHealthQA/MentalQA/Task1/output/improved_camelbert_checkpoints'...
Found latest checkpoint to use: /content/drive/MyDrive/AraHealthQA/MentalQA/Task1/output/improved_camelbert_checkpoints/checkpoint-330

--- Starting Evaluation on Test Set ---
Loading model and tokenizer from: /content/drive/MyDrive/AraHealthQA/MentalQA/Task1/output/improved_camelbert_checkpoints/checkpoint-330
Instantiating custom model structure...


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at CAMeL-Lab/bert-base-arabic-camelbert-mix-sentiment and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([3, 768]) in the checkpoint and torch.Size([7, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([3]) in the checkpoint and torch.Size([7]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loading weights from: /content/drive/MyDrive/AraHealthQA/MentalQA/Task1/output/improved_camelbert_checkpoints/checkpoint-330/model.safetensors
Loading tokenizer from original base model: 'CAMeL-Lab/bert-base-arabic-camelbert-mix-sentiment'
Loading test data from: /content/drive/MyDrive/AraHealthQA/MentalQA/Task1/subtask1_input_test.tsv
Loaded 150 samples from the test set.
Tokenizing test data...
Generating predictions on the test set...


Generated 150 predictions for the test set.
Test set predictions successfully saved to '/content/drive/MyDrive/AraHealthQA/MentalQA/Task1/output/predictions_on_test_set.tsv'


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [9]:
from safetensors.torch import load_file