#Set-up

In [1]:
from google.colab import drive
drive.mount('/content/drive')

!pip install transformers

Mounted at /content/drive


In [2]:
# set seeds
import random
import numpy as np
import torch

def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)

  torch.backends.cudnn.deterministic = True #- Forces cuDNN to use deterministic algorithms
  torch.backends.cudnn.benchmark = False #

set_seed(42)

In [12]:
# Add the specific directory to the Python path
import sys
sys.path.append('/content/drive/MyDrive/generally_useful_scripts')

# Import all functions from utils_addgene
from utils_addgene import *
import utils_addgene
import importlib
importlib.reload(utils_addgene)


## load custom functions from utils.py

import sys
sys.path.append('//content/drive/MyDrive/SAEs_for_Genomics')

import importlib
import utils
importlib.reload(utils)

<module 'utils' from '//content/drive/MyDrive/SAEs_for_Genomics/utils.py'>

# Load NT model

In [4]:
"loading smallest nucleotide transformer (50m params)"
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig
import torch

num_params = 50 ## default 50

# Import the tokenizer and the model
tokenizer_nt = AutoTokenizer.from_pretrained(f"InstaDeepAI/nucleotide-transformer-v2-{num_params}m-multi-species", trust_remote_code=True)
model_nt = AutoModelForMaskedLM.from_pretrained(f"InstaDeepAI/nucleotide-transformer-v2-{num_params}m-multi-species", trust_remote_code=True)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/129 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/28.7k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/101 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

esm_config.py:   0%|          | 0.00/14.9k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/InstaDeepAI/nucleotide-transformer-v2-50m-multi-species:
- esm_config.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling_esm.py:   0%|          | 0.00/58.2k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/InstaDeepAI/nucleotide-transformer-v2-50m-multi-species:
- modeling_esm.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors:   0%|          | 0.00/224M [00:00<?, ?B/s]

In [5]:
# Optionally we can load weights from extra pretraining
# lets load addgene-pretrained weights
model_nt_extra_pre = model_nt
weights_path = '/content/drive/MyDrive/SAEs_for_Genomics/nt_weights_extra_pretraining/NT50Mm_addgenepretrained_epoch_4.pt'
model_nt_extra_pre.load_state_dict(torch.load(weights_path)['model_state_dict'])

<All keys matched successfully>

# Set-up & Load SAE

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

cfg = {
    "seed": 49,
    "batch_size": 4096,
    "buffer_mult": 384,
    "lr": 5e-5,
    "num_tokens": tokenizer_nt.vocab_size,
    "d_model": model_nt.config.hidden_size,
    "l1_coeff": 1e-1,
    "l0_coeff": 1,
    "beta1": 0.9,
    "beta2": 0.999,
    "dict_mult": 8, # hidden_d = d_model * dict_mult
    "seq_len": 512,
    "d_mlp": model_nt.config.hidden_size,
    "enc_dtype":"fp32",
    "remove_rare_dir": False,
    "total_training_steps": 10000,
    "lr_warm_up_steps": 1000,
    "device": "cuda",
    "tempterature":0.05,
    "activation_treshold":0.3,
    "enc_dtype": torch.float32

}

sae_model = utils.AutoEncoder(cfg)

In [None]:
#count params of sae_model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'Number of parameters in sae_model: {count_parameters(sae_model)}')

Number of parameters in sae_model: 4198912


# Load and preprocess addgene dataset

In [9]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np


# Constants
TEST_DATA_PATH = '/content/drive/MyDrive/NOO_paper/datasets/worldwide/blast_geac_ext_169k_val_random.csv'
TRAIN_DATA_PATH = '/content/drive/MyDrive/NOO_paper/datasets/worldwide/blast_geac_ext_169k_train_random.csv'
INFREQUENT_THRESHOLD = 10

# Preprocess data: each data point has shape (DNA_seq: str, nation_idx: int)
df_train, df_val = utils_addgene.preprocess_data(TRAIN_DATA_PATH, TEST_DATA_PATH)

# Display the split data
print("Train Data Shape:", df_train.shape)
print("Validation Data Shape:", df_val.shape)

# Create torch datasets in which sequences are tokenized
# Each data point has shape (input_ids, attention_masks, labels)
val_dataset = GenomicDataset(df_val, tokenizer_nt=tokenizer_nt)
train_dataset = GenomicDataset(df_train, tokenizer_nt=tokenizer_nt)

BS = 64*6

val_loader_dna = DataLoader(val_dataset, batch_size=BS, shuffle=False, pin_memory=True, num_workers=2)
train_loader_dna = DataLoader(train_dataset, batch_size=BS, shuffle=True, pin_memory=True, num_workers=2)

y_test shape: (15551,)
test_data shape: (15551, 4)
Train Data Shape: (93306, 2)
Validation Data Shape: (15551, 2)


# Test-run SAE



In [8]:
## test run SAE on batch of data
model_nt_extra_pre.eval().cuda()
for batch in train_loader_dna:

    input_ids = batch['input_ids'].cuda()
    attention_mask = batch['attention_mask'].cuda()
    label = batch['label'].cuda()

    # forward pass with model_nt to get acts
    with torch.no_grad():
        mlp_act = utils.get_layer_activations(model_nt_extra_pre, input_ids, attention_mask, layer_N=11)
        acts = mlp_act[0].reshape(-1, model_nt_extra_pre.config.hidden_size)

    with torch.cuda.amp.autocast():
        loss, x_reconstruct, acts_sparse, l2_loss, nmse, l1_loss, true_l0 = sae_model(acts)


    print(f'loss: {loss}')
    print(f'l2_loss: {l2_loss}')
    print(f'nmse: {nmse}')
    print(f'l1_loss: {l1_loss}')
    break


  with torch.cuda.amp.autocast():


loss: 1955.95849609375
l2_loss: 984.78076171875
nmse: 1.386866807937622
l1_loss: 609.4150390625


# Extracting Activations

In [10]:
import torch
from tqdm import tqdm
import h5py
import numpy as np
from torch.cuda.amp import autocast
import utils

def extract_activations(model_nt, dataloader, layer_N=11, save_path='activations.h5'):
    model_nt.eval()
    all_activations = []

    with h5py.File(save_path, 'w') as f:
        # Create dataset with unknown size initially
        dset = None
        current_idx = 0

        with torch.no_grad():
            total_batches = len(dataloader)
            progress_bar = tqdm(total=total_batches, desc="Extracting activations")
            for batch in dataloader:
                input_ids = batch['input_ids'].cuda()
                attention_mask = batch['attention_mask'].cuda()

                with autocast():
                    mlp_act = utils.get_layer_activations(model_nt,
                                                        input_ids,
                                                        attention_mask,
                                                        layer_N=layer_N)

                mlp_act = mlp_act[0].reshape(-1, model_nt.config.hidden_size)
                mlp_act = mlp_act.cpu().numpy()
                print(mlp_act.shape)

                # Initialize dataset with correct shape if first batch
                if dset is None:
                    total_size = len(dataloader) * mlp_act.shape[0]
                    dset = f.create_dataset('activations',
                                          shape=(total_size, mlp_act.shape[1]),
                                          dtype='float32',
                                          chunks=True)

                # Save batch
                dset[current_idx:current_idx + mlp_act.shape[0]] = mlp_act
                current_idx += mlp_act.shape[0]

                # Clear memory
                del mlp_act, input_ids, attention_mask
                torch.cuda.empty_cache()

                progress_bar.update(1)


# Run extraction
model_nt.eval()
extract_activations(model_nt.cuda(),
                    train_loader_dna,
                    save_path='/content/drive/MyDrive/SAEs_for_Genomics/plasmidpretrained_NT50_L12_mlp.activations.h5')

  with autocast():


(196608, 512)


Extracting activations:   0%|          | 1/243 [00:16<1:04:55, 16.10s/it]

(196608, 512)


Extracting activations:   1%|          | 2/243 [00:31<1:03:04, 15.70s/it]

(196608, 512)


Extracting activations:   1%|          | 3/243 [00:47<1:02:30, 15.63s/it]

(196608, 512)


Extracting activations:   2%|▏         | 4/243 [01:01<1:00:25, 15.17s/it]

(196608, 512)


Extracting activations:   2%|▏         | 5/243 [01:15<59:00, 14.88s/it]  

(196608, 512)


