# Analysis of the pretrained facenet

## Imports

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import rasterio as rio
import sklearn as sk
import scipy as sc
import seaborn as sns

In [2]:
import tqdm
#from tqdm import tqdm
from tqdm.notebook import tqdm # for notebooks
tqdm.pandas()

In [3]:
import pickle
import gc
collectAll = lambda : gc.collect(0) + gc.collect(1) + gc.collect(2)

In [4]:
#from pandarallel import pandarallel
#pandarallel.initialize(progress_bar=True)

In [5]:
#import itables
#from itables import show

In [6]:
from collections import namedtuple
import itertools
from itertools import product

In [7]:
from skimage.feature import graycomatrix
from skimage.feature import graycoprops
from math import pi

In [8]:
import PIL as pil
from PIL import Image, ImageDraw

In [9]:
#import papermill as pm

In [10]:
import math

In [11]:
import matplotlib
matplotlib.rcParams["image.origin"] = 'upper'

In [12]:
import os

In [13]:
import functools

In [None]:
import tensorflow as tf

In [None]:
import torch
import torchvision

In [None]:
from einops import *

In [None]:
import cv2

## Load in the pretrained facenet

In [None]:
import model
from model import face_model

In [None]:
import importlib
model = importlib.reload(model)
face_model = model.face_model

In [None]:
class Params():
    def __init__(self):
        self.image_size = 250
        self.embedding_size=128

In [None]:
fm = face_model(Params())

In [None]:
checkpoint = tf.train.Checkpoint(fm)

In [None]:
checkpoint.restore('../weights/weights/ckpt/ckpt-11')

## Load in our data

In [None]:
from pathlib import Path
result = Path("./glcms_3_7_128").rglob("*.npz")

In [None]:
df = pd.DataFrame(result, columns=['npz_path'])

In [None]:
df['path_str_split'] = df['npz_path'].apply(lambda x: str(x)).str.split('\\')

In [None]:
df['phase'] = df['path_str_split'].apply(lambda x: x[1]).astype("category")

In [None]:
df['filename'] = df['path_str_split'].apply(lambda x: x[-1])

In [None]:
df[['species', 'tree_id']] = df['filename'].str.split('_').apply(lambda x: pd.Series(x[:2]))

In [None]:
df['species'] = df['species'].astype('category')

In [None]:
df['tree_id'] = df['tree_id'].str.split('.').apply(lambda x: x[0]).astype("int")

In [None]:
df_nv = df[['npz_path', 'phase', 'species', 'tree_id']]

In [None]:
df_nv.head()

In [None]:
def get_glcm(filename_or_arr, is_filename=True):
    if is_filename:
        arr = np.load(filename_or_arr, allow_pickle=True)
    else:
        arr = filename_or_arr
    rval = arr
    return rval

In [None]:
df['glcm'] = df['npz_path'].progress_apply(lambda x: get_glcm(x))

In [None]:
df['glcm'].iloc[0].shape

In [None]:
#df = df[df['species'] != 'Cratoxylum Formosum']

In [None]:
#df = df[df['glcm'].apply(lambda x: x.shape[0]) >= 64]
#df = df[df['glcm'].apply(lambda x: x.shape[1]) >= 64]

In [None]:
len(df)

## Make sense of our channel ordering

Wideband Red = 0, Wideband Green = 1, Wideband Blue = 2, RedEdge = 3, Blue = 4, NIR = 5, Red = 6, Green = 7

NONE = 0, HOMOGENEITY = 1, CONTRAST = 2, ASM = 3, MEAN_I = 4, VAR_I = 5, CORRELATION = 6

In [None]:
bands = ['wr', 'wg', 'wb', 're', 'b', 'ni', 'r', 'g']

In [None]:
filters = ['none', 'homogeneity', 'contrast', 'asm', 'mean', 'var', 'correlation']

## Duplicate how Aaron did it

### Feature selection

Bands used: Red-Edge, Narrowband Blue, NIR, Narrowband Red, Narrowband Green

GLCM filters used: None, GLCM Mean.

