In [None]:
import torch
import matplotlib.pyplot as plt
import os
import yaml
import sys
import numpy as np
from tqdm import tqdm
sys.path.append('../')
from datamodule.datamodule import select_data
from models.models import Classifier, CooperativeOpticalModelRemote
from scipy.spatial.distance import pdist, squareform

from sklearn import datasets, decomposition
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import umap

In [None]:
print(plt.style.available)
plt.style.use('seaborn-v0_8-whitegrid')

# Load in the data

In [None]:
config = yaml.load(open('../../config.yaml', 'r'), Loader=yaml.FullLoader)
config['paths']['path_root'] = '../../'
config['paths']['path_data'] = 'data/baseline'

In [None]:
baseline_filenames = [os.path.join(config['paths']['path_root'], config['paths']['path_data'], i) for i in os.listdir(os.path.join(config['paths']['path_root'], config['paths']['path_data']))]
baseline_filenames.sort()

In [None]:
baseline_bench_images = [torch.load(f, weights_only=True)['bench_image'].squeeze().numpy() for f in tqdm(baseline_filenames)]
baseline_sim_images = [torch.load(f, weights_only=True)['sim_output'].squeeze().numpy() for f in tqdm(baseline_filenames)]
baseline_ideal_images = [torch.load(f, weights_only=True)['resampled_sample'].squeeze().numpy() for f in tqdm(baseline_filenames)]

baseline_targets = [torch.argmax(torch.load(f, weights_only=True)['target']).numpy() for f in tqdm(baseline_filenames)]
baseline_targets = np.asarray(baseline_targets).squeeze()
baseline_unique_targets = np.unique(baseline_targets)

In [None]:
config['paths']['path_data'] = 'data/post_training'
pt_filenames = [os.path.join(config['paths']['path_root'], config['paths']['path_data'], i) for i in os.listdir(os.path.join(config['paths']['path_root'], config['paths']['path_data']))]
pt_filenames.sort()

In [None]:
pt_bench_images = [torch.load(f, weights_only=True)['bench_image'].squeeze().detach().numpy() for f in tqdm(pt_filenames)]
pt_sim_images = [torch.load(f, weights_only=True)['sim_output'].squeeze().detach().numpy() for f in tqdm(pt_filenames)]
pt_ideal_images = [torch.load(f, weights_only=True)['resampled_sample'].squeeze().detach().numpy() for f in tqdm(pt_filenames)]

pt_targets = [torch.argmax(torch.load(f, weights_only=True)['target']).numpy() for f in tqdm(pt_filenames)]
pt_targets = np.asarray(pt_targets).squeeze()
pt_unique_targets = np.unique(pt_targets)

# Load in the classifier

In [None]:
checkpoint_path = '../../results/classifier_baseline_bench_resampled_sample/version_0/checkpoints/last.ckpt'
classifier = Classifier.load_from_checkpoint(checkpoint_path).double().cpu()

# Populate the feature representations for the different images

In [None]:
def create_feature_vectors(classifier, images):
    feature_vectors = []
    for image in tqdm(images):
        image = torch.from_numpy(image).squeeze().unsqueeze(0).unsqueeze(0)
        image = torch.cat([image, image, image], dim=1).double()
        feature_vectors.append(classifier.feature_extractor(image))
    return feature_vectors

In [None]:
with torch.no_grad():
    baseline_bench_feature_embeddings = create_feature_vectors(classifier, baseline_bench_images)
    baseline_sim_feature_embeddings = create_feature_vectors(classifier, baseline_sim_images)
    baseline_ideal_feature_embeddings = create_feature_vectors(classifier, baseline_ideal_images)

In [None]:
with torch.no_grad():
    pt_bench_feature_embeddings = create_feature_vectors(classifier, pt_bench_images)
    pt_sim_feature_embeddings = create_feature_vectors(classifier, pt_sim_images)

In [None]:
path_results = '../../results/feature_embeddings/'
os.makedirs(path_results, exist_ok = True)

In [None]:
torch.save(baseline_bench_feature_embeddings, os.path.join(path_results, 'baseline_bench_feature_embeddings.pt'))
torch.save(baseline_sim_feature_embeddings, os.path.join(path_results, 'baseline_sim_feature_embeddings.pt'))
torch.save(baseline_ideal_feature_embeddings, os.path.join(path_results, 'baseline_ideal_feature_embeddings.pt'))
torch.save(pt_bench_feature_embeddings, os.path.join(path_results, 'pt_bench_feature_embeddings.pt'))
torch.save(pt_sim_feature_embeddings, os.path.join(path_results, 'pt_sim_feature_embeddings.pt'))

