# Classification model

In [41]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [42]:
!unzip /content/drive/MyDrive/daic_data/daic_data.zip

Archive:  /content/drive/MyDrive/daic_data/daic_data.zip
   creating: daic_data/
   creating: daic_data/labels/
  inflating: daic_data/labels/dev.csv  
  inflating: daic_data/labels/train.csv  
  inflating: daic_data/labels/test.csv  
   creating: daic_data/transcripts/
  inflating: daic_data/transcripts/408_TRANSCRIPT.csv  
  inflating: daic_data/transcripts/457_TRANSCRIPT.csv  
  inflating: daic_data/transcripts/378_TRANSCRIPT.csv  
  inflating: daic_data/transcripts/488_TRANSCRIPT.csv  
  inflating: daic_data/transcripts/475_TRANSCRIPT.csv  
  inflating: daic_data/transcripts/336_TRANSCRIPT.csv  
  inflating: daic_data/transcripts/332_TRANSCRIPT.csv  
  inflating: daic_data/transcripts/361_TRANSCRIPT.csv  
  inflating: daic_data/transcripts/419_TRANSCRIPT.csv  
  inflating: daic_data/transcripts/452_TRANSCRIPT.csv  
  inflating: daic_data/transcripts/396_TRANSCRIPT.csv  
  inflating: daic_data/transcripts/435_TRANSCRIPT.csv  
  inflating: daic_data/transcripts/418_TRANSCRIPT.csv  
 

## Libraries

In [43]:
import pandas as pd
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModel, EarlyStoppingCallback
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from peft import LoraConfig, get_peft_model

In [44]:
MODEL_NAME = "allenai/longformer-base-4096"
DATA_DIR = "/content/daic_data"

## Data loader

In [45]:
augmented_df = pd.read_csv("/content/augmented_dataset_745.csv")

In [46]:
## Save Augmented Dataset in Original Format

def save_augmented_dataset_in_original_format(
    augmented_df,
    output_dir="./augmented_daic_data/",
    use_augmented=True,
    keep_original_labels=True
):
    """
    Save augmented dataset in the same folder and CSV structure as the original dataset.
    """

    import os
    import shutil
    import pandas as pd

    # --- Create output directory structure ---
    transcripts_dir = os.path.join(output_dir, "transcripts")
    labels_dir = os.path.join(output_dir, "labels")

    os.makedirs(transcripts_dir, exist_ok=True)
    os.makedirs(labels_dir, exist_ok=True)

    print(f"Saving augmented dataset to: {output_dir}")

    # --- Prepare label DataFrame ---
    label_df = pd.DataFrame(columns=["Participant_ID", "PHQ8_Binary"])

    # --- Generate transcripts by participant ---
    participants = augmented_df['participant_id'].unique()
    print(f"\nCreating transcript files for {len(participants)} participants...")

    for participant_id in participants:
        participant_data = augmented_df[augmented_df['participant_id'] == participant_id].copy()

        # First label (binary)
        participant_label = participant_data['depression_label'].iloc[0]

        # Each augmentation index = a new synthetic participant ID
        augmentation_indexes = participant_data['augmentation_index'].unique()

        for augmentation_index in augmentation_indexes:

            # Subset this augmented version
            data = participant_data[participant_data['augmentation_index'] == augmentation_index].copy()
            data = data.sort_values(by='start_time').reset_index(drop=True)

            transcript_entries = []

            for idx, row in data.iterrows():

                question = str(row['question']).strip()
                answer = str(row['augmented_answer']).strip()

                if not answer:
                    continue

                start_time = float(row['start_time'])
                stop_time = start_time + 5.0  # naive duration approximation

                # Ellie question entry
                transcript_entries.append({
                    'speaker': 'Ellie',
                    'start_time': start_time - 2.0,
                    'stop_time': start_time - 0.5,
                    'value': question
                })

                # Participant answer entry
                transcript_entries.append({
                    'speaker': 'Participant',
                    'start_time': start_time,
                    'stop_time': stop_time,
                    'value': answer
                })

            # If transcript has valid entries, save it
            if transcript_entries:

                transcript_df = pd.DataFrame(transcript_entries)
                transcript_df = transcript_df.sort_values(by='start_time').reset_index(drop=True)

                transcript_filename = f"{int(participant_id)}{int(augmentation_index)}_TRANSCRIPT.csv"
                transcript_path = os.path.join(transcripts_dir, transcript_filename)

                transcript_df.to_csv(transcript_path, sep='\t', index=False, encoding="utf-8-sig")

                # Create new synthetic participant ID
                new_participant_id = int(f"{int(participant_id)}{int(augmentation_index)}")

                # Append label row
                label_df.loc[len(label_df)] = {
                    "Participant_ID": new_participant_id,
                    "PHQ8_Binary": participant_label
                }

    print(f"\nSaved {len(participants)} participants' transcript files in:")
    print(f"  → {transcripts_dir}")

    # --- Save labels ---
    print("\nMerging new labels with original ones...")

    # Caminho dos labels originais
    old_train_labels = os.path.join(DATA_DIR, "labels", "train.csv")

    # 1. Load old labels
    if os.path.exists(old_train_labels):
        old_df = pd.read_csv(old_train_labels)
    else:
        print("⚠️ Warning: old train.csv not found! Creating a new one.")
        old_df = pd.DataFrame(columns=label_df.columns)

    # 2. Concatenate old + new
    merged_labels = pd.concat([old_df, label_df], ignore_index=True)

    # 3. Remove duplicates (optional but recommended)
    merged_labels = merged_labels.drop_duplicates(subset=["Participant_ID"], keep="first")

    # 4. Save final merged labels in the new folder
    label_filename = "train.csv"
    label_path = os.path.join(labels_dir, label_filename)

    merged_labels.to_csv(label_path, index=False, encoding="utf-8-sig")

    print("\n✅ Labels merged and saved successfully!")
    print(f"   Transcripts directory: {transcripts_dir}")
    print(f"   Labels directory: {labels_dir}")


    return output_dir

