<h1> Multiple Kernel Transfer Clustering </h1>


In [1]:

from mktc import mktc


<h2> The function interface:  </h2>

<h2> mktc(X, bagX, bagY, bag_sidx, bag_len, list_sigmas, max_iter=100, n_init=10, n_jobs=1) </h2>

<br>

<h2> Function Parameters: <h2>
    
<br>

<h2> X : The dataset </h2>

<h2> bagX : Multi-Instance dataset </h2>

<h2> bagY : Weak Supervised labels provided to the Multi-Instance dataset </h2>

<h2> bag_sidx : Starting idx of each bag </h2>

<h2> bag_len : Length of each bag </h2>

<h2> list_sigmas : List of Gaussian kernel parameters </h2>

<h2> max_iter : Maximum number of iterations </h2>

<h2> n_init : Number of runs of MKTC </h2>

<h2> n_jobs : Number of parallel processor threads to compute distances and kernel similarities </h2>
    
<br>

<h2> Returned Variables: <h2>
    
<br>
    
<h2> data_mem : Cluster membership of dataset X

<h2> multi_instance_mem : Cluster membership of bagX
    
<h2> w : Multiple Kernel parameters
    
<h2> centers : Cluster centers estimated on bagX
    
<h2> costs : Objective Function values per iteration on MKMIKM
    
<h2> v_iter : Iteration number on which MKMIKM terminates
    
<br>
    

In [2]:

import numpy as np


list_sigmas = np.array([1e-2, 5e-2, 1e-1, 1, 10, 50, 100])
m = list_sigmas.shape[0]
n_jobs = -1


print('Gaussian kernel parameters :', list_sigmas)


Gaussian kernel parameters : [1.e-02 5.e-02 1.e-01 1.e+00 1.e+01 5.e+01 1.e+02]


In [3]:

# Load in the digits dataset

dataset = 'digits'
from sklearn.datasets import load_digits
X = load_digits().data
y = load_digits().target

# Preprocess the data to flatten all features, translate the data to remove negative values, 
# scale the features by the maximum value
if X.ndim > 2:
    X = X.reshape(X.shape[0], -1)
if (X < 0).sum() > 1:
    X = X - X.min()
tmp = X.max(axis=0)
tmp[tmp==0] = 1
X = X / tmp

print(dataset, 'dataset size:', X.shape)


digits dataset size: (1797, 64)


In [4]:

# Run MKTC on n_sets number of multi-instance subsets
# and predict the cluster memberships on the dataset
# Display average ARI obtained


from sklearn.metrics import adjusted_rand_score as ARI
from sklearn.metrics import normalized_mutual_info_score as NMI


n_sets = 10
mean_multi_instance_ari = 0
mean_multi_instance_nmi = 0
mean_data_ari = 0
mean_data_nmi = 0


for i2 in range(n_sets):
    print(i2+1, '/', n_sets)
    
    # Read in the multi-instance subsets of the dataset bagX,
    # the weakly supervised bag labels bagY
    # the bag starting idx bag_sidx
    # the bag instance labels all_instances_y
    # Calculate bag lengths bag_len
    tmp = np.load('data_bags_npz/'+dataset+'/'+dataset+'_set'+str(i2)+'.npz')
    bagX = X[tmp['bagX_idxs']]
    n = bagX.shape[0]
    bagY = tmp['bagy']
    n_clusters = bagY.shape[1]
    bag_sidx = tmp['bag_sidx']
    bag_len = np.hstack((bag_sidx[1:] - bag_sidx[0:-1], n - bag_sidx[-1]))
    true_y = tmp['all_instances_y']

    
    # Run MKTC
    data_mem, multi_instance_mem, w, centers, costs, v_iter = mktc(
        X, bagX, bagY, bag_sidx, bag_len, list_sigmas, max_iter=100, n_init=10, n_jobs=n_jobs
    )

    
    # Calculate ARI and NMI
    multi_instance_ari = ARI(true_y, multi_instance_mem)
    multi_instance_nmi = NMI(true_y, multi_instance_mem)
    data_ari = ARI(y, data_mem)
    data_nmi = NMI(y, data_mem)
    print('Data ARI:', data_ari, ', Data NMI:', data_nmi)
    print('Multi-Instance ARI:', multi_instance_ari, ', Multi-Instance NMI:', multi_instance_nmi)
    mean_multi_instance_ari += multi_instance_ari
    mean_data_ari += data_ari
    mean_multi_instance_nmi += multi_instance_nmi
    mean_data_nmi += data_nmi

    
print('Mean Data ARI:', mean_data_ari/n_sets, ', Mean Multi-Instance ARI:', mean_multi_instance_ari/n_sets)
print('Mean Data NMI:', mean_data_nmi/n_sets, ', Mean Multi-Instance NMI:', mean_multi_instance_nmi/n_sets)


1 / 10
break at 13
break at 8
break at 12
break at 10
break at 10
break at 8
break at 8
break at 13
break at 16
break at 10
Data ARI: 0.9264489938671637 , Data NMI: 0.9263948786094083
Multi-Instance ARI: 0.981932190701072 , Multi-Instance NMI: 0.9784735395348774
2 / 10
break at 8
break at 8
break at 7
break at 8
break at 9
break at 9
break at 8
break at 12
break at 7
break at 9
Data ARI: 0.9287996117288587 , Data NMI: 0.925827655079438
Multi-Instance ARI: 0.9546231706999228 , Multi-Instance NMI: 0.9502693534136666
3 / 10
break at 9
break at 10
break at 8
break at 9
break at 22
break at 10
break at 11
break at 11
break at 9
break at 12
Data ARI: 0.9150655535061447 , Data NMI: 0.9162445410209763
Multi-Instance ARI: 0.9630891841877544 , Multi-Instance NMI: 0.9613841752319181
4 / 10
break at 11
break at 9
break at 10
break at 7
break at 11
break at 12
break at 7
break at 11
break at 9
break at 9
Data ARI: 0.916374641466335 , Data NMI: 0.9167507227727185
Multi-Instance ARI: 0.97740879302984