diff --git a/torchreid/data/datamanager.py b/torchreid/data/datamanager.py index 7ae28cbaf..872009d8e 100644 --- a/torchreid/data/datamanager.py +++ b/torchreid/data/datamanager.py @@ -175,7 +175,8 @@ def __init__( train_sampler_t='RandomSampler', cuhk03_labeled=False, cuhk03_classic_split=False, - market1501_500k=False + market1501_500k=False, + **kwargs ): super(ImageDataManager, self).__init__( @@ -202,7 +203,8 @@ def __init__( split_id=split_id, cuhk03_labeled=cuhk03_labeled, cuhk03_classic_split=cuhk03_classic_split, - market1501_500k=market1501_500k + market1501_500k=market1501_500k, + **kwargs ) trainset.append(trainset_) trainset = sum(trainset) @@ -246,7 +248,8 @@ def __init__( split_id=split_id, cuhk03_labeled=cuhk03_labeled, cuhk03_classic_split=cuhk03_classic_split, - market1501_500k=market1501_500k + market1501_500k=market1501_500k, + **kwargs ) trainset_t.append(trainset_t_) trainset_t = sum(trainset_t) @@ -295,7 +298,8 @@ def __init__( split_id=split_id, cuhk03_labeled=cuhk03_labeled, cuhk03_classic_split=cuhk03_classic_split, - market1501_500k=market1501_500k + market1501_500k=market1501_500k, + **kwargs ) self.test_loader[name]['query'] = torch.utils.data.DataLoader( queryset, @@ -317,7 +321,8 @@ def __init__( split_id=split_id, cuhk03_labeled=cuhk03_labeled, cuhk03_classic_split=cuhk03_classic_split, - market1501_500k=market1501_500k + market1501_500k=market1501_500k, + **kwargs ) self.test_loader[name]['gallery'] = torch.utils.data.DataLoader( galleryset,