In [20]:
import sys
sys.path.append("../..")

In [21]:
model_path = '../../trained_models/simsiam_50_skiPan'
backbone = 'resnet50'
dataset = 'skiPose'

# Imports

In [22]:
import os
import torch
import re
import random
from torchvision.io import read_image

import matplotlib.pyplot as plt
import torchvision.transforms as T

## SimSiam method

In [23]:
from contrastive_training.simsiam.model import get_siam_net

# Clustering

I extract the representations from the trained model and run K means clustering algorithm (number of clusters=8) on these representation. I reduce the dimentionality of the representations with PCA and LDA algorithm and plot the clustering.
For the comparison I use the representations extracted from encoder model (before the projection head) and the representations extracted after the projection head.

## Setup

In [24]:
from dataloaders.datasets import cluster_datasets

In [25]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [26]:
def get_cluster_data(batch_size, dataset='panoptic'):
    transforms = T.Compose(
        [
            T.ToTensor(),
            T.Resize(size=(128, 128)),
        ]
    )

    cluster_dataset = cluster_datasets[dataset](transforms, dataset_dir='../../datasets')
    
    cluster_loader = torch.utils.data.DataLoader(cluster_dataset, batch_size)

    return cluster_dataset, cluster_loader

In [27]:
def extract_representations(path, cluster_loader, load=True):
    net = get_siam_net(backbone)

    if load:
        net.load_state_dict(torch.load(path))

    net.to(device)
    net.eval()

    base = []
    proj = []
    pred = []

    with torch.no_grad():
        for batch_idx, inputs in enumerate(cluster_loader):
            images = inputs['image']
            images.to(device)
            x, j, d = net(images)
            base.append(x)
            proj.append(j)
            pred.append(d)

    return torch.cat(base).cpu().numpy(), torch.cat(proj).cpu().numpy(), torch.cat(pred).cpu().numpy()

In [28]:
from sklearn.cluster import KMeans


def kmeans_algorithm(features, n_clusters=8, random_state=0):
    kmeans = KMeans(n_clusters=n_clusters, random_state=random_state)
    kmeans.fit(features)

    return kmeans.labels_

In [29]:
#import LDA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
#import PCA
from sklearn.decomposition import PCA

def reduce_dim_LDA(features, labels, n_components=2):
    lda = LDA(n_components=n_components)
    lda.fit(features, labels)

    return lda.transform(features)

def reduce_dim_PCA(features, n_components=2):
    pca = PCA(n_components=n_components)
    pca.fit(features)

    return pca.transform(features)

In [30]:
def plot_clusters(clusters, features, title):
    colors = {
        0: '#F8512E', 1: '#F8F82E',
        2: '#40F82E', 3: '#2EC1F8',
        4: '#6B2EF8', 5: '#D92EF8',
        6: '#731642', 7: '#092040'
    }

    cluster_colors = [colors[c] for c in clusters]

    plt.scatter(features[:, 0], features[:, 1], c=cluster_colors)
    plt.title(title)
    plt.show()

In [31]:
from sklearn.metrics import silhouette_score


def cluster(model_path, load=True, n_clusters=8, n_components=2, dataset='panoptic', random_state=0):
    cluster_set, cluster_loader = get_cluster_data(30, dataset)

    base, proj, pred = extract_representations(model_path, cluster_loader, load)

    labels_base = kmeans_algorithm(base, n_clusters, random_state)
    pca_base = reduce_dim_PCA(base, n_components)
    lda_base = reduce_dim_LDA(base, labels_base, n_components)

    labels_proj = kmeans_algorithm(proj, n_clusters, random_state)
    pca_proj = reduce_dim_PCA(proj, n_components)
    lda_proj = reduce_dim_LDA(proj, labels_proj, n_components)

    labels_pred = kmeans_algorithm(pred, n_clusters, random_state)
    pca_pred = reduce_dim_PCA(pred, n_components)
    lda_pred = reduce_dim_LDA(pred, labels_pred, n_components)

    silhouette_base = silhouette_score(base, labels_base)
    silhouette_proj = silhouette_score(proj, labels_proj)
    silhouette_pred = silhouette_score(pred, labels_pred)

    print("Silhouette score for the encoder features: {}".format(silhouette_base))
    print("Silhouette score for the projection head features: {}".format(silhouette_proj))
    print("Silhouette score for the prediction head features: {}".format(silhouette_pred))

    plot_clusters(labels_base, pca_base, "Encoder features")
    plot_clusters(labels_proj, pca_proj, "Projection head features")
    plot_clusters(labels_pred, pca_pred, "Prediction head features")

    plot_clusters(labels_base, lda_base, "Encoder features")
    plot_clusters(labels_proj, lda_proj, "Projection head features")
    plot_clusters(labels_pred, lda_pred, "Prediction head features")

    return cluster_set, base, proj, pred, labels_base, labels_proj, labels_pred, lda_base, lda_proj, lda_pred, pca_base, pca_proj, pca_pred