In [None]:
selected = list(itertools.product(('none', 'mean'),
                                  ('re', 'b', 'nir', 'r', 'g')))

In [None]:
selected = [(i[1], i[0]) for i in selected]

In [None]:
selected

In [None]:
def aaron_feature_selection(arr):
    data = np.zeros((arr.shape[0], arr.shape[1], 10))
    data[:,:,0] = arr[:,:,3,0]
    data[:,:,1] = arr[:,:,4,0]
    data[:,:,2] = arr[:,:,5,0]
    data[:,:,3] = arr[:,:,6,0]
    data[:,:,4] = arr[:,:,7,0]
    data[:,:,5] = arr[:,:,3,4]
    data[:,:,6] = arr[:,:,4,4]
    data[:,:,7] = arr[:,:,5,4]    
    data[:,:,8] = arr[:,:,6,4]
    data[:,:,9] = arr[:,:,7,4]
    return data

In [None]:
df['sel_features'] = df['glcm'].progress_apply(aaron_feature_selection)

In [None]:
fig, ax =  plt.subplots(2, 5, figsize=(5.6*5, 4.8*2))
for ax_, idx, title in zip(ax.flatten(), range(len(ax.flatten())), selected):
    img = ax_.imshow(df['sel_features'].iloc[0][:,:,idx])
    ax_.set_title(title)
    plt.colorbar(img)

In [None]:
df[['species', 'phase']].iloc[0]

In [None]:
def crop_center(img,cropx,cropy):
    y,x = img.shape[:2]
    if y < cropy or x < cropx:
        return float('NaN')
    startx = x//2-(cropx//2)
    starty = y//2-(cropy//2)    
    return img[starty:starty+cropy,startx:startx+cropx]

In [None]:
df_croppable = df['sel_features'][df['sel_features'].apply(lambda x: True if x.shape[0] > 64 and x.shape[1] > 64 else False)]

In [None]:
len(df_croppable)

In [None]:
catted = np.concatenate(list(df_croppable\
                                 .apply(lambda x: crop_center(x, 64, 64))\
                                 .apply(lambda x: x.reshape(-1, 10))))

In [None]:
catted = pd.DataFrame(catted)

In [None]:
catted.columns = selected

In [None]:
catted = pd.melt(catted)

In [None]:
import warnings
warnings.filterwarnings("ignore", "is_categorical_dtype")
warnings.filterwarnings("ignore", "use_inf_as_na")

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(5.6*2, 4.8))
sns.histplot(catted,
             x='value',
             hue='variable',
             bins=256,
             ax=ax[0])
sns.kdeplot(catted,
            x='value',
            hue='variable',
            ax=ax[1])

In [None]:
catted = np.concatenate(list(df_croppable.iloc[:1]\
                                 .apply(lambda x: crop_center(x, 64, 64))\
                                 .apply(lambda x: x.reshape(-1, 10))))

In [None]:
catted = pd.DataFrame(catted)

In [None]:
catted.columns = selected

In [None]:
catted = pd.melt(catted)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(5.6*2, 4.8))
sns.histplot(catted.iloc[:64*64*10],
             x='value',
             hue='variable',
             bins=256,
             ax=ax[0])
sns.kdeplot(catted.iloc[:64*64*10],
            x='value',
            hue='variable',
            ax=ax[1])

### Principal Components Analysis -- One image

In [None]:
from sklearn.decomposition import PCA

In [None]:
pcas = []

arr = df['sel_features'].iloc[0]

arr_filt = np.zeros((arr.shape[0], arr.shape[1], 10))

for j in range(len(selected)):
    #Applying to red channel and then applying inverse transform to transformed array.
    pca = PCA(5)
    pca_ = pca.fit_transform(arr[:,:,j])
    inverted = pca.inverse_transform(pca_)
    arr_filt[:,:,j] = inverted
    pcas.append(pca)

In [None]:
evrs = pd.DataFrame([i.explained_variance_ratio_ for i in pcas])

In [None]:
evrs = evrs.reset_index()
evrs['index'] = selected
evrs.columns = ['feature'] + list(evrs.columns)[1:]