Extracting activations:   2%|▏         | 6/243 [01:30<57:51, 14.65s/it]

(196608, 512)


Extracting activations:   3%|▎         | 7/243 [01:45<59:11, 15.05s/it]

(196608, 512)


Extracting activations:   3%|▎         | 8/243 [02:01<59:30, 15.19s/it]

(196608, 512)


Extracting activations:   4%|▎         | 9/243 [02:16<59:31, 15.26s/it]

(196608, 512)


Extracting activations:   4%|▍         | 10/243 [02:32<1:00:05, 15.47s/it]

(196608, 512)


Extracting activations:   5%|▍         | 11/243 [02:47<58:27, 15.12s/it]  

(196608, 512)


Extracting activations:   5%|▍         | 12/243 [03:02<58:46, 15.27s/it]

(196608, 512)


Extracting activations:   5%|▌         | 13/243 [03:18<58:38, 15.30s/it]

(196608, 512)


Extracting activations:   6%|▌         | 14/243 [03:33<58:30, 15.33s/it]

(196608, 512)


Extracting activations:   6%|▌         | 15/243 [03:47<57:05, 15.03s/it]

(196608, 512)


Extracting activations:   7%|▋         | 16/243 [04:02<55:56, 14.79s/it]

(196608, 512)


Extracting activations:   7%|▋         | 17/243 [04:16<55:07, 14.63s/it]

(196608, 512)


Extracting activations:   7%|▋         | 18/243 [04:30<54:23, 14.50s/it]

(196608, 512)


Extracting activations:   8%|▊         | 19/243 [04:46<55:38, 14.90s/it]

(196608, 512)


Extracting activations:   8%|▊         | 20/243 [05:02<56:15, 15.14s/it]

(196608, 512)


Extracting activations:   9%|▊         | 21/243 [05:17<56:27, 15.26s/it]

(196608, 512)


Extracting activations:   9%|▉         | 22/243 [05:31<54:58, 14.93s/it]

(196608, 512)


Extracting activations:   9%|▉         | 23/243 [05:46<54:18, 14.81s/it]

(196608, 512)


Extracting activations:  10%|▉         | 24/243 [06:00<53:12, 14.58s/it]

(196608, 512)


Extracting activations:  10%|█         | 25/243 [06:14<52:59, 14.59s/it]

(196608, 512)


Extracting activations:  11%|█         | 26/243 [06:29<52:34, 14.54s/it]

(196608, 512)


Extracting activations:  11%|█         | 27/243 [06:43<52:02, 14.45s/it]

(196608, 512)


Extracting activations:  12%|█▏        | 28/243 [06:58<51:43, 14.43s/it]

(196608, 512)


Extracting activations:  12%|█▏        | 29/243 [07:12<51:37, 14.47s/it]

(196608, 512)


Extracting activations:  12%|█▏        | 30/243 [07:26<51:11, 14.42s/it]

(196608, 512)


Extracting activations:  13%|█▎        | 31/243 [07:41<50:53, 14.40s/it]

(196608, 512)


Extracting activations:  13%|█▎        | 32/243 [07:55<50:38, 14.40s/it]

(196608, 512)


Extracting activations:  14%|█▎        | 33/243 [08:09<50:08, 14.33s/it]

(196608, 512)


Extracting activations:  14%|█▍        | 34/243 [08:24<50:40, 14.55s/it]

(196608, 512)


Extracting activations:  14%|█▍        | 35/243 [08:40<51:18, 14.80s/it]

(196608, 512)


Extracting activations:  15%|█▍        | 36/243 [08:54<50:24, 14.61s/it]

(196608, 512)


Extracting activations:  15%|█▌        | 37/243 [09:08<49:55, 14.54s/it]

(196608, 512)


Extracting activations:  16%|█▌        | 38/243 [09:23<49:49, 14.58s/it]

(196608, 512)


Extracting activations:  16%|█▌        | 39/243 [09:37<49:09, 14.46s/it]

(196608, 512)


Extracting activations:  16%|█▋        | 40/243 [09:53<49:58, 14.77s/it]

(196608, 512)


Extracting activations:  17%|█▋        | 41/243 [10:08<50:25, 14.98s/it]

(196608, 512)


Extracting activations:  17%|█▋        | 42/243 [10:23<50:23, 15.04s/it]

(196608, 512)


Extracting activations:  18%|█▊        | 43/243 [10:38<49:20, 14.80s/it]

(196608, 512)


Extracting activations:  18%|█▊        | 44/243 [10:52<48:46, 14.70s/it]

(196608, 512)


Extracting activations:  19%|█▊        | 45/243 [11:06<47:55, 14.52s/it]

(196608, 512)


Extracting activations:  19%|█▉        | 46/243 [11:21<48:06, 14.65s/it]

(196608, 512)


Extracting activations:  19%|█▉        | 47/243 [11:37<48:42, 14.91s/it]

(196608, 512)


Extracting activations:  20%|█▉        | 48/243 [11:51<47:45, 14.70s/it]

(196608, 512)


Extracting activations:  20%|██        | 49/243 [12:07<48:40, 15.05s/it]

(196608, 512)


Extracting activations:  21%|██        | 50/243 [12:22<48:59, 15.23s/it]

(196608, 512)


Extracting activations:  21%|██        | 51/243 [12:37<48:33, 15.18s/it]

(196608, 512)


Extracting activations:  21%|██▏       | 52/243 [12:53<48:26, 15.22s/it]

(196608, 512)


Extracting activations:  22%|██▏       | 53/243 [13:07<47:25, 14.98s/it]

(196608, 512)


Extracting activations:  22%|██▏       | 54/243 [13:21<46:36, 14.80s/it]

(196608, 512)


Extracting activations:  23%|██▎       | 55/243 [13:36<46:01, 14.69s/it]

(196608, 512)


Extracting activations:  23%|██▎       | 56/243 [13:50<45:27, 14.59s/it]

(196608, 512)


Extracting activations:  23%|██▎       | 57/243 [14:06<46:09, 14.89s/it]

(196608, 512)


Extracting activations:  24%|██▍       | 58/243 [14:21<45:58, 14.91s/it]

(196608, 512)


Extracting activations:  24%|██▍       | 59/243 [14:35<45:07, 14.71s/it]

(196608, 512)


Extracting activations:  25%|██▍       | 60/243 [14:49<44:18, 14.53s/it]

(196608, 512)


Extracting activations:  25%|██▌       | 61/243 [15:03<43:49, 14.45s/it]

(196608, 512)


Extracting activations:  26%|██▌       | 62/243 [15:18<43:34, 14.45s/it]

(196608, 512)


Extracting activations:  26%|██▌       | 63/243 [15:32<43:11, 14.40s/it]

(196608, 512)


Extracting activations:  26%|██▋       | 64/243 [15:46<42:54, 14.38s/it]

(196608, 512)


Extracting activations:  27%|██▋       | 65/243 [16:02<43:50, 14.78s/it]

(196608, 512)


Extracting activations:  27%|██▋       | 66/243 [16:18<44:18, 15.02s/it]

(196608, 512)


Extracting activations:  28%|██▊       | 67/243 [16:33<43:55, 14.97s/it]

(196608, 512)


Extracting activations:  28%|██▊       | 68/243 [16:47<43:03, 14.76s/it]

(196608, 512)


Extracting activations:  28%|██▊       | 69/243 [17:01<42:14, 14.57s/it]

(196608, 512)


Extracting activations:  29%|██▉       | 70/243 [17:15<41:39, 14.45s/it]

(196608, 512)


Extracting activations:  29%|██▉       | 71/243 [17:31<42:30, 14.83s/it]

(196608, 512)


Extracting activations:  30%|██▉       | 72/243 [17:46<42:45, 15.01s/it]

(196608, 512)


Extracting activations:  30%|███       | 73/243 [18:01<42:39, 15.06s/it]

(196608, 512)


Extracting activations:  30%|███       | 74/243 [18:17<42:34, 15.12s/it]

(196608, 512)


Extracting activations:  31%|███       | 75/243 [18:31<41:48, 14.93s/it]

(196608, 512)


Extracting activations:  31%|███▏      | 76/243 [18:47<42:05, 15.12s/it]

(196608, 512)


Extracting activations:  32%|███▏      | 77/243 [19:01<41:05, 14.85s/it]

(196608, 512)


