#Set-up

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!pip install transformers

Mounted at /content/drive


In [None]:
# set seeds
import random
import numpy as np
import torch

def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)

set_seed(42)

# Load NT model

In [None]:
"loading smallest nucleotide transformer (50m params)"


from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch

num_params = 50 ## default 50

# Import the tokenizer and the model
tokenizer_nt = AutoTokenizer.from_pretrained(f"InstaDeepAI/nucleotide-transformer-v2-{num_params}m-multi-species", trust_remote_code=True)
model_nt = AutoModelForMaskedLM.from_pretrained(f"InstaDeepAI/nucleotide-transformer-v2-{num_params}m-multi-species", trust_remote_code=True)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
!pip install -U bitsandbytes --upgrade
!pip install -U accelerate
!python -m bitsandbytes


Collecting bitsandbytes
  Downloading bitsandbytes-0.45.0-py3-none-manylinux_2_24_x86_64.whl.metadata (2.9 kB)
Downloading bitsandbytes-0.45.0-py3-none-manylinux_2_24_x86_64.whl (69.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.1/69.1 MB[0m [31m31.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.45.0
Collecting accelerate
  Downloading accelerate-1.3.0-py3-none-any.whl.metadata (19 kB)
Downloading accelerate-1.3.0-py3-none-any.whl (336 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m336.6/336.6 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: accelerate
  Attempting uninstall: accelerate
    Found existing installation: accelerate 1.2.1
    Uninstalling accelerate-1.2.1:
      Successfully uninstalled accelerate-1.2.1
Successfully installed accelerate-1.3.0
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
++++++++++++++++++ BU

In [None]:
import transformers
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import bitsandbytes as bnb

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("metagene-ai/METAGENE-1-BnB-4Bit", '_load_in_4bit')
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

# Define quantization config
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype="float16",
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True
)

# Load model with quantization config
model = AutoModelForCausalLM.from_pretrained(
    "metagene-ai/METAGENE-1-BnB-4Bit",
    quantization_config=quantization_config,
    device_map="auto"
)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/1.38k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/41.9k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/833 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.33k [00:00<?, ?B/s]

Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


model.safetensors:   0%|          | 0.00/3.36G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

# Load and preprocess addgene dataset

In [None]:
import pandas as pd


# Constants
TEST_DATA_PATH = '/content/drive/MyDrive/NOO_paper/Datasets/WorldWide/BLAST_geac_ext_169k_val_random.csv'
TRAIN_DATA_PATH = '/content/drive/MyDrive/NOO_paper/Datasets/WorldWide/BLAST_geac_ext_169k_train_random.csv'
INFREQUENT_THRESHOLD = 10

def split_test_data(test_data):
    """Split test data into input and target variables."""
    y_test = test_data['nations']
    x_test = test_data[['sequence']]
    return x_test, y_test

def replace_infrequent_labels(labels, threshold=INFREQUENT_THRESHOLD):
    """Identify and replace infrequent labels."""
    label_counts = labels.value_counts()
    infrequent_labels = label_counts[label_counts < threshold].index
    return labels.replace(infrequent_labels, 'infrequent')

def map_labels_to_integers(labels):
    """Map labels to integers."""
    unique_labels = labels.unique()
    return {label: int(i) for i, label in enumerate(unique_labels)}

def without_US(data):
    """Filter out rows where the nation is 'UNITED STATES'."""
    data_wo_US = data[data['nations'] != 'UNITED STATES']
    data_wo_US.reset_index(drop=True, inplace=True)

    data_w_US = data[data['nations'] == 'UNITED STATES']
    data_w_US.reset_index(drop=True, inplace=True)
    return data_wo_US, data_w_US

def US_vs_them(labels):
    """Categorize labels into 'UNITED STATES' and 'NON US'."""
    return labels.apply(lambda x: x if x == 'UNITED STATES' else 'NON US')

def pad_sequence(seq, length, pad_char='N'):
    """Pad sequences to the specified length with the given character."""
    return seq.ljust(length, pad_char)[:length]

# Load data
train_data = pd.read_csv(TRAIN_DATA_PATH)
test_data = pd.read_csv(TEST_DATA_PATH)

print(f'test_data shape: {test_data.shape}')

# Remove US
# train_data, train_data_US = without_US(train_data)
# test_data, test_data_US = without_US(test_data)

print(f'test_data shape: {test_data.shape}')

# Split data
x_train, y_train = train_data[['sequence']], train_data['nations']
x_test, y_test = split_test_data(test_data)

print(f'test_data shape: {y_test.shape}')
print(f'x_train shape: {x_train.shape}')
print(f'y_train shape: {y_train.shape}')

# Combine labels from train and test datasets
processed_labels = pd.concat([y_train, y_test], axis=0, ignore_index=True)
label_to_int = map_labels_to_integers(processed_labels)


# map labels to integers
y_train = y_train.map(label_to_int)
y_test = y_test.map(label_to_int)

print(f'y_test shape: {y_test.shape}')


# reset indices before concat
x_train.reset_index(drop=True, inplace=True)
y_train.reset_index(drop=True, inplace=True)
x_test.reset_index(drop=True, inplace=True)
y_test.reset_index(drop=True, inplace=True)

df_train = pd.concat([x_train, y_train], axis=1)
df_val = pd.concat([x_test, y_test], axis=1)

print(f'test_data shape: {test_data.shape}')


# Filter out sequences shorter than min_length and clean them
min_length = 0
df_train = df_train[df_train['sequence'].str.len() > min_length]
df_val = df_val[df_val['sequence'].str.len() > min_length]

print(f'test_data shape: {test_data.shape}')


# Ensure indices are reset correctly
df_train.reset_index(drop=True, inplace=True)
df_val.reset_index(drop=True, inplace=True)

# Display the split data
print("Train Data Shape:", df_train.shape)
print("Validation Data Shape:", df_val.shape)


ParserError: Error tokenizing data. C error: Calling read(nbytes) on source failed. Try engine='python'.

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np

class GenomicDataset(Dataset):
    def __init__(self,
                 ds: pd.DataFrame,
                 tokenizer_nt,
                 seq_length: int = 8000):


        self.sequences = ds['sequence']
        self.labels = ds['nations']
        self.seq_len = seq_length
        self.tokenizer = tokenizer_nt



    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences.iloc[idx]
        label = self.labels.iloc[idx]

        # Tokenize the sequence
        inputs = self.tokenizer(sequence, max_length=512, padding='max_length', truncation=True, return_tensors="pt")
        input_ids = inputs['input_ids'].squeeze()  # Remove batch dimension
        attention_mask = inputs['attention_mask'].squeeze()  # Remove batch dimension

        # to torch tensors
        label = torch.tensor(label, dtype=torch.long)

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'label': label
        }

# Parameters

val_dataset = GenomicDataset(df_val, tokenizer_nt=tokenizer_nt)
train_dataset = GenomicDataset(df_train, tokenizer_nt=tokenizer_nt)

BS = 64

val_loader_dna = DataLoader(val_dataset, batch_size=BS, shuffle=False, pin_memory=True, num_workers=2)
train_loader_dna = DataLoader(train_dataset, batch_size=BS, shuffle=True, pin_memory=True, num_workers=2)

# Load pretraining data of NT

In [None]:
!pip install datasets
!pip install huggingface_hub
!pip install biopython

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (

## Multispecies

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from typing import Optional
import numpy as np
from tqdm import tqdm



class StreamingGenomicDataset(Dataset):
    def __init__(
        self,
        split: str = "train",
        tokenizer = None,
        seq_length: int = 512,
        max_samples: Optional[int] = None,
        cache_mode: bool = True
    ):
        """
        Streaming dataset for genomic sequences.

        Args:
            split: Dataset split ('train', 'validation', 'test')
            tokenizer: Tokenizer for DNA sequences
            seq_length: Maximum sequence length
            max_samples: Maximum number of samples to load (None for all)
            cache_mode: If True, caches all sequences in memory
        """
        self.seq_length = seq_length
        self.tokenizer = tokenizer

        # Load dataset in streaming mode
        dataset = load_dataset("InstaDeepAI/multi_species_genomes", split=split, streaming=True)

        if cache_mode:
            # Cache all sequences in memory
            self.sequences = []
            pbar = tqdm(dataset, total=max_samples, desc=f"Loading {split} data")

            for i, item in enumerate(pbar):
                if max_samples and i >= max_samples:
                    break
                self.sequences.append(item)
        else:
            # Store iterator for streaming mode
            self.sequences = dataset
            self.max_samples = max_samples

    def __len__(self):
        if isinstance(self.sequences, list):
            return len(self.sequences)
        return self.max_samples if self.max_samples else int(1e9)  # Large number for streaming

    def __getitem__(self, idx):
        if isinstance(self.sequences, list):
            # Cached mode
            item = self.sequences[idx]
        else:
            # Streaming mode
            item = next(iter(self.sequences))

        sequence = item['sequence']

        # Tokenize the sequence
        inputs = self.tokenizer(
            sequence,
            max_length=self.seq_length,
            padding='max_length',
            truncation=True,
            return_tensors="pt"
        )

        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'sequence_info': {
                'description': item['description'],
                'start_pos': item['start_pos'],
                'end_pos': item['end_pos']
            }
        }

def create_genomic_dataloaders(
    tokenizer,
    batch_size: int = 32,
    seq_length: int = 512,
    max_samples: Optional[int] = None,
    num_workers: int = 2,
    cache_mode: bool = True
):
    """
    Create training and validation DataLoaders for genomic data.

    Args:
        tokenizer: DNA sequence tokenizer
        batch_size: Batch size for DataLoader
        seq_length: Maximum sequence length
        max_samples: Maximum samples per split (None for all)
        num_workers: Number of DataLoader workers
        cache_mode: If True, caches all sequences in memory
    """
    # Create datasets
    train_dataset = StreamingGenomicDataset(
        split="train",
        tokenizer=tokenizer,
        seq_length=seq_length,
        max_samples=max_samples,
        cache_mode=cache_mode
    )

    val_dataset = StreamingGenomicDataset(
        split="validation",
        tokenizer=tokenizer,
        seq_length=seq_length,
        max_samples=max_samples//10,
        cache_mode=cache_mode
    )

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=cache_mode,  # Can only shuffle if data is cached
        num_workers=num_workers,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    return train_loader, val_loader

# Example usage
# Create dataloaders with small sample size for testing
train_loader, val_loader = create_genomic_dataloaders(
    tokenizer=tokenizer,
    batch_size=32,
    seq_length=512,
    max_samples=19600,  # Small sample size for testing
    cache_mode=True
)

# Print dataset sizes
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

# Example of iterating through one batch
for batch in train_loader:
    print("\nBatch shapes:")
    print(f"Input ids: {batch['input_ids'].shape}")
    print(f"Attention mask: {batch['attention_mask'].shape}")
    break



Loading train data: 100%|██████████| 19600/19600 [00:18<00:00, 1051.35it/s]
Loading validation data: 100%|██████████| 1960/1960 [00:01<00:00, 1191.80it/s]

Training batches: 613
Validation batches: 62






Batch shapes:
Input ids: torch.Size([32, 512])
Attention mask: torch.Size([32, 512])


In [None]:
# Example of iterating through one batch
for batch in train_loader:
    print("\nBatch shapes:")
    print(f"Input ids: {batch['input_ids']}")
    print(f"Attention mask: {batch['attention_mask'].shape}")
    break



Batch shapes:
Input ids: tensor([[  6,  57,  64,  ...,  23,  26, 234],
        [941,  50,  27,  ...,  40, 116,  63],
        [  6, 828,  73,  ..., 491, 361,  32],
        ...,
        [157, 337,  41,  ..., 160,  58, 100],
        [  6,  18, 854,  ...,  84, 772,  64],
        [622, 453,  61,  ..., 340, 952,  25]])
Attention mask: torch.Size([32, 512])


In [None]:
## test how long it takes for the model to perform forward passes on all of these

# Compare MLM loss on plasmids vs multi-specis

In [None]:
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from typing import Dict, List, Optional
from dataclasses import dataclass

@dataclass
class MLMEvalResults:
    total_loss: float
    num_batches: int
    total_tokens: int
    masked_tokens: int
    per_batch_losses: List[float]

class MatchedMLMEvaluator:
    def __init__(self, model, tokenizer, device='cuda'):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.loss_fct = nn.CrossEntropyLoss(reduction='mean')

        # Exact hyperparameters from training
        self.mask_token_id = model.config.mask_token_id
        self.mask_ratio = 0.15
        self.mask_prob = 0.8
        self.random_token_prob = 0.1  # Changed from 0.5 to match training
        self.pad_token_id = model.config.pad_token_id

    def create_mlm_mask(self, input_ids):
        """
        Create MLM masks exactly matching training setup:
        - 15% of tokens selected for corruption
        - Of these:
          - 80% replaced with [MASK]
          - 10% replaced with random token
          - 10% unchanged
        """
        probability_matrix = torch.full(input_ids.shape, self.mask_ratio, device=self.device)

        # Don't mask padding tokens
        special_tokens_mask = input_ids == self.pad_token_id
        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)

        # Select tokens for corruption
        corrupted_indices = torch.bernoulli(probability_matrix).bool()

        # Prepare outputs
        labels = input_ids.clone()
        labels[~corrupted_indices] = -100  # Only compute loss on corrupted tokens

        # Copy input_ids for masking
        masked_input_ids = input_ids.clone()

        # For corrupted tokens:
        # - 80% [MASK]
        indices_mask = corrupted_indices & (torch.rand_like(probability_matrix) < self.mask_prob)
        masked_input_ids[indices_mask] = self.mask_token_id

        # - 10% random token
        indices_random = corrupted_indices & ~indices_mask & (torch.rand_like(probability_matrix) < self.random_token_prob)
        random_words = torch.randint(4, self.model.config.vocab_size, labels.shape, device=self.device)
        masked_input_ids[indices_random] = random_words[indices_random]

        # - 10% unchanged (already handled by not modifying those positions)

        return masked_input_ids, labels, corrupted_indices

    def evaluate_batch(self, batch: Dict[str, torch.Tensor]):
        """Evaluate a single batch with exact training settings"""
        self.model.eval()
        with torch.no_grad():
            # Move batch to device
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch.get('attention_mask', torch.ones_like(input_ids)).to(self.device)

            # Create MLM masks
            masked_input_ids, labels, corrupted_indices = self.create_mlm_mask(input_ids)

            # Forward pass
            outputs = self.model(
                input_ids=masked_input_ids,
                attention_mask=attention_mask
            )

            # Calculate loss
            logits = outputs.logits

            # Only compute loss on corrupted tokens
            active_logits = logits[corrupted_indices]
            active_labels = labels[corrupted_indices]

            loss = self.loss_fct(
                active_logits.view(-1, self.model.config.vocab_size),
                active_labels.view(-1)
            )

            return (
                loss.item(),
                corrupted_indices.sum().item(),
                attention_mask.sum().item()
            )

    def evaluate_dataset(self, dataloader, num_batches=None):
        """Evaluate entire dataset"""
        total_loss = 0
        total_tokens = 0
        total_masked = 0
        batch_losses = []

        iterator = tqdm(dataloader) if num_batches is None else tqdm(list(islice(dataloader, num_batches)))

        for batch in iterator:
            loss, num_masked, num_tokens = self.evaluate_batch(batch)
            total_loss += loss
            total_masked += num_masked
            total_tokens += num_tokens
            batch_losses.append(loss)

        return MLMEvalResults(
            total_loss=total_loss,
            num_batches=len(batch_losses),
            total_tokens=total_tokens,
            masked_tokens=total_masked,
            per_batch_losses=batch_losses
        )

def print_evaluation_results(name: str, results: MLMEvalResults):
    """Print evaluation results"""
    print(f"\n{name} Dataset Results:")
    print("-" * 60)
    print(f"Average Loss: {results.total_loss / results.num_batches:.4f}")
    print(f"Mask Ratio: {results.masked_tokens / results.total_tokens:.1%}")
    print(f"Total Tokens: {results.total_tokens}")
    print(f"Masked Tokens: {results.masked_tokens}")

In [None]:
# Initialize evaluator with matched hyperparameters
model_nt = model_nt.cuda()
evaluator = MatchedMLMEvaluator(model_nt, tokenizer_nt)

# Evaluate both datasets
multi_results = evaluator.evaluate_dataset(train_loader, num_batches=10)
print_evaluation_results("Multi-species", multi_results)

addgene_results = evaluator.evaluate_dataset(train_loader_dna, num_batches=10)
print_evaluation_results("Addgene", addgene_results)

100%|██████████| 10/10 [00:03<00:00,  3.05it/s]


Multi-species Dataset Results:
------------------------------------------------------------
Average Loss: 6.2668
Mask Ratio: 15.0%
Total Tokens: 327680
Masked Tokens: 49089



100%|██████████| 10/10 [00:03<00:00,  3.04it/s]


Addgene Dataset Results:
------------------------------------------------------------
Average Loss: 6.0871
Mask Ratio: 14.9%
Total Tokens: 273280
Masked Tokens: 40768





### Test

In [None]:
import torch
import torch.nn as nn
from typing import Dict, Optional
from dataclasses import dataclass
import numpy as np

class MLMEvaluatorTests:
    def __init__(self, evaluator, model, tokenizer):
        """
        Initialize testing suite for MLM evaluation.

        Args:
            evaluator: MLMDatasetEvaluator instance
            model: The transformer model
            tokenizer: The tokenizer used
        """
        self.evaluator = evaluator
        self.model = model
        self.tokenizer = tokenizer

    def run_all_checks(self, batch):
        """Run all sanity checks on a single batch."""
        results = {}
        print("Running sanity checks...")

        # Test 1: Check mask token application
        results["mask_check"] = self.check_mask_application(batch)

        # Test 2: Check loss computation
        results["loss_check"] = self.check_loss_computation(batch)

        # Test 3: Check attention mask handling
        results["attention_check"] = self.check_attention_mask(batch)

        # Test 4: Check token distributions
        results["token_dist_check"] = self.check_token_distribution(batch)

        # Test 5: Check model output shapes
        results["shape_check"] = self.check_output_shapes(batch)

        return results

    def check_mask_application(self, batch, mask_ratio=0.15):
        """Verify masking is applied correctly."""
        with torch.no_grad():
            input_ids = batch['input_ids'].to(self.evaluator.device)
            masked_input_ids = input_ids.clone()

            # Generate mask
            rand = torch.rand(input_ids.shape, device=self.evaluator.device)
            mask_arr = (rand < mask_ratio) * (input_ids != 0) * (input_ids != 1)

            # Apply masking
            masked_input_ids[mask_arr] = 3  # mask token ID

            # Compute statistics
            total_tokens = (input_ids != 0).sum().item()
            masked_tokens = mask_arr.sum().item()
            actual_ratio = masked_tokens / total_tokens if total_tokens > 0 else 0

            result = {
                "pass": abs(actual_ratio - mask_ratio) < 0.05,  # Within 5% of target
                "target_ratio": mask_ratio,
                "actual_ratio": actual_ratio,
                "total_tokens": total_tokens,
                "masked_tokens": masked_tokens
            }

            print(f"\nMask Application Check:")
            print(f"Target mask ratio: {mask_ratio:.3f}")
            print(f"Actual mask ratio: {actual_ratio:.3f}")
            print(f"Status: {'PASS' if result['pass'] else 'FAIL'}")

            return result

    def check_loss_computation(self, batch):
        """Verify loss computation is reasonable."""
        with torch.no_grad():
            loss, num_masked, ratio = self.evaluator.calculate_batch_mlm_loss(batch)

            result = {
                "pass": 0 < loss < 20,  # Reasonable range for cross-entropy loss
                "loss_value": loss,
                "num_masked": num_masked,
                "ratio": ratio
            }

            print(f"\nLoss Computation Check:")
            print(f"Loss value: {loss:.3f}")
            print(f"Status: {'PASS' if result['pass'] else 'FAIL'}")

            return result

    def check_attention_mask(self, batch):
        """Verify attention mask is being properly applied."""
        attention_mask = batch.get('attention_mask')
        if attention_mask is None:
            print("\nAttention Mask Check: SKIP - No attention mask provided")
            return {"pass": None, "message": "No attention mask"}

        with torch.no_grad():
            # Check if mask aligns with pad tokens
            input_ids = batch['input_ids']
            pad_tokens = (input_ids == 0)
            mask_match = (attention_mask == 0) == pad_tokens

            result = {
                "pass": mask_match.all().item(),
                "matching_ratio": mask_match.float().mean().item()
            }

            print(f"\nAttention Mask Check:")
            print(f"Mask-padding alignment: {result['matching_ratio']:.3%}")
            print(f"Status: {'PASS' if result['pass'] else 'FAIL'}")

            return result

    def check_token_distribution(self, batch):
        """Check distribution of tokens in batch."""
        input_ids = batch['input_ids']

        # Get token counts
        unique_tokens, counts = torch.unique(input_ids, return_counts=True)
        total_tokens = input_ids.numel()

        # Calculate distribution
        distribution = {
            self.tokenizer.convert_ids_to_tokens(t.item()):
            (c.item() / total_tokens)
            for t, c in zip(unique_tokens, counts)
        }

        result = {
            "pass": len(distribution) > 1,  # Should have multiple token types
            "distribution": distribution,
            "unique_tokens": len(distribution)
        }

        print(f"\nToken Distribution Check:")
        print(f"Unique tokens: {len(distribution)}")
        print(f"Top 5 tokens: {dict(sorted(distribution.items(), key=lambda x: x[1], reverse=True)[:5])}")
        print(f"Status: {'PASS' if result['pass'] else 'FAIL'}")

        return result

    def check_output_shapes(self, batch):
        """Verify model output shapes are correct."""
        with torch.no_grad():
            input_ids = batch['input_ids'].to(self.evaluator.device)
            attention_mask = batch.get('attention_mask')

            if attention_mask is not None:
                attention_mask = attention_mask.to(self.evaluator.device)
                outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
            else:
                outputs = self.model(input_ids=input_ids)

            # Check shapes
            batch_size, seq_len = input_ids.shape
            expected_logits_shape = (batch_size, seq_len, self.tokenizer.vocab_size)

            result = {
                "pass": outputs.logits.shape == expected_logits_shape,
                "expected_shape": expected_logits_shape,
                "actual_shape": tuple(outputs.logits.shape)
            }

            print(f"\nOutput Shapes Check:")
            print(f"Expected shape: {expected_logits_shape}")
            print(f"Actual shape: {tuple(outputs.logits.shape)}")
            print(f"Status: {'PASS' if result['pass'] else 'FAIL'}")

            return result

def print_test_summary(test_results: Dict):
    """Print summary of all test results."""
    print("\n" + "="*50)
    print("MLM Evaluator Test Summary")
    print("="*50)

    all_passed = True
    for test_name, result in test_results.items():
        if result.get('pass') is not None:  # Skip tests that weren't run
            status = 'PASS' if result['pass'] else 'FAIL'
            print(f"{test_name:<20}: {status}")
            all_passed = all_passed and result['pass']

    print("\nOverall Status:", "PASS" if all_passed else "FAIL")
    print("="*50)

# Initialize evaluator and tests
evaluator = MLMDatasetEvaluator(model_nt)
test_suite = MLMEvaluatorTests(evaluator, model_nt, tokenizer_nt)

# Run tests on a single batch from each dataset
print("\nTesting Multi-species dataset:")
for batch in train_loader:
    test_results = test_suite.run_all_checks(batch)
    print_test_summary(test_results)
    break

print("\nTesting Human Reference dataset:")
for batch in train_loader_dna:
    test_results = test_suite.run_all_checks(batch)
    print_test_summary(test_results)
    break

# If tests pass, proceed with main evaluation
if all(result.get('pass', False) for result in test_results.values()):
    print("\nAll tests passed! Proceeding with main evaluation...")
    dataloaders = {
        "Multi-species": train_loader,
        "Human Reference": train_loader_dna
    }
    results = evaluator.compare_datasets(dataloaders, num_batches=10)
    print_comparison_results(results)
else:
    print("\nSome tests failed! Please check the results above.")


Testing Multi-species dataset:
Running sanity checks...

Mask Application Check:
Target mask ratio: 0.150
Actual mask ratio: 0.151
Status: PASS

Loss Computation Check:
Loss value: 8.035
Status: PASS

Attention Mask Check:
Mask-padding alignment: 100.000%
Status: PASS

Token Distribution Check:
Unique tokens: 4102
Top 5 tokens: {'<cls>': 0.001953125, 'CGCCGC': 0.001373291015625, 'GCGGCG': 0.00136566162109375, 'GCCGCC': 0.00131988525390625, 'GCGCCG': 0.001251220703125}
Status: PASS

Output Shapes Check:
Expected shape: (256, 512, 4107)
Actual shape: (256, 512, 4107)
Status: PASS

MLM Evaluator Test Summary
mask_check          : PASS
loss_check          : PASS
attention_check     : PASS
token_dist_check    : PASS
shape_check         : PASS

Overall Status: PASS

Testing Human Reference dataset:
Running sanity checks...

Mask Application Check:
Target mask ratio: 0.150
Actual mask ratio: 0.149
Status: PASS

Loss Computation Check:
Loss value: 7.954
Status: PASS

Attention Mask Check:
Mas

Evaluating Multi-species: 100%|██████████| 10/10 [00:12<00:00,  1.28s/it]


Evaluating dataset: Human Reference



Evaluating Human Reference: 100%|██████████| 10/10 [00:25<00:00,  2.54s/it]


MLM Loss Comparison Results:
------------------------------------------------------------
Dataset                Avg Loss  Per Token Mask Ratio
------------------------------------------------------------
Multi-species            8.0329     0.0004    100.00%
Human Reference          7.9300     0.0002    100.00%





## Human ref

In [None]:
# Install required packages
!pip install pysam pandas

import subprocess
import pandas as pd
import pysam
import os

class GenomeDataLoader:
    def __init__(self, data_dir="./genome_data"):
        """Initialize the genome data loader."""
        self.data_dir = data_dir
        os.makedirs(data_dir, exist_ok=True)

    def download_hg38(self):
        """Download the latest human reference genome (hg38)."""
        print("Downloading hg38 reference genome...")
        # Install required tools
        subprocess.run("apt-get update && apt-get install -y wget samtools", shell=True)

        # Download from UCSC
        genome_url = "https://hgdownload.soe.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz"
        subprocess.run(f"wget {genome_url} -P {self.data_dir}", shell=True)

        # Decompress
        subprocess.run(f"gunzip {self.data_dir}/hg38.fa.gz", shell=True)

        # Index the genome
        subprocess.run(f"samtools faidx {self.data_dir}/hg38.fa", shell=True)

        return f"{self.data_dir}/hg38.fa"

    def download_1000g_vcf(self, chromosome="chr1"):
        """
        Download 1000 Genomes Project VCF for specified chromosome.
        """
        print(f"Downloading 1000G data for {chromosome}...")
        # Install required tools
        subprocess.run("apt-get update && apt-get install -y tabix", shell=True)

        # Download from 1000G FTP
        base_url = "http://ftp.1000genomes.ebi.ac.uk/vol1/ftp/data_collections/1000_genomes_project/release/20190312_biallelic_SNV_and_INDEL"
        vcf_url = f"{base_url}/{chromosome}.vcf.gz"

        subprocess.run(f"wget {vcf_url} -P {self.data_dir}", shell=True)
        subprocess.run(f"wget {vcf_url}.tbi -P {self.data_dir}", shell=True)

        return f"{self.data_dir}/{chromosome}.vcf.gz"

    def read_reference_sequence(self, fasta_file, chromosome, start, end):
        """
        Read a sequence from the reference genome.
        """
        with pysam.FastaFile(fasta_file) as fasta:
            return fasta.fetch(chromosome, start, end)

    def read_variants(self, vcf_file, chromosome, start, end):
        """
        Read variants from 1000G VCF file.
        """
        variants = []
        with pysam.VariantFile(vcf_file) as vcf:
            for record in vcf.fetch(chromosome, start, end):
                variants.append({
                    'position': record.pos,
                    'reference': record.ref,
                    'alternate': record.alts[0],
                    'allele_freq': record.info.get('AF', [None])[0]
                })
        return pd.DataFrame(variants)

    def get_sequence_with_variants(self, ref_file, vcf_file, chromosome, start, end):
        """
        Get reference sequence and its variants in the specified region.
        """
        sequence = self.read_reference_sequence(ref_file, chromosome, start, end)
        variants = self.read_variants(vcf_file, chromosome, start, end)

        return {
            'sequence': sequence,
            'variants': variants
        }

def example_usage():
    """Example usage of the GenomeDataLoader class."""
    loader = GenomeDataLoader()

    # Download reference genome
    ref_file = loader.download_hg38()

    # Download 1000G data for chromosome 1
    vcf_file = loader.download_1000g_vcf("chr1")

    # Get sequence and variants for a specific region
    region = loader.get_sequence_with_variants(
        ref_file=ref_file,
        vcf_file=vcf_file,
        chromosome="chr1",
        start=1000000,
        end=1001000
    )

    print("\nReference sequence:")
    print(region['sequence'][:100] + "...")
    print("\nVariants found:")
    print(region['variants'].head())



# Copy and paste the code above
# Then initialize and use:
loader = GenomeDataLoader()
ref_file = loader.download_hg38()
#vcf_file = loader.download_1000g_vcf("chr1")

Collecting pysam
  Downloading pysam-0.22.1-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (1.5 kB)
Downloading pysam-0.22.1-cp310-cp310-manylinux_2_28_x86_64.whl (22.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.0/22.0 MB[0m [31m84.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pysam
Successfully installed pysam-0.22.1
Downloading hg38 reference genome...


KeyboardInterrupt: 

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import pysam
import numpy as np
from typing import Optional

class Hg38Dataset(Dataset):
    def __init__(
        self,
        fasta_file: str,
        tokenizer,
        seq_length: int = 512,
        chromosomes: Optional[list] = None,
        stride: Optional[int] = None
    ):
        """
        Dataset for hg38 reference genome sequences.

        Args:
            fasta_file: Path to the hg38.fa file
            tokenizer: DNA tokenizer
            seq_length: Length of sequences to return
            chromosomes: List of chromosomes to use (default: chr1-22,X,Y)
            stride: Stride length for splitting sequences (default: seq_length)
        """
        self.fasta_file = fasta_file
        self.tokenizer = tokenizer
        self.seq_length = seq_length
        self.stride = stride if stride else seq_length

        # Initialize chromosome list
        if chromosomes is None:
            self.chromosomes = [f"chr{i}" for i in range(1, 23)]
            self.chromosomes.extend(["chrX", "chrY"])
        else:
            self.chromosomes = chromosomes

        # Calculate number of sequences per chromosome
        self.sequence_indices = []
        with pysam.FastaFile(self.fasta_file) as fasta:
            for chrom in self.chromosomes:
                chrom_length = fasta.get_reference_length(chrom)
                num_sequences = (chrom_length - self.seq_length) // self.stride + 1

                for i in range(num_sequences):
                    start = i * self.stride
                    end = start + self.seq_length
                    if end <= chrom_length:
                        self.sequence_indices.append({
                            'chromosome': chrom,
                            'start': start,
                            'end': end
                        })

    def __len__(self):
        return len(self.sequence_indices)

    def __getitem__(self, idx):
        # Get sequence location
        loc = self.sequence_indices[idx]

        # Read sequence
        with pysam.FastaFile(self.fasta_file) as fasta:
            sequence = fasta.fetch(
                loc['chromosome'],
                loc['start'],
                loc['end']
            )

        # Tokenize sequence
        inputs = self.tokenizer(
            sequence,
            max_length=self.seq_length,
            padding='max_length',
            truncation=True,
            return_tensors="pt"
        )

        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'position': torch.tensor([loc['start'], loc['end']], dtype=torch.long),
            'chromosome': loc['chromosome']
        }

def create_genome_dataloaders(
    fasta_file: str,
    tokenizer,
    seq_length: int = 512,
    batch_size: int = 32,
    val_chromosomes: Optional[list] = None,
    num_workers: int = 2
):
    """
    Create train and validation dataloaders for genomic data.

    Args:
        fasta_file: Path to hg38.fa file
        tokenizer: DNA tokenizer
        seq_length: Sequence length
        batch_size: Batch size
        val_chromosomes: List of chromosomes to use for validation (default: chr21, chr22)
        num_workers: Number of worker processes for data loading
    """
    if val_chromosomes is None:
        val_chromosomes = ['chr21', 'chr22']

    # Create list of training chromosomes (all except validation chromosomes)
    train_chromosomes = [f"chr{i}" for i in range(1, 23)]
    train_chromosomes.extend(['chrX', 'chrY'])
    train_chromosomes = [c for c in train_chromosomes if c not in val_chromosomes]

    # Create datasets
    train_dataset = Hg38Dataset(
        fasta_file=fasta_file,
        tokenizer=tokenizer,
        seq_length=seq_length,
        chromosomes=train_chromosomes
    )

    val_dataset = Hg38Dataset(
        fasta_file=fasta_file,
        tokenizer=tokenizer,
        seq_length=seq_length,
        chromosomes=val_chromosomes
    )

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    return train_loader, val_loader

# Example usage
def example_usage(fasta_file, tokenizer):
    """
    Example of how to use the dataloaders.
    """
    train_loader, val_loader = create_genome_dataloaders(
        fasta_file=fasta_file,
        tokenizer=tokenizer,
        seq_length=512,
        batch_size=32
    )

    # Print dataset sizes
    print(f"Training batches: {len(train_loader)}")
    print(f"Validation batches: {len(val_loader)}")

    # Example of iterating through data
    for batch in train_loader:
        input_ids = batch['input_ids']  # Shape: [batch_size, seq_length]
        attention_mask = batch['attention_mask']  # Shape: [batch_size, seq_length]
        positions = batch['position']  # Shape: [batch_size, 2]
        chromosomes = batch['chromosome']  # List of chromosome names
        break

    return input_ids.shape, attention_mask.shape

# Create dataloaders
train_loader, val_loader = create_genome_dataloaders(
    fasta_file="/content/genome_data/hg38.fa",
    tokenizer=tokenizer_nt,
    seq_length=512,  # in tokens
    batch_size=512    # adjust based on your GPU memory
)

In [None]:
for batch in train_loader:
    input_ids = batch['input_ids']  # Shape: [batch_size, seq_length]
    attention_mask = batch['attention_mask']  # Shape: [batch_size, seq_length]
    print(input_ids)
    break

In [None]:
for batch in train_loader_dna:
    input_ids = batch['input_ids']  # Shape: [batch_size, seq_length]
    attention_mask = batch['attention_mask']  # Shape: [batch_size, seq_length]
    print(input_ids)
    break

# Set-up & Load SAE

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

cfg = {
    "seed": 49,
    "batch_size": 4096*6,
    "buffer_mult": 384,
    "lr": 5e-5,
    "num_tokens": tokenizer_nt.vocab_size,
    "d_model": 512,
    "l1_coeff": 1e-1,
    "beta1": 0.9,
    "beta2": 0.999,
    "dict_mult": 8, # hidden_d = d_model * dict_mult
    "seq_len": 512,
    "d_mlp": 512,
    "enc_dtype":"fp32",
    "remove_rare_dir": False,
    "total_training_steps": 10000,
    "lr_warm_up_steps": 1000,
    "device": "cuda"
}
cfg["model_batch_size"] = 64
cfg["buffer_size"] = cfg["batch_size"] * cfg["buffer_mult"]
cfg["buffer_batches"] = cfg["buffer_size"] // cfg["seq_len"]

DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}

class AutoEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        # HP-choices
        d_hidden = cfg["d_mlp"] * cfg["dict_mult"]
        d_mlp = cfg["d_mlp"]
        self.l0_coeff = cfg.get("l0_coeff", 5)
        self.threshold = cfg.get("activation_threshold", 0.3)
        # Temperature for sigmoid approximation
        self.temperature = cfg.get("temperature", 1.0)
        dtype = DTYPES[cfg["enc_dtype"]]
        torch.manual_seed(cfg["seed"])

        self.W_enc = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_mlp, d_hidden, dtype=dtype)))
        self.W_dec = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, d_mlp, dtype=dtype)))
        self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))
        self.b_dec = nn.Parameter(torch.zeros(d_mlp, dtype=dtype))
        self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)

        self.d_hidden = d_hidden
        self.to("cuda")

    def get_continuous_l0(self, x):
        """
        Compute continuous relaxation of L0 norm using sigmoid
        This provides useful gradients unlike the discrete L0
        """
        # Shifted sigmoid to approximate step function
        return torch.sigmoid((x.abs() - self.threshold) / self.temperature)

    def forward(self, x):
        # encoding and decoding of input vec
        x_cent = x - self.b_dec
        pre_acts = x_cent @ self.W_enc + self.b_enc
        acts = F.relu(pre_acts)

        # Compute continuous L0 approximation before thresholding
        l0_proxy = self.get_continuous_l0(acts)

        # Apply hard threshold for forward pass --- This is actually jumprelu (I think!)
        acts_sparse = (acts.abs() > self.threshold).float() * acts
        x_reconstruct = acts_sparse @ self.W_dec + self.b_dec

        # L2 Loss (Reconstruction Loss)
        l2_loss = F.mse_loss(x_reconstruct.float(), x.float(), reduction='none')
        l2_loss = l2_loss.sum(-1)
        l2_loss = l2_loss.mean()

        # Normalized MSE for reporting
        nmse = torch.norm(x - x_reconstruct, p=2) / torch.norm(x, p=2)

        # Continuous L0 loss (using sigmoid approximation)
        l0_loss = l0_proxy.sum(dim=1).mean()

        # Total Loss: reconstruction + sparsity
        loss = l2_loss + self.l0_coeff * l0_loss

        # For monitoring: true L0 count (not used in optimization)
        true_l0 = (acts_sparse.float().abs() > 0).float().sum(dim=1).mean()

        # For monitoring: L1 loss
        l1_loss = acts_sparse.float().abs().sum(-1).mean()

        return loss, x_reconstruct, acts_sparse, l2_loss, nmse, l1_loss, true_l0

    @torch.no_grad()
    def remove_parallel_component_of_grads(self):
        W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
        W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed
        self.W_dec.grad -= W_dec_grad_proj