In [None]:
evrs = pd.melt(evrs, id_vars=['feature'], var_name='component #', value_name='explained variance')

In [None]:
evrs['cum. EV'] = np.cumsum(evrs['explained variance'])

In [None]:
selected_ordered = evrs[evrs['component #'] == 0].sort_values('cum. EV', ascending=False)['feature']

In [None]:
sns.lineplot(evrs,
             x='component #',
             y='cum. EV',
             hue='feature',
             hue_order=selected_ordered)

In [None]:
fig, ax = plt.subplots(2, 10, figsize=(5.6*10, 4.8*2))
for col, col_idx in zip(ax.T, range(len(ax.T))):
    corrected_index = selected.index(selected_ordered[col_idx])
    col[0].imshow(arr[:,:,corrected_index])
    col[1].imshow(arr_filt[:,:,corrected_index])
    col[0].set_title(selected[corrected_index])
ax[0][0].set_ylabel("pre-pca filtering")
ax[1][0].set_ylabel("post-pca filtering")

## Do the PCA on all our data

In [None]:
def aaron_pca(arr):
    arr_filt = np.zeros((arr.shape[0], arr.shape[1], 10))
    for j in range(len(selected)):
        #Applying to red channel and then applying inverse transform to transformed array.
        pca = PCA(5)
        pca_ = pca.fit_transform(arr[:,:,j])
        inverted = pca.inverse_transform(pca_)
        arr_filt[:,:,j] = inverted
    return arr_filt

In [None]:
df['post_pca'] = df['sel_features'].progress_apply(aaron_pca)

In [None]:
plt.imshow(df['post_pca'].iloc[2][:,:,0])

## Viz. our clusterings with resizing, no PCA, no whitening

In [None]:
df_resizing = pd.DataFrame(
    df['sel_features'],
    index = df.index)
df_resizing.columns = ['input']

df_resizing['resized'] = df_resizing['input'].progress_apply(
    lambda x:
        np.stack([cv2.resize(x[:,:,i], dsize=(250, 250), interpolation=cv2.INTER_CUBIC) for i in range(x.shape[-1])],
                 axis=-1))

fig, ax = plt.subplots(2, len(df), figsize=(5.6*len(df), 4.8*2))
for col, col_idx in zip(ax.T, range(len(ax.T))):
    col[0].imshow(df['sel_features'].iloc[col_idx][:,:,3])
    col[1].imshow(df_resizing['resized'].iloc[col_idx][:,:,3])
    col[0].set_title('{}, {}'.format(df.iloc[col_idx]['phase'],
                                     df.iloc[col_idx]['species']))
ax[0][0].set_ylabel("pre-resizing")
ax[1][0].set_ylabel("post-resizing")

df_resizing['embeddings'] = df_resizing['resized'].progress_apply(
    lambda x:
        tf.math.l2_normalize(fm(x.reshape((-1, *x.shape))), axis=1, epsilon=1e-10))

df_resizing[['species', 'phase']] = df[['species', 'phase']]

In [None]:
from umap import UMAP
reducer = UMAP(n_components=2,
               metric='euclidean',
               n_neighbors=3,
               min_dist=0.4)

df_resizing['embeddings'].iloc[0].shape

X = reducer.fit_transform(np.stack(list(df_resizing['embeddings'].apply(lambda x: x[0]))))
X_train =\
    reducer.fit_transform(np.stack(list(df_resizing[df_resizing['phase'] == '10May2021']['embeddings'].apply(lambda x: x[0]))))
X_test =\
    reducer.fit_transform(np.stack(list(df_resizing[df_resizing['phase'] == '18Dec2020']['embeddings'].apply(lambda x: x[0]))))

species = df['species'].unique()

fig, ax = plt.subplots(2, 2, figsize=(5.6*2, 4.8*2))
sns.scatterplot(
    df_resizing,
    x=X[:,0], y=X[:,1],
    hue='phase',
    palette='tab20',
    s=12,
    legend=True,
    ax=ax[0][0])
sns.scatterplot(
    df_resizing,
    x=X[:,0], y=X[:,1],
    hue='species', hue_order=species,
    palette='tab20',
    s=12,
    legend=True,
    ax=ax[0][1])