# Save augmented dataset in original format
# Configuration


save_augmented_dataset_in_original_format(
    augmented_df,
    output_dir=DATA_DIR,
)


Saving augmented dataset to: /content/daic_data

Creating transcript files for 30 participants...

Saved 30 participants' transcript files in:
  → /content/daic_data/transcripts

Merging new labels with original ones...

✅ Labels merged and saved successfully!
   Transcripts directory: /content/daic_data/transcripts
   Labels directory: /content/daic_data/labels


'/content/daic_data'

In [47]:
def process_daic_data(data_dir):
  transcripts_dir = os.path.join(data_dir, "topic_transcripts/transcripts")
  # transcripts_dir = os.path.join(data_dir, "transcripts")
  labels_dir = os.path.join(data_dir, "labels")

  df = pd.DataFrame()

  for file in os.listdir(labels_dir):
    if not file.endswith(".csv"):
      continue

    split_name = file.replace(".csv", "")
    split_df = pd.read_csv(os.path.join(labels_dir, file))
    split_df = split_df.rename(columns={
      "PHQ_Binary": "depression_label",
      "PHQ_Score": "depression_severity",
      "PHQ8_Binary": "depression_label",
      "PHQ8_Score": "depression_severity",
      "Participant_ID": "participant_id",
    })

    transcripts_df = create_dataframe(split_df, transcripts_dir)
    transcripts_df["split"] = split_name

    df = pd.concat([df, transcripts_df], ignore_index=True)

  return df

def create_dataframe(split_df, transcripts_dir):
  df = {"text": [], "depression_label": []}

  for _, row in split_df.iterrows():
    participant_id = str(int(float(row.participant_id)))
    depression_label = int(row.depression_label)

    participant_text = ""
    transcript_file = os.path.join(transcripts_dir, f"{participant_id}_TRANSCRIPT.csv")
    if not os.path.exists(transcript_file):
      print(f"Transcript file not found for participant {participant_id}")
      continue

    transcripts = pd.read_csv(transcript_file, sep="\t")
    participant_transcripts = transcripts[transcripts['speaker'] == 'Participant']

    for _, transcript_row in participant_transcripts.iterrows():
      participant_text += str(transcript_row.value) + " "

    df["text"].append(participant_text.strip())
    df["depression_label"].append(depression_label)

  return pd.DataFrame(df)


## Train classification model

In [48]:
class TranscriptsDataset(Dataset):
  def __init__(self, dataframe, tokenizer, max_length=4096):
    self.data = dataframe
    self.tokenizer = tokenizer
    self.max_length = max_length

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    text = str(self.data.iloc[idx]["text"])
    label = int(self.data.iloc[idx]["depression_label"])

    encoding = self.tokenizer(
      text,
      truncation=True,
      padding="max_length",
      max_length=self.max_length,
      return_tensors="pt"
    )

    return {
      "input_ids": encoding["input_ids"].squeeze(),
      "attention_mask": encoding["attention_mask"].squeeze(),
      "labels": torch.tensor(label, dtype=torch.long),
    }

class TextFeaturizer(nn.Module):
  def __init__(self, model_name, dropout=0.5, dense_size=256,
               lora_r=8, lora_alpha=16, lora_dropout=0.1):
    super().__init__()

    # Load Longformer encoder
    self.encoder = AutoModel.from_pretrained(model_name)
    hidden_size = self.encoder.config.hidden_size

    self.projection = nn.Sequential(
      nn.Linear(hidden_size, dense_size),
      nn.ReLU(),
      nn.Dropout(dropout)
    )

    lora_config = LoraConfig(
      r=lora_r,
      lora_alpha=lora_alpha,
      target_modules=["query", "key", "value"],
      lora_dropout=lora_dropout,
      bias="none",
      task_type="FEATURE_EXTRACTION"
    )
    self.encoder = get_peft_model(self.encoder, lora_config)

    for name, param in self.encoder.named_parameters():
      if 'lora' not in name:
        param.requires_grad = False

  def forward(self, input_ids, attention_mask):
    outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
    cls_token = outputs.last_hidden_state[:, 0]
    return self.projection(cls_token)

