In [1]:
#@title Don't forget to upload usps.h5

import numpy as np
from sklearn import metrics

def purity_score(y_true, y_pred): # from https://stackoverflow.com/a/51672699/7947996; in [0,1]; 0-bad,1-good
    # compute contingency matrix (also called confusion matrix)
    contingency_matrix = metrics.cluster.contingency_matrix(y_true, y_pred)
    # return purity
    return np.sum(np.amax(contingency_matrix, axis=0)) / np.sum(contingency_matrix) 

from sklearn.metrics.cluster import adjusted_rand_score # in [0,1]; 0-bad,1-good
from sklearn.metrics.cluster import normalized_mutual_info_score # in [0,1]; 0-bad,1-good

!pip install coclust
from coclust.evaluation.external import accuracy # in [0,1]; 0-bad,1-good

def get_data_20news():
  import tensorflow as tf
  from sklearn.datasets import fetch_20newsgroups
  from sklearn.feature_extraction.text import TfidfVectorizer

  _20news = fetch_20newsgroups(subset="all")
  data = _20news.data
  target = _20news.target

  vectorizer = TfidfVectorizer(max_features=2000)
  data = vectorizer.fit_transform(data)
  data = data.toarray()

  return data, target


def get_data_mnist():
  import tensorflow as tf
  mnist = tf.keras.datasets.mnist
  (x_train, y_train),(x_test, y_test) = mnist.load_data()

  x_train = np.concatenate((x_train,x_test))
  y_train = np.concatenate((y_train,y_test))

  real_labels = y_train

  # # indices = np.isin(y_train,range(number_of_dist))
  # x_train = x_train[indices]
  # y_train = y_train[indices]

  samples = (x_train.reshape((x_train.shape[0],-1))/255.).astype(np.float32)
  
  return samples, real_labels



def get_data_usps():
  import h5py
  path = "./usps.h5"
  with h5py.File(path, 'r') as hf:
    train = hf.get('train')
    X_tr = train.get('data')[:]
    y_tr = train.get('target')[:]
    test = hf.get('test')
    X_te = test.get('data')[:]
    y_te = test.get('target')[:]

  samples = np.concatenate((X_tr,X_te))
  real_labels = np.concatenate((y_tr,y_te))
  return samples, real_labels

original_data_name = "usps" # @param ["mnist", "20news", "usps"]

if original_data_name == "mnist":
    samples, real_labels = get_data_mnist()
elif original_data_name == "20news":
    samples, real_labels = get_data_20news()
elif original_data_name == "usps":
    samples, real_labels = get_data_usps()
  
k = len(np.unique(real_labels))
n_init = 10

Collecting coclust
  Downloading https://files.pythonhosted.org/packages/5d/44/ad5a69c7187c2b7bcf2c45596e9052811a3be52f4fcaa6709937c5146ee2/coclust-0.2.1.tar.gz
Building wheels for collected packages: coclust
  Building wheel for coclust (setup.py) ... [?25l[?25hdone
  Created wheel for coclust: filename=coclust-0.2.1-cp37-none-any.whl size=29871 sha256=5641b5e4eba70f6dc7fab9cbf4461af1abd5cd0a14a01a3119b84da1f5b4da75
  Stored in directory: /root/.cache/pip/wheels/cd/d7/68/df601d0b5f8b934cf890dc626c2271df381fb0c3e910b0a34e
Successfully built coclust
Installing collected packages: coclust
Successfully installed coclust-0.2.1




### Random

In [2]:
predicted_random = np.random.randint(k,size=len(real_labels))

print(purity_score(real_labels,predicted_random))
print(adjusted_rand_score(real_labels,predicted_random))
print(normalized_mutual_info_score(real_labels,predicted_random))
print(accuracy(real_labels,predicted_random))

0.16831576683157667
-7.939696487276653e-05
0.0017561598357265873
0.11271241127124113




### k-means

In [3]:
from sklearn.cluster import KMeans
import numpy as np
X = samples
kmeans = KMeans(n_clusters=k,n_init=n_init).fit(X)
predicted_km = kmeans.predict(X)

print(purity_score(real_labels,predicted_km))
print(adjusted_rand_score(real_labels,predicted_km))
print(normalized_mutual_info_score(real_labels,predicted_km))
print(accuracy(real_labels,predicted_km))

0.7387610238761024
0.5455530410934258
0.6270275336895192
0.6678855667885567




### GMM

In [4]:
import numpy as np
from sklearn.mixture import GaussianMixture
X = samples
gm = GaussianMixture(n_components=k,n_init=n_init).fit(X)
predicted_gmm = gm.predict(X)

print(purity_score(real_labels,predicted_gmm))
print(adjusted_rand_score(real_labels,predicted_gmm))
print(normalized_mutual_info_score(real_labels,predicted_gmm))
print(accuracy(real_labels,predicted_gmm))

0.6079802107980211
0.39829472426072704
0.5370877440055406
0.5526995052699505


