# ReDSM5 LLM Fine-tuning on Google Colab

This notebook enables training decoder-only LLMs (Llama/Qwen) with **TPU/GPU** acceleration on Google Colab.

## Features
- ✅ Automatic TPU/GPU/CPU detection
- ✅ LoRA/QLoRA support for efficient fine-tuning
- ✅ Multi-label DSM-5 symptom classification
- ✅ Sliding window for long documents
- ✅ Threshold optimization and model export

## Before Starting
1. Go to **Runtime > Change runtime type**
2. Select **T4 GPU** or **TPU v2** for hardware accelerator
3. Click **Save**

---

In [None]:
# Cell 1: Setup and Installation
print("📦 Installing dependencies...")
!pip install -q transformers>=4.36.0 datasets>=2.16.0 accelerate>=0.25.0
!pip install -q peft>=0.7.0 bitsandbytes>=0.41.0 scipy scikit-learn
!pip install -q wandb optuna pyyaml pandas matplotlib seaborn

# For TPU support (optional - will be skipped if not on TPU)
try:
    !pip install -q cloud-tpu-client
    !pip install -q torch-xla
    print("✅ TPU libraries installed")
except:
    print("⚠️  TPU libraries not available (using GPU/CPU)")

# Clone repository
import os
if not os.path.exists('LLM_Agents_ReDSM5'):
    !git clone https://github.com/OscarTsao/LLM_Agents_ReDSM5.git
    print("✅ Repository cloned")
else:
    print("✅ Repository already exists")

%cd LLM_Agents_ReDSM5
print("\n✅ Installation complete!")

In [None]:
# Cell 2: Hardware Detection
import torch
import sys

print("🔍 Detecting hardware...\n")

# Try TPU detection
USE_TPU = False
try:
    import torch_xla
    import torch_xla.core.xla_model as xm
    device = xm.xla_device()
    USE_TPU = True
    print(f"✅ TPU detected: {device}")
    print(f"   TPU cores: {xm.xrt_world_size()}")
except ImportError:
    USE_TPU = False
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    if torch.cuda.is_available():
        print(f"✅ GPU detected: {torch.cuda.get_device_name(0)}")
        print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
        print(f"   Compute capability: {torch.cuda.get_device_capability(0)}")
    else:
        print("⚠️  CPU only - training will be slow")
        print("   Recommendation: Use Runtime > Change runtime type to enable GPU/TPU")

print(f"\n📊 Environment Info:")
print(f"   Device: {device}")
print(f"   Python: {sys.version.split()[0]}")
print(f"   PyTorch: {torch.__version__}")
print(f"   CUDA available: {torch.cuda.is_available()}")
print(f"   TPU mode: {USE_TPU}")

In [None]:
# Cell 3: Mount Google Drive (Optional - for data/checkpoints)
from google.colab import drive
import os

MOUNT_DRIVE = False  # Set to True to use Google Drive

if MOUNT_DRIVE:
    drive.mount('/content/drive')
    DATA_DIR = '/content/drive/MyDrive/redsm5_data'
    OUTPUT_BASE = '/content/drive/MyDrive/redsm5_outputs'
    print(f"✅ Drive mounted")
    print(f"   Data directory: {DATA_DIR}")
    print(f"   Output directory: {OUTPUT_BASE}")
else:
    print("ℹ️  Using local storage (data will be lost after session ends)")
    print("   Set MOUNT_DRIVE=True to use Google Drive for persistent storage")

In [None]:
# Cell 4: Generate Sample Data (or load your own)
from pathlib import Path

USE_SAMPLE_DATA = True  # Set to False if using real data

if USE_SAMPLE_DATA:
    print("📝 Generating synthetic sample data...")
    from tests.fixtures.data import generate_synthetic_dataset
    
    # Create sample dataset
    DATA_DIR = Path('/content/sample_data')
    generate_synthetic_dataset(DATA_DIR, num_samples=200, seed=42)
    
    print(f"\n✅ Generated sample data in {DATA_DIR}")
    print(f"   Train samples: 140")
    print(f"   Dev samples: 30")
    print(f"   Test samples: 30")
    print("\n⚠️  Note: This is synthetic data for demonstration.")
    print("   For real training, set USE_SAMPLE_DATA=False and provide your data.")
