## CTM Sync Filter Fine-tuning (Minimal ImageNet via Streaming)

This notebook is meant for **Colab (free tier)** or **VS Code/Cursor + Colab kernel**.

What it does:
- Installs minimal dependencies
- Downloads the provided Drive checkpoint zip
- Fine-tunes **only** the new synchronization filter parameters (FIR / IIR) on a tiny streamed ImageNet subset

Notes:
- The Google Drive file is a **zip** that contains the actual `.pt` checkpoint.
- Streaming ImageNet is used; we take only `N` samples.
- This is for **sanity-check / prototyping**, not full ImageNet benchmarking.



In [14]:
# If you're in Colab: Runtime -> Change runtime type -> GPU
# If you're in VS Code/Cursor with the Colab kernel: select a Colab GPU runtime.

!nvidia-smi -L || true



GPU 0: Tesla T4 (UUID: GPU-3ab31fc9-d008-645c-59b8-cba842c7269a)


In [15]:
# Install minimal deps for running the CTM code + streaming ImageNet.
# (If you already have the repo environment, you can skip.)

!pip -q install --upgrade pip
!pip -q install torch torchvision --index-url https://download.pytorch.org/whl/cu121
!pip -q install datasets huggingface_hub safetensors tqdm pillow gdown



In [16]:
# Clone the repo (or skip if you're already in it)
# IMPORTANT: use your repo (it contains the new sync-filter + colab files)

import os
import subprocess

REPO_URL = "https://github.com/aryangoyal7/CTM-sync.git"
REPO_DIR = "CTM-sync"

if not os.path.exists(REPO_DIR):
    subprocess.check_call(["git", "clone", "--depth", "1", REPO_URL, REPO_DIR])

os.chdir(REPO_DIR)
print("cwd:", os.getcwd())
print("has colab/:", os.path.exists("colab"))
print("has models/ctm_sync_filters.py:", os.path.exists("models/ctm_sync_filters.py"))



cwd: /content/continuous-thought-machines/CTM-sync
has colab/: True
has models/ctm_sync_filters.py: True


In [17]:
# Download the Drive checkpoint zip (your link) into ./checkpoints

!mkdir -p checkpoints
!pip -q install gdown

CHECKPOINT_URL = "https://drive.google.com/file/d/1Lr_3RZU9X9SS8lBhAhECBiSZDKfKhDkJ/view?usp=drive_link"
!gdown --fuzzy "{CHECKPOINT_URL}" -O checkpoints/ctm_checkpoint.pt
!ls -lh checkpoints/ctm_checkpoint.pt



Downloading...
From (original): https://drive.google.com/uc?id=1Lr_3RZU9X9SS8lBhAhECBiSZDKfKhDkJ
From (redirected): https://drive.google.com/uc?id=1Lr_3RZU9X9SS8lBhAhECBiSZDKfKhDkJ&confirm=t&uuid=c7d5b89d-418e-4796-83be-3e409a58a63f
To: /content/continuous-thought-machines/CTM-sync/checkpoints/ctm_checkpoint.pt
100% 691M/691M [00:10<00:00, 68.2MB/s] 
-rw-r--r-- 1 root root 659M May 11  2025 checkpoints/ctm_checkpoint.pt


In [19]:
import torch
from models.ctm_sync_filters import ContinuousThoughtMachineFIR, ContinuousThoughtMachineIIR, ContinuousThoughtMachineMultiBand

device = "cuda" if torch.cuda.is_available() else "cpu"

# Keep this small if backbone_type='none' (attention tokens = H*W)
x = torch.randn(1, 3, 32, 32, device=device)

m = ContinuousThoughtMachineFIR(
    iterations=2,
    d_model=128,
    d_input=64,
    heads=2,
    n_synch_out=32,
    n_synch_action=32,
    synapse_depth=1,
    memory_length=4,
    deep_nlms=False,
    memory_hidden_dims=4,
    do_layernorm_nlm=False,
    backbone_type="none",
    positional_embedding_type="none",
    out_dims=1000,
    prediction_reshaper=[-1],
    dropout=0.0,
    dropout_nlm=None,
    neuron_select_type="random-pairing",
    n_random_pairing_self=0,
    fir_k=4,
).to(device).eval()

with torch.no_grad():
    preds, certs, sync = m(x)

preds.shape, certs.shape, sync.shape

ModuleNotFoundError: No module named 'models.ctm_sync_filters'

In [20]:
# Run fine-tuning (FIR / IIR / MultiBand) on a tiny streamed ImageNet subset.
# This will download only ~N samples worth of images.
# NOTE: We should already be inside the cloned repo (CTM-sync) from the cell above.

import os
print("cwd:", os.getcwd())

# Sanity checks
!ls -la | head
!ls -la colab | head
!ls -la checkpoints | head

# FIR (filter-only)
!python colab/run_finetune_sync_fir_colab.py \
  --checkpoint_path checkpoints/ctm_checkpoint.pt \
  --n_train 2000 --n_val 500 \
  --batch_size 4 \
  --epochs 2 \
  --lr 1e-3

# IIR (filter-only) (uncomment)
# !python colab/run_finetune_sync_iir_colab.py \
#   --checkpoint_path checkpoints/ctm_checkpoint.pt \
#   --n_train 2000 --n_val 500 \
#   --batch_size 4 \
#   --epochs 2 \
#   --lr 1e-4

# MultiBand (filters + q_proj + output_projector) (uncomment)
# !python colab/run_finetune_sync_multiband_colab.py \
#   --checkpoint_path checkpoints/ctm_checkpoint.pt \
#   --n_train 2000 --n_val 500 \
#   --batch_size 4 \
#   --epochs 2 \
#   --lr 1e-3 \
#   --band_ks 8 16 32



cwd: /content/continuous-thought-machines/CTM-sync
total 80
drwxr-xr-x 12 root root  4096 Jan 12 06:02 .
drwxr-xr-x 12 root root  4096 Jan 12 06:02 ..
drwxr-xr-x  2 root root  4096 Jan 12 06:02 assets
drwxr-xr-x  2 root root  4096 Jan 12 06:02 checkpoints
drwxr-xr-x  2 root root  4096 Jan 12 06:02 colab
drwxr-xr-x  2 root root  4096 Jan 12 06:02 data
drwxr-xr-x  2 root root  4096 Jan 12 06:02 examples
drwxr-xr-x  8 root root  4096 Jan 12 06:02 .git
-rw-r--r--  1 root root   343 Jan 12 06:02 .gitignore
total 52
drwxr-xr-x  2 root root  4096 Jan 12 06:02 .
drwxr-xr-x 12 root root  4096 Jan 12 06:02 ..
-rw-r--r--  1 root root 10714 Jan 12 06:02 finetune_sync_imagenet_minimal.ipynb
-rw-r--r--  1 root root  1424 Jan 12 06:02 README.md
-rw-r--r--  1 root root  6588 Jan 12 06:02 run_finetune_sync_fir_colab.py
-rw-r--r--  1 root root  6446 Jan 12 06:02 run_finetune_sync_iir_colab.py
-rw-r--r--  1 root root  6853 Jan 12 06:02 run_finetune_sync_multiband_colab.py
-rw-r--r--  1 root root  3203 Ja