<a href="https://colab.research.google.com/github/alexstj0hn/task-arithmetic/blob/main/notebooks/colab_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Protein Task Vectors — Phase 1 Training (Colab)

Run on a free T4 GPU. Train one property at a time.

**Before starting:** Runtime → Change runtime type → **T4 GPU**

## Step 1: Mount Google Drive (for persistent storage)

In [1]:
from google.colab import drive
drive.mount('/content/drive')

# Create persistent directory on Google Drive
!mkdir -p /content/drive/MyDrive/protein-task-vectors/checkpoints
!mkdir -p /content/drive/MyDrive/protein-task-vectors/zero_shot
!mkdir -p /content/drive/MyDrive/protein-task-vectors/phase1_metrics
!mkdir -p /content/drive/MyDrive/protein-task-vectors/task_vectors
print('Google Drive mounted. Checkpoints will persist between sessions.')

Mounted at /content/drive
Google Drive mounted. Checkpoints will persist between sessions.


## Step 2: Clone repo and install dependencies

In [5]:
# CHANGE THIS to your GitHub repo URL
REPO_URL = "https://github.com/alexstj0hn/task-arithmetic.git"

import os
if os.path.exists('/content/task-arithmetic'):
    %cd /content/task-arithmetic
    !git pull
else:
    !git clone {REPO_URL} /content/task-arithmetic
    %cd /content/task-arithmetic

!pip install -e . -q
print('\nDependencies installed.')