sns.scatterplot(
    df_resizing[df_resizing['phase'] == '10May2021'],
    x=X_train[:,0], y=X_train[:,1],
    hue='species', hue_order=species,
    palette='tab20',
    s=12,
    legend=False,
    ax=ax[1][0])
sns.scatterplot(
    df_resizing[df_resizing['phase'] == '18Dec2020'],
    x=X_test[:,0], y=X_test[:,1],
    hue='species', hue_order=species,
    palette='tab20',
    s=12,
    legend=False,
    ax=ax[1][1])
sns.move_legend(ax[0][0], "upper left", bbox_to_anchor=(2.2, 1))
sns.move_legend(ax[0][1], "upper left", bbox_to_anchor=(1, 0.77))
ax[0][0].set_title("UMAP Train and Test, capture shown")
ax[0][1].set_title("UMAP Train and Test, species shown")
ax[1][0].set_title("UMAP Train/10May2021, species shown")
ax[1][1].set_title("UMAP Test/18Dec2021, species shown")
#fig.tight_layout()

## Viz. our clusterings with resizing and PCA, no whitening

In [None]:
df_resize_pca = pd.DataFrame(
    df['post_pca'],
    index = df.index)
df_resize_pca.columns = ['input']

df_resize_pca['resized'] = df_resize_pca['input'].progress_apply(
    lambda x:
        np.stack([cv2.resize(x[:,:,i], dsize=(250, 250), interpolation=cv2.INTER_CUBIC) for i in range(x.shape[-1])],
                 axis=-1))

fig, ax = plt.subplots(2, len(df), figsize=(5.6*len(df), 4.8*2))
for col, col_idx in zip(ax.T, range(len(ax.T))):
    col[0].imshow(df['post_pca'].iloc[col_idx][:,:,3])
    col[1].imshow(df_resize_pca['resized'].iloc[col_idx][:,:,3])
    col[0].set_title('{}, {}'.format(df.iloc[col_idx]['phase'],
                                     df.iloc[col_idx]['species']))
ax[0][0].set_ylabel("pre-resize_pca")
ax[1][0].set_ylabel("post-resize_pca")

df_resize_pca['embeddings'] = df_resize_pca['resized'].progress_apply(
    lambda x:
        tf.math.l2_normalize(fm(x.reshape((-1, *x.shape))), axis=1, epsilon=1e-10))

df_resize_pca[['species', 'phase']] = df[['species', 'phase']]

In [None]:
from umap import UMAP
reducer = UMAP(n_components=2,
               metric='euclidean',
               n_neighbors=3,
               min_dist=0.4)

df_resize_pca['embeddings'].iloc[0].shape

X = reducer.fit_transform(np.stack(list(df_resize_pca['embeddings'].apply(lambda x: x[0]))))
X_train =\
    reducer.fit_transform(np.stack(list(df_resize_pca[df_resize_pca['phase'] == '10May2021']['embeddings'].apply(lambda x: x[0]))))
X_test =\
    reducer.fit_transform(np.stack(list(df_resize_pca[df_resize_pca['phase'] == '18Dec2020']['embeddings'].apply(lambda x: x[0]))))

species = df['species'].unique()

fig, ax = plt.subplots(2, 2, figsize=(5.6*2, 4.8*2))
sns.scatterplot(
    df_resize_pca,
    x=X[:,0], y=X[:,1],
    hue='phase',
    palette='tab20',
    s=12,
    legend=True,
    ax=ax[0][0])
sns.scatterplot(
    df_resize_pca,
    x=X[:,0], y=X[:,1],
    hue='species', hue_order=species,
    palette='tab20',
    s=12,
    legend=True,
    ax=ax[0][1])
sns.scatterplot(
    df_resize_pca[df_resize_pca['phase'] == '10May2021'],
    x=X_train[:,0], y=X_train[:,1],
    hue='species', hue_order=species,
    palette='tab20',
    s=12,
    legend=False,
    ax=ax[1][0])
