In [1]:
import numpy as np
import matplotlib.pyplot as plt
from keras import layers, optimizers, datasets, models, utils, losses, callbacks
import keras.backend as K
from scipy.spatial.distance import euclidean
from sklearn.cluster import KMeans
from tqdm import tqdm

%matplotlib inline

Using TensorFlow backend.


### Helper functions

In [2]:
def get_labels_from_clusters(centroids, data) -> np.array:
    """
        Assigns the labels to the images based on the distance from 
        the centroid
        
        Returns:
            array of labels 
    """
    
    # reminder to pass the right data
    assert data.shape[1] == 10 
    
    return np.array([np.argmin(np.square([euclidean(u=centroid, v=feature) 
                                         for centroid in centroids])) 
                    for feature in data])

In [3]:
# assign clusters function
def assign_clusters(population, labels):
    """
        Assigns the images to clusters based on their labels
        
        Returns:
            clusters (np.array):
                images that were labeled 5, will be under index 4 in this array
                
            counts (np.array):
                count of images in every cluster
    """
    
    # init the vars
    clusters = list()
    counts = list()
    
    # iterate over classes
    for i in range(10):
        clusters.append(population[labels==i])
        counts.append(len(population[labels==i]))
    return np.array(clusters), np.array(counts)

### Load the data

In [4]:
# load the mnist dataset
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train = np.expand_dims(x_train, axis=3) / 255

### Initialize the CNN

In [6]:
# - - - input image - - -
# 28x28x1
image_input = layers.Input(shape=(28,28,1))


# - - - CNN - - - 
# first convolution
# 14x14x32
conv_1 = layers.Conv2D(filters=32, kernel_size=(2,2), strides=(2,2), 
                       activation="relu", padding="valid", name="conv_1")(image_input)

# second convolution
# 7x7x64
conv_2 = layers.Conv2D(filters=64, kernel_size=(2,2), strides=(2,2), 
                       activation="relu", padding="valid", name="conv_2")(conv_1)
# - - - CNN - - - 


# - - - Adaptation Layers - - -
ada_3 = layers.Conv2D(filters=128, kernel_size=(2,2), strides=(2,2),
                     activation="relu", padding="same", name="ada_1")(conv_2)

ada_4 = layers.Conv2D(filters=10, kernel_size=(2,2), strides=(2,2),
                     activation="relu", padding="same", name="ada_2")(ada_3)
# - - - Adaptation Layers - - - 


# - - - Global Pool - - -
global_max_pool = layers.GlobalMaxPool2D()(ada_4)
# - - - Global Pool - - -


# - - - FC9 - - -
fc = layers.Dense(units=10, activation='relu', name='fc')(global_max_pool)
# - - - FC9 - - -


# - - - SOFTMAX - - -
softmax = layers.Dense(units=10, activation='softmax', name='softmax')(fc)
# - - - SOFTMAX - - - 

### Forward pass of the K randomly selected images

### This is where hell starts

In [7]:
# choose K - (10) random images out of the set to use as initial cluster centroids
index_list = np.arange(len(x_train))

# get 10 random indices
random_indices = np.random.choice(a=index_list, size=10, replace=False)

# get the corresponding 10 images
initial_random_images = x_train[random_indices]

In [8]:
# define the K.function to get the centroids
ccnn_function = K.function(inputs=[image_input], outputs=[softmax])

# get the initial centroids
initial_centroids = ccnn_function([initial_random_images])[0]

# get the labels of the initial centroids
# ARGMAX since softmax outputs the probabilities
ccnn_initial_labels = np.argmax(initial_centroids, 1)

# print the initial labels
print(ccnn_initial_labels)

[2 6 2 2 2 6 6 2 2 2]


In [9]:
# current centroids
random_indices

array([59492, 42410, 56477, 32314,  6661, 55999, 46247, 48428, 43234,
       31748])

### Forward pass of the rest of the data

