In [11]:
!pip install scikit-learn-extra



## Code to pick initial samples using KMedoids and Random

In [12]:
import sys
sys.path.append('../')

In [13]:
import torch
import torch.nn as nn
import numpy as np
from sklearn_extra.cluster import KMedoids

import os
import random

In [14]:
device = torch.device("cuda:" + str(0) if torch.cuda.is_available() else "cpu")

In [15]:
from dataloader.dataloader import random_split_dataloader
from util.util import *
from copy import deepcopy

In [16]:
def split_domain(domains, split_idx, print_domain=True):
    source_domain = deepcopy(domains)
    target_domain = [source_domain.pop(split_idx)]
    if print_domain:
        print('Source domain: ', end='')
        for domain in source_domain:
            print(domain, end=', ')
        print('Target domain: ', end='')
        for domain in target_domain:
            print(domain)
    return source_domain, target_domain
    
domain_map = {
    'PACS': ['photo', 'art_painting', 'cartoon', 'sketch'],
    'PACS_random_split': ['photo', 'art_painting', 'cartoon', 'sketch'],
    'OfficeHome': ['Art', 'Clipart', 'Product', 'RealWorld'],
    'VLCS': ['Caltech', 'Labelme', 'Pascal', 'Sun']
}

def get_domain(name):
    if name not in domain_map:
        raise ValueError('Name of dataset unknown %s' %name)
    return domain_map[name]

In [22]:
ROOT_DIR = "../.."
data_root = f"{ROOT_DIR}/PACS_data/kfold/"
domain_samples = 170
RANDOM_STATE = 42

INDICES_DIR = f"{ROOT_DIR}/saved-indices"
INDICES_PATH = f"{INDICES_DIR}/indices_final.pt"

exp_nums = 0

In [23]:
source_domain, target_domain = split_domain(get_domain("PACS"), exp_nums)

Source domain: art_painting, cartoon, sketch, Target domain: photo


In [24]:
model = get_model("caffenet", "general")(
        num_classes=7, num_domains=3, pretrained=True)
caffenet_model = model.eval()

Using Caffe AlexNet


In [27]:
source_domain[0:1]

['art_painting']

In [28]:
clustering_indices = {}
for i in range(4):
    print("*"*40)
    if i==3:
        current_domain = target_domain
    else:
        current_domain = source_domain[i:i+1]
    
    clustering_indices[current_domain[0]] = {}
    
    source_train, source_val, target_test, source= random_split_dataloader_init(
    data="PACS", data_root=data_root, source_domain=current_domain, target_domain=target_domain,
    batch_size=128)
    
    features = []
    labels = []
    for ind, batch in enumerate(source_train):

        print(f"Domain: {i} Batch number: {ind}")
        features.extend(caffenet_model.features(batch[0]))
        labels.extend(batch[1].numpy())
    
    features = [feature.data.numpy() for feature in features]

    kmedoids = KMedoids(n_clusters=domain_samples, random_state=RANDOM_STATE)
    I = kmedoids.fit_predict(features)

    clustering_indices[current_domain[0]]['clustering_indices'] = kmedoids.medoid_indices_.tolist()
    clustering_indices[current_domain[0]]['random_indices'] = random.sample(range(0, len(features)) , k=domain_samples)

    class_strength_0 = np.zeros(7)
    class_strength_1 = np.zeros(7)

    for indices in clustering_indices[current_domain[0]]['clustering_indices']:
        class_strength_0[labels[indices]] += 1

    for indices in clustering_indices[current_domain[0]]['random_indices']:
        class_strength_1[labels[indices]] += 1

    clustering_indices[current_domain[0]]['clustering_class_strength'] = class_strength_0
    clustering_indices[current_domain[0]]['random_class_strength'] = class_strength_1