sns.scatterplot(
    df_resize_pca[df_resize_pca['phase'] == '18Dec2020'],
    x=X_test[:,0], y=X_test[:,1],
    hue='species', hue_order=species,
    palette='tab20',
    s=12,
    legend=False,
    ax=ax[1][1])
sns.move_legend(ax[0][0], "upper left", bbox_to_anchor=(2.2, 1))
sns.move_legend(ax[0][1], "upper left", bbox_to_anchor=(1, 0.77))
ax[0][0].set_title("UMAP Train and Test, capture shown")
ax[0][1].set_title("UMAP Train and Test, species shown")
ax[1][0].set_title("UMAP Train/10May2021, species shown")
ax[1][1].set_title("UMAP Test/18Dec2021, species shown")
#fig.tight_layout()

## Viz. our clusterings with cropping and PCA, no whitening

In [None]:
def crop_center(img,cropx,cropy):
    y,x = img.shape[:2]
    if y < cropy or x < cropx:
        return float('NaN')
    startx = x//2-(cropx//2)
    starty = y//2-(cropy//2)    
    return img[starty:starty+cropy,startx:startx+cropx]

In [None]:
df_crop_pca = pd.DataFrame(
    df['post_pca'],
    index = df.index)
df_crop_pca.columns = ['input']

df_crop_pca = df_crop_pca[df_crop_pca['input'].apply(lambda x: True if x.shape[0] > 64 and x.shape[1] > 64 else False)]

df_crop_pca['cropped'] = df_crop_pca['input'].progress_apply(
    lambda x: crop_center(x, 64, 64))
df_crop_pca['cropped'] = df_crop_pca['cropped'].progress_apply(
    lambda x:
        np.stack([cv2.resize(x[:,:,i], dsize=(250, 250), interpolation=cv2.INTER_CUBIC) for i in range(x.shape[-1])],
                 axis=-1))

fig, ax = plt.subplots(2, len(df), figsize=(5.6*len(df), 4.8*2))
for col, col_idx in zip(ax.T, range(len(ax.T))):
    iloc_to_loc = df['post_pca'].index[col_idx]
    col[0].imshow(df['post_pca'].loc[iloc_to_loc][:,:,3])
    if iloc_to_loc in df_crop_pca.index:
        col[1].imshow(df_crop_pca['cropped'].loc[iloc_to_loc][:,:,3])
    col[0].set_title('{}, {}'.format(df.loc[iloc_to_loc]['phase'],
                                     df.loc[iloc_to_loc]['species']))
ax[0][0].set_ylabel("pre-crop_pca")
ax[1][0].set_ylabel("post-crop_pca")

df_crop_pca['embeddings'] = df_crop_pca['cropped'].progress_apply(
    lambda x:
        tf.math.l2_normalize(fm(x.reshape((-1, *x.shape))), axis=1, epsilon=1e-10))

df_crop_pca[['species', 'phase']] = df[['species', 'phase']]

In [None]:
from umap import UMAP
reducer = UMAP(n_components=2,
               metric='euclidean',
               n_neighbors=3,
               min_dist=0.4)

df_crop_pca['embeddings'].iloc[0].shape

X = reducer.fit_transform(np.stack(list(df_crop_pca['embeddings'].apply(lambda x: x[0]))))
X_train =\
    reducer.fit_transform(np.stack(list(df_crop_pca[df_crop_pca['phase'] == '10May2021']['embeddings'].apply(lambda x: x[0]))))
X_test =\
    reducer.fit_transform(np.stack(list(df_crop_pca[df_crop_pca['phase'] == '18Dec2020']['embeddings'].apply(lambda x: x[0]))))

species = df['species'].unique()

fig, ax = plt.subplots(2, 2, figsize=(5.6*2, 4.8*2))
sns.scatterplot(
    df_crop_pca,
    x=X[:,0], y=X[:,1],
    hue='phase',
    palette='tab20',
    s=12,
    legend=True,
    ax=ax[0][0])
sns.scatterplot(
    df_crop_pca,
    x=X[:,0], y=X[:,1],
    hue='species', hue_order=species,
    palette='tab20',
    s=12,
    legend=True,
    ax=ax[0][1])