In [None]:
baseline_bench_feature_embeddings = torch.load(os.path.join(path_results, 'baseline_bench_feature_embeddings.pt'), weights_only=True)
baseline_sim_feature_embeddings = torch.load(os.path.join(path_results, 'baseline_sim_feature_embeddings.pt'), weights_only=True)
baseline_ideal_feature_embeddings = torch.load(os.path.join(path_results, 'baseline_ideal_feature_embeddings.pt'), weights_only=True)
pt_bench_feature_embeddings = torch.load(os.path.join(path_results, 'pt_bench_feature_embeddings.pt'), weights_only=True)
pt_sim_feature_embeddings = torch.load(os.path.join(path_results, 'pt_sim_feature_embeddings.pt'), weights_only=True)

In [None]:
baseline_bench_feature_embeddings = np.asarray([np.asarray(i.squeeze().detach().cpu()) for i in baseline_bench_feature_embeddings])
baseline_sim_feature_embeddings = np.asarray([np.asarray(i.squeeze().detach().cpu()) for i in baseline_sim_feature_embeddings])
baseline_ideal_feature_embeddings = np.asarray([np.asarray(i.squeeze().detach().cpu()) for i in baseline_ideal_feature_embeddings])
pt_bench_feature_embeddings = np.asarray([np.asarray(i.squeeze().detach().cpu()) for i in pt_bench_feature_embeddings])
pt_sim_feature_embeddings = np.asarray([np.asarray(i.squeeze().detach().cpu()) for i in pt_sim_feature_embeddings])

In [None]:
baseline_bench_feature_embeddings.shape

# Colors

In [None]:
colors = ['#E8ECFB', '#D9CCE3', '#D1BBD7', '#CAACCB', '#BA8DB4', 
          '#AE76A3', '#AA6F9E', '#994F88', '#882E72', '#1965B0', 
          '#437DBF', '#5289C7', '#6195CF', '#7BAFDE', '#4EB265', 
          '#90C987', '#CAE0AB', '#F7F056', '#F7CB45', '#F6C141', 
          '#F4A736', '#F1932D', '#EE8026', '#E8601C', '#E65518', 
          '#DC050C', '#A5170E', '#72190E', '#42150A']

colors2 = [ '#a6cee3',
            '#1f78b4',
            '#b2df8a',
            '#33a02c',
            '#fb9a99',
            '#e31a1c',
            '#fdbf6f',
            '#ff7f00',
            '#cab2d6',
            '#6a3d9a']

color_indices = [9,10,14,15,17,18,21,24,26,28]
color2_indices = [0,1,2,3,4,5,6,7,8,9]
len(color_indices)

# PCA comparison

In [None]:
pca = PCA(n_components=2)
pca.fit(baseline_ideal_feature_embeddings)

In [None]:
baseline_bench_pca = pca.transform(baseline_bench_feature_embeddings)
baseline_sim_pca = pca.transform(baseline_sim_feature_embeddings)
baseline_ideal_pca = pca.transform(baseline_ideal_feature_embeddings)
pt_bench_pca = pca.transform(pt_bench_feature_embeddings)
pt_sim_pca = pca.transform(pt_sim_feature_embeddings)

In [None]:
baseline_targets = [torch.argmax(torch.load(f, weights_only=True)['target']).numpy() for f in tqdm(baseline_filenames)]
baseline_targets = np.asarray(baseline_targets).squeeze()
baseline_unique_targets = np.unique(baseline_targets)
pt_targets = [torch.argmax(torch.load(f, weights_only=True)['target']).numpy() for f in tqdm(pt_filenames)]
pt_targets = np.asarray(pt_targets).squeeze()
pt_unique_targets = np.unique(pt_targets)

In [None]:
fig,ax = plt.subplots(2,3, figsize=(15,10))

for target in baseline_unique_targets:
        indices = np.where(baseline_targets == target)[0]
        bench_transform_values = baseline_bench_pca[indices]
        ideal_transform_values = baseline_ideal_pca[indices]
        sim_transform_values = baseline_sim_pca[indices]
        color_idx = color2_indices[target]
        color = colors2[color_idx]
    
        x_vals = ideal_transform_values[:,0]
        y_vals = ideal_transform_values[:,1]
        ax[0][0].scatter(x_vals, y_vals, color=color, label = target)
        ax[0][0].set_title("Ideal image embeddings")
    
        x_vals = sim_transform_values[:,0]
        y_vals = sim_transform_values[:,1]
        ax[0][1].scatter(x_vals, y_vals, color=color, label = target)
        ax[0][1].set_title("Simulated image embeddings")

        x_vals = bench_transform_values[:,0]
        y_vals = bench_transform_values[:,1]
        ax[0][2].scatter(x_vals, y_vals, color=color, label = target)
        ax[0][2].set_title("Bench image embeddings")

