In [None]:
import torch
import glob, os
import pandas as pd
import gc
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from transformers import StoppingCriteria, StoppingCriteriaList

# Set CUDA memory allocation to reduce fragmentation
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Clear CUDA cache at startup
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    print(f"CUDA available. GPU: {torch.cuda.get_device_name(0)}")
    print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")


# Dataset Augmentation

This notebook loads a finetuned model and generates augmented answers for each question in the dataset.


In [None]:
# Utility functions from daic_finetune.ipynb

def get_questions_answers_df(transcripts_dir):
  transcripts_files = glob.glob(os.path.join(transcripts_dir, "*.csv"))

  # Load and concatenate all transcript files
  df = pd.concat(
    (
      pd.read_csv(file, sep="\t", encoding="utf-8-sig").assign(source=os.path.basename(file))
      for file in transcripts_files
    ),
    ignore_index=True
  )

  # Create block_id to identify contiguous speaker segments
  df['block_id'] = (df['speaker'] != df['speaker'].shift(1)).cumsum()

  # Aggregate by source and block_id to merge contiguous segments by the same speaker
  df = df.groupby(['source', 'block_id']).agg(
    speaker=('speaker', 'first'),
    start_time=('start_time', 'min'),
    stop_time=('stop_time', 'max'),
    value=('value', lambda x: ' '.join(x.astype(str)))
  )

  # Sort by participant and time
  df = df.sort_values(by=['source', 'start_time']).reset_index()

  # Add previous speaker and value columns only if the previous source is the same
  df['prev_speaker'] = df.groupby('source')['speaker'].shift(1)
  df['prev_value'] = df.groupby('source')['value'].shift(1)

  is_answer = (
    (df['speaker'] == 'Participant') &
    (df['prev_speaker'] == 'Ellie') &
    (df['source'] == df['source'].shift(1))
  )

  df = df[is_answer].copy()
  df = df.rename(columns={
    'prev_value': 'question', # The previous Ellie utterance is the question
    'value': 'answer',            # The current Participant utterance is the answer
  })

  df['participant_id'] = df['source'].str.split("_").str[0].astype(int)
  df = df[['participant_id', 'question', 'answer', 'start_time']]

  return df

def add_labels_to_df(qa_df, labels_dir):
  splits = ['train', 'dev', 'test']

  all_labels_df = pd.DataFrame()
  for split in splits:
    split_labels_df = pd.read_csv(os.path.join(labels_dir, f"{split}.csv"))
    split_labels_df = split_labels_df.rename(columns={
      "Participant_ID": "participant_id",
      "PHQ8_Binary": "depression_label",
      "PHQ8_Score": "depression_severity",
      "PHQ_Binary": "depression_label",
      "PHQ_Score": "depression_severity",
    })
    split_labels_df = split_labels_df[["participant_id", "depression_label", "depression_severity"]]
    split_labels_df["split"] = split
    all_labels_df = pd.concat([all_labels_df, split_labels_df], ignore_index=True)

  merged_df = pd.merge(qa_df, all_labels_df, on="participant_id", how="left")
  return merged_df

def format_input(df, row, n_context=3):
  past_pairs = df[
    (df['participant_id'] == row['participant_id']) &
    (df.index < row.name)
  ].tail(n_context)

  context_lines = []
  for _, past_row in past_pairs.iterrows():
    q = str(past_row.get("question", "")).strip()
    a = str(past_row.get("answer", "")).strip()
    context_lines.append(f"Q: {q}\nA: {a}")

  context = "[START]\n" + "\n".join(context_lines) if context_lines else "[START]\n"

  instruction = (
    "### Instruction:\n"
    "You are analyzing a therapeutic interview between a virtual interviewer (Ellie) and a participant.\n"
    "The participant has a PHQ-8 score ranging from 0 (no depression) to 24 (severe depression). "
    f"This participant's score is {row['depression_severity']}. "
    "Scores of 10 or higher are typically considered indicative of depression.\n"
    "Given the participant's previous responses and their PHQ score, "
    "predict how they might answer the next question in a coherent and realistic way. "
    "Use natural, casual language. Avoid overly formal styles. "
    "Tolerate some irregularities (omissions, repetitions, filler words).\n\n"
  )

  question = str(row.get("question", "")).strip()

  input_text = f"### Input:\n{context}\nQ: {question}\nA:"

  return instruction + input_text

# Stopping criteria for [END] token
class EndTokenStoppingCriteria(StoppingCriteria):
    def __init__(self, end_token_id):
        self.end_token_id = end_token_id
    
    def __call__(self, input_ids, scores, **kwargs):
        # Stop if the last generated token is [END]
        return input_ids[0][-1] == self.end_token_id

