## CTM Sync Filter Fine-tuning (Minimal ImageNet via Streaming)

This notebook is meant for **Colab (free tier)** or **VS Code/Cursor + Colab kernel**.

What it does:
- Installs minimal dependencies
- Downloads the provided Drive checkpoint zip
- Fine-tunes **only** the new synchronization filter parameters (FIR / IIR) on a tiny streamed ImageNet subset

Notes:
- The Google Drive file is a **zip** that contains the actual `.pt` checkpoint.
- Streaming ImageNet is used; we take only `N` samples.
- This is for **sanity-check / prototyping**, not full ImageNet benchmarking.



In [1]:
# If you're in Colab: Runtime -> Change runtime type -> GPU
# If you're in VS Code/Cursor with the Colab kernel: select a Colab GPU runtime.

!nvidia-smi -L || true



GPU 0: Tesla T4 (UUID: GPU-3ab31fc9-d008-645c-59b8-cba842c7269a)


In [2]:
# Install minimal deps for running the CTM code + streaming ImageNet.
# (If you already have the repo environment, you can skip.)

!pip -q install --upgrade pip
!pip -q install torch torchvision --index-url https://download.pytorch.org/whl/cu121
!pip -q install datasets huggingface_hub safetensors tqdm pillow gdown python-dotenv



In [3]:
# Clone the repo (or skip if you're already in it)
# IMPORTANT: use your repo (it contains the new sync-filter + colab files)

import os
import subprocess
import os, getpass
os.environ["HF_TOKEN"] = getpass.getpass("Paste HF token (hidden): ")

REPO_URL = "https://github.com/aryangoyal7/CTM-sync.git"
REPO_DIR = "CTM-sync"

# Always start from /content to avoid nesting CTM-sync/CTM-sync/... if this cell is re-run.
os.chdir("/content")

if not os.path.exists(REPO_DIR):
    subprocess.check_call(["git", "clone", "--depth", "1", REPO_URL, REPO_DIR])

os.chdir(REPO_DIR)
print("cwd:", os.getcwd())
print("has colab/:", os.path.exists("colab"))
print("has models/ctm_sync_filters.py:", os.path.exists("models/ctm_sync_filters.py"))

# Provide HF token for gated ImageNet-1k
# # Recommended: create /content/CTM-sync/colab/.env with: HF_TOKEN=hf_...
# # (Do NOT commit this file.)
# from dotenv import load_dotenv
# load_dotenv("colab/.env", override=False)
# print("HF_TOKEN present:", bool(os.environ.get("HF_TOKEN")))
# # Optional interactive login if you didn't set HF_TOKEN
# if not os.environ.get("HF_TOKEN"):
#     from huggingface_hub import notebook_login
#     notebook_login()



cwd: /content/CTM-sync
has colab/: True
has models/ctm_sync_filters.py: True


In [4]:
# Download the Drive checkpoint zip (your link) into ./checkpoints

!mkdir -p checkpoints
!pip -q install gdown

CHECKPOINT_URL = "https://drive.google.com/file/d/1Lr_3RZU9X9SS8lBhAhECBiSZDKfKhDkJ/view?usp=drive_link"
!gdown --fuzzy "{CHECKPOINT_URL}" -O checkpoints/ctm_checkpoint.pt
!ls -lh checkpoints/ctm_checkpoint.pt



Downloading...
From (original): https://drive.google.com/uc?id=1Lr_3RZU9X9SS8lBhAhECBiSZDKfKhDkJ
From (redirected): https://drive.google.com/uc?id=1Lr_3RZU9X9SS8lBhAhECBiSZDKfKhDkJ&confirm=t&uuid=16b34c5a-ac71-4a32-a419-322d9eb5344d
To: /content/CTM-sync/checkpoints/ctm_checkpoint.pt
100% 691M/691M [00:04<00:00, 149MB/s]  
-rw-r--r-- 1 root root 659M May 11  2025 checkpoints/ctm_checkpoint.pt


In [5]:
# Fix import path in Colab notebooks (ensures repo root is on sys.path)
import os
import sys

sys.path.insert(0, os.getcwd())

