In [1]:
from splade.utils.hydra import hydra_chdir
from omegaconf import DictConfig

In [2]:
toy_exp_dict = DictConfig({
    'config': {
        'lr': 2e-5,
        'seed': 123,
        'gradient_accumulation_steps': 1,
        'weight_decay': 0.01,
        'validation_metrics': ['MRR@10', 'recall@100', 'recall@200', 'recall@500'],
        'pretrained_no_yamlconfig': False,
        'nb_iterations': 10,
        'train_batch_size': 6,
        'eval_batch_size': 8,
        'index_retrieve_batch_size': 6,
        'record_frequency': 3,
        'train_monitoring_freq': 2,
        'warmup_steps': 5,
        'max_length': 10,
        'fp16': False,
        'augment_pairs': 'in_batch_negatives',
        'matching_type': 'splade',
        'monitoring_ckpt': 'loss',
        'loss': 'InBatchPairwiseNLL',
        'regularizer': {
            'FLOPS': {
                'lambda_q': 5e-4,
                'lambda_d': 3e-4,
                'T': 3,
                'targeted_rep': 'rep',
                'reg': 'FLOPS'
            }
        },
        'tokenizer_type': 'distilbert-base-uncased',
        'top_k': 5,
        'threshold': 0.4,
        'eval_metric': [['mrr_10', 'recall']],
        'retrieval_name': ['TOY'],
        'checkpoint_dir': 'experiments/debug/checkpoint',
        'index_dir': 'experiments/debug/index',
        'out_dir': 'experiments/debug/out'
    },
    'data': {
        'type': 'triplets',
        'TRAIN_DATA_DIR': 'data/toy_data/triplets',
        'VALIDATION_SIZE_FOR_LOSS': 20,
        'VALIDATION_FULL_RANKING': {
            'D_COLLECTION_PATH': 'data/toy_data/val_collection',
            'Q_COLLECTION_PATH': 'data/toy_data/val_queries',
            'QREL_PATH': 'data/toy_data/qrel/qrel.json',
            'TOP_K': 20
        },
        'COLLECTION_PATH': 'data/toy_data/full_collection',
        'Q_COLLECTION_PATH': ['data/toy_data/dev_queries'],
        'EVAL_QREL_PATH': ['data/toy_data/qrel/qrel.json'],
        'flops_queries': 'data/toy_data/dev_queries'
    },
    'init_dict': {
        'model_type_or_dir': 'distilbert-base-uncased',
        'model_type_or_dir_q': None,
        'freeze_d_model': 0,
        'agg': 'max',
        'fp16': False
    }
})
splade_exp_dict = DictConfig({
    'config': {
        'lr': 2e-05,
        'seed': 123,
        'gradient_accumulation_steps': 1,
        'weight_decay': 0.01,
        'validation_metrics': ['MRR@10', 'recall@100', 'recall@200', 'recall@500'],
        'pretrained_no_yamlconfig': False,
        'nb_iterations': 150000,
        'train_batch_size': 128,
        'eval_batch_size': 600,
        'index_retrieve_batch_size': 500,
        'record_frequency': 10000,
        'train_monitoring_freq': 500,
        'warmup_steps': 6000,
        'max_length': 256,
        'fp16': True,
        'matching_type': 'splade',
        'monitoring_ckpt': 'MRR@10',
        'loss': 'InBatchPairwiseNLL',
        'tokenizer_type': 'distilbert-base-uncased',
        'top_k': 1000,
        'eval_metric': [['mrr_10', 'recall'], ['ndcg_cut'], ['mrr_10', 'recall'], ['ndcg_cut'], ['mrr_10', 'recall']],
        'retrieval_name': ['MSMARCO', 'TREC_DL_2019', 'TREC_DL_2019', 'TREC_DL_2020', 'TREC_DL_2020'],
        'threshold': 0,
        'checkpoint_dir': 'experiments/splade/checkpoint',
        'index_dir': 'experiments/splade/index',
        'out_dir': 'experiments/splade/out'
    },
    'data': {
        'type': 'triplets',
        'TRAIN_DATA_DIR': 'data/msmarco/triplets',
        'VALIDATION_SIZE_FOR_LOSS': 60000,
        'VALIDATION_FULL_RANKING': {
            'D_COLLECTION_PATH': 'data/msmarco/val_retrieval/collection',
            'Q_COLLECTION_PATH': 'data/msmarco/val_retrieval/queries',
            'QREL_PATH': 'data/msmarco/val_retrieval/qrel.json',
            'TOP_K': 500
        },
        'COLLECTION_PATH': 'data/msmarco/full_collection',
        'Q_COLLECTION_PATH': [
            'data/msmarco/dev_queries',
            'data/msmarco/TREC_DL_2019/queries_2019',
            'data/msmarco/TREC_DL_2020/queries_2020'
        ],
        'EVAL_QREL_PATH': [
            'data/msmarco/dev_qrel.json',
            'data/msmarco/TREC_DL_2019/qrel.json',
            'data/msmarco/TREC_DL_2019/qrel_binary.json',
            'data/msmarco/TREC_DL_2020/qrel.json',
            'data/msmarco/TREC_DL_2020/qrel_binary.json'
        ],
        'flops_queries': 'data/msmarco/all_dev_queries/'
    },
    'init_dict': {
        'model_type_or_dir': 'distilbert-base-uncased',
        'model_type_or_dir_q': None,
        'freeze_d_model': 0,
        'agg': 'max',
        'fp16': True
    }
})
exp_dict = splade_exp_dict
hydra_chdir(exp_dict)

