Skip to content

Commit

Permalink
completed centroids + source only training
Browse files Browse the repository at this point in the history
  • Loading branch information
take2rohit committed Jan 15, 2022
1 parent b78b8c6 commit 597c8d3
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 152 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ centroids
model/
OfficeHomeDataset_10072016.zip
dalib/domainbed/__pycache__/*
degaa
44 changes: 5 additions & 39 deletions image_source_final.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ def train_source(args):
args.out_file.write(log_str + '\n')
args.out_file.flush()
print(log_str+'\n')
print(args)
test_target(args)

if acc_s_te >= acc_init:
Expand Down Expand Up @@ -300,13 +299,12 @@ def print_args(args):

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='SHOT')
parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
parser.add_argument('--source',default='Ar,Pr', type=str, help="source")
parser.add_argument('--target',default='Cl,Rw', type=str, help="target")
parser.add_argument('--root', default='data/', type=str, help="source")
parser.add_argument('--max_epoch', type=int, default=20, help="max iterations")
parser.add_argument('--batch_size', type=int, default=64, help="batch_size")
parser.add_argument('--workers', type=int, default=4, help="number of workers")
parser.add_argument('--workers', type=int, default=8, help="number of workers")
parser.add_argument('--dataset', type=str, default='OfficeHome', choices=['visda-2017', 'office', 'OfficeHome','pacs', 'domain_net'])
parser.add_argument('--lr', type=float, default=1e-2, help="learning rate")
parser.add_argument('--net', type=str, default='resnet50', help="vgg16, resnet50, resnet101")
Expand Down Expand Up @@ -344,29 +342,16 @@ def print_args(args):
names = ['clipart', 'infograph', 'painting', 'quickdraw', 'sketch', 'real']
# args.class_num = 345

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
# os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
SEED = args.seed
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
# torch.backends.cudnn.deterministic = True

# folder = './data/'
# args.s_dset_path = folder + args.dataset + '/' + names[args.source] + '.txt'
# args.test_dset_path = folder + args.dataset + '/' + names[args.t] + '.txt'
mode = 'online' if args.wandb else 'disabled'
wandb.init(project='degaa', entity='vclab', name=f'SRC Train: {args.source}', mode=mode)
print(print_args(args))
# if args.dataset == 'OfficeHome':
# if args.da == 'pda':
# args.class_num = 65
# args.src_classes = [i for i in range(65)]
# args.tar_classes = [i for i in range(25)]
# if args.da == 'oda':
# args.class_num = 25
# args.src_classes = [i for i in range(25)]
# args.tar_classes = [i for i in range(65)]

args.output_dir_src = osp.join(args.output, args.da, args.dataset, args.source.replace(',',''))
args.name_src = args.source.replace(',','')
Expand All @@ -378,27 +363,8 @@ def print_args(args):
args.out_file = open(osp.join(args.output_dir_src, 'log.txt'), 'w')
args.out_file.write(print_args(args)+'\n')
args.out_file.flush()
train_source(args)

args.out_file = open(osp.join(args.output_dir_src, 'log_test.txt'), 'w')
# for i in range(len(names)):
# if i == args.source:
# continue
# args.t = i
args.name = args.source.replace(',','') + '_' + args.target.replace(',','')

# folder = './data/'
# args.s_dset_path = folder + args.dataset + '/' + names[args.source] + '_list.txt'
# args.test_dset_path = folder + args.dataset + '/' + names[args.t] + '_list.txt'

# if args.dataset == 'OfficeHome':
# if args.da == 'pda':
# args.class_num = 65
# args.src_classes = [i for i in range(65)]
# args.tar_classes = [i for i in range(25)]
# if args.da == 'oda':
# args.class_num = 25
# args.src_classes = [i for i in range(25)]
# args.tar_classes = [i for i in range(65)]

args.out_file = open(osp.join(args.output_dir_src, 'log_test.txt'), 'w')

train_source(args)
test_target(args)
9 changes: 5 additions & 4 deletions scripts/endtoend.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# for prototypes
CUDA_VISIBLE_DEVICE=0 python src/embeddings.py --wandb --output_dir ./run2
# for warmup (will save 3 weights netF, netB, netC)
python image_source_final.py --gpu_id 2 --dataset OfficeHome --output weights --batch_size 128 --max_epoch 50 --source Ar,Pr --target Cl,Rw --wandb 0

# for adaptation
# for source only training (will save 3 weights netF, netB, netC)
CUDA_VISIBLE_DEVICE=0 python image_source_final.py --dataset OfficeHome --output weights --batch_size 128 --max_epoch 50 --source Ar,Pr --target Cl,Rw --wandb 0
# for computing centroids and saving TSNE plots
CUDA_VISIBLE_DEVICE=0 python warmup.py --dataset OfficeHome --source Ar,Pr --target Cl,Rw --trained_wt weights/uda/OfficeHome
# for adaptation
144 changes: 35 additions & 109 deletions warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
import matplotlib.pyplot as plt
import pandas as pd
import os

sys.path.append('../../..')
import sys
sys.path.append('..')
sys.path.append('common')
sys.path.append('src')
from common.modules.classifier import Classifier
import common.vision.datasets as datasets
import common.vision.models as models
Expand All @@ -39,6 +40,8 @@
import network
from sklearn.manifold import TSNE
import seaborn as sns
from data_helper import setup_datasets

sns.set(style="darkgrid")

np.random.seed(0)
Expand All @@ -47,100 +50,15 @@

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def data_load(args, batch_size=64):
## prepare data
txt_path=f'data/{args.dset}'


def image(resize_size=256, crop_size=224, alexnet=False):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
return transforms.Compose([
transforms.Resize((resize_size, resize_size)),
transforms.RandomCrop(crop_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
])

def image_test(resize_size=256, crop_size=224, alexnet=False):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
return transforms.Compose([
transforms.Resize((resize_size, resize_size)),
transforms.ToTensor(),
normalize
])

dsets = {}
dset_loaders = {}
dset_loaders_test = {}

train_bs = batch_size

if args.dset == 'office':
txt_files = {'amazon' : f'{txt_path}/amazon.txt',
'webcam': f'{txt_path}/webcam.txt',
'dslr': f'{txt_path}/dslr.txt'}
args.num_classes = 31

if args.dset == 'office-home':
txt_files = {'Art' : f'{txt_path}/Art.txt',
'Clipart': f'{txt_path}/Clipart.txt',
'Product': f'{txt_path}/Product.txt',
'RealWorld': f'{txt_path}/RealWorld.txt'}
args.num_classes = 65

if args.dset == 'pacs':
txt_files = {'art_painting' : f'{txt_path}/art_painting.txt',
'cartoon': f'{txt_path}/cartoon.txt',
'photo': f'{txt_path}/photo.txt',
'sketch': f'{txt_path}/sketch.txt'}

if args.dset == 'domain_net':
txt_files = {'clipart': f'{txt_path}/clipart.txt',
'infograph': f'{txt_path}/infograph.txt',
'painting': f'{txt_path}/painting.txt',
'quickdraw': f'{txt_path}/quickdraw.txt',
'sketch': f'{txt_path}/sketch.txt',
'real': f'{txt_path}/real.txt'}

txt_files_test = {'clipart': f'{txt_path}/clipart_test.txt',
'infograph': f'{txt_path}/infograph_test.txt',
'painting': f'{txt_path}/painting_test.txt',
'quickdraw': f'{txt_path}/quickdraw_test.txt',
'sketch': f'{txt_path}/sketch_test.txt',
'real': f'{txt_path}/real_test.txt'}
args.num_classes = 345

for domain, paths in txt_files_test.items():
txt_tar = open(paths).readlines()
dsets[domain] = ImageList_idx(txt_tar, transform=image_test())
dset_loaders_test[domain] = DataLoader(dsets[domain], batch_size=train_bs,drop_last=False)

if args.dset != 'domain_net':
dset_loaders_test = dset_loaders

for domain, paths in txt_files.items():
if domain in [args.source, args.target]:
txt_tar = open(paths).readlines()

dsets[domain] = ImageList_idx(txt_tar, transform=image())
dset_loaders[domain] = DataLoader(dsets[domain], batch_size=train_bs, shuffle=True,drop_last=True)

return dset_loaders, dset_loaders_test


def load_models(args,domain):
def load_models(args,dom_adapts):

model_loaders = {}
if args.dset == 'office-home':
wt_abbr = { 'Art': 'A','Clipart': 'C', 'Product': 'P', 'RealWorld': 'R'}
# if args.dataset == 'office-home':
# wt_abbr = { 'Art': 'A','Clipart': 'C', 'Product': 'P', 'RealWorld': 'R'}

dom_adapts = wt_abbr[domain]
# dom_adapts = wt_abbr[domain]

print('Loading weights for ', domain)
print('Loading weights for ', dom_adapts)
if args.net[0:3] == 'res':
netF = network.ResBase(res_name=args.net).to(device)
elif args.net == 'vit':
Expand All @@ -150,7 +68,7 @@ def load_models(args,domain):
args.bottleneck_dim = 256

netB = network.feat_bootleneck(type='bn', feature_dim=args.feature_dim,bottleneck_dim=args.bottleneck_dim).to(device)
netC = network.feat_classifier(type='wn', class_num=args.num_classes, bottleneck_dim=args.bottleneck_dim).to(device)
netC = network.feat_classifier(type='wn', class_num=args.class_num, bottleneck_dim=args.bottleneck_dim).to(device)

modelpathF = f'{args.trained_wt}/{dom_adapts}/source_F.pt'
netF.load_state_dict(torch.load(modelpathF))
Expand Down Expand Up @@ -188,8 +106,8 @@ def compute_features(args, net, dataloader, dataset_name=None):
iter_test = iter(dataloader)
for i in tqdm(range(len(dataloader))):
data = iter_test.next()
inputs = data[0].to('cuda')
labels = data[1].to('cuda')
inputs = data[0][0].to('cuda')
labels = data[0][1].to('cuda')

feats = netB(netF(inputs))

Expand Down Expand Up @@ -219,15 +137,15 @@ def compute_features(args, net, dataloader, dataset_name=None):

def compute_centroids(args,features_labels):

run_sum = torch.zeros((args.num_classes, args.bottleneck_dim), dtype=torch.float).to(device)
lbl_cnt = torch.ones((args.num_classes), dtype=torch.float).to(device)
cls_centroids = torch.zeros((args.num_classes, args.bottleneck_dim), dtype=torch.float).to(device)
run_sum = torch.zeros((args.class_num, args.bottleneck_dim), dtype=torch.float).to(device)
lbl_cnt = torch.ones((args.class_num), dtype=torch.float).to(device)
cls_centroids = torch.zeros((args.class_num, args.bottleneck_dim), dtype=torch.float).to(device)

for feat, lbl in zip(features_labels['feature'],features_labels['label']):
run_sum[lbl] += feat
lbl_cnt[lbl] += 1

for i in range(args.num_classes):
for i in range(args.class_num):
# if lbl_cnt[i] == 0:
# print('Failed computing centroid (no images present) for Class', i)
cls_centroids[i] = run_sum[i] / lbl_cnt[i]
Expand All @@ -237,7 +155,7 @@ def compute_centroids(args,features_labels):
def tsne_plotter(args, np_cls_centroids, np_feat, np_label, plt_name=None, save_cnt_path=None):

cat_feat = np.concatenate((np_cls_centroids, np_feat), axis=0)
save_folder = f'{save_cnt_path}/{args.dset}'
save_folder = f'{save_cnt_path}/{args.dataset}'
if save_folder is not None:
if not os.path.exists(save_folder):
os.makedirs(save_folder)
Expand All @@ -246,8 +164,8 @@ def tsne_plotter(args, np_cls_centroids, np_feat, np_label, plt_name=None, save_

data_cen, data_feat = {}, {}
X_embedded = TSNE(n_components=2, perplexity=15, learning_rate=10).fit_transform(cat_feat)
data_cen['x_cent'], data_cen['y_cent'], data_cen['lbl_cent'] = X_embedded[:args.num_classes,0],X_embedded[:args.num_classes,1], np.arange(args.num_classes)
data_feat['x_feat'], data_feat['y_feat'], data_feat['lbl_feat'] = X_embedded[args.num_classes:,0],X_embedded[args.num_classes:,1], np_label
data_cen['x_cent'], data_cen['y_cent'], data_cen['lbl_cent'] = X_embedded[:args.class_num,0],X_embedded[:args.class_num,1], np.arange(args.class_num)
data_feat['x_feat'], data_feat['y_feat'], data_feat['lbl_feat'] = X_embedded[args.class_num:,0],X_embedded[args.class_num:,1], np_label

data_feat=pd.DataFrame(data_feat)
data_cen=pd.DataFrame(data_cen)
Expand All @@ -259,11 +177,17 @@ def tsne_plotter(args, np_cls_centroids, np_feat, np_label, plt_name=None, save_

if not os.path.exists('tsne_plts'):
os.makedirs('tsne_plts')
plt.savefig(f'tsne_plts/tsne_{plt_name}.pdf')
plt.savefig(f'tsne_plts/tsne_{plt_name}.png')

def main(args):

dom_dataloaders, dset_loaders_test = data_load(args, batch_size=args.batch_size)
dom_dataloaders = {}
# dom_dataloaders, dataset_loaders_test = data_load(args, batch_size=args.batch_size)
args.class_num, train_src_loader, train_target_loader, _, _ = setup_datasets(args)
args.source = ''.join(args.source)
args.target = ''.join(args.target)
dom_dataloaders[args.source] = train_src_loader
dom_dataloaders[args.target] = train_target_loader
model_loaders ={}

for domain in [args.source, args.target]:
Expand All @@ -285,13 +209,15 @@ def main(args):
parser = argparse.ArgumentParser(description='Clusformer')
parser.add_argument('-b', '--batch_size', default=32, type=int,help='mini-batch size (default: 54)')
# parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
parser.add_argument('-s', '--source', type=str,help='Select the source [amazon, dslr, webcam]')
parser.add_argument('-t', '--target', type=str,help='Select the target [amazon, dslr, webcam]')
parser.add_argument('-d', '--dset', type=str,help='Select the target [amazon, dslr, webcam]')
parser.add_argument('--root', default='data/', type=str, )
parser.add_argument('--workers', default=8, type=int )
parser.add_argument('-s', '--source',default='Ar,Pr', type=str,help='Select the source [amazon, dslr, webcam]')
parser.add_argument('-t', '--target', default='Cl,Rw',type=str,help='Select the target [amazon, dslr, webcam]')
parser.add_argument('-d', '--dataset',default='OfficeHome', type=str,help='Select the target [amazon, dslr, webcam]')
# parser.add_argument('-e', '--epochs', default=40, type=int,help='select number of cycles')
parser.add_argument('-w', '--wandb', default=0, type=int,help='Log to wandb or not [0 - dont use | 1 - use]')
parser.add_argument('--net', default='vit', type=str,help='Select vit or rn50 based (default: vit)')
parser.add_argument('-l', '--trained_wt', default='weights/office-home', type=str,help='Load src')
parser.add_argument('--net', default='resnet50', type=str,help='Select vit or resnet50 based (default: vit)')
parser.add_argument('-l', '--trained_wt', default='weights/uda/OfficeHome', type=str,help='Load src')
args = parser.parse_args()
print(args)
main(args)
Expand Down

0 comments on commit 597c8d3

Please sign in to comment.