In [1]:
import os
os.chdir("/traindata/maksim/repos/unilm/simlm/src")
!pwd

/traindata/maksim/repos/unilm/simlm/src


In [2]:
import logging

import torch
from typing import Dict
from functools import partial
from transformers.utils.logging import enable_explicit_format
from transformers.trainer_callback import PrinterCallback
from transformers import (
    AutoTokenizer,
    HfArgumentParser,
    EvalPrediction,
    Trainer,
    set_seed,
    PreTrainedTokenizerFast
)

from logger_config import logger, LoggerCallback
from config import Arguments
from trainers import BiencoderTrainer
from loaders import RetrievalDataLoader
from collators import BiencoderCollator
from metrics import accuracy, batch_mrr
from models import BiencoderModel

def _common_setup(args: Arguments):
    if args.process_index > 0:
        logger.setLevel(logging.WARNING)
    enable_explicit_format()
    set_seed(args.seed)


def _compute_metrics(args: Arguments, eval_pred: EvalPrediction) -> Dict[str, float]:
    # field consistent with BiencoderOutput
    preds = eval_pred.predictions
    scores = torch.tensor(preds[-1]).float()
    labels = torch.arange(0, scores.shape[0], dtype=torch.long) * args.train_n_passages
    labels = labels % scores.shape[1]

    topk_metrics = accuracy(output=scores, target=labels, topk=(1, 3))
    mrr = batch_mrr(output=scores, target=labels)

    return {'mrr': mrr, 'acc1': topk_metrics[0], 'acc3': topk_metrics[1]}

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
os.environ["DATA_DIR"] = "./data/msmarco_bm25_official/"
os.environ["OUTPUT_DIR"] = "./tmp/"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


In [4]:
import sys
sys.argv = ['src/train_biencoder.py', '--deepspeed', '/traindata/maksim/repos/unilm/simlm/ds_config.json', '--model_name_or_path', 'intfloat/simlm-base-msmarco', '--per_device_train_batch_size', '16', '--per_device_eval_batch_size', '32', '--add_pooler', 'False', '--t', '0.02', '--seed', '1234', '--do_train', '--fp16', '--train_file', '/traindata/maksim/repos/unilm/simlm/data/msmarco_bm25_official/train.jsonl', '--validation_file', '/traindata/maksim/repos/unilm/simlm/data/msmarco_bm25_official/dev.jsonl', '--q_max_len', '32', '--p_max_len', '144', '--train_n_passages', '16', '--dataloader_num_workers', '1', '--num_train_epochs', '3', '--learning_rate', '2e-5', '--use_scaled_loss', 'True', '--warmup_steps', '1000', '--share_encoder', 'True', '--logging_steps', '50', '--output_dir', '/traindata/maksim/repos/unilm/simlm/tmp/', '--data_dir', '/traindata/maksim/repos/unilm/simlm/data/msmarco_bm25_official/', '--save_total_limit', '2', '--save_strategy', 'epoch', '--evaluation_strategy', 'epoch', '--remove_unused_columns', 'False', '--overwrite_output_dir', '--disable_tqdm', 'True', '--report_to', 'none']
parser = HfArgumentParser((Arguments,))
args: Arguments = parser.parse_args_into_dataclasses()[0]
#args.local_rank = -1 # disable dist training, for debugging!
_common_setup(args)
args



In [5]:
tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(args.model_name_or_path)
model: BiencoderModel = BiencoderModel.build(args=args)
logger.info(model)
logger.info('Vocab size: {}'.format(len(tokenizer)))

data_collator = BiencoderCollator(
    tokenizer=tokenizer,
    pad_to_multiple_of=8 if args.fp16 else None)

You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[2024-11-11 12:05:43,291 INFO] BiencoderModel(
  (lm_q): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
    

In [6]:
retrieval_data_loader = RetrievalDataLoader(args=args, tokenizer=tokenizer)
train_dataset = retrieval_data_loader.train_dataset
eval_dataset = retrieval_data_loader.eval_dataset

trainer: Trainer = BiencoderTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset if args.do_train else None,
    eval_dataset=eval_dataset if args.do_eval else None,
    data_collator=data_collator,
    compute_metrics=partial(_compute_metrics, args),
    tokenizer=tokenizer,
)
trainer.remove_callback(PrinterCallback)
trainer.add_callback(LoggerCallback)
retrieval_data_loader.trainer = trainer
model.trainer = trainer

