Skip to content

Commit

Permalink
Merge branch 'embeddings'
Browse files Browse the repository at this point in the history
  • Loading branch information
take2rohit committed Jan 15, 2022
2 parents ffb5c09 + 9bb22c4 commit c3c1a60
Show file tree
Hide file tree
Showing 24 changed files with 187 additions and 198 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@ wandb/
protoruns/
**/__pycache__/
**/**/__pycache__/
**/**/**/__pycache__/
logs/
*.npy
centroids
*.pyc
model/
OfficeHomeDataset_10072016.zip
dalib/domainbed/__pycache__/*
degaa
**/tensorboard/
san/
54 changes: 54 additions & 0 deletions common/vision/datasets/Concatenate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import bisect
import warnings

from torch.utils.data import Dataset, IterableDataset
# from torch.utils.data.dataset import ConcatDataset


class ConcatenateDataset(Dataset):
r"""Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets.
Arguments:
datasets (sequence): List of datasets to be concatenated
"""

@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r

def __init__(self, datasets) -> None:
super(ConcatenateDataset, self).__init__()
# Cannot verify that datasets is Sized
assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore
self.datasets = list(datasets)
for d in self.datasets:
assert not isinstance(d, IterableDataset), "ConcatenateDataset does not support IterableDataset"
self.cumulative_sizes = self.cumsum(self.datasets)

def __len__(self):
return self.cumulative_sizes[-1]

def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx], self.datasets[dataset_idx].domain_index

