In [1]:
!pip install scikit-learn-extra



## Code to pick initial samples using KMedoids and Random

In [2]:
import sys
sys.path.append('../')

In [3]:
import torch
import torch.nn as nn
import numpy as np
from sklearn_extra.cluster import KMedoids
from copy import deepcopy
import os
import random

In [4]:
device = torch.device("cuda:" + str(0) if torch.cuda.is_available() else "cpu")

In [5]:
from dataloader.dataloader import random_split_dataloader
from util.util import *
from copy import deepcopy

In [6]:
def split_domain(domains, split_idx, print_domain=True):
    source_domain = deepcopy(domains)
    target_domain = [source_domain.pop(split_idx)]
    if print_domain:
        print('Source domain: ', end='')
        for domain in source_domain:
            print(domain, end=', ')
        print('Target domain: ', end='')
        for domain in target_domain:
            print(domain)
    return source_domain, target_domain
    
domain_map = {
    'PACS': ['photo', 'art_painting', 'cartoon', 'sketch'],
    'PACS_random_split': ['photo', 'art_painting', 'cartoon', 'sketch'],
    'OfficeHome': ['Art', 'Clipart', 'Product', 'RealWorld'],
    'VLCS': ['Caltech', 'Labelme', 'Pascal', 'Sun']
}

def get_domain(name):
    if name not in domain_map:
        raise ValueError('Name of dataset unknown %s' %name)
    return domain_map[name]

In [7]:
ROOT_DIR = "/home/arfeen/papers_code/dom_gen_aaai_2020"
data_root = f"{ROOT_DIR}/PACS/kfold/"
#domain_samples = 170
domain_samples = 255
RANDOM_STATE = 42

INDICES_DIR = f"{ROOT_DIR}/saved-indices"
INDICES_PATH = f"{INDICES_DIR}/mixed_domain_clustering_indices_all_indices.pt"

exp_nums = 3

In [8]:
source_domain, target_domain = split_domain(get_domain("PACS"), exp_nums)

Source domain: photo, art_painting, cartoon, Target domain: sketch


In [9]:
model = get_model("caffenet", "general")(num_classes=7, num_domains=3, pretrained=True)
caffenet_model = model.eval()

Using Caffe AlexNet