class FocalLoss(nn.Module):
  """Focal Loss for addressing class imbalance by focusing on hard examples."""
  def __init__(self, alpha=0.75, gamma=2.0, reduction='mean'):
    super().__init__()
    self.alpha = alpha
    self.gamma = gamma
    self.reduction = reduction

  def forward(self, inputs, targets):
    ce_loss = nn.functional.cross_entropy(inputs, targets, reduction='none')
    pt = torch.exp(-ce_loss)
    focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss

    if self.reduction == 'mean':
      return focal_loss.mean()
    elif self.reduction == 'sum':
      return focal_loss.sum()
    else:
      return focal_loss

class TextClassifier(nn.Module):
  def __init__(self, model_name, num_labels=2, class_weights=None, use_focal_loss=True, focal_alpha=0.75, focal_gamma=2.0):
    super().__init__()
    self.featurizer = TextFeaturizer(model_name)
    self.classifier = nn.Linear(256, num_labels)
    self.use_focal_loss = use_focal_loss
    self.focal_alpha = focal_alpha
    self.focal_gamma = focal_gamma

    # Store class weights for loss calculation (if not using focal loss)
    if class_weights is not None and not use_focal_loss:
      self.register_buffer('class_weights', torch.tensor(class_weights, dtype=torch.float32))
    else:
      self.class_weights = None

    # Initialize focal loss if using it
    if use_focal_loss:
      self.focal_loss = FocalLoss(alpha=focal_alpha, gamma=focal_gamma)

  def forward(self, input_ids, attention_mask, labels=None):
    features = self.featurizer(input_ids, attention_mask)
    logits = self.classifier(features)

    # Add regularization to prevent extreme logits (helps prevent all-one-class predictions)
    # Penalize large logit differences between classes to encourage more balanced predictions
    logit_diff = torch.abs(logits[:, 0] - logits[:, 1])
    logit_regularization = 0.01 * torch.mean(logit_diff ** 2)  # Penalize large differences
    # Also add L2 regularization on logits to keep them from becoming too extreme
    logit_l2 = 0.001 * torch.mean(logits ** 2)
    total_regularization = logit_regularization + logit_l2

    if labels is not None:
      if self.use_focal_loss:
        # Use Focal Loss for better handling of imbalanced datasets
        loss = self.focal_loss(logits, labels)
      else:
        # Fallback to weighted CrossEntropyLoss if not using focal loss
        if self.class_weights is not None:
          loss_fn = nn.CrossEntropyLoss(weight=self.class_weights)
        else:
          loss_fn = nn.CrossEntropyLoss()
        loss = loss_fn(logits, labels)

      # Add regularization to loss
      loss = loss + total_regularization
      return {"loss": loss, "logits": logits}
    return {"logits": logits}

def compute_metrics(eval_pred):
  """Compute metrics for evaluation during training with balanced predictions.
  Uses adaptive threshold to prevent all-one-class predictions."""
  predictions, labels = eval_pred

  # Convert logits to probabilities
  probs = torch.softmax(torch.tensor(predictions), dim=-1).numpy()
  probs_class1 = probs[:, 1]

  # Calculate true class distribution
  class_1_ratio = np.mean(labels)

  # Strategy 1: Try to match the class distribution with quantile threshold
  if class_1_ratio > 0.05 and class_1_ratio < 0.95:
    threshold = np.percentile(probs_class1, (1 - class_1_ratio) * 100)
    preds = (probs_class1 >= threshold).astype(int)
  else:
    # For extreme distributions, start with median
    threshold = np.median(probs_class1)
    preds = (probs_class1 >= threshold).astype(int)

  # Strategy 2: Ensure we have predictions for both classes
  unique_preds = np.unique(preds)
  if len(unique_preds) == 1:
    # If all predictions are the same class, find a threshold that gives diversity
    # Try a range of thresholds to find one that predicts both classes
    best_threshold = 0.5
    best_f1 = -1

    for trial_threshold in np.arange(0.1, 0.9, 0.05):
      preds_trial = (probs_class1 >= trial_threshold).astype(int)
      unique_trial = np.unique(preds_trial)

      # Prefer thresholds that predict both classes
      if len(unique_trial) > 1:
        # Calculate F1 with this threshold
        f1_trial = f1_score(labels, preds_trial, zero_division=0)
        if f1_trial > best_f1:
          best_f1 = f1_trial
          best_threshold = trial_threshold
          preds = preds_trial

    # If still all one class after trying thresholds, force some diversity
    if len(np.unique(preds)) == 1:
      # Sort probabilities and predict top k% as class 1
      k = max(10, int(class_1_ratio * 100))  # At least 10% or match distribution
      k = min(90, k)  # At most 90%
      sorted_indices = np.argsort(probs_class1)
      preds = np.zeros_like(labels)
      preds[sorted_indices[-k:]] = 1

  accuracy = accuracy_score(labels, preds)
  precision = precision_score(labels, preds, zero_division=0, average='binary')
  recall = recall_score(labels, preds, zero_division=0, average='binary')
  f1 = f1_score(labels, preds, zero_division=0, average='binary')

  return {
    "accuracy": accuracy,
    "precision": precision,
    "recall": recall,
    "f1": f1
  }

