# DTC Agent - TPU Training in Google Colab

This notebook runs the Dual-Timescale Competence (DTC) reinforcement learning agent on Google Cloud TPU.

## Prerequisites

**IMPORTANT:** Make sure you have selected TPU as your runtime:
1. Go to `Runtime` -> `Change runtime type`
2. Select `TPU` as the hardware accelerator
3. Click `Save`

## What is DTC?

DTC is a model-based RL agent that learns through intrinsic motivation:
- **Learning Progress** (competence/mastery tracking)
- **Clean Curiosity** (epistemic uncertainty)
- **Empowerment** (control/optionality maximization)
- **Episodic Rehearsal** (long-term skill maintenance)

It uses:
- Slot attention for object-centric representations
- World model ensemble for predictions
- Global workspace theory for cognitive routing
- Temporal self-awareness for adaptive exploration

## 1. Setup Environment

In [None]:
# Clone the repository
!git clone https://github.com/YOUR_USERNAME/DTC-agent-tpu.git
%cd DTC-agent-tpu

In [None]:
# Install TPU dependencies and DTC agent
import sys
!pip install -q torch torch-xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
!pip install -q -r requirements-colab.txt
!pip install -q -e .

print("\n✓ Installation complete!")

In [None]:
# Verify TPU is available
import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()
print(f"TPU device: {device}")

# Test basic TPU operation
x = torch.randn(3, 3, device=device)
y = x + 2
xm.mark_step()  # Synchronize

print(f"\n✓ TPU is working correctly!")
print(f"Number of TPU cores: {xm.xrt_world_size()}")

## 2. Configuration

The TPU configuration is optimized for Cloud TPU v2/v3 with 8 cores.

In [None]:
# View the TPU configuration
!cat configs/tpu.yaml

## 3. Weights & Biases Setup (Optional)

W&B tracks training metrics and visualizations. If you don't have an account, skip this cell.

In [None]:
# Optional: Login to Weights & Biases
import wandb

# Uncomment and run to login:
# wandb.login()

# Or set your API key directly:
# import os
# os.environ['WANDB_API_KEY'] = 'your-api-key-here'

print("To use W&B, uncomment the code above and add your API key.")

## 4. Training

This will start training the DTC agent on TPU using the Crafter environment.

In [None]:
# Start training on TPU
!python -m dtc_agent.training \
    --config configs/tpu.yaml \
    --max-episodes 100 \
    --log-to-wandb false

# Notes:
# - Set --log-to-wandb true if you configured W&B above
# - Adjust --max-episodes based on how long you want to train
# - Training will automatically use TPU when available
# - Checkpoints are saved in /content/DTC-agent-tpu/

## 5. Monitor Training

View training progress and metrics.

In [None]:
# View recent training logs
from dtc_agent.utils import flush_xla_logs

print("Flushing XLA logs...")
flush_xla_logs()

# Or view saved metrics if using W&B:
# !wandb login
# Then open your W&B dashboard to see live metrics

## 6. Save and Download Checkpoints

Download trained models to continue training locally or for evaluation.

In [None]:
# List available checkpoints
import os
from pathlib import Path

checkpoint_dir = Path("/content/DTC-agent-tpu")
checkpoints = list(checkpoint_dir.glob("*.pt"))

print("Available checkpoints:")
for ckpt in checkpoints:
    size_mb = ckpt.stat().st_size / 1024 / 1024
    print(f"  {ckpt.name} ({size_mb:.2f} MB)")

In [None]:
# Download a specific checkpoint
from google.colab import files

# Replace with your checkpoint filename
checkpoint_name = "dtc_agent_checkpoint_step_1000.pt"

if Path(checkpoint_name).exists():
    files.download(checkpoint_name)
    print(f"Downloaded {checkpoint_name}")
else:
    print(f"Checkpoint {checkpoint_name} not found")

## Tips and Troubleshooting

### Performance Tips
- **Batch size**: TPU config uses batch_size=128 per core (effective 1024 total)
- **Memory**: If OOM, reduce `episodic_memory.capacity` in config
- **Speed**: TPUs are fastest with large batch sizes and minimal host-device transfers

### Common Issues

1. **TPU not available**
   - Check Runtime -> Change runtime type -> Hardware accelerator = TPU
   - Restart runtime if needed

2. **Out of memory**
   - Reduce batch_size in configs/tpu.yaml
   - Reduce episodic_memory.capacity
   - Reduce world_model_ensemble size

3. **Slow training**
   - Make sure compile_modules is false (XLA handles compilation)
   - Avoid print statements in training loop (use xla_print)
   - Check that threading is disabled on TPU

4. **Graph breaks**
   - XLA requires static graphs
   - Dynamic shapes and print statements cause recompilation
   - Our code automatically handles this

### Resources
- [DTC Documentation](https://github.com/YOUR_USERNAME/DTC-agent-tpu)
- [PyTorch XLA](https://pytorch.org/xla/release/2.0/index.html)
- [Google Cloud TPU](https://cloud.google.com/tpu/docs)

### Support
If you encounter issues, please open an issue on GitHub with:
- Error message
- TPU type (v2/v3)
- Configuration used
- Steps to reproduce

## 7. Cleanup

Free up resources when done.

In [None]:
# Optional: Clear GPU/TPU memory
import gc
import torch

gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None

print("✓ Cleanup complete")