In [1]:
import numpy as np

import sys
import os

# sys.path.append("./code/train_code/")

import config as config
import metrics as metrics
from config_populate import data_settings
import exp_select as exp_select

import importlib

import gaussian_utils

from trainer import MultiSourceTrainer

import torch
from torch.autograd import Variable

from evaluate_model import get_target_accuracy

In [2]:
def change_config(dataset, task, exp_id):
    config.dataset_name = dataset
    config.data_key = task
    
    config.settings['dataset_dir']                 = os.path.join('data', 'pretrained-features', config.dataset_name)
    
    config.settings['C']                           = data_settings[config.dataset_name][config.data_key]['C']
    config.settings['C_dash']                      = data_settings[config.dataset_name][config.data_key]['C_dash']
    config.settings['num_C_dash']                  = data_settings[config.dataset_name][config.data_key]['num_C_dash']

    config.settings['num_C']                       = data_settings[config.dataset_name][config.data_key]['num_C']
    config.settings['src_datasets']                = data_settings[config.dataset_name][config.data_key]['src_datasets']
    config.settings['trgt_datasets']               = data_settings[config.dataset_name][config.data_key]['trgt_datasets']
    
    config.settings['log_interval']                = config.settings['expt_dict'][config.dataset_name][config.data_key]['val_after']
    config.settings['start_iter']                  = 0
    config.settings['max_iter']                    = config.settings['expt_dict'][config.dataset_name][config.data_key]['max_iter']
    config.settings['enough_iter']                 = config.settings['expt_dict'][config.dataset_name][config.data_key]['enough_iter']
    config.settings['val_after']                   = config.settings['expt_dict'][config.dataset_name][config.data_key]['val_after']
    config.settings['batch_size']                  = config.settings['expt_dict'][config.dataset_name][config.data_key]['batch_size']
    config.settings['adapt_batch_size']            = config.settings['expt_dict'][config.dataset_name][config.data_key]['adapt_batch_size']
    config.settings['val_batch_size_factor']       = config.settings['expt_dict'][config.dataset_name][config.data_key]['val_batch_size_factor']
    
    exp_select.exp_id = exp_id
    config.settings['exp_name']                    = config.gen_exp_name()

In [3]:
# Learn the weights w minimizing the generalizability objective
def learn_w_w2(it_thresh='max_iter', num_steps=100):
    w2_dist = np.zeros(len(config.settings['src_datasets']))
    
    # The approximation is better if the batch size is larger. Let's see how large we can make it
    multiplier = 5
    for src_domain_idx in range(len(config.settings['src_datasets'])):
        dom = config.settings['src_datasets'][src_domain_idx]
        trainer = MultiSourceTrainer(src_domain_idx)
        trainer.load_model_weights(it_thresh=it_thresh)
        trainer.set_mode(trainer.settings['mode']['val'])
        initial_batch_size = trainer.adapt_batch_size
        
        for curr_multiplier in range(1, 6):
            print("Loading trainer for source domain {}".format(config.settings['src_datasets'][src_domain_idx]))

            # align the batch sizes
            trainer.adapt_batch_size = initial_batch_size * curr_multiplier
            trainer.batch_size = initial_batch_size * curr_multiplier

            if trainer.batch_size > 600:
                multiplier = min(multiplier, curr_multiplier - 1)
                break

            try:
                trainer.initialize_src_train_dataloader()
                _,X_src,_,_ = trainer.source_dl_iter_train_list.next()
                X_src = Variable(X_src).to(trainer.settings['device']).float()

                trainer.initialize_target_adapt_dataloader()
                _,X_tar,_,_ = trainer.adapt_target_dl_iter_train_list.next()
                X_tar = Variable(X_tar).to(trainer.settings['device']).float()
            except StopIteration:
                print("Multiplier for {} is {}".format(dom, curr_multiplier - 1))
                multiplier = min(multiplier, curr_multiplier - 1)
                break
                
    print('multiplier = {}'.format(multiplier))
    print("\n\n")
    
    for src_domain_idx in range(len(config.settings['src_datasets'])):
        print("Loading trainer for source domain {}".format(config.settings['src_datasets'][src_domain_idx]))

        dom = config.settings['src_datasets'][src_domain_idx]
        trainer = MultiSourceTrainer(src_domain_idx)
        
        # learn the gaussians
        trainer.load_model_weights(it_thresh='enough_iter')
        gaussian_utils.learn_gaussians(trainer)
        n_samples = np.ones(trainer.N_CLASSES, dtype=int) * trainer.settings["gaussian_samples_per_class"] * 100
        trainer.gaussian_z, trainer.gaussian_y = gaussian_utils.sample_from_gaussians(trainer.means, trainer.covs, n_samples)
        
        # If there are high confidence pseudo-labels for most of the samples, use those for adaptation