def find_optimal_threshold(trainer, val_dataset):
  """Find optimal threshold that maximizes F1 score on validation set, with constraints to avoid extreme predictions."""
  predictions = trainer.predict(val_dataset)
  probs = torch.softmax(torch.tensor(predictions.predictions), dim=-1)
  probs_class1 = probs[:, 1].numpy()
  labels = predictions.label_ids

  best_threshold = 0.5
  best_f1 = 0
  best_metrics = {}

  # Get true class distribution
  true_class1_ratio = np.mean(labels)

  # Try different thresholds with wider search range to avoid extreme predictions
  # Search more thoroughly across the probability range
  for threshold in np.arange(0.1, 0.9, 0.01):
    preds = (probs_class1 >= threshold).astype(int)

    # Calculate metrics
    f1 = f1_score(labels, preds, zero_division=0)
    precision = precision_score(labels, preds, zero_division=0)
    recall = recall_score(labels, preds, zero_division=0)

    # Check if predictions are balanced (not all one class)
    pred_class1_count = np.sum(preds)
    total_samples = len(preds)
    pred_class1_ratio = pred_class1_count / total_samples

    # Skip thresholds that predict all as one class
    if pred_class1_count == 0 or pred_class1_count == total_samples:
      continue

    # Prefer thresholds that give reasonable class balance
    # Add small bonus for thresholds close to true distribution
    balance_bonus = 1.0
    if abs(pred_class1_ratio - true_class1_ratio) < 0.2:  # Within 20% of true ratio
      balance_bonus = 1.05  # 5% bonus

    # Score combines F1 with balance consideration
    score = f1 * balance_bonus

    if score > best_f1 or (score == best_f1 and abs(pred_class1_ratio - true_class1_ratio) < abs(best_metrics.get('pred_class1_ratio', 1) - true_class1_ratio)):
      best_f1 = f1
      best_threshold = threshold
      best_metrics = {
        'f1': f1,
        'precision': precision,
        'recall': recall,
        'pred_class1_ratio': pred_class1_ratio,
        'true_class1_ratio': true_class1_ratio
      }

  # Fallback if no good threshold found
  if best_threshold == 0.5 and best_f1 == 0:
    print("Warning: No balanced threshold found in search, trying alternative approaches...")

    # Try using quantile-based threshold
    threshold = np.percentile(probs_class1, (1 - true_class1_ratio) * 100)
    preds = (probs_class1 >= threshold).astype(int)

    if len(np.unique(preds)) > 1:  # If we get both classes
      best_threshold = threshold
      best_f1 = f1_score(labels, preds, zero_division=0)
      best_metrics = {
        'f1': best_f1,
        'precision': precision_score(labels, preds, zero_division=0),
        'recall': recall_score(labels, preds, zero_division=0),
        'pred_class1_ratio': np.sum(preds) / len(preds),
        'true_class1_ratio': true_class1_ratio
      }
    else:
      # Last resort: use median
      best_threshold = np.median(probs_class1)
      preds = (probs_class1 >= best_threshold).astype(int)
      best_f1 = f1_score(labels, preds, zero_division=0)
      best_metrics = {
        'f1': best_f1,
        'precision': precision_score(labels, preds, zero_division=0),
        'recall': recall_score(labels, preds, zero_division=0),
        'pred_class1_ratio': np.sum(preds) / len(preds),
        'true_class1_ratio': true_class1_ratio
      }

  print(f"Optimal threshold: {best_threshold:.3f}")
  print(f"  F1: {best_metrics.get('f1', 0):.4f}, Precision: {best_metrics.get('precision', 0):.4f}, Recall: {best_metrics.get('recall', 0):.4f}")
  print(f"  Predicted class 1 ratio: {best_metrics.get('pred_class1_ratio', 0):.2%}")
  print(f"  True class 1 ratio: {best_metrics.get('true_class1_ratio', 0):.2%}")
  return best_threshold