Extracting activations:  32%|███▏      | 78/243 [19:15<40:21, 14.67s/it]

(196608, 512)


Extracting activations:  33%|███▎      | 79/243 [19:30<39:51, 14.58s/it]

(196608, 512)


Extracting activations:  33%|███▎      | 80/243 [19:44<39:27, 14.53s/it]

(196608, 512)


Extracting activations:  33%|███▎      | 81/243 [19:58<39:01, 14.46s/it]

(196608, 512)


Extracting activations:  34%|███▎      | 82/243 [20:13<38:42, 14.42s/it]

(196608, 512)


Extracting activations:  34%|███▍      | 83/243 [20:27<38:17, 14.36s/it]

(196608, 512)


Extracting activations:  35%|███▍      | 84/243 [20:42<38:54, 14.68s/it]

(196608, 512)


Extracting activations:  35%|███▍      | 85/243 [20:57<38:57, 14.80s/it]

(196608, 512)


Extracting activations:  35%|███▌      | 86/243 [21:12<38:30, 14.72s/it]

(196608, 512)


Extracting activations:  36%|███▌      | 87/243 [21:26<38:03, 14.64s/it]

(196608, 512)


Extracting activations:  36%|███▌      | 88/243 [21:41<37:46, 14.63s/it]

(196608, 512)


Extracting activations:  37%|███▋      | 89/243 [21:56<37:37, 14.66s/it]

(196608, 512)


Extracting activations:  37%|███▋      | 90/243 [22:10<37:03, 14.53s/it]

(196608, 512)


Extracting activations:  37%|███▋      | 91/243 [22:25<36:51, 14.55s/it]

(196608, 512)


Extracting activations:  38%|███▊      | 92/243 [22:39<36:41, 14.58s/it]

(196608, 512)


Extracting activations:  38%|███▊      | 93/243 [22:53<36:06, 14.44s/it]

(196608, 512)


Extracting activations:  39%|███▊      | 94/243 [23:08<36:01, 14.51s/it]

(196608, 512)


Extracting activations:  39%|███▉      | 95/243 [23:24<36:37, 14.85s/it]

(196608, 512)


Extracting activations:  40%|███▉      | 96/243 [23:38<35:50, 14.63s/it]

(196608, 512)


Extracting activations:  40%|███▉      | 97/243 [23:52<35:29, 14.58s/it]

(196608, 512)


Extracting activations:  40%|████      | 98/243 [24:07<35:14, 14.58s/it]

(196608, 512)


Extracting activations:  41%|████      | 99/243 [24:21<34:51, 14.52s/it]

(196608, 512)


Extracting activations:  41%|████      | 100/243 [24:36<34:34, 14.51s/it]

(196608, 512)


Extracting activations:  42%|████▏     | 101/243 [24:50<34:19, 14.50s/it]

(196608, 512)


Extracting activations:  42%|████▏     | 102/243 [25:04<33:47, 14.38s/it]

(196608, 512)


Extracting activations:  42%|████▏     | 103/243 [25:19<33:42, 14.44s/it]

(196608, 512)


Extracting activations:  43%|████▎     | 104/243 [25:34<33:42, 14.55s/it]

(196608, 512)


Extracting activations:  43%|████▎     | 105/243 [25:48<33:17, 14.47s/it]

(196608, 512)


Extracting activations:  44%|████▎     | 106/243 [26:03<33:12, 14.55s/it]

(196608, 512)


Extracting activations:  44%|████▍     | 107/243 [26:17<32:56, 14.53s/it]

(196608, 512)


Extracting activations:  44%|████▍     | 108/243 [26:33<33:20, 14.82s/it]

(196608, 512)


Extracting activations:  45%|████▍     | 109/243 [26:48<33:33, 15.03s/it]

(196608, 512)


Extracting activations:  45%|████▌     | 110/243 [27:04<33:39, 15.18s/it]

(196608, 512)


Extracting activations:  46%|████▌     | 111/243 [27:19<33:31, 15.24s/it]

(196608, 512)


Extracting activations:  46%|████▌     | 112/243 [27:33<32:43, 14.99s/it]

(196608, 512)


Extracting activations:  47%|████▋     | 113/243 [27:48<32:21, 14.93s/it]

(196608, 512)


Extracting activations:  47%|████▋     | 114/243 [28:03<31:39, 14.72s/it]

(196608, 512)


Extracting activations:  47%|████▋     | 115/243 [28:17<31:13, 14.63s/it]

(196608, 512)


Extracting activations:  48%|████▊     | 116/243 [28:31<30:44, 14.52s/it]

(196608, 512)


Extracting activations:  48%|████▊     | 117/243 [28:47<31:03, 14.79s/it]

(196608, 512)


Extracting activations:  49%|████▊     | 118/243 [29:01<30:35, 14.68s/it]

(196608, 512)


Extracting activations:  49%|████▉     | 119/243 [29:15<30:07, 14.57s/it]

(196608, 512)


Extracting activations:  49%|████▉     | 120/243 [29:31<30:13, 14.75s/it]

(196608, 512)


Extracting activations:  50%|████▉     | 121/243 [29:46<30:26, 14.97s/it]

(196608, 512)


Extracting activations:  50%|█████     | 122/243 [30:01<29:59, 14.87s/it]

(196608, 512)


Extracting activations:  51%|█████     | 123/243 [30:16<30:00, 15.00s/it]

(196608, 512)


Extracting activations:  51%|█████     | 124/243 [30:30<29:24, 14.83s/it]

(196608, 512)


Extracting activations:  51%|█████▏    | 125/243 [30:45<28:57, 14.72s/it]

(196608, 512)


Extracting activations:  52%|█████▏    | 126/243 [30:59<28:27, 14.60s/it]

(196608, 512)


Extracting activations:  52%|█████▏    | 127/243 [31:13<28:03, 14.51s/it]

(196608, 512)


Extracting activations:  53%|█████▎    | 128/243 [31:29<28:11, 14.71s/it]

(196608, 512)


Extracting activations:  53%|█████▎    | 129/243 [31:44<28:19, 14.91s/it]

(196608, 512)


Extracting activations:  53%|█████▎    | 130/243 [31:59<28:16, 15.01s/it]

(196608, 512)


Extracting activations:  54%|█████▍    | 131/243 [32:14<27:48, 14.90s/it]

(196608, 512)


Extracting activations:  54%|█████▍    | 132/243 [32:28<27:08, 14.68s/it]

(196608, 512)


Extracting activations:  55%|█████▍    | 133/243 [32:43<27:12, 14.84s/it]

(196608, 512)


Extracting activations:  55%|█████▌    | 134/243 [32:59<27:18, 15.03s/it]

(196608, 512)


Extracting activations:  56%|█████▌    | 135/243 [33:14<27:12, 15.11s/it]

(196608, 512)


Extracting activations:  56%|█████▌    | 136/243 [33:28<26:32, 14.88s/it]

(196608, 512)


Extracting activations:  56%|█████▋    | 137/243 [33:43<25:57, 14.69s/it]

(196608, 512)


Extracting activations:  57%|█████▋    | 138/243 [33:58<26:10, 14.96s/it]

(196608, 512)


Extracting activations:  57%|█████▋    | 139/243 [34:14<26:15, 15.15s/it]

(196608, 512)


Extracting activations:  58%|█████▊    | 140/243 [34:29<26:11, 15.26s/it]

(196608, 512)


Extracting activations:  58%|█████▊    | 141/243 [34:44<25:33, 15.04s/it]

(196608, 512)


Extracting activations:  58%|█████▊    | 142/243 [35:00<25:37, 15.22s/it]

(196608, 512)


Extracting activations:  59%|█████▉    | 143/243 [35:15<25:28, 15.29s/it]

(196608, 512)


Extracting activations:  59%|█████▉    | 144/243 [35:29<24:43, 14.98s/it]

(196608, 512)


Extracting activations:  60%|█████▉    | 145/243 [35:44<24:15, 14.85s/it]

(196608, 512)


Extracting activations:  60%|██████    | 146/243 [35:58<23:43, 14.67s/it]

(196608, 512)


Extracting activations:  60%|██████    | 147/243 [36:13<23:30, 14.70s/it]

(196608, 512)


Extracting activations:  61%|██████    | 148/243 [36:27<23:06, 14.60s/it]

(196608, 512)


Extracting activations:  61%|██████▏   | 149/243 [36:42<23:12, 14.82s/it]

