In [1]:
!pip install scikit-learn-extra



## Code to pick initial samples using KMedoids and Random

In [2]:
import sys
sys.path.append('../')

In [3]:
import torch
import torch.nn as nn
import numpy as np
from sklearn_extra.cluster import KMedoids
from copy import deepcopy
import os
import random

In [4]:
device = torch.device("cuda:" + str(0) if torch.cuda.is_available() else "cpu")

In [5]:
from dataloader.dataloader import random_split_dataloader
from util.util import *
from copy import deepcopy

In [6]:
def split_domain(domains, split_idx, print_domain=True):
    source_domain = deepcopy(domains)
    target_domain = [source_domain.pop(split_idx)]
    if print_domain:
        print('Source domain: ', end='')
        for domain in source_domain:
            print(domain, end=', ')
        print('Target domain: ', end='')
        for domain in target_domain:
            print(domain)
    return source_domain, target_domain
    
domain_map = {
    'PACS': ['photo', 'art_painting', 'cartoon', 'sketch'],
    'PACS_random_split': ['photo', 'art_painting', 'cartoon', 'sketch'],
    'OfficeHome': ['Art', 'Clipart', 'Product', 'RealWorld'],
    'VLCS': ['Caltech', 'Labelme', 'Pascal', 'Sun']
}

def get_domain(name):
    if name not in domain_map:
        raise ValueError('Name of dataset unknown %s' %name)
    return domain_map[name]

In [7]:
ROOT_DIR = "/home/arfeen/papers_code/dom_gen_aaai_2020"
data_root = f"{ROOT_DIR}/PACS/kfold/"
#domain_samples = 170
domain_samples = 255
RANDOM_STATE = 42

INDICES_DIR = f"{ROOT_DIR}/saved-indices"
INDICES_PATH = f"{INDICES_DIR}/mixed_domain_clustering_indices_255.pt"

exp_nums = 3

In [8]:
source_domain, target_domain = split_domain(get_domain("PACS"), exp_nums)

Source domain: photo, art_painting, cartoon, Target domain: sketch


In [9]:
model = get_model("caffenet", "general")(num_classes=7, num_domains=3, pretrained=True)
caffenet_model = model.eval()

Using Caffe AlexNet