sae_model = AutoEncoder(cfg)
sae_res = AutoEncoder(cfg)
sae_l10 = AutoEncoder(cfg)


## Load already-trained SAE

In [None]:
weights_path = "/content/drive/MyDrive/SAEs_for_Genomics/Weights/nt18.5m_sae_l10_2024-11-13.pt"
state_dict = torch.load(weights_path, weights_only=True)
sae_l10.load_state_dict(state_dict)

weights_path = "/content/drive/MyDrive/SAEs_for_Genomics/Weights/nt50m_sae_+40mtokens.pt"
state_dict = torch.load(weights_path, weights_only=True)
sae_model.load_state_dict(state_dict)

weights_path = "/content/drive/MyDrive/SAEs_for_Genomics/Weights/nt15m_sae_final.res_2024-11-25.pt"
state_dict = torch.load(weights_path, weights_only=True)
sae_res.load_state_dict(state_dict)

<All keys matched successfully>

In [None]:
## load custom functions from utils.py

import sys
sys.path.append('//content/drive/MyDrive/SAEs_for_Genomics')

import importlib
import utils
importlib.reload(utils)

<module 'utils' from '//content/drive/MyDrive/SAEs_for_Genomics/utils.py'>

# Eval % of CE Loss reconstructed

In [None]:
import torch
import torch.nn as nn
import numpy as np