def create_stopping_criteria(tokenizer):
    """Create stopping criteria that stops at [END] token."""
    end_token_id = tokenizer.convert_tokens_to_ids("[END]")
    return StoppingCriteriaList([EndTokenStoppingCriteria(end_token_id)])

def clear_memory():
    """Clear GPU and CPU memory cache."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        torch.cuda.synchronize()

def reset_cuda_memory():
    """Aggressively clear all CUDA memory."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        torch.cuda.synchronize()
        torch.cuda.reset_peak_memory_stats()

def load_finetuned_model(model_name, tokenizer, checkpoint_path=None):
    """Load a finetuned model from checkpoint or final model.
    
    Args:
        model_name: Base model name
        tokenizer: Tokenizer instance (must have [END] token)
        checkpoint_path: Path to checkpoint, or None for final model
        
    Returns:
        Loaded PEFT model
    """
    import torch
    
    # Determine device
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    
    # Load base model - don't use device_map to avoid offloading issues
    print(f"Loading base model...")
    base = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        device_map=None,  # Don't use device_map to avoid offloading
    )
    
    # Manually move to device
    base = base.to(device)
    
    # Resize token embeddings if [END] token was added
    vocab_size = len(tokenizer)
    model_vocab_size = base.get_input_embeddings().weight.shape[0]
    
    if vocab_size != model_vocab_size:
        print(f"Resizing token embeddings from {model_vocab_size} to {vocab_size}...")
        base.resize_token_embeddings(vocab_size)
    
    # Load PEFT model - don't use device_map to avoid offloading issues
    print("Loading PEFT adapter...")
    if checkpoint_path:
        lora = PeftModel.from_pretrained(
            base, 
            checkpoint_path,
            device_map=None,  # Don't use device_map
        )
    else:
        lora = PeftModel.from_pretrained(
            base, 
            "./tiny_llama_instruction_tuned",
            device_map=None,  # Don't use device_map
        )
    
    # Ensure model is on the correct device and in eval mode
    lora = lora.to(device)
    lora.eval()
    
    return lora

def generate_response(model, tokenizer, prompt, max_new_tokens=150, stopping_criteria=None):
    """Generate a response using the model."""
    pipe = None
    try:
        pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
        
        if stopping_criteria is None:
            stopping_criteria = create_stopping_criteria(tokenizer)
        
        res = pipe(prompt, max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)
        return res[0]["generated_text"]
    finally:
        # Clean up pipeline
        if pipe is not None:
            del pipe
        clear_memory()

def extract_response_only(full_output, prompt):
    """Extract only the generated response part, removing the prompt."""
    if full_output.startswith(prompt):
        return full_output[len(prompt):].strip()
    return full_output

def clean_augmented_answer(answer):
    """Clean the augmented answer by removing [END] token and extra whitespace."""
    if answer:
        # Remove [END] token if present
        answer = answer.replace("[END]", "").strip()
        # Remove any trailing incomplete sentences or fragments
        return answer
    return ""


## Load Finetuned Model


In [None]:
# Configuration
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
checkpoint_path = None  # Set to checkpoint path if you want to use a specific checkpoint, else uses final model
model_dir = "./tiny_llama_instruction_tuned"  # Directory where finetuned model is saved
data_dir = "./daic_data/"
output_csv_path = "./augmented_dataset.csv"

# Clear memory before loading
print("Clearing GPU memory...")
reset_cuda_memory()

if torch.cuda.is_available():
    total = torch.cuda.get_device_properties(0).total_memory / 1024**3
    reserved = torch.cuda.memory_reserved() / 1024**3
    free = total - reserved
    print(f"GPU memory - Total: {total:.2f} GB, Reserved: {reserved:.2f} GB, Free: {free:.2f} GB")

print("\nLoading tokenizer...")
# Try to load tokenizer from saved model directory first (has [END] token)
if os.path.exists(model_dir) and os.path.exists(os.path.join(model_dir, "tokenizer_config.json")):
    print(f"Loading tokenizer from {model_dir}...")
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
else:
    print(f"Loading tokenizer from {model_name}...")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Add [END] as a special token if not already present
    if "[END]" not in tokenizer.get_vocab():
        special_tokens_dict = {"additional_special_tokens": ["[END]"]}
        tokenizer.add_special_tokens(special_tokens_dict)

print("\nLoading finetuned model...")
model = load_finetuned_model(model_name, tokenizer, checkpoint_path)

