In [14]:
import logging
import os
import pickle
import sys
from contextlib import nullcontext
import dataclasses

import numpy as np
from tqdm import tqdm

import torch
import transformers
from transformers import AutoConfig, AutoModel

from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoModel, AutoTokenizer
from transformers import (
    HfArgumentParser,
)

from tevatron.arguments import ModelArguments, DataArguments, \
    TevatronTrainingArguments as TrainingArguments
from data import HFQueryDataset, HFCorpusDataset

from repllama import RepLLaMA
from data import EncodeDataset, EncodeCollator
from utils import replace_with_xformers_attention

#pd.set_option('display.max_columns', 70)
#pd.set_option('display.max_rows', 120)

logger = logging.getLogger(__name__)
%reload_ext autoreload
%autoreload 2

In [2]:
torch.cuda.is_available()

True

In [4]:
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
    #model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    model_args, data_args, training_args = parser.parse_json_file(json_file='./train_params.json')
else:
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    model_args: ModelArguments
    data_args: DataArguments
    training_args: TrainingArguments

if training_args.local_rank > 0 or training_args.n_gpu > 1:
    raise NotImplementedError('Multi-GPU encoding is not supported.')

# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)

tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir, token='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'
    )
tokenizer.pad_token_id = tokenizer.unk_token_id
tokenizer.pad_token = tokenizer.unk_token
tokenizer.padding_side = "right"

In [13]:
dataclasses.asdict(data_args)

{'train_dir': None,
 'dataset_name': 'Tevatron/msmarco-passage',
 'passage_field_separator': ' ',
 'dataset_proc_num': 32,
 'train_n_passages': 16,
 'positive_passage_no_shuffle': False,
 'negative_passage_no_shuffle': False,
 'encode_in_path': None,
 'encoded_save_path': None,
 'encode_is_qry': False,
 'encode_num_shard': 1,
 'encode_shard_index': 0,
 'q_max_len': 32,
 'p_max_len': 196,
 'data_cache_dir': None}

In [7]:
text_max_length = data_args.q_max_len if data_args.encode_is_qry else data_args.p_max_len
if data_args.encode_is_qry:
    encode_dataset = HFQueryDataset(tokenizer=tokenizer, data_args=data_args,
                                    cache_dir=data_args.data_cache_dir or model_args.cache_dir)
else:
    encode_dataset = HFCorpusDataset(tokenizer=tokenizer, data_args=data_args,
                                        cache_dir=data_args.data_cache_dir or model_args.cache_dir)

Using the latest cached version of the module from /home/azureuser/.cache/huggingface/modules/datasets_modules/datasets/Tevatron--beir-corpus/02e1318cd9412cdf85d3f039bf36bec0af49ddeeab2279d4cf19fe556af6f29a (last modified on Wed Mar 13 12:48:30 2024) since it couldn't be found locally at Tevatron/beir-corpus, or remotely on the Hugging Face Hub.


In [8]:
encode_dataset.dataset[0]

