In [None]:
import torch
import glob, os
import pandas as pd
import gc
from transformers import pipeline
from torch.utils.data import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
from transformers import TrainingArguments, Trainer, StoppingCriteria, StoppingCriteriaList
import torch.nn.functional as F


# Data loading

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!unzip /content/drive/MyDrive/daic_data/daic_data.zip

## Data processing

In [None]:
def get_questions_answers_df(transcripts_dir):
  transcripts_files = glob.glob(os.path.join(transcripts_dir, "*.csv"))

  # Load and concatenate all transcript files
  df = pd.concat(
    (
      pd.read_csv(file, sep="\t", encoding="utf-8-sig").assign(source=os.path.basename(file))
      for file in transcripts_files
    ),
    ignore_index=True
  )

  # Create block_id to identify contiguous speaker segments
  df['block_id'] = (df['speaker'] != df['speaker'].shift(1)).cumsum()

  # Aggregate by source and block_id to merge contiguous segments by the same speaker
  df = df.groupby(['source', 'block_id']).agg(
    speaker=('speaker', 'first'),
    start_time=('start_time', 'min'),
    stop_time=('stop_time', 'max'),
    value=('value', lambda x: ' '.join(x.astype(str)))
  )

  # Sort by participant and time
  df = df.sort_values(by=['source', 'start_time']).reset_index()

  # Add previous speaker and value columns only if the previous source is the same
  df['prev_speaker'] = df.groupby('source')['speaker'].shift(1)
  df['prev_value'] = df.groupby('source')['value'].shift(1)

  is_answer = (
    (df['speaker'] == 'Participant') &
    (df['prev_speaker'] == 'Ellie') &
    (df['source'] == df['source'].shift(1))
  )

  df = df[is_answer].copy()
  df = df.rename(columns={
    'prev_value': 'question', # The previous Ellie utterance is the question
    'value': 'answer',            # The current Participant utterance is the answer
  })

  df['participant_id'] = df['source'].str.split("_").str[0].astype(int)
  df = df[['participant_id', 'question', 'answer', 'start_time']]

  return df

def add_labels_to_df(qa_df, labels_dir):
  splits = ['train', 'dev', 'test']

  all_labels_df = pd.DataFrame()
  for split in splits:
    split_labels_df = pd.read_csv(os.path.join(labels_dir, f"{split}.csv"))
    split_labels_df = split_labels_df.rename(columns={
      "Participant_ID": "participant_id",
      "PHQ8_Binary": "depression_label",
      "PHQ8_Score": "depression_severity",
      "PHQ_Binary": "depression_label",
      "PHQ_Score": "depression_severity",
    })
    split_labels_df = split_labels_df[["participant_id", "depression_label", "depression_severity"]]
    split_labels_df["split"] = split
    all_labels_df = pd.concat([all_labels_df, split_labels_df], ignore_index=True)

  merged_df = pd.merge(qa_df, all_labels_df, on="participant_id", how="left")
  return merged_df

def format_input(df, row, n_context=3):
  past_pairs = df[
    (df['participant_id'] == row['participant_id']) &
    (df.index < row.name)
  ].tail(n_context)

  context_lines = []
  for _, past_row in past_pairs.iterrows():
    q = str(past_row.get("question", "")).strip()
    a = str(past_row.get("answer", "")).strip()
    context_lines.append(f"Q: {q}\nA: {a}")

  context = "[START]\n" + "\n".join(context_lines) if context_lines else "[START]\n"

  instruction = (
    "### Instruction:\n"
    "You are analyzing a therapeutic interview between a virtual interviewer (Ellie) and a participant.\n"
    "The participant has a PHQ-8 score ranging from 0 (no depression) to 24 (severe depression). "
    f"This participant’s score is {row['depression_severity']}. "
    "Scores of 10 or higher are typically considered indicative of depression.\n"
    "Given the participant’s previous responses and their PHQ score, "
    "predict how they might answer the next question in a coherent and realistic way. "
    "Use natural, casual language. Avoid overly formal styles. "
    "Tolerate some irregularities (omissions, repetitions, filler words).\n\n"
  )

  question = str(row.get("question", "")).strip()

  input_text = f"### Input:\n{context}\nQ: {question}\nA:"

  return instruction + input_text

class InstructionDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_length=512):
        dataframe = dataframe.sort_values(
            by=['participant_id', 'start_time']
        ).reset_index(drop=True)

        self.df = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length

        self.samples = []

        for _, row in self.df.iterrows():
            prompt = format_input(self.df, row)
            response = str(row.get("answer", "")).strip()

            full = f"{prompt}\n\n### Response:\n{response} [END]"

            self.samples.append((prompt, full))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        prompt, full_text = self.samples[idx]

        encoded_full = self.tokenizer(
            full_text,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt"
        )

        encoded_prompt = self.tokenizer(
            prompt,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt"
        )

        input_ids = encoded_full["input_ids"].squeeze(0)
        attention_mask = encoded_full["attention_mask"].squeeze(0)

        prompt_len = (encoded_prompt["input_ids"] != self.tokenizer.pad_token_id).sum()

        labels = input_ids.clone()
        labels[:prompt_len] = -100  # ignore instruction & input tokens

        return {
            "input_ids": input_ids,
            "labels": labels,
            "attention_mask": attention_mask
        }

def load_daic_data(tokenizer, data_dir="./daic_data/", should_create_csv=False, return_splits=False):
  """Load DAIC data and optionally return train/validation splits.
  
  Args:
    tokenizer: Tokenizer instance
    data_dir: Directory containing data
    should_create_csv: Whether to save CSV file
    return_splits: If True, return train and validation datasets separately
    
  Returns:
    If return_splits=False: single InstructionDataset with all data
    If return_splits=True: tuple of (train_dataset, val_dataset)
  """
  transcripts_dir = os.path.join(data_dir, "transcripts")
  labels_dir = os.path.join(data_dir, "labels")

  qa_df = get_questions_answers_df(transcripts_dir)
  qa_df = add_labels_to_df(qa_df, labels_dir)

  if should_create_csv:
    qa_df.to_csv("questions_and_answers.csv", index=False, encoding="utf-8-sig")

  if return_splits:
    # Split into train and validation (using 'dev' as validation)
    train_df = qa_df[qa_df['split'] == 'train'].copy()
    val_df = qa_df[qa_df['split'] == 'dev'].copy()
    
    train_dataset = InstructionDataset(train_df, tokenizer)
    val_dataset = InstructionDataset(val_df, tokenizer)
    
    print(f"Train samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    
    return train_dataset, val_dataset
  else:
    instruction_dataset = InstructionDataset(qa_df, tokenizer)
    return instruction_dataset

In [None]:
def get_tokenizer_and_early_model(model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Add [END] as a special token
    special_tokens_dict = {"additional_special_tokens": ["[END]"]}
    tokenizer.add_special_tokens(special_tokens_dict)

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        load_in_8bit=False,
        device_map="auto"
    )

    # Resize token embeddings to accommodate the new [END] token
    model.resize_token_embeddings(len(tokenizer))

    model.config.pad_token_id = tokenizer.pad_token_id

    return tokenizer, model, model_name

def get_lora_model(model):
  lora_config = LoraConfig(
    r=8, # rank
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"], # depends on model architecture
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM
  )

  model = get_peft_model(model, lora_config)
  model.print_trainable_parameters()
  return model

def fine_tune_model(
    model,
    tokenizer,
    train_dataset,
    output_dir="./tiny_llama_instruction_tuned",
    eval_dataset=None,
):
    """Fine-tune model with optional validation dataset for monitoring.
    
    Args:
        model: Model to fine-tune
        tokenizer: Tokenizer instance
        train_dataset: Training dataset
        output_dir: Output directory for checkpoints
        eval_dataset: Optional validation dataset for monitoring
    """
    training_args = TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        gradient_accumulation_steps=4,
        warmup_ratio=0.1,  # 10% of training steps for warmup (scales with dataset size)
        num_train_epochs=1,  # Train on entire dataset (1333 steps)
        learning_rate=1e-4,  # Reduced from 2e-4 for more gradual learning
        lr_scheduler_type="cosine",  # Cosine decay for smooth learning rate reduction
        fp16=True,
        logging_steps=50,  # Increased since we have more steps now
        save_steps=200,  # Save checkpoints less frequently (every ~200 steps)
        eval_strategy="steps" if eval_dataset else "no",  # Evaluate during training if val set provided
        eval_steps=200,  # Evaluate every 200 steps (same as save_steps)
        save_total_limit=3,  # Keep only last 3 checkpoints to save space
        load_best_model_at_end=True if eval_dataset else False,  # Load best model if validation set provided
        metric_for_best_model="eval_loss" if eval_dataset else None,
        greater_is_better=False,  # Lower loss is better
    )

    def collator(batch):
        return {
            "input_ids": torch.stack([x["input_ids"] for x in batch]),
            "labels": torch.stack([x["labels"] for x in batch]),
            "attention_mask": torch.stack([x["attention_mask"] for x in batch]),
        }

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=collator,
    )

    trainer.train()
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)