(196608, 512)


Extracting activations:  62%|██████▏   | 150/243 [36:57<22:48, 14.72s/it]

(196608, 512)


Extracting activations:  62%|██████▏   | 151/243 [37:13<22:58, 14.98s/it]

(196608, 512)


Extracting activations:  63%|██████▎   | 152/243 [37:28<22:45, 15.01s/it]

(196608, 512)


Extracting activations:  63%|██████▎   | 153/243 [37:43<22:29, 14.99s/it]

(196608, 512)


Extracting activations:  63%|██████▎   | 154/243 [37:57<21:55, 14.78s/it]

(196608, 512)


Extracting activations:  64%|██████▍   | 155/243 [38:11<21:29, 14.66s/it]

(196608, 512)


Extracting activations:  64%|██████▍   | 156/243 [38:26<21:25, 14.78s/it]

(196608, 512)


Extracting activations:  65%|██████▍   | 157/243 [38:41<20:58, 14.64s/it]

(196608, 512)


Extracting activations:  65%|██████▌   | 158/243 [38:55<20:37, 14.56s/it]

(196608, 512)


Extracting activations:  65%|██████▌   | 159/243 [39:10<20:21, 14.55s/it]

(196608, 512)


Extracting activations:  66%|██████▌   | 160/243 [39:25<20:39, 14.93s/it]

(196608, 512)


Extracting activations:  66%|██████▋   | 161/243 [39:40<20:08, 14.74s/it]

(196608, 512)


Extracting activations:  67%|██████▋   | 162/243 [39:55<20:10, 14.94s/it]

(196608, 512)


Extracting activations:  67%|██████▋   | 163/243 [40:11<20:10, 15.13s/it]

(196608, 512)


Extracting activations:  67%|██████▋   | 164/243 [40:25<19:46, 15.02s/it]

(196608, 512)


Extracting activations:  68%|██████▊   | 165/243 [40:40<19:11, 14.76s/it]

(196608, 512)


Extracting activations:  68%|██████▊   | 166/243 [40:54<18:53, 14.72s/it]

(196608, 512)


Extracting activations:  69%|██████▊   | 167/243 [41:08<18:28, 14.59s/it]

(196608, 512)


Extracting activations:  69%|██████▉   | 168/243 [41:23<18:04, 14.45s/it]

(196608, 512)


Extracting activations:  70%|██████▉   | 169/243 [41:38<18:20, 14.87s/it]

(196608, 512)


Extracting activations:  70%|██████▉   | 170/243 [41:54<18:31, 15.23s/it]

(196608, 512)


Extracting activations:  70%|███████   | 171/243 [42:10<18:13, 15.19s/it]

(196608, 512)


Extracting activations:  71%|███████   | 172/243 [42:24<17:46, 15.02s/it]

(196608, 512)


Extracting activations:  71%|███████   | 173/243 [42:38<17:15, 14.79s/it]

(196608, 512)


Extracting activations:  72%|███████▏  | 174/243 [42:52<16:44, 14.56s/it]

(196608, 512)


Extracting activations:  72%|███████▏  | 175/243 [43:08<16:56, 14.95s/it]

(196608, 512)


Extracting activations:  72%|███████▏  | 176/243 [43:24<16:54, 15.13s/it]

(196608, 512)


Extracting activations:  73%|███████▎  | 177/243 [43:39<16:40, 15.16s/it]

(196608, 512)


Extracting activations:  73%|███████▎  | 178/243 [43:55<16:34, 15.29s/it]

(196608, 512)


Extracting activations:  74%|███████▎  | 179/243 [44:09<16:02, 15.04s/it]

(196608, 512)


Extracting activations:  74%|███████▍  | 180/243 [44:25<15:58, 15.21s/it]

(196608, 512)


Extracting activations:  74%|███████▍  | 181/243 [44:39<15:28, 14.97s/it]

(196608, 512)


Extracting activations:  75%|███████▍  | 182/243 [44:55<15:22, 15.12s/it]

(196608, 512)


Extracting activations:  75%|███████▌  | 183/243 [45:09<14:55, 14.92s/it]

(196608, 512)


Extracting activations:  76%|███████▌  | 184/243 [45:24<14:32, 14.78s/it]

(196608, 512)


Extracting activations:  76%|███████▌  | 185/243 [45:38<14:09, 14.65s/it]

(196608, 512)


Extracting activations:  77%|███████▋  | 186/243 [45:53<14:07, 14.86s/it]

(196608, 512)


Extracting activations:  77%|███████▋  | 187/243 [46:09<14:08, 15.15s/it]

(196608, 512)


Extracting activations:  77%|███████▋  | 188/243 [46:25<14:10, 15.47s/it]

(196608, 512)


Extracting activations:  78%|███████▊  | 189/243 [46:41<13:53, 15.44s/it]

(196608, 512)


Extracting activations:  78%|███████▊  | 190/243 [46:55<13:24, 15.19s/it]

(196608, 512)


Extracting activations:  79%|███████▊  | 191/243 [47:10<12:56, 14.93s/it]

(196608, 512)


Extracting activations:  79%|███████▉  | 192/243 [47:24<12:30, 14.72s/it]

(196608, 512)


Extracting activations:  79%|███████▉  | 193/243 [47:39<12:25, 14.92s/it]

(196608, 512)


Extracting activations:  80%|███████▉  | 194/243 [47:54<12:02, 14.74s/it]

(196608, 512)


Extracting activations:  80%|████████  | 195/243 [48:08<11:41, 14.62s/it]

(196608, 512)


Extracting activations:  81%|████████  | 196/243 [48:22<11:24, 14.56s/it]

(196608, 512)


Extracting activations:  81%|████████  | 197/243 [48:37<11:14, 14.66s/it]

(196608, 512)


Extracting activations:  81%|████████▏ | 198/243 [48:52<10:57, 14.61s/it]

(196608, 512)


Extracting activations:  82%|████████▏ | 199/243 [49:06<10:39, 14.53s/it]

(196608, 512)


Extracting activations:  82%|████████▏ | 200/243 [49:22<10:45, 15.02s/it]

(196608, 512)


Extracting activations:  83%|████████▎ | 201/243 [49:37<10:29, 14.98s/it]

(196608, 512)


Extracting activations:  83%|████████▎ | 202/243 [49:53<10:21, 15.16s/it]

(196608, 512)


Extracting activations:  84%|████████▎ | 203/243 [50:08<10:09, 15.24s/it]

(196608, 512)


Extracting activations:  84%|████████▍ | 204/243 [50:23<09:50, 15.13s/it]

(196608, 512)


Extracting activations:  84%|████████▍ | 205/243 [50:39<09:39, 15.25s/it]

(196608, 512)


Extracting activations:  85%|████████▍ | 206/243 [50:53<09:16, 15.05s/it]

(196608, 512)


Extracting activations:  85%|████████▌ | 207/243 [51:09<09:06, 15.19s/it]

(196608, 512)


Extracting activations:  86%|████████▌ | 208/243 [51:24<08:53, 15.24s/it]

(196608, 512)


Extracting activations:  86%|████████▌ | 209/243 [51:39<08:31, 15.04s/it]

(196608, 512)


Extracting activations:  86%|████████▋ | 210/243 [51:53<08:09, 14.83s/it]

(196608, 512)


Extracting activations:  87%|████████▋ | 211/243 [52:07<07:49, 14.66s/it]

(196608, 512)


Extracting activations:  87%|████████▋ | 212/243 [52:23<07:47, 15.07s/it]

(196608, 512)


Extracting activations:  88%|████████▊ | 213/243 [52:39<07:35, 15.19s/it]

(196608, 512)


Extracting activations:  88%|████████▊ | 214/243 [52:54<07:24, 15.31s/it]

(196608, 512)


Extracting activations:  88%|████████▊ | 215/243 [53:10<07:11, 15.42s/it]

(196608, 512)


Extracting activations:  89%|████████▉ | 216/243 [53:26<07:00, 15.59s/it]

(196608, 512)


Extracting activations:  89%|████████▉ | 217/243 [53:40<06:36, 15.24s/it]

(196608, 512)


Extracting activations:  90%|████████▉ | 218/243 [53:55<06:15, 15.01s/it]

(196608, 512)


Extracting activations:  90%|█████████ | 219/243 [54:09<05:54, 14.76s/it]