def evaluate_model(trainer, test_dataset, threshold=None):
  from sklearn.metrics import confusion_matrix  # Import here if not already imported
  predictions = trainer.predict(test_dataset)
  probs = torch.softmax(torch.tensor(predictions.predictions), dim=-1)
  probs_class1 = probs[:, 1].numpy()
  labels = predictions.label_ids

  # Use optimal threshold if provided, otherwise use argmax (threshold=0.5)
  if threshold is not None:
    preds = (probs_class1 >= threshold).astype(int)
    print(f"Using threshold: {threshold:.3f}")
  else:
    preds = np.argmax(predictions.predictions, axis=1)
    print("Using default threshold: 0.5 (argmax)")

  # Show first 10 predictions as examples
  print("Sample predictions (first 10):")
  for i, (label, pred) in enumerate(zip(labels[:10], preds[:10])):
    print(f"  True: {label}, Predicted: {pred}, Prob(class1): {probs_class1[i]:.3f}")
  if len(labels) > 10:
    print(f"  ... ({len(labels) - 10} more predictions)")

  accuracy = accuracy_score(labels, preds)
  precision = precision_score(labels, preds, zero_division=0)
  recall = recall_score(labels, preds, zero_division=0)
  f1 = f1_score(labels, preds, zero_division=0)

  # Calculate and print confusion matrix
  cm = confusion_matrix(labels, preds)

  print(f"\nConfusion Matrix:")
  print(f"                Predicted")
  print(f"              Non-Dep  Depressed")
  print(f"Actual Non-Dep    {cm[0,0]:4d}      {cm[0,1]:4d}")
  print(f"       Depressed   {cm[1,0]:4d}      {cm[1,1]:4d}")
  print(f"\nConfusion Matrix (detailed):")
  print(f"  True Negatives (TN): {cm[0,0]} - Correctly predicted non-depressed")
  print(f"  False Positives (FP): {cm[0,1]} - Non-depressed predicted as depressed")
  print(f"  False Negatives (FN): {cm[1,0]} - Depressed predicted as non-depressed")
  print(f"  True Positives (TP): {cm[1,1]} - Correctly predicted depressed")

  print(f"\nTest Metrics:")
  print(f"  Accuracy: {accuracy:.4f}")
  print(f"  Precision: {precision:.4f}")
  print(f"  Recall: {recall:.4f}")
  print(f"  F1 Score: {f1:.4f}")

  return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "confusion_matrix": cm}

def train_model(df, save_model=True, model_save_path="./depression_classifier_model"):
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

  train_df = df[df['split'] == 'train'].reset_index(drop=True)
  val_df = df[df['split'] == 'dev'].reset_index(drop=True)
  test_df = df[df['split'] == 'test'].reset_index(drop=True)

  print(f"Training samples: {len(train_df)}")
  print(f"Validation samples: {len(val_df)}")
  print(f"Test samples: {len(test_df)}")

  # Calculate class weights to handle imbalanced dataset (for reference, though we use Focal Loss)
  from sklearn.utils.class_weight import compute_class_weight
  labels = train_df['depression_label'].values
  classes = np.unique(labels)
  class_weights_balanced = compute_class_weight('balanced', classes=classes, y=labels)

  # Apply multiplier to strengthen minority class weight (step 4)
  weight_multiplier = 1.8  # Increase weight for minority class
  class_weights = class_weights_balanced.copy()
  # Find minority class (class with fewer samples)
  class_counts = train_df['depression_label'].value_counts().sort_index()
  minority_class = class_counts.idxmin()
  minority_class_idx = list(classes).index(minority_class)
  class_weights[minority_class_idx] *= weight_multiplier

  class_weights_dict = dict(zip(classes, class_weights))
  print(f"\nClass distribution in training set:")
  print(train_df['depression_label'].value_counts().sort_index())
  print(f"Balanced class weights: {dict(zip(classes, class_weights_balanced))}")
  print(f"Adjusted class weights (multiplier={weight_multiplier}x for minority): {class_weights_dict}")
  print(f"Using Focal Loss (alpha=0.85, gamma=2.5) for better imbalance handling")

  train_dataset = TranscriptsDataset(train_df, tokenizer)
  val_dataset = TranscriptsDataset(val_df, tokenizer)
  test_dataset = TranscriptsDataset(test_df, tokenizer)

  # Use Focal Loss with adjusted parameters for better learning and preventing all-one-class predictions
  # Higher alpha (0.9) gives more weight to minority class, higher gamma (3.0) focuses more on hard examples
  # These parameters help prevent the model from collapsing to predicting only one class
  model = TextClassifier(MODEL_NAME, num_labels=2, use_focal_loss=True, focal_alpha=0.9, focal_gamma=3.0)

  # Initialize classifier weights with small values to prevent extreme initial predictions
  # This helps avoid the model starting with all-one-class predictions
  with torch.no_grad():
    nn.init.normal_(model.classifier.weight, mean=0.0, std=0.02)
    nn.init.zeros_(model.classifier.bias)

  print("\nModel initialized with balanced weights to prevent extreme initial predictions")

  training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",  # Evaluate every epoch
    logging_strategy="steps",
    logging_steps=10,  # Log more frequently to see training progress
    learning_rate=1e-5,  # Slightly lower learning rate for more stable training
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=10,  # More epochs to allow better learning
    gradient_accumulation_steps=4,
    fp16=True,
    save_strategy="epoch",  # Save every epoch
    load_best_model_at_end=True,  # Load the best model at the end
    metric_for_best_model="eval_f1",  # Use F1 score instead of eval_loss as the primary metric
    greater_is_better=True,  # F1 score should be maximized (higher is better)
    warmup_steps=10,  # Add warmup for better training stability
    weight_decay=0.01,  # Add weight decay for regularization
    save_total_limit=3,  # Keep only the best 3 models based on F1 score
    report_to="none",  # Disable wandb/tensorboard to reduce overhead
  )

  # Create early stopping callback - it will automatically use eval_f1 from TrainingArguments
  # The callback will use the metric specified in metric_for_best_model
  early_stopping = EarlyStoppingCallback(
    early_stopping_patience=3,  # Stop if F1 doesn't improve for 3 epochs
    early_stopping_threshold=0.001  # Minimum improvement threshold
  )

  # training_args = TrainingArguments(
  #   output_dir="./results",

  #   # --- LOGGING ---
  #   logging_strategy="steps",
  #   logging_steps=1,                 # Log every step
  #   report_to="none",                # or "tensorboard"
  #   log_level="info",
  #   log_level_replica="info",

  #   # --- EVAL DURING TRAINING ---
  #   eval_strategy="steps",     # Run evaluation more frequently
  #   eval_steps=10,                   # Eval every 10 steps
  #   save_strategy="steps",           # Save more often
  #   save_steps=10,

  #   # --- TRAINING HYPERPARAMETERS ---
  #   learning_rate=1e-5,
  #   per_device_train_batch_size=1,
  #   per_device_eval_batch_size=1,
  #   gradient_accumulation_steps=4,
  #   num_train_epochs=1,
  #   warmup_steps=10,
  #   weight_decay=0.01,
  #   fp16=True,

  #   # --- BEST MODEL ---
  #   load_best_model_at_end=True,
  #   metric_for_best_model="eval_f1",
  #   greater_is_better=True,

  #   # --- MISC ---
  #   save_total_limit=3,    # Keep only last 3 checkpoints
  # )

  trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,  # Compute F1, precision, recall, accuracy
    callbacks=[early_stopping],  # Add early stopping based on F1 score
  )

  print("Training configuration:")
  print(f"  Primary metric: eval_f1 (F1 score)")
  print(f"  Model selection: Best model selected based on highest F1 score")
  print(f"  Early stopping: Stops if F1 doesn't improve for 3 epochs")
  print(f"  Loss function: Focal Loss (for training, but model selection uses F1)")

  print("\nStarting training...")
  trainer.train()

  # Find optimal threshold on validation set (step 3)
  print("\nFinding optimal threshold on validation set...")

  # First, check probability distribution on validation set
  val_predictions = trainer.predict(val_dataset)
  val_probs = torch.softmax(torch.tensor(val_predictions.predictions), dim=-1)
  val_probs_class1 = val_probs[:, 1].numpy()
  print(f"\nValidation set probability distribution (class 1):")
  print(f"  Min: {val_probs_class1.min():.4f}, Max: {val_probs_class1.max():.4f}")
  print(f"  Mean: {val_probs_class1.mean():.4f}, Std: {val_probs_class1.std():.4f}")
  print(f"  Median: {np.median(val_probs_class1):.4f}")

  optimal_threshold = find_optimal_threshold(trainer, val_dataset)

  print("\nEvaluating on test set...")
  metrics = evaluate_model(trainer, test_dataset, threshold=optimal_threshold)

  if save_model:
    print(f"\nSaving model to {model_save_path}...")
    trainer.save_model(model_save_path)
    tokenizer.save_pretrained(model_save_path)
    print("Model saved successfully!")

  return trainer, metrics


