In [3]:
import os
os.chdir("/traindata/maksim/repos/unilm/simlm/src")
!pwd

/traindata/maksim/repos/unilm/simlm/src


In [4]:
import torch
torch.cuda.is_available()

True

In [5]:
import logging

import torch
from typing import Dict
from functools import partial
from transformers.utils.logging import enable_explicit_format
from transformers.trainer_callback import PrinterCallback
from transformers import (
    AutoTokenizer,
    HfArgumentParser,
    EvalPrediction,
    Trainer,
    set_seed,
    PreTrainedTokenizerFast
)

from logger_config import logger, LoggerCallback
from config import Arguments
from trainers import BiencoderTrainer
from loaders import RetrievalDataLoader
from collators import BiencoderCollator
from metrics import accuracy, batch_mrr
from models import BiencoderModel

def _common_setup(args: Arguments):
    if args.process_index > 0:
        logger.setLevel(logging.WARNING)
    enable_explicit_format()
    set_seed(args.seed)


def _compute_metrics(args: Arguments, eval_pred: EvalPrediction) -> Dict[str, float]:
    # field consistent with BiencoderOutput
    preds = eval_pred.predictions
    scores = torch.tensor(preds[-1]).float()
    labels = torch.arange(0, scores.shape[0], dtype=torch.long) * args.train_n_passages
    labels = labels % scores.shape[1]

    topk_metrics = accuracy(output=scores, target=labels, topk=(1, 3))
    mrr = batch_mrr(output=scores, target=labels)

    return {'mrr': mrr, 'acc1': topk_metrics[0], 'acc3': topk_metrics[1]}

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
os.environ["DATA_DIR"] = "./data/msmarco_bm25_official/"
os.environ["OUTPUT_DIR"] = "./tmp/"
os.environ["CUDA_VISIBLE_DEVICES"] = "8"

In [7]:
import sys
sys.argv = ['src/train_biencoder.py', '--deepspeed', '/traindata/maksim/repos/unilm/simlm/ds_config.json', '--model_name_or_path', 'intfloat/simlm-base-msmarco', '--per_device_train_batch_size', '16', '--per_device_eval_batch_size', '16', '--kd_mask_hn', 'False', '--kd_cont_loss_weight', '0.2', '--seed', '123', '--do_train', '--do_kd_biencoder', '--t', '0.02', '--fp16', '--train_file', '/traindata/maksim/repos/unilm/simlm/data/msmarco_distillation//kd_train.jsonl', '--validation_file', '/traindata/maksim/repos/unilm/simlm/data/msmarco_distillation//kd_dev.jsonl', '--q_max_len', '32', '--p_max_len', '144', '--train_n_passages', '24', '--dataloader_num_workers', '1', '--num_train_epochs', '6', '--learning_rate', '3e-5', '--warmup_steps', '1000', '--share_encoder', 'True', '--logging_steps', '50', '--output_dir', '/traindata/maksim/repos/unilm/simlm/checkpoint/distilled_biencoder/', '--data_dir', '/traindata/maksim/repos/unilm/simlm/data/msmarco_distillation/', '--save_total_limit', '10', '--save_strategy', 'epoch', '--evaluation_strategy', 'epoch', '--load_best_model_at_end', '--metric_for_best_model', 'mrr', '--greater_is_better', 'True', '--remove_unused_columns', 'False', '--overwrite_output_dir', '--disable_tqdm', 'True', '--report_to', 'none']
parser = HfArgumentParser((Arguments,))
args: Arguments = parser.parse_args_into_dataclasses()[0]
_common_setup(args)
args



In [8]:
tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(args.model_name_or_path)
model: BiencoderModel = BiencoderModel.build(args=args)
logger.info(model)
logger.info('Vocab size: {}'.format(len(tokenizer)))

data_collator = BiencoderCollator(
    tokenizer=tokenizer,
    pad_to_multiple_of=8 if args.fp16 else None)