class SAEEvaluator:
    def __init__(self, model, sae, layer_N=11):
        """
        Initialize evaluator with model, SAE, and target layer number.

        Args:
            model: The ESM model
            sae: Trained SAE model
            layer_N: Layer number to evaluate (default: 11)
        """
        self.model = model
        self.sae = sae
        self.layer_N = layer_N
        self.original_state = {}
        self.loss_fct = nn.CrossEntropyLoss(reduction='mean')

        # Hyperparameters from training
        self.mask_token_id = model.config.mask_token_id  # Should be 2
        self.pad_token_id = model.config.pad_token_id   # Should be 1
        self.mask_ratio = 0.15
        self.mask_prob = 0.8      # 80% mask token
        self.random_prob = 0.1    # 10% random token, 10% unchanged

        # move to cuda
        self.model.to('cuda')
        self.sae.to('cuda')

    def _get_target_layer(self):
        layer = self.model.esm.encoder.layer[self.layer_N].output
        return layer.dense

    def _store_original_state(self):
        """Store the original layer state."""
        target_layer = self._get_target_layer()
        self.original_state = target_layer.state_dict()

    def _restore_original_state(self):
        """Restore the original layer state."""
        target_layer = self._get_target_layer()
        target_layer.load_state_dict(self.original_state)

    def _replace_activations_with_zeros(self):
        """Replace layer outputs with zeros."""
        target_layer = self._get_target_layer()

        def zero_forward_hook(module, input, output):
            zeros = torch.zeros_like(output)
            # Add debug prints
            print(f"Original output stats - mean: {output.mean():.3f}, std: {output.std():.3f}")
            print(f"Zeroed output stats - mean: {zeros.mean():.3f}, std: {zeros.std():.3f}")
            assert torch.all(zeros == 0), "Not all values are zero!"
            return zeros

        handle = target_layer.register_forward_hook(zero_forward_hook)
        return handle

    def _replace_activations_with_reconstructions(self):
        """Replace layer outputs with SAE reconstructions."""
        target_layer = self._get_target_layer()
        def reconstruction_forward_hook(module, input, output):
            with torch.no_grad():
                loss, reconstructed, hidden, l2_loss, nmse, l1_loss, true_l0 = self.sae(output)
                return reconstructed
        handle = target_layer.register_forward_hook(reconstruction_forward_hook)
        return handle

    def create_mlm_mask(self, input_ids, device='cuda'):
        """
        Create MLM masks matching training setup:
        - 15% of tokens selected for masking
        - Of those:
          - 80% replaced with [MASK]
          - 10% replaced with random token
          - 10% unchanged
        """
        probability_matrix = torch.full(input_ids.shape, self.mask_ratio, device=device)

        # Don't mask padding tokens
        special_tokens_mask = input_ids == self.pad_token_id
        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)

        # Select tokens for corruption
        corrupted_indices = torch.bernoulli(probability_matrix).bool()

        # Prepare outputs
        labels = input_ids.clone()
        labels[~corrupted_indices] = -100  # Only compute loss on corrupted tokens

        # Copy input_ids for masking
        masked_input_ids = input_ids.clone()

        # For corrupted tokens:
        # - 80% [MASK]
        indices_mask = corrupted_indices & (torch.rand_like(probability_matrix) < self.mask_prob)
        masked_input_ids[indices_mask] = self.mask_token_id

        # - 10% random token
        indices_random = corrupted_indices & ~indices_mask & (torch.rand_like(probability_matrix) < self.random_prob)
        random_words = torch.randint(4, self.model.config.vocab_size, labels.shape, device=device)
        masked_input_ids[indices_random] = random_words[indices_random]

        # - 10% unchanged (already handled by not modifying those positions)

        return masked_input_ids, labels, corrupted_indices

    def calculate_mlm_loss(self, batch, device='cuda', mask_arr=None):
        """Calculate MLM loss for a single batch."""
        self.model.eval()
        with torch.no_grad():
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch.get('attention_mask', torch.ones_like(input_ids)).to(device)

            # Create masks
            if mask_arr is None:
                masked_input_ids, labels, mask_arr = self.create_mlm_mask(input_ids, device)
            else:
                # If mask provided, use it but apply same masking strategy
                labels = input_ids.clone()
                labels[~mask_arr] = -100
                masked_input_ids = input_ids.clone()

                # Apply masking strategy to selected tokens
                indices_mask = mask_arr & (torch.rand_like(input_ids, dtype=torch.float) < self.mask_prob)
                indices_random = mask_arr & ~indices_mask & (torch.rand_like(input_ids, dtype=torch.float) < self.random_prob)

                masked_input_ids[indices_mask] = self.mask_token_id
                random_words = torch.randint(4, self.model.config.vocab_size, labels.shape, device=device)
                masked_input_ids[indices_random] = random_words[indices_random]

            # Forward pass
            outputs = self.model(
                input_ids=masked_input_ids,
                attention_mask=attention_mask
            )

            # Calculate loss only on masked tokens
            logits = outputs.logits
            active_loss = labels != -100
            active_logits = logits[active_loss]
            active_labels = labels[active_loss]

            loss = self.loss_fct(
                active_logits.view(-1, self.model.config.vocab_size),
                active_labels.view(-1)
            )

            return loss.item(), mask_arr

    def calculate_percent_loss_recovered(self, batch, device='cuda'):
        """Calculate percentage of loss recovered by SAE reconstruction."""
        self._store_original_state()

        # Calculate original loss and get mask
        ce_original, mask_arr = self.calculate_mlm_loss(batch, device)
        print(f"Original CE Loss: {ce_original}")

        # Use same mask for zero calculation
        zero_handle = self._replace_activations_with_zeros()
        ce_zero, _ = self.calculate_mlm_loss(batch, device, mask_arr)
        print(f"Zero CE Loss: {ce_zero}")
        zero_handle.remove()

        # Use same mask for reconstruction calculation
        reconstruction_handle = self._replace_activations_with_reconstructions()
        ce_reconstruction, _ = self.calculate_mlm_loss(batch, device, mask_arr)
        print(f"Reconstruction CE Loss: {ce_reconstruction}")
        reconstruction_handle.remove()

        self._restore_original_state()

        # Sanity checks
        if ce_zero <= ce_original:
            print("WARNING: Zero loss not higher than original loss!")
            print("Mask percentage:", (mask_arr.sum() / mask_arr.numel()).item())

        percent_recovered = (1 - (ce_reconstruction - ce_original) /
                           (ce_zero - ce_original)) * 100 if ce_zero > ce_original else 0

        return percent_recovered