sns.scatterplot(
    df_crop_pca[df_crop_pca['phase'] == '10May2021'],
    x=X_train[:,0], y=X_train[:,1],
    hue='species', hue_order=species,
    palette='tab20',
    s=12,
    legend=False,
    ax=ax[1][0])
sns.scatterplot(
    df_crop_pca[df_crop_pca['phase'] == '18Dec2020'],
    x=X_test[:,0], y=X_test[:,1],
    hue='species', hue_order=species,
    palette='tab20',
    s=12,
    legend=False,
    ax=ax[1][1])
sns.move_legend(ax[0][0], "upper left", bbox_to_anchor=(2.2, 1))
sns.move_legend(ax[0][1], "upper left", bbox_to_anchor=(1, 0.77))
ax[0][0].set_title("UMAP Train and Test, capture shown")
ax[0][1].set_title("UMAP Train and Test, species shown")
ax[1][0].set_title("UMAP Train/10May2021, species shown")
ax[1][1].set_title("UMAP Test/18Dec2021, species shown")
#fig.tight_layout()

## Implement a random crop-and-stretch transformation

In [None]:
plt.imshow(df.iloc[0]['sel_features'][:,:,3])

In [None]:
transform = torchvision.transforms.RandomResizedCrop(
    size=(250,250),
    scale=((64/250)**2,1),
    ratio=(1/2, 2),
    interpolation=torchvision.transforms.InterpolationMode.BICUBIC,
    antialias=True
)

In [None]:
pre_trans = df.iloc[0]['sel_features']
post_trans =\
    [rearrange(
        transform(torch.Tensor(rearrange(pre_trans, 'h w c -> c h w'))).numpy(),
        'c h w -> h w c')
     for i in range(10)]

In [None]:
fig, ax = plt.subplots(2, 5, figsize=(5.6*5, 4.8*2))
for ax_, idx in zip(ax.flatten(), range(len(ax.flatten()))):
    ax_.imshow(post_trans[idx][:,:,3])

## Apply it to our data

In [None]:
def rand_crop_stretch(pre_trans, n):
    return\
        [rearrange(
            transform(torch.Tensor(rearrange(pre_trans, 'h w c -> c h w'))).numpy(),
            'c h w -> h w c')
         for i in range(n)]

In [None]:
def embed(x):
    return tf.math.l2_normalize(fm(x.reshape((-1, *x.shape))), axis=1, epsilon=1e-10)

In [None]:
df_augment = df.copy()

In [None]:
df_augment['augment'] = df['post_pca'].progress_apply(lambda x: rand_crop_stretch(x, 20))

In [None]:
df_augment = df_augment.explode('augment')

In [None]:
df_augment['embeddings'] = df_augment['augment'].progress_apply(embed)

In [None]:
plt.imshow(df.iloc[0]['post_pca'][:,:,3])

In [None]:
post_trans = list(df_augment.iloc[:20]['augment'])

In [None]:
fig, ax = plt.subplots(2, 5, figsize=(5.6*5, 4.8*2))
for ax_, idx in zip(ax.flatten(), range(len(ax.flatten()))):
    ax_.imshow(post_trans[idx][:,:,3])

In [None]:
from umap import UMAP
reducer = UMAP(n_components=2,
               metric='euclidean',
               n_neighbors=20,
               min_dist=0.1)

df_augment['embeddings'].iloc[0].shape

X = reducer.fit_transform(np.stack(list(df_augment['embeddings'].apply(lambda x: x[0]))))
X_train =\
    reducer.fit_transform(np.stack(list(df_augment[df_augment['phase'] == '10May2021']['embeddings'].apply(lambda x: x[0]))))
X_test =\
    reducer.fit_transform(np.stack(list(df_augment[df_augment['phase'] == '18Dec2020']['embeddings'].apply(lambda x: x[0]))))

species = df['species'].unique()

fig, ax = plt.subplots(2, 2, figsize=(5.6*2, 4.8*2))
sns.scatterplot(
    df_augment,
    x=X[:,0], y=X[:,1],
    hue='phase',
    palette='tab20',
    s=4,
    legend=True,
    ax=ax[0][0])
