In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import time, json
import pandas as pd
from openai import OpenAI
from tqdm.auto import tqdm
import spacy

import sys
sys.path.append("../")
import os

import logging
from src.utils import logging_utils
from src.utils import env_utils
from src import functional
from datasets import load_dataset

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

import torch
import transformers

logger.info(f"{torch.__version__=}, {torch.version.cuda=}")
logger.info(f"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}")
logger.info(f"{transformers.__version__=}")

  from .autonotebook import tqdm as notebook_tqdm


2024-10-23 22:09:33 __main__ INFO     torch.__version__='2.4.1', torch.version.cuda='12.1'
2024-10-23 22:09:33 __main__ INFO     torch.cuda.is_available()=True, torch.cuda.device_count()=1, torch.cuda.get_device_name()='NVIDIA RTX A6000'
2024-10-23 22:09:33 __main__ INFO     transformers.__version__='4.44.2'


In [3]:
import torch

from nnsight import LanguageModel
from src.models import ModelandTokenizer

# model_name = "openai-community/gpt2-xl"
# model_name = "openai-community/gpt2"
# model_name = "EleutherAI/pythia-410m"
# model_name = "google/gemma-2-2b"
# model_name = "meta-llama/Llama-3.2-1B"
model_name = "allenai/OLMo-1B-0724-hf"

mt = ModelandTokenizer(
    model_key=model_name,
    torch_dtype=torch.float32,
)

2024-10-23 22:09:51 accelerate.utils.modeling INFO     We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.65it/s]

2024-10-23 22:09:52 src.models INFO     loaded model </home/local_arnab/Codes/00_MODEL/allenai/OLMo-1B-0724-hf> | size: 4882.004 MB | dtype: torch.float32 | device: cuda:0





In [13]:
dataset_name = "roneneldan/TinyStories"
sae_data_checkpoint = 2000000

In [14]:
from dictionary_learning.dictionary import AutoEncoder, GatedAutoEncoder

model_data_dir = os.path.join(
    model_name.split("/")[-1],
    dataset_name.split("/")[-1],
)

sae_dir = os.path.join(
    env_utils.DEFAULT_RESULTS_DIR,
    "train_sae",
    model_data_dir,
    str(sae_data_checkpoint),
    "trainer_0/ae.pt"
)

sae = GatedAutoEncoder.from_pretrained(
    path = sae_dir,
    device=mt.device
).to(mt.dtype)
sae

  state_dict = t.load(path)


GatedAutoEncoder(
  (encoder): Linear(in_features=2048, out_features=16384, bias=False)
  (decoder): Linear(in_features=16384, out_features=2048, bias=False)
)

In [15]:
import numpy as np
from src.utils import experiment_utils
experiment_utils.set_seed(123456)

# eval_dataset_name = "mickume/harry_potter_tiny"
eval_dataset_name = "jahjinx/IMDb_movie_reviews"

eval_dataset = load_dataset(eval_dataset_name)
eval_dataset["train"][:5]["text"]

2024-10-23 22:15:55 src.utils.experiment_utils INFO     setting all seeds to 123456
2024-10-23 22:15:55 urllib3.connectionpool DEBUG    Resetting dropped connection: huggingface.co


2024-10-23 22:15:56 urllib3.connectionpool DEBUG    https://huggingface.co:443 "GET /api/datasets/jahjinx/IMDb_movie_reviews HTTP/11" 200 1693
2024-10-23 22:15:56 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): s3.amazonaws.com:443
2024-10-23 22:15:56 urllib3.connectionpool DEBUG    https://s3.amazonaws.com:443 "HEAD /datasets.huggingface.co/datasets/datasets/jahjinx/IMDb_movie_reviews/jahjinx/IMDb_movie_reviews.py HTTP/11" 404 0
2024-10-23 22:15:56 urllib3.connectionpool DEBUG    https://huggingface.co:443 "GET /api/datasets/jahjinx/IMDb_movie_reviews HTTP/11" 200 1693
2024-10-23 22:15:56 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443
2024-10-23 22:15:56 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /datasets/jahjinx/IMDb_movie_reviews/resolve/ef30f6a046230c843d79822b928267efd9453d5b/README.md HTTP/11" 200 0
2024-10-23 22:15:56 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:

