In [1]:
# imports
%load_ext autoreload
%autoreload 2

from epsilon_transformers.persistence import S3Persister, HackyPersister
from epsilon_transformers.training.configs.model_configs import RawModelConfig
from epsilon_transformers.process.processes import RRXOR, TransitionMatrixGHMM, ZeroOneR, Mess3
from epsilon_transformers.analysis.activation_analysis import get_beliefs_for_transformer_inputs

import numpy as np
import torch
import plotly.express as px
import pathlib

from sklearn.linear_model import LinearRegression
from torch.utils.data import IterableDataset, DataLoader

In [2]:
# PLAN

# 1. Generate all possible sequences of length seq_len + 1 and their probabilities, on CPU


# 2. Create a dataset class that can draw samples from the process, all data should be on GPU
# 3. Create a dataloader for the dataset
# 4. Create a transformer model
# 5. Train the model
# 6. Evaluate the model

In [3]:
from epsilon_transformers.process.processes import Process
from typing import Tuple

def generate_all_seqs(process: Process, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Generate all possible sequences and their probabilities for a given process and sequence length.

    Args:
        process (Process): The process to generate sequences from.
        seq_len (int): The length of sequences to generate.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
            - transformer_inputs: Tensor of shape (num_sequences, seq_len + 1) containing all possible sequences.
            - probs: Tensor of shape (num_sequences,) containing the probability of each sequence.

    Raises:
        ValueError: If the sum of probabilities is not equal to 1.0 (within floating-point precision).
    """
    # Generate all paths and probabilities
    msp = process.derive_mixed_state_presentation(depth=seq_len)
    paths, probs = msp.get_paths_and_probs(depth=seq_len)

    # Filter paths to keep only those of the desired length
    valid_paths = [path for path in paths if len(path) == seq_len]
    valid_probs = [prob for path, prob in zip(paths, probs) if len(path) == seq_len]

    # Convert to tensors
    transformer_inputs = torch.tensor(valid_paths, dtype=torch.int32)
    probs = torch.tensor(valid_probs, dtype=torch.float32)

    # Validate probabilities sum to 1
    if not torch.allclose(probs.sum(), torch.tensor(1.0)):
        error_message = f"The sum of probabilities is not equal to 1.0. Actual sum: {probs.sum().item():.6f}"
        raise ValueError(error_message)

    return transformer_inputs, probs



In [20]:
from transformer_lens import HookedTransformer, HookedTransformerConfig # type: ignore
from typing import Optional

# Configuration
cfg = {
    "seq_len": 8,
    "batch_size": 64,
    "batches_per_epoch": 100,
    "device": torch.device("mps" if torch.backends.mps.is_available() else "cpu")
}

# Process initialization
process = TransitionMatrixGHMM(Mess3().transition_matrix)
process.name = "RRXOR"

# Generate sequences and probabilities
transformer_inputs, probs = generate_all_seqs(process, cfg["seq_len"] + 1)
transformer_inputs = transformer_inputs.to(cfg["device"])
probs = probs.to(cfg["device"])

# Create an iterable batch generator
class BatchGenerator:
    def __init__(self, transformer_inputs, probs, cfg):
        self.transformer_inputs = transformer_inputs
        self.probs = probs
        self.cfg = cfg

    def __len__(self):
        return self.cfg["batches_per_epoch"]

    def __iter__(self):
        total_samples = self.cfg["batches_per_epoch"] * self.cfg["batch_size"]
        sample_inds = torch.multinomial(self.probs, total_samples, replacement=True)
        sample_inds = sample_inds.reshape(self.cfg["batches_per_epoch"], self.cfg["batch_size"])

        for batch_indices in sample_inds:
            batch = self.transformer_inputs[batch_indices]
            X, Y = batch[:, :-1], batch[:, 1:]
            yield X, Y

# Create the batch generator
batch_generator = BatchGenerator(transformer_inputs, probs, cfg)

model_cfg = {
    "d_model": 32,
    "d_head": 8,
    "n_layers": 1,
    "n_ctx": cfg["seq_len"] + 1,
    "n_heads": 4,
    "attn_only": False,
    "act_fn": "relu",
    "positional_embedding_type": "standard",
    "normalization_type": "LN",
}

def create_hooked_transformer(model_cfg: dict, device: torch.device, seed: Optional[int] = None
    ) -> HookedTransformer:
        config = HookedTransformerConfig(
            d_model=model_cfg["d_model"],
            d_head=model_cfg["d_head"],
            n_layers=model_cfg["n_layers"],
            n_ctx=model_cfg["n_ctx"],
            n_heads=model_cfg["n_heads"],
            d_mlp=4 * model_cfg["d_model"],
            d_vocab=3,
            attn_only=model_cfg["attn_only"],
            seed=seed,
            device=device,
            act_fn=model_cfg["act_fn"],
            positional_embedding_type=model_cfg["positional_embedding_type"],
            normalization_type=model_cfg["normalization_type"],
        )
        return HookedTransformer(config).to(device)

model = create_hooked_transformer(model_cfg, cfg["device"])
from torch.nn.functional import cross_entropy
# train!
from tqdm.autonotebook import tqdm

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for n_epoch in tqdm(range(10), desc="Epochs"):
    epoch_loss = 0
    for X, Y in tqdm(batch_generator, desc="Batches", leave=False):
        model.train()
        optimizer.zero_grad(set_to_none=True)
        logits = model(X)
        logits = logits.reshape(-1, 3)
        Y = Y.reshape(-1)
        loss = cross_entropy(logits, Y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    avg_loss = epoch_loss / len(batch_generator)
    tqdm.write(f"Epoch {n_epoch+1}, Average Loss: {avg_loss:.4f}")


Moving model to device:  mps


Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Batches:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1, Average Loss: 1.1021


Batches:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 2, Average Loss: 1.0923


Batches:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 3, Average Loss: 1.0908


Batches:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 4, Average Loss: 1.0908


Batches:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 5, Average Loss: 1.0927


Batches:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 6, Average Loss: 1.0914


Batches:   0%|          | 0/100 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [12]:
logits.shape

torch.Size([64, 8, 3])