In [10]:
clustering_indices = {}
for i in range(4):
    exp_nums = i
    source_domain, target_domain = split_domain(get_domain("PACS"), exp_nums)
    print("*"*40)
       
    clustering_indices['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])] = {}
    
    source_train, source_val, target_test, source= random_split_dataloader(
    data="PACS", data_root=data_root, source_domain=source_domain, target_domain=target_domain,
    batch_size=128)
    all_indices_per_domain = []
    features = []
    labels = []
    indice = torch.load('/home/arfeen/papers_code/dom_gen_aaai_2020/dg_mmld-master/indices_final.pt')['indices']
    for i in range(len(source_domain)):
        sample_indices = indice[source_domain[i]]['clustering_indices']
        all_indices_per_domain.extend(sample_indices)
    for ind, batch in enumerate(source_train):

        print(f"Domain: {i} Batch number: {ind}")
        features.extend(caffenet_model.features(batch[0]))
        labels.extend(batch[1].numpy())
        
        
    features = [feature.data.numpy() for feature in features]
    indices_range = [i for i in range(len(features)) if i not in all_indices_per_domain]
    
    kmedoids = KMedoids(n_clusters=domain_samples, random_state=RANDOM_STATE)
    I = kmedoids.fit_predict(features)

    clustering_indices['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])]['clustering_indices'] = kmedoids.medoid_indices_.tolist()
    #clustering_indices['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])]['random_indices'] = random.sample(indices_range , k=domain_samples)

    class_strength_0 = np.zeros(7)
    class_strength_1 = np.zeros(7)

    for indices in clustering_indices['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])]['clustering_indices']:
        class_strength_0[labels[indices]] += 1

    #for indices in clustering_indices['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])]['random_indices']:
        #class_strength_1[labels[indices]] += 1

    clustering_indices['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])]['clustering_class_strength'] = class_strength_0
    #clustering_indices['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])]['random_class_strength'] = class_strength_1

Source domain: art_painting, cartoon, sketch, Target domain: photo
****************************************
length of domain wise samples 170
length of domain wise samples 170
length of domain wise samples 170
length of mixed samples 255
length of all combined samples 765
Train: 688, Val: 77, Test: 1670
Domain: 2 Batch number: 0
Domain: 2 Batch number: 1
Domain: 2 Batch number: 2
Domain: 2 Batch number: 3
Domain: 2 Batch number: 4
Domain: 2 Batch number: 5
Source domain: photo, cartoon, sketch, Target domain: art_painting
****************************************
length of domain wise samples 170
length of domain wise samples 170
length of domain wise samples 170
length of mixed samples 255
length of all combined samples 765
Train: 688, Val: 77, Test: 2048
Domain: 2 Batch number: 0
Domain: 2 Batch number: 1
Domain: 2 Batch number: 2
Domain: 2 Batch number: 3
Domain: 2 Batch number: 4
Domain: 2 Batch number: 5
Source domain: photo, art_painting, sketch, Target domain: cartoon
***********

In [12]:
clustering_indices

{"mixed_'art_painting', 'cartoon', 'sketch'_source_domain_indices": {'clustering_indices': [72,
   458,
   348,
   352,
   339,
   460,
   354,
   335,
   333,
   331,
   466,
   318,
   366,
   291,
   287,
   481,
   274,
   263,
   259,
   256,
   20,
   254,
   372,
   608,
   247,
   499,
   500,
   27,
   240,
   229,
   442,
   424,
   583,
   648,
   220,
   595,
   211,
   515,
   650,
   652,
   528,
   193,
   192,
   530,
   33,
   656,
   182,
   176,
   171,
   166,
   389,
   158,
   157,
   659,
   438,
   155,
   545,
   149,
   392,
   674,
   600,
   113,
   444,
   443,
   95,
   91,
   89,
   566,
   86,
   69,
   84,
   410,
   175,
   647,
   409,
   75,
   349,
   2,
   78,
   407,
   80,
   81,
   614,
   671,
   347,
   1,
   271,
   87,
   88,
   66,
   670,
   340,
   92,
   618,
   353,
   431,
   96,
   63,
   98,
   99,
   334,
   101,
   404,
   103,
   104,
   8,
   561,
   62,
   591,
   109,
   613,
   430,
   447,
   439,
   327,
   115,
   553,
   4

In [13]:
len(clustering_indices['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])]['clustering_indices'])

255

In [14]:
if not os.path.exists(INDICES_DIR):
    print("Making directory to save indices")
    os.mkdir(INDICES_DIR)
else:
    print("Directory is already present!")

Directory is already present!


In [15]:
torch.save({'indices': clustering_indices}, INDICES_PATH)

In [16]:
INDICES = torch.load(INDICES_PATH)['indices']

In [17]:
INDICES.keys()

dict_keys(["mixed_'art_painting', 'cartoon', 'sketch'_source_domain_indices", "mixed_'photo', 'cartoon', 'sketch'_source_domain_indices", "mixed_'photo', 'art_painting', 'sketch'_source_domain_indices", "mixed_'photo', 'art_painting', 'cartoon'_source_domain_indices"])

In [20]:
len(INDICES["mixed_'art_painting', 'cartoon', 'sketch'_source_domain_indices"]['clustering_indices'])

255

In [None]:
len(INDICES["mixed_'art_painting', 'cartoon', 'sketch'_source_domain_indices"]['clustering_indices'])==len(set(INDICES["mixed_'art_painting', 'cartoon', 'sketch'_source_domain_indices"]['clustering_indices']))

In [None]:
dir(source)

In [None]:
source.samples

In [None]:
len(set(INDICES['art_painting']['clustering_indices'])) == len(INDICES['art_painting']['clustering_indices'])