## Load and Explore Data

## Generate topics

In [49]:
def train_index(train_dir, val_dir, test_dir):
  train_df = pd.read_csv(train_dir)
  val_df = pd.read_csv(val_dir)
  test_df = pd.read_csv(test_dir)

  train_df = train_df[["Participant_ID", "Gender"]].copy()
  test_df = test_df[["Participant_ID", "Gender"]].copy()
  val_df = val_df[["Participant_ID", "Gender"]].copy()

  return train_df, val_df, test_df


def topic_selection(transcript):
  interest = [
    "recently that you really enjoy",
    "traveling",
    "travel alot",
    "family",
    "fun",
    "best friend",
    "weekend",
  ]
  sleep = [
    "good night's sleep",
    "don't sleep well"
  ]
  feeling_depressed = [
    "really happy",
    "behavior",
    " disturbing thought",
    "feel_lately",
  ]
  failure = [
    "regret",
    "guilty",
    "proud",
    "being_parent",
    "best_quality"
  ]
  personality = [
    "introvert",
    "shyoutgoing"
  ]
  dignose = [
    "ptsd",
    "depression",
    "therapy is useful"
  ]
  parent = [
    "hard_parent",
    "best_parent",
    "easy_parent",
    "your_kid",
    "differnet_parent",
  ]

  ques = [interest, sleep, feeling_depressed, failure, personality, dignose, parent]
  topic_name = []
  for topic_count, topic in enumerate(ques):
    for sub_topic_count, sub_topic in enumerate(topic):
      # remove nan
      if type(transcript) == float:
        print(transcript)
        return "problem"
      if type(transcript) != float:
        if sub_topic in transcript:
          topic_name.append([topic_count, sub_topic_count])

  return topic_name