import torch
from models.ctm_sync_filters import ContinuousThoughtMachineFIR, ContinuousThoughtMachineIIR, ContinuousThoughtMachineMultiBand

device = "cuda" if torch.cuda.is_available() else "cpu"

# Keep this small if backbone_type='none' (attention tokens = H*W)
x = torch.randn(1, 3, 32, 32, device=device)

m = ContinuousThoughtMachineFIR(
    iterations=2,
    d_model=128,
    d_input=64,
    heads=2,
    n_synch_out=32,
    n_synch_action=32,
    synapse_depth=1,
    memory_length=4,
    deep_nlms=False,
    memory_hidden_dims=4,
    do_layernorm_nlm=False,
    backbone_type="none",
    positional_embedding_type="none",
    out_dims=1000,
    prediction_reshaper=[-1],
    dropout=0.0,
    dropout_nlm=None,
    neuron_select_type="random-pairing",
    n_random_pairing_self=0,
    fir_k=4,
).to(device).eval()

with torch.no_grad():
    preds, certs, sync = m(x)

preds.shape, certs.shape, sync.shape

Using neuron select type: random-pairing
Synch representation size action: 32
Synch representation size out: 32


(torch.Size([1, 1000, 2]), torch.Size([1, 2, 2]), torch.Size([1, 32]))

In [12]:
# Run fine-tuning (FIR / IIR / MultiBand) on a tiny streamed ImageNet subset.
# This will download only ~N samples worth of images.
# NOTE: We should already be inside the cloned repo (CTM-sync) from the cell above.

import os
print("cwd:", os.getcwd())

# Sanity checks
!ls -la | head
!ls -la colab | head
!ls -la checkpoints | head

# FIR (filter-only)
# Running as a module keeps the repo root on sys.path
!python -m colab.run_finetune_sync_fir_colab \
  --checkpoint_path checkpoints/ctm_checkpoint.pt \
  --n_train 2000 --n_val 500 \
  --batch_size 4 \
  --epochs 2 \
  --lr 1e-3

# IIR (filter-only) (uncomment)
# !python colab/run_finetune_sync_iir_colab.py \
#   --checkpoint_path checkpoints/ctm_checkpoint.pt \
#   --n_train 2000 --n_val 500 \
#   --batch_size 4 \
#   --epochs 2 \
#   --lr 1e-4

# MultiBand (filter-only, reduced back to original sync dimensionality) (uncomment)
# !python -m colab.run_finetune_sync_multiband_colab \
#   --checkpoint_path checkpoints/ctm_checkpoint.pt \
#   --n_train 2000 --n_val 500 \
#   --batch_size 4 \
#   --epochs 2 \
#   --lr 1e-3 \
#   --band_ks 8 16 32



cwd: /content/CTM-sync
total 84
drwxr-xr-x 13 root root  4096 Jan 12 06:16 .
drwxr-xr-x  1 root root  4096 Jan 12 06:15 ..
drwxr-xr-x  2 root root  4096 Jan 12 06:15 assets
drwxr-xr-x  3 root root  4096 Jan 12 06:58 checkpoints
drwxr-xr-x  3 root root  4096 Jan 12 06:16 colab
drwxr-xr-x  2 root root  4096 Jan 12 06:15 data
drwxr-xr-x  2 root root  4096 Jan 12 06:15 examples
drwxr-xr-x  8 root root  4096 Jan 12 06:15 .git
-rw-r--r--  1 root root   343 Jan 12 06:15 .gitignore
total 68
drwxr-xr-x  3 root root  4096 Jan 12 06:16 .
drwxr-xr-x 13 root root  4096 Jan 12 06:16 ..
-rw-r--r--  1 root root 17236 Jan 12 06:15 finetune_sync_imagenet_minimal.ipynb
-rw-r--r--  1 root root   187 Jan 12 06:15 __init__.py
drwxr-xr-x  2 root root  4096 Jan 12 06:16 __pycache__
-rw-r--r--  1 root root  1424 Jan 12 06:15 README.md
-rw-r--r--  1 root root  6908 Jan 12 06:15 run_finetune_sync_fir_colab.py
-rw-r--r--  1 root root  6766 Jan 12 06:15 run_finetune_sync_iir_colab.py
-rw-r--r--  1 root root  7173 