sns.scatterplot(
    df_augment,
    x=X[:,0], y=X[:,1],
    hue='species', hue_order=species,
    palette='tab20',
    s=4,
    legend=True,
    ax=ax[0][1])
sns.scatterplot(
    df_augment[df_augment['phase'] == '10May2021'],
    x=X_train[:,0], y=X_train[:,1],
    hue='species', hue_order=species,
    palette='tab20',
    s=4,
    legend=False,
    ax=ax[1][0])
sns.scatterplot(
    df_augment[df_augment['phase'] == '18Dec2020'],
    x=X_test[:,0], y=X_test[:,1],
    hue='species', hue_order=species,
    palette='tab20',
    s=4,
    legend=False,
    ax=ax[1][1])
sns.move_legend(ax[0][0], "upper left", bbox_to_anchor=(2.2, 1))
sns.move_legend(ax[0][1], "upper left", bbox_to_anchor=(1, 0.77))
ax[0][0].set_title("UMAP Train and Test, capture shown")
ax[0][1].set_title("UMAP Train and Test, species shown")
ax[1][0].set_title("UMAP Train/10May2021, species shown")
ax[1][1].set_title("UMAP Test/18Dec2021, species shown")
#fig.tight_layout()

## What happens if we introduce whitening?

In [None]:
df_whiten = df.copy()

In [None]:
def whiten_image(x):
    return (x-x.mean(axis=(0,1), keepdims=True))/x.std(axis=(0,1), keepdims=True)

In [None]:
df_whiten['whiten'] = df['post_pca'].progress_apply(lambda x: [whiten_image(i) for i in rand_crop_stretch(x, 20)])

In [None]:
df_whiten = df_whiten.explode('whiten')

In [None]:
df_whiten['embeddings'] = df_whiten['whiten'].progress_apply(embed)

In [None]:
plt.imshow(df.iloc[0]['post_pca'][:,:,3])

In [None]:
post_trans = list(df_whiten.iloc[:20]['whiten'])

In [None]:
post_trans[idx][:,:,2].min()

In [None]:
fig, ax = plt.subplots(2, 5, figsize=(5.6*5, 4.8*2))
for ax_, idx in zip(ax.flatten(), range(len(ax.flatten()))):
    ax_.imshow(post_trans[idx][:,:,:3])

In [None]:
fig, ax = plt.subplots(2, 5, figsize=(5.6*5, 4.8*2))
for ax_, idx in zip(ax.flatten(), range(len(ax.flatten()))):
    ax_.imshow((post_trans[idx][:,:,:3] - post_trans[idx][:,:,:3].min()) / (post_trans[idx][:,:,:3].max() - post_trans[idx][:,:,:3].min()))

In [None]:
from umap import UMAP
reducer = UMAP(n_components=2,
               metric='euclidean',
               n_neighbors=20,
               min_dist=0.1)

df_whiten['embeddings'].iloc[0].shape

X = reducer.fit_transform(np.stack(list(df_whiten['embeddings'].apply(lambda x: x[0]))))
X_train =\
    reducer.fit_transform(np.stack(list(df_whiten[df_whiten['phase'] == '10May2021']['embeddings'].apply(lambda x: x[0]))))
X_test =\
    reducer.fit_transform(np.stack(list(df_whiten[df_whiten['phase'] == '18Dec2020']['embeddings'].apply(lambda x: x[0]))))

species = df['species'].unique()

fig, ax = plt.subplots(2, 2, figsize=(5.6*2, 4.8*2))
sns.scatterplot(
    df_whiten,
    x=X[:,0], y=X[:,1],
    hue='phase',
    palette='tab20',
    s=4,
    legend=True,
    ax=ax[0][0])
sns.scatterplot(
    df_whiten,
    x=X[:,0], y=X[:,1],
    hue='species', hue_order=species,
    palette='tab20',
    s=4,
    legend=True,
    ax=ax[0][1])