(196608, 512)


Extracting activations:  91%|█████████ | 220/243 [54:25<05:48, 15.13s/it]

(196608, 512)


Extracting activations:  91%|█████████ | 221/243 [54:39<05:27, 14.89s/it]

(196608, 512)


Extracting activations:  91%|█████████▏| 222/243 [54:55<05:15, 15.03s/it]

(196608, 512)


Extracting activations:  92%|█████████▏| 223/243 [55:10<05:03, 15.18s/it]

(196608, 512)


Extracting activations:  92%|█████████▏| 224/243 [55:25<04:44, 14.97s/it]

(196608, 512)


Extracting activations:  93%|█████████▎| 225/243 [55:39<04:26, 14.79s/it]

(196608, 512)


Extracting activations:  93%|█████████▎| 226/243 [55:54<04:10, 14.76s/it]

(196608, 512)


Extracting activations:  93%|█████████▎| 227/243 [56:08<03:53, 14.61s/it]

(196608, 512)


Extracting activations:  94%|█████████▍| 228/243 [56:23<03:39, 14.63s/it]

(196608, 512)


Extracting activations:  94%|█████████▍| 229/243 [56:37<03:23, 14.52s/it]

(196608, 512)


Extracting activations:  95%|█████████▍| 230/243 [56:51<03:07, 14.45s/it]

(196608, 512)


Extracting activations:  95%|█████████▌| 231/243 [57:07<02:56, 14.72s/it]

(196608, 512)


Extracting activations:  95%|█████████▌| 232/243 [57:21<02:40, 14.58s/it]

(196608, 512)


Extracting activations:  96%|█████████▌| 233/243 [57:37<02:29, 14.92s/it]

(196608, 512)


Extracting activations:  96%|█████████▋| 234/243 [57:52<02:16, 15.13s/it]

(196608, 512)


Extracting activations:  97%|█████████▋| 235/243 [58:06<01:59, 14.89s/it]

(196608, 512)


Extracting activations:  97%|█████████▋| 236/243 [58:22<01:44, 15.00s/it]

(196608, 512)


Extracting activations:  98%|█████████▊| 237/243 [58:36<01:28, 14.82s/it]

(196608, 512)


Extracting activations:  98%|█████████▊| 238/243 [58:50<01:13, 14.69s/it]

(196608, 512)


Extracting activations:  98%|█████████▊| 239/243 [59:05<00:58, 14.57s/it]

(196608, 512)


Extracting activations:  99%|█████████▉| 240/243 [59:19<00:43, 14.50s/it]

(196608, 512)


Extracting activations:  99%|█████████▉| 241/243 [59:34<00:29, 14.59s/it]

(196608, 512)


Extracting activations: 100%|█████████▉| 242/243 [59:49<00:14, 14.75s/it]

(193536, 512)


Extracting activations: 100%|██████████| 243/243 [1:00:03<00:00, 14.83s/it]


# Training SAE

In [13]:
importlib.reload(utils)

# Enable anomaly detection
torch.autograd.set_detect_anomaly(True)

# Loads chunks of the dataset at a time
file_path = '/content/drive/MyDrive/SAEs_for_Genomics/plasmidpretrained_NT50_L12_mlp.activations.h5'
start_chunk = 0
dataset = utils.ChunkedActivationsDataset(file_path)

# Train sae_model on dataset chunks
sae_model = utils.train_sae(
        sae_model=sae_model,
        dataset=dataset,
        batch_size=2048*4,
        num_epochs=20,
        learning_rate=5e-5,
        device='cuda',
        start_chunk = start_chunk,
        wandb_log= False,
        save_dir='/content/drive/MyDrive/SAEs_for_Genomics/sae_weights/SAE_NT50_plasmidpre_L12_mlpout_40mtokens_190325.pt'
)



Epoch 1/20:   0%|          | 0/5832 [00:00<?, ?it/s]

IndexError: index 13620 is out of bounds for dimension 0 with size 0

In [15]:
# save sae weights and HPs
import os
import torch

model_dict = {
    'model_state_dict': sae_model.state_dict(),
    'hyperparameters': cfg
}

#save
torch.save(model_dict, '/content/drive/MyDrive/SAEs_for_Genomics/sae_weights/extra_addgene_pretraining/SAE_NT50_plasmidpre_L12_mlpout_40mtokens_190325.pt')


# HP Tuning

In [None]:
### wandb init
!pip install wandb
import wandb

# Set your API key here
API_KEY = "9d4aba4e2ab3fcfc7278c20827aacff0f7fca089"

def initialize_wandb(run_name, hyperparameter, reinit, project = "DeepGEA"):
    wandb.login(key=API_KEY)  # Ensure API key is stored securely
    wandb.init(project=project, entity="amaiwald", name=run_name, mode = 'online', config=hyperparameter, reinit=reinit)

run_name = 'SAE_random.NT50_L11.mlp_48mtokens' ## choose
#initialize_wandb(run_name = run_name, hyperparameter = cfg)




In [None]:
# randomly search grid of HPs and train each for 20 ep
from sklearn.model_selection import ParameterGrid
from sklearn.model_selection import ParameterSampler
import numpy as np
from datetime import datetime
import json
import traceback

param_distributions = {
    'lr': np.log10(np.array([5e-6, 1e-3])),  # log-uniform between 5e-6 and 1e-3
    'l0_weight': np.log10(np.array([0.001, 2.0])),  # log-uniform between 0.001 and 2.0
    'threshold': np.array([0.03, 2.4]),  # uniform between 0.03 and 2.4
    'dict_mult': [8, 16, 32]  # keep discrete values
}

def sample_params(n_samples=10, random_seed=42):
    samples = list(ParameterSampler(param_distributions, n_samples, random_state=random_seed))
    for sample in samples:
        sample['lr'] = 10 ** sample['lr']
        sample['l0_weight'] = 10 ** sample['l0_weight']
    return samples


# Create full training dataset
dataset = ChunkedActivationsDataset(
    '/content/drive/MyDrive/SAEs_for_Genomics/plasmidpretrained_NT50_L11_mlp.activations.h5',
    batch_size=2048*4,
    chunks_in_memory=20  # Adjust based on your GPU memory
)

# Test the sampling
samples = sample_params(n_samples=20, random_seed=42)
print(samples)

# Constants
N_EPOCHS = 10


# In your main script
for i, sample in enumerate(samples):
    # Create unique directory for this trial
    trial_dir = os.path.join('/content/drive/MyDrive/SAEs_for_Genomics', f'NT.plasmidpretr_SAE_trial_{i}')
    os.makedirs(trial_dir, exist_ok=True)

    # Save trial config
    trial_config = {
        'trial_id': i,
        'hyperparameters': sample,
        'timestamp': datetime.now().isoformat()
    }
    with open(os.path.join(trial_dir, 'config.json'), 'w') as f:
        json.dump(trial_config, f, indent=2)

    # Set HPs for this trial
    cfg['lr'] = sample['lr']
    cfg['l0_weight'] = sample['l0_weight']
    cfg['threshold'] = sample['threshold']
    cfg['dict_mult'] = sample['dict_mult']

    print(f"\nStarting trial {i+1}/20")
    print("HPs:", cfg)
    print(50*"-")

    # Initialize model and W&B
    sae_model = AutoEncoder(cfg)
    run_name = f'SAE_NT50.plasmidpretr_L11.mlp_48mtokens_trial_{i}'
    initialize_wandb(run_name=run_name, hyperparameter=cfg, reinit=True)

    try:
        results = train_sae(
            sae_model=sae_model,
            dataset=dataset,
            batch_size=2048*4,
            num_epochs=N_EPOCHS,
            learning_rate=cfg['lr'],
            device='cuda',
            start_chunk=start_chunk,
            wandb_log=True,
            save_dir=trial_dir  # Save in trial-specific directory
        )

        sae_model = results['model']
        best_loss = results['best_loss']
        best_epoch = results['best_epoch']
        final_metrics = results['final_metrics']

        # Save final model state with hyperparameters
        final_save_path = os.path.join(trial_dir, 'final_model.pt')
        torch.save({
            'trial_id': i,
            'hyperparameters': sample,
            'model_state_dict': sae_model.state_dict(),
            'final_metrics': {
                'best_loss': best_loss,  # You'll need to modify train_sae to return this
                'final_epoch': N_EPOCHS
            }
        }, final_save_path)

    except Exception as e:
        print(f"Trial {i+1} failed with error: {str(e)}")
        with open(os.path.join(trial_dir, 'error_log.txt'), 'w') as f:
            f.write(f"Error: {str(e)}\n")
            f.write(traceback.format_exc())

    finally:
        wandb.finish()
        print(f"Cleaned up trial {i+1}")



