In [1]:
def dpr_setting(args):
    # default setting
    args.batch_size = getattr(args, 'batch_size', 8)
    args.epoch = getattr(args, 'epoch', 20)
    args.report_freq = getattr(args, "report_freq", 5)
    args.accumulate_step = getattr(args, "accumulate_step", 2) #梯度累加，batch size很小的时候用
    args.model_type = getattr(args, "model_type", "princeton-nlp/sup-simcse-roberta-base")
    args.warmup_steps = getattr(args, "warmup_steps", 200)
    args.grad_norm = getattr(args, "grad_norm", 1)
    args.seed = getattr(args, "seed", 42)
    args.max_lr = getattr(args, "max_lr", 2e-5)
    args.max_length = getattr(args, "max_length", 128)
    args.eval_interval = getattr(args, "eval_interval", 20) #评估次数
    args.retrieval_num = getattr(args, "retrieval_num", 4)
    args.evidence_samples = getattr(args, "evidence_samples", 64) #一个batch里面

In [2]:
from torch.utils.data import Dataset
import json
import random

In [3]:
def process(text):
    return text.lower()


def to_cuda(batch):
    for n in batch.keys():
        if n in ["query_input_ids", "evidence_input_ids", "query_attention_mask", "evidence_attention_mask"]:
            batch[n] = batch[n].cuda()

In [4]:
class TrainDataset(Dataset):
    def __init__(self, mode, tok, evidence_samples, max_length=512, using_negative=True):
        self.max_length = max_length
        if using_negative:
            f = open("data/train-claims-with-negatives.json", "r") #生成负样本，造数据。
        else:
            f = open("data/{}-claims.json".format(mode), "r")
        self.dataset = json.load(f)
        f.close()
        self.using_negative = using_negative
        # f = open("data/evidence.json", "r")
        f = open("temp_data/reduced-evidences.json", "r")
        self.evidences = json.load(f)
        f.close()

        self.tokenizer = tok
        self.claim_ids = list(self.dataset.keys())
        self.mode = mode
        self.evidence_samples = evidence_samples
        self.evidence_ids = list(self.evidences.keys())

    def __len__(self):
        return len(self.claim_ids)

    def __getitem__(self, idx):

        data = self.dataset[self.claim_ids[idx]]
        processed_query = process(data["claim_text"])
        evidences = []
        for evidence_id in data["evidences"]:
            evidences.append(evidence_id)
        if self.using_negative:
            negative_evidences = data["negative_evidences"]
            return [processed_query, evidences, negative_evidences]
        else:
            return [processed_query, evidences]

    def collate_fn(self, batch):
        queries = []
        evidences = []
        labels = []
        if self.using_negative:
            negative_evidences = []
            for query, evidence, negative_evidence in batch:
                queries.append(query)
                evidences.extend(evidence)
                negative_evidences.extend(negative_evidence)
                labels.append(len(evidence))
            evidences.extend(negative_evidences)
        else:
            for query, evidence in batch:
                queries.append(query)
                evidences.extend(evidence)
                labels.append(len(evidence))
        cnt = len(evidences)
        if cnt > self.evidence_samples:
            evidences = evidences[:self.evidence_samples]
        evidences_text = [process(self.evidences[evidence_id]) for evidence_id in evidences]
        while cnt < self.evidence_samples: #继续做负采样
            evidence_id = random.choice(self.evidence_ids)
            while evidence_id in evidences:
                evidence_id = random.choice(self.evidence_ids)
            evidences.append(evidence_id)
            evidences_text.append(process(self.evidences[evidence_id]))
            cnt += 1

        query_text = self.tokenizer(
            queries,
            max_length=self.max_length,
            padding=True,
            return_tensors="pt",
            truncation=True,
        )

        evidence_text = self.tokenizer(
            evidences_text,
            max_length=self.max_length,
            padding=True,
            return_tensors="pt",
            truncation=True,
        )

        batch_encoding = dict()
        batch_encoding["query_input_ids"] = query_text.input_ids
        batch_encoding["evidence_input_ids"] = evidence_text.input_ids
        batch_encoding["query_attention_mask"] = query_text.attention_mask
        batch_encoding["evidence_attention_mask"] = evidence_text.attention_mask
        batch_encoding["labels"] = labels
        return batch_encoding

In [5]:
class EvidenceDataset(Dataset):
    def __init__(self, tok, max_length=512):
        self.max_length = max_length

        f = open("temp_data/reduced-evidences.json", "r")
        # f = open("data/evidence.json", "r")
        self.evidences = json.load(f)
        f.close()

        self.tokenizer = tok
        self.evidences_ids = list(self.evidences.keys())

    def __len__(self):
        return len(self.evidences_ids)

    def __getitem__(self, idx):
        evidences_id = self.evidences_ids[idx]
        evidence = self.evidences[evidences_id]
        return [evidences_id, evidence]

    def collate_fn(self, batch):
        evidences_ids = []
        evidences = []

        for evidences_id, evidence in batch:
            evidences_ids.append(evidences_id)
            evidences.append(process(evidence))

        evidences_text = self.tokenizer(
            evidences,
            max_length=self.max_length,
            padding=True,
            return_tensors="pt",
            truncation=True,
        )

        batch_encoding = dict()
        batch_encoding["evidence_input_ids"] = evidences_text.input_ids
        batch_encoding["evidence_attention_mask"] = evidences_text.attention_mask
        batch_encoding["evidences_ids"] = evidences_ids
        return batch_encoding

In [6]:
class ValDataset(Dataset):
    def __init__(self, mode, tok, max_length=512):
        self.max_length = max_length
        if mode != "test":
            f = open("data/{}-claims.json".format(mode), "r")
        else:
            f = open("data/test-claims-unlabelled.json", "r")
        self.dataset = json.load(f)
        f.close()

        self.tokenizer = tok
        self.claim_ids = list(self.dataset.keys())
        self.mode = mode

    def __len__(self):
        return len(self.claim_ids)

    def __getitem__(self, idx):
        data = self.dataset[self.claim_ids[idx]]
        processed_text = process(data["claim_text"])
        return [processed_text, data, self.claim_ids[idx]]

    def collate_fn(self, batch):
        queries = []
        datas = []
        evidences = []
        claim_ids = []
        for query, data, claim_id in batch:
            queries.append(query)
            datas.append(data)
            if self.mode != "test":
                evidences.append(data["evidences"])
            claim_ids.append(claim_id)

        query_text = self.tokenizer(
            queries,
            max_length=self.max_length,
            padding=True,
            return_tensors="pt",
            truncation=True,
        )

        batch_encoding = dict()
        batch_encoding["query_input_ids"] = query_text.input_ids
        batch_encoding["query_attention_mask"] = query_text.attention_mask

        batch_encoding["datas"] = datas
        batch_encoding["claim_ids"] = claim_ids
        if self.mode != "test":
            batch_encoding["evidences"] = evidences
        return batch_encoding