-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
21 changed files
with
141 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import bisect | ||
import warnings | ||
|
||
from torch.utils.data import Dataset, IterableDataset | ||
# from torch.utils.data.dataset import ConcatDataset | ||
|
||
|
||
class ConcatenateDataset(Dataset): | ||
r"""Dataset as a concatenation of multiple datasets. | ||
This class is useful to assemble different existing datasets. | ||
Arguments: | ||
datasets (sequence): List of datasets to be concatenated | ||
""" | ||
|
||
@staticmethod | ||
def cumsum(sequence): | ||
r, s = [], 0 | ||
for e in sequence: | ||
l = len(e) | ||
r.append(l + s) | ||
s += l | ||
return r | ||
|
||
def __init__(self, datasets) -> None: | ||
super(ConcatenateDataset, self).__init__() | ||
# Cannot verify that datasets is Sized | ||
assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore | ||
self.datasets = list(datasets) | ||
for d in self.datasets: | ||
assert not isinstance(d, IterableDataset), "ConcatenateDataset does not support IterableDataset" | ||
self.cumulative_sizes = self.cumsum(self.datasets) | ||
|
||
def __len__(self): | ||
return self.cumulative_sizes[-1] | ||
|
||
def __getitem__(self, idx): | ||
if idx < 0: | ||
if -idx > len(self): | ||
raise ValueError("absolute value of index should not exceed dataset length") | ||
idx = len(self) + idx | ||
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | ||
if dataset_idx == 0: | ||
sample_idx = idx | ||
else: | ||
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] | ||
return self.datasets[dataset_idx][sample_idx], self.datasets[dataset_idx].domain_index | ||
|
||
@property | ||
def cummulative_sizes(self): | ||
warnings.warn("cummulative_sizes attribute is renamed to " | ||
"cumulative_sizes", DeprecationWarning, stacklevel=2) | ||
return self.cumulative_sizes |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
CUDA_VISIBLE_DEVICES=1 python src/degaa.py --output_dir ./adapt/run1 --tensorboard --source Ar,Pr --target Cl,Rw --batch_size 32 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# CUDA_VISIBLE_DEVICES=0 python src/embeddings_old.py --output_dir ./protoruns/run4 --batch_size 48 --num_proto_steps 100000 --tensorboard --resume 60000 | ||
|
||
CUDA_VISIBLE_DEVICES=1 python src/embeddings_old.py --output_dir ./protoruns/run5 --batch_size 32 --hparams '{"mixup":1}' --num_proto_steps 1000000 --tensorboard --resume 400100 --checkpoint_freq 10000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.