def use_tokenizer(tokenizer, text):
  return tokenizer(text, truncation=True, padding='max_length', max_length=512)

# Stopping criteria for [END] token
class EndTokenStoppingCriteria(StoppingCriteria):
    def __init__(self, end_token_id):
        self.end_token_id = end_token_id
    
    def __call__(self, input_ids, scores, **kwargs):
        # Stop if the last generated token is [END]
        return input_ids[0][-1] == self.end_token_id

def create_stopping_criteria(tokenizer):
    """Create stopping criteria that stops at [END] token."""
    end_token_id = tokenizer.convert_tokens_to_ids("[END]")
    return StoppingCriteriaList([EndTokenStoppingCriteria(end_token_id)])

def clear_memory():
    """Clear GPU and CPU memory cache."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

def load_finetuned_model(model_name, tokenizer, checkpoint_path=None):
    """Load a finetuned model from checkpoint or final model.
    Note: Base model is loaded fresh each time to avoid PEFT weight conflicts.
    
    Args:
        model_name: Base model name
        tokenizer: Tokenizer instance
        checkpoint_path: Path to checkpoint, or None for final model
    """
    # Always load base model fresh to avoid PEFT weight conflicts
    # (PEFT models modify base model in place, so we can't reuse it)
    base = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
    
    # Resize token embeddings if [END] token was added
    if len(tokenizer) != base.get_input_embeddings().weight.shape[0]:
        base.resize_token_embeddings(len(tokenizer))
    
    if checkpoint_path:
        lora = PeftModel.from_pretrained(base, checkpoint_path)
    else:
        lora = PeftModel.from_pretrained(base, "./tiny_llama_instruction_tuned")
    
    return lora, base  # Return both so we can clean up base separately

def generate_response(model, tokenizer, prompt, max_new_tokens=100, stopping_criteria=None):
    """Generate a response using the model. Cleans up pipeline after use."""
    pipe = None
    try:
        pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
        
        if stopping_criteria is None:
            stopping_criteria = create_stopping_criteria(tokenizer)
        
        res = pipe(prompt, max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)
        return res[0]["generated_text"]
    finally:
        # Clean up pipeline
        if pipe is not None:
            del pipe
        clear_memory()

def extract_response_only(full_output, prompt):
    """Extract only the generated response part, removing the prompt."""
    if full_output.startswith(prompt):
        return full_output[len(prompt):].strip()
    return full_output

def unload_model(model):
    """Unload a model from memory."""
    if model is not None:
        # Move to CPU and delete
        if hasattr(model, 'cpu'):
            model.cpu()
        del model
        clear_memory()

In [None]:
print('Loading tokenizer and model...')
tokenizer, model, model_name = get_tokenizer_and_early_model()

print('Loading datasets (train and validation)...')
train_dataset, val_dataset = load_daic_data(tokenizer, should_create_csv=False, return_splits=True)

print('Getting LoRA model...')
model = get_lora_model(model)

print('Fine-tuning model with validation monitoring...')
fine_tune_model(model, tokenizer, train_dataset, eval_dataset=val_dataset)

Loading tokenizer and model...
Loading dataset...
Getting LoRA model...
trainable params: 1,126,400 || all params: 1,101,174,784 || trainable%: 0.1023
Fine-tuning model...


The model is already on multiple devices. Skipping the move to device specified in `args`.


Step,Training Loss
10,3.2168
20,2.5304
30,1.9921
40,2.0987
50,1.8073
60,1.9004
70,1.8776
80,1.6385
90,1.7208
100,1.8204


### Analyzing models

In [None]:
request = """
You are analyzing a therapeutic interview between a virtual interviewer (Ellie) and a participant.
The participant has a PHQ-8 score ranging from 0 (no depression) to 24 (severe depression). This participant’s score is 3. Scores of 10 or higher are typically considered indicative of depression.
Given the participant’s previous responses and their PHQ score, predict how they might answer the next question in a coherent and realistic way.Use natural, casual language. Avoid overly formal styles.Tolerate some irregularities (omissions, repetitions, filler words) given the conversational context.