In [10]:
# create the mask to drop the initial indices
mask = np.ones(shape=index_list.shape)
mask[random_indices] = False

# get the other images tensor
other_images = x_train[mask.astype(np.bool)]

In [11]:
# forward pass the rest of the images
other_images_features = ccnn_function([other_images])[0]

In [12]:
# get the predictions of the other images based on the CNN
other_images_predictions = np.argmax(other_images_features, 1)

In [13]:
# get the pseudo-ground truth labels
other_images_labels = get_labels_from_clusters(centroids=initial_centroids, data=other_images_features)

### Train the CNN

Now we can use the *labels* and the *predictions* to define the loss function

In [14]:
# one hot encode the labels
labels = utils.to_categorical(y=other_images_labels, num_classes=10)
predictions = utils.to_categorical(y=other_images_predictions, num_classes=10)

In [15]:
labels[:5]

array([[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]], dtype=float32)

In [16]:
predictions[:5]

array([[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)

In [17]:
# define the keras model
model = models.Model(inputs=image_input, outputs=softmax)

# define metrics
metrics = ["accuracy"]

# define callbacks
early_stopping = callbacks.EarlyStopping(monitor='val_loss', patience=5, verbose=1)
callbacks_ = [early_stopping]

# define loss
# predictions_ = K.constant(predictions)
# labels_ = K.constant(labels)
loss_ = losses.categorical_crossentropy

# compile the model
model.compile(loss=loss_, optimizer=optimizers.SGD(lr=0.001), metrics=metrics)

In [21]:
# fit the model
model.fit(x=other_images, y=labels, validation_split=0.20, epochs=10, verbose=1, callbacks=callbacks_)

Train on 47992 samples, validate on 11998 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x2afa99eaac8>

### Update the centroids

In [22]:
# we got the new features from the CNN
new_features = ccnn_function([other_images])[0]

# now we are going to assign them to the old centroids
new_labels = get_labels_from_clusters(centroids=initial_centroids, data=new_features)

In [56]:
# get the clusters and counts
clusters, counts = assign_clusters(labels=new_labels, population=other_images)

# replace zeros in the counts
counts[counts == 0] = 2

# get gammas
gammas = 1 / counts

In [44]:
gammas.shape

(10,)

In [57]:
clusters[0].shape

(387, 28, 28, 1)

In [41]:
new_centroids = np.multiply((1 - gammas), initial_centroids) + np.multiply(gammas, new_features)

ValueError: operands could not be broadcast together with shapes (10,10) (59990,10) 

In [45]:
gammas

array([2.58397933e-03, 2.64725347e-05, 6.30914826e-05, 3.86847195e-04,
       1.01832994e-03, 3.44827586e-02, 5.00000000e-01, 5.37634409e-04,
       1.91570881e-03, 5.00000000e-01])

In [None]:
get_labels(centroids=initial_centroids_features, data=new_features)

In [None]:
assign_clusters(new_features, new_features_labels_kmeans)

### Forward pass the images

In [None]:
for t in tqdm(range(10)):
    
    # get the new feature vector by forward pass of the CNN
    new_features = ccnn_function([x_train])[0]
    
    # form new clusters
    kmeans = KMeans(n_clusters=10).fit(new_features)
    
    # get the new labels
    labels_kmeans_onehot = utils.to_categorical(y=kmeans.labels_, num_classes=10)
    
    # fit the model
    model.fit(x=x_train, y=labels_kmeans_onehot, 
              validation_split=0.20, epochs=10, verbose=0, 
              callbacks=callbacks_)

In [None]:
random_indices = np.random.choice(np.arange(len(x_train)), replace=False, size=100)

In [None]:
random_images = x_train[random_indices]

In [None]:
true_labels = y_train[random_indices]

In [None]:
predicted_labels = np.argmax(ccnn_function([random_images])[0], 1)

In [None]:
predicted_labels

In [None]:
true_labels

In [None]:
true_labels == predicted_labels

In [None]:
plt.imshow(random_images[0].reshape(28,28))