In [None]:
#get the latest model
epoch = 0
for file in os.listdir(model_path):
    if 'epoch' in file:
        e = int(re.findall(r'\d+', file)[0])
        if e > epoch:
            epoch = e

path = model_path + '/epoch_{:d}.pt'.format(epoch)
print(path)

#path="trained_models/siam2.pt"

## Before training

In [None]:
_, _, _, _, _, _, _, _, _, _, _, _, _ = cluster(path, load=False, n_clusters=8, n_components=2, dataset=dataset, random_state=0)

## Trained model
visualization based on the model preprained with 0.0001 learning rate for base encoder

In [None]:
cluster_set, base_features, proj_features, pred_features, labels_base, labels_proj, labels_pred, lda_base, lda_proj, lda_pred, pca_base, pca_proj, pca_pred = cluster(
    path, load=True, n_clusters=8, n_components=2, dataset=dataset, random_state=0
    )

### Clustering results

In [35]:
def plot_one_cluster(dataset, cluster_ind, labels):
    indexes = [i for i in range(0, len(labels)) if labels[i]==cluster_ind]

    plt.figure(figsize = (10,10))

    for i in range(0, 9):
        ax = plt.subplot(3, 3, i+1)
        ax.imshow(dataset[indexes[random.randint(0, len(indexes)-1)]]['image'].permute(1, 2, 0))

    plt.show()

#### Base encoder
Clustering based on features extracted from base encoder

In [None]:
plot_one_cluster(cluster_set, 0, labels_base)

In [None]:
plot_one_cluster(cluster_set, 1, labels_base)

In [None]:
plot_one_cluster(cluster_set, 2, labels_base)

In [None]:
plot_one_cluster(cluster_set, 3, labels_base)

In [None]:
plot_one_cluster(cluster_set, 4, labels_base)

In [None]:
plot_one_cluster(cluster_set, 5, labels_base)

In [None]:
plot_one_cluster(cluster_set, 6, labels_base)

In [None]:
plot_one_cluster(cluster_set, 7, labels_base)

#### Projection head
Clustering based on features extracted from projection head

In [None]:
plot_one_cluster(cluster_set, 0, labels_proj)

In [None]:
plot_one_cluster(cluster_set, 1, labels_proj)

In [None]:
plot_one_cluster(cluster_set, 2, labels_proj)

In [None]:
plot_one_cluster(cluster_set, 3, labels_proj)

In [None]:
plot_one_cluster(cluster_set, 4, labels_proj)

In [None]:
plot_one_cluster(cluster_set, 5, labels_proj)

In [None]:
plot_one_cluster(cluster_set, 6, labels_proj)

In [None]:
plot_one_cluster(cluster_set, 7, labels_proj)

#### Prediction head
Clustering based on features extracted from prediction head

In [None]:
plot_one_cluster(cluster_set, 0, labels_pred)

In [None]:
plot_one_cluster(cluster_set, 1, labels_pred)

In [None]:
plot_one_cluster(cluster_set, 2, labels_pred)

In [None]:
plot_one_cluster(cluster_set, 3, labels_pred)

In [None]:
plot_one_cluster(cluster_set, 4, labels_pred)

In [None]:
plot_one_cluster(cluster_set, 5, labels_pred)

In [None]:
plot_one_cluster(cluster_set, 6, labels_pred)

In [None]:
plot_one_cluster(cluster_set, 7, labels_pred)