Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmchat/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .concat_dataset import ConcatDataset
from .huggingface import process_hf_dataset

__all__ = ['process_hf_dataset']
__all__ = ['process_hf_dataset', 'ConcatDataset']
25 changes: 25 additions & 0 deletions mmchat/datasets/concat_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from torch.utils.data import ConcatDataset as _ConcatDataset

from mmchat.registry import DATASETS


class ConcatDataset(_ConcatDataset):

def __init__(self, tokenizer, datasets_cfg):
datasets = []
names = []
for name, cfg in datasets_cfg.items():
if cfg.get('tokenizer', None) is None:
cfg['tokenizer'] = tokenizer
datasets.append(DATASETS.build(cfg))
names.append(name)
self.names = names
super().__init__(datasets=datasets)

def __repr__(self):
main_str = 'Dataset as a concatenation of multiple datasets. \n'
main_str += '\n'.join([
f'{name}: {repr(dataset)},'
for name, dataset in zip(self.names, self.datasets)
])
return main_str