config:
  lr: 2.0e-05
  seed: 123
  gradient_accumulation_steps: 1
  weight_decay: 0.01
  validation_metrics:
  - MRR@10
  - recall@100
  - recall@200
  - recall@500
  pretrained_no_yamlconfig: false
  nb_iterations: 150000
  train_batch_size: 128
  eval_batch_size: 600
  index_retrieve_batch_size: 500
  record_frequency: 10000
  train_monitoring_freq: 500
  warmup_steps: 6000
  max_length: 256
  fp16: true
  matching_type: splade
  monitoring_ckpt: MRR@10
  loss: InBatchPairwiseNLL
  tokenizer_type: distilbert-base-uncased
  top_k: 1000
  eval_metric:
  - - mrr_10
    - recall
  - - ndcg_cut
  - - mrr_10
    - recall
  - - ndcg_cut
  - - mrr_10
    - recall
  retrieval_name:
  - MSMARCO
  - TREC_DL_2019
  - TREC_DL_2019
  - TREC_DL_2020
  - TREC_DL_2020
  threshold: 0
  checkpoint_dir: experiments/splade/checkpoint
  index_dir: experiments/splade/index
  out_dir: experiments/splade/out
data:
  type: triplets
  TRAIN_DATA_DIR: data/msmarco/triplets
  VALIDATION_SIZE_FOR_LOSS: 60000
  V

In [3]:
import os

import hydra
import torch
from omegaconf import open_dict
from torch.utils import data

from conf.CONFIG_CHOICE import CONFIG_NAME, CONFIG_PATH
from splade.datasets.dataloaders import CollectionDataLoader, SiamesePairsDataLoader, DistilSiamesePairsDataLoader
from splade.datasets.datasets import PairsDatasetPreLoad, DistilPairsDatasetPreLoad, MsMarcoHardNegatives, \
    CollectionDatasetPreLoad
from splade.losses.regularization import init_regularizer, RegWeightScheduler
from splade.models.models_utils import get_model
from splade.optim.bert_optim import init_simple_bert_optim
from splade.tasks.transformer_evaluator import SparseApproxEvalWrapper
from splade.tasks.transformer_trainer import SiameseTransformerTrainer
from splade.utils.utils import set_seed, restore_model, get_initialize_config, get_loss, set_seed_from_config

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
exp_dict, config, init_dict, _ = get_initialize_config(exp_dict, train=True)
model = get_model(config, init_dict)
model