else:
    print(f"📂 Using data from: {DATA_DIR}")
    # Verify data exists
    data_dir = Path(DATA_DIR)
    for split in ['train', 'dev', 'test']:
        files = list(data_dir.glob(f"{split}.*"))
        if files:
            print(f"   ✅ {split}: {files[0].name}")
        else:
            print(f"   ❌ {split}: NOT FOUND")

In [None]:
# Cell 5: Hugging Face Login (for gated models)
from huggingface_hub import notebook_login

USE_GATED_MODEL = False  # Set to True if using Llama-2 or other gated models

if USE_GATED_MODEL:
    print("🔐 Please log in to Hugging Face...")
    notebook_login()
    print("✅ Logged in to Hugging Face")
else:
    print("ℹ️  Skipping HF login (not using gated models)")
    print("   Set USE_GATED_MODEL=True if using Llama-2 or similar models")

In [None]:
# Cell 6: Configure Training
import yaml
from pathlib import Path

print("⚙️  Configuring training parameters...\n")

# Choose model (use smaller models for faster experimentation)
MODEL_OPTIONS = {
    'tiny': 'hf-internal-testing/tiny-random-LlamaForSequenceClassification',  # For testing
    'small': 'meta-llama/Llama-2-7b-hf',  # 7B model
    'medium': 'Qwen/Qwen2.5-7B',  # Alternative 7B
    'large': 'meta-llama/Llama-2-13b-hf'  # 13B model
}

MODEL_SIZE = 'tiny'  # Change to 'small', 'medium', or 'large' for production

config = {
    # Model settings
    'model_id': MODEL_OPTIONS[MODEL_SIZE],
    'method': 'lora',  # 'full_ft', 'lora', or 'qlora'
    
    # Training hyperparameters (optimized for TPU/GPU)
    'num_train_epochs': 3,
    'per_device_train_batch_size': 8 if USE_TPU else (4 if torch.cuda.is_available() else 2),
    'per_device_eval_batch_size': 16 if USE_TPU else (8 if torch.cuda.is_available() else 4),
    'gradient_accumulation_steps': 2,
    'learning_rate': 2e-5,
    'warmup_ratio': 0.1,
    'weight_decay': 0.01,
    'max_grad_norm': 1.0,
    
    # LoRA settings (if method='lora' or 'qlora')
    'lora_r': 16,
    'lora_alpha': 32,
    'lora_dropout': 0.05,
    'lora_target_modules': ['q_proj', 'v_proj', 'k_proj', 'o_proj'],
    
    # Document processing
    'max_length': 2048 if USE_TPU else (1024 if torch.cuda.is_available() else 512),
    'doc_stride': 512,
    'truncation_strategy': 'window_pool',
    'pooler': 'mean',  # 'max', 'mean', or 'logit_sum'
    
    # Loss settings
    'loss_type': 'bce',  # 'bce' or 'focal'
    'class_weighting': 'sqrt_inv',  # 'none', 'inv', or 'sqrt_inv'
    'label_smoothing': 0.0,
    'focal_gamma': 2.0,
    
    # Optimization (hardware-aware)
    'bf16': USE_TPU or (torch.cuda.is_available() and torch.cuda.is_bf16_supported()),
    'fp16': not USE_TPU and torch.cuda.is_available() and not torch.cuda.is_bf16_supported(),
    'tf32': True,
    'gradient_checkpointing': True,
    
    # Evaluation and checkpointing
    'evaluation_strategy': 'epoch',
    'save_strategy': 'epoch',
    'save_total_limit': 2,
    'load_best_model_at_end': True,
    'metric_for_best_model': 'macro_f1',
    
    # Data
    'data_dir': str(DATA_DIR),
    'train_split': 'train',
    'dev_split': 'dev',
    'test_split': 'test',
    'seed': 42,
    
    # Optional: Limit samples for quick testing
    # 'max_train_samples': 50,
    # 'max_eval_samples': 20,
}