You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[2024-11-12 13:29:47,576 INFO] BiencoderModel(
  (lm_q): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
    

In [9]:
retrieval_data_loader = RetrievalDataLoader(args=args, tokenizer=tokenizer)
train_dataset = retrieval_data_loader.train_dataset
eval_dataset = retrieval_data_loader.eval_dataset

trainer: Trainer = BiencoderTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset if args.do_train else None,
    eval_dataset=eval_dataset if args.do_eval else None,
    data_collator=data_collator,
    compute_metrics=partial(_compute_metrics, args),
    tokenizer=tokenizer,
)
trainer.remove_callback(PrinterCallback)
trainer.add_callback(LoggerCallback)
retrieval_data_loader.trainer = trainer
model.trainer = trainer

[2024-11-12 13:29:47,829 INFO] Sample 27453 of the training set: {'query_id': '689043', 'query': 'what is a lime rickey?', 'positives': {'doc_id': ['559262'], 'score': [2.54883]}, 'negatives': {'doc_id': ['945415', '2225798', '2174436', '4025338', '1351518', '189410', '3841208', '559263', '5637829', '4967760', '4070900', '3476454', '559258', '3799440', '5035615', '5055167', '7631620', '3030057', '8075849', '7609175', '559257', '4810917', '3653953', '51826', '945414', '4967767', '7003968', '559259', '4038251', '4573199', '3991139', '559264', '4057818', '7157577', '3438664', '3991141', '1875093', '1378439', '8123451', '7610009', '2653049', '5984731', '3991140', '6513406', '2526329', '2560919', '5018273', '1378438', '8410762', '7631615', '3369864', '7631617', '5393019', '4718163', '6513415', '2526334', '1555993', '7743876', '3007692', '3081016', '4070908', '5344416', '2437633', '1924078', '3897853', '3799438', '7612325', '4573192', '2710018', '2936717', '2936722', '1379731', '3297383', '5

In [10]:
example = train_dataset[0]
list(example.keys())

['q_input_ids',
 'q_token_type_ids',
 'q_attention_mask',
 'd_input_ids',
 'd_token_type_ids',
 'd_attention_mask',
 'kd_labels']

In [11]:
len(example["kd_labels"]), len(example["d_input_ids"])


(24, 24)

In [12]:
[len(elem) for elem in example["d_input_ids"] if elem is not None]


[57,
 144,
 57,
 77,
 83,
 120,
 47,
 70,
 65,
 75,
 58,
 116,
 130,
 60,
 79,
 119,
 47,
 134,
 116,
 85,
 84,
 56,
 144,
 132]

In [13]:
example["kd_labels"]

[-1.2793,
 -4.66406,
 -4.52734,
 -6.86719,
 -4.89062,
 -6.89844,
 -3.67969,
 -4.18359,
 -4.62891,
 -3.20312,
 -3.77539,
 -6.39844,
 -4.60547,
 -4.36719,
 -6.19531,
 -6.30078,
 -5.84375,
 -4.19922,
 -3.94141,
 -4.52734,
 -5.3125,
 -4.17188,
 -4.52734,
 -4.73438]

In [14]:
len(example["q_input_ids"])

16

In [15]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("intfloat/simlm-base-msmarco")
tokenizer.decode(example["q_input_ids"])



'[CLS] ) what was the immediate impact of the success of the manhattan project? [SEP]'

In [16]:
for elem in example["d_input_ids"]:
    print(tokenizer.decode(elem))


[CLS] introduction [SEP] the presence of communication amid scientific minds was equally important to the success of the manhattan project as scientific intellect was. the only cloud hanging over the impressive achievement of the atomic researchers and engineers is what their success truly meant ; hundreds of thousands of innocent lives obliterated. [SEP]
[CLS] 51f. the manhattan project [SEP] by the summer of 1945, oppenheimer was ready to test the first bomb. on july 16, 1945, at trinity site near alamogordo, new mexico, scientists of the manhattan project readied themselves to watch the detonation of the world's first atomic bomb. the device was affixed to a 100 - foot tower and discharged just before dawn. he main assembly plant was built at los alamos, new mexico. robert oppenheimer was put in charge of putting the pieces together at los alamos. after the final bill was tallied, nearly $ 2 billion had been spent on research and development of the atomic bomb. the manhattan project

In [17]:
from torch.utils.data import DataLoader

dataloader_params = {
    "batch_size": trainer._train_batch_size,
    "collate_fn": data_collator,
    "num_workers": trainer.args.dataloader_num_workers,
    "pin_memory": trainer.args.dataloader_pin_memory,
    "persistent_workers": trainer.args.dataloader_persistent_workers,
}

if not isinstance(train_dataset, torch.utils.data.IterableDataset):
    dataloader_params["sampler"] = trainer._get_train_sampler()
    dataloader_params["drop_last"] = trainer.args.dataloader_drop_last
    dataloader_params["prefetch_factor"] = trainer.args.dataloader_prefetch_factor

train_dataloader = DataLoader(train_dataset, **dataloader_params)

In [18]:
for batch in train_dataloader:
    break
batch

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


{'q_input_ids': tensor([[  101,  2129,  2172,  2003, 24728,  8566, 11058,   102,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0],
        [  101,  2054,  2186,  2001,  7673,  9303, 10762,  1999,   102,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0],
        [  101,  2054,  2221,  2003, 10493, 12436,  2284,  1999,   102,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0],
        [  101,  2054,  4295,  2515,  2552,  5740,  8029,  9623,  3426,  1999,
          7125,   102,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0],
        [  101, 18833,  2000, 29533,  1013, 29533,  8197,  2958,   102,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0],
    

In [19]:
batch["q_input_ids"].shape

torch.Size([16, 24])

In [20]:
batch["d_input_ids"].shape

torch.Size([384, 144])

In [21]:
from trainers.biencoder_trainer import _unpack_qp
query_batch_dict, doc_batch_dict = _unpack_qp(batch)
query_batch_dict

{'input_ids': tensor([[  101,  2129,  2172,  2003, 24728,  8566, 11058,   102,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0],
         [  101,  2054,  2186,  2001,  7673,  9303, 10762,  1999,   102,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0],
         [  101,  2054,  2221,  2003, 10493, 12436,  2284,  1999,   102,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0],
         [  101,  2054,  4295,  2515,  2552,  5740,  8029,  9623,  3426,  1999,
           7125,   102,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0],
         [  101, 18833,  2000, 29533,  1013, 29533,  8197,  2958,   102,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0, 

In [22]:
import torch
import torch.distributed as dist
import os

def init_distributed_single_gpu():
    # Set environment variables
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12356"
    
    # Initialize process group
    dist.init_process_group(
        backend="nccl",  # Use NCCL backend for GPU
        rank=0,          # Single GPU, so rank is 0
        world_size=1     # Total number of processes is 1
    )
    
    # Set the device
    torch.cuda.set_device(1)

init_distributed_single_gpu()

In [23]:
DEVICE = "cuda"
model = model.to(DEVICE)
query_batch_dict = {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v for k, v in query_batch_dict.items()}
doc_batch_dict = {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v for k, v in doc_batch_dict.items()}


In [25]:
outputs = model(query=query_batch_dict, passage=doc_batch_dict)
outputs

BiencoderOutput(q_reps=tensor([[ 0.0417, -0.0576,  0.0260,  ..., -0.0078,  0.0197,  0.0434],
        [-0.0273, -0.0229, -0.0409,  ...,  0.0051,  0.0109,  0.0149],
        [ 0.0280,  0.0017,  0.0068,  ..., -0.0052,  0.0060, -0.0040],
        ...,
        [ 0.0115,  0.0308, -0.0189,  ..., -0.0083,  0.0140, -0.0190],
        [-0.0075, -0.0525, -0.0408,  ..., -0.0002,  0.0039, -0.0084],
        [ 0.0630,  0.0090, -0.0447,  ..., -0.0181,  0.0172, -0.0029]],
       device='cuda:1', grad_fn=<DivBackward0>), p_reps=tensor([[-0.0107, -0.0948,  0.0390,  ...,  0.0099,  0.0322,  0.0229],
        [-0.0070, -0.0582,  0.0135,  ...,  0.0241, -0.0216,  0.0216],
        [-0.0309, -0.0687,  0.0285,  ..., -0.0108, -0.0566,  0.0010],
        ...,
        [ 0.0161, -0.0075, -0.0256,  ..., -0.0083,  0.0362, -0.0082],
        [-0.0584, -0.0090, -0.0012,  ...,  0.0298,  0.0180,  0.0149],
        [ 0.0006, -0.0477, -0.0298,  ...,  0.0213,  0.0414,  0.0023]],
       device='cuda:1', grad_fn=<DivBackward0>), loss