def data_retrieve(working_dir, train_id, should_save_no_topic=False):
  participants = train_id
  transcripts = pd.DataFrame()

  for index_p, row in participants.iterrows():
    filename = str(int(row.Participant_ID)) + "_TRANSCRIPT.csv"

    location = os.path.join(working_dir, filename)
    if not os.path.exists(location):
      continue

    temp = pd.read_csv(location, sep="\t")
    temp = temp.dropna(subset=["value"])
    temp = temp.reset_index(drop=True)  # Reset index to ensure sequential 0-based indexing
    temp["topic"] = pd.Series(dtype="object")
    temp["topic_value"] = pd.Series(dtype="object")
    temp["sub_topic"] = pd.Series(dtype="object")
    temp["participant"] = row.Participant_ID

    found_any_topic = False
    max_index = len(temp) - 1  # Get maximum valid index

    for index_t, row_t in temp.iterrows():
      if row_t.speaker == "Ellie":
        topic = topic_selection(row_t.value)

        if topic != [] and len(topic) > 1:
          df_try = row_t
          temp.append([df_try] * len(topic), ignore_index=True)

        for words in topic:
          found_any_topic = True
          check = False

          # Check bounds before accessing index_t + 1
          if index_t + 1 <= max_index and temp.iloc[index_t + 1]["speaker"] == "Participant":
            temp.loc[index_t + 1, "topic"] = words[0]
            temp.loc[index_t + 1, "sub_topic"] = words[1]
            temp.loc[index_t + 1, "topic_value"] = row_t.value

          # Check bounds before accessing index_t + 2
          if index_t + 2 <= max_index and temp.iloc[index_t + 2]["speaker"] == "Participant":
            check = True
            temp.loc[index_t + 2, "topic"] = words[0]
            temp.loc[index_t + 2, "sub_topic"] = words[1]
            temp.loc[index_t + 2, "topic_value"] = row_t.value

          # Check bounds before accessing index_t + 3
          if index_t + 3 <= max_index and check and temp.iloc[index_t + 3]["speaker"] == "Participant":
            temp.loc[index_t + 3, "topic"] = words[0]
            temp.loc[index_t + 3, "sub_topic"] = words[1]
            temp.loc[index_t + 3, "topic_value"] = row_t.value

    if not found_any_topic and should_save_no_topic:
      temp["topic"] = -1
      temp["sub_topic"] = -1
      temp["topic_value"] = "No_topic_found"

    temp.dropna(inplace=True)
    transcripts = pd.concat([transcripts, temp], axis=0)

  return transcripts


def generate_topic_transcripts(train_file, val_file, test_file, transcritps_dir, out_dir):
  """
  Generate topic-labeled transcripts and save them in the same structure as daic_data.
  Creates individual transcript files per participant in out_dir/transcripts/ folder.
  """
  # Create output directory structure matching daic_data
  out_transcripts_dir = os.path.join(out_dir, "transcripts")
  os.makedirs(out_transcripts_dir, exist_ok=True)

  train_df, val_df, test_df = train_index(train_file, val_file, test_file)

  print("Processing training transcripts...")
  train_transcripts = data_retrieve(transcritps_dir, train_df)

  print("Processing validation transcripts...")
  val_transcripts = data_retrieve(transcritps_dir, val_df, should_save_no_topic=True)

  print("Processing test transcripts...")
  test_transcripts = data_retrieve(transcritps_dir, test_df, should_save_no_topic=True)

  # Combine all transcripts and group by participant
  # This ensures we handle each participant only once, even if they appear in multiple splits
  all_transcripts_combined = pd.concat([train_transcripts, val_transcripts, test_transcripts], axis=0, ignore_index=True)

  files_saved = 0
  total_rows = 0

  # Group by participant and save individual files (matching daic_data structure)
  if len(all_transcripts_combined) > 0:
    for participant_id in all_transcripts_combined['participant'].unique():
      participant_data = all_transcripts_combined[all_transcripts_combined['participant'] == participant_id].copy()

      # Save transcript file with same naming convention as original
      filename = f"{int(participant_id)}_TRANSCRIPT.csv"
      file_path = os.path.join(out_transcripts_dir, filename)

      # Select and order columns to match original structure, adding topic columns
      columns_to_save = ['speaker', 'start_time', 'stop_time', 'value']
      if 'topic' in participant_data.columns:
        columns_to_save.extend(['topic', 'sub_topic', 'topic_value'])
      if 'participant' in participant_data.columns:
        columns_to_save.append('participant')

      # Only save columns that exist in the dataframe
      available_columns = [col for col in columns_to_save if col in participant_data.columns]
      participant_data = participant_data[available_columns]

      # Sort by start_time to maintain chronological order
      if 'start_time' in participant_data.columns:
        participant_data = participant_data.sort_values(by='start_time').reset_index(drop=True)

      participant_data.to_csv(file_path, index=False, sep="\t", encoding="utf-8")
      files_saved += 1
      total_rows += len(participant_data)

  print(f"\n✅ Topic transcripts generated successfully!")
  print(f"   Files saved: {files_saved} transcript files")
  print(f"   Total rows: {total_rows}")
  print(f"   Training: {len(train_transcripts)} rows")
  print(f"   Validation: {len(val_transcripts)} rows")
  print(f"   Test: {len(test_transcripts)} rows")
  print(f"\n   Output directory: {out_transcripts_dir}")


