In [None]:
%load_ext autoreload
%autoreload 2

In [None]:

from transformers import AutoTokenizer
from activault_rcache import S3RCache, create_s3_client, ActivaultS3ActivationBuffer
import os
import json
import logging


In [None]:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def get_first_hook_prefix(run_name, bucket_name):
    """Get the first available hook prefix for the run."""
    s3_client = create_s3_client()
    response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=f"{run_name}/", Delimiter="/")
    if "CommonPrefixes" in response:
        # Get first hook directory
        first_hook = response["CommonPrefixes"][0]["Prefix"].rstrip("/")
        return first_hook
    return None


def get_model_name_from_config(run_name, bucket_name):
    """Get model name from the run's config file."""
    s3_client = create_s3_client()
    cfg_path = f"/tmp/{run_name}_cfg.json"
    s3_client.download_file(bucket_name, f"{run_name}/cfg.json", cfg_path)
    with open(cfg_path, "r") as f:
        model_name = json.load(f)["transformer_config"]["model_name"]
    os.remove(cfg_path)
    return model_name


def inspect_batch(states, input_ids, tokenizer):
    """Helper function to inspect a batch of activations and tokens."""
    logger.info(f"States shape: {states.shape}")
    logger.info(f"Input IDs shape: {input_ids.shape}")
    logger.info(f"\nStats: mean={states.mean().item():.4f}, std={states.std().item():.4f}")
    logger.info(f"Sample text: {tokenizer.decode(input_ids[0])[:100]}...")


In [None]:
%env AWS_ACCESS_KEY_ID=#put sensitive details here
%env AWS_SECRET_ACCESS_KEY=#...
%env S3_ENDPOINT_URL=#...
%env S3_BUCKET_NAME=main

In [None]:
!echo $AWS_ACCESS_KEY_ID
!echo $AWS_SECRET_ACCESS_KEY
!echo $S3_ENDPOINT_URL
!echo $S3_BUCKET_NAME

In [None]:

# Constants
RUN_NAME = "mistral.24b.fineweb"  # Base run name without hook
BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", "main")

# logger.info("Demo: Reading transformer activations from S3 cache")

# # Get first available hook prefix
# prefix = get_first_hook_prefix(RUN_NAME, BUCKET_NAME)
# if not prefix:
#     logger.error(f"No hooks found for run {RUN_NAME}")
#     raise ValueError(f"No hooks found for run {RUN_NAME}")
# logger.info(f"Using hook prefix: {prefix}")

# Initialize tokenizer
model_name = get_model_name_from_config(RUN_NAME, BUCKET_NAME)
tokenizer = AutoTokenizer.from_pretrained(model_name)

s3_prefix=["mistral.24b.fineweb/blocks.10.hook_resid_post", "mistral.24b.lmsys/blocks.10.hook_resid_post"]
s3_prefix=["mistral.24b.fineweb/blocks.10.hook_resid_post"]
shuffle=True

# s3_prefix = (
#     f"{cfg.train_dataset}/{hookpoint}"
#     if not isinstance(cfg.train_dataset, list)
#     else [f"{dataset}/{hookpoint}" for dataset in cfg.train_dataset]
# )

cache = S3RCache.from_credentials(
        aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"),
        aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"),
        s3_prefix=s3_prefix,
        bucket_name=BUCKET_NAME,
        device="cpu",
        buffer_size=2,
        return_ids=True,
        shuffle=True,
    )

In [None]:
buffer = ActivaultS3ActivationBuffer(cache, batch_size=8192, device="cuda", io="out")

In [None]:
for i, batch in enumerate(buffer):
    print(batch.shape)
    if i > 10:
        break


In [None]:
logger.info("\nReading first two megabatch files from S3...")
logger.info("Each file contains n_batches_per_file batches concatenated together")
logger.info("Format: [n_batches_per_file, sequence_length, hidden_dim]\n")

test_batches = []

# Inspect a few batches
for batch_idx, batch in enumerate(cache):
    if batch_idx >= 2:
        break
    test_batches.append(batch)
    inspect_batch(batch["states"], batch["input_ids"], tokenizer)

cache.finalize()

In [None]:
def inspect_batch(states, input_ids, tokenizer):
    """Helper function to inspect a batch of activations and tokens."""
    logger.info(f"States shape: {states.shape}")
    logger.info(f"Input IDs shape: {input_ids.shape}")
    logger.info(f"\nStats: mean={states.mean().item():.4f}, std={states.std().item():.4f}")
    logger.info(f"Sample text: {tokenizer.decode(input_ids[0])[:100]}...")

first_batch = test_batches[0]
inspect_batch(first_batch["states"], first_batch["input_ids"], tokenizer)


In [None]:
example = "<s>[INST]testing prompt[/INST] testing response"
print(tokenizer.encode(example))

In [None]:
print(first_batch["input_ids"][0, :10])

In [None]:
print(first_batch["input_ids"].shape)

In [None]:
# Count occurrences of 3 and 4 in the input_ids tensor
num_ones = (first_batch["input_ids"] == 1).sum().item()
num_threes = (first_batch["input_ids"] == 3).sum().item()
num_fours = (first_batch["input_ids"] == 4).sum().item()

print(f"Number of 3s: {num_threes}")
print(f"Number of 4s: {num_fours}")
print(f"Number of 1s: {num_ones}")

In [None]:
buffer = ActivaultS3ActivationBuffer(cache, batch_size=8192, device="cuda", io="out")