In [None]:
from __future__ import absolute_import, division, print_function
import sys
import collections 

import matplotlib.pyplot as plt
import numpy as np
import scipy
from sklearn.neighbors import NearestNeighbors
import seaborn as sns
import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import cifar100, cifar10
from tqdm import tqdm
from IPython.display import clear_output

from scan_tf.models.train import pretext_training, pretext_training_bis, semantic_clustering_training 
from scan_tf.models.resnet import *
import scan_tf.utils.utils as utils
import scan_tf.utils.augmentations as augmentations


# Grow memory to avoid mem overflow
memory_growth=True
if memory_growth:
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
      try:
        for gpu in gpus:
          tf.config.experimental.set_memory_growth(gpu, True)
      except RuntimeError as e:
        print(e)

In [None]:
# Model configuration
img_width, img_height, img_num_channels = 32, 32, 3
no_epochs = 100
optimizer = tf.keras.optimizers.Adam()
validation_split = 0.2
verbosity = 1

# Load CIFAR-10 data
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
X_train = X_train[:5000]
y_train = y_train[:5000]



X_train.shape
input_train = X_train
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

# Find number of classes dynamically
num_classes = y_train.shape[1]

# Determine shape of the data
input_shape = (img_width, img_height, img_num_channels)

# Normalize data
#X_train = (X_train/255).astype(np.float32)
#X_test = (X_test/255).astype(np.float32)
X_train = X_train.astype(np.float32)
X_test = X_test.astype(np.float32)

print(f"X_train shape: {X_train.shape}")
print(f"y_train shape: {y_train.shape}")
print(f"X_test shape : {X_train.shape}")
print(f"y_test shape : {y_test.shape}")

In [None]:
backbone_model = resnet_18(n_output=128)
backbone_model.build(input_shape=(None, img_width, img_height, img_num_channels))
backbone_model.summary()

<h3>Training for minimizing the Rotation Loss</h3>

In [None]:
pretext_model_save_path="pretext_task_rotnet_correct.h5"
train_pretext = True
if train_pretext:
    pretext_model = pretext_training_bis(backbone_model, X_train, y_train, epochs=50, save_path=pretext_model_save_path)
else:
    pretext_model = backbone_model
    pretext_model.load_weights(pretext_model_save_path)

In [None]:
images = X_train[:10]
degrees_to_rotate = np.random.choice([180], 10).astype(int)
rotated_images = np.array([scipy.ndimage.rotate(images[i], degrees_to_rotate[i], axes=(0, 1)) for i in range(len(degrees_to_rotate))])
print(rotated_images.shape)
for i in range(1):
    plt.imshow(rotated_images[i].astype("uint8"))
    plt.show()
    plt.imshow(images[i].astype("uint8"))
    plt.show()
    plt.imshow(np.rot90(np.rot90(images[i].astype("uint8"))))
    plt.show()

<h3>Finding Nearest Neighbors (for debugging)</h3>

In [None]:
def find_neighbor_consistancy(pretext_model, images, labels, n_neighbors=5, plot=False):
    n_neighbors = 6
    nn = utils.CLusteringNN(pretext_model, n_neighbors=n_neighbors)
    nn.fit(images)
    nn_indexes = nn.get_neighbors_indexes(images)
    
    n = 4
    f, axes = plt.subplots(n, n_neighbors)
    for i, cluster_indexes in enumerate(nn_indexes[:n]):
        for j, im_i in enumerate(cluster_indexes):
            axes[i,j].axis('off')
            axes[i,j].imshow(images[im_i].astype("uint8"))
    plt.tight_layout()
    plt.show()
    
    consistancies = list()
    true_classes = np.argmax(labels, axis=1)
    for cluster_indexes in nn_indexes:
        cluster_classes = true_classes[cluster_indexes]
        # Not sure which formula is better
        consistancy = (cluster_classes[1:]==cluster_classes[0]).sum()/(len(cluster_indexes)-1)
        # Not sure which formula is better
        #cluster_label_counter = collections.Counter(cluster_classes)
        #consistancy = cluster_label_counter.most_common()[0][1]/len(cluster_indexes)
        consistancies.append(consistancy)
    if plot:
        sns.distplot(consistancies)
        plt.xlabel(f"Consistancy of {n_neighbors-1} nearest neighbors")
        plt.show()
    return consistancies

consistancies = find_neighbor_consistancy(pretext_model, X_train, y_train, n_neighbors=5, plot=True)
print(f"Correct Number of pairs: {np.mean(consistancies)*100:.2f}%")

## K means (can substitute semantic clustering)

In [None]:
from sklearn.cluster import KMeans
train_embeddings = pretext_model.predict(X_train)
kmeans = KMeans(n_clusters=y_train.shape[1], random_state=0).fit(train_embeddings)
predicted_clusters = kmeans.predict(train_embeddings)

## Semantic Clustering

In [None]:
semantic_clustering_model_save_path="semantic_clustering_task.h5"
train_semantic_clustering = False
if train_semantic_clustering:
    semantic_clustering_model = semantic_clustering_training(pretext_model, X_train, y_train, epochs=20, save_path=semantic_clustering_model_save_path)
else:
    # Hope that num clusters is correct lol
    num_clusters = y_train.shape[1]
    semantic_clustering_model = add_classification_layer(pretext_model, num_clusters)
    input_shape = (None,) + X_train.shape[1:]
    semantic_clustering_model.build(input_shape)
    semantic_clustering_model.load_weights(semantic_clustering_model_save_path)


## Hungarian algorithm to match clusters with labels

In [None]:
cluster_probability_associations = semantic_clustering_model.predict(X_train)
predicted_clusters = np.argmax(cluster_probability_associations, axis=1)

In [None]:
from scipy.optimize import linear_sum_assignment
# Run hungarian algorithm for matching
true_labels = np.argmax(y_train, axis=1)
plt.hist(true_labels, alpha=0.5, label="Label")
plt.hist(predicted_clusters, alpha=0.5, label="Cluster")
plt.legend()
plt.show()

# Create cost matrix
frequencies = np.zeros((np.unique(predicted_clusters).shape[0], y_train.shape[1]))
for i, j in zip(predicted_clusters, true_labels):
    frequencies[i,j] += 1
cost_matrix = -frequencies

# Run Hungarian algo to match clusters
row_ind, col_ind = linear_sum_assignment(cost_matrix)
for match in zip(row_ind, col_ind):
    print(f"Cluster {match[0]} matched with label {match[1]}")

#print(frequencies[col_ind][:, row_ind])
correct_assigned = frequencies[row_ind, col_ind].sum()
print(f"Accuracy: {correct_assigned/np.sum(frequencies)*100:.2f}%")
predicted_labels = np.array(col_ind[np.where(row_ind==cluster)[0][0]] for cluster in predicted_clusters)