config:
  lr: 2.0e-05
  seed: 123
  gradient_accumulation_steps: 1
  weight_decay: 0.01
  validation_metrics:
  - MRR@10
  - recall@100
  - recall@200
  - recall@500
  pretrained_no_yamlconfig: false
  nb_iterations: 150000
  train_batch_size: 128
  eval_batch_size: 600
  index_retrieve_batch_size: 500
  record_frequency: 10000
  train_monitoring_freq: 500
  warmup_steps: 6000
  max_length: 256
  fp16: true
  matching_type: splade
  monitoring_ckpt: MRR@10
  loss: InBatchPairwiseNLL
  tokenizer_type: distilbert-base-uncased
  top_k: 1000
  eval_metric:
  - - mrr_10
    - recall
  - - ndcg_cut
  - - mrr_10
    - recall
  - - ndcg_cut
  - - mrr_10
    - recall
  retrieval_name:
  - MSMARCO
  - TREC_DL_2019
  - TREC_DL_2019
  - TREC_DL_2020
  - TREC_DL_2020
  threshold: 0
  checkpoint_dir: experiments/splade/checkpoint
  index_dir: experiments/splade/index
  out_dir: experiments/splade/out
data:
  type: triplets
  TRAIN_DATA_DIR: data/msmarco/triplets
  VALIDATION_SIZE_FOR_LOSS: 60000
  V

