In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
import h5py
import json
import torch

from tqdm import tqdm

from datasets import load_dataset 
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
# dataset = load_dataset("allenai/tulu-3-sft-mixture")

# def filter_single_turn(example):
#     # One user turn and one assistant turn
#     return len(example["messages"]) == 2

# dataset = dataset.shuffle(seed=42)

# print("Filtering for single turn data only")
# prev_len = len(dataset["train"])
# dataset = dataset.filter(filter_single_turn)
# print(f"Filtered out {prev_len - len(dataset['train'])} rows. Now dataset is length {len(dataset['train'])}")

# split_dataset = dataset["train"].train_test_split(test_size=100_000, seed=42)
# train_dataset = split_dataset["train"]
# test_dataset = split_dataset["test"]

# train_dataset.to_json("data/tulu3_sft_train.json")
# test_dataset.to_json("data/tulu3_sft_test.json")

train_dataset = load_dataset("json", data_files="data/tulu3_sft_train.json")
test_dataset = load_dataset("json", data_files="data/tulu3_sft_test.json")

print("Train split:")
print(train_dataset)
print("\nTest split:")
print(test_dataset)

In [None]:
train_dataset["train"][0]

In [None]:
# model_repo = "microsoft/Phi-3.5-mini-instruct"
model_repo = "meta-llama/Meta-Llama-3.1-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_repo,  device_map="cuda:0", torch_dtype=torch.bfloat16)
_ = model.eval()

In [None]:
def generate_data(start, end):
  train_dataset_w_hidden_states = []

  for i in tqdm(range(start, end)):
    messages = train_dataset["train"][i]["messages"]

    input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt")
    if input_ids.shape[1] > 2_048:
      print(f"Skipping sample with context len {input_ids.shape[1]}")
      continue

    with torch.no_grad():
      output = model(input_ids.to("cuda:0"), output_hidden_states=True)

    hidden_state = output.hidden_states[-1]
    logits = output.logits[0].argmax(dim=-1)

    data_point = {
      "input_ids": input_ids.cpu()[0].tolist(),
      "hidden_state": hidden_state.cpu()[0].tolist()
    }

    train_dataset_w_hidden_states.append(data_point)

  with h5py.File(f"data/train_dataset_w_hidden_states_{start}-{end}.h5", "w") as f:
      for i, item in enumerate(train_dataset_w_hidden_states):
          grp = f.create_group(str(i))
          grp.create_dataset("input_ids", data=item["input_ids"], dtype='int32')
          grp.create_dataset("hidden_states", data=item["hidden_state"], dtype='float32')

# start, end = 0, 1_000
# generate_data(start, end)

In [None]:
big_start, big_end = 0, 20_000
for i in range(big_start, big_end, 1_000):
    small_start, small_end = i, i + 1_000
    print(f"Generating data {small_start}-{small_end}")
    generate_data(small_start, small_end)