In [None]:
clustering_indices = {}
for i in range(4):
    exp_nums = i
    source_domain, target_domain = split_domain(get_domain("PACS"), exp_nums)
    print("*"*40)
       
    clustering_indices['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])] = {}
    
    source_train, source_val, target_test, source= random_split_dataloader(
    data="PACS", data_root=data_root, source_domain=source_domain, target_domain=target_domain,
    batch_size=128)
    
    all_indices_per_domain = []
    features = []
    labels = []
    indice = torch.load('/home/arfeen/papers_code/dom_gen_aaai_2020/dg_mmld-master/indices_final.pt')['indices']
    for j in range(len(source_domain)):
        sample_indices = indice[source_domain[j]]['random_indices']
        all_indices_per_domain.extend(sample_indices)
        print('domain wise indices :',len(all_indices_per_domain))
    for ind, batch in enumerate(source_train):
        
        print(f"Domain: {i} Batch number: {ind}")
        features.extend(caffenet_model.features(batch[0]))
        labels.extend(batch[1].numpy())
        
        
    features = [feature.data.numpy() for feature in features]
    print('length of features :', len(features))
    indices_range = [k for k in range(len(features)) if k not in all_indices_per_domain]
    
    print('length of indices_range :', len(indices_range))
    print('length of all_indices_per_domain :', len(all_indices_per_domain))
    
    kmedoids = KMedoids(n_clusters=domain_samples, random_state=RANDOM_STATE)
    I = kmedoids.fit_predict(features)

    clustering_indices['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])]['clustering_indices'] = kmedoids.medoid_indices_.tolist()
    #clustering_indices['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])]['random_indices'] = random.sample(indices_range , k=domain_samples)
    clustering_indices['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])]['random_indices'] = indices_range

    class_strength_0 = np.zeros(7)
    class_strength_1 = np.zeros(7)

    for indices in clustering_indices['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])]['clustering_indices']:
        class_strength_0[labels[indices]] += 1

    for indices in clustering_indices['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])]['random_indices']:
        class_strength_1[labels[indices]] += 1

    clustering_indices['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])]['clustering_class_strength'] = class_strength_0
    clustering_indices['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])]['random_class_strength'] = class_strength_1

Source domain: art_painting, cartoon, sketch, Target domain: photo
****************************************
Train: 7488, Val: 833, Test: 1670
source_train length 59
domain wise indices : 170
domain wise indices : 340
domain wise indices : 510
Domain: 0 Batch number: 0
Domain: 0 Batch number: 1
Domain: 0 Batch number: 2
Domain: 0 Batch number: 3
Domain: 0 Batch number: 4
Domain: 0 Batch number: 5
Domain: 0 Batch number: 6
Domain: 0 Batch number: 7
Domain: 0 Batch number: 8
Domain: 0 Batch number: 9
Domain: 0 Batch number: 10
Domain: 0 Batch number: 11
Domain: 0 Batch number: 12
Domain: 0 Batch number: 13
Domain: 0 Batch number: 14
Domain: 0 Batch number: 15
Domain: 0 Batch number: 16
Domain: 0 Batch number: 17
Domain: 0 Batch number: 18
Domain: 0 Batch number: 19
Domain: 0 Batch number: 20
Domain: 0 Batch number: 21
Domain: 0 Batch number: 22
Domain: 0 Batch number: 23
Domain: 0 Batch number: 24
Domain: 0 Batch number: 25
Domain: 0 Batch number: 26
Domain: 0 Batch number: 27
Domain: 0 B

In [15]:
indice

{'art_painting': {'clustering_indices': [1170,
   825,
   218,
   217,
   1276,
   1471,
   533,
   874,
   1758,
   1854,
   627,
   1655,
   1888,
   1235,
   1092,
   177,
   1680,
   176,
   665,
   1351,
   1352,
   950,
   1222,
   279,
   498,
   1356,
   280,
   1132,
   1088,
   1780,
   140,
   1541,
   138,
   1645,
   596,
   945,
   1447,
   681,
   1585,
   1781,
   1370,
   1494,
   1818,
   43,
   730,
   1382,
   1743,
   1805,
   792,
   592,
   935,
   446,
   1546,
   105,
   822,
   55,
   1197,
   91,
   1191,
   397,
   578,
   704,
   1735,
   1919,
   927,
   1421,
   1420,
   1703,
   801,
   1725,
   1931,
   612,
   1417,
   929,
   64,
   1287,
   1761,
   77,
   78,
   1978,
   1690,
   1723,
   1676,
   420,
   1048,
   85,
   758,
   1897,
   1573,
   1986,
   1193,
   717,
   309,
   214,
   966,
   308,
   1793,
   1162,
   820,
   1267,
   1792,
   1309,
   697,
   103,
   582,
   204,
   201,
   981,
   1051,
   1626,
   10,
   340,
   1928,
   1687,

In [14]:
len(set(all_indices_per_domain))==len(all_indices_per_domain)

False

In [None]:
clustering_indices

In [None]:
len(clustering_indices['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])]['random_indices'])

In [None]:
if not os.path.exists(INDICES_DIR):
    print("Making directory to save indices")
    os.mkdir(INDICES_DIR)
else:
    print("Directory is already present!")

In [None]:
torch.save({'indices': clustering_indices}, INDICES_PATH)

In [None]:
INDICES = torch.load(INDICES_PATH)['indices']

In [None]:
INDICES.keys()

In [None]:
source_domain = domain_map["PACS"]

In [None]:
source_domain

In [None]:
source_domain.pop(0)

In [None]:
source_domain

In [None]:
dir(source)

In [None]:
source.samples

In [None]:
len(set(INDICES['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])]['random_indices']))==len(INDICES['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])]['random_indices'])