<a href="https://colab.research.google.com/github/RaphJean/PicassBot/blob/main/colab_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Picassbot Training on Google Colab

This notebook allows you to train the Picassbot Joint Model (Encoder + Predictor + Policy) using Colab's free GPUs.

## 1. Setup Environment

In [None]:
# Clone the repository (if not already present)
import os
if not os.path.exists('PicassBot'):
    !git clone https://github.com/RaphJean/PicassBot.git
%cd PicassBot

# Install dependencies
!pip install -r requirements.txt
!pip install tensorboard

In [None]:
# Download QuickDraw Dataset
!python download_all_data.py

In [None]:
import torch
if torch.cuda.is_available():
    print(f"✅ GPU Available: {torch.cuda.get_device_name(0)}")
else:
    print("⚠️ No GPU found. Go to Runtime > Change runtime type > T4 GPU")

## 2. Mount Google Drive (Optional)
Useful for saving checkpoints directly to your Drive so you don't lose them.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Create a checkpoint directory in Drive
!mkdir -p /content/drive/MyDrive/Picassbot/checkpoints

## 3. Train Joint Model

### Option A: Train from Scratch

In [None]:
# Run training
!PYTHONPATH=src:. python -m picassbot.policy.train_joint --config config.yaml

### Option B: Train with Pre-trained Policy (Recommended)
Upload your `policy_epoch_X.pth` to the `Picassbot` folder or Drive first.

In [None]:
# Example command (adjust path as needed)
# !PYTHONPATH=src:. python -m picassbot.policy.train_joint --config config.yaml --pretrained_policy /content/drive/MyDrive/Picassbot/policy_epoch_10.pth

## 4. Visualize with TensorBoard

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs/joint

## 5. Download Checkpoints

In [None]:
from google.colab import files
import glob

# Find latest checkpoint
checkpoints = glob.glob("joint_checkpoints/*.pth")
if checkpoints:
    latest = max(checkpoints, key=os.path.getctime)
    print(f"Downloading {latest}...")
    files.download(latest)
else:
    print("No checkpoints found.")