In [None]:
# Example usage:

# Initialize your model, SAE, and dataloader
model = model_nt
sae = sae_model
dataloader = train_loader_dna

# Create evaluator
# if you have an existing instance
evaluator = SAEEvaluator(model, sae)

# Calculate percent loss recovered on single batch of data_loader
avg_percent_recovered = 0

for i, batch in enumerate(dataloader):
    if i > 10:
        break
    percent_recovered = evaluator.calculate_percent_loss_recovered(batch)
    avg_percent_recovered += percent_recovered

avg_percent_recovered /= len(dataloader)
print(f"Average Percent Loss Recovered: {avg_percent_recovered:.2f}%")



Original CE Loss: 6.964635372161865
Original output stats - mean: 0.000, std: 0.000
Zeroed output stats - mean: 0.000, std: 0.000
Zero CE Loss: 6.96466064453125
Reconstruction CE Loss: 7.017735004425049
Original CE Loss: 6.923696517944336
Original output stats - mean: 0.000, std: 0.000
Zeroed output stats - mean: 0.000, std: 0.000
Zero CE Loss: 6.88441801071167
Reconstruction CE Loss: 6.884451389312744
Mask percentage: 0.12274169921875
Original CE Loss: 6.979600429534912
Original output stats - mean: 0.000, std: 0.000
Zeroed output stats - mean: 0.000, std: 0.000
Zero CE Loss: 6.970797061920166
Reconstruction CE Loss: 7.051527976989746
Mask percentage: 0.114990234375
Original CE Loss: 6.959843158721924
Original output stats - mean: 0.000, std: 0.000
Zeroed output stats - mean: 0.000, std: 0.000
Zero CE Loss: 6.9217352867126465
Reconstruction CE Loss: 7.009720802307129
Mask percentage: 0.12554931640625
Original CE Loss: 6.858223915100098
Original output stats - mean: 0.000, std: 0.000
Z

