In [4]:
#@title Don't forget to upload usps.h5

import numpy as np
from sklearn import metrics

def purity_score(y_true, y_pred): # from https://stackoverflow.com/a/51672699/7947996; in [0,1]; 0-bad,1-good
    # compute contingency matrix (also called confusion matrix)
    contingency_matrix = metrics.cluster.contingency_matrix(y_true, y_pred)
    # return purity
    return np.sum(np.amax(contingency_matrix, axis=0)) / np.sum(contingency_matrix) 

from sklearn.metrics.cluster import adjusted_rand_score # in [0,1]; 0-bad,1-good
from sklearn.metrics.cluster import normalized_mutual_info_score # in [0,1]; 0-bad,1-good

!pip install coclust
from coclust.evaluation.external import accuracy # in [0,1]; 0-bad,1-good

def get_data_20news():
  import tensorflow as tf
  from sklearn.datasets import fetch_20newsgroups
  from sklearn.feature_extraction.text import TfidfVectorizer

  _20news = fetch_20newsgroups(subset="all")
  data = _20news.data
  target = _20news.target

  vectorizer = TfidfVectorizer(max_features=2000)
  data = vectorizer.fit_transform(data)
  data = data.toarray()

  return data, target


def get_data_mnist():
  import tensorflow as tf
  mnist = tf.keras.datasets.mnist
  (x_train, y_train),(x_test, y_test) = mnist.load_data()

  x_train = np.concatenate((x_train,x_test))
  y_train = np.concatenate((y_train,y_test))

  real_labels = y_train

  # # indices = np.isin(y_train,range(number_of_dist))
  # x_train = x_train[indices]
  # y_train = y_train[indices]

  samples = (x_train.reshape((x_train.shape[0],-1))/255.).astype(np.float32)
  
  return samples, real_labels

def get_data_mnist5():
  import tensorflow as tf
  mnist = tf.keras.datasets.mnist
  (x_train, y_train),(x_test, y_test) = mnist.load_data()

  x_train = np.concatenate((x_train,x_test))
  y_train = np.concatenate((y_train,y_test))

  indices = y_train < 5
  x_train = x_train[indices]
  y_train = y_train[indices]

  real_labels = y_train

  samples = (x_train.reshape((x_train.shape[0],-1))/255.).astype(np.float32)
  
  return samples, real_labels

def get_data_fmnist():
  import tensorflow as tf
  mnist = tf.keras.datasets.fashion_mnist
  (x_train, y_train),(x_test, y_test) = mnist.load_data()

  x_train = np.concatenate((x_train,x_test))
  y_train = np.concatenate((y_train,y_test))

  real_labels = y_train

  # # indices = np.isin(y_train,range(number_of_dist))
  # x_train = x_train[indices]
  # y_train = y_train[indices]

  samples = (x_train.reshape((x_train.shape[0],-1))/255.).astype(np.float32)
  
  return samples, real_labels

def get_data_usps():
  import h5py
  path = "./usps.h5"
  with h5py.File(path, 'r') as hf:
    train = hf.get('train')
    X_tr = train.get('data')[:]
    y_tr = train.get('target')[:]
    test = hf.get('test')
    X_te = test.get('data')[:]
    y_te = test.get('target')[:]

  samples = np.concatenate((X_tr,X_te))
  real_labels = np.concatenate((y_tr,y_te))
  return samples, real_labels

original_data_name = "mnist5" # @param ["mnist", "mnist5", "fmnist", "20news", "usps"]

if original_data_name == "mnist":
    samples, real_labels = get_data_mnist()
elif original_data_name == "mnist5":
    samples, real_labels = get_data_mnist5()
elif original_data_name == "fmnist":
    samples, real_labels = get_data_fmnist()
elif original_data_name == "20news":
    samples, real_labels = get_data_20news()
elif original_data_name == "usps":
    samples, real_labels = get_data_usps()
  
k = len(np.unique(real_labels))
n_init = 10



### Random

In [5]:
predicted_random = np.random.randint(k,size=len(real_labels))

print(purity_score(real_labels,predicted_random))
print(adjusted_rand_score(real_labels,predicted_random))
print(normalized_mutual_info_score(real_labels,predicted_random))
print(accuracy(real_labels,predicted_random))

0.22042815167203023
-6.413886038376716e-05
5.886327645035204e-05
0.20282636071078775




### k-means

In [6]:
from sklearn.cluster import KMeans
import numpy as np
X = samples
kmeans = KMeans(n_clusters=k,n_init=n_init).fit(X)
predicted_km = kmeans.predict(X)

print(purity_score(real_labels,predicted_km))
print(adjusted_rand_score(real_labels,predicted_km))
print(normalized_mutual_info_score(real_labels,predicted_km))
print(accuracy(real_labels,predicted_km))

0.8810969637610186
0.7325198572443044
0.7099687984462205
0.8810969637610186




In [7]:
import sklearn.metrics
matrix = sklearn.metrics.cluster.contingency_matrix(real_labels, predicted_km)
print(matrix)
print(matrix/matrix.sum(axis=1, keepdims=True))

[[ 115  387   21 6132  248]
 [  17   39 7701    0  120]
 [ 380  596  786   94 5134]
 [ 176 6024  401   45  495]
 [6495    5  268   15   41]]
[[1.66594234e-02 5.60625815e-02 3.04215558e-03 8.88309431e-01
  3.59264088e-02]
 [2.15818205e-03 4.95112352e-03 9.77656468e-01 0.00000000e+00
  1.52342262e-02]
 [5.43633763e-02 8.52646638e-02 1.12446352e-01 1.34477825e-02
  7.34477825e-01]
 [2.46464081e-02 8.43579331e-01 5.61546002e-02 6.30163843e-03
  6.93180227e-02]
 [9.51787808e-01 7.32708089e-04 3.92731536e-02 2.19812427e-03
  6.00820633e-03]]


### GMM

In [8]:
import numpy as np
from sklearn.mixture import GaussianMixture
X = samples
gm = GaussianMixture(n_components=k,n_init=n_init).fit(X)
predicted_gmm = gm.predict(X)

print(purity_score(real_labels,predicted_gmm))
print(adjusted_rand_score(real_labels,predicted_gmm))
print(normalized_mutual_info_score(real_labels,predicted_gmm))
print(accuracy(real_labels,predicted_gmm))

0.5678466489436127
0.2639756041554489
0.3513609018936295
0.5678466489436127




In [9]:
import sklearn.metrics
matrix = sklearn.metrics.cluster.contingency_matrix(real_labels, predicted_km)
print(matrix)
print(matrix/matrix.sum(axis=1, keepdims=True))

[[ 115  387   21 6132  248]
 [  17   39 7701    0  120]
 [ 380  596  786   94 5134]
 [ 176 6024  401   45  495]
 [6495    5  268   15   41]]
[[1.66594234e-02 5.60625815e-02 3.04215558e-03 8.88309431e-01
  3.59264088e-02]
 [2.15818205e-03 4.95112352e-03 9.77656468e-01 0.00000000e+00
  1.52342262e-02]
 [5.43633763e-02 8.52646638e-02 1.12446352e-01 1.34477825e-02
  7.34477825e-01]
 [2.46464081e-02 8.43579331e-01 5.61546002e-02 6.30163843e-03
  6.93180227e-02]
 [9.51787808e-01 7.32708089e-04 3.92731536e-02 2.19812427e-03
  6.00820633e-03]]