sns.scatterplot(
    df_whiten[df_whiten['phase'] == '10May2021'],
    x=X_train[:,0], y=X_train[:,1],
    hue='species', hue_order=species,
    palette='tab20',
    s=4,
    legend=False,
    ax=ax[1][0])
sns.scatterplot(
    df_whiten[df_whiten['phase'] == '18Dec2020'],
    x=X_test[:,0], y=X_test[:,1],
    hue='species', hue_order=species,
    palette='tab20',
    s=4,
    legend=False,
    ax=ax[1][1])
sns.move_legend(ax[0][0], "upper left", bbox_to_anchor=(2.2, 1))
sns.move_legend(ax[0][1], "upper left", bbox_to_anchor=(1, 0.77))
ax[0][0].set_title("UMAP Train and Test, capture shown")
ax[0][1].set_title("UMAP Train and Test, species shown")
ax[1][0].set_title("UMAP Train/10May2021, species shown")
ax[1][1].set_title("UMAP Test/18Dec2021, species shown")
#fig.tight_layout()

## Visualise the intermediate activations on various images

In [None]:
df_resize_pca['conversion_layer_outputs'] = df_resize_pca['resized'].progress_apply(
    lambda x: rearrange(
        fm.conversion_layer_1(fm.conversion_layer_2(fm.conversion_layer(
            tf.expand_dims(x, axis=0)))),
        'b h w c -> (b h) w c').numpy())

In [None]:
input = df_resize_pca['resized'].iloc[0]
output = df_resize_pca['conversion_layer_outputs'].iloc[0]

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(5.6*2, 4.8))
sns.histplot([input[:,:,i].flatten() for i in range(10)], ax=ax[0])
sns.kdeplot([input[:,:,i].flatten() for i in range(10)], ax=ax[1])

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(5.6*2, 4.8))
sns.histplot([output[:,:,i].flatten() for i in range(3)], ax=ax[0])
sns.kdeplot([output[:,:,i].flatten() for i in range(3)], ax=ax[1])

In [None]:
np.unique(output[:,:,0]).size,\
np.unique(output[:,:,1]).size,\
np.unique(output[:,:,2]).size

In [None]:
sns.histplot([output[:,:,i].flatten() for i in range(3)],
             log_scale=(False, True),
             element='step',
             fill=False,
             bins=128)

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(5.6*3, 4.8))
for ax_, idx in zip(ax, range(3)):
    ax_.imshow(output[:,:,idx])

In [None]:
output.mean(axis=(0,1))

In [None]:
def minmax_image(x):
    x = x * (1 / x[:,:,0].max())
    return x

In [None]:
sns.histplot([minmax_image(output)[:,:,i].flatten() for i in range(3)],
             log_scale=(False, True),
             element='step',
             fill=False,
             bins=128)

In [None]:
plt.imshow(input[:,:,[1, 3, 4]] * input[:,:,[1, 3, 4]].max() / ) 

In [None]:
df_resize_pca['filename'] = df['filename']

In [None]:
selected

In [None]:
fig, ax = plt.subplots(67, 6, figsize=(5.6*4, 4.8*66))
row_num = 0
for idx, row in tqdm(df_resize_pca.iterrows()):
    row_ax = ax[row_num]
    row_ax[0].set_xlabel(str((row['filename'], row['phase'])))
    row_ax[0].set_ylabel('rgb')
    row_ax[1].set_ylabel('red_edge')
    row_ax[2].set_ylabel('nir')
    row_ax[3].set_ylabel('out1')
    row_ax[4].set_ylabel('out2')
    row_ax[5].set_ylabel('out3')
    row_ax[0].imshow(minmax_image(row['resized'][:,:,[3, 4, 1]]))
    row_ax[1].imshow(row['resized'][:,:,0])
    row_ax[2].imshow(row['resized'][:,:,2])
    row_ax[3].imshow(row['conversion_layer_outputs'][:,:,0])
    row_ax[4].imshow(row['conversion_layer_outputs'][:,:,1])
    row_ax[5].imshow(row['conversion_layer_outputs'][:,:,2])
    row_num += 1
plt.tight_layout()