# Save config
config_path = Path('/content/colab_config.yaml')
with open(config_path, 'w') as f:
    yaml.dump(config, f)

print("✅ Configuration saved\n")
print("📋 Key settings:")
print(f"   Model: {config['model_id']}")
print(f"   Method: {config['method']}")
print(f"   Epochs: {config['num_train_epochs']}")
print(f"   Batch size: {config['per_device_train_batch_size']}")
print(f"   Max length: {config['max_length']}")
print(f"   Precision: {'bf16' if config['bf16'] else ('fp16' if config['fp16'] else 'fp32')}")
print(f"   Device: {'TPU' if USE_TPU else 'GPU' if torch.cuda.is_available() else 'CPU'}")

In [None]:
# Cell 7: Start Training
from pathlib import Path
import time

OUTPUT_DIR = Path('/content/outputs')
LABELS_PATH = Path('configs/labels.yaml')

print("🚀 Starting training...\n")
start_time = time.time()

!python -m src.train \
    --config {config_path} \
    --labels {LABELS_PATH} \
    --out_dir {OUTPUT_DIR} \
    --use_wandb false

elapsed_time = time.time() - start_time
print(f"\n✅ Training complete in {elapsed_time/60:.2f} minutes!")
print(f"📁 Outputs saved to: {OUTPUT_DIR}")

In [None]:
# Cell 8: View Training Results
import json
import pandas as pd

print("📊 Loading training results...\n")

# Load metrics
metrics_path = OUTPUT_DIR / 'metrics_dev.json'
if metrics_path.exists():
    with open(metrics_path) as f:
        metrics = json.load(f)
    
    print("🎯 Development Set Results:")
    print(f"   Macro F1:    {metrics['macro_f1']:.4f}")
    print(f"   Micro F1:    {metrics['micro_f1']:.4f}")
    print(f"   Weighted F1: {metrics['weighted_f1']:.4f}")
    
    # Load per-label report
    report_path = OUTPUT_DIR / 'label_report_dev.csv'
    if report_path.exists():
        df = pd.read_csv(report_path)
        print("\n📋 Per-Label Performance:")
        print(df[['label', 'f1', 'precision', 'recall', 'support']].to_string(index=False))
else:
    print("❌ Metrics file not found. Training may have failed.")

In [None]:
# Cell 9: Evaluate on Test Set
BEST_CKPT = OUTPUT_DIR / 'best'

print("🧪 Evaluating on test set...\n")

!python -m src.eval \
    --ckpt {BEST_CKPT} \
    --labels {LABELS_PATH} \
    --data_dir {DATA_DIR} \
    --split test

print("\n✅ Test evaluation complete!")

# Load test metrics
test_metrics_path = BEST_CKPT / 'eval_test' / 'metrics.json'
if test_metrics_path.exists():
    with open(test_metrics_path) as f:
        test_metrics = json.load(f)
    
    print("\n🎯 Test Set Results:")
    print(f"   Macro F1:    {test_metrics['macro_f1']:.4f}")
    print(f"   Micro F1:    {test_metrics['micro_f1']:.4f}")
    print(f"   Weighted F1: {test_metrics['weighted_f1']:.4f}")

In [None]:
# Cell 10: Visualize Predictions
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (14, 8)

print("📈 Generating visualizations...\n")

