In [17]:
import math
import paddle
import paddle.nn.functional as F
import os
import random
from paddle.io import Dataset, IterableDataset
import gzip
from functools import reduce
from args import config
import numpy as np
# ----------------------  DataLoader ----------------------- #

def process_data(query, title, content, max_seq_len):
    """ process [query, title, content] into a tensor 
        [CLS] + query + [SEP] + title + [SEP] + content + [SEP] + [PAD]
    """
    data = [config._CLS_]
    segment = [0]

    data = data + [int(item) + 10 for item in query.split(b'\x01')] # query
    data = data + [config._SEP_]
    segment = segment + [0] * (len(query.split(b'\x01')) + 1)

    data = data + [int(item) + 10 for item in title.split(b'\x01')] # content
    data = data + [config._SEP_] # sep defined as 1
    segment = segment + [1] * (len(title.split(b'\x01')) + 1)

    data = data + [int(item) + 10 for item in content.split(b'\x01')] # content
    data = data + [config._SEP_]
    segment = segment + [1] * (len(content.split(b'\x01')) + 1)

    # padding 
    padding_mask = [False] * len(data)
    if len(data) < max_seq_len: 
        padding_mask += [True] * (max_seq_len - len(data))
        data += [config._PAD_] * (max_seq_len - len(data))
    else:
        padding_mask = padding_mask[:max_seq_len]
        data = data[:max_seq_len]

    # segment id
    if len(segment) < max_seq_len:
        segment += [1] * (max_seq_len-len(segment))
    else:
        segment = segment[:max_seq_len]
    padding_mask = paddle.to_tensor(padding_mask, dtype='int32')
    data = paddle.to_tensor(data, dtype="int32")
    segment = paddle.to_tensor(segment, dtype="int32")
    return data, segment, padding_mask


class TrainDataset(IterableDataset):
    def __init__(self, directory_path, buffer_size=100000, max_seq_len=128):
        super().__init__()
        self.directory_path = directory_path
        self.buffer_size = buffer_size 
        self.files = os.listdir(self.directory_path)
        random.shuffle(self.files)
        self.cur_query = "#"
        self.max_seq_len = max_seq_len


    def __iter__(self):
        buffer = []
        for file in self.files:
            print('load file', file)
            if file[-3:] != '.gz' or file == 'part-00000.gz':  # part-00000.gz is for evaluation
                continue
            with gzip.open(os.path.join(self.directory_path, file), 'rb') as f:
                for line in f.readlines():
                    line_list = line.strip(b'\n').split(b'\t')
                    if len(line_list) == 3:  # new query 
                        self.cur_query = line_list[1]
                    elif len(line_list) > 6:  # urls 
                        position, title, content, click_label = line_list[0], line_list[2], line_list[3], line_list[5]
                        try:
                            src_input, segment, src_padding_mask = process_data(self.cur_query, title, content, self.max_seq_len)
                            buffer.append([src_input, segment, src_padding_mask, float(click_label)])
                        except:
                            pass
                    if len(buffer) >= self.buffer_size:
                        random.shuffle(buffer)

                        for record in buffer:
                            yield record

usage: ipykernel_launcher.py [-h] [--train_datadir TRAIN_DATADIR]
                             [--valid_annotate_path VALID_ANNOTATE_PATH]
                             [--valid_click_path VALID_CLICK_PATH]
                             [--num_candidates NUM_CANDIDATES]
                             [--ntokens NTOKENS] [--seed SEED]
                             [--max_seq_len MAX_SEQ_LEN] [--emb_dim EMB_DIM]
                             [--nlayers NLAYERS] [--nhead NHEAD]
                             [--dropout DROPOUT]
                             [--n_queries_for_each_gpu N_QUERIES_FOR_EACH_GPU]
                             [--init_parameters INIT_PARAMETERS]
                             [--eval_batch_size EVAL_BATCH_SIZE]
                             [--n_gpus N_GPUS] [--lr LR]
                             [--max_steps MAX_STEPS]
                             [--warmup_steps WARMUP_STEPS]
                             [--weight_decay WEIGHT_DECAY]
                             [--buffer_s

AssertionError: 

In [4]:
val_col_names=['Qid(undefined)','Query','Title','Abstract','Label','Bucket']


In [5]:
import pandas as pd 
val_data=pd.read_csv('data/annotate_data/val_data.txt',sep='\t',names=val_col_names)

In [8]:
val_data

Unnamed: 0,Qid(undefined),Query,Title,Abstract,Label,Bucket
0,0,21438218642142628505986951615025266721...,20196400228505986951610956118481502526...,21864,3,8
1,0,21438218642142628505986951615025266721...,16784451220722850169133176150252667214...,21864,0,8
2,0,21438218642142628505986951615025266721...,20196100729763615488152850598695161502...,20196100722186497632186461548815218642...,2,8
3,0,21438218642142628505986951615025266721...,26088510285059869516150252667214362139...,34975647218642186421864126004127218643...,2,8
4,0,21438218642142628505986951615025266721...,15038168901095628501513616890598695161...,150382186416890109562186428502186415136...,0,8
...,...,...,...,...,...,...
495,16,213912143921391218642144621436214362139...,213682186421410218642143921864214466145...,21864218642186421864614526892186440191...,0,9
496,16,213912143921391218642144621436214362139...,213682186421410218642143921864214466145...,21864218642186421864614526892186440191...,0,9
497,16,213912143921391218642144621436214362139...,213682186421410218642143921864214466145...,21864218642186421864614526892186440191...,0,9
498,16,213912143921391218642144621436214362139...,213682186421410218642143921864214466145...,21864218642186421864614526892186440191...,0,9


In [6]:
annotation_data=pd.read_csv('/media/credog/2202-0B9C/WSDM2023/annotation_data_0522.txt',sep='\t',names=val_col_names)

In [13]:
annotation_data['Label'].value_counts()

0    219305
2    112759
1     36622
3     28172
4       714
Name: Label, dtype: int64

In [1]:
import gzip
import csv

with gzip.open('/media/credog/2202-0B9C/WSDM2023/part-00000.gz') as f:
    data=f.readlines()

In [5]:
len(data[1].split(b'\t'))

32

In [14]:
with open('/media/credog/2202-0B9C/WSDM2023/annotation_data_0522.txt') as f:
    annotation_data=f.readlines()
with open('/media/credog/2202-0B9C/WSDM2023/test_data.txt') as f:
    test_data=f.readlines()
with open('/media/credog/2202-0B9C/WSDM2023/test_data (1).txt') as f:
    test_data1=f.readlines()
assert annotation_data==test_data==test_data1