{'docid': '4983',
 'title': 'Microstructural development of human newborn cerebral white matter assessed in vivo by diffusion tensor magnetic resonance imaging.',
 'text': 'Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical development and result in functional disabilities. A line scan diffusion-weighted magnetic resonance imaging (MRI) sequence with diffusion tensor analysis was applied to measure the apparent diffusion coefficient, to calculate relative anisotropy, and to delineate three-dimensional fiber architecture in cerebral white matter in preterm (n = 17) and full-term infants (n = 7). To assess effects of prematurity on cerebral white matter development, early gestation preterm infants (n = 10) were studied a second time at term. In the central white matter the mean apparent diffusion coefficient at 28 wk was high, 1.8 microm2/ms, and decreased toward term to 1.2 microm2/ms. In the posterior limb of the internal capsule,

In [9]:
encode_dataset = EncodeDataset(encode_dataset.process(data_args.encode_num_shard, data_args.encode_shard_index),
                                   tokenizer, max_len=text_max_length)

In [13]:
encode_dataset[0]

You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


('4983',
 {'input_ids': [1, 13382, 29901, 20140, 4984, 3631, 5849, 310, 5199, 716, 4939, 274, 406, 1182, 284, 4796, 4383, 1223, 11517, 297, 325, 4243, 491, 23253, 12489, 15611, 27396, 749, 6382, 292, 29889, 20561, 800, 310, 278, 11258, 310, 274, 406, 1182, 284, 4796, 4383, 297, 278, 14338, 5199, 17294, 508, 6602, 13979, 936, 5849, 322, 1121, 297, 13303, 766, 11614, 29889, 319, 1196, 12812, 23253, 29899, 7915, 287, 15611, 27396, 749, 6382, 292, 313, 29924, 3960, 29897, 5665, 411, 23253, 12489, 7418, 471, 7436, 304, 5645, 278, 20295, 23253, 10825, 29892, 304, 8147, 6198, 385, 275, 327, 14441, 29892, 322, 304, 628, 457, 403, 2211, 29899, 12531, 5713, 495, 11258, 297, 274, 406, 1182, 284, 4796, 4383, 297, 758, 8489, 313, 29876, 353, 29871, 29896, 29955, 29897, 322, 2989, 29899, 8489, 3041, 1934, 313, 29876, 353, 29871, 29955, 467, 1763, 24809, 9545, 310, 5188, 1337, 537, 373, 274, 406, 1182, 284, 4796, 4383, 5849, 29892, 4688, 7737, 362, 758, 8489, 3041, 1934, 313, 29876, 353, 29871, 29896

In [16]:
len(encode_dataset[0][1]['input_ids'])

511

In [17]:
encode_loader = DataLoader(
        encode_dataset,
        batch_size=training_args.per_device_eval_batch_size,
        collate_fn=EncodeCollator(
            tokenizer,
            max_length=text_max_length,
            padding='max_length'
        ),
        shuffle=False,
        drop_last=False,
        num_workers=training_args.dataloader_num_workers,
    )

In [18]:
model = RepLLaMA.load(
        model_name_or_path=model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )

Loading checkpoint shards: 100%|██████████| 2/2 [01:32<00:00, 46.05s/it]


In [19]:
encoded = []
lookup_indices = []
model = model.to(training_args.device)
model.eval()

RepLLaMA(
  (lm_q): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_p): LlamaModel(

In [20]:
for (batch_ids, batch) in tqdm(encode_loader):
        lookup_indices.extend(batch_ids)
        with torch.cuda.amp.autocast() if training_args.fp16 else nullcontext():
            with torch.no_grad():
                for k, v in batch.items():
                    batch[k] = v.to(training_args.device)
                if data_args.encode_is_qry:
                    model_output = model(query=batch)
                    encoded.append(model_output.q_reps.cpu().detach().numpy())
                else:
                    model_output = model(passage=batch)
                    encoded.append(model_output.p_reps.cpu().detach().numpy())

encoded = np.concatenate(encoded)

with open(data_args.encoded_save_path, 'wb') as f:
    pickle.dump((encoded, lookup_indices), f)

100%|██████████| 324/324 [33:31<00:00,  6.21s/it]


In [24]:
model.config

LlamaConfig {
  "_name_or_path": "meta-llama/Llama-2-7b-hf",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "max_position_embeddings": 4096,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 32,
  "pad_token_id": 0,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "float16",
  "transformers_version": "4.37.2",
  "use_cache": true,
  "vocab_size": 32000
}

In [34]:
dir(model.lm_p.layers[0].self_attn.q_proj)

['T_destination',
 '__annotations__',
 '__call__',
 '__class__',
 '__constants__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_backward_hooks',
 '_backward_pre_hooks',
 '_buffers',
 '_call_impl',
 '_compiled_call_impl',
 '_forward_hooks',
 '_forward_hooks_always_called',
 '_forward_hooks_with_kwargs',
 '_forward_pre_hooks',
 '_forward_pre_hooks_with_kwargs',
 '_get_backward_hooks',
 '_get_backward_pre_hooks',
 '_get_name',
 '_is_full_backward_hook',
 '_is_hf_initialized',
 '_load_from_state_dict',
 '_load_state_dict_post_hooks',
 '_load_state_dict_pre_hooks',
 '_maybe_warn_non_full_backward_hook',
 '_modules',
 '_nam

In [40]:
model.lm_p.layers[0].self_attn.q_proj.state_dict()['weight'].dtype

torch.float32