****************************************
Train: 2048, Val: 0, Test: 1670
Domain: 0 Batch number: 0
Domain: 0 Batch number: 1
Domain: 0 Batch number: 2
Domain: 0 Batch number: 3
Domain: 0 Batch number: 4
Domain: 0 Batch number: 5
Domain: 0 Batch number: 6
Domain: 0 Batch number: 7
Domain: 0 Batch number: 8
Domain: 0 Batch number: 9
Domain: 0 Batch number: 10
Domain: 0 Batch number: 11
Domain: 0 Batch number: 12
Domain: 0 Batch number: 13
Domain: 0 Batch number: 14
Domain: 0 Batch number: 15
****************************************
Train: 2344, Val: 0, Test: 1670
Domain: 1 Batch number: 0
Domain: 1 Batch number: 1
Domain: 1 Batch number: 2
Domain: 1 Batch number: 3
Domain: 1 Batch number: 4
Domain: 1 Batch number: 5
Domain: 1 Batch number: 6
Domain: 1 Batch number: 7
Domain: 1 Batch number: 8
Domain: 1 Batch number: 9
Domain: 1 Batch number: 10
Domain: 1 Batch number: 11
Domain: 1 Batch number: 12
Domain: 1 Batch number: 13
Domain: 1 Batch number: 14
Domain: 1 Batch number: 15
Domain: 1 

In [29]:
clustering_indices

{'art_painting': {'clustering_indices': [1170,
   825,
   218,
   217,
   1276,
   1471,
   533,
   874,
   1758,
   1854,
   627,
   1655,
   1888,
   1235,
   1092,
   177,
   1680,
   176,
   665,
   1351,
   1352,
   950,
   1222,
   279,
   498,
   1356,
   280,
   1132,
   1088,
   1780,
   140,
   1541,
   138,
   1645,
   596,
   945,
   1447,
   681,
   1585,
   1781,
   1370,
   1494,
   1818,
   43,
   730,
   1382,
   1743,
   1805,
   792,
   592,
   935,
   446,
   1546,
   105,
   822,
   55,
   1197,
   91,
   1191,
   397,
   578,
   704,
   1735,
   1919,
   927,
   1421,
   1420,
   1703,
   801,
   1725,
   1931,
   612,
   1417,
   929,
   64,
   1287,
   1761,
   77,
   78,
   1978,
   1690,
   1723,
   1676,
   420,
   1048,
   85,
   758,
   1897,
   1573,
   1986,
   1193,
   717,
   309,
   214,
   966,
   308,
   1793,
   1162,
   820,
   1267,
   1792,
   1309,
   697,
   103,
   582,
   204,
   201,
   981,
   1051,
   1626,
   10,
   340,
   1928,
   1687,

In [30]:
if not os.path.exists(INDICES_DIR):
    print("Making directory to save indices")
    os.mkdir(INDICES_DIR)
else:
    print("Directory is already present!")

Directory is already present!


In [31]:
torch.save({'indices': clustering_indices}, INDICES_PATH)

In [32]:
INDICES = torch.load(INDICES_PATH)['indices']

In [33]:
INDICES.keys()

dict_keys(['art_painting', 'cartoon', 'sketch', 'photo'])

In [None]:
source_domain = domain_map["PACS"]

In [None]:
source_domain

In [9]:
INDICES

{'photo': {'clustering_indices': [921,
   250,
   2,
   1660,
   267,
   206,
   1657,
   341,
   1648,
   1617,
   10,
   11,
   1271,
   1065,
   357,
   198,
   369,
   378,
   1570,
   1563,
   1560,
   1541,
   1539,
   1276,
   192,
   420,
   1525,
   1519,
   28,
   1512,
   430,
   1500,
   1493,
   1484,
   1473,
   431,
   1451,
   481,
   441,
   1416,
   1356,
   1409,
   446,
   462,
   467,
   496,
   1375,
   525,
   173,
   575,
   580,
   165,
   1339,
   1312,
   1311,
   55,
   594,
   57,
   1303,
   24,
   598,
   1297,
   600,
   614,
   634,
   1269,
   640,
   641,
   1240,
   1218,
   1195,
   1186,
   660,
   1599,
   664,
   691,
   1170,
   727,
   729,
   741,
   744,
   748,
   756,
   767,
   518,
   1091,
   773,
   784,
   1040,
   1034,
   1031,
   1016,
   794,
   679,
   1006,
   999,
   818,
   969,
   129,
   854,
   960,
   865,
   870,
   951,
   871,
   233,
   949,
   834,
   872,
   874,
   889,
   117,
   904,
   946,
   1601,
   1530,
   88

In [None]:
source_train, source_val, target_test, source= random_split_dataloader_init(
data="PACS", data_root=data_root, source_domain=source_domain[:1], target_domain=target_domain,
batch_size=128)

In [None]:
dir(source)

In [None]:
source.samples

In [39]:
len(INDICES['photo']['clustering_indices'])

170