['Beautifully photographed and ably acted, generally, but the writing is very slipshod. There are scenes of such unbelievability that there is no joy in the watching. The fact that the young lover has a twin brother, for instance, is so contrived that I groaned out loud. And the "emotion-light bulb connection" seems gimmicky, too.<br /><br />I don\'t know, though. If you have a few glasses of wine and feel like relaxing with something pretty to look at with a few flaccid comedic scenes, this is a pretty good movie. No major effort on the part of the viewer required. But Italian film, especially Italian comedy, is usually much, much better than this.',
 'Well, where to start describing this celluloid debacle? You already know the big fat NADA passing as a plot, so let\'s jut point out that this is so PC it\'s offensive. Hard to believe that Frank Oz, the same guy that gave us laugh riots like Little Shop of Horrors and Bowfinger, made this unfunny mess.<br /><br />So, this guy doesn\'t 

In [16]:
relu = torch.nn.ReLU()

cache_dir = os.path.join(
    env_utils.DEFAULT_RESULTS_DIR,
    "cache_sae_mixtures",
    eval_dataset_name.split("/")[-1],
    model_data_dir,
    str(sae_data_checkpoint),
)

os.makedirs(cache_dir, exist_ok=True)

from src.models import prepare_input
from src.functional import get_module_nnsight, free_gpu_cache

limit = 100
context_limit = 1024

sae_layer_name = mt.layer_name_format.format(mt.n_layer // 2)

for doc_index, doc in tqdm(enumerate(eval_dataset["train"][:limit]["text"])):
    inputs = prepare_input(
        prompts = doc,
        tokenizer = mt
    )
    if inputs["input_ids"].shape[1] > context_limit:
        inputs["input_ids"] = inputs["input_ids"][:, :context_limit]
        inputs["attention_mask"] = inputs["attention_mask"][:, :context_limit]

    # print(f"{doc=}")
    # logger.info(inputs["input_ids"].shape)

    with mt.trace(inputs, scan = False, validate = False) as trace:
        module = get_module_nnsight(mt, sae_layer_name)
        sae_input = module.output[0].save()
    
    sae_mixture = sae.encode(sae_input)
    # logger.info(f"{sae_input.shape=} | {sae_mixture.shape=}")

    cache = {
        "layer": sae_layer_name,
        "doc": doc,
        "sae_input": sae_input.detach().cpu().numpy().astype(np.float32),
        "sae_mixture": sae_mixture.detach().cpu().numpy().astype(np.float32),
    }

    cache_path = os.path.join(cache_dir, f"{doc_index}")
    np.savez_compressed(cache_path, **cache)

    free_gpu_cache()

100it [00:34,  2.86it/s]


In [17]:
import numpy as np
import torch
sae_path = "/home/local_arnab/Codes/Projects/sae/results/cache_sae_mixtures/IMDb_movie_reviews/OLMo-1B-0724-hf/TinyStories/2000000/39.npz"

file = np.load(sae_path)
file["sae_mixture"].shape, file["sae_input"].shape

((1, 162, 16384), (1, 162, 2048))

In [18]:
from torch.nn import ReLU
relu = ReLU()
relu(torch.Tensor(file["sae_mixture"]))

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])

In [23]:
t = torch.Tensor(file["sae_mixture"]).squeeze()
t.shape
# t.mean(dim = 0).shape

torch.Size([162, 16384])

In [28]:
t.norm()

tensor(1501.8020)

In [29]:
file.keys()  # Check the keys in the loaded file

KeysView(NpzFile '/home/local_arnab/Codes/Projects/sae/results/cache_sae_mixtures/IMDb_movie_reviews/OLMo-1B-0724-hf/TinyStories/2000000/39.npz' with keys: layer, doc, sae_input, sae_mixture)