# Check memory after loading
if torch.cuda.is_available():
    allocated = torch.cuda.memory_allocated() / 1024**3
    reserved = torch.cuda.memory_reserved() / 1024**3
    print(f"\nModel loaded. Memory - Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
print("Model loaded successfully!")


## Load Original Dataset


In [None]:
print("Loading original dataset...")
transcripts_dir = os.path.join(data_dir, "transcripts")
labels_dir = os.path.join(data_dir, "labels")

# Load questions and answers
qa_df = get_questions_answers_df(transcripts_dir)
qa_df = add_labels_to_df(qa_df, labels_dir)

# Sort by participant and time to maintain interview order
qa_df = qa_df.sort_values(by=['participant_id', 'start_time']).reset_index(drop=True)

print(f"Total samples: {len(qa_df)}")
print(f"Participants: {qa_df['participant_id'].nunique()}")
print(f"\nSplit distribution:")
print(qa_df['split'].value_counts())
print(f"\nFirst few rows:")
print(qa_df.head())


## Augment Dataset


In [None]:
def augment_dataset(df, model, tokenizer, progress_interval=50):
    """
    Augment dataset by generating new answers for each question.
    
    Args:
        df: DataFrame with questions and answers
        model: Finetuned model for generation
        tokenizer: Tokenizer instance
        progress_interval: Print progress every N samples
        
    Returns:
        DataFrame with original and augmented answers
    """
    augmented_data = []
    stopping_criteria = create_stopping_criteria(tokenizer)
    
    total_samples = len(df)
    successful = 0
    failed = 0
    
    print(f"Starting augmentation for {total_samples} samples...\n")
    
    for idx, row in df.iterrows():
        try:
            # Build prompt with context
            prompt = format_input(df, row, n_context=3)
            
            # Generate augmented answer
            full_output = generate_response(
                model, 
                tokenizer, 
                prompt, 
                max_new_tokens=150, 
                stopping_criteria=stopping_criteria
            )
            
            # Extract only the response
            augmented_answer = extract_response_only(full_output, prompt)
            augmented_answer = clean_augmented_answer(augmented_answer)
            
            # Store results
            augmented_data.append({
                'participant_id': row['participant_id'],
                'question': row['question'],
                'original_answer': row['answer'],
                'augmented_answer': augmented_answer,
                'depression_severity': row['depression_severity'],
                'depression_label': row.get('depression_label', None),
                'split': row.get('split', None),
                'start_time': row['start_time'],
            })
            
            successful += 1
            
            # Progress update
            if (idx + 1) % progress_interval == 0:
                print(f"Progress: {idx + 1}/{total_samples} ({100*(idx+1)/total_samples:.1f}%) - "
                      f"Successful: {successful}, Failed: {failed}")
                
        except Exception as e:
            print(f"Error processing sample {idx + 1} (participant {row['participant_id']}): {e}")
            failed += 1
            # Still add the row with empty augmented answer
            augmented_data.append({
                'participant_id': row['participant_id'],
                'question': row['question'],
                'original_answer': row['answer'],
                'augmented_answer': '',  # Empty on error
                'depression_severity': row['depression_severity'],
                'depression_label': row.get('depression_label', None),
                'split': row.get('split', None),
                'start_time': row['start_time'],
            })
    
    print(f"\nAugmentation complete!")
    print(f"Total: {total_samples}, Successful: {successful}, Failed: {failed}")
    
    return pd.DataFrame(augmented_data)

# Run augmentation
augmented_df = augment_dataset(qa_df, model, tokenizer, progress_interval=50)


## Save Augmented Dataset


In [None]:
# Display statistics
print("Augmented Dataset Statistics:")
print(f"Total samples: {len(augmented_df)}")
print(f"Non-empty augmented answers: {(augmented_df['augmented_answer'] != '').sum()}")
print(f"Empty augmented answers: {(augmented_df['augmented_answer'] == '').sum()}")

print("\nSample augmented responses:")
sample_df = augmented_df[augmented_df['augmented_answer'] != ''].head(3)
for idx, row in sample_df.iterrows():
    print(f"\n--- Sample {idx + 1} ---")
    print(f"Question: {row['question'][:100]}...")
    print(f"Original: {row['original_answer'][:100]}...")
    print(f"Augmented: {row['augmented_answer'][:100]}...")

# Save to CSV
print(f"\nSaving augmented dataset to {output_csv_path}...")
augmented_df.to_csv(output_csv_path, index=False, encoding="utf-8-sig")
print("Dataset saved successfully!")


## Optional: Preview Augmented Dataset


In [None]:
# Display full augmented dataset
print("Augmented Dataset Preview:")
print(augmented_df.head(10))