@property
def cummulative_sizes(self):
warnings.warn("cummulative_sizes attribute is renamed to "
"cumulative_sizes", DeprecationWarning, stacklevel=2)
return self.cumulative_sizes
Binary file removed dalib/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file removed dalib/domainbed/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed dalib/domainbed/__pycache__/datasets.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file removed dalib/domainbed/__pycache__/networks.cpython-37.pyc
Binary file not shown.
8 changes: 2 additions & 6 deletions dalib/domainbed/algorithms_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class Proto(nn.Module):
"""

def __init__(self, input_shape, num_classes, num_domains, hparams, use_relu=True):
# def __init__(self, hparams, use_relu=True):
super(Proto, self).__init__()

self.hparams = hparams
Expand All @@ -60,19 +61,14 @@ def __init__(self, input_shape, num_classes, num_domains, hparams, use_relu=True
# weight_decay=self.hparams['weight_decay']
# )

# initializing constants
self.nd = num_domains
self.nc = num_classes

# initializing architecture parameters
featurizer = networks.Featurizer(input_shape, self.hparams)
featurizer = networks.Featurizer(hparams["input_shape"], self.hparams)
self.ft_output_size = featurizer.n_outputs
self.proto_size = int(self.ft_output_size * 0.25)
self.feat_size = int(self.ft_output_size)

# initializing hyperparameters
self.proto_frac = hparams["proto_train_frac"]
self.epochs = hparams["n_steps"]
self.proto_epochs = hparams["n_steps_proto"]

# self.kernel_type = "gaussian"
Expand Down
2 changes: 2 additions & 0 deletions dalib/domainbed/hparams_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def _hparams(algorithm, dataset, random_state):
hparams["dataset"] = (dataset, dataset)
hparams["domains_per_iter"] = (4, 4)
hparams["data_parallel"] = (True, True)
hparams["input_shape"] = ((3, 224, 224,), (3, 224, 224,))

if dataset in RESNET_DATASETS:
hparams["lr"] = (1e-4, 10 ** random_state.uniform(-5.5, -3.5))
Expand All @@ -44,6 +45,7 @@ def _hparams(algorithm, dataset, random_state):
hparams["weight_decay"] = (1e-4, 10 ** random_state.uniform(-4, -4))

hparams["class_balanced"] = (False, False)
hparams["num_proto_extraction_points"] = (500, 500)

if algorithm in ["DANN", "CDANN"]:

Expand Down
Binary file not shown.
Binary file removed dalib/domainbed/lib/__pycache__/misc.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
44 changes: 5 additions & 39 deletions image_source_final.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ def train_source(args):
args.out_file.write(log_str + '\n')
args.out_file.flush()
print(log_str+'\n')
print(args)
test_target(args)

if acc_s_te >= acc_init:
Expand Down Expand Up @@ -300,13 +299,12 @@ def print_args(args):

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='SHOT')
parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
parser.add_argument('--source',default='Ar,Pr', type=str, help="source")
parser.add_argument('--target',default='Cl,Rw', type=str, help="target")
parser.add_argument('--root', default='data/', type=str, help="source")
parser.add_argument('--max_epoch', type=int, default=20, help="max iterations")
parser.add_argument('--batch_size', type=int, default=64, help="batch_size")
parser.add_argument('--workers', type=int, default=4, help="number of workers")
parser.add_argument('--workers', type=int, default=8, help="number of workers")
parser.add_argument('--dataset', type=str, default='OfficeHome', choices=['visda-2017', 'office', 'OfficeHome','pacs', 'domain_net'])
parser.add_argument('--lr', type=float, default=1e-2, help="learning rate")
parser.add_argument('--net', type=str, default='resnet50', help="vgg16, resnet50, resnet101")
Expand Down Expand Up @@ -344,29 +342,16 @@ def print_args(args):
names = ['clipart', 'infograph', 'painting', 'quickdraw', 'sketch', 'real']
# args.class_num = 345

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
# os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
SEED = args.seed
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
# torch.backends.cudnn.deterministic = True

# folder = './data/'
# args.s_dset_path = folder + args.dataset + '/' + names[args.source] + '.txt'
# args.test_dset_path = folder + args.dataset + '/' + names[args.t] + '.txt'
mode = 'online' if args.wandb else 'disabled'
wandb.init(project='degaa', entity='vclab', name=f'SRC Train: {args.source}', mode=mode)
print(print_args(args))
# if args.dataset == 'OfficeHome':
# if args.da == 'pda':
# args.class_num = 65
# args.src_classes = [i for i in range(65)]
# args.tar_classes = [i for i in range(25)]
# if args.da == 'oda':
# args.class_num = 25
# args.src_classes = [i for i in range(25)]
# args.tar_classes = [i for i in range(65)]

args.output_dir_src = osp.join(args.output, args.da, args.dataset, args.source.replace(',',''))
args.name_src = args.source.replace(',','')
Expand All @@ -378,27 +363,8 @@ def print_args(args):
args.out_file = open(osp.join(args.output_dir_src, 'log.txt'), 'w')
args.out_file.write(print_args(args)+'\n')
args.out_file.flush()
train_source(args)

args.out_file = open(osp.join(args.output_dir_src, 'log_test.txt'), 'w')
# for i in range(len(names)):
# if i == args.source:
# continue
# args.t = i
args.name = args.source.replace(',','') + '_' + args.target.replace(',','')

# folder = './data/'
# args.s_dset_path = folder + args.dataset + '/' + names[args.source] + '_list.txt'
# args.test_dset_path = folder + args.dataset + '/' + names[args.t] + '_list.txt'

# if args.dataset == 'OfficeHome':
# if args.da == 'pda':
# args.class_num = 65
# args.src_classes = [i for i in range(65)]
# args.tar_classes = [i for i in range(25)]
# if args.da == 'oda':
# args.class_num = 25
# args.src_classes = [i for i in range(25)]
# args.tar_classes = [i for i in range(65)]

args.out_file = open(osp.join(args.output_dir_src, 'log_test.txt'), 'w')

train_source(args)
test_target(args)
1 change: 1 addition & 0 deletions scripts/adapt.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CUDA_VISIBLE_DEVICES=1 python src/degaa.py --output_dir ./adapt/run1 --tensorboard --source Ar,Pr --target Cl,Rw --batch_size 32
3 changes: 3 additions & 0 deletions scripts/embeddings.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# CUDA_VISIBLE_DEVICES=0 python src/embeddings_old.py --output_dir ./protoruns/run4 --batch_size 48 --num_proto_steps 100000 --tensorboard --resume 60000

CUDA_VISIBLE_DEVICES=1 python src/embeddings_old.py --output_dir ./protoruns/run5 --batch_size 32 --hparams '{"mixup":1}' --num_proto_steps 1000000 --tensorboard --resume 400100 --checkpoint_freq 10000
9 changes: 5 additions & 4 deletions scripts/endtoend.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# for prototypes
CUDA_VISIBLE_DEVICE=0 python src/embeddings.py --wandb --output_dir ./run2
# for warmup (will save 3 weights netF, netB, netC)
python image_source_final.py --gpu_id 2 --dataset OfficeHome --output weights --batch_size 128 --max_epoch 50 --source Ar,Pr --target Cl,Rw --wandb 0

# for adaptation
# for source only training (will save 3 weights netF, netB, netC)
CUDA_VISIBLE_DEVICE=0 python image_source_final.py --dataset OfficeHome --output weights --batch_size 128 --max_epoch 50 --source Ar,Pr --target Cl,Rw --wandb 0
# for computing centroids and saving TSNE plots
CUDA_VISIBLE_DEVICE=0 python warmup.py --dataset OfficeHome --source Ar,Pr --target Cl,Rw --trained_wt weights/uda/OfficeHome
# for adaptation
8 changes: 4 additions & 4 deletions scripts/source_only.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# python image_source_final.py --output san --gpu_id 0 --dset office --max_epoch 50 --s 0 --net vit

#office-home
python image_source_final.py --output weights --gpu_id 0 --dset office-home --max_epoch 50 --s 0 --net resnet50
python image_source_final.py --output weights --gpu_id 0 --dset office-home --max_epoch 50 --s 1 --net resnet50
python image_source_final.py --output weights --gpu_id 1 --dset office-home --max_epoch 50 --s 2 --net resnet50
python image_source_final.py --output weights --gpu_id 0 --dset office-home --max_epoch 50 --s 3 --net resnet50
python image_source_final.py --output weights --gpu_id 2 --batch_size 128 --dataset OfficeHome --max_epoch 50 --source Ar,Pr --target Cl,Rw --wandb 0
# python image_source_final.py --output weights --gpu_id 0 --dataset OfficeHome --max_epoch 50 --source "Ar,Pr" --target "Cl,Rw"
# python image_source_final.py --output weights --gpu_id 1 --dataset OfficeHome --max_epoch 50 --source "Ar,Pr" --target "Cl,Rw"
# python image_source_final.py --output weights --gpu_id 0 --dataset OfficeHome --max_epoch 50 --source "Ar,Pr" --target "Cl,Rw"

#pacs
# python image_source_final.py --output san --gpu_id 0 --dset pacs --max_epoch 50 --s 0 --net vit
Expand Down
Loading

0 comments on commit c3c1a60

Please sign in to comment.