Cloning into '/content/task-arithmetic'...
remote: Enumerating objects: 55, done.[K
remote: Counting objects: 100% (55/55), done.[K
remote: Compressing objects: 100% (49/49), done.[K
remote: Total 55 (delta 3), reused 55 (delta 3), pack-reused 0 (from 0)[K
Receiving objects: 100% (55/55), 93.33 KiB | 18.67 MiB/s, done.
Resolving deltas: 100% (3/3), done.
/content/task-arithmetic
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m76.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Building editable for protein-property-vectors (pyproject.toml) ... [?25l[?25hdone

Dependencies installed.


In [8]:
# Install MMseqs2
!cd /tmp && wget -q https://mmseqs.com/latest/mmseqs-linux-avx2.tar.gz && tar xzf mmseqs-linux-avx2.tar.gz && cp mmseqs/bin/mmseqs /usr/local/bin/
!mmseqs version

01683a607f83878e95436632d73e1d7d9ae30955


In [10]:
# Verify GPU
import torch
print(f'GPU: {torch.cuda.get_device_name(0)}')
print(f'VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')
print(f'bfloat16: {torch.cuda.is_bf16_supported()}')

GPU: Tesla T4
VRAM: 15.6 GB
bfloat16: True


In [11]:
# Create merged config with T4-safe settings
# (base config assumes A100 80GB; T4 has only 16GB)
import yaml

with open('configs/train_config.yaml') as f:
    config = yaml.safe_load(f)

with open('configs/colab_overrides.yaml') as f:
    overrides = yaml.safe_load(f)

# Deep merge overrides into config
for section, values in overrides.items():
    if section in config and isinstance(config[section], dict):
        config[section].update(values)
    else:
        config[section] = values

# Write merged config
with open('configs/train_config_colab.yaml', 'w') as f:
    yaml.dump(config, f, default_flow_style=False, sort_keys=False)

print('Created configs/train_config_colab.yaml with T4-safe settings:')
print(f'  mixed_precision: {config["training"]["mixed_precision"]}')
print(f'  batch_size: {config["training"]["batch_size"]}')
print(f'  list_size: {config["training"]["list_size"]}')
print(f'  grad_accum: {config["training"]["gradient_accumulation_steps"]}')
print(f'  eval_batch_size: {config["evaluation"]["eval_batch_size"]}')

Created configs/train_config_colab.yaml with T4-safe settings:
  mixed_precision: fp16
  batch_size: 2
  list_size: 16
  grad_accum: 16
  eval_batch_size: 16


## Step 3: Symlink results to Google Drive

This way checkpoints survive Colab disconnects.

In [13]:
import os
import shutil

DRIVE_DIR = '/content/drive/MyDrive/protein-task-vectors'
REPO_DIR = '/content/task-arithmetic'

os.makedirs(os.path.join(REPO_DIR, 'results'), exist_ok=True)

# Symlink results subdirs to Google Drive
for subdir in ['checkpoints', 'zero_shot', 'phase1_metrics', 'task_vectors']:
    local = os.path.join(REPO_DIR, 'results', subdir)
    remote = os.path.join(DRIVE_DIR, subdir)
    if os.path.islink(local):
        print(f'  {subdir}: already symlinked')
    else:
        if os.path.isdir(local):
            # Copy any existing files first
            for f in os.listdir(local):
                src = os.path.join(local, f)
                dst = os.path.join(remote, f)
                if not os.path.exists(dst):
                    shutil.copy2(src, dst) if os.path.isfile(src) else shutil.copytree(src, dst)
            shutil.rmtree(local)
        os.symlink(remote, local)
        print(f'  {subdir}: symlinked to Drive')

print('\nResults will be saved to Google Drive automatically.')

  checkpoints: symlinked to Drive
  zero_shot: symlinked to Drive
  phase1_metrics: symlinked to Drive
  task_vectors: symlinked to Drive

Results will be saved to Google Drive automatically.


## Step 4: Download data

Downloads ProteinGym (~500MB). Only runs once — skips if already downloaded.

In [14]:
!python -m src.data.download --config configs/train_config.yaml


Downloading reference file from https://raw.githubusercontent.com/OATML-Markslab/ProteinGym/main/reference_files/DMS_substitutions.csv
Reference CSV: 213kB [00:00, 710kB/s]                 
Downloaded to: data/raw/DMS_substitutions.csv

Downloading DMS assays from https://marks.hms.harvard.edu/proteingym/ProteinGym_v1.3/DMS_ProteinGym_substitutions.zip
This may take a few minutes (~500MB)...
DMS Assays ZIP: 43.0MB [00:04, 8.78MB/s]                
Downloaded to: data/raw/DMS_ProteinGym_substitutions.zip

Extracting ZIP to data/raw/DMS_ProteinGym_substitutions...
Removed ZIP file: data/raw/DMS_ProteinGym_substitutions.zip

Validating download...
✓ Validation successful!
  Found 217 / 217 assay files

✓ Download complete. Data saved to: data/raw


## Step 5: Categorize and split (if not already done)

In [15]:
import os

if not os.path.exists('data/processed/category_assignments.json'):
    !python -m src.data.categorize --config configs/train_config.yaml
else:
    print('Already categorized.')

if not os.path.exists('data/splits/train_assays.json'):
    !python -m src.data.splits --config configs/train_config.yaml
else:
    print('Splits already created.')

Already categorized.
Splits already created.


## Step 6: Zero-shot baseline

Scores all assays with ESM-2 masked marginal likelihood.
This takes ~2-4 hours for all 217 assays on T4. Skips already-scored assays.

In [None]:
!python scripts/04_zero_shot.py --config configs/train_config_colab.yaml

2026-02-18 08:35:46,349 [INFO] numexpr.utils: NumExpr defaulting to 2 threads.
Loaded reference file: data/raw/DMS_substitutions.csv
  Shape: (217, 46)
  Columns: ['DMS_index', 'DMS_id', 'DMS_filename', 'UniProt_ID', 'taxon', 'source_organism', 'target_seq', 'seq_len', 'includes_multiple_mutants', 'DMS_total_number_mutants']...
2026-02-18 08:35:46,546 [INFO] __main__: Scoring 217 assays...
2026-02-18 08:35:59,264 [INFO] __main__: Loading facebook/esm2_t33_650M_UR50D...
2026-02-18 08:35:59,562 [INFO] httpx: HTTP Request: HEAD https://huggingface.co/facebook/esm2_t33_650M_UR50D/resolve/main/config.json "HTTP/1.1 307 Temporary Redirect"
2026-02-18 08:35:59,567 [INFO] httpx: HTTP Request: HEAD https://huggingface.co/api/resolve-cache/models/facebook/esm2_t33_650M_UR50D/08e4846e537177426273712802403f7ba8261b6c/config.json "HTTP/1.1 200 OK"
2026-02-18 08:35:59,574 [INFO] httpx: HTTP Request: GET https://huggingface.co/api/resolve-cache/models/facebook/esm2_t33_650M_UR50D/08e4846e537177426273

## Step 7: Train property models

Train ONE property per Colab session.
Change `PROPERTY` below and run a new session for each.

Order: stability → binding → expression → activity

Each takes ~2-4 hours on T4.

In [None]:
#########################################
# CHANGE THIS for each training session #
#########################################
PROPERTY = "stability"  # stability | binding | expression | activity

In [None]:
!python scripts/05_train_property_models.py \
    --config configs/train_config_colab.yaml \
    --property {PROPERTY} \
    --resume

## Step 8: Evaluate (after all 4 properties are trained)

In [None]:
# Per-property evaluation
for prop in ['stability', 'binding', 'expression', 'activity']:
    print(f'\n=== Evaluating {prop} ===')
    !python scripts/06_evaluate.py --config configs/train_config_colab.yaml --property {prop}

In [None]:
# Cross-property matrix (THE key result)
!python scripts/06_evaluate.py --config configs/train_config_colab.yaml --cross-property

In [None]:
# View the result
import pandas as pd
matrix = pd.read_csv('results/phase1_metrics/cross_property_matrix.csv', index_col=0)
print('Cross-Property Evaluation Matrix (Spearman correlation)')
print(matrix.to_string())

## Step 9: Extract task vectors

In [None]:
!python scripts/07_extract_vectors.py --config configs/train_config_colab.yaml

In [None]:
# View cosine similarity between task vectors
import pandas as pd
sim = pd.read_csv('results/task_vectors/cosine_similarity_matrix.csv', index_col=0)
print('Task Vector Cosine Similarity')
print(sim.to_string())

## Done!

All results are saved to your Google Drive at:
- `My Drive/protein-task-vectors/checkpoints/` — trained models
- `My Drive/protein-task-vectors/phase1_metrics/` — evaluation results
- `My Drive/protein-task-vectors/task_vectors/` — extracted vectors

You can close this notebook. Everything persists on Drive.