自己实现 EmbedCollator 

In [1]:
from dataclasses import dataclass
import torch
from transformers import DataCollatorWithPadding, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from hf import model_args, data_args, training_args



Downloading Model from https://www.modelscope.cn to directory: /home/jie/.cache/modelscope/hub/models/AI-ModelScope/bert-base-uncased


In [3]:
import sys
sys.path.append("../")

from src.data import TrainDatasetForEmbedding

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)

## 加载dataset

In [5]:
dataset = TrainDatasetForEmbedding(
    args=data_args,
    tokenizer=tokenizer
)
dataset[0]

('Generate representations for this sentence to retrieve related articles:Five women walk along a beach wearing flip-flops.',
 ['Some women with flip-flops on, are walking along the beach',
  'The man is talking about hawaii.',
  'The battle was over. ',
  'A group of people plays volleyball.'])

In [6]:
from typing import List

In [61]:
torch.randn(2, 3).size(-1)

3

In [86]:
@dataclass
class EmbedCollator(DataCollatorWithPadding):
    """
    Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg]
    and pass batch separately to the actual collator.
    Abstract out data detail for the model.
    """
    
    query_max_len: int = 32
    passage_max_len: int = 128

    def __call__(self, features):
        query = [f[0] for f in features]
        passages = [f[1] for f in features]
        query: List[str]  # batch_size
        passages: List[List[str]]  # batch_size, group_size

        q_collated = self.tokenizer(
            query,
            padding=True,
            truncation=True,
            max_length=self.query_max_len,
            return_tensors="pt",
        )

        # 单个样本tokenizer
        batch_max_passage_length = 0
        passage_collated_data = []
        for passage in passages:
            tmp_collated = self.tokenizer(
                passage,
                padding=True,
                truncation=True,
                max_length=self.passage_max_len,
                return_tensors="pt",
            )
            batch_max_passage_length = max(
                batch_max_passage_length, tmp_collated.input_ids.size(-1)
            )
            passage_collated_data.append(tmp_collated)
            
        # padding to batch_max_passage_length, then stack
        passage_collated = {}
        for item in passage_collated_data:
            # pad
            padded_sentences = tokenizer.pad(
                item,
                padding="max_length",  # 自动填充到最长序列长度
                return_tensors="pt",  # 返回 PyTorch 格式的张量
                max_length=batch_max_passage_length,
            )

            for k, v in padded_sentences.items():
                if k not in passage_collated.keys():
                    passage_collated[k] = []
                passage_collated[k].append(v)

        for k, v in passage_collated.items():
            passage_collated[k] = torch.stack(v)

        return {"query": q_collated, "passage": passage_collated}

In [87]:
data_collator = EmbedCollator(tokenizer=tokenizer)

In [49]:
head_data = [dataset[0], dataset[1], dataset[2]]

In [91]:
query, passage = data_collator(head_data).values()

In [95]:
passage["input_ids"].shape

torch.Size([3, 4, 16])