[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


[{'threshold': 0.03, 'lr': 4.9999999999999996e-06, 'l0_weight': 0.001, 'dict_mult': 16}, {'threshold': 0.03, 'lr': 4.9999999999999996e-06, 'l0_weight': 0.001, 'dict_mult': 32}, {'threshold': 0.03, 'lr': 4.9999999999999996e-06, 'l0_weight': 0.001, 'dict_mult': 8}, {'threshold': 0.03, 'lr': 0.001, 'l0_weight': 0.001, 'dict_mult': 32}, {'threshold': 2.4, 'lr': 0.001, 'l0_weight': 0.001, 'dict_mult': 16}, {'threshold': 2.4, 'lr': 4.9999999999999996e-06, 'l0_weight': 0.001, 'dict_mult': 16}, {'threshold': 2.4, 'lr': 4.9999999999999996e-06, 'l0_weight': 2.0, 'dict_mult': 16}, {'threshold': 2.4, 'lr': 4.9999999999999996e-06, 'l0_weight': 0.001, 'dict_mult': 8}, {'threshold': 2.4, 'lr': 4.9999999999999996e-06, 'l0_weight': 2.0, 'dict_mult': 32}, {'threshold': 2.4, 'lr': 4.9999999999999996e-06, 'l0_weight': 2.0, 'dict_mult': 8}, {'threshold': 0.03, 'lr': 0.001, 'l0_weight': 0.001, 'dict_mult': 8}, {'threshold': 0.03, 'lr': 4.9999999999999996e-06, 'l0_weight': 2.0, 'dict_mult': 16}, {'threshold'

Epoch 1, Chunk 1/293:   5%|▌         | 1/20 [00:00<00:11,  1.62it/s, loss=2458.4248, l2_loss=901.9199, l1_loss=774.8301, l0_loss=1556.5049, nmse=1.3272]

Input stats:
Input range: [-215.5000, 69.0000]
Input mean/std: 0.0184 / 11.7914

Activations after encoding:
Acts range: [0.0000, 1.8428]
Acts mean/std: 0.1378 / 0.2050

Sparse activations:
Sparse acts range: [0.0000, 1.8428]
Sparsity: 80.9697%

Reconstruction stats:
Recon range: [-4.0586, 4.3750]
Recon mean/std: -0.0100 / 0.8696
Input stats:
Input range: [-214.2500, 66.0625]
Input mean/std: 0.0165 / 11.8339

Activations after encoding:
Acts range: [0.0000, 1.8193]
Acts mean/std: 0.1377 / 0.2050

Sparse activations:
Sparse acts range: [0.0000, 1.8193]
Sparsity: 80.9997%

Reconstruction stats:
Recon range: [-3.9180, 4.2734]
Recon mean/std: -0.0099 / 0.8691


Epoch 1, Chunk 1/293:  15%|█▌        | 3/20 [00:00<00:03,  4.37it/s, loss=2459.9651, l2_loss=902.2513, l1_loss=775.2341, l0_loss=1557.7137, nmse=1.3275]

Input stats:
Input range: [-212.5000, 66.7500]
Input mean/std: 0.0177 / 11.8423

Activations after encoding:
Acts range: [0.0000, 1.8447]
Acts mean/std: 0.1377 / 0.2050

Sparse activations:
Sparse acts range: [0.0000, 1.8447]
Sparsity: 80.9895%

Reconstruction stats:
Recon range: [-4.0938, 4.2852]
Recon mean/std: -0.0099 / 0.8691
Input stats:
Input range: [-215.5000, 66.5000]
Input mean/std: 0.0178 / 11.8128

Activations after encoding:
Acts range: [0.0000, 1.8535]
Acts mean/std: 0.1378 / 0.2050

Sparse activations:
Sparse acts range: [0.0000, 1.8535]
Sparsity: 80.9849%

Reconstruction stats:
Recon range: [-3.9570, 4.2422]
Recon mean/std: -0.0104 / 0.8696
Input stats:
Input range: [-212.0000, 67.2500]
Input mean/std: 0.0177 / 11.8573

Activations after encoding:
Acts range: [0.0000, 1.8223]
Acts mean/std: 0.1377 / 0.2050

Sparse activations:


Epoch 1, Chunk 1/293:  25%|██▌       | 5/20 [00:01<00:02,  6.41it/s, loss=2462.3271, l2_loss=902.3057, l1_loss=776.0931, l0_loss=1560.0215, nmse=1.3275]

Sparse acts range: [0.0000, 1.8223]
Sparsity: 81.0067%

Reconstruction stats:
Recon range: [-3.9766, 4.2969]
Recon mean/std: -0.0098 / 0.8691
Input stats:
Input range: [-214.7500, 67.1250]
Input mean/std: 0.0174 / 11.7638

Activations after encoding:
Acts range: [0.0000, 1.8564]
Acts mean/std: 0.1378 / 0.2050

Sparse activations:
Sparse acts range: [0.0000, 1.8564]
Sparsity: 80.9568%

Reconstruction stats:
Recon range: [-3.9570, 4.3477]
Recon mean/std: -0.0101 / 0.8696
Input stats:
Input range: [-213.5000, 67.0000]
Input mean/std: 0.0178 / 11.8059

Activations after encoding:
Acts range: [0.0000, 1.8398]
Acts mean/std: 0.1378 / 0.2050

Sparse activations:
Sparse acts range: [0.0000, 1.8398]
Sparsity: 80.9688%

Reconstruction stats:
Recon range: [-3.8867, 4.2344]
Recon mean/std: -0.0099 / 0.8696


Epoch 1, Chunk 1/293:  45%|████▌     | 9/20 [00:01<00:01,  9.29it/s, loss=2446.2397, l2_loss=890.7029, l1_loss=774.0552, l0_loss=1555.5370, nmse=1.3190]

Input stats:
Input range: [-213.7500, 66.0625]
Input mean/std: 0.0168 / 11.8329

Activations after encoding:
Acts range: [0.0000, 1.8701]
Acts mean/std: 0.1377 / 0.2050

Sparse activations:
Sparse acts range: [0.0000, 1.8701]
Sparsity: 80.9955%

Reconstruction stats:
Recon range: [-4.1289, 4.2578]
Recon mean/std: -0.0102 / 0.8691
Input stats:
Input range: [-214.2500, 67.1250]
Input mean/std: 0.0177 / 11.8205

Activations after encoding:
Acts range: [0.0000, 1.8506]
Acts mean/std: 0.1376 / 0.2048

Sparse activations:
Sparse acts range: [0.0000, 1.8506]
Sparsity: 81.0115%

Reconstruction stats:
Recon range: [-3.8145, 4.2500]
Recon mean/std: -0.0097 / 0.8628
Input stats:
Input range: [-213.6250, 67.3750]
Input mean/std: 0.0179 / 11.8741

Activations after encoding:
Acts range: [0.0000, 1.8281]
Acts mean/std: 0.1375 / 0.2047

Sparse activations:
Sparse acts range: [0.0000, 1.8281]
Sparsity: 81.0610%

Reconstruction stats:
Recon range: [-3.7676, 4.2227]

Epoch 1, Chunk 1/293:  55%|█████▌    | 11/20 [00:01<00:00, 10.13it/s, loss=2422.1863, l2_loss=869.2629, l1_loss=772.3613, l0_loss=1552.9235, nmse=1.3030]


Recon mean/std: -0.0088 / 0.8560
Input stats:
Input range: [-214.7500, 68.3125]
Input mean/std: 0.0160 / 11.8427

Activations after encoding:
Acts range: [0.0000, 1.8281]
Acts mean/std: 0.1373 / 0.2046

Sparse activations:
Sparse acts range: [0.0000, 1.8281]
Sparsity: 81.0434%

Reconstruction stats:
Recon range: [-3.9512, 4.2930]
Recon mean/std: -0.0087 / 0.8501
Input stats:
Input range: [-215.3750, 67.6250]
Input mean/std: 0.0178 / 11.8562

Activations after encoding:
Acts range: [0.0000, 1.8135]
Acts mean/std: 0.1372 / 0.2045

Sparse activations:
Sparse acts range: [0.0000, 1.8135]
Sparsity: 81.0681%

Reconstruction stats:
Recon range: [-3.8711, 4.1172]
Recon mean/std: -0.0080 / 0.8438


Epoch 1, Chunk 1/293:  65%|██████▌   | 13/20 [00:01<00:00, 10.72it/s, loss=2386.1106, l2_loss=837.4636, l1_loss=769.6873, l0_loss=1548.6470, nmse=1.2789]

Input stats:
Input range: [-212.2500, 67.4375]
Input mean/std: 0.0173 / 11.8154

Activations after encoding:
Acts range: [0.0000, 1.8281]
Acts mean/std: 0.1371 / 0.2043

Sparse activations:
Sparse acts range: [0.0000, 1.8281]
Sparsity: 81.0834%

Reconstruction stats:
Recon range: [-3.7070, 4.1992]
Recon mean/std: -0.0078 / 0.8379
Input stats:
Input range: [-211.7500, 69.0625]
Input mean/std: 0.0177 / 11.8040

Activations after encoding:
Acts range: [0.0000, 1.8486]
Acts mean/std: 0.1371 / 0.2042

Sparse activations:
Sparse acts range: [0.0000, 1.8486]
Sparsity: 81.0956%

Reconstruction stats:
Recon range: [-3.7754, 4.1641]
Recon mean/std: -0.0071 / 0.8320
Input stats:
Input range: [-212.7500, 65.4375]
Input mean/std: 0.0167 / 11.8101

Activations after encoding:
Acts range: [0.0000, 1.8203]
Acts mean/std: 0.1368 / 0.2041

Sparse activations:
Sparse acts range: [0.0000, 1.8203]
Sparsity: 81.1201%

Reconstruction stats:


Epoch 1, Chunk 1/293:  75%|███████▌  | 15/20 [00:01<00:00, 11.16it/s, loss=2360.7837, l2_loss=816.8083, l1_loss=767.3152, l0_loss=1543.9752, nmse=1.2631]

Recon range: [-3.6445, 4.1367]
Recon mean/std: -0.0066 / 0.8257
Input stats:
Input range: [-213.5000, 66.3750]
Input mean/std: 0.0174 / 11.8377

Activations after encoding:
Acts range: [0.0000, 1.8203]
Acts mean/std: 0.1367 / 0.2041

Sparse activations:
Sparse acts range: [0.0000, 1.8203]
Sparsity: 81.1526%

Reconstruction stats:
Recon range: [-3.6582, 4.0859]
Recon mean/std: -0.0060 / 0.8198
Input stats:
Input range: [-213.1250, 66.1250]
Input mean/std: 0.0169 / 11.8000

Activations after encoding:
Acts range: [0.0000, 1.8213]
Acts mean/std: 0.1367 / 0.2039

Sparse activations:
Sparse acts range: [0.0000, 1.8213]
Sparsity: 81.1317%

Reconstruction stats:
Recon range: [-3.8125, 4.0547]
Recon mean/std: -0.0059 / 0.8145


Epoch 1, Chunk 1/293:  95%|█████████▌| 19/20 [00:02<00:00, 11.73it/s, loss=2329.0815, l2_loss=787.2192, l1_loss=765.5928, l0_loss=1541.8623, nmse=1.2400]

Input stats:
Input range: [-212.5000, 66.7500]
Input mean/std: 0.0183 / 11.8541

Activations after encoding:
Acts range: [0.0000, 1.8379]
Acts mean/std: 0.1365 / 0.2039

Sparse activations:
Sparse acts range: [0.0000, 1.8379]
Sparsity: 81.1909%

Reconstruction stats:
Recon range: [-3.8086, 3.8848]
Recon mean/std: -0.0053 / 0.8081
Input stats:
Input range: [-212.8750, 66.3750]
Input mean/std: 0.0170 / 11.8189

Activations after encoding:
Acts range: [0.0000, 1.7988]
Acts mean/std: 0.1365 / 0.2037

Sparse activations:
Sparse acts range: [0.0000, 1.7988]
Sparsity: 81.1784%

Reconstruction stats:
Recon range: [-3.6504, 3.9355]
Recon mean/std: -0.0052 / 0.8032
Input stats:
Input range: [-211.8750, 67.2500]
Input mean/std: 0.0160 / 11.8548

Activations after encoding:
Acts range: [0.0000, 1.7949]
Acts mean/std: 0.1362 / 0.2036

Sparse activations:
Sparse acts range: [0.0000, 1.7949]
Sparsity: 81.2209%

Reconstruction stats:
Recon range: [-3.6074, 3.8867]
Recon mean/std: -0.0046 / 0.7974

Epoch 1, Chunk 1/293: 100%|██████████| 20/20 [00:02<00:00,  8.71it/s, loss=2315.5640, l2_loss=777.1824, l1_loss=763.9532, l0_loss=1538.3817, nmse=1.2320]





Epoch 1, Chunk 2/293:   0%|          | 0/20 [00:00<?, ?it/s]

### test HP run

In [None]:
initialize_wandb(run_name=run_name, reinit = True, hyperparameter=cfg, project="test_genomics_sae")  # Note the test_ prefix


# Test version of your code
def test_hp_search(n_trials=2, n_epochs=2):
    # Minimal parameter space
    test_param_distributions = {
        'lr': np.log10(np.array([1e-4, 1e-3])),  # just two values
        'l0_weight': np.log10(np.array([0.1, 0.2])),
        'threshold': np.array([0.1, 0.2]),
        'dict_mult': [8]  # single value
    }
    BATCH_SIZE = 2048*4

    test_dataset = ChunkedActivationsDataset(
    '/content/drive/MyDrive/SAEs_for_Genomics/plasmidpretrained_NT50_L11_mlp.activations.h5',
    batch_size=BATCH_SIZE,  # Your actual batch size
    chunks_in_memory=4,  # Each chunk holds 4 batches worth of data
    max_chunks=10  # Limit total chunks for testing
    )

    samples = sample_params(n_samples=n_trials)

    for i, sample in enumerate(samples):
        print(f"Test trial {i+1}/{n_trials}")

        cfg['lr'] = sample['lr']
        cfg['l0_weight'] = sample['l0_weight']
        cfg['threshold'] = sample['threshold']
        cfg['dict_mult'] = sample['dict_mult']

        sae_model = AutoEncoder(cfg)

        # Note the TEST prefix in run name
        run_name = f'TEST_run{i}'
        initialize_wandb(
            run_name=run_name,
            hyperparameter=cfg,
            reinit=True,
            project="test_genomics_sae"  # separate test project
        )


        try:
            # Use minimal data and epochs
            sae_model = train_sae(
                sae_model=sae_model,
                dataset=test_dataset,  # just use 100 samples
                batch_size=BATCH_SIZE,  # smaller batch
                num_epochs=n_epochs,  # minimal epochs
                learning_rate=cfg['lr'],
                device='cuda',
                start_chunk=start_chunk,
                wandb_log=True,
                save_dir='/tmp/test_sae/'  # temporary directory
            )
        finally:
            wandb.finish()
            print(f"Finished and cleaned up trial {i+1}")

# Run the test
test_hp_search(n_trials=5, n_epochs=2)



In [None]:
def cleanup_test_runs():
    api = wandb.Api()
    runs = api.runs("amaiwald/test_genomics_sae")
    for run in runs:
        if run.name.startswith('TEST_'):
            print(f"Deleting test run: {run.name}")
            run.delete()

# Run cleanup after testing
cleanup_test_runs()

Deleting test run: TEST_run0
Deleting test run: TEST_run0
Deleting test run: TEST_run0
Deleting test run: TEST_run0
Deleting test run: TEST_run0
Deleting test run: TEST_run0
Deleting test run: TEST_run0
Deleting test run: TEST_run1
Deleting test run: TEST_run2
Deleting test run: TEST_run3
Deleting test run: TEST_run4


In [None]:
## load custom functions from utils.py

import sys
sys.path.append('//content/drive/MyDrive/SAEs_for_Genomics')

import importlib
import utils
importlib.reload(utils)

<module 'utils' from '//content/drive/MyDrive/SAEs_for_Genomics/utils.py'>

In [None]:
import torch
from torch.utils.checkpoint import checkpoint
import gc
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR

def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
    return LambdaLR(optimizer, lr_lambda)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.AdamW(sae_model.parameters(), lr=cfg['lr'], betas=(cfg['beta1'], cfg['beta2']))
total_steps = cfg['total_training_steps']
warmup_steps = cfg['lr_warm_up_steps']
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

# Enable gradient checkpointing for the transformer model
model_nt.to(device).eval()
sae_model.train()

num_layer = 11
d_mlp = model_nt.config.hidden_size
n_epochs = 20
global_step = 0
train_epoch_loss = 0
train_epoch_l2_loss = 0
train_epoch_l1_loss = 0
train_epoch_l0_loss = 0
train_epoch_nmse = 0
all_latents = []

for epoche in range(n_epochs):
    for batch in tqdm(train_dataloader, desc="Inference"):
        with torch.no_grad():
            # Move everything to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            mlp_act = utils.get_layer_activations(
                model_nt.cuda(),
                input_ids,
                attention_mask
            )

            mlp_act = mlp_act[0].reshape(-1, d_mlp)

            # Move to CPU temporarily to save GPU memory
            mlp_act = mlp_act.cpu()

        # Move back to GPU for SAE processing
        mlp_act = mlp_act.cuda()

        # Train SAE based on this
        optimizer.zero_grad()
        loss, x_reconstruct, acts, l2_loss, nmse, l1_loss, l0_loss = sae_model(mlp_act)
        loss.backward()

        # Explicit gradient cleanup
        for param in sae_model.parameters():
            if param.grad is not None:
                param.grad.detach_()
                param.grad.zero_()

        sae_model.remove_parallel_component_of_grads()
        optimizer.step()
        scheduler.step()

        train_epoch_loss += loss.item()
        train_epoch_l2_loss += l2_loss.item()
        train_epoch_l1_loss += l1_loss.item()
        train_epoch_l0_loss += l0_loss.item()
        train_epoch_nmse += nmse

        # Explicit cleanup
        del loss, x_reconstruct, acts, l2_loss, nmse, l1_loss, l0_loss
        del mlp_act
        torch.cuda.empty_cache()
        gc.collect()

        global_step += 1
        if global_step >= total_steps:
            break

    # Calculate average training losses
    avg_train_loss = train_epoch_loss / len(train_dataloader)
    avg_train_l2_loss = train_epoch_l2_loss / len(train_dataloader)
    avg_train_l1_loss = train_epoch_l1_loss / len(train_dataloader)
    avg_train_l0_loss = train_epoch_l0_loss / len(train_dataloader)
    avg_train_nmse = train_epoch_nmse / len(train_dataloader)

    print(f'Training - Epoch {epoche}, Loss: {avg_train_loss:.4f}, L2 Loss: {avg_train_l2_loss:.4f}, '
          f'L1 Loss: {avg_train_l1_loss:.4f}, L0 Loss: {avg_train_l0_loss:.4f}, NMSE: {avg_train_nmse:.4f}, '
          f'LR: {scheduler.get_last_lr()[0]:.6f}')

    # Reset metrics for next epoch
    train_epoch_loss = 0
    train_epoch_l2_loss = 0
    train_epoch_l1_loss = 0
    train_epoch_l0_loss = 0
    train_epoch_nmse = 0
    torch.cuda.empty_cache()

Inference:   0%|          | 0/23327 [00:00<?, ?it/s]

Exception ignored in: <function _xla_gc_callback at 0x7dbe53a839c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/lib/__init__.py", line 96, in _xla_gc_callback
    def _xla_gc_callback(*args):
    
KeyboardInterrupt: 


KeyboardInterrupt: 

# Load pretraining data of NT

In [None]:
!pip install datasets
!pip install huggingface_hub
!pip install biopython

## Load HF dataset for Streaming

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from typing import Optional
import numpy as np
from tqdm import tqdm



class StreamingGenomicDataset(Dataset):
    def __init__(
        self,
        split: str = "train",
        tokenizer = None,
        seq_length: int = 512,
        max_samples: Optional[int] = None,
        cache_mode: bool = True
    ):
        """
        Streaming dataset for genomic sequences.

        Args:
            split: Dataset split ('train', 'validation', 'test')
            tokenizer: Tokenizer for DNA sequences
            seq_length: Maximum sequence length
            max_samples: Maximum number of samples to load (None for all)
            cache_mode: If True, caches all sequences in memory
        """
        self.seq_length = seq_length
        self.tokenizer = tokenizer

        # In the __init__ method:
        dataset = load_dataset("metagene-ai/HumanVirusInfecting", "class-1", streaming=True)
        # Get the specific split
        dataset = dataset[split]  # This was missing in your original code

        if cache_mode:
            # Cache all sequences in memory
            self.sequences = []
            pbar = tqdm(dataset, total=max_samples, desc=f"Loading {split} data")

            for i, item in enumerate(pbar):
                if max_samples and i >= max_samples:
                    break
                if not isinstance(item, dict):
                    print(f"Warning: item at index {i} is not a dictionary: {item}")
                    continue
                self.sequences.append(item)
        else:
            # Store iterator for streaming mode
            self.sequences = dataset
            self.max_samples = max_samples

    def __len__(self):
        if isinstance(self.sequences, list):
            return len(self.sequences)
        return self.max_samples if self.max_samples else int(1e9)  # Large number for streaming

    # First, let's modify the __getitem__ method to debug the data structure:

    def __getitem__(self, idx):
        if isinstance(self.sequences, list):
            # Cached mode
            item = self.sequences[idx]
        else:
            # Streaming mode
            item = next(iter(self.sequences))


        # Let's safely access the sequence
        if isinstance(item, dict):
            sequence = item.get('sequence')
        else:
            raise TypeError(f"Expected item to be dict, got {type(item)} instead")

        if sequence is None:
            raise ValueError("No 'sequence' field found in item")

        # Rest of your code remains the same
        inputs = self.tokenizer(
            sequence,
            max_length=self.seq_length,
            padding='max_length',
            truncation=True,
            return_tensors="pt"
        )

        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'sequence_info': {
                'description': item.get('description', ''),
                'start_pos': item.get('start_pos', 0),
                'end_pos': item.get('end_pos', 0)
            }
        }

def create_genomic_dataloaders(
    tokenizer,
    batch_size: int = 32,
    seq_length: int = 512,
    max_samples: Optional[int] = None,
    num_workers: int = 2,
    cache_mode: bool = True
):
    """
    Create training and validation DataLoaders for genomic data.

    Args:
        tokenizer: DNA sequence tokenizer
        batch_size: Batch size for DataLoader
        seq_length: Maximum sequence length
        max_samples: Maximum samples per split (None for all)
        num_workers: Number of DataLoader workers
        cache_mode: If True, caches all sequences in memory
    """
    # Create datasets
    train_dataset = StreamingGenomicDataset(
        split="train",
        tokenizer=tokenizer,
        seq_length=seq_length,
        max_samples=max_samples,
        cache_mode=cache_mode
    )


    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=cache_mode,  # Can only shuffle if data is cached
        num_workers=num_workers,
        pin_memory=True
    )


    return train_loader

# Example usage
# Create dataloaders with small sample size for testing
train_loader = create_genomic_dataloaders(
    tokenizer=tokenizer,
    batch_size=16*3,
    seq_length=50,
    max_samples=19600,  # Small sample size for testing
    cache_mode=True
)

# Print dataset sizes
print(f"Training batches: {len(train_loader)}")

# Example of iterating through one batch
for batch in train_loader:
    print("\nBatch shapes:")
    print(f"Input ids: {batch['input_ids'].shape}")
    print(f"Attention mask: {batch['attention_mask'].shape}")
    break



Loading train data: 100%|██████████| 19600/19600 [00:01<00:00, 10101.86it/s]

Training batches: 409






Batch shapes:
Input ids: torch.Size([48, 50])
Attention mask: torch.Size([48, 50])


In [None]:
from tqdm import tqdm  # Import tqdm

## test how long it takes for the model to perform forward passes on all of the train batches
model.eval()

for batch in tqdm(train_loader):  # Wrap train_loader with tqdm
    input_ids = batch['input_ids'].to(model.device)
    attention_mask = batch['attention_mask'].to(model.device)

    # Forward pass
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)

100%|██████████| 77/77 [01:22<00:00,  1.08s/it]