#         trainer.load_model_weights(it_thresh='enough_iter')
#         trainer.pseudo_target_dist, conf, _ = trainer.get_target_pseudo_distribution()
#         if np.sum(np.max(conf, axis=-1) > trainer.settings['confidence_thresh']) / conf.shape[0] < trainer.settings['confidence_ratio']:
        trainer.pseudo_target_dist = np.ones(trainer.N_CLASSES, dtype=int)
            
        print("Dist=")
        print(trainer.pseudo_target_dist)
        
        trainer.load_model_weights(it_thresh=it_thresh)
        trainer.set_mode(trainer.settings['mode']['val'])
        
        trainer.adapt_batch_size *= multiplier
        trainer.batch_size = trainer.adapt_batch_size    # align the batch sizes

        trainer.initialize_src_train_dataloader()
        trainer.initialize_target_adapt_dataloader()
        
        # Get the mean W2 distance between encodings of the current domain and the target domain
        # Get a source sample
        w2_dist[src_domain_idx] = 0
        for step in range(num_steps):
            # Get a target sample
            try:
                _,X_src,_,_ = trainer.source_dl_iter_train_list.next()
                X_src = Variable(X_src).to(trainer.settings['device']).float()
            except StopIteration:
                trainer.initialize_src_train_dataloader()
                _,X_src,_,_ = trainer.source_dl_iter_train_list.next()
                X_src = Variable(X_src).to(trainer.settings['device']).float()

            # Get a target sample
            try:
                _,X_tar,_,_ = trainer.adapt_target_dl_iter_train_list.next()
                X_tar = Variable(X_tar).to(trainer.settings['device']).float()
            except StopIteration:
                trainer.initialize_target_adapt_dataloader()
                _,X_tar,_,_ = trainer.adapt_target_dl_iter_train_list.next()
                X_tar = Variable(X_tar).to(trainer.settings['device']).float()
                
            # Compute the number of gaussian samples to be used for the current batch
            normalized_dist = trainer.pseudo_target_dist / np.sum(trainer.pseudo_target_dist)
            num_samples = np.array(normalized_dist * trainer.adapt_batch_size, dtype=int)
            while trainer.adapt_batch_size > np.sum(num_samples):
                idx = np.random.choice(range(trainer.N_CLASSES), p = normalized_dist)
                num_samples[idx] += 1

            # Get gaussian samples for the current batch
            gz = []
            gy = []
            for c in range(trainer.N_CLASSES):
                ind = np.where(trainer.gaussian_y == c)[0]
                ind = ind[np.random.choice(range(len(ind)), num_samples[c], replace=False)]
                gz.append(trainer.gaussian_z[ind])
                gy.append(trainer.gaussian_y[ind])
            gz = np.vstack(gz)
            gy = np.concatenate(gy)

            gz = torch.as_tensor(gz).to(trainer.settings['device']).float()
            gy = torch.as_tensor(gy).to(trainer.settings['device']).long()

            with torch.no_grad():
                f_src = trainer.network.model['global']['Fs'](X_src)
                f_tar = trainer.network.model['global']['Fs'](X_tar)
                
                d1 = metrics.sliced_wasserstein_distance(f_src, gz, trainer.settings['num_projections'], 2, trainer.settings['device']).item()
                d2 = metrics.sliced_wasserstein_distance(f_tar, gz, trainer.settings['num_projections'], 2, trainer.settings['device']).item()
                w2_dist[src_domain_idx] += d1 + d2
    

            if step % (num_steps // 10) == 0:
                print(step, w2_dist, dom)
                
    return w2_dist
        
#     w = 1 / w2_dist
#     w = w / np.sum(w)
        
#     return w

In [None]:
res_store = []
res_store_2 = []

for dataset in ['office-31', 'image-clef', 'office-caltech', 'office-home']:
    tasks = config.settings['expt_dict'][dataset].keys()
    
    print(dataset)
    
    
    for task in tasks:
        print(task)

        for exp in range(5):
            change_config(dataset, task, exp)
            
            w = learn_w_w2()
            w = w / np.sum(w)
            
            w_hat_w2 = 1 / np.copy(w)
            w_hat_w2 = w_hat_w2 / np.sum(w_hat_w2)
            res_store.append((dataset, task, exp, get_target_accuracy(w_hat_w2, 'max_iter')[0], np.copy(w_hat_w2)))
            
            w_hat_w2 = 1 - np.copy(w)
            w_hat_w2 = w_hat_w2 / np.sum(w_hat_w2)
            res_store_2.append((dataset, task, exp, get_target_accuracy(w_hat_w2, 'max_iter')[0], np.copy(w_hat_w2)))
            
            print(res_store[-1])
            print(res_store_2[-1])

office-31
AD_W
Loading trainer for source domain amazon
Loading trainer for source domain amazon
Loading trainer for source domain amazon
Loading trainer for source domain amazon
Loading trainer for source domain amazon
Loading trainer for source domain dslr
Loading trainer for source domain dslr
Loading trainer for source domain dslr
Loading trainer for source domain dslr
Loading trainer for source domain dslr
multiplier = 5



Loading trainer for source domain amazon


amazon: 100%|██████████| 176/176 [00:00<00:00, 387.08it/s]


Dist=
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
0 [0.05901814 0.        ] amazon
10 [0.64640579 0.        ] amazon
20 [1.25674414 0.        ] amazon
30 [1.90525379 0.        ] amazon
40 [2.53242396 0.        ] amazon
50 [3.14613749 0.        ] amazon
60 [3.7923124 0.       ] amazon
70 [4.42834812 0.        ] amazon
80 [5.07870832 0.        ] amazon
90 [5.68782908 0.        ] amazon
Loading trainer for source domain dslr


dslr: 100%|██████████| 31/31 [00:00<00:00, 183.02it/s]


Dist=
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
0 [6.23233495 0.05392897] dslr
10 [6.23233495 0.59802025] dslr
20 [6.23233495 1.09320621] dslr
30 [6.23233495 1.64965891] dslr
40 [6.23233495 2.17245644] dslr
50 [6.23233495 2.74557499] dslr
60 [6.23233495 3.27255114] dslr
70 [6.23233495 3.82315779] dslr
80 [6.23233495 4.33155094] dslr
90 [6.23233495 4.9015769 ] dslr
Loading trainer for source domain amazon


webcam: 100%|██████████| 3/3 [00:00<00:00, 22.85it/s]


Loading trainer for source domain dslr


webcam: 100%|██████████| 3/3 [00:00<00:00, 23.45it/s]


Loading trainer for source domain amazon


webcam: 100%|██████████| 3/3 [00:00<00:00, 21.50it/s]


Loading trainer for source domain dslr


webcam: 100%|██████████| 3/3 [00:00<00:00, 24.34it/s]


('office-31', 'AD_W', 0, 0.9382871536523929, array([0.46475592, 0.53524408]))
('office-31', 'AD_W', 0, 0.9382871536523929, array([0.46475592, 0.53524408]))
Loading trainer for source domain amazon
Loading trainer for source domain amazon
Loading trainer for source domain amazon
Loading trainer for source domain amazon
Loading trainer for source domain amazon
Loading trainer for source domain dslr
Loading trainer for source domain dslr
Loading trainer for source domain dslr
Loading trainer for source domain dslr
Loading trainer for source domain dslr
multiplier = 5



Loading trainer for source domain amazon


amazon: 100%|██████████| 176/176 [00:00<00:00, 342.76it/s]


Dist=
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
0 [0.06081343 0.        ] amazon
10 [0.62572178 0.        ] amazon
20 [1.17881612 0.        ] amazon
30 [1.71961884 0.        ] amazon
40 [2.27033443 0.        ] amazon
50 [2.83790485 0.        ] amazon
60 [3.40866809 0.        ] amazon
70 [3.95061831 0.        ] amazon
80 [4.50925705 0.        ] amazon
90 [5.07676887 0.        ] amazon
Loading trainer for source domain dslr


dslr: 100%|██████████| 31/31 [00:00<00:00, 191.00it/s]


Dist=
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
0 [5.59201054 0.05270137] dslr
10 [5.59201054 0.53290462] dslr
20 [5.59201054 1.03393811] dslr
30 [5.59201054 1.53718082] dslr
40 [5.59201054 2.05301042] dslr
50 [5.59201054 2.54216597] dslr


In [None]:
for dataset in ['office-31', 'image-clef', 'office-caltech', 'office-home']:
    tasks = config.settings['expt_dict'][dataset].keys()
    
    dataset_avg = []
    
    for t in tasks:
#         print(dataset, t)
        avg = []
        for r in res_store:
            if r[0] == dataset and r[1] == t:
                dataset_avg.append(r[3] * 100)
                avg.append(r[3] * 100)
#         avg = np.asarray(avg)
        
        print(dataset, t, np.mean(avg).round(1), np.var(avg).round(2))
    print(dataset, np.mean(dataset_avg).round(1))
    print()

In [None]:
# # for r in res_store:
# #     print(r)
    
# # Writing to file
# with open("../../run_benchmarks/updated_w2_scores.txt", "w") as save_file:
#     for dataset in ['office-31', 'image-clef', 'office-caltech', 'office-home']:
#         tasks = config.settings['expt_dict'][dataset].keys()

#         dataset_avg = []

#         for t in tasks:
#     #         print(dataset, t)
#             avg = []
#             for r in res_store:
#                 if r[0] == dataset and r[1] == t:
#                     dataset_avg.append(r[3] * 100)
#                     avg.append(r[3] * 100)
#     #         avg = np.asarray(avg)

#             save_file.write("{} {} {} {}\n".format(dataset, t, np.mean(avg).round(1), np.var(avg).round(2)))
#         save_file.write("{} {}\n".format(dataset, np.mean(dataset_avg).round(1)))
#         save_file.write("\n")
    
#     # Writing data to a file
#     for r in res_store:
#         save_file.write(str(r) + "\n")