-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
24 changed files
with
187 additions
and
198 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import bisect | ||
import warnings | ||
|
||
from torch.utils.data import Dataset, IterableDataset | ||
# from torch.utils.data.dataset import ConcatDataset | ||
|
||
|
||
class ConcatenateDataset(Dataset): | ||
r"""Dataset as a concatenation of multiple datasets. | ||
This class is useful to assemble different existing datasets. | ||
Arguments: | ||
datasets (sequence): List of datasets to be concatenated | ||
""" | ||
|
||
@staticmethod | ||
def cumsum(sequence): | ||
r, s = [], 0 | ||
for e in sequence: | ||
l = len(e) | ||
r.append(l + s) | ||
s += l | ||
return r | ||
|
||
def __init__(self, datasets) -> None: | ||
super(ConcatenateDataset, self).__init__() | ||
# Cannot verify that datasets is Sized | ||
assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore | ||
self.datasets = list(datasets) | ||
for d in self.datasets: | ||
assert not isinstance(d, IterableDataset), "ConcatenateDataset does not support IterableDataset" | ||
self.cumulative_sizes = self.cumsum(self.datasets) | ||
|
||
def __len__(self): | ||
return self.cumulative_sizes[-1] | ||
|
||
def __getitem__(self, idx): | ||
if idx < 0: | ||
if -idx > len(self): | ||
raise ValueError("absolute value of index should not exceed dataset length") | ||
idx = len(self) + idx | ||
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | ||
if dataset_idx == 0: | ||
sample_idx = idx | ||
else: | ||
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] | ||
return self.datasets[dataset_idx][sample_idx], self.datasets[dataset_idx].domain_index | ||
|
||
@property | ||
def cummulative_sizes(self): | ||
warnings.warn("cummulative_sizes attribute is renamed to " | ||
"cumulative_sizes", DeprecationWarning, stacklevel=2) | ||
return self.cumulative_sizes |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
CUDA_VISIBLE_DEVICES=1 python src/degaa.py --output_dir ./adapt/run1 --tensorboard --source Ar,Pr --target Cl,Rw --batch_size 32 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# CUDA_VISIBLE_DEVICES=0 python src/embeddings_old.py --output_dir ./protoruns/run4 --batch_size 48 --num_proto_steps 100000 --tensorboard --resume 60000 | ||
|
||
CUDA_VISIBLE_DEVICES=1 python src/embeddings_old.py --output_dir ./protoruns/run5 --batch_size 32 --hparams '{"mixup":1}' --num_proto_steps 1000000 --tensorboard --resume 400100 --checkpoint_freq 10000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
# for prototypes | ||
CUDA_VISIBLE_DEVICE=0 python src/embeddings.py --wandb --output_dir ./run2 | ||
# for warmup (will save 3 weights netF, netB, netC) | ||
python image_source_final.py --gpu_id 2 --dataset OfficeHome --output weights --batch_size 128 --max_epoch 50 --source Ar,Pr --target Cl,Rw --wandb 0 | ||
|
||
# for adaptation | ||
# for source only training (will save 3 weights netF, netB, netC) | ||
CUDA_VISIBLE_DEVICE=0 python image_source_final.py --dataset OfficeHome --output weights --batch_size 128 --max_epoch 50 --source Ar,Pr --target Cl,Rw --wandb 0 | ||
# for computing centroids and saving TSNE plots | ||
CUDA_VISIBLE_DEVICE=0 python warmup.py --dataset OfficeHome --source Ar,Pr --target Cl,Rw --trained_wt weights/uda/OfficeHome | ||
# for adaptation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.