In [1]:
!pip install scikit-learn-extra



## Code to pick initial samples using KMedoids and Random

In [2]:
import sys
sys.path.append('../')

In [15]:
import torch
import torch.nn as nn
import numpy as np
from sklearn_extra.cluster import KMedoids
from copy import deepcopy
import os
import random

In [16]:
device = torch.device("cuda:" + str(0) if torch.cuda.is_available() else "cpu")

In [17]:
from dataloader.dataloader import random_split_dataloader
from util.util import *
from copy import deepcopy

In [18]:
def split_domain(domains, split_idx, print_domain=True):
    source_domain = deepcopy(domains)
    target_domain = [source_domain.pop(split_idx)]
    if print_domain:
        print('Source domain: ', end='')
        for domain in source_domain:
            print(domain, end=', ')
        print('Target domain: ', end='')
        for domain in target_domain:
            print(domain)
    return source_domain, target_domain
    
domain_map = {
    'PACS': ['photo', 'art_painting', 'cartoon', 'sketch'],
    'PACS_random_split': ['photo', 'art_painting', 'cartoon', 'sketch'],
    'OfficeHome': ['Art', 'Clipart', 'Product', 'RealWorld'],
    'VLCS': ['Caltech', 'Labelme', 'Pascal', 'Sun']
}

def get_domain(name):
    if name not in domain_map:
        raise ValueError('Name of dataset unknown %s' %name)
    return domain_map[name]

In [19]:
ROOT_DIR = "/home/arfeen/papers_code/dom_gen_aaai_2020"
data_root = f"{ROOT_DIR}/PACS/kfold/"
#domain_samples = 170
domain_samples = 255
RANDOM_STATE = 42

INDICES_DIR = f"{ROOT_DIR}/saved-indices"
INDICES_PATH = f"{INDICES_DIR}/mixed_domain_indices_255.pt"

exp_nums = 0

In [20]:
source_domain, target_domain = split_domain(get_domain("PACS"), exp_nums)

Source domain: art_painting, cartoon, sketch, Target domain: photo


In [21]:
model = get_model("caffenet", "general")(num_classes=7, num_domains=3, pretrained=True)
caffenet_model = model.eval()

Using Caffe AlexNet


In [22]:
clustering_indices = {}
for i in range(4):
    exp_nums = i
    source_domain, target_domain = split_domain(get_domain("PACS"), exp_nums)
    print("*"*40)
       
    clustering_indices['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])] = {}
    
    source_train, source_val, target_test, source= random_split_dataloader(
    data="PACS", data_root=data_root, source_domain=source_domain, target_domain=target_domain,
    batch_size=128)
    all_indices_per_domain = []
    features = []
    labels = []
    indice = torch.load('/home/arfeen/papers_code/dom_gen_aaai_2020/dg_mmld-master/indices_final.pt')['indices']
    for i in range(len(source_domain)):
        sample_indices = indice[source_domain[i]]['random_indices']
        all_indices_per_domain.extend(sample_indices)
    for ind, batch in enumerate(source_train):

        print(f"Domain: {i} Batch number: {ind}")
        features.extend(caffenet_model.features(batch[0]))
        labels.extend(batch[1].numpy())
        
        
    features = [feature.data.numpy() for feature in features]
    indices_range = [i for i in range(len(features)) if i not in all_indices_per_domain]
    kmedoids = KMedoids(n_clusters=domain_samples, random_state=RANDOM_STATE)
    I = kmedoids.fit_predict(features)

    clustering_indices['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])]['clustering_indices'] = kmedoids.medoid_indices_.tolist()
    clustering_indices['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])]['random_indices'] = random.sample(indices_range , k=domain_samples)

    class_strength_0 = np.zeros(7)
    class_strength_1 = np.zeros(7)

    for indices in clustering_indices['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])]['clustering_indices']:
        class_strength_0[labels[indices]] += 1

    for indices in clustering_indices['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])]['random_indices']:
        class_strength_1[labels[indices]] += 1

    clustering_indices['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])]['clustering_class_strength'] = class_strength_0
    clustering_indices['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])]['random_class_strength'] = class_strength_1

Source domain: art_painting, cartoon, sketch, Target domain: photo
****************************************
Train: 7488, Val: 833, Test: 1670
Domain: 2 Batch number: 0
Domain: 2 Batch number: 1
Domain: 2 Batch number: 2
Domain: 2 Batch number: 3
Domain: 2 Batch number: 4
Domain: 2 Batch number: 5
Domain: 2 Batch number: 6
Domain: 2 Batch number: 7
Domain: 2 Batch number: 8
Domain: 2 Batch number: 9
Domain: 2 Batch number: 10
Domain: 2 Batch number: 11
Domain: 2 Batch number: 12
Domain: 2 Batch number: 13
Domain: 2 Batch number: 14
Domain: 2 Batch number: 15
Domain: 2 Batch number: 16
Domain: 2 Batch number: 17
Domain: 2 Batch number: 18
Domain: 2 Batch number: 19
Domain: 2 Batch number: 20
Domain: 2 Batch number: 21
Domain: 2 Batch number: 22
Domain: 2 Batch number: 23
Domain: 2 Batch number: 24
Domain: 2 Batch number: 25
Domain: 2 Batch number: 26
Domain: 2 Batch number: 27
Domain: 2 Batch number: 28
Domain: 2 Batch number: 29
Domain: 2 Batch number: 30
Domain: 2 Batch number: 31
Doma

In [23]:
len(clustering_indices['mixed_{}_source_domain_indices'.format(str(source_domain)[1:-1])]['random_indices'])

255

In [24]:
if not os.path.exists(INDICES_DIR):
    print("Making directory to save indices")
    os.mkdir(INDICES_DIR)
else:
    print("Directory is already present!")

Directory is already present!


In [25]:
torch.save({'indices': clustering_indices}, INDICES_PATH)

In [None]:
INDICES = torch.load(INDICES_PATH)['indices']

In [None]:
INDICES.keys()

In [None]:
source_domain = domain_map["PACS"]

In [None]:
source_domain

In [None]:
source_train, source_val, target_test, source= random_split_dataloader_init(
data="PACS", data_root=data_root, source_domain=source_domain[:1], target_domain=target_domain,
batch_size=128)

In [None]:
dir(source)

In [None]:
source.samples

In [None]:
len(INDICES['art_painting']['random_indices'])