KeyboardInterrupt: 

## Setup

In [None]:
try:
    # import google.colab # type: ignore
    # from google.colab import output
    %pip install sae-lens transformer-lens circuitsvis
except:
    from IPython import get_ipython  # type: ignore

    ipython = get_ipython()
    assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

Collecting sae-lens
  Downloading sae_lens-5.2.1-py3-none-any.whl.metadata (5.2 kB)
Collecting transformer-lens
  Downloading transformer_lens-2.10.0-py3-none-any.whl.metadata (12 kB)
Collecting circuitsvis
  Downloading circuitsvis-1.43.2-py3-none-any.whl.metadata (2.3 kB)
Collecting automated-interpretability<1.0.0,>=0.0.5 (from sae-lens)
  Downloading automated_interpretability-0.0.6-py3-none-any.whl.metadata (778 bytes)
Collecting babe<0.0.8,>=0.0.7 (from sae-lens)
  Downloading babe-0.0.7-py3-none-any.whl.metadata (10 kB)
Collecting datasets<3.0.0,>=2.17.1 (from sae-lens)
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting matplotlib<4.0.0,>=3.8.3 (from sae-lens)
  Downloading matplotlib-3.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting plotly-express<0.5.0,>=0.4.1 (from sae-lens)
  Downloading plotly_express-0.4.1-py2.py3-none-any.whl.metadata (1.7 kB)
