In [None]:
from modelzipper.tutils import *

fanout_final_test_w_content = "/vepfs/wcf/G/zecheng/data/fanout-final-test-content.json"
content = auto_read_data(fanout_final_test_w_content)

In [None]:
content[0]

In [None]:
from torch.utils.data import Dataset

class Fanout(Dataset):

    def __init__(self, file_path, tokenizer, max_context_length) -> None:
        super().__init__()
        self.file_path = file_path
        self.tokenizer = tokenizer
        self.content = auto_read_data(file_path)
        self.TEMPLATE = "*** BEGIN DATA ***\n\n{context}\n*** END DATA ***\n\n \
            Answer the following question based on the documents above, and output only your answer. \
            If the answer is a list, output one on each line. \n\n[Question]: {question}"
        self.QUESTION_TEMPLATE = "<document>\n<title>{title}</title>\n<content>{evidence}</content>\n</document>\n"
        self.max_context_length = max_context_length

    def __len__(self, index) -> Any:
        return len(self.content)

    def filter_tok_context_length(self, s, L) -> str:
        tok_context = self.tokenizer(s, return_tensors='pt', add_special_tokens=False).input_ids[:L]
        decoded_context = self.tokenizer.decode(tok_context[0], skip_special_tokens=True)
        return decoded_context
    
    def __getitem__(self, index) -> Any:
        sample = self.content[index]
        titles = [item['title'] for item in sample]
        evidences = [item['content'] for item in sample]
        # cut the evidence length to fit the input length

        num_evidences = len(evidences)
        per_evidence_max_length = self.max_context_length // num_evidences
        evidences = [self.filter_tok_context_length(s, per_evidence_max_length) for s in evidences]
        context = [self.QUESTION_TEMPLATE.format(title, evidence) for title, evidence in zip(titles, evidences)]
        query = self.TEMPLATE.format(context=context, question=sample['question'])

        # your tokenizer operation
        tok_query = self.tokenizer(query, return_tensors='pt')

        return tok_query  # including input_ids and attention mask
