In [20]:
import json
import os
from glob import glob

import torch
import numpy as np
from safetensors.torch import save_file, load_file

from transformers import AutoTokenizer, AutoModel

In [21]:
model_name = "facebook/opt-350m"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, load_in_8bit=True, device_map="auto")

In [22]:
shards = glob("../../../data/text-laion-20M-images-with-text-train/*")

In [23]:
shard = shards[0]

In [24]:
batch_size = 8
max_length - 128
extract_layer = -1
if extract_layer < 0:
    _extract_layer = model.config.num_hidden_layers + extract_layer

batch = []
file_suffix = os.path.basename(model_name) + "_layer" + str(_extract_layer) + ".safetensors"
save_dir = "features/"
os.makedirs(save_dir, exist_ok=True)

for meta_path in glob(os.path.join(shard, "*.json")):
    id_ = os.path.basename(meta_path).rstrip(".json")
    meta = json.load(open(meta_path))

    caption = meta["caption"]
    batch.append({"id": id_, "caption": caption})

    if len(batch) == batch_size:
        captions_batch = [x["caption"] for x in batch]
        inputs = tokenizer(captions_batch, return_tensors="pt", padding=True, truncation=True, max_length=max_length).to("cuda")
        with torch.no_grad():
            output = model(**inputs, output_hidden_states=True)
            hidden_states = output.hidden_states[extract_layer].cpu()
            for i, item in enumerate(batch):
                item_id = item["id"]
                file_path = os.path.join(save_dir, f"{item_id}_{file_suffix}")
                save_file({"hidden_states": hidden_states[i]}, file_path)
        batch = []
        break

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [19]:
load_file(file_path)

{'hidden_states': tensor([[-1.6025,  0.0000, -0.8257,  ..., -1.3506,  0.9683, -1.4590],
         [ 1.9883,  1.7012, -3.8281,  ..., -0.1750, -0.2998, -0.0864],
         [-1.3213,  0.8457,  1.4121,  ..., -1.4688, -4.5781, -0.7188],
         ...,
         [-1.6465,  6.1914,  1.8418,  ...,  2.0059, -0.7075,  0.1279],
         [-1.6465,  6.1914,  1.8418,  ...,  2.0059, -0.7075,  0.1279],
         [-1.6465,  6.1914,  1.8418,  ...,  2.0059, -0.7075,  0.1279]],
        dtype=torch.float16)}