In [1]:
!pip install scikit-learn-extra



## Code to pick initial samples using KMedoids and Random

In [2]:
import sys
sys.path.append('../')

In [3]:
import torch
import torch.nn as nn
import numpy as np
from sklearn_extra.cluster import KMedoids

import os
import random

In [4]:
device = torch.device("cuda:" + str(0) if torch.cuda.is_available() else "cpu")

In [5]:
from dataloader.dataloader_aaai2020 import *
from util.util import *
from copy import deepcopy

In [6]:
def split_domain(domains, split_idx, print_domain=True):
    source_domain = deepcopy(domains)
    target_domain = [source_domain.pop(split_idx)]
    if print_domain:
        print('Source domain: ', end='')
        for domain in source_domain:
            print(domain, end=', ')
        print('Target domain: ', end='')
        for domain in target_domain:
            print(domain)
    return source_domain, target_domain
    
domain_map = {
    'PACS': ['photo', 'art_painting', 'cartoon', 'sketch'],
    'PACS_random_split': ['photo', 'art_painting', 'cartoon', 'sketch'],
    'OfficeHome': ['Art', 'Clipart', 'Product', 'RealWorld'],
    'VLCS': ['Caltech', 'Labelme', 'Pascal', 'Sun']
}

def get_domain(name):
    if name not in domain_map:
        raise ValueError('Name of dataset unknown %s' %name)
    return domain_map[name]

In [7]:
ROOT_DIR = "../"
data_root = f"{ROOT_DIR}/PACS/kfold/"
domain_samples = 170
RANDOM_STATE = 42

INDICES_DIR = f"{ROOT_DIR}"
INDICES_PATH = f"{INDICES_DIR}/dg_mmld-master/indices_final_170_june2021.pt"

exp_nums = 3
INDICES_DIR

'../'

In [8]:
source_domain, target_domain = split_domain(get_domain("PACS"), exp_nums)

Source domain: photo, art_painting, cartoon, Target domain: sketch


In [9]:
model = get_model("caffenet", "general")(
        num_classes=7, num_domains=3, pretrained=True)
caffenet_model = model.eval()

Using Caffe AlexNet


In [10]:
source_domain[0:1]

['photo']

In [15]:
clustering_indices = {}
for i in range(4):
    print("*"*40)
    if i==3:
        current_domain = target_domain
    else:
        current_domain = source_domain[i:i+1]
    
    clustering_indices[current_domain[0]] = {}
    
    source_train, source_val, target_test = random_split_dataloader_init(
        data="PACS", data_root=data_root, source_domain=current_domain, target_domain=target_domain, batch_size=128)
    
    #print(list(source_train))
    features = []
    labels = []
    for ind, batch in enumerate(source_train):
        data, lbl = batch
        print(f"Domain: {i} Batch number: {ind}")
        features.extend(caffenet_model.features(batch[0]))
        labels.extend(batch[1].numpy())
    
    features = [feature.data.numpy() for feature in features]
    print("len of features:", len(features))

    kmedoids = KMedoids(n_clusters=domain_samples, random_state=RANDOM_STATE)
    I = kmedoids.fit_predict(features)

    clustering_indices[current_domain[0]]['clustering_indices'] = kmedoids.medoid_indices_.tolist()
    clustering_indices[current_domain[0]]['random_indices'] = random.sample(range(0, len(features)) , k=domain_samples)

    class_strength_0 = np.zeros(7)
    class_strength_1 = np.zeros(7)

    for indices in clustering_indices[current_domain[0]]['clustering_indices']:
        class_strength_0[labels[indices]] += 1

    for indices in clustering_indices[current_domain[0]]['random_indices']:
        class_strength_1[labels[indices]] += 1

    clustering_indices[current_domain[0]]['clustering_class_strength'] = class_strength_0
    clustering_indices[current_domain[0]]['random_class_strength'] = class_strength_1

****************************************
Train: 1503, Val: 167, Test: 3929
Domain: 0 Batch number: 0
Domain: 0 Batch number: 1
Domain: 0 Batch number: 2
Domain: 0 Batch number: 3
Domain: 0 Batch number: 4
Domain: 0 Batch number: 5
Domain: 0 Batch number: 6
Domain: 0 Batch number: 7
Domain: 0 Batch number: 8
Domain: 0 Batch number: 9
Domain: 0 Batch number: 10
Domain: 0 Batch number: 11
len of features: 1503
****************************************
Train: 1843, Val: 205, Test: 3929
Domain: 1 Batch number: 0
Domain: 1 Batch number: 1
Domain: 1 Batch number: 2
Domain: 1 Batch number: 3
Domain: 1 Batch number: 4
Domain: 1 Batch number: 5
Domain: 1 Batch number: 6
Domain: 1 Batch number: 7
Domain: 1 Batch number: 8
Domain: 1 Batch number: 9
Domain: 1 Batch number: 10
Domain: 1 Batch number: 11
Domain: 1 Batch number: 12
Domain: 1 Batch number: 13
Domain: 1 Batch number: 14
len of features: 1843
****************************************
Train: 2109, Val: 235, Test: 3929
Domain: 2 Batch number

In [16]:
clustering_indices

{'photo': {'clustering_indices': [715,
   1068,
   654,
   165,
   1418,
   967,
   952,
   246,
   946,
   1024,
   925,
   144,
   278,
   1444,
   291,
   299,
   309,
   615,
   892,
   1238,
   110,
   1058,
   829,
   1365,
   424,
   425,
   815,
   1317,
   1064,
   1490,
   787,
   72,
   573,
   1277,
   744,
   706,
   559,
   558,
   1342,
   520,
   1281,
   174,
   303,
   547,
   1111,
   545,
   187,
   47,
   537,
   722,
   532,
   760,
   531,
   525,
   553,
   1102,
   1336,
   454,
   639,
   59,
   740,
   741,
   1476,
   592,
   1341,
   506,
   554,
   689,
   1413,
   1117,
   1483,
   1272,
   574,
   486,
   483,
   482,
   479,
   468,
   78,
   1084,
   34,
   1306,
   1408,
   1353,
   154,
   927,
   433,
   664,
   692,
   428,
   917,
   629,
   913,
   23,
   230,
   1367,
   1249,
   97,
   394,
   680,
   1371,
   631,
   1291,
   838,
   1460,
   307,
   546,
   1443,
   368,
   911,
   623,
   366,
   907,
   113,
   1167,
   1126,
   106,
   899

In [None]:
if not os.path.exists(INDICES_DIR):
    print("Making directory to save indices")
    os.mkdir(INDICES_DIR)
else:
    print("Directory is already present!")

In [None]:
torch.save({'indices': clustering_indices}, INDICES_PATH)

In [None]:
INDICES = torch.load(INDICES_PATH)['indices']

In [None]:
INDICES.keys()

In [None]:
source_domain = domain_map["PACS"]

In [None]:
source_domain

In [None]:
INDICES

In [None]:
source_train, source_val, target_test, source= random_split_dataloader_init(
data="PACS", data_root=data_root, source_domain=source_domain[:1], target_domain=target_domain,
batch_size=128)

In [None]:
dir(source)

In [None]:
source.samples

In [None]:
len(INDICES['photo']['clustering_indices'])