# Generate topic transcripts
train_file = os.path.join(DATA_DIR, "labels", "train.csv")
val_file = os.path.join(DATA_DIR, "labels", "dev.csv")
test_file = os.path.join(DATA_DIR, "labels", "test.csv")
transcripts_dir = os.path.join(DATA_DIR, "transcripts")
out_dir = os.path.join(DATA_DIR, "topic_transcripts")

generate_topic_transcripts(train_file, val_file, test_file, transcripts_dir, out_dir)

Processing training transcripts...
Processing validation transcripts...
Processing test transcripts...

✅ Topic transcripts generated successfully!
   Files saved: 219 transcript files
   Total rows: 5615
   Training: 2954 rows
   Validation: 1292 rows
   Test: 1369 rows

   Output directory: /content/daic_data/topic_transcripts/transcripts


In [50]:
# Load the data
df = process_daic_data(DATA_DIR)

# Display basic information about the dataset
print(f"Total samples: {len(df)}")
print(f"\nSplit distribution:")
print(df['split'].value_counts())
print(f"\nLabel distribution:")
print(df['depression_label'].value_counts())
print(f"\nLabel distribution by split:")
print(df.groupby(['split', 'depression_label']).size())

# Display first few rows
print(f"\nFirst few rows:")
df.head()

Total samples: 219

Split distribution:
split
train    137
test      47
dev       35
Name: count, dtype: int64

Label distribution:
depression_label
0    133
1     86
Name: count, dtype: int64

Label distribution by split:
split  depression_label
dev    0                   23
       1                   12
test   0                   33
       1                   14
train  0                   77
       1                   60
dtype: int64

First few rows:


Unnamed: 0,text,depression_label,split
0,i'm very close sometimes too close no mm not o...,0,dev
1,um uh going out with friends going to bars goi...,0,dev
2,no no i'm much <mu> much of more <m> more of a...,0,dev
3,no i don't um the last time i felt really ha...,1,dev
4,i just like um going to new places and just di...,1,dev


## Train the Model

In [51]:
trainer, metrics = train_model(df, save_model=True, model_save_path="./depression_classifier_model")

Training samples: 137
Validation samples: 35
Test samples: 47

Class distribution in training set:
depression_label
0    77
1    60
Name: count, dtype: int64
Balanced class weights: {np.int64(0): np.float64(0.8896103896103896), np.int64(1): np.float64(1.1416666666666666)}
Adjusted class weights (multiplier=1.8x for minority): {np.int64(0): np.float64(0.8896103896103896), np.int64(1): np.float64(2.055)}
Using Focal Loss (alpha=0.85, gamma=2.5) for better imbalance handling

Model initialized with balanced weights to prevent extreme initial predictions
Training configuration:
  Primary metric: eval_f1 (F1 score)
  Model selection: Best model selected based on highest F1 score
  Early stopping: Stops if F1 doesn't improve for 3 epochs
  Loss function: Focal Loss (for training, but model selection uses F1)

Starting training...


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.0791,0.080819,0.657143,0.5,0.5,0.5
2,0.0803,0.079599,0.657143,0.5,0.5,0.5
3,0.0793,0.078466,0.657143,0.5,0.5,0.5
4,0.0816,0.077909,0.657143,0.5,0.5,0.5



Finding optimal threshold on validation set...



Validation set probability distribution (class 1):
  Min: 0.5100, Max: 0.5160
  Mean: 0.5123, Std: 0.0013
  Median: 0.5125


Optimal threshold: 0.513
  F1: 0.5000, Precision: 0.5000, Recall: 0.5000
  Predicted class 1 ratio: 34.29%
  True class 1 ratio: 34.29%

Evaluating on test set...


Using threshold: 0.513
Sample predictions (first 10):
  True: 0, Predicted: 0, Prob(class1): 0.512
  True: 0, Predicted: 1, Prob(class1): 0.513
  True: 0, Predicted: 0, Prob(class1): 0.512
  True: 1, Predicted: 1, Prob(class1): 0.514
  True: 1, Predicted: 0, Prob(class1): 0.513
  True: 1, Predicted: 0, Prob(class1): 0.513
  True: 0, Predicted: 1, Prob(class1): 0.514
  True: 0, Predicted: 0, Prob(class1): 0.512
  True: 0, Predicted: 0, Prob(class1): 0.512
  True: 1, Predicted: 0, Prob(class1): 0.512
  ... (37 more predictions)

Confusion Matrix:
                Predicted
              Non-Dep  Depressed
Actual Non-Dep      29         4
       Depressed     10         4

Confusion Matrix (detailed):
  True Negatives (TN): 29 - Correctly predicted non-depressed
  False Positives (FP): 4 - Non-depressed predicted as depressed
  False Negatives (FN): 10 - Depressed predicted as non-depressed
  True Positives (TP): 4 - Correctly predicted depressed

Test Metrics:
  Accuracy: 0.7021
  Precisi