for target in pt_unique_targets:
        indices = np.where(pt_targets == target)[0]
        baseline_indices = np.where(baseline_targets == target)[0]
        ideal_transform_values = baseline_ideal_pca[baseline_indices]

        bench_transform_values = pt_bench_pca[indices]
        sim_transform_values = pt_sim_pca[indices]
        color_idx = color2_indices[target]
        color = colors2[color_idx]
    
        x_vals = ideal_transform_values[:,0]
        y_vals = ideal_transform_values[:,1]
        ax[1][0].scatter(x_vals, y_vals, color=color, label = target)
        ax[1][0].set_title("Ideal image embeddings")
    
        x_vals = sim_transform_values[:,0]
        y_vals = sim_transform_values[:,1]
        ax[1][1].scatter(x_vals, y_vals, color=color, label = target)
        ax[1][1].set_title("Simulated image embeddings")

        x_vals = bench_transform_values[:,0]
        y_vals = bench_transform_values[:,1]
        ax[1][2].scatter(x_vals, y_vals, color=color, label = target)
        ax[1][2].set_title("Bench image embeddings")



for ax in ax.flatten():
    ax.set_aspect('equal')
    ax.legend(frameon=True, framealpha=1)
    ax.set_xlim(-5.5, 5.5)
    ax.set_ylim(-5.5, 5.5)
plt.tight_layout()
fig.savefig('feature_space_comparison.png')

# UMAP comparison

In [None]:
umap_transform = umap.UMAP(n_neighbors=5, random_state=42).fit(baseline_ideal_feature_embeddings)

In [None]:
baseline_ideal_umap = umap_transform.transform(baseline_ideal_feature_embeddings)
baseline_bench_umap = umap_transform.transform(baseline_bench_feature_embeddings)
baseline_sim_umap = umap_transform.transform(baseline_sim_feature_embeddings)
pt_bench_umap = umap_transform.transform(pt_bench_feature_embeddings)
pt_sim_umap = umap_transform.transform(pt_sim_feature_embeddings)

In [None]:
fig,ax = plt.subplots(2,3, figsize=(15,10))

for target in baseline_unique_targets:
        indices = np.where(baseline_targets == target)[0]
        bench_transform_values = baseline_bench_umap[indices]
        ideal_transform_values = baseline_ideal_umap[indices]
        sim_transform_values = baseline_sim_umap[indices]
        color_idx = color2_indices[target]
        color = colors2[color_idx]
    
        x_vals = ideal_transform_values[:,0]
        y_vals = ideal_transform_values[:,1]
        ax[0][0].scatter(x_vals, y_vals, color=color, label = target)
        ax[0][0].set_title("Ideal image embeddings")
    
        x_vals = sim_transform_values[:,0]
        y_vals = sim_transform_values[:,1]
        ax[0][1].scatter(x_vals, y_vals, color=color, label = target)
        ax[0][1].set_title("Simulated image embeddings")

        x_vals = bench_transform_values[:,0]
        y_vals = bench_transform_values[:,1]
        ax[0][2].scatter(x_vals, y_vals, color=color, label = target)
        ax[0][2].set_title("Bench image embeddings")

for target in pt_unique_targets:
        indices = np.where(pt_targets == target)[0]
        baseline_indices = np.where(baseline_targets == target)[0]
        ideal_transform_values = baseline_ideal_umap[baseline_indices]

        bench_transform_values = pt_bench_umap[indices]
        sim_transform_values = pt_sim_umap[indices]
        color_idx = color2_indices[target]
        color = colors2[color_idx]
    
        x_vals = ideal_transform_values[:,0]
        y_vals = ideal_transform_values[:,1]
        ax[1][0].scatter(x_vals, y_vals, color=color, label = target)
        ax[1][0].set_title("Ideal image embeddings")
    
        x_vals = sim_transform_values[:,0]
        y_vals = sim_transform_values[:,1]
        ax[1][1].scatter(x_vals, y_vals, color=color, label = target)
        ax[1][1].set_title("Simulated image embeddings")

        x_vals = bench_transform_values[:,0]
        y_vals = bench_transform_values[:,1]
        ax[1][2].scatter(x_vals, y_vals, color=color, label = target)
        ax[1][2].set_title("Bench image embeddings")



for ax in ax.flatten():
    ax.set_aspect('equal')
    ax.legend(frameon=True, framealpha=1)
    ax.set_xlim(-10, 20)
    ax.set_ylim(-10, 20)
plt.tight_layout()

In [None]:
umap_transform = umap.UMAP(n_neighbors=5, random_state=42).fit(pt_bench_feature_embeddings)

In [None]:
pt_bench_umap = umap_transform.transform(pt_bench_feature_embeddings)


In [None]:
fig,ax = plt.subplots(1,1, figsize=(5,5))

for target in pt_unique_targets:
        indices = np.where(pt_targets == target)[0]
        bench_transform_values = pt_bench_umap[indices]
        color_idx = color2_indices[target]
        color = colors2[color_idx]
    
        x_vals = ideal_transform_values[:,0]
        y_vals = ideal_transform_values[:,1]
        ax.scatter(x_vals, y_vals, color=color, label = target)
        ax.set_title("Bench image embeddings")