In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import time, json
import pandas as pd
from openai import OpenAI
from tqdm.auto import tqdm
import spacy

import sys
sys.path.append("../")
import os

import logging
from src.utils import logging_utils
from src.utils import env_utils
from src import functional
from datasets import load_dataset

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

import torch
import transformers

logger.info(f"{torch.__version__=}, {torch.version.cuda=}")
logger.info(f"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}")
logger.info(f"{transformers.__version__=}")

  from .autonotebook import tqdm as notebook_tqdm


2024-08-21 18:01:18 __main__ INFO     torch.__version__='2.3.1', torch.version.cuda='12.1'
2024-08-21 18:01:18 __main__ INFO     torch.cuda.is_available()=True, torch.cuda.device_count()=1, torch.cuda.get_device_name()='NVIDIA RTX A6000'
2024-08-21 18:01:18 __main__ INFO     transformers.__version__='4.43.3'


In [4]:
import torch

from nnsight import LanguageModel
from src.models import ModelandTokenizer

# model_name = "openai-community/gpt2-xl"
# model_name = "openai-community/gpt2"
model_name = "EleutherAI/pythia-410m"

mt = ModelandTokenizer(
    model_key=model_name,
    torch_dtype=torch.float32,
)

2024-08-21 18:01:19 accelerate.utils.modeling INFO     We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
2024-08-21 18:01:20 src.models INFO     loaded model </home/local_arnab/Codes/00_MODEL/EleutherAI/pythia-410m> | size: 1648.227 MB | dtype: torch.float32 | device: cuda:0


In [5]:
dataset_name = "roneneldan/TinyStories"

In [5]:
from dictionary_learning.dictionary import AutoEncoder

model_data_dir = os.path.join(
    model_name.split("/")[-1],
    dataset_name.split("/")[-1],
)

sae_dir = os.path.join(
    env_utils.DEFAULT_RESULTS_DIR,
    "trained_saes",
    model_data_dir,
    "trainer_0/ae.pt"
)

sae = AutoEncoder.from_pretrained(
    path = sae_dir,
    device=mt.device
).to(mt.dtype)
sae

AutoEncoder(
  (encoder): Linear(in_features=1024, out_features=4096, bias=True)
  (decoder): Linear(in_features=4096, out_features=1024, bias=False)
)

In [6]:
import numpy as np
from src.utils import experiment_utils
experiment_utils.set_seed(123456)


dataset = load_dataset("mickume/harry_potter_tiny")
dataset["train"][:5]["text"]

2024-08-21 17:00:34 src.utils.experiment_utils INFO     setting all seeds to 123456
2024-08-21 17:00:34 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443
2024-08-21 17:00:34 urllib3.connectionpool DEBUG    https://huggingface.co:443 "GET /api/datasets/mickume/harry_potter_tiny HTTP/11" 200 975
2024-08-21 17:00:34 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): s3.amazonaws.com:443
2024-08-21 17:00:35 urllib3.connectionpool DEBUG    https://s3.amazonaws.com:443 "HEAD /datasets.huggingface.co/datasets/datasets/mickume/harry_potter_tiny/mickume/harry_potter_tiny.py HTTP/11" 404 0
2024-08-21 17:00:35 urllib3.connectionpool DEBUG    https://huggingface.co:443 "GET /api/datasets/mickume/harry_potter_tiny HTTP/11" 200 975
2024-08-21 17:00:35 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443
2024-08-21 17:00:35 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /datasets/mickume/harry_potte

['"RUN!" Harry yelled, grabbing at her robes. Hermione’s feet hit the hard ground, running in tandem to her heartbeat as the prophecies shattered around them. With each dropped prophecy, the voices began whispering into the dark, speaking over one-another in a cacophony as Lucius Malfoy yelled for pursuit.\xa0',
 'A death eater appeared beside them and Hermione screamed, watching as Harry elbowed him in the face. She could hear more yelling, screaming against the rising cacophony as shelves began shuddering from the impact of their spells, knocking into one-another.\xa0',
 'Looking over shoulder, she glanced as another Death Eater appeared, reaching out to grab at Harry’s shoulder as the wand pulled back, a curse almost spoken––"Stupefy!" She yelled and watched as the Death Eater froze, his wand’s light dying like a candle blown out.',
 'They ran on, Neville wheezing for breath beside her. "Come on," she said, grabbing him. "Come on, we have to––" Neville threw a stupefy behind her and

In [21]:
relu = torch.nn.ReLU()

cache_dir = os.path.join(
    env_utils.DEFAULT_RESULTS_DIR,
    "sae_mixtures",
    model_data_dir,
)

os.makedirs(cache_dir, exist_ok=True)

from src.models import prepare_input
from src.functional import get_module_nnsight, free_gpu_cache

limit = 100
context_limit = 1024

sae_layer_name = mt.layer_name_format.format(mt.n_layer // 2)

for doc_index, doc in tqdm(enumerate(dataset["train"][:limit]["text"])):
    inputs = prepare_input(
        prompts = doc,
        tokenizer = mt
    )
    if inputs["input_ids"].shape[1] > context_limit:
        inputs["input_ids"] = inputs["input_ids"][:, :context_limit]
        inputs["attention_mask"] = inputs["attention_mask"][:, :context_limit]

    # print(f"{doc=}")
    # logger.info(inputs["input_ids"].shape)

    with mt.trace(inputs, scan = False, validate = False) as trace:
        module = get_module_nnsight(mt, sae_layer_name)
        sae_input = module.output[0].save()
    
    sae_mixture = relu(sae.encoder(sae_input))
    # logger.info(f"{sae_input.shape=} | {sae_mixture.shape=}")

    cache = {
        "layer": sae_layer_name,
        "doc": doc,
        "sae_input": sae_input.detach().cpu().numpy().astype(np.float32),
        "sae_mixture": sae_mixture.detach().cpu().numpy().astype(np.float32),
    }

    cache_path = os.path.join(cache_dir, f"{doc_index}")
    np.savez_compressed(cache_path, **cache)

    free_gpu_cache()

100it [00:16,  5.99it/s]


In [9]:
import numpy as np
sae_path = "/home/local_arnab/Codes/sae/results/sae_mixtures/pythia-410m/TinyStories/4.npz"

file = np.load(sae_path)
file["sae_mixture"].shape, file["sae_input"].shape

((1, 52, 4096), (1, 52, 1024))

In [1]:
from torch.nn import ReLU
relu = ReLU()
relu(torch.Tensor(file["sae_mixture"]))

NameError: name 'torch' is not defined