Collecting pytest-profiling<2.0.0,>=1.7.0 (from sae-lens)
  Do

In [None]:
import torch
import os

from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

Using device: cuda


In [None]:

total_training_steps = 30_000
batch_size = 4096
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = 0
lr_decay_steps = total_training_steps // 5  # 20% of training
l1_warm_up_steps = total_training_steps // 20  # 5% of training

cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="togethercomputer/evo-1-8k-base",  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
    model_class_name="AutoModelForCausalLM",
    hook_name="",  # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
    hook_layer=10,  # Only one layer in the model.
    d_in=1024,  # the width of the mlp output.
    dataset_path="LongSafari/open-genome",  # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.
    is_dataset_tokenized=False,
    streaming=True,  # we could pre-download the token dataset if it was small.
    # SAE Parameters
    mse_loss_normalization=None,  # We won't normalize the mse loss,
    expansion_factor=16,  # the width of the SAE. Larger will result in better stats but slower training.
    b_dec_init_method="zeros",  # The geometric median can be used to initialize the decoder weights.
    apply_b_dec_to_input=False,  # We won't apply the decoder weights to the input.
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    normalize_activations="expected_average_only_in",
    # Training Parameters
    lr=5e-5,
    adam_beta1=0.9,  # adam params (default, but once upon a time we experimented with these.)
    adam_beta2=0.999,
    lr_scheduler_name="constant",  # constant learning rate with warmup.
    lr_warm_up_steps=lr_warm_up_steps,  # this can help avoid too many dead features initially.
    lr_decay_steps=lr_decay_steps,  # this will help us avoid overfitting.
    l1_coefficient=5,  # will control how sparse the feature activations are
    l1_warm_up_steps=l1_warm_up_steps,  # this can help avoid too many dead features initially.
    lp_norm=1.0,  # the L1 penalty (and not a Lp for p < 1)
    train_batch_size_tokens=batch_size,
    context_size=256,  # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.
    # Activation Store Parameters
    n_batches_in_buffer=64,  # controls how many activations we store / shuffle.
    training_tokens=total_training_tokens,  # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
    store_batch_size_prompts=16,
    # Resampling protocol
    use_ghost_grads=False,  # we don't use ghost grads anymore.
    feature_sampling_window=1000,  # this controls our reporting of feature sparsity stats
    dead_feature_window=1000,  # would effect resampling or ghost grads if we were using it.
    dead_feature_threshold=1e-4,  # would effect resampling or ghost grads if we were using it.
    # WANDB
    log_to_wandb=False,  # always use wandb unless you are just testing code.
    wandb_project="sae_lens_tutorial",
    wandb_log_frequency=30,
    eval_every_n_wandb_logs=20,
    # Misc
    device=device,
    seed=42,
    n_checkpoints=0,
    checkpoint_path="checkpoints",
    dtype="float32"
)
sparse_autoencoder = SAETrainingRunner(cfg).run()