[2024-11-11 12:05:43,542 INFO] Sample 231070 of the training set: {'query_id': '1179605', 'query': 'average cost of fence homeadvisor', 'positives': {'doc_id': ['4435119'], 'score': [-1.0]}, 'negatives': {'doc_id': ['905674', '6466825', '5013998', '6008466', '7792504', '6381090', '5323265', '6202403', '8409672', '3529117', '1786957', '6098973', '5041970', '2113635', '7383823', '2046830', '6953357', '3318649', '5570938', '5890165', '2349672', '2975852', '6831044', '5653389', '6242690', '5014719', '571773', '1729765', '3472300', '5537537', '395723', '4322414', '8247821', '5471572', '6509108', '8311952', '3860934', '757714', '1097619', '6491138', '5182081', '1493425', '7057757', '1326609', '4323128', '4479581', '1347526', '6615847', '5314570', '757716', '6439609', '5841024', '6014361', '6471506', '3756582', '1851271', '4106998', '5570467', '5493263', '4257654', '7938907', '3613226', '2104322', '7749241', '1697703', '4581371', '1889331', '2559842', '7990737', '4094925', '3809462', '6913103

In [7]:
example = train_dataset[0]
list(example.keys())

['q_input_ids',
 'q_token_type_ids',
 'q_attention_mask',
 'd_input_ids',
 'd_token_type_ids',
 'd_attention_mask']

In [8]:
len(example["q_input_ids"])

16

In [9]:
example["q_input_ids"]

[101,
 1007,
 2054,
 2001,
 1996,
 6234,
 4254,
 1997,
 1996,
 3112,
 1997,
 1996,
 7128,
 2622,
 1029,
 102]

In [10]:
len(example["d_input_ids"])

16

In [11]:
[len(d) for d in example["d_input_ids"]]

[57, 59, 47, 98, 44, 79, 70, 89, 71, 63, 72, 76, 46, 57, 55, 32]

In [12]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("intfloat/simlm-base-msmarco")
tokenizer.decode(example["q_input_ids"])



'[CLS] ) what was the immediate impact of the success of the manhattan project? [SEP]'

In [13]:
tokenizer.decode(example["d_input_ids"][4])

'[CLS] history of the twin towers [SEP] downtown lower manhattan association is created by real estate developer david rockefeller to revitalize lower manhattan and begins to promote the idea of a world trade and finance center in new york city. [SEP]'

In [14]:
from torch.utils.data import DataLoader

dataloader_params = {
    "batch_size": trainer._train_batch_size,
    "collate_fn": data_collator,
    "num_workers": trainer.args.dataloader_num_workers,
    "pin_memory": trainer.args.dataloader_pin_memory,
    "persistent_workers": trainer.args.dataloader_persistent_workers,
}

if not isinstance(train_dataset, torch.utils.data.IterableDataset):
    dataloader_params["sampler"] = trainer._get_train_sampler()
    dataloader_params["drop_last"] = trainer.args.dataloader_drop_last
    dataloader_params["prefetch_factor"] = trainer.args.dataloader_prefetch_factor

train_dataloader = DataLoader(train_dataset, **dataloader_params)

In [15]:
for batch in train_dataloader:
    break
batch

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


{'q_input_ids': tensor([[  101,  4248, 17470,  4013,  2490,  2193,   102,     0,     0,     0,
             0,     0,     0,     0,     0,     0],
        [  101,  2003,  1996,  2783,  2051, 22851,  2102,  2030,  8827,  2102,
           102,     0,     0,     0,     0,     0],
        [  101,  2054,  2515,  2632,  8569, 27833,  2812,   102,     0,     0,
             0,     0,     0,     0,     0,     0],
        [  101,  2043,  2064,  2017,  4929,  2317,  6471,   102,     0,     0,
             0,     0,     0,     0,     0,     0],
        [  101,  6207, 11306,  1005,  1055,  5592,  4003,  5703,  3042,  2193,
           102,     0,     0,     0,     0,     0],
        [  101,  2054,  2064,  1045,  6570,  6001,  2007, 13908, 13597,   102,
             0,     0,     0,     0,     0,     0],
        [  101,  2079, 21122, 15580,  9880,   102,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0],
        [  101,  2054,  2217,  2003,  2115, 22524,  2006,  2005,  2

In [16]:
batch["q_input_ids"].shape

torch.Size([16, 16])

In [17]:
from trainers.biencoder_trainer import _unpack_qp
query_batch_dict, doc_batch_dict = _unpack_qp(batch)
query_batch_dict

{'input_ids': tensor([[  101,  4248, 17470,  4013,  2490,  2193,   102,     0,     0,     0,
              0,     0,     0,     0,     0,     0],
         [  101,  2003,  1996,  2783,  2051, 22851,  2102,  2030,  8827,  2102,
            102,     0,     0,     0,     0,     0],
         [  101,  2054,  2515,  2632,  8569, 27833,  2812,   102,     0,     0,
              0,     0,     0,     0,     0,     0],
         [  101,  2043,  2064,  2017,  4929,  2317,  6471,   102,     0,     0,
              0,     0,     0,     0,     0,     0],
         [  101,  6207, 11306,  1005,  1055,  5592,  4003,  5703,  3042,  2193,
            102,     0,     0,     0,     0,     0],
         [  101,  2054,  2064,  1045,  6570,  6001,  2007, 13908, 13597,   102,
              0,     0,     0,     0,     0,     0],
         [  101,  2079, 21122, 15580,  9880,   102,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0],
         [  101,  2054,  2217,  2003,  2115, 22524,  200

In [22]:
import torch
import torch.distributed as dist
import os

def init_distributed_single_gpu():
    # Set environment variables
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    
    # Initialize process group
    dist.init_process_group(
        backend="nccl",  # Use NCCL backend for GPU
        rank=0,          # Single GPU, so rank is 0
        world_size=1     # Total number of processes is 1
    )
    
    # Set the device
    torch.cuda.set_device(0)

init_distributed_single_gpu()

In [31]:
model = model.to("cuda")
query_batch_dict = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in query_batch_dict.items()}
doc_batch_dict = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in doc_batch_dict.items()}

In [32]:
outputs = model(query=query_batch_dict, passage=doc_batch_dict)

In [33]:
outputs

BiencoderOutput(q_reps=tensor([[ 0.0355, -0.0662,  0.0463,  ..., -0.0304,  0.0158,  0.0318],
        [ 0.0407, -0.0098,  0.0277,  ...,  0.0115,  0.0487, -0.0128],
        [ 0.0358,  0.0359,  0.0065,  ..., -0.0172, -0.0137, -0.0005],
        ...,
        [ 0.0457,  0.0100, -0.0348,  ..., -0.0243,  0.0192, -0.0326],
        [-0.0131,  0.0027, -0.0103,  ..., -0.0231,  0.0251,  0.0269],
        [ 0.0325,  0.0376, -0.0209,  ..., -0.0306, -0.0295, -0.0003]],
       device='cuda:0', grad_fn=<DivBackward0>), p_reps=tensor([[-2.3926e-02,  2.9929e-04,  1.7268e-02,  ...,  1.2327e-03,
          3.0578e-03,  4.8459e-02],
        [ 4.3832e-02,  6.3219e-03,  1.8771e-03,  ...,  2.9059e-02,
          2.4033e-02,  1.9232e-02],
        [ 1.2833e-02, -2.8508e-02, -2.6362e-02,  ...,  3.7568e-02,
         -2.4967e-03,  1.3990e-03],
        ...,
        [ 2.5597e-02, -1.2830e-02, -6.7153e-05,  ...,  2.8175e-02,
          1.7792e-02,  1.4907e-02],
        [ 1.8049e-02,  3.0032e-02,  8.6424e-03,  ...,  2.8988e

In [35]:
batch = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in batch.items()}

In [37]:
trainer.state.epoch = 0
trainer.compute_loss(model, batch)

tensor(10.0293, device='cuda:0', grad_fn=<MulBackward0>)