In [1]:
from datasets import load_dataset
import random
import numpy as np
import torch

def set_seed(SEED):
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(0)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
!huggingface-cli login --token "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"

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /home/cychang/.cache/huggingface/token
Login successful


In [3]:
from transformers import AutoModel, AutoTokenizer, pipeline, BitsAndBytesConfig

################################################################################
# bitsandbytes parameters
################################################################################

# Activate 4-bit precision base model loading
use_4bit = True

# Compute dtype for 4-bit base models
bnb_4bit_compute_dtype = "float16"

# Quantization type (fp4 or nf4)
bnb_4bit_quant_type = "nf4"

# Activate nested quantization for 4-bit base models (double quantization)
use_nested_quant = False

# Load tokenizer and model with QLoRA configuration
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=use_4bit,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=use_nested_quant,
)

# Check GPU compatibility with bfloat16
if compute_dtype == torch.float16 and use_4bit:
    major, _ = torch.cuda.get_device_capability()
    if major >= 8:
        print("=" * 80)
        print("Your GPU supports bfloat16: accelerate training with bf16=True")
        print("=" * 80)

# The model that you want to train from the Hugging Face hub
model_name = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModel.from_pretrained(model_name, device_map='auto', quantization_config=bnb_config)

# Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, device_map='auto')
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training

Loading checkpoint shards: 100%|██████████| 2/2 [01:55<00:00, 57.74s/it]


In [4]:
import os
TRAIN_ROOT = "/tmp2/cychang/tmp"
images = [image.rstrip('.jpg') for image in os.listdir(TRAIN_ROOT)]
len(images)

471436

In [11]:
from tqdm import tqdm, tnrange
pipe = pipeline(task="feature-extraction", model=model, tokenizer=tokenizer, device_map='auto')
llama2_embedding = {}
for e, caption in enumerate(tqdm(images[:210000])):
    result = pipe(f"<s>[INST] {caption} [/INST]", return_tensors=True)
    llama2_embedding[caption] = result.cpu().detach().numpy()[0][0]

  0%|          | 101/210000 [00:28<16:37:56,  3.51it/s]

{'4.4 cu ft all-refrigerator, glass shelves, vegetable crisper, black': array([ 0.34350586, -0.07244873, -0.02818298, ...,  0.10522461,
       -0.10870361,  0.30541992], dtype=float32), 'Will your Cabinet or <PERSON> fit - Find the right size storage piece Interior Design Basics, Home Furnishings, Locker Storage, <PERSON>, Hardware, Cabinet, Furniture, Home Decor, Jelly Cupboard': array([ 0.34350586, -0.07244873, -0.02818298, ...,  0.10522461,
       -0.10870361,  0.30541992], dtype=float32), "PARIS, FRANCE - SEPTEMBER 08: Kids show the screen of their smartphone with Nintendo Co.'s Pokemon Go augmented-reality game at the Trocadero in front of the Eiffel tower on September 8, 2016 in Paris, France": array([ 0.34350586, -0.07244873, -0.02818298, ...,  0.10522461,
       -0.10870361,  0.30541992], dtype=float32), 'Continuous one line drawing of women cook in a pastry shop. Cuts flour products in modern minimalistic style, vector illustration stock illustration': array([ 0.34350586, -0.0




In [12]:
import pickle
llama2_embed_file = '/tmp2/cychang/llama2_embedding.pkl'
with open(llama2_embed_file, 'wb') as fp:
    pickle.dump(llama2_embedding, fp)
    print('dictionary saved successfully to file')

dictionary saved successfully to file