# Load predictions
pred_path = BEST_CKPT / 'eval_test' / 'predictions.csv'
if pred_path.exists():
    pred_df = pd.read_csv(pred_path)
    
    # Get label columns
    label_cols = [
        'depressed_mood', 'diminished_interest', 'weight_appetite_change',
        'sleep_disturbance', 'psychomotor', 'fatigue',
        'worthlessness_guilt', 'concentration_indecision', 'suicidality'
    ]
    
    # Check if prediction columns exist
    prob_cols = [f'{label}_prob' for label in label_cols]
    if all(col in pred_df.columns for col in prob_cols):
        # Plot 1: Probability distribution heatmap
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
        
        # Heatmap of prediction probabilities
        probs_matrix = pred_df[prob_cols].values[:50]  # First 50 samples
        sns.heatmap(probs_matrix.T, ax=ax1, cmap='YlOrRd', 
                   yticklabels=[l.replace('_', ' ').title() for l in label_cols],
                   xticklabels=False, cbar_kws={'label': 'Probability'})
        ax1.set_title('Prediction Probabilities (First 50 samples)', fontsize=14, fontweight='bold')
        ax1.set_xlabel('Sample Index')
        ax1.set_ylabel('DSM-5 Symptom')
        
        # Plot 2: Average probability per label
        avg_probs = pred_df[prob_cols].mean().values
        colors = plt.cm.viridis(np.linspace(0, 1, len(label_cols)))
        bars = ax2.barh([l.replace('_', ' ').title() for l in label_cols], avg_probs, color=colors)
        ax2.set_xlabel('Average Probability', fontsize=12)
        ax2.set_title('Average Prediction Probability by Symptom', fontsize=14, fontweight='bold')
        ax2.set_xlim(0, 1)
        
        # Add value labels
        for i, (bar, prob) in enumerate(zip(bars, avg_probs)):
            ax2.text(prob + 0.02, i, f'{prob:.3f}', va='center', fontsize=10)
        
        plt.tight_layout()
        plt.show()
        
        # Plot 3: Confusion-style label counts
        if 'doc_id' in pred_df.columns:
            fig, ax = plt.subplots(figsize=(12, 6))
            
            # Count positive predictions per label
            pred_counts = (pred_df[prob_cols] > 0.5).sum().values
            
            x = np.arange(len(label_cols))
            ax.bar(x, pred_counts, color='steelblue', alpha=0.7, label='Predicted Positive')
            
            ax.set_xlabel('DSM-5 Symptom', fontsize=12, fontweight='bold')
            ax.set_ylabel('Count', fontsize=12, fontweight='bold')
            ax.set_title('Predicted Positive Cases per Symptom', fontsize=14, fontweight='bold')
            ax.set_xticks(x)
            ax.set_xticklabels([l.replace('_', ' ').title() for l in label_cols], 
                              rotation=45, ha='right')
            ax.legend()
            ax.grid(axis='y', alpha=0.3)
            
            plt.tight_layout()
            plt.show()
    else:
        print("⚠️  Probability columns not found in predictions file")
else:
    print("❌ Predictions file not found")

print("\n✅ Visualizations complete!")

In [None]:
# Cell 11: Export Model to Drive
from shutil import copytree, make_archive
import os

print("💾 Exporting model...\n")

# Option 1: Save to Google Drive (if mounted)
if MOUNT_DRIVE:
    DRIVE_OUTPUT = '/content/drive/MyDrive/redsm5_models/best_model'
    os.makedirs(os.path.dirname(DRIVE_OUTPUT), exist_ok=True)
    copytree(BEST_CKPT, DRIVE_OUTPUT, dirs_exist_ok=True)
    print(f"✅ Model saved to Google Drive: {DRIVE_OUTPUT}")

# Option 2: Create ZIP for download
if BEST_CKPT.exists():
    zip_path = OUTPUT_DIR / 'best_model'
    make_archive(str(zip_path), 'zip', BEST_CKPT)
    print(f"✅ Model packaged as: {zip_path}.zip")
    print(f"   Size: {os.path.getsize(str(zip_path) + '.zip') / 1e6:.2f} MB")
    
    # Uncomment to download automatically
    # from google.colab import files
    # files.download(str(zip_path) + '.zip')
    # print("📥 Download started...")
else:
    print("❌ Best checkpoint not found")

# Also save key artifacts
artifacts = ['thresholds.json', 'config_used.yaml', 'metrics_dev.json', 'metrics_test.json']
print("\n📦 Key artifacts:")
for artifact in artifacts:
    artifact_path = OUTPUT_DIR / artifact
    if artifact_path.exists():
        print(f"   ✅ {artifact}")
    else:
        print(f"   ⚠️  {artifact} (not found)")