### Input:
[START]
Q: right there are always trade offs in life aren't there
A: yeah
Q: what made you decide to do that
A: so um i think i think my in my life i knew that i there's a lot of things i have <ha> there's more dislikes <laughter> than likes so i kinda narrowed it down to what am i good at and what am i not good at and what am i gonna work well or who who am i gonna work well with and who will i not work well with so i kind of i kinda sorted out and then the list kind of mmm kind of answered itself so
Q: that sounds really hard
A: no it i don't think it was hard but it was just but i think it was a real reality check and i think it it's kind of a good thing 'cause sometimes trying to conform to doing things that doesn't really fit you doesn't make sense it's like trying to shove a a round peg into a square a square hole and it's like it just no matter how you try to shove it in it's not gonna go in so sometimes it's just might as well go down a path that seems to work better for you
Q: right that makes sense what's one of your most memorable experiences

### Response:
"""

### Last model

In [None]:
print("Loading final model...")
model, base_model = load_finetuned_model(model_name, tokenizer)

try:
    print("\nGenerating response...")
    full_output = generate_response(model, tokenizer, request, max_new_tokens=100)
    response_only = extract_response_only(full_output, request)
    
    print("=" * 80)
    print("FULL OUTPUT:")
    print("=" * 80)
    print(full_output)
    print("\n" + "=" * 80)
    print("RESPONSE ONLY:")
    print("=" * 80)
    print(response_only)
finally:
    # Clean up models from memory
    unload_model(model)
    unload_model(base_model)
    print("\nModels unloaded from memory.")


Loading best model...


Device set to use cuda:0



You are analyzing a therapeutic interview between a virtual interviewer (Ellie) and a participant.
The participant has a PHQ-8 score ranging from 0 (no depression) to 24 (severe depression). This participant’s score is 3. Scores of 10 or higher are typically considered indicative of depression.
Given the participant’s previous responses and their PHQ score, predict how they might answer the next question in a coherent and realistic way.Use natural, casual language. Avoid overly formal styles.Tolerate some irregularities (omissions, repetitions, filler words) given the conversational context.

### Input:
[START]
Q: right there are always trade offs in life aren't there
A: yeah
Q: what made you decide to do that
A: so um i think i think my in my life i knew that i there's a lot of things i have <ha> there's more dislikes <laughter> than likes so i kinda narrowed it down to what am i good at and what am i not good at and what am i gonna work well or who who am i gonna work well with and 

In [None]:
def test_all_checkpoints(model_name, tokenizer, prompt, output_dir="./tiny_llama_instruction_tuned"):
    """Test all checkpoint models and return results. Memory-efficient version.
    Each model is loaded, tested, and immediately unloaded to save memory.
    
    Args:
        model_name: Base model name
        tokenizer: Tokenizer instance
        prompt: Prompt to test with
        output_dir: Directory containing checkpoints
    """
    checkpoint_folders = sorted([
        f for f in os.listdir(output_dir) 
        if f.startswith("checkpoint-") and os.path.isdir(os.path.join(output_dir, f))
    ])
    
    if not checkpoint_folders:
        print("No checkpoints found.")
        return {}
    
    results = {}
    stopping_criteria = create_stopping_criteria(tokenizer)
    
    print(f"Found {len(checkpoint_folders)} checkpoints. Testing each (memory-efficient mode)...\n")
    
    for i, folder in enumerate(checkpoint_folders, 1):
        checkpoint_path = os.path.join(output_dir, folder)
        print(f"[{i}/{len(checkpoint_folders)}] Testing {folder}...")
        
        model = None
        base_model = None
        try:
            # Load model (returns both lora and base for cleanup)
            model, base_model = load_finetuned_model(model_name, tokenizer, checkpoint_path)
            
            # Generate response
            full_output = generate_response(model, tokenizer, prompt, max_new_tokens=100, stopping_criteria=stopping_criteria)
            response_only = extract_response_only(full_output, prompt)
            
            results[folder] = {
                "full_output": full_output,
                "response_only": response_only
            }
            
            print(f"✓ {folder} completed")
            
        except Exception as e:
            print(f"✗ Error testing {folder}: {e}")
            results[folder] = {"error": str(e)}
        finally:
            # Always unload models after each checkpoint to free memory
            if model is not None:
                unload_model(model)
            if base_model is not None:
                unload_model(base_model)
            print(f"  Memory freed after {folder}\n")
    
    return results

# Test all checkpoints
checkpoint_results = test_all_checkpoints(model_name, tokenizer, request)

# Display results
print("\n" + "=" * 80)
print("CHECKPOINT COMPARISON")
print("=" * 80)
for checkpoint_name, result in checkpoint_results.items():
    if "error" in result:
        print(f"\n{checkpoint_name}: ERROR - {result['error']}")
    else:
        print(f"\n{checkpoint_name}:")
        print("-" * 80)
        print(result["response_only"])
        print()