Splade(
  (transformer_rep): TransformerRep(
    (transformer): DistilBertForMaskedLM(
      (activation): GELUActivation()
      (distilbert): DistilBertModel(
        (embeddings): Embeddings(
          (word_embeddings): Embedding(30522, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (transformer): Transformer(
          (layer): ModuleList(
            (0-5): 6 x TransformerBlock(
              (attention): DistilBertSdpaAttention(
                (dropout): Dropout(p=0.1, inplace=False)
                (q_lin): Linear(in_features=768, out_features=768, bias=True)
                (k_lin): Linear(in_features=768, out_features=768, bias=True)
                (v_lin): Linear(in_features=768, out_features=768, bias=True)
                (out_lin): Linear(in_features=768, out_features=768, bias=True)
              )
 

In [5]:
random_seed = set_seed_from_config(config)

optimizer, scheduler = init_simple_bert_optim(model, lr=config["lr"], warmup_steps=config["warmup_steps"],
                                                weight_decay=config["weight_decay"],
                                                num_training_steps=config["nb_iterations"])

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

################################################################
# CHECK IF RESUME TRAINING
################################################################
iterations = (1, config["nb_iterations"] + 1)  # tuple with START and END
regularizer = None



In [6]:
model.to(device)

loss = get_loss(config)

In [7]:
if "regularizer" in config and regularizer is None:  # else regularizer is loaded
    output_dim = model.module.output_dim if hasattr(model, "module") else model.output_dim
    regularizer = {"eval": {"L0": {"loss": init_regularizer("L0")},
                            "sparsity_ratio": {"loss": init_regularizer("sparsity_ratio",
                                                                        output_dim=output_dim)}},
                    "train": {}}
    if config["regularizer"] == "eval_only":
        # just in the case we train a model without reg but still want the eval metrics like L0
        pass
    else:
        for reg in config["regularizer"]:
            temp = {"loss": init_regularizer(config["regularizer"][reg]["reg"]),
                    "targeted_rep": config["regularizer"][reg]["targeted_rep"]}
            d_ = {}
            if "lambda_q" in config["regularizer"][reg]:
                d_["lambda_q"] = RegWeightScheduler(config["regularizer"][reg]["lambda_q"],
                                                    config["regularizer"][reg]["T"])
            if "lambda_d" in config["regularizer"][reg]:
                d_["lambda_d"] = RegWeightScheduler(config["regularizer"][reg]["lambda_d"],
                                                    config["regularizer"][reg]["T"])
            temp["lambdas"] = d_  # it is possible to have reg only on q or d if e.g. you only specify lambda_q
            # in the reg config
            # targeted_rep is just used to indicate which rep to constrain (if e.g. the model outputs several
            # representations)
            # the common case: model outputs "rep" (in forward) and this should be the value for this targeted_rep
            regularizer["train"][reg] = temp

# fix for current in batch neg losses that break on last batch
if config["loss"] in ("InBatchNegHingeLoss", "InBatchPairwiseNLL"):
    drop_last = True
else:
    drop_last = False

In [8]:
if exp_dict["data"].get("type", "") == "triplets":
    data_train = PairsDatasetPreLoad(data_dir=exp_dict["data"]["TRAIN_DATA_DIR"])
    train_mode = "triplets"
elif exp_dict["data"].get("type", "") == "triplets_with_distil":
    data_train = DistilPairsDatasetPreLoad(data_dir=exp_dict["data"]["TRAIN_DATA_DIR"])
    train_mode = "triplets_with_distil"
elif exp_dict["data"].get("type", "") == "hard_negatives":
    data_train = MsMarcoHardNegatives(
        dataset_path=exp_dict["data"]["TRAIN"]["DATASET_PATH"],
        document_dir=exp_dict["data"]["TRAIN"]["D_COLLECTION_PATH"],
        query_dir=exp_dict["data"]["TRAIN"]["Q_COLLECTION_PATH"],
        qrels_path=exp_dict["data"]["TRAIN"]["QREL_PATH"])
    train_mode = "triplets_with_distil"
else:
    raise ValueError("provide valid data type for training")

Preloading dataset


29768it [00:00, 285802.02it/s]

39780811it [01:17, 514652.19it/s]


In [9]:
example = data_train[0]

In [10]:
example

('is a little caffeine ok during pregnancy',
 'We donâ\x80\x99t know a lot about the effects of caffeine during pregnancy on you and your baby. So itâ\x80\x99s best to limit the amount you get each day. If youâ\x80\x99re pregnant, limit caffeine to 200 milligrams each day. This is about the amount in 1Â½ 8-ounce cups of coffee or one 12-ounce cup of coffee.',
 'It is generally safe for pregnant women to eat chocolate because studies have shown to prove certain benefits of eating chocolate during pregnancy. However, pregnant women should ensure their caffeine intake is below 200 mg per day.')

In [11]:
val_loss_loader = None  # default
if "VALIDATION_SIZE_FOR_LOSS" in exp_dict["data"]:
    print("initialize loader for validation loss")
    print("split train, originally {} pairs".format(len(data_train)))
    data_train, data_val = torch.utils.data.random_split(data_train, lengths=[
        len(data_train) - exp_dict["data"]["VALIDATION_SIZE_FOR_LOSS"],
        exp_dict["data"]["VALIDATION_SIZE_FOR_LOSS"]])
    print("train: {} pairs ~~ val: {} pairs".format(len(data_train), len(data_val)))
    if train_mode == "triplets":
        val_loss_loader = SiamesePairsDataLoader(dataset=data_val, batch_size=config["eval_batch_size"],
                                                    shuffle=False,
                                                    num_workers=4,
                                                    tokenizer_type=config["tokenizer_type"],
                                                    max_length=config["max_length"], drop_last=drop_last)
    elif train_mode == "triplets_with_distil":
        val_loss_loader = DistilSiamesePairsDataLoader(dataset=data_val, batch_size=config["eval_batch_size"],
                                                        shuffle=False,
                                                        num_workers=4,
                                                        tokenizer_type=config["tokenizer_type"],
                                                        max_length=config["max_length"], drop_last=drop_last)
    else:
        raise NotImplementedError

if train_mode == "triplets":
    train_loader = SiamesePairsDataLoader(dataset=data_train, batch_size=config["train_batch_size"], shuffle=True,
                                            num_workers=4,
                                            tokenizer_type=config["tokenizer_type"],
                                            max_length=config["max_length"], drop_last=drop_last)
elif train_mode == "triplets_with_distil":
    train_loader = DistilSiamesePairsDataLoader(dataset=data_train, batch_size=config["train_batch_size"],
                                                shuffle=True,
                                                num_workers=4,
                                                tokenizer_type=config["tokenizer_type"],
                                                max_length=config["max_length"], drop_last=drop_last)
else:
    raise NotImplementedError

val_evaluator = None
if "VALIDATION_FULL_RANKING" in exp_dict["data"]:
    with open_dict(config):
        config["val_full_rank_qrel_path"] = exp_dict["data"]["VALIDATION_FULL_RANKING"]["QREL_PATH"]
    full_ranking_d_collection = CollectionDatasetPreLoad(
        data_dir=exp_dict["data"]["VALIDATION_FULL_RANKING"]["D_COLLECTION_PATH"], id_style="row_id")
    full_ranking_d_loader = CollectionDataLoader(dataset=full_ranking_d_collection,
                                                    tokenizer_type=config["tokenizer_type"],
                                                    max_length=config["max_length"],
                                                    batch_size=config["eval_batch_size"],
                                                    shuffle=False, num_workers=4)
    full_ranking_q_collection = CollectionDatasetPreLoad(
        data_dir=exp_dict["data"]["VALIDATION_FULL_RANKING"]["Q_COLLECTION_PATH"], id_style="row_id")
    full_ranking_q_loader = CollectionDataLoader(dataset=full_ranking_q_collection,
                                                    tokenizer_type=config["tokenizer_type"],
                                                    max_length=config["max_length"], batch_size=1,
                                                    # TODO fix: bs currently set to 1
                                                    shuffle=False, num_workers=4)
    val_evaluator = SparseApproxEvalWrapper(model,
                                            config={"top_k": exp_dict["data"]["VALIDATION_FULL_RANKING"]["TOP_K"],
                                                    "out_dir": os.path.join(config["checkpoint_dir"],
                                                                            "val_full_ranking")
                                                    },
                                            collection_loader=full_ranking_d_loader,
                                            q_loader=full_ranking_q_loader,
                                            restore=False)

initialize loader for validation loss
split train, originally 39780811 pairs
train: 39720811 pairs ~~ val: 60000 pairs
Preloading dataset


276142it [00:00, 434448.78it/s]


Preloading dataset


1600it [00:00, 1159247.95it/s]




In [12]:
for batch in train_loader:
    break
batch

{'q_input_ids': tensor([[ 101, 2054, 2003,  ...,    0,    0,    0],
         [ 101, 7095, 8738,  ...,    0,    0,    0],
         [ 101, 2003, 1037,  ...,    0,    0,    0],
         ...,
         [ 101, 2054, 2003,  ...,    0,    0,    0],
         [ 101, 2054, 5320,  ...,    0,    0,    0],
         [ 101, 4127, 1997,  ...,    0,    0,    0]]),
 'q_attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 'pos_input_ids': tensor([[  101,  2703,  1038,  ...,     0,     0,     0],
         [  101,  2012,  5369,  ...,     0,     0,     0],
         [  101,  2852,  1012,  ...,     0,     0,     0],
         ...,
         [  101,  4517,  4255,  ...,     0,     0,     0],
         [  101,  4958, 13340,  ...,     0,     0,     0],
         [  101,  2045,  2024,  ...,     0,     0,     0]]),
 'pos_attention_ma

In [13]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

In [14]:
for i in range(len(batch['q_input_ids'])):
    print(train_loader.tokenizer.decode(batch['q_input_ids'][i]))
    print(train_loader.tokenizer.decode(batch['pos_input_ids'][i]))
    print(train_loader.tokenizer.decode(batch['neg_input_ids'][i]))
    print()

[CLS] what is paul blart mall cops real name [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
[CLS] paul blart : mall cop. from wikipedia, the free encyclopedia. paul blart : mall cop is a 2009 american action comedy film directed by steve carr and co - written by kevin james, who stars as the title character, paul blart. filming began in february 2008 with most of the shooting taking place at the burlington mall in burlington, massachusetts. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [P

In [15]:
trainer = SiameseTransformerTrainer(model=model, iterations=iterations, loss=loss, optimizer=optimizer,
                                    config=config, scheduler=scheduler,
                                    train_loader=train_loader, validation_loss_loader=val_loss_loader,
                                    validation_evaluator=val_evaluator,
                                    regularizer=regularizer)

initialize trainer...
 --- total number parameters: 66985530
 === trainer config === 


In [16]:
for k, v in batch.items():
    batch[k] = v.to(device)

In [17]:
# all pos and neg examples are of the same shape!
batch["neg_input_ids"], batch["pos_input_ids"]

(tensor([[  101, 13584,  3253,  ...,     0,     0,     0],
         [  101,  2045,  2003,  ...,     0,     0,     0],
         [  101,  2011,  4806,  ...,     0,     0,     0],
         ...,
         [  101,  8085,  1024,  ...,     0,     0,     0],
         [  101, 15321,  6305,  ...,     0,     0,     0],
         [  101, 10247,  1004,  ...,     0,     0,     0]], device='cuda:0'),
 tensor([[  101,  2703,  1038,  ...,     0,     0,     0],
         [  101,  2012,  5369,  ...,     0,     0,     0],
         [  101,  2852,  1012,  ...,     0,     0,     0],
         ...,
         [  101,  4517,  4255,  ...,     0,     0,     0],
         [  101,  4958, 13340,  ...,     0,     0,     0],
         [  101,  2045,  2024,  ...,     0,     0,     0]], device='cuda:0'))

In [18]:
for row in batch["neg_input_ids"]:
    print(train_loader.tokenizer.decode(row))


[CLS] sonny offered to buy paul a drink as a way of thanking him for the help. paul insisted that sonny didn ' t need to thank him ; he admitted, however, that he would like a drink. sonny acknowledged that paul wanted him to keep his head down and leave deimos alone. paul conceded that it wasn ' t in sonny ' s nature to do that. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [P

In [40]:
preds = trainer.forward(batch)

  with torch.cuda.amp.autocast() if self.fp16 else amp.NullContextManager():
  with torch.cuda.amp.autocast() if self.fp16 else NullContextManager():
  with torch.cuda.amp.autocast() if self.fp16 else NullContextManager():
  with torch.cuda.amp.autocast() if self.fp16 else NullContextManager():
  with torch.cuda.amp.autocast() if self.fp16 else NullContextManager():
  with torch.cuda.amp.autocast() if self.fp16 else NullContextManager():


In [26]:
preds

{'pos_d_rep': tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0606, 0.0000],
         [0.7258, 0.7134, 0.7828,  ..., 0.9857, 0.7774, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
        device='cuda:0', grad_fn=<MaxBackward0>),
 'pos_q_rep': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0',
        grad_fn=<MaxBackward0>),
 'pos_score': tensor([[ 7115.9946],
         [ 3571.4072],
         [ 2081.2451],
         [ 3470.3240],
         [ 3201.7109],
         [ 6247.2480],
         [ 3102.1016],
         [ 2221.3164],
         [ 

In [27]:
batch["pos_input_ids"].shape

torch.Size([128, 164])

In [21]:
list(preds.keys())

['pos_d_rep', 'pos_q_rep', 'pos_score', 'neg_d_rep', 'neg_q_rep', 'neg_score']

In [28]:
preds["pos_d_rep"].shape

torch.Size([128, 30522])

In [30]:
torch.sum(preds["pos_d_rep"][0] > 0)

tensor(15211, device='cuda:0')

In [31]:
preds["pos_score"].shape

torch.Size([128, 1])

In [32]:
preds["pos_score"][0]

tensor([7115.9946], device='cuda:0', grad_fn=<SelectBackward0>)

In [33]:
preds["neg_score"][0]

tensor([3493.5142], device='cuda:0', grad_fn=<SelectBackward0>)

In [42]:
loss = trainer.loss(preds)
loss

tensor(168.5342, device='cuda:0', grad_fn=<MeanBackward0>)

In [41]:
trainer.loss

<splade.losses.pairwise.InBatchPairwiseNLL at 0x7f424b373850>