In [1]:
import logging
import os
from pathlib import Path
import torch
from torch import nn
from transformers import AutoConfig, AutoTokenizer
from transformers import (
    HfArgumentParser,
    set_seed,
)
from transformers import Trainer

import sys
sys.path.append("..")
from src.arguments import (
    ModelArguments,
    DataArguments,
    RetrieverTrainingArguments as TrainingArguments,
)
from src.data import TrainDatasetForEmbedding, EmbedCollator
from src.modeling import BiEncoderModel
# from trainer import BiTrainer


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from hf import model_args, data_args, training_args
from data import dataset, data_collator

In [4]:
num_labels = 1

tokenizer = AutoTokenizer.from_pretrained(
    (
        model_args.tokenizer_name
        if model_args.tokenizer_name
        else model_args.model_name_or_path
    ),
    cache_dir=model_args.cache_dir,
    use_fast=False,
)

config = AutoConfig.from_pretrained(
    (
        model_args.config_name
        if model_args.config_name
        else model_args.model_name_or_path
    ),
    # num_labels=num_labels,
    cache_dir=model_args.cache_dir,
)

# model = BiEncoderModel(
#     model_name=model_args.model_name_or_path,
#     normlized=training_args.normlized,
#     sentence_pooling_method=training_args.sentence_pooling_method,
#     negatives_cross_device=training_args.negatives_cross_device,
#     temperature=training_args.temperature,
# )

In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)

In [6]:
head_data = [dataset[0], dataset[1], dataset[2]]
query, passage = data_collator(head_data).values()
print(passage["input_ids"].shape)

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


torch.Size([12, 27])


In [7]:
dataset[0]

('Generate representations for this sentence to retrieve related articles:Five women walk along a beach wearing flip-flops.',
 ['Some women with flip-flops on, are walking along the beach',
  'A woman is standing outside.',
  'There was a reform in 1996.',
  'The battle was over. '])

In [8]:
mean_model = BiEncoderModel(
    model_name=model_args.model_name_or_path,
    normlized=training_args.normlized,
    sentence_pooling_method="mean",
)

In [9]:
query_tensor = mean_model.encode(query)

In [12]:
query_tensor.unsqueeze(1).shape

torch.Size([3, 1, 768])

In [13]:
query_tensor.shape

torch.Size([3, 768])

In [17]:
config

BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.50.3",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [None]:
# torch.Size([3, 1, 768])

In [18]:
passage_tensor = mean_model.encode(passage)
passage_tensor = passage_tensor.reshape(
    -1, data_args.train_group_size, config.hidden_size
)
passage_tensor.shape

torch.Size([3, 4, 768])

In [33]:
passage_tensor = passage_tensor.reshape(-1, data_args.train_group_size, config.hidden_size)

In [19]:
torch.matmul(query_tensor.unsqueeze(1), passage_tensor.transpose(-2, -1)).shape

torch.Size([3, 1, 4])

In [35]:
mean_model.compute_similarity(query_tensor, passage_tensor)

tensor([[[0.7989, 0.5650, 0.4587, 0.6340],
         [0.7240, 0.5850, 0.4034, 0.6214],
         [0.6837, 0.5761, 0.4178, 0.6391]],

        [[0.7432, 0.6825, 0.5964, 0.5350],
         [0.7682, 0.6258, 0.5924, 0.5315],
         [0.6731, 0.5850, 0.5655, 0.5124]],

        [[0.6734, 0.7426, 0.5338, 0.4966],
         [0.6542, 0.7297, 0.4738, 0.5386],
         [0.6489, 0.6687, 0.4550, 0.5075]]], grad_fn=<CloneBackward0>)

In [None]:
mean_model.model(passage["input_ids"])

In [45]:
A = torch.tensor([[[1.0, 2.0, 3.0]], [[4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0]]])  # (3, 1, 3)

B = torch.tensor(
    [
        [[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
        [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
        [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]],
    ]
)  # (3, 3, 2)

res = torch.matmul(A, B)

In [47]:
A.shape, B.shape

(torch.Size([3, 1, 3]), torch.Size([3, 3, 2]))

In [48]:
res

tensor([[[  4.,   5.]],

        [[ 49.,  64.]],

        [[220., 244.]]])

In [46]:
res.shape

torch.Size([3, 1, 2])

In [53]:
torch.matmul(
    torch.randn(3, 1, 3),
    torch.randn(3, 3, 2),
).shape

torch.Size([3, 1, 2])

In [52]:
torch.matmul(
    torch.randn(3, 3),
    torch.randn(3, 3, 2),
).shape

torch.Size([3, 3, 2])

In [None]:
mean_model(query=query, passage=passage)

验证 hf 使用 mean 还是 cls 在计算encode

In [None]:
trainer = Trainer(
    model=mean_model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=EmbedCollator(
        tokenizer,
        query_max_len=data_args.query_max_len,
        passage_max_len=data_args.passage_max_len,
    ),
    tokenizer=tokenizer,
)

In [17]:
# trainer.train()

In [18]:
# trainer.save_model()