The repository for togethercomputer/evo-1-8k-base contains custom code which must be executed to correctly load the model. You can inspect the repository content at https://hf.co/togethercomputer/evo-1-8k-base.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y
The repository for togethercomputer/evo-1-8k-base contains custom code which must be executed to correctly load the model. You can inspect the repository content at https://hf.co/togethercomputer/evo-1-8k-base.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


FileNotFoundError: [Errno 2] No such file or directory: '/root/.cache/huggingface/modules/transformers_modules/togethercomputer/evo-1-131k-base/78c715ab81852e02ec3b1c7e795dc7250d8c7625/positional_embeddings.py'

In [None]:
!pip install evo-model
#pip install flash_attn

Collecting evo-model
  Downloading evo_model-0.2.1-py3-none-any.whl.metadata (7.8 kB)
Collecting stripedhyena==0.2.2 (from evo-model)
  Downloading stripedhyena-0.2.2-py3-none-any.whl.metadata (19 kB)
Collecting biopython (from evo-model)
  Downloading biopython-1.84-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading evo_model-0.2.1-py3-none-any.whl (20 kB)
Downloading stripedhyena-0.2.2-py3-none-any.whl (30 kB)
Downloading biopython-1.84-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m43.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: biopython, stripedhyena, evo-model
Successfully installed biopython-1.84 evo-model-0.2.1 stripedhyena-0.2.2