In [None]:
# Cell 12: Inference Example
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch.nn.functional as F

print("🔮 Loading model for inference...\n")

# Load model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained(BEST_CKPT)
tokenizer = AutoTokenizer.from_pretrained(BEST_CKPT)
model.eval()

# Move to device
if USE_TPU:
    import torch_xla.core.xla_model as xm
    device = xm.xla_device()
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Load thresholds
thresholds_path = OUTPUT_DIR / 'thresholds.json'
with open(thresholds_path) as f:
    thresholds_data = json.load(f)
    thresholds = torch.tensor(thresholds_data['thresholds']).to(device)

print("✅ Model loaded\n")

# Define label names
label_cols = [
    'depressed_mood', 'diminished_interest', 'weight_appetite_change',
    'sleep_disturbance', 'psychomotor', 'fatigue',
    'worthlessness_guilt', 'concentration_indecision', 'suicidality'
]

def predict_symptoms(text):
    """Predict DSM-5 symptoms from text."""
    # Tokenize
    inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Predict
    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.sigmoid(outputs.logits).cpu()
        preds = (probs > thresholds.cpu()).int()
    
    return probs[0], preds[0]

# Example 1: Depressive symptoms
print("🔍 Example 1: Depression-related text")
text1 = "I feel so sad and hopeless. I can't sleep and have no energy to do anything. Nothing brings me joy anymore."
probs, preds = predict_symptoms(text1)

print(f"\nText: {text1}")
print("\nPredicted Symptoms:")
for i, label in enumerate(label_cols):
    if preds[i] == 1:
        print(f"  ✓ {label.replace('_', ' ').title()} (prob: {probs[i]:.3f})")

# Example 2: Neutral text
print("\n" + "="*60)
print("🔍 Example 2: Neutral text")
text2 = "I went to the store today and bought some groceries. The weather was nice."
probs, preds = predict_symptoms(text2)

print(f"\nText: {text2}")
print("\nPredicted Symptoms:")
detected = False
for i, label in enumerate(label_cols):
    if preds[i] == 1:
        print(f"  ✓ {label.replace('_', ' ').title()} (prob: {probs[i]:.3f})")
        detected = True
if not detected:
    print("  (No symptoms detected)")

# Example 3: Custom input
print("\n" + "="*60)
print("🔍 Example 3: Try your own text!")
print("\nModify the cell below to test your own text:")
print("""\ntext_custom = "Your text here..."
probs, preds = predict_symptoms(text_custom)
# ... process results ...""")

## 🎉 Training Complete!

### Next Steps

1. **Improve Performance:**
   - Use larger models (7B/13B)
   - Increase training epochs
   - Use real ReDSM5 data
   - Tune hyperparameters (learning rate, batch size)

2. **Experiment with Settings:**
   - Try different pooling strategies: `'max'`, `'mean'`, `'logit_sum'`
   - Test Focal loss: `loss_type='focal'`
   - Adjust class weighting: `'inv'`, `'sqrt_inv'`
   - Use QLoRA for 13B models: `method='qlora'`

3. **Production Deployment:**
   - Save best model to Google Drive
   - Export thresholds for inference
   - Document your results
   - Set up monitoring

4. **Analysis:**
   - Check per-label F1 scores
   - Analyze false positives/negatives
   - Review threshold values
   - Validate on held-out data

### 📚 Resources

- **GitHub:** https://github.com/OscarTsao/LLM_Agents_ReDSM5
- **Paper:** [ReDSM5 Dataset](https://arxiv.org/abs/xxxx.xxxxx)
- **Documentation:** See `README.md` in repository

### 💡 Tips

- Use TPU for fastest training (Runtime > Change runtime type > TPU)
- Mount Google Drive for persistent storage
- Monitor training with wandb (`use_wandb=true`)
- Save checkpoints regularly

---

**Happy Fine-tuning! 🚀**