Skip to content

Commit

Permalink
Merge pull request #17 from ant-research/dev
Browse files Browse the repository at this point in the history
Add functionalities to load data from huggingface
  • Loading branch information
iLampard committed Feb 12, 2024
2 parents 23f615e + 1b04cf5 commit de5b060
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 13 deletions.
49 changes: 40 additions & 9 deletions easy_tpp/preprocess/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,36 @@ def build_input_from_pkl(self, source_dir: str, split: str):
input_dict = dict({'time_seqs': time_seqs, 'time_delta_seqs': time_delta_seqs, 'type_seqs': type_seqs})
return input_dict

def build_input_from_json(self, source_dir: str, split: str):
from datasets import load_dataset
split_ = 'validation' if split == 'dev' else split
# load locally
if source_dir.split('.')[-1] == 'json':
data = load_dataset('json', data_files={split_: source_dir})
elif source_dir.startswith('easytpp'):
data = load_dataset(source_dir, split=split_)
else:
raise NotImplementedError

py_assert(data[split_].data['dim_process'][0].as_py() == self.num_event_types,
ValueError,
"inconsistent dim_process in different splits?")

source_data = data[split_]['event_seqs'][0]
time_seqs, type_seqs, time_delta_seqs = [], [], []
for k, v in source_data.items():
cur_time_seq, cur_type_seq, cur_time_delta_seq = [], [], []
for k_, v_ in v.items():
cur_time_seq.append(v_['time_since_start'])
cur_type_seq.append(v_['type_event'])
cur_time_delta_seq.append(v_['time_since_last_event'])
time_seqs.append(cur_time_seq)
type_seqs.append(cur_type_seq)
time_delta_seqs.append(cur_time_delta_seq)

input_dict = dict({'time_seqs': time_seqs, 'time_delta_seqs': time_delta_seqs, 'type_seqs': type_seqs})
return input_dict

def get_loader(self, split='train', **kwargs):
"""Get the corresponding data loader.
Expand All @@ -50,16 +80,17 @@ def get_loader(self, split='train', **kwargs):

if data_source_type == 'pkl':
data = self.build_input_from_pkl(data_dir, split)
dataset = TPPDataset(data)
tokenizer = EventTokenizer(self.data_config.data_specs)
loader = get_data_loader(dataset,
self.backend,
tokenizer,
batch_size=self.kwargs['batch_size'],
shuffle=self.kwargs['shuffle'],
**kwargs)
else:
raise NotImplementedError
data = self.build_input_from_json(data_dir, split)

dataset = TPPDataset(data)
tokenizer = EventTokenizer(self.data_config.data_specs)
loader = get_data_loader(dataset,
self.backend,
tokenizer,
batch_size=self.kwargs['batch_size'],
shuffle=self.kwargs['shuffle'],
**kwargs)

return loader

Expand Down
8 changes: 4 additions & 4 deletions examples/configs/experiment_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ pipeline_config_id: runner_config

data:
taxi:
data_format: pkl
train_dir: ./data/taxi/train.pkl
valid_dir: ./data/taxi/dev.pkl
test_dir: ./data/taxi/test.pkl
data_format: json
train_dir: ./data/taxi/train.json
valid_dir: ./data/taxi/dev.json
test_dir: ./data/taxi/test.json
data_specs:
num_event_types: 10
pad_token_id: 10
Expand Down
20 changes: 20 additions & 0 deletions examples/hf_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from datasets import load_dataset

def load_data_from_hf(hf_dir=None, local_dir=None):
if hf_dir:
ds = load_dataset(hf_dir)
else:
ds = load_dataset('json', data_files=local_dir)
print(ds)
print('dim process: ' + str(ds['validation'].data['dim_process'][0].as_py()))
print('num seqs: ' + str(ds['validation'].data['num_seqs'][0].as_py()))
print('avg seq len: ' + str(ds['validation'].data['avg_seq_len'][0].as_py()))
print('min seq len: ' + str(ds['validation'].data['min_seq_len'][0].as_py()))
print('max seq len: ' + str(ds['validation'].data['max_seq_len'][0].as_py()))
return


if __name__ == '__main__':
# in case one fails to load from hf directly
# one can load the json data file locally
load_data_from_hf(hf_dir=None, local_dir={'validation':'dev.json'})

0 comments on commit de5b060

Please sign in to comment.