# Investigate a Sample

### First run this cell

In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib notebook
%load_ext autoreload
%autoreload 2


#load some packages in
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
import random as python_random
from numba import njit
from tensorboard.plugins.hparams import api as hp
from stemutils.io import Path
import hyperspy.api as hs
import concurrent.futures
from skimage.transform import resize
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from functools import lru_cache
from stemseg.processing_funcs import *
import json

#set some variables
print('Using TensorFlow v%s' % tf.__version__)
plt.style.use('default')
python_random.seed(0)
np.random.seed(0)
tf.random.set_seed(0)


#define some functions

###################################################
########### Data Preprocessing ####################
###################################################

def batch_resize(d, bs=512):
    if len(d.shape) == 4:
        flat_d = flatten_nav(d)
    else:
        flat_d = d
    n_batches = int(np.ceil(flat_d.shape[0]//bs))
    batches = [flat_d[i*bs:(i+1)*bs] for i in range(n_batches+1)]
    if len(batches[-1])==0:
        batches.pop(-1)
    print(len(batches[-1]))
    with concurrent.futures.ProcessPoolExecutor() as exe:
        res = [exe.submit(resize, batch, (batch.shape[0],128,128)) for batch in batches]
    r_batches = [f.result() for f in res]
    return np.concatenate(r_batches, axis = 0).reshape((d.shape[0],128,128))

def data_manip(d, bs = 512):
    if type(d) != np.ndarray:
        print('dask to numpy')
        d = d.compute()
        print('dask to numpy done')
    print('started data manipulations')
    #d = resize(d,(d.shape[0],128,128))
    print('resized')
    d_maxes = np.max(d,(1,2))[:,None,None]
    d = d/d_maxes
    d_maxes = None
    d = batch_resize(d, bs)
    scaler = np.log(1001)
    return np.log((d*1000)+1)/scaler 


###################################################
###################################################
###################################################

def flatten_nav(sig):
    shape = [sig.shape[0]*sig.shape[1]]
    for i in sig.shape[2:]:
        shape.append(i)
    return sig.reshape(shape)


class My_Custom_Generator(keras.utils.Sequence) :
    def __init__(self, image_filenames,  batch_size) :
        self.image_filenames = image_filenames
        self.batch_size = batch_size

    def __len__(self) :
        return (np.ceil(len(self.image_filenames) / float(self.batch_size))).astype(np.int)
    
    
    @lru_cache(None)
    def __getitem__(self, idx) :
        batch_x = self.image_filenames[idx * self.batch_size : (idx+1) * self.batch_size]
        out_img = np.asarray([np.load(file_name)[:,:,None] for file_name in batch_x])
        return out_img, out_img
        #return batch_x, batch_y
        
        
class Array_Generator(keras.utils.Sequence) :
    def __init__(self, images,  batch_size) :
        self.images = images
        self.batch_size = batch_size

    def __len__(self) :
        return (np.ceil(len(self.images) / float(self.batch_size))).astype(np.int)
    
    
    @lru_cache(None)
    def __getitem__(self, idx) :
        out_img = self.images[idx * self.batch_size : (idx+1) * self.batch_size, :,:,None]
        return out_img, out_img
        #return batch_x, batch_y

class Sampling(layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        epsilon = tf.keras.backend.random_normal(shape=tf.shape(z_mean))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon
    
def create_vae_model(hparams):
    
    n_img = 128
    latent_dim = hparams['LAT']
    beta = hparams['B']

    image_input = keras.Input(shape=(n_img, n_img,1), name = 'enc_input')
    x = layers.Conv2D(hparams['KN1'],5, strides = 2, activation='relu',padding='same', input_shape=image_input.shape, name = 'enc_conv1')(image_input)
    x = layers.Conv2D(hparams['KN2'],5, strides = 2, activation='relu',padding='same', name = 'enc_conv2')(x)
    x = layers.Conv2D(hparams['KN3'],5, strides = 2, activation='relu',padding='same', name = 'enc_conv3')(x)
    x = layers.Conv2D(hparams['KN4'],5, strides = 2, activation='relu',padding='same', name = 'enc_conv4')(x)
    x = layers.Conv2D(hparams['KN5'],5, strides = 2, activation='relu',padding='same', name = 'enc_conv5')(x)
    x = layers.Flatten()(x)
    x = layers.Dense(hparams['D1'], activation='relu', name = 'enc_d1')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d2_t')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d3_t')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d4_t')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d5_t')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d6_t')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d7_t')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d8_t')(x)
    z_mean = layers.Dense(latent_dim, name="z_mean_t")(x)
    z_log_var = layers.Dense(latent_dim, name="z_log_var_t")(x)
    z_output = Sampling()([z_mean, z_log_var])
    encoder_VAE = keras.Model(image_input, [z_mean, z_log_var, z_output])

    z_input = keras.Input(shape=(latent_dim,), name = 'dec_input_t')
    x = layers.Dense(hparams['D2'], activation="relu", name = 'dec_d1_t')(z_input)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'dec_d2')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'dec_d3')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'dec_d4')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'dec_d5')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'dec_d6')(x)
    x = layers.Dense(hparams['D1'], activation="relu", name = 'dec_d7')(x)
    x = layers.Dense(4*4*hparams['KN5'], activation="relu", name = 'dec_d8')(x)
    x = layers.Reshape((4, 4,hparams['KN5']))(x)
    x = layers.Conv2DTranspose(hparams['KN4'],5, strides = 2, activation='relu',padding='same', name = 'dec_conv1')(x)
    x = layers.Conv2DTranspose(hparams['KN3'],5, strides = 2, activation='relu',padding='same', name = 'dec_conv2')(x)
    x = layers.Conv2DTranspose(hparams['KN2'],5, strides = 2, activation='relu',padding='same', name = 'dec_conv3')(x)
    x = layers.Conv2DTranspose(hparams['KN1'],5, strides = 2, activation='relu',padding='same', name = 'dec_conv4')(x)
    image_output = layers.Conv2DTranspose(1,5, strides = 2, activation='sigmoid',padding='same', name = 'dec_conv5')(x)
    #image_output = layers.Conv2DTranspose(16,3, strides = 2, activation='sigmoid',padding='same')
    #image_output = layers.Reshape((n_img, n_img,1))(x)
    decoder_VAE = keras.Model(z_input, image_output)

    # VAE class
    class VAE(keras.Model):
        # constructor
        def __init__(self, encoder, decoder, **kwargs):
            super(VAE, self).__init__(**kwargs)
            self.encoder = encoder
            self.decoder = decoder

        # customise train_step() to implement the loss 
        def train_step(self, x):
            if isinstance(x, tuple):
                x = x[0]
            with tf.GradientTape() as tape:
                # encoding
                z_mean, z_log_var, z = self.encoder(x)
                # decoding
                x_prime = self.decoder(z)
                # reconstruction error by binary crossentropy loss
                reconstruction_loss = tf.reduce_mean(keras.losses.binary_crossentropy(x, x_prime)) * n_img * n_img
                # KL divergence
                kl_loss = -0.5 * tf.reduce_mean(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
                # loss = reconstruction error + KL divergence
                loss = reconstruction_loss + beta* kl_loss
            # apply gradient
            grads = tape.gradient(loss, self.trainable_weights)
            self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
            # return loss for metrics log
            return {"loss": loss}


        def call(self, x):
            if isinstance(x, tuple):
                x = x[0]
            # encoding
            z_mean, z_log_var, z = self.encoder(x)
            # decoding
            x_prime = self.decoder(z)
            return x_prime
    # build the VAE
    vae_model = VAE(encoder_VAE, decoder_VAE)

    # compile the VAE
    vae_model.compile(optimizer=keras.optimizers.Adam(learning_rate=hparams['LR']),loss=custom_loss)
    vae_model.build((1,128,128,1))
    
    return vae_model



def custom_loss(x,y):
    n_img = 128
    return tf.reduce_mean(keras.losses.binary_crossentropy(x, y)) * n_img * n_img

def remove_background(sample, thresh = 500, old_tag=None, new_tag=None,blanker = 30):
    d = sample.raw_data.data.copy()
    d_shape = d.shape
    n_shape, p_shape = d_shape[0:2], d_shape[2:]
    ps0 = p_shape[0] //2
    try:
        d[:,:,ps0- blanker:ps0 + blanker, ps0 - blanker: ps0 + blanker] = np.zeros((2*blanker,2*blanker))
    except:
        d = d.compute()
        d[:,:,ps0- blanker:ps0 + blanker, ps0 - blanker: ps0 + blanker] = np.zeros((2*blanker,2*blanker))
    maskx, masky = np.where(d.sum(axis=(2,3))<thresh)
    if old_tag !=None:
        clustmap = sample.all_maps[old_tag].copy()
        clustmap += 1
        clustmap[maskx, masky] = 0 
        newmap = np.zeros_like(clustmap)
        for i, o in enumerate(np.unique(clustmap)):
            newmap[np.where(clustmap == o)] = i
        newmap += 1
        if new_tag != None:
            sample.all_maps[new_tag] = newmap
        return newmap
    else:
        return np.where(d.sum(axis=(2,3))<thresh, 0, 1)

def show_cluster_patterns(sample, tag):
    uis = np.unique(sample.all_maps[tag])
    od = np.zeros((uis.size, 512, 256))
    for x,i in enumerate(uis):
        p = resize(sample.all_patterns[tag][i-1], (256,256))
        n = resize(np.where(sample.all_maps[tag] == i, p.max(), 1), (256,256))
        p[0,:] = p.max()
        o = np.concatenate([n,p], axis = 0)
        od[x] = o
    return hs.signals.Signal2D(od)

def signal_boosted_scan(sample, tag):
    uts = np.unique(sample.all_maps[tag])
    blank = np.zeros_like((sample.raw_data))
    blank = blank.astype('float32')
    for t in uts:
        blank[np.where(sample.all_maps[tag]==t)] = sample.all_patterns[tag][t]
    return hs.signals.Signal2D(blank)

def inv_sbs(sbs, tag = 'vl_vae', sp = (0,0), return_fig = False, interactive = True, **kwargs):
    sbsg = np.repeat(sbs.data.sum(axis= (2,3))[:,:,None],3, -1)
    sbsg /= sbsg.max()
    
    def boost(array):
        return np.log10(np.log10(array+1)+1)

    def format_ax():
        ax[0].set_frame_on(False)
        #ax[1].set_frame_on(False)
        ax[0].set_xticks([])
        ax[0].set_yticks([])
        ax[1].set_xticks([])
        ax[1].set_yticks([])
    fig, ax = plt.subplots(2, 1, gridspec_kw={'height_ratios': [1, 2]}, figsize=(8,8))
    
    
    clust = sample.all_maps[tag][sp[0],sp[1]]

    clust_loc = np.where(sample.all_maps[tag] == clust)

    new_nav = sbsg.copy()

    new_nav[clust_loc] = np.array([0.1254902 , 0.69803922, 0.66666667])
    
    
    ax[0].imshow(new_nav)
    ax[1].imshow(boost(sbs.data[sp[0],sp[1]]), cmap= 'gray', **kwargs)

    format_ax()
    
    if interactive == True:
    
        global coords
        coords = []

        def onclick(event):
            global ix, iy
            ix, iy = np.round(event.xdata,0), np.round(event.ydata,0)
            print(ix, iy)

            coords.append((ix, iy))

            ax[0].clear()
            ax[1].clear()

            clust = sample.all_maps[tag][int(iy),int(ix)]

            clust_loc = np.where(sample.all_maps[tag] == clust)

            new_nav = sbsg.copy()

            new_nav[clust_loc] = np.array([0.1254902 , 0.69803922, 0.66666667])



            ax[0].imshow(new_nav)
            ax[1].imshow(boost(sbs.data[int(iy),int(ix)]), cmap = 'gray', **kwargs)

            format_ax()

            ax[0].draw()
            ax[1].draw()


            return coords

        cid = fig.canvas.mpl_connect('button_press_event', onclick)

    if return_fig == True:
        return fig
        

### Check the GPU can be found

In [None]:
tf.__version__

In [None]:
tf.config.list_physical_devices()

In [None]:
fdp = Path('/dls/e02/data/2021/mg28749-1/processing/Calibrated/')

dl = fdp.ls()
[i for i in enumerate(dl)]

### Select your data set

In [None]:
#Can just set a Path to your dataset
dp = Path('something/something/file.hdf5')

In [None]:
#Or can use the index from the list above

#Can set index directly 
select = 0

#or find the timestamp
#ts = '135457'
#select = np.argmax([str(x).find(ts) for x in dl])

In [None]:
dp = dl[select].walk('.hdf5', 'Model')[0]
dp

### Set the Model Path

In [None]:
mp = dp.redirect('Final_Models')
mp

### Create a ProcessedSample Object

In [None]:
sample = ProcessedSample(dp, 'Test')

In [None]:
sample.raw_data.plot()

### Pre-process the data to speed up later functions (can avoid this if there are memory constraints)

In [None]:
sample.save_ml_manipulation('full_ds', data_manip,64)

## Set the hparams, can pull these out of the info dictionary

In [None]:
info_path = dp.redirect('Final_Models',1).walk('.json')[1]

In [None]:
with open(info_path, 'r') as f:
    info = json.load(f)

In [None]:
info

In [None]:
hparams= info['full']['hparams']

In [None]:
sample.set_model_data('/dls/science/groups/imaging/ePSIC_students/Andy_Bridger/mg28034-1/processing/Models/New','cnn',hparams['LAT'],use_generic_model = True)

### Check the model has built

In [None]:
model = create_vae_model(hparams)
model.summary()

### Load in the trained weights

best_model = mp.walk('.hdf5')[1]
best_model

In [None]:
model.load_weights('/dls/e02/data/2021/mg28749-1/processing/Calibrated/20210925_152115/FullModel/chk-250-3.23907e+02.hdf5')

### You now need to set the model to the sample 

In [None]:
sample.set_model(model)

### You can encode the data you pre-processed earlier (or will default to sample.raw_data)

In [None]:
sample.encode('vae',input_data_tag='full_ds')

### You can inspect the reconstructed image and compare it to the raw data

In [None]:
sample.inspect_model()

In [None]:
from skimage.metrics import structural_similarity as SSI
from skimage.transform import PiecewiseAffineTransform, warp
from sklearn.neighbors import NearestNeighbors as kNN

def get_latgrid(sample, res=100):
    xmin, xmax = np.floor(np.min(sample.encoded_data[:,0])), np.ceil(np.max(sample.encoded_data[:,0]))
    ymin, ymax = np.floor(np.min(sample.encoded_data[:,1])), np.ceil(np.max(sample.encoded_data[:,1]))

    latgrid_res = res

    xgrid, ygrid = np.repeat(np.linspace(xmin, xmax, latgrid_res)[:,None],latgrid_res, axis = 1), np.repeat(np.linspace(ymin, ymax, latgrid_res)[None,:],latgrid_res, axis = 0)

    return np.concatenate([xgrid[:,:,None], ygrid[:,:,None]],axis = 2)


def batch_calc_grad(img, radial_kernel, decoded_data, weighting_func, bs=256):
    ssi_ff = []
    n_batches = int(np.ceil(img.shape[0]//bs))
    batches = [img[i*bs:(i+1)*bs] for i in range(n_batches+1)]
    dec_batches = [decoded_data[i*bs:(i+1)*bs] for i in range(n_batches+1)]
    rs_batches = [b.reshape(b.shape[0]*b.shape[1]) for b in batches]
    cart_rs_batches = [np.concatenate([b.real[:,None], b.imag[:,None]], axis = 1) for b in rs_batches]
    for i, batch in enumerate(cart_rs_batches):
        print(i, n_batches)
        nimg = sample.model.decoder(batch).numpy()
        rs_nimg = nimg.reshape((int(nimg.size/(img.shape[1]*128*128)), img.shape[1], 128, 128))
        comp_patterns = dec_batches[i]
        for x, dec_pat in enumerate(comp_patterns):
            grad_ssi = np.asarray([weighting_func(dec_pat, y) for y in rs_nimg[x]])
            ssi_ff.append(np.sum(grad_ssi*radial_kernel))
    return np.asarray(ssi_ff)


def SSI_weighting(img1, img2):
    return 100*SSI(img1,img2)

def get_mobile_points(nn_comp_enc,steps, prev_mp_locs = (), thresh = 'mean', relative_locs = False):
    if thresh == 'mean':
        thresh = np.mean(np.abs(steps))
    if thresh == 'ten':
        thresh = np.max(np.abs(steps))/10
        print(thresh)
        print(np.where(np.abs(steps) > thresh))
    mp_locs = np.where(np.abs(steps) > thresh)
    mobile_points = nn_comp_enc[mp_locs]
    
    if len(prev_mp_locs) != 0:
        n_mp_locs = prev_mp_locs[mp_locs]
        
    if relative_locs == True:

        return mobile_points, n_mp_locs, mp_locs
    else:
        return mobile_points, n_mp_locs

def get_grad_and_decode_data(mobile_points, radial_kernel, r_scale_kernel = False, nn_scale = False):
    if r_scale_kernel ==False:
        grad_points = np.repeat(mobile_points[:,None], radial_kernel.shape[0], axis = 1) + radial_kernel[None, :]
    else:
        if nn_scale == False:
            rf = np.round((np.abs(mobile_points)/np.abs(mobile_points).min()),0).astype('int')
            grad_points = np.repeat(mobile_points[:,None], radial_kernel.shape[0], axis = 1) + r_scale_kernel*rf[:,None]*radial_kernel[None, :]
        else:
            sample_locs = np.concatenate((mobile_points.real[:,None], mobile_points.imag[:,None]), axis = 0)
            nbrs = kNN(n_neighbors=1, algorithm='ball_tree').fit(sample_locs)
            p_sep, indices = nbrs.kneighbors(sample_locs)
            closest = p_sep.min()
            norm_sep = p_sep/closest
            grad_points = np.repeat(mobile_points[:,None], radial_kernel.shape[0], axis = 1) + r_scale_kernel*norm_sep[:,None]*radial_kernel[None, :]

    dec_dat = get_terr_patts(np.concatenate([mobile_points.real[:,None], mobile_points.imag[:,None]],axis = 1))
    return grad_points, dec_dat

def sig_step_from_grad(d_gp, gradient_step, sigz=0.25, sigf=100):
    grad_mag = np.abs(d_gp)

    return sigmoid(grad_mag, sigz, sigf)*gradient_step*(d_gp/grad_mag) 

def norm_step_from_grad(d_gp, factor):
    grad_mag = np.max(np.abs(d_gp))
    
    return (d_gp/grad_mag)*factor 

def sigmoid(z, sigz=0.25, sigf=100):
    x = sigf*(z - sigz)
    return np.exp(-np.logaddexp(0, -x))


def adjust_encoding(mobile_points, grads, comp_enc, mp_locs):
    X,Y  = mobile_points.real, mobile_points.imag

    dX, dY = grads.real, grads.imag
    U, V = X+dX, Y+dY

    moved_points = U+1j*V

    migrated_points = comp_enc.copy()

    migrated_points[mp_locs] = moved_points
    
    return (X,Y), (U,V), migrated_points

def get_terr_patts(img, bs =256):
    n_batches = int(np.ceil(img.shape[0]//bs))
    batches = [img[i*bs:(i+1)*bs] for i in range(n_batches+1)]
    nimg = [sample.model.decoder(batch).numpy() for batch in batches]
    return np.concatenate(nimg, axis = 0).reshape((img.shape[0], 128,128))

In [None]:
latgrid = flatten_nav(get_latgrid(sample, res =250))

In [None]:
comp_latgrid = latgrid[:,0] + 1j*latgrid[:,1]

In [None]:
plt.figure()
plt.scatter(latgrid[:,0], latgrid[:,1])

In [None]:
gradient_step = 0.01
n_rsteps = 12

In [None]:
radial_kernel = gradient_step* np.exp(1j*np.pi*(np.linspace(0, 360, n_rsteps+1)/180))[1:]

In [None]:
grad_p1, decdat1 = get_grad_and_decode_data(comp_latgrid, radial_kernel)

In [None]:
delta_gp1 = batch_calc_grad(grad_p1, radial_kernel, decdat1, SSI_weighting, 256)

In [None]:
def sig_step_from_grad(d_gp, gradient_step, sigz=0.25, sigf=100):
    grad_mag = np.abs(d_gp)

    return sigmoid(grad_mag, sigz, sigf)*gradient_step*(d_gp/grad_mag) 

def sigmoid(z, sigz=0.25, sigf=100):
    x = sigf*(z - sigz)
    return np.exp(-np.logaddexp(0, -x))

def lin_thresh_step(d_gp, thresh, mag = 1):
    scale = np.abs(d_gp)
    return (np.where(scale>thresh, thresh, scale)/thresh)*(d_gp/scale)*mag

def scaled_thresh_step(d_gp, thresh, mobile_points, mag):
    sample_locs = np.concatenate((mobile_points.real[:,None], mobile_points.imag[:,None]), axis = 1)
    nbrs = kNN(n_neighbors=5, algorithm='ball_tree').fit(sample_locs)
    p_sep, indices = nbrs.kneighbors(sample_locs, n_neighbors = 2)
    print(p_sep.shape, p_sep[:,0])
    p_sep = p_sep[:,1]
    closest = p_sep.min()
    norm_sep = p_sep/closest
    
    scale = np.abs(d_gp)
    return (np.where(scale>thresh, thresh, scale)/thresh)*(d_gp/scale)*norm_sep*mag

In [None]:
linsteps = lin_thresh_step(delta_gp1, 0.001, 0.01)

In [None]:
plt.figure()
plt.plot(np.sort(np.abs(linsteps)))

sigmp = srss[int(srss.size//10)]
sig_steps = sig_step_from_grad(delta_gp1,gradient_step, sigmp, 1e6)

In [None]:
#steps1 = norm_step_from_grad(delta_gp1,gradient_step*100)
op1, np1, current_ps1= adjust_encoding(comp_latgrid, linsteps, comp_latgrid, np.where(comp_latgrid != None))

In [None]:
plt.figure(figsize = (8,8))
plt.scatter(op1[0], op1[1],s =20)
plt.scatter(np1[0], np1[1],s =20)
plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1], s = 10, alpha = 0.2)

In [None]:
o_latgrid = np.asarray(op1).T
n_latgrid = np.asarray(np1).T

In [None]:
tform = PiecewiseAffineTransform()
tform.estimate(n_latgrid, o_latgrid)

In [None]:
test_data = sample.encoded_data
out_data = tform.inverse(test_data)

In [None]:
r_out_data = out_data
refine_steps=[]
for i in range(50):
    r_out_data = tform.inverse(r_out_data)
    refine_steps.append(r_out_data)

In [None]:
plt.figure()
#plt.scatter(test_data[:,0], test_data[:,1], s =10)
#plt.scatter(out_data[:,0], out_data[:,1], s =10)
plt.scatter(r_out_data[:,0], r_out_data[:,1], s =10, alpha = 0.1)

In [None]:
from matplotlib import animation, rc
from IPython.display import HTML, Image

In [None]:
rc('animation', html='html5')
# First set up the figure, the axis, and the plot element we want to animate
fig, ax = plt.subplots()

ax.set_xlim(( -1, 1))
ax.set_ylim((-1, 1))

line, = ax.plot([], [], lw=2, ls = '', marker = 'o', alpha = 0.2)

In [None]:
def init():
    line.set_data([], [])
    return (line,)
def animate(i):
    d = refine_steps[i]
    x,y = d[:,0], d[:,1]
    line.set_data(x, y)
    return (line,)
# call the animator. blit=True means only re-draw the parts that 
# have changed.
anim = animation.FuncAnimation(fig, animate, init_func=init,
                               frames=50, interval=20, blit=True)

In [None]:
anim

In [None]:
comp_enc = sample.encoded_data[:,0] +1j*sample.encoded_data[:,1]

In [None]:
nrs, nts = 60,40
max_r = np.max(np.abs(comp_enc))

rs = np.linspace(0, max_r, nrs+1)[1:]
ts = np.linspace(0, 2*np.pi, nts, endpoint = False)

latspace = []

for r in rs:
    for t in ts:
        latspace.append(r*np.exp(-1j*t))
        
latspace = np.asarray(latspace)

In [None]:
gradient_step_r = 0.015
n_rsteps_r = 12

radial_kernel_r = gradient_step_r* np.exp(1j*np.pi*(np.linspace(0, 360, n_rsteps_r+1)/180))[1:]

In [None]:
grad_p1_r, decdat1_r = get_grad_and_decode_data(latspace, radial_kernel_r, 0.01)

In [None]:
plt.figure()
plt.scatter(latspace.real, latspace.imag)
plt.scatter(flatten_nav(grad_p1_r).real, flatten_nav(grad_p1_r).imag)

In [None]:
delta_gp1_r = batch_calc_grad(grad_p1_r, radial_kernel_r, decdat1_r, SSI_weighting, 256)

In [None]:
sts = lin_thresh_step(delta_gp1_r, 0.001, 0.1)

In [None]:
sts

In [None]:
plt.figure()
plt.plot(np.sort(np.abs(sts)))

In [None]:
norm_step_r = delta_gp1_r/np.abs(delta_gp1_r).max()

In [None]:
r_s_norm_step_r = (np.abs(latspace)/np.abs(latspace).min())*gradient_step_r*norm_step_r*100

In [None]:
r_s_step_r = np.where(np.abs(r_s_norm_step_r) > 0.01, (r_s_norm_step_r/np.abs(r_s_norm_step_r))*0.01, r_s_norm_step_r)

In [None]:
op1r, np1r, current_ps1r = adjust_encoding(latspace, sts, latspace, np.where(latspace != None))

In [None]:
plt.figure(figsize = (8,8))
plt.scatter(op1r[0], op1r[1],s =20)
plt.scatter(np1r[0], np1r[1],s =20)
plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1], s = 10, alpha = 0.2)

In [None]:
o_latspacer = np.asarray(op1r).T
n_latspacer = np.asarray(np1r).T

In [None]:
tform2 = PiecewiseAffineTransform()
tform2.estimate(n_latspacer, o_latspacer)

In [None]:
out_data2r = tform2.inverse(r_out_data)

r_out_data2 = out_data2r

In [None]:
plt.figure()
#plt.scatter(test_data[:,0], test_data[:,1], s =10)
#plt.scatter(out_data[:,0], out_data[:,1], s =10)
plt.scatter(r_out_data[:,0], r_out_data[:,1], s =10)
plt.scatter(r_out_data2[:,0], r_out_data2[:,1], s =10)

In [None]:
refine_steps2=[]

for i in range(500):
    r_out_data2 = tform2.inverse(r_out_data2)
    refine_steps2.append(r_out_data2)

In [None]:
plt.figure()
plt.scatter(test_data[:,0], test_data[:,1], s =10)
#plt.scatter(out_data[:,0], out_data[:,1], s =10)
#plt.scatter(r_out_data[:,0], r_out_data[:,1], s =10)
plt.scatter(r_out_data2[:,0], r_out_data2[:,1], s =10)

In [None]:
rc('animation', html='html5')
# First set up the figure, the axis, and the plot element we want to animate
figr, axr = plt.subplots()

axr.set_xlim(( -1, 1))
axr.set_ylim((-1, 1))

liner, = axr.plot([], [], lw=2, ls = '', marker = 'o', alpha = 0.2)

def init2():
    liner.set_data([], [])
    return (liner,)
def animate2(i):
    d = refine_steps2[i]
    x,y = d[:,0], d[:,1]
    liner.set_data(x, y)
    return (liner,)
# call the animator. blit=True means only re-draw the parts that 
# have changed.
animr = animation.FuncAnimation(figr, animate2, init_func=init2,
                               frames=100, interval=20, blit=True)

In [None]:
animr

In [None]:
from sklearn.cluster import DBSCAN

In [None]:
db = DBSCAN(eps = 0.05, min_samples= 10).fit_predict(r_out_data2)

db.max()

In [None]:
plt.figure()
plt.imshow(db.reshape(255,255))

In [None]:
db.min()

In [None]:
db += 1

In [None]:
sample.all_maps['latwarp'] = db.reshape(255,255)

In [None]:
np.unique(sample.all_maps['latwarp'])

In [None]:
sample.get_map_patterns('latwarp',method='mean', recompute = True)

In [None]:
sample.all_patterns['latwarp'].shape

In [None]:
import itertools
ssi_pair = []
ssi_val = []
for comb in itertools.combinations(np.arange(0,sample.all_patterns['latwarp'].shape[0]),2):
    ssi_pair.append(comb)
    ssi_val.append(SSI(sample.all_patterns['latwarp'][comb[0]],sample.all_patterns['latwarp'][comb[1]]))

In [None]:
import networkx as nx

def get_graph_from_connectivity(uthresh):
    '''
    uthresh: the connectivity matrix from which the graph can be generated
    '''
    g = nx.Graph()
    inds_in_graph = np.unique(uthresh)
    g.add_nodes_from(inds_in_graph)
    g.add_edges_from(uthresh)
    return g, inds_in_graph

    
def view_graph(g):
    '''
    g: Graph'''
    plt.figure()
    nx.draw(g, with_labels=True)
    
def get_connected_nodes(g1):
    return list(nx.connected_components(g1))

In [None]:
similarity_inds = np.where(np.asarray(ssi_val)>0.975)

In [None]:
ssi_connect = np.asarray(ssi_pair)[similarity_inds]

In [None]:
g, ginds = get_graph_from_connectivity(ssi_connect)

In [None]:
view_graph(g)

In [None]:
con_nodes = get_connected_nodes(g)

In [None]:
con_nodes

In [None]:
all_nodes = list(range(sample.all_patterns['latwarp'].shape[0]))
for each in con_nodes:
    combine = list(each)
    [all_nodes.pop(all_nodes.index(e)) for e in combine]

In [None]:
all_nodes

In [None]:
blank = np.zeros_like(sample.all_maps['latwarp'])
count = 1
for nc_node in all_nodes:
    blank += np.where(sample.all_maps['latwarp'] == nc_node, count, 0)
    count += 1
for each in con_nodes:
    combine = list(each)
    for c_node in combine:
        blank += np.where(sample.all_maps['latwarp'] == c_node, count, 0)
    count += 1


In [None]:
refine_latwarp1 = blank -1

In [None]:
np.unique(refine_latwarp1)

In [None]:
sample.all_maps['refine_latwarp1'] = refine_latwarp1

In [None]:
sample.imshow(None, 'refine_latwarp1')

In [None]:
sample.get_map_patterns('refine_latwarp1', method = 'mean', recompute = True)

In [None]:
outliers = np.where(sample.all_maps['refine_latwarp1'] == 0)

In [None]:
outliers = np.asarray(outliers).T

In [None]:
n_clust = []
for o in outliers:
    patt = sample.raw_data.data[o[0], o[1]]
    ssis = np.asarray([SSI(patt, cpatt) for cpatt in sample.all_patterns['refine_latwarp1'][1:]])
    n_clust.append(np.argmax(ssis) + 1)
    

In [None]:
n_clust = np.asarray(n_clust)

In [None]:
refine_latwarp2 = sample.all_maps['refine_latwarp1'].copy()

In [None]:
refine_latwarp2[[outliers.T[0],outliers.T[1]]] = n_clust

In [None]:
sample.all_maps['refine_latwarp2'] = refine_latwarp2 -1 

np.save('/dls/science/groups/imaging/ePSIC_students/latwarp_cluster.npy', sample.all_maps['refine_latwarp2'])

sample.all_maps['refine_latwarp2'] = np.load('/dls/science/groups/imaging/ePSIC_students/latwarp_cluster.npy')

In [None]:
sample.imshow(None, 'refine_latwarp2')

In [None]:
plt.figure(figsize = (8,8))
plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1], c = flatten_nav(sample.all_maps['refine_latwarp1']), cmap= 'turbo', s = 10)

In [None]:
sample.get_map_patterns('refine_latwarp2', method = 'mean', recompute=True)

In [None]:
### View the patterns and their associated regions

show_cluster_patterns(sample, 'refine_latwarp2').plot()

In [None]:
### View a signal boosted representation of the sample

sbs = signal_boosted_scan(sample, 'refine_latwarp2')

sbs.plot()

In [None]:
unique_regions = [np.asarray(np.where(sample.all_maps['refine_latwarp2'] == x))[:,0] for x in np.unique(sample.all_maps['refine_latwarp2'])]

f = inv_sbs(sbs, tag='refine_latwarp2', return_fig=True, vmax = 0.1)

In [None]:
figp = '/dls/science/groups/imaging/ePSIC_students/Andy_Bridger/Latspace_Clustering'

In [None]:
time_stamp = str(dp).split('/')[-2]

In [None]:
for i, ur in enumerate(unique_regions):
    f = inv_sbs(sbs,'refine_latwarp2', ur, return_fig=True, vmax = 0.1)
    f.savefig(str(figp)+f'/{time_stamp}-region-{i}-vmax0.01.jpg', dpi = 200)

In [None]:
def sig_step_from_grad(d_gp, gradient_step, sigz=0.25, sigf=100):
    grad_mag = np.abs(d_gp)

    return sigmoid(grad_mag, sigz, sigf)*gradient_step*(d_gp/grad_mag) 

def sigmoid(z, sigz=0.25, sigf=100):
    x = sigf*(z - sigz)
    return np.exp(-np.logaddexp(0, -x))

def lin_thresh_step(d_gp, thresh, mag = 1):
    scale = np.abs(d_gp)
    return (np.where(scale>thresh, thresh, scale)/thresh)*(d_gp/scale)*mag

def scaled_thresh_step(d_gp, thresh, mobile_points, mag):
    sample_locs = np.concatenate((mobile_points.real[:,None], mobile_points.imag[:,None]), axis = 1)
    nbrs = kNN(n_neighbors=5, algorithm='ball_tree').fit(sample_locs)
    p_sep, indices = nbrs.kneighbors(sample_locs, n_neighbors = 2)
    print(p_sep.shape, p_sep[:,0])
    p_sep = p_sep[:,1]
    closest = p_sep.min()
    norm_sep = p_sep/closest
    
    scale = np.abs(d_gp)
    return (np.where(scale>thresh, thresh, scale)/thresh)*(d_gp/scale)*norm_sep*mag


from sklearn.neighbors import KernelDensity
def get_density_net(sample, n_samples, n_bkg_samples, density_approx = 10,  bandwidth=0.5):
    D = sample.encoded_data.copy()
    np.random.shuffle(D)
    D = D[::density_approx]
    kde = KernelDensity(kernel='gaussian', bandwidth=bandwidth).fit(D)
    R = kde.sample(n_samples)
    
    
    
    xmin, xmax = np.floor(np.min(sample.encoded_data[:,0])), np.ceil(np.max(sample.encoded_data[:,0]))
    ymin, ymax = np.floor(np.min(sample.encoded_data[:,1])), np.ceil(np.max(sample.encoded_data[:,1]))
    xmean, ymean = np.mean((xmin, xmax)), np.mean((ymin, ymax))  
    s_samples = 2*np.random.random((n_bkg_samples, 2)) - 1
    
    s_samples = s_samples - np.array((xmean, ymean))
    s_samples[:,0] *= np.abs((xmax - xmean))
    s_samples[:,1] *= np.abs((xmax - xmean))
    
    return np.concatenate((R, s_samples), axis = 0)

import sklearn.metrics.cluster as cmet

def get_map_label_df(map1):
    return np.asarray([np.where(map1 == uinds, 1, 0) for uinds in np.unique(map1)])

def get_cluster_label_overlap(map_pair):
    db1_df,db2_df = map_pair
    label_overlap = np.zeros((db1_df.shape[0], db2_df.shape[0]))
    for i, idf in enumerate(db1_df):
        for j, jdf in enumerate(db2_df):
            label_overlap[i,j] = np.sum(db1_df[i] * db2_df[j])/ np.sum(db1_df[i])
    return label_overlap

def find_map_label(pos, map1):
    return map1[pos]

def get_confidence_from_maps(maps):
    dfs = [x for x in map(get_map_label_df, maps)]

    cluster_overlaps = [x for x in map(get_cluster_label_overlap, [x for x in itertools.permutations(dfs, 2)])]

    overlap_inds = [x for x in itertools.permutations(np.arange(len(maps)), 2)]

    overlap_inds

    len(cluster_overlaps)

    confidence = np.zeros_like(map1, dtype='float32')
    for point in range(len(map1)):
        labels = [i for i in map(find_map_label, np.repeat(point, len(maps)) , maps)]
        total = 0
        for cind, oinds in enumerate(overlap_inds):
            l1, l2 = labels[oinds[0]], labels[oinds[1]]
            total+=cluster_overlaps[cind][l1, l2]
        mean = total/len(overlap_inds)
        confidence[point] = mean
    return confidence



def refine_based_on_density(sample, density_approx = 5, bw = 0.4, n_sample_points = 2500, 
                            n_bkg_points = 500, show_net = True, rand_gradient_step = 0.01,
                           rand_n_rsteps = 12, step_scale = 0.01, step_thresh = 0.001,
                           show_step_size = True, show_net_movement= True, show_first_refinement = False,
                           n_refine_steps = 200, show_refinement = True, animate_refinement = True):
    
    '''
    sample: ProcessedSample with the encoded data
    density_approx: the number of skips to take of randomly shuffed encoded data to approximate the density (default - 10)
    bw: bandwidth for the density approximation (default - 0.3)
    n_sample_points: number of net points to sample from the desnity distribution (default - 4000)
    n_bkg_points: number of net points to sample uniformly (default - 1000)
    show_net: Show the positions of the net points (default - True)
    rand_gradient_step: The distance around net points to sample for gradient approximation (default - 0.01)
    rand_n_rsteps: The number of radial samples around net points to sample for grad approx (default - 12)
    step_scale: The magnitude of the largest steps you want in net movement (default - 0.01)
    step_thresh: The gradient cut-off, above which, the step size == step_scale (default - 0.001)
    show_step_size: Show a graph of the distribution of step sizes (default - True)
    show_net_movement: Show how the net points have distorted (default - True)
    show_first_refinement: Show how the sample points have moved after one step (default - True)
    n_refine_steps: Number of refinement steps to run (default - 500)
    show_refinement: Show how the sample points have moved after n steps (default - True)
    animate_refinement: Show an animation of the refinement process (default - True)
    '''
    
    #create a dictionary to hold some data that might be useful to return
    accessory_dict = {}
    
    #get a density based net
    R = get_density_net(sample, n_sample_points, n_bkg_points, density_approx, bw)
    accessory_dict['net'] = R
    
    #view the point distribution
    if show_net == True:
        plt.figure()
        plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1], s = 10)
        plt.scatter(R[:,0], R[:,1], s = 20)
    
    #get all the sample points for the gradient 
    rand_latspace = R[:,0] + 1j*R[:,1]
    rand_radial_kernel = rand_gradient_step* np.exp(1j*np.pi*(np.linspace(0, 360, rand_n_rsteps+1)/180))[1:]
    rand_grad_p1, rand_decdat1 = get_grad_and_decode_data(rand_latspace, rand_radial_kernel)
    
    #calculate the gradients
    rand_delta_gp1 = batch_calc_grad(rand_grad_p1, rand_radial_kernel, rand_decdat1, SSI_weighting, 256)
    accessory_dict['grads'] = rand_delta_gp1
    
    #scale the gradients for step sizes
    rand_linsteps = lin_thresh_step(rand_delta_gp1, step_thresh, step_scale)
    if show_step_size == True:
        plt.figure()
        plt.plot(np.sort(np.abs(rand_linsteps)))
    accessory_dict['steps'] = rand_linsteps
    
    #adjust the net points
    rand_op1, rand_np1, rand_current_ps1= adjust_encoding(rand_latspace, rand_linsteps, rand_latspace, np.where(rand_latspace != None))
    accessory_dict['net_displacement'] = rand_np1
    if show_net_movement == True:
        plt.figure(figsize = (8,8))
        plt.scatter(rand_op1[0], rand_op1[1],s =20)
        plt.scatter(rand_np1[0], rand_np1[1],s =20)
        plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1], s = 10, alpha = 0.2)
    #refine the points once
    test_data = sample.encoded_data.copy()

    rand_o_latspacer = np.asarray(rand_op1).T
    rand_n_latspacer = np.asarray(rand_np1).T

    rand_tform2 = PiecewiseAffineTransform()
    accessory_dict['transform'] = rand_tform2
    rand_tform2.estimate(rand_n_latspacer, rand_o_latspacer)

    rand_out_data2 = rand_tform2.inverse(test_data)
    
    accessory_dict['first_refinement'] = rand_out_data2.copy()

    if show_first_refinement == True:
        plt.figure()
        plt.scatter(test_data[:,0], test_data[:,1], s =10)
        plt.scatter(rand_out_data2[:,0], rand_out_data2[:,1], s =10)
        
    rand_refine_steps2=[]

    for i in range(n_refine_steps):
        rand_out_data2 = rand_tform2.inverse(rand_out_data2)
        rand_refine_steps2.append(rand_out_data2)
    
    if show_refinement == True:
        plt.figure()
        plt.scatter(test_data[:,0], test_data[:,1], s =10)
        #plt.scatter(out_data[:,0], out_data[:,1], s =10)
        plt.scatter(rand_out_data2[:,0], rand_out_data2[:,1], s =10)
        
    accessory_dict['refinement_steps'] = rand_refine_steps2

    
    if animate_refinement == True:
        # First set up the figure, the axis, and the plot element we want to animate
        figr3, axr3 = plt.subplots()

        axr3.set_xlim(( -1, 1))
        axr3.set_ylim((-1, 1))

        liner3, = axr3.plot([], [], lw=2, ls = '', marker = 'o', alpha = 0.2)

        def init3():
            liner3.set_data([], [])
            return (liner,)
        def animate3(i):
            d3 = rand_refine_steps2[i]
            x3,y3 = d3[:,0], d3[:,1]
            liner3.set_data(x3, y3)
            return (liner3,)
        # call the animator. blit=True means only re-draw the parts that 
        # have changed.
        animr3 = animation.FuncAnimation(figr3, animate3, init_func=init3,
                                       frames=1250, interval=20, blit=True)
        accessory_dict['animation'] = animr3
    return rand_out_data2, accessory_dict
        

# Optimize Parameters

### params:

In [None]:
density_approx = 5
bw = 1
n_sample_points = 1000
n_bkg_points = 50

### process:

In [None]:
D = sample.encoded_data.copy()

np.random.shuffle(D)

D = D[::density_approx]

In [None]:
D.shape

In [None]:
plt.figure()
plt.scatter(D[:,0], D[:,1])

In [None]:
from sklearn.neighbors import KernelDensity
kde = KernelDensity(kernel='gaussian', bandwidth=bw).fit(D)

In [None]:
xgrid = np.linspace(-3,3,100)
ygrid = np.linspace(-3,3,100)
X,Y = np.meshgrid(xgrid, ygrid)
xy = np.vstack((X.ravel(), Y.ravel())).T

In [None]:
xy.shape

In [None]:
Z = kde.score_samples(xy).reshape(X.shape)

In [None]:
plt.figure()
plt.contourf(X,Y,Z, levels = np.linspace(Z.min(), Z.max(), 50))

In [None]:
R = get_density_grid(sample, n_sample_points, n_bkg_points, bw)

### optmise the net point distribution here:

In [None]:
plt.figure()
plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1], s = 10)
plt.scatter(R[:,0], R[:,1], s = 20)

##### Note the fewer net points the faster it will be (both for gradient calculation but also the refinement iterations)

In [None]:
n_sample_points + n_bkg_points

Next 

### params:

In [None]:
rand_gradient_step = 0.01
rand_n_rsteps = 12
step_scale = 0.015
step_thresh = 0.001,

### process:

In [None]:
rand_latspace = R[:,0] + 1j*R[:,1]

rand_radial_kernel = rand_gradient_step* np.exp(1j*np.pi*(np.linspace(0, 360, rand_n_rsteps+1)/180))[1:]

rand_grad_p1, rand_decdat1 = get_grad_and_decode_data(rand_latspace, rand_radial_kernel)

rand_delta_gp1 = batch_calc_grad(rand_grad_p1, rand_radial_kernel, rand_decdat1, SSI_weighting, 256)

In [None]:
rand_linsteps = lin_thresh_step(rand_delta_gp1, step_thresh, step_scale)

plt.figure()
plt.plot(np.sort(np.abs(rand_linsteps)))

In [None]:
#steps1 = norm_step_from_grad(delta_gp1,gradient_step*100)
rand_op1, rand_np1, rand_current_ps1= adjust_encoding(rand_latspace, rand_linsteps, rand_latspace, np.where(rand_latspace != None))
test_data = sample.encoded_data.copy()

rand_o_latspacer = np.asarray(rand_op1).T
rand_n_latspacer = np.asarray(rand_np1).T

rand_tform2 = PiecewiseAffineTransform()
rand_tform2.estimate(rand_n_latspacer, rand_o_latspacer)

rand_out_data2 = rand_tform2.inverse(test_data)

rand_out_data2 = rand_out_data2

### Optimse the movement of net points based on steps:

In [None]:
plt.figure(figsize = (8,8))
plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1], s = 10, alpha = 0.2, c = 'grey')
plt.scatter(rand_op1[0], rand_op1[1],s =20, c='blue')
plt.scatter(rand_np1[0], rand_np1[1],s =20, c ='orange')

### Optimise the effect on the encoded data:

In [None]:
plt.figure()
plt.scatter(test_data[:,0], test_data[:,1], s =10)
#plt.scatter(out_data[:,0], out_data[:,1], s =10)
plt.scatter(rand_out_data2[:,0], rand_out_data2[:,1], s =10)

Next

### params:

In [None]:
n_refine_steps = 1000

### process:

In [None]:
rand_refine_steps2=[]

import time
t1 = time.time()
for i in range(n_refine_steps):
    rand_out_data2 = rand_tform2.inverse(rand_out_data2)
    rand_refine_steps2.append(rand_out_data2)
print(time.time() - t1)

In [None]:
plt.figure()
plt.scatter(test_data[:,0], test_data[:,1], s =10)
#plt.scatter(out_data[:,0], out_data[:,1], s =10)
plt.scatter(rand_out_data2[:,0], rand_out_data2[:,1], s =10)

np.save('/dls/science/groups/imaging/ePSIC_students/Al_alloy_4DSTEM_EM19064-2/refine_enc_positions_4.npy', rand_out_data2)

### Optimise the refinement:

In [None]:
# First set up the figure, the axis, and the plot element we want to animate
figr3, axr3 = plt.subplots()

axr3.set_xlim(( -1, 1))
axr3.set_ylim((-1, 1))

liner3, = axr3.plot([], [], lw=2, ls = '', marker = 'o', alpha = 0.2)

def init3():
    liner3.set_data([], [])
    return (liner,)
def animate3(i):
    d3 = rand_refine_steps2[i]
    x3,y3 = d3[:,0], d3[:,1]
    liner3.set_data(x3, y3)
    return (liner3,)
# call the animator. blit=True means only re-draw the parts that 
# have changed.
animr3 = animation.FuncAnimation(figr3, animate3, init_func=init3,
                               frames=100, interval=20, blit=True)

In [None]:
animr3

In [None]:
all_transforms = [refine_based_on_density(sample, animate_refinement=False, n_refine_steps=1,n_sample_points=1450,n_bkg_points=50,step_scale=0.01, bw = 1, show_first_refinement=False, show_net = False, show_net_movement= False, show_step_size=False,show_refinement=False)[1]['transform'] for x in range(50)]

In [None]:
rp1 = test_data.copy()

In [None]:
def successive_trans(x, ts):
    for t in ts:
        x = t.inverse(x)
    return x

In [None]:
rand_refine_steps2=[]

In [None]:
for i in range(100):
    print(i)
    rp1= successive_trans(rp1, all_transforms)
    rand_refine_steps2.append(rp1)

In [None]:
plt.figure()
plt.scatter(test_data[:,0], test_data[:,1], s =10)
#plt.scatter(refined_positions[0][0][:,0], refined_positions[0][0][:,1])
plt.scatter(rp1[:,0], rp1[:,1])

In [None]:
tset1 = rp1

db1 = DBSCAN(eps = 0.05, min_samples= 10).fit_predict(tset1)

db1-= db1.min()

db1.max()

plt.figure()
plt.imshow(db1.reshape(255,255))

In [None]:
map4 = db1.copy()

In [None]:
map3 = db1.copy()

In [None]:
map2 = db1.copy()

In [None]:
map1 = db1.copy()

In [None]:
plt.figure()
plt.scatter(test_data[:,0], test_data[:,1], s =10, c = db1)
#plt.scatter(refined_positions[0][0][:,0], refined_positions[0][0][:,1])
#plt.scatter(rp1[:,0], rp1[:,1],  c = db1)

In [None]:
conf = get_confidence_from_maps([map1,map2,map3,map4])

plt.figure()
plt.imshow(conf.reshape((255,255)), cmap= cmap)
plt.colorbar()

In [None]:
import palettable
import matplotlib.colors as mcolors

In [None]:
# Given colormap which takes values from 0→50
colors1 = palettable.colorbrewer.sequential.YlGn_9.mpl_colormap(np.linspace(0, 1, 256))
colors1[0] = [0.,0.,0.,1.]
# generating a smoothly-varying LinearSegmentedColormap
cmap = mcolors.LinearSegmentedColormap.from_list('colormap', colors1)

In [None]:
colors1[0]

In [None]:
for uind in np.unique(map1):
    plt.figure()
    plt.imshow((np.where(map1 == uind,1,0) * conf).reshape((255,255)), cmap= cmap, interpolation = 'nearest')
    plt.colorbar()
    plt.clim(0,1)
    

In [None]:
rtrans1 = refined_positions[0][1]['transform']
rtrans2 = refined_positions[1][1]['transform']
rtrans3 = refined_positions[2][1]['transform']

In [None]:
for i in range(100):
    rp1= successive_trans(rp1, (rtrans1,rtrans2, rtrans3))
    rand_refine_steps2.append(rp1)

In [None]:
plt.figure()
plt.scatter(test_data[:,0], test_data[:,1], s =10)
#plt.scatter(refined_positions[0][0][:,0], refined_positions[0][0][:,1])
plt.scatter(rp1[:,0], rp1[:,1])

In [None]:
tset1 = rp1

db1 = DBSCAN(eps = 0.05, min_samples= 10).fit_predict(tset1)

db1-= db1.min()

db1.max()

plt.figure()
plt.imshow(db1.reshape(255,255))

In [None]:
plt.figure()
plt.scatter(test_data[:,0], test_data[:,1], s =10)
#plt.scatter(refined_positions[0][0][:,0], refined_positions[0][0][:,1])
plt.scatter(rp1[:,0], rp1[:,1], c = db1)

In [None]:
b = db1
for t in [0.98,0.97,0.96,0.95]:
    b, g, cc_fig = combine_closest_clusters(b, t)

In [None]:
plt.figure()
plt.scatter(test_data[:,0], test_data[:,1], s =10, c = b, cmap = 'turbo')

plt.figure()
plt.imshow(b.reshape(255,255), cmap = 'turbo')

In [None]:
cluster_centres = [np.mean(rp1[np.where(db1 ==uind)], axis = 0) for uind in np.unique(db1)]

In [None]:
cluster_centres = np.asarray(cluster_centres)

In [None]:
cluster_centres

In [None]:
plt.figure()
plt.scatter(test_data[:,0], test_data[:,1], s =10)
#plt.scatter(refined_positions[0][0][:,0], refined_positions[0][0][:,1])
plt.scatter(rp1[:,0], rp1[:,1])
plt.scatter(cluster_centres[:,0], cluster_centres[:,1], marker = 'x')

In [None]:
c_patts = get_terr_patts(cluster_centres)

In [None]:
from sklearn.neighbors import NearestNeighbors

In [None]:
nbrs = NearestNeighbors(n_neighbors=5, algorithm='ball_tree').fit(cluster_centres)
distances, indices = nbrs.kneighbors(cluster_centres, 2)

In [None]:
indices = np.asarray(indices)[:,1:]

In [None]:
comb_nodes = []
thresh_sim = 0.975
for pi, sis in enumerate(indices):
    for si in sis:
        if SSI(c_patts[pi], c_patts[si]) > thresh_sim:
            comb_nodes.append([pi,si])

In [None]:
g, ginds = get_graph_from_connectivity(comb_nodes)

view_graph(g)

In [None]:
con_nodes = get_connected_nodes(g)

con_nodes

In [None]:
combine_centres = np.zeros(cluster_centres.shape[0])

In [None]:
for i, each in enumerate(con_nodes):
    for e in each:
        combine_centres[e] = i

In [None]:
plt.figure()
plt.scatter(test_data[:,0], test_data[:,1], s =10)
#plt.scatter(refined_positions[0][0][:,0], refined_positions[0][0][:,1])
plt.scatter(rp1[:,0], rp1[:,1])
plt.scatter(cluster_centres[:,0], cluster_centres[:,1], marker = 'o', c = combine_centres, s = 30)

In [None]:
all_nodes = [u for u in np.unique(db1)]
for each in con_nodes:
    [all_nodes.pop(all_nodes.index(e)) for e in each]

In [None]:
all_nodes

In [None]:
blank = np.zeros_like(db1)
count = 1 
for an in all_nodes:
    blank += np.where(db1 ==an, count, 0)
    count += 1
for each in con_nodes:
    count += 1
    for e in each:
        blank += np.where(db1 ==e, count, 0)
    

In [None]:
plt.figure()
plt.scatter(test_data[:,0], test_data[:,1], s =10)
#plt.scatter(refined_positions[0][0][:,0], refined_positions[0][0][:,1])
plt.scatter(rp1[:,0], rp1[:,1], c = blank)
#plt.scatter(cluster_centres[:,0], cluster_centres[:,1], marker = 'x')

In [None]:
plt.figure()
plt.imshow(blank.reshape(255,255))

In [None]:
def combine_closest_clusters(input_cluster, thresh_sim = 0.975):
    db1 = input_cluster
    cluster_centres = [np.mean(rp1[np.where(db1 ==uind)], axis = 0) for uind in np.unique(db1)]

    cluster_centres = np.asarray(cluster_centres)

    cluster_centres

    c_patts = get_terr_patts(cluster_centres)

    from sklearn.neighbors import NearestNeighbors

    nbrs = NearestNeighbors(n_neighbors=5, algorithm='ball_tree').fit(cluster_centres)
    distances, indices = nbrs.kneighbors(cluster_centres, 2)

    indices = np.asarray(indices)[:,1:]

    comb_nodes = []
    for pi, sis in enumerate(indices):
        for si in sis:
            if SSI(c_patts[pi], c_patts[si]) > thresh_sim:
                comb_nodes.append([pi,si])

    g, ginds = get_graph_from_connectivity(comb_nodes)

    con_nodes = get_connected_nodes(g)

    combine_centres = np.zeros(cluster_centres.shape[0])

    for i, each in enumerate(con_nodes):
        for e in each:
            combine_centres[e] = i
    cc_fig = plt.figure()
    plt.scatter(test_data[:,0], test_data[:,1], s =10)
    #plt.scatter(refined_positions[0][0][:,0], refined_positions[0][0][:,1])
    plt.scatter(rp1[:,0], rp1[:,1])
    plt.scatter(cluster_centres[:,0], cluster_centres[:,1], marker = 'o', c = combine_centres, s = 30)
    all_nodes = [u for u in np.unique(db1)]
    for each in con_nodes:
        [all_nodes.pop(all_nodes.index(e)) for e in each]

    blank = np.zeros_like(db1)
    count = 0
    for an in all_nodes:
        count += 1
        blank += np.where(db1 ==an, count, 0)
    for each in con_nodes:
        count += 1
        for e in each:
            blank += np.where(db1 ==e, count, 0)
    blank -= 1
            
    return blank, g, cc_fig

In [None]:
b, g, cc_fig = combine_closest_clusters(db1)

In [None]:
b2, g, cc_fig = combine_closest_clusters(b, 0.96)

In [None]:
b3, g, cc_fig = combine_closest_clusters(b2, 0.96)

In [None]:
b4, g, cc_fig = combine_closest_clusters(b3, 0.95)

In [None]:
b5, g, cc_fig = combine_closest_clusters(b4, 0.95)

In [None]:
plt.figure()
plt.scatter(test_data[:,0], test_data[:,1], s =10, c = b5, cmap = 'turbo')

In [None]:
plt.figure()
plt.imshow(b5.reshape(255,255), cmap = 'turbo')

In [None]:
for i, each in enumerate(con_nodes):
    for e in each:
        [np.mean(rp1[np.where(db1 ==uind)], axis = 0) for uind in np.unique(db1)]
        
        
        combine_centres[e] = i

In [None]:
# First set up the figure, the axis, and the plot element we want to animate
figr3, axr3 = plt.subplots()

axr3.set_xlim(( -1, 1))
axr3.set_ylim((-1, 1))

liner3, = axr3.plot([], [], lw=2, ls = '', marker = 'o', alpha = 0.2)

def init3():
    liner3.set_data([], [])
    return (liner,)
def animate3(i):
    d3 = rand_refine_steps2[i]
    x3,y3 = d3[:,0], d3[:,1]
    liner3.set_data(x3, y3)
    return (liner3,)
# call the animator. blit=True means only re-draw the parts that 
# have changed.
animr3 = animation.FuncAnimation(figr3, animate3, init_func=init3,
                               frames=100, interval=20, blit=True)

In [None]:
plt.figure()
plt.scatter(test_data[:,0], test_data[:,1], s =10)
#plt.scatter(refined_positions[0][0][:,0], refined_positions[0][0][:,1])
plt.scatter(rp1[:,0], rp1[:,1])

In [None]:
rps = np.asarray([x[0] for x in refined_positions])

In [None]:
rps.shape

In [None]:
plt.figure()
plt.scatter(rps[0,:,0], rps[0,:,1])

In [None]:
plt.figure()
for e in rps:
    plt.scatter(e[:,0], e[:,1])

In [None]:
plt.figure()
mean_pos = np.mean(rps, axis = 0)
plt.scatter(mean_pos[:,0], mean_pos[:,1], alpha = 0.2, c=db3, cmap = 'turbo')

In [None]:
tset1 = rp1

In [None]:
db1 = DBSCAN(eps = 0.05, min_samples= 50).fit_predict(tset1)

db1-= db1.min()

In [None]:
db1.max()

In [None]:
plt.figure()
plt.imshow(db1.reshape(255,255))

In [None]:
tset2 = rps[1]

db2 = DBSCAN(eps = 0.05, min_samples= 50).fit_predict(tset2)

db2-= db2.min()
plt.figure()
plt.imshow(db2.reshape(255,255))

In [None]:
tset3 = rps[2]

db3 = DBSCAN(eps = 0.05, min_samples= 50).fit_predict(tset3)

db3-= db3.min()

plt.figure()
plt.imshow(db3.reshape(255,255))

In [None]:
import sklearn.metrics.cluster as cmet

def get_map_label_df(map1):
    return np.asarray([np.where(map1 == uinds, 1, 0) for uinds in np.unique(map1)])

def get_cluster_label_overlap(map_pair):
    db1_df,db2_df = map_pair
    label_overlap = np.zeros((db1_df.shape[0], db2_df.shape[0]))
    for i, idf in enumerate(db1_df):
        for j, jdf in enumerate(db2_df):
            label_overlap[i,j] = np.sum(db1_df[i] * db2_df[j])/ np.sum(db1_df[i])
    return label_overlap

def find_map_label(pos, map1):
    return map1[pos]

def get_confidence_from_maps(maps):
    dfs = [x for x in map(get_map_label_df, maps)]

    cluster_overlaps = [x for x in map(get_cluster_label_overlap, [x for x in itertools.permutations(dfs, 2)])]

    overlap_inds = [x for x in itertools.permutations(np.arange(len(maps)), 2)]

    overlap_inds

    len(cluster_overlaps)

    confidence = np.zeros_like(map1, dtype='float32')
    for point in range(len(map1)):
        labels = [i for i in map(find_map_label, np.repeat(point, len(maps)) , maps)]
        total = 0
        for cind, oinds in enumerate(overlap_inds):
            l1, l2 = labels[oinds[0]], labels[oinds[1]]
            total+=cluster_overlaps[cind][l1, l2]
        mean = total/len(overlap_inds)
        confidence[point] = mean
    return confidence
    

In [None]:
conf = get_confidence_from_maps([map1,map2,map3])

In [None]:
plt.figure()
plt.imshow(conf.reshape((255,255)))

In [None]:
dfs = [x for x in dfs]

In [None]:
db1_df = np.asarray([np.where(map1 == uinds, 1, 0) for uinds in np.unique(map1)])

db2_df = np.asarray([np.where(map2 == uinds, 1, 0) for uinds in np.unique(map2)])

db3_df = np.asarray([np.where(map3 == uinds, 1, 0) for uinds in np.unique(map3)])

In [None]:
db1_df.shape

In [None]:
olap = db1_df[2]*db2_df[3]

In [None]:
def get_cluster_label_overlap(db1_df,db2_df):
    label_overlap = np.zeros((db1_df.shape[0], db2_df.shape[0]))
    for i, idf in enumerate(db1_df):
        for j, jdf in enumerate(db2_df):
            label_overlap[i,j] = np.sum(db1_df[i] * db2_df[j])/ np.sum(db1_df[i])
    return label_overlap

In [None]:
confidence = np.zeros_like(map1, dtype='float32')
for p in range(len(map1)):
    p1 = map1[p]
    p2 = map2[p]
    p3 = map3[p]
    c12 = np.sum(db1_df[p1] * db2_df[p2])/ np.sum(db1_df[p1])
    c21 = np.sum(db2_df[p2] * db1_df[p1])/ np.sum(db2_df[p2])
    c23 = np.sum(db2_df[p2] * db3_df[p3])/ np.sum(db2_df[p2])
    c32 = np.sum(db3_df[p3] * db2_df[p2])/ np.sum(db3_df[p3])
    c31 = np.sum(db3_df[p3] * db1_df[p1])/ np.sum(db3_df[p3])
    c13 = np.sum(db1_df[p1] * db3_df[p3])/ np.sum(db1_df[p1])
    confidence[p] =  np.mean((c12, c21, c23, c32, c31, c13))
    

In [None]:
plt.figure()
plt.imshow(map1.reshape((255,255)))

In [None]:
plt.figure()
plt.imshow(map2.reshape((255,255)))

In [None]:
plt.figure()
plt.imshow(map3.reshape((255,255)))

In [None]:
plt.figure()
plt.imshow(confidence.reshape((255,255)))

In [None]:
for uind in np.unique(map1):
    plt.figure()
    plt.imshow((np.where(map1 == uind,1,0) * confidence).reshape((255,255)))

In [None]:
low_conf = np.where(confidence< 0.3, 1, 0)

plt.figure()
plt.imshow(low_conf.reshape((255,255)))

In [None]:
plt.figure()
plt.scatter(test_data[:,0], test_data[:,1], s =10, c = db1)

In [None]:
db3 = DBSCAN(eps = 0.05, min_samples= 50).fit_predict(rand_out_data2)

db3.max()

In [None]:
plt.figure()
plt.imshow(db3.reshape(255,255))

In [None]:
[len(np.where(db3 == uind)[0]) for uind in np.unique(db3)]

In [None]:
pop_clust = [np.where(db3 == uind, 1, 0) for uind in np.unique(db3) if len(np.where(db3 == uind)[0])> 40]

In [None]:
blank = np.zeros_like(db3)
for ind, df in enumerate(pop_clust):
    blank += df*(ind+1)

In [None]:
plt.figure()
plt.imshow(blank.reshape((255,255)))

In [None]:
plt.figure()
plt.scatter(rand_out_data2[:,0], rand_out_data2[:,1], s =10, c = db3)

In [None]:
db3 += 1

sample.all_maps['latwarp_random'] = db3.reshape(255,255)

In [None]:
np.unique(sample.all_maps['latwarp_random'])

sample.get_map_patterns('latwarp_random',method='mean', recompute = True)

sample.all_patterns['latwarp_random'].shape

import itertools
ssi_pair = []
ssi_val = []
for comb in itertools.combinations(np.arange(0,sample.all_patterns['latwarp_random'].shape[0]),2):
    ssi_pair.append(comb)
    ssi_val.append(SSI(sample.all_patterns['latwarp_random'][comb[0]],sample.all_patterns['latwarp_random'][comb[1]]))


In [None]:
similarity_inds = np.where(np.asarray(ssi_val)>0.975)

ssi_connect = np.asarray(ssi_pair)[similarity_inds]

g, ginds = get_graph_from_connectivity(ssi_connect)

view_graph(g)

In [None]:
con_nodes = get_connected_nodes(g)

con_nodes

all_nodes = list(range(sample.all_patterns['latwarp_random'].shape[0]))
for each in con_nodes:
    combine = list(each)
    [all_nodes.pop(all_nodes.index(e)) for e in combine]

all_nodes

blank = np.zeros_like(sample.all_maps['latwarp_random'])
count = 1
for nc_node in all_nodes:
    blank += np.where(sample.all_maps['latwarp_random'] == nc_node, count, 0)
    count += 1
for each in con_nodes:
    combine = list(each)
    for c_node in combine:
        blank += np.where(sample.all_maps['latwarp_random'] == c_node, count, 0)
    count += 1


refine_latwarp1 = blank -1

In [None]:
sample.all_maps['latwarp2'] = refine_latwarp1

In [None]:
sample.imshow(refine_latwarp1)

In [None]:
sample.get_map_patterns('latwarp2', method = 'mean', recompute = True)

In [None]:
ssi_pair2 = []
ssi_val2 = []
for comb in itertools.combinations(np.arange(0,sample.all_patterns['latwarp2'].shape[0]),2):
    ssi_pair2.append(comb)
    ssi_val2.append(SSI(sample.all_patterns['latwarp2'][comb[0]],sample.all_patterns['latwarp2'][comb[1]]))

In [None]:
similarity_inds2 = np.where(np.asarray(ssi_val2)>0.955)

ssi_connect2 = np.asarray(ssi_pair2)[similarity_inds2]

g2, ginds2 = get_graph_from_connectivity(ssi_connect2)

view_graph(g2)

In [None]:
con_nodes2 = get_connected_nodes(g2)

con_nodes2

all_nodes2 = list(range(sample.all_patterns['latwarp2'].shape[0]))
for each in con_nodes2:
    combine = list(each)
    [all_nodes2.pop(all_nodes2.index(e)) for e in combine]

all_nodes2

blank = np.zeros_like(sample.all_maps['latwarp2'])
count = 1
for nc_node in all_nodes2:
    blank += np.where(sample.all_maps['latwarp2'] == nc_node, count, 0)
    count += 1
for each in con_nodes2:
    combine = list(each)
    for c_node in combine:
        blank += np.where(sample.all_maps['latwarp2'] == c_node, count, 0)
    count += 1


refine_latwarp2 = blank -1

In [None]:
sample.imshow(refine_latwarp2)

In [None]:
distance_moved = np.linalg.norm(sample.encoded_data.copy()- rand_refine_steps2[-1], axis = 1)

In [None]:
dmap = distance_moved.reshape(refine_latwarp2.shape)

In [None]:
plt.figure()
plt.imshow(dmap)

refined CoM

In [None]:
cluster_com = []
com_inds = []
flat_latwarp = flatten_nav(refine_latwarp2)
for uind in np.unique(flat_latwarp):
    com = np.mean(rand_refine_steps2[-1][np.where(flat_latwarp==uind)], axis = 0)
    cluster_com.append(com)
    com_inds.append(uind)
    

In [None]:
cluster_com = np.asarray(cluster_com)


In [None]:
plt.figure()
plt.scatter(cluster_com[:,0], cluster_com[:,1])

In [None]:
point_ccentre_distances = np.linalg.norm(sample.encoded_data.copy()[:,None,:] - cluster_com[None,:], axis = 2)

In [None]:
closest_centre = np.argmin(point_ccentre_distances, axis = 1)

In [None]:
cc_overlap = np.where(closest_centre == flat_latwarp,1, 0).reshape(refine_latwarp2.shape)

In [None]:
plt.figure()
plt.imshow(cc_overlap)

In [None]:
sorted_ccentre_d = np.sort(point_ccentre_distances, axis = 1)
argsort_ccentre = np.argsort(point_ccentre_distances, axis = 1)

In [None]:
cluster_closeness_ind = []
for i in range(len(sample.encoded_data)):
    cluster_closeness_ind.append(np.where(argsort_ccentre[i] == flat_latwarp[i])[0][0])

In [None]:
plt.figure()
plt.imshow(np.asarray(cluster_closeness_ind).reshape(refine_latwarp2.shape))

In [None]:
sorted_ccentre_d[0]

In [None]:
cluster_closeness_difference = []
for i in range(len(sample.encoded_data)):
    cluster_closeness_ind = np.where(argsort_ccentre[i] == flat_latwarp[i])[0][0]
    if cluster_closeness_ind == 0:
        cluster_closeness_difference.append(sorted_ccentre_d[i][1] - sorted_ccentre_d[i][0])
    else:
        cluster_closeness_difference.append(sorted_ccentre_d[i][0] - sorted_ccentre_d[i][cluster_closeness_ind])
cluster_closeness_difference = np.asarray(cluster_closeness_difference)

In [None]:
sorted_ccentre_d.shape

In [None]:
cluster_closeness_difference = np.asarray(cluster_closeness_difference).reshape(refine_latwarp2.shape)

In [None]:
plt.figure()
plt.imshow(cluster_closeness_difference )

In [None]:
plt.figure()
plt.imshow(np.where(cluster_closeness_difference>0,1,0 ))

In [None]:
mean_patt = np.mean(sample.decoded_data, axis = (0,1))

In [None]:
plt.figure()
plt.imshow(mean_patt)

In [None]:
lg = get_latgrid(sample, 100)

In [None]:
flg = flatten_nav(lg)

In [None]:
plt.figure()
plt.scatter(flg[:,0], flg[:,1])

In [None]:
lg_patts = get_terr_patts(flg)

In [None]:
mean_patt.shape

In [None]:
lg_ssi = [SSI(mean_patt, lg_p) for lg_p in lg_patts]

In [None]:
lg_ssi_img = np.asarray(lg_ssi).reshape(lg.shape[0:2])
plt.figure()
plt.imshow(lg_ssi_img)

In [None]:
plt.figure()
plt.scatter(flg[:,0], flg[:,1], c = lg_ssi)
plt.colorbar()

In [None]:
plt.figure()
plt.scatter(flg[:,0], flg[:,1], c = lg_ssi)
plt.colorbar()
plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1], s = 10, c = 'black', alpha = 0.5)

## Individual DA investigation

In [None]:
gradient_step = 0.05
n_rsteps = 16

In [None]:
radial_kernel = gradient_step* np.exp(1j*np.pi*(np.linspace(0, 360, n_rsteps+1)/180))[1:]

In [None]:
def get_individual_movement(point, radial_kernel, w_func):
    gradient_points = radial_kernel +point
    lat_points = np.append(np.array((point)), gradient_points)
    batch = np.concatenate((lat_points.real[:,None], lat_points.imag[:,None]), axis = 1)
    lat_patterns = sample.model.decoder(batch).numpy()[:,:,:,0]
    c_patt, grad_patts = lat_patterns[0,:,:], lat_patterns[1:,:,:]
    weightings = [w_func(c_patt, gp) for gp in grad_patts]
    movement = np.sum([weightings[i]*radial_kernel[i] for i in range(len(weightings))])
    return movement, lat_patterns, weightings

In [None]:
test_point = -0.298 +0.389j

In [None]:
test_point

In [None]:
move,lp,w = get_individual_movement(test_point, radial_kernel, SSI_weighting)

In [None]:
np.abs(move)

In [None]:
w

In [None]:
fig, ax = plt.subplots(figsize= (10,10))

ax.scatter(radial_kernel.real, radial_kernel.imag)

from matplotlib.offsetbox import (TextArea, DrawingArea, OffsetImage,
                                  AnnotationBbox)

im = OffsetImage(lp[0], zoom=0.5)
im.image.axes = ax

ab = AnnotationBbox(im, (0, 0),
                    xybox=(0, 50),
                    xycoords='data',
                    boxcoords="offset points",
                    pad=0.3,
                    arrowprops=dict(arrowstyle="->"))
ax.add_artist(ab)


for i in range(len(radial_kernel)):
    im = OffsetImage(lp[i+1], zoom=0.5)
    im.image.axes = ax

    ab = AnnotationBbox(im, (radial_kernel[i].real, radial_kernel[i].imag),
                        xybox=(0, 50.),
                        xycoords='data',
                        boxcoords="offset points",
                        pad=0.3,
                        arrowprops=dict(arrowstyle="->"))
    ax.add_artist(ab)
    

vec = gradient_step*move/(2*np.abs(move))

ax.arrow(0,0,vec.real, vec.imag)
ax.axis('off')

ax.annotate(f'{np.round(np.abs(move),5)}',(vec.real,vec.imag+0.005),fontsize=15)

In [None]:
test_data = np.concatenate((np.random.rand(3000,2)-1, np.random.rand(3000,2)+1))

In [None]:
from sklearn.neighbors import NearestNeighbors
neigh = NearestNeighbors(n_neighbors=200)
neigh.fit(test_data)

In [None]:
d = np.mean(neigh.kneighbors(test_data)[0], axis = 1)

In [None]:
plt.figure()
plt.scatter(test_data[:,0], test_data[:,1], c= d)

In [None]:
from skimage.transform import PiecewiseAffineTransform, warp

In [None]:

src_cols = np.linspace(-1, 1, 20)
src_rows = np.linspace(-1, 1, 20)
src_rows, src_cols = np.meshgrid(src_rows, src_cols)
src = np.dstack([src_cols.flat, src_rows.flat])[0]




In [None]:
latgrid = flatten_nav(get_latgrid(sample, res =200))

comp_latgrid = latgrid[:,0] + 1j*latgrid[:,1]

In [None]:
comp_latspace = latspace[:,0] + 1j*latspace[:,1]

In [None]:
gradient_step = 0.05
n_rsteps = 16

In [None]:
radial_kernel = gradient_step* np.exp(1j*np.pi*(np.linspace(0, 360, n_rsteps+1)/180))[1:]

In [None]:
grad_p1, decdat1 = get_grad_and_decode_data(comp_latspace, radial_kernel)

In [None]:
delta_gp1 = batch_calc_grad(grad_p1, radial_kernel, decdat1, SSI_weighting, 256)

In [None]:
steps1 = norm_step_from_grad(delta_gp1,gradient_step*10)
op1, np1, current_ps1, gm1 = adjust_encoding(comp_latspace, steps1, comp_latspace, np.where(comp_latspace != None))

In [None]:
plt.figure(figsize = (8,8))
plt.scatter(op1[0], op1[1],s =50)
plt.scatter(np1[0], np1[1],s =50)
plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1], s = 10, alpha = 0.2)

In [None]:
o_latspace = np.asarray(op1).T
n_latspace = np.asarray(np1).T

In [None]:
n_latspace.shape

In [None]:
tform = PiecewiseAffineTransform()
tform.estimate(n_latspace, o_latspace)

In [None]:
test_data = sample.encoded_data

In [None]:
out_data = tform.inverse(test_data)

In [None]:
for i in range(100):
    r_out_data = tform.inverse(r_out_data)

In [None]:
plt.figure()
#plt.scatter(test_data[:,0], test_data[:,1], s =10)
#plt.scatter(out_data[:,0], out_data[:,1], s =10)
plt.scatter(out_data[:,0], out_data[:,1], s =10, alpha = 0.01, c=db)

In [None]:
db = DBSCAN(eps = 0.05, min_samples= 10).fit_predict(r_out_data)

In [None]:
db.max()

In [None]:
plt.figure()
plt.imshow(db.reshape(255,255))

In [None]:
n_latspace = latspace +np.random.rand(latspace.shape[0],latspace.shape[1])/5

In [None]:
plt.figure()
plt.scatter(n_latspace[:,0], n_latspace[:,1])

In [None]:
tform = PiecewiseAffineTransform()
tform.estimate(n_latspace, latspace)

In [None]:
trans = tform.inverse(src)

In [None]:
plt.figure()
plt.scatter(src[:,0], src[:,1])

In [None]:
plt.figure()
plt.scatter(trans[:,0], trans[:,1])

In [None]:
kde = KernelDensity(
    bandwidth=0.025, metric="euclidean", kernel="gaussian", algorithm="ball_tree"
)

In [None]:
kde.fit(sample.encoded_data)

In [None]:
nx, ny = (20, 20)
x = np.linspace(-0.1, 0.3, nx)
y = np.linspace(0.4, 0.6, ny)
X, Y = np.meshgrid(x,y)


xy = np.vstack([Y.ravel(), X.ravel()]).T

In [None]:
Z= np.exp(kde.score_samples(xy))
Z = Z.reshape(X.shape)

In [None]:
levels = np.linspace(0, Z.max(), 25)
plt.figure()
plt.contourf(X, Y, Z, levels=levels, cmap=plt.cm.Reds)
plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1],s =5, alpha=0.2)
plt.xlim([-0.1,0.3])
plt.ylim([0.4,0.6])

## Density Adjustment

In [None]:
sample.encoded_data.shape

In [None]:
gradient_step = 0.05
n_rsteps = 16

In [None]:
comp_enc = sample.encoded_data[:,0] +1j*sample.encoded_data[:,1]

In [None]:
radial_kernel = gradient_step* np.exp(1j*np.pi*(np.linspace(0, 360, n_rsteps+1)/180))[1:]

In [None]:
plt.figure()
plt.scatter(radial_kernel.real, radial_kernel.imag)

In [None]:
grad_p1, decdat1 = get_grad_and_decode_data(comp_enc, radial_kernel)

In [None]:
delta_gp1 = batch_calc_grad(grad_p1, radial_kernel, decdat1, SSI_weighting, 256)

In [None]:
steps1 = norm_step_from_grad(delta_gp1,gradient_step)
op1, np1, current_ps1, gm1 = adjust_encoding(comp_enc, steps1, comp_enc, np.where(comp_enc != None))

In [None]:
plt.figure(figsize = (8,8))
#plt.scatter(op1[0], op1[1],s =10,c= np.abs(steps1))
plt.scatter(np1[0], np1[1],s =10, c= np.abs(steps1))
plt.xlim([-0.5,0.5])
plt.ylim([-0.5,0.6])

In [None]:
plt.figure(figsize = (10,10))
#plt.scatter(comp_enc.real, comp_enc.imag,s =10)
plt.scatter(current_ps1.real, current_ps1.imag,s =10)
#plt.xlim([-0.5,0.5])
#plt.ylim([-0.5,0.6])

## Pass 2

In [None]:
mp2, mplocs2 = get_mobile_points(current_ps1,steps1, thresh = 0.0015)

In [None]:
plt.figure(figsize = (10,10))
plt.scatter(current_ps1.real, current_ps1.imag,s =10)
plt.scatter(mp2.real, mp2.imag,s =10)
plt.xlim([-0.5,0.5])
plt.ylim([-0.5,0.6])

In [None]:
grad_p2, decdat2 = get_grad_and_decode_data(mp2, radial_kernel)

In [None]:
delta_gp2 = batch_calc_grad(grad_p2, radial_kernel, decdat2, SSI_weighting, 256)

In [None]:
steps2 = norm_step_from_grad(delta_gp2,gradient_step)
op2, np2, current_ps2, gm2 = adjust_encoding(mp2, steps2, comp_enc, mplocs2)

In [None]:
plt.figure(figsize = (10,10))
plt.scatter(op2[0], op2[1],s =10)#,c= np.abs(steps2))
plt.scatter(np2[0], np2[1],s =10)#, c= np.where(np.abs(steps2)>0.005,1,0))
plt.xlim([-0.5,0.5])
plt.ylim([-0.5,0.6])

In [None]:
plt.figure(figsize = (10,10))
plt.scatter(comp_enc.real, comp_enc.imag,s =10)
plt.scatter(current_ps2.real, current_ps2.imag,s =10)
plt.xlim([-0.5,0.5])
plt.ylim([-0.5,0.6])

In [None]:
updated_mp2 = np2[0]+1j*np2[1]

## Pass 3

In [None]:
mp3, mplocs3 = updated_mp2, mplocs2

In [None]:
len(mp3)

In [None]:
plt.figure(figsize = (10,10))
plt.scatter(current_ps2.real, current_ps2.imag,s =10)
plt.scatter(mp3.real, mp3.imag,s =10)
plt.xlim([-0.5,0.5])
plt.ylim([-0.5,0.6])

In [None]:
grad_p3, decdat3 = get_grad_and_decode_data(mp3, radial_kernel)

In [None]:
delta_gp3 = batch_calc_grad(grad_p3, radial_kernel, decdat3, SSI_weighting, 256)

In [None]:
steps3 = norm_step_from_grad(delta_gp3,3*gradient_step)
op3, np3, current_ps3, gm3 = adjust_encoding(mp3, steps3, comp_enc, mplocs3)
updated_mp3 = np3[0]+1j*np3[1]

In [None]:
plt.figure(figsize = (10,10))
plt.scatter(op3[0], op3[1],s =10)#,c= np.abs(steps2))
plt.scatter(np3[0], np3[1],s =10)#, c= np.abs(steps2))
plt.xlim([-0.5,0.5])
plt.ylim([-0.5,0.6])

In [None]:
plt.figure(figsize = (10,10))
plt.scatter(comp_enc.real, comp_enc.imag,s =10)
plt.scatter(current_ps3.real, current_ps3.imag,s =10)
plt.xlim([-1,1])
plt.ylim([-1,1])

In [None]:
mp4, mplocs4,mp3locs4 = get_mobile_points(updated_mp3,steps3,mplocs3[0],0.0075, relative_locs=True)

In [None]:
plt.figure(figsize = (10,10))
plt.scatter(updated_mp3.real, updated_mp3.imag,s =10)
plt.scatter(mp4.real,mp4.imag,s=10)
plt.xlim([-1,1])
plt.ylim([-1,1])

In [None]:
grad_p4, decdat4 = get_grad_and_decode_data(mp4, radial_kernel)

In [None]:
delta_gp4 = batch_calc_grad(grad_p4, radial_kernel, decdat4, SSI_weighting, 256)

In [None]:
def step_dot(step1,step2):
    fs1 = np.concatenate((step1.real[:,None], step1.imag[:,None]),axis = 1)
    fs2 = np.concatenate((step2.real[:,None], step2.imag[:,None]),axis = 1)
    fs1 = fs1/np.linalg.norm(fs1, axis = 1)[:,None]
    fs2 = fs2/np.linalg.norm(fs2, axis =1)[:,None]
    return [np.dot(fs1[i], fs2[i]) for i in range(len(fs1))]

In [None]:
steps4 = norm_step_from_grad(delta_gp4,10*gradient_step)
bounce_points = np.where(np.asarray(step_dot(steps3[mp3locs4],steps4))<0.75)
steps4[bounce_points] = 0+0*1j



op4, np4, current_ps4, gm4 = adjust_encoding(mp4, steps4, comp_enc, mplocs4)
updated_mp4 = np4[0]+1j*np4[1]

In [None]:
plt.figure(figsize = (10,10))
plt.scatter(op4[0], op4[1],s =10)
plt.scatter(np4[0], np4[1],s =10)#, c= np.abs(steps2))
plt.xlim([-0.5,0.5])
plt.ylim([-0.5,0.6])

In [None]:
plt.figure(figsize = (10,10))
#plt.scatter(comp_enc.real, comp_enc.imag,s =10)
plt.scatter(current_ps4.real, current_ps4.imag,s =10, alpha = 0.2)
plt.xlim([-1,1])
plt.ylim([-1,1])

In [None]:
CP = np.concatenate((current_ps4.real[:,None], current_ps4.imag[:,None]), axis = 1)

In [None]:
np.save('/dls/science/groups/imaging/ePSIC_students/Al_alloy_4DSTEM_EM19064-2/refine_enc_positions_2.npy', CP)

In [None]:
comp_enc = sample.encoded_data[:,0] +1j*sample.encoded_data[:,1]
dec_dat = flatten_nav(sample.decoded_data)

In [None]:
radial_kernel = gradient_step* np.exp(1j*np.pi*(np.linspace(0, 360, n_rsteps+1)/180))[1:]

In [None]:
plt.figure()
plt.scatter(radial_kernel.real, radial_kernel.imag)

In [None]:
mp3, mp3_loc = get_mobile_points(op2[0],op2[1], gm2, mp2_loc)

len(mp3_loc)

gp3, dd3 = get_grad_and_decode_data(mp3, radial_kernel)

d_gp3 = batch_calc_grad(gp3, radial_kernel, dd3,256)

op3, np3, current_ps, gm3 = adjust_encoding(mp3, gradient_step, d_gp3, current_ps, mp3_loc)

plt.figure(figsize = (8,8))
plt.scatter(op3[0], op3[1],s =10)
plt.scatter(np3[0], np3[1],s =10)

plt.figure(figsize = (8,8))
plt.scatter(comp_enc.real, comp_enc.imag,s =10)
plt.scatter(current_ps.real, current_ps.imag,s =10)

In [None]:
grad_points = np.repeat(comp_enc[:,None], n_rsteps, axis = 1) + radial_kernel[None, :]

In [None]:
print(grad_points.shape, dec_dat.shape)

In [None]:
d_gp = batch_calc_grad(grad_points, radial_kernel, dec_dat,256)

In [None]:
X,Y  = comp_enc.real, comp_enc.imag

In [None]:
grad_mag = np.abs(d_gp)

In [None]:
scale = gradient_step/np.max(np.abs(d_gp))

In [None]:
grads = scale*d_gp

In [None]:
dX, dY = grads.real, grads.imag

In [None]:
U, V = X+dX, Y+dY

In [None]:
plt.figure()
plt.scatter(comp_enc.real,comp_enc.imag,s =10,) #c = np.abs(d_gp))
#plt.scatter(U,V, alpha = 0.5, s =2, c = np.abs(d_gp))
#[plt.arrow(X[i],Y[i],dX[i],dY[i]) for i in range(len(X))]
#print('done')

In [None]:
plt.figure()
#plt.scatter(X,Y,s =20, c = np.abs(d_gp))
plt.scatter(U,V, s =20, c = np.abs(d_gp))

In [None]:
new_comp_enc = U+1j*V

In [None]:
mp_locs = np.where(grad_mag > np.mean(grad_mag))

In [None]:
mobile_points = new_comp_enc[mp_locs]

In [None]:
plt.figure()
plt.scatter(mobile_points.real, mobile_points.imag)

In [None]:
grad_points = np.repeat(mobile_points[:,None], n_rsteps, axis = 1) + radial_kernel[None, :]

In [None]:
grad_points.shape

In [None]:
dec_dat = get_terr_patts(np.concatenate([mobile_points.real[:,None], mobile_points.imag[:,None]],axis = 1))

In [None]:
print(grad_points.shape, dec_dat.shape)

In [None]:
d_gp = batch_calc_grad(grad_points, radial_kernel, dec_dat,256)

In [None]:
X,Y  = mobile_points.real, mobile_points.imag

In [None]:
grad_mag = np.abs(d_gp)

In [None]:
scale = gradient_step/np.max(grad_mag)

In [None]:
grads = scale*d_gp

In [None]:
dX, dY = grads.real, grads.imag
U, V = X+dX, Y+dY

In [None]:
plt.figure()
#plt.scatter(X,Y,s =40, c = np.abs(d_gp))
plt.scatter(U,V, alpha = 0.5, s =2, c = np.abs(d_gp))
#[plt.arrow(X[i],Y[i],dX[i],dY[i]) for i in range(len(X))]
#print('done')

In [None]:
moved_points = U+1j*V

In [None]:
migrated_points = comp_enc.copy()

In [None]:
migrated_points[mp_locs] = moved_points

In [None]:
plt.figure(figsize = (8,8))
plt.scatter(comp_enc.real, comp_enc.imag,s =10)
plt.scatter(migrated_points.real, migrated_points.imag,s =10)

In [None]:
def get_mobile_points(U,V,grad_mag, prev_mp_locs = ()):
    nn_comp_enc = U+1j*V

    mp_locs = np.where(grad_mag > np.mean(grad_mag))
    mobile_points = nn_comp_enc[mp_locs]
    
    if len(prev_mp_locs) != 0:
        mp_locs = prev_mp_locs[mp_locs]

    return mobile_points, mp_locs

def get_grad_and_decode_data(mobile_points, radial_kernel):
    grad_points = np.repeat(mobile_points[:,None], n_rsteps, axis = 1) + radial_kernel[None, :]
    dec_dat = get_terr_patts(np.concatenate([mobile_points.real[:,None], mobile_points.imag[:,None]],axis = 1))
    return grad_points, dec_dat

def adjust_encoding(mobile_points, gradient_step, d_gp, comp_enc, mp_locs):
    X,Y  = mobile_points.real, mobile_points.imag

    grad_mag = np.abs(d_gp)

    scale = gradient_step/np.max(grad_mag)

    grads = scale*d_gp

    dX, dY = grads.real, grads.imag
    U, V = X+dX, Y+dY

    moved_points = U+1j*V

    migrated_points = comp_enc.copy()

    migrated_points[mp_locs] = moved_points
    
    return (X,Y), (U,V), migrated_points, grad_mag

In [None]:
mp2, mp2_loc = get_mobile_points(U,V, grad_mag, mp_locs[0])

In [None]:
gp2, dd2 = get_grad_and_decode_data(mp2, radial_kernel)

In [None]:
d_gp2 = batch_calc_grad(gp2, radial_kernel, dd2,256)

In [None]:
def adjust_encoding(mobile_points, gradient_step, d_gp, comp_enc, mp_locs):
    X,Y  = mobile_points.real, mobile_points.imag

    grad_mag = np.abs(d_gp)

    scale = gradient_step/np.max(grad_mag)

    grads = scale*d_gp

    dX, dY = grads.real, grads.imag
    U, V = X+dX, Y+dY

    moved_points = U+1j*V

    migrated_points = comp_enc.copy()

    migrated_points[mp_locs] = moved_points
    
    return (X,Y), (U,V), migrated_points, grad_mag

In [None]:
op2, np2, current_ps, gm2 = adjust_encoding(mp2, gradient_step, d_gp2, migrated_points, mp2_loc)

In [None]:
plt.figure(figsize = (8,8))
plt.scatter(op2[0], op2[1],s =10)
plt.scatter(np2[0], np2[1],s =10)

In [None]:
plt.figure(figsize = (8,8))
plt.scatter(comp_enc.real, comp_enc.imag,s =10)
plt.scatter(current_ps.real, current_ps.imag,s =10)

In [None]:
mp3, mp3_loc = get_mobile_points(op2[0],op2[1], gm2, mp2_loc)

In [None]:
len(mp3_loc)

In [None]:
gp3, dd3 = get_grad_and_decode_data(mp3, radial_kernel)

d_gp3 = batch_calc_grad(gp3, radial_kernel, dd3,256)

In [None]:
op3, np3, current_ps, gm3 = adjust_encoding(mp3, gradient_step, d_gp3, current_ps, mp3_loc)

In [None]:
plt.figure(figsize = (8,8))
plt.scatter(op3[0], op3[1],s =10)
plt.scatter(np3[0], np3[1],s =10)

In [None]:
plt.figure(figsize = (8,8))
plt.scatter(comp_enc.real, comp_enc.imag,s =10)
plt.scatter(current_ps.real, current_ps.imag,s =10)

In [None]:
mp4, mp4_loc = get_mobile_points(op3[0],op3[1], gm3, mp3_loc)

len(mp4_loc)

In [None]:
gp4, dd4 = get_grad_and_decode_data(mp4, radial_kernel)

d_gp4 = batch_calc_grad(gp4, radial_kernel, dd4,256)

op4, np4, current_ps, gm4 = adjust_encoding(mp4, gradient_step, d_gp4, current_ps, mp4_loc)

In [None]:
plt.figure(figsize = (8,8))
plt.scatter(op4[0], op4[1],s =10)
plt.scatter(np4[0], np4[1],s =10)

In [None]:
plt.figure(figsize = (8,8))
plt.scatter(comp_enc.real, comp_enc.imag,s =10)
plt.scatter(current_ps.real, current_ps.imag,s =10)

In [None]:
mp5, mp5_loc = get_mobile_points(op4[0],op4[1], gm4, mp4_loc)

len(mp5_loc)

In [None]:
gp5, dd5 = get_grad_and_decode_data(mp5, radial_kernel)

d_gp5 = batch_calc_grad(gp5, radial_kernel, dd5,256)

op5, np5, current_ps, gm5 = adjust_encoding(mp5, gradient_step, d_gp5, current_ps, mp5_loc)

In [None]:
plt.figure(figsize = (8,8))
plt.scatter(op5[0], op5[1],s =20)
plt.scatter(np5[0], np5[1],s =20)

In [None]:
plt.figure(figsize = (10,10))
plt.scatter(comp_enc.real, comp_enc.imag,s =10)
plt.scatter(current_ps.real, current_ps.imag,s =5)
plt.xlim([-0.5,0.5])
plt.ylim([-0.5,0.5])

## DB clustering

In [None]:
from sklearn.cluster import DBSCAN

In [None]:
sample.raw_data

In [None]:
CP = np.concatenate((comp_enc.real[:,None], comp_enc.imag[:,None]), axis = 1)

np.save('/dls/science/groups/imaging/ePSIC_students/Al_alloy_4DSTEM_EM19064-2/refine_enc_positions.npy', CP)

In [None]:
clustering = DBSCAN(eps=0.02,min_samples=10).fit_predict(CP)

In [None]:
np.max(clustering)

In [None]:
plt.figure(figsize = (8,8))
plt.scatter(comp_enc.real, comp_enc.imag,s =10, c = clustering)
plt.xlim([-0.5,0.5])
plt.ylim([-0.5,0.5])

In [None]:
non_outlier = np.where(clustering>-1)

In [None]:
plt.figure(figsize = (8,8))
plt.scatter(current_ps.real[non_outlier], current_ps.imag[non_outlier],s =10, c = clustering[non_outlier])

In [None]:
def get_latgrid(sample, res=100):
    xmin, xmax = np.floor(np.min(sample.encoded_data[:,0])), np.ceil(np.max(sample.encoded_data[:,0]))
    ymin, ymax = np.floor(np.min(sample.encoded_data[:,1])), np.ceil(np.max(sample.encoded_data[:,1]))

    latgrid_res = res

    xgrid, ygrid = np.repeat(np.linspace(xmin, xmax, latgrid_res)[:,None],latgrid_res, axis = 1), np.repeat(np.linspace(ymin, ymax, latgrid_res)[None,:],latgrid_res, axis = 0)

    return np.concatenate([xgrid[:,:,None], ygrid[:,:,None]],axis = 2)

def closest_gridpoint(latgrid, rp):
    flat_lg = latgrid.reshape((latgrid.shape[0]*latgrid.shape[1],2) )
    diff = flat_lg-rp
    ind = np.argmin(np.linalg.norm(diff, axis = -1))
    return flat_lg[ind], np.unravel_index(ind, (latgrid.shape[0], latgrid.shape[1]))
    
def surrounding_inds(ind, latgrid):
    indx, indy = ind
    boundx = indx%(latgrid.shape[0]-1) == 0
    boundy = indy%(latgrid.shape[0]-1) == 0
    if boundx == False:
        if boundy == False:
            surrinds = [(indx, indy +1), (indx, indy-1), (indx+1, indy +1), (indx+1, indy-1), (indx+1, indy),(indx-1, indy +1), (indx-1, indy-1), (indx-1, indy)]
        elif boundy == True:
            if indy ==0:
                surrinds = [(indx, indy +1), (indx+1, indy +1), (indx+1, indy),(indx-1, indy +1), (indx-1, indy)]
            else:
                surrinds = [(indx, indy-1), (indx+1, indy-1), (indx+1, indy), (indx-1, indy-1), (indx-1, indy)]
    elif boundx == True:
        if indx ==0:
            if boundy == False:
                surrinds = [(indx, indy +1), (indx, indy-1), (indx+1, indy +1), (indx+1, indy-1), (indx+1, indy)]
            elif boundy == True:
                if indy == 0:
                    surrinds = [(indx, indy +1), (indx+1, indy +1), (indx+1, indy)]
                else:
                    surrinds = [(indx, indy-1), (indx+1, indy-1), (indx+1, indy)]
        else:
            if boundy == False:
                surrinds = [(indx, indy +1), (indx, indy-1),(indx-1, indy +1), (indx-1, indy-1), (indx-1, indy)]
            elif boundy == True:
                if indy == 0:
                    surrinds = [(indx, indy +1),(indx-1, indy +1), (indx-1, indy)]
                else:
                    surrinds = [(indx, indy-1), (indx-1, indy-1), (indx-1, indy)]
                    
    return surrinds

def get_lat_region(latgrid, rp, nrmse):
    clust_lat = np.zeros_like(latgrid[:,:,0])

    gp, gind = closest_gridpoint(latgrid, rp)

    clust_lat[gind] = 2

    inds = np.asarray(np.where(clust_lat == 2)).T

    all_surr_inds = [x for x in surrounding_inds(gind, latgrid) if clust_lat[x]==0]

    for i, x in enumerate(all_surr_inds):
        clust_lat[x] = 1

    lat_patts = sample.model.decoder(np.concatenate([latgrid[x][None,:] for x in all_surr_inds], axis = 0)).numpy()[:,:,:,0]

    clust_gt = sample.model.decoder(latgrid[gind[0],gind[1]][None, :]).numpy()[0,:,:,0]
    
    
    metrics = [nrmse(clust_gt, cp) for cp in lat_patts]
    thresh = np.mean(np.flip(np.sort(metrics))[:3])-0.01

    [nrmse(x, clust_gt) for x in lat_patts]

    clust_inc = np.asarray([nrmse(x, clust_gt)>thresh for x in lat_patts]).astype('int')

    clust_inc

    for i, x in enumerate(all_surr_inds):
        clust_lat[x] = (-1, 2)[clust_inc[i]]

    
    growing = True
    while growing == True:
        print('still growing')

        inds = np.asarray(np.where(clust_lat == 2)).T

        inds

        all_surr_inds = [np.asarray([x for x in surrounding_inds(gind, latgrid) if clust_lat[x]==0]) for gind in inds]
        all_surr_inds = [x for x in all_surr_inds if x.size > 0]

        if len(all_surr_inds) != 0:

            all_surr_inds = np.unique(np.concatenate(all_surr_inds, axis = 0), axis = 0)

            for i, x in enumerate(all_surr_inds):
                u,v = x
                clust_lat[u,v] = 1

            lat_patts = sample.model.decoder(np.concatenate([latgrid[x[0],x[1]][None,:] for x in all_surr_inds], axis = 0)).numpy()[:,:,:,0]

            clust_inc = np.asarray([nrmse(x, clust_gt)>thresh for x in lat_patts]).astype('int')

            if np.asarray(np.where(clust_inc==1)).size == 0:
                growing = False

            for i, x in enumerate(all_surr_inds):
                clust_lat[x[0],x[1]] = (-1, 2)[clust_inc[i]]
        else:
            growing = False
    return np.where(clust_lat==2,1,0)

def evaluate_current_seg(lat_regions, region, sample,latgrid):
    n_regions = len(list(lat_regions.keys()))
    lat_regions[int(n_regions+1)] = region*int(n_regions+1)
    comb_seg = np.asarray(list(lat_regions.values())).sum(axis = 0)

    sorted_points= {}
    arg_sorted_points = {}
    for k in lat_regions.keys():
        sorted_points[k] = []
        arg_sorted_points[k] = []

    unsorted_points = []
    arg_unsorted_points = []
    for i,e in enumerate(sample.encoded_data):
        cgp = closest_gridpoint(latgrid, e)[1]
        if comb_seg[cgp] ==0:
            unsorted_points.append(e)
            arg_unsorted_points.append(i)
        else:
            for r in lat_regions.values():
                if r[cgp] != 0:
                    v = r[cgp]
                    sorted_points[v].append(e)
                    arg_sorted_points[v].append(i)
    bmaps = []
    for k in lat_regions.keys():
        bmap = np.zeros(sample.all_maps['vae'].shape[0]*sample.all_maps['vae'].shape[1])
        for a in arg_sorted_points[k]:
            bmap[a] = 1 
        bmaps.append(bmap.reshape(sample.all_maps['vae'].shape[0],sample.all_maps['vae'].shape[1]))
    return sorted_points, arg_sorted_points, unsorted_points, arg_unsorted_points, bmaps

## Pass 1

In [None]:
import skimage.metrics as mets

In [None]:
rp = sample.encoded_data[np.random.randint(0, len(sample.encoded_data))]

In [None]:
rp

In [None]:
gind = closest_gridpoint(latgrid, rp)[1]

In [None]:
gind

In [None]:
gind = (53,55)

In [None]:
grid_moves = [(-1,-1),(0,-1),(1,0),(-1,0),(0,0),(1,0),(1,-1),(1,0),(1,1)]
gps = np.concatenate([latgrid[int(gind[0]+gm[0]),int(gind[1]+gm[1])][None,:] for gm in grid_moves],axis= 0)
patts = sample.model.decoder(gps).numpy()[:,:,:,0]

metrics = [mets.structural_similarity(patts[4], cp) for cp in patts]

thresh = np.mean(np.flip(np.sort(metrics))[1:4])-0.01
print(thresh)

In [None]:
plt.figure(figsize = (8,8))
for i in range(len(grid_moves)):
    cp = patts[i]
    ax = plt.subplot(3,3,int(i+1))
    ax.imshow(cp)
    ax.set_xticks([])
    ax.set_yticks([])
    metric= metrics[i]
    ax.set_title(str(metric))
    for spine in ax.spines.values():
        if metric == np.inf:
            continue
        elif metric >thresh:
            spine.set_edgecolor('green')
            spine.set_lw(10)
        else:
            spine.set_edgecolor('red')
            spine.set_lw(10)

In [None]:
lat_regions = {}

In [None]:
res = 100
latgrid = get_latgrid(sample,res)

In [None]:
rp = sample.encoded_data[np.random.randint(0, len(sample.encoded_data))]

In [None]:
rp

In [None]:
closest_gridpoint(latgrid, rp)

In [None]:
region = get_lat_region(latgrid,rp,mets.structural_similarity)

In [None]:
plt.figure()
plt.imshow(region)

In [None]:
sp, asp, up, aup, bmaps = evaluate_current_seg(lat_regions, region, sample, latgrid)

In [None]:
bmaps

In [None]:
plt.figure()
plt.imshow(bmaps[0])

In [None]:
cont_looping = True
while cont_looping == True:
    rp = sample.encoded_data[aup[np.random.randint(0, len(aup))]]

    region = get_lat_region(latgrid,rp,mets.structural_similarity)

    sp, asp, up, aup, bmaps = evaluate_current_seg(lat_regions, region, sample, latgrid)
    
    print(len(bmaps), len(aup))
    
    if len(aup) == 0:
        cont_looping = False
        

In [None]:
plt.close('all')

In [None]:
len(bmaps)

In [None]:
lat_regions.keys()

In [None]:
def generate_n_primes(N):
    primes  = []
    chkthis = 2
    while len(primes) < N:
        ptest    = [chkthis for i in primes if chkthis%i == 0]
        primes  += [] if ptest else [chkthis]
        chkthis += 1
    return primes

In [None]:
pfacts = generate_n_primes(len(bmaps)+1)

In [None]:
pfacts

In [None]:
umap = np.ones_like(lat_regions[1])
for k,v in lat_regions.items():
    p = pfacts[int(k-1)]
    print(k, p)
    pfact_map = np.where(v!=0, p, 1)
    umap*=pfact_map

In [None]:
final_lat_map = np.zeros_like(umap)
for i, u in enumerate(np.unique(umap)):
    final_lat_map += np.where(umap == u, i+1, 0)

In [None]:
plt.figure()
plt.imshow(final_lat_map)

In [None]:
lat_region_sizes = [np.sum(np.where(v!=0,1,0)) for v in lat_regions.values()]

In [None]:
list(lat_regions.values())[list(np.argsort(lat_region_sizes))]

In [None]:
list(np.argsort(lat_region_sizes))

In [None]:
ranked_lrs = np.asarray([np.where(list(lat_regions.values())[a]!=0,1,0) for a in np.argsort(lat_region_sizes)])

In [None]:
px,py = list(lat_regions.values())[0:2]

In [None]:
plt.figure()
plt.imshow(ranked_lrs[0])

In [None]:
img

In [None]:
def terr_patts(latgrid, res):
    img = latgrid.reshape((res*res,2))
    n_batches = int(np.ceil(img.shape[0]//256))
    batches = [img[i*256:(i+1)*256] for i in range(n_batches+1)]
    nimg = [sample.model.decoder(batch).numpy() for batch in batches]
    return np.concatenate(nimg, axis = 0).reshape((res, res, 128,128))

In [None]:
terr_patts = terr_patts(latgrid, res)

In [None]:
plt.figure()
plt.imshow(terr_patts[55,44])

In [None]:
connect_list = []

for a, lr in enumerate(ranked_lrs[:-15]):
    other_lrs = np.concatenate([ranked_lrs[:-15][:a],ranked_lrs[:-15][a+1:]])
    cpatt = np.mean(terr_patts[np.where(lr==1)], axis = 0)
    overlap = other_lrs + lr[None,:,:]
    if len(np.where(overlap ==2)[0]) > 0:
        ilrs = list(np.unique(np.where(overlap ==2)[0]))
        print(ilrs)
        olrs = ranked_lrs[ilrs]
        patts = [np.mean(terr_patts[np.where(olr==1)],axis = 0) for olr in olrs]
        print(len(patts))
        #plt.figure()
        #ax = plt.subplot(2,1,2)
        #ax.imshow(cpatt)
        #ax.set_title('patt')
        #ax.set_xticks([])
        #ax.set_yticks([])
        for axind, p in enumerate(patts):
            #ax = plt.subplot(2,len(patts),int(axind+1))
            #ax.imshow(p)
            #ax.set_xticks([])
            #ax.set_yticks([])
            metric = mets.structural_similarity(cpatt, p)
            #ax.set_title(str(metric))
            if metric > 0.98:
                connect_list.append((a, ilrs[axind]))
            
        

In [None]:
import networkx as nx
import itertools

In [None]:
connect_list

In [None]:
def get_graph_from_connectivity(uthresh):
    '''
    uthresh: the connectivity matrix from which the graph can be generated
    '''
    g = nx.Graph()
    inds_in_graph = np.unique(uthresh)
    g.add_nodes_from(inds_in_graph)
    g.add_edges_from(uthresh)
    return g, inds_in_graph

    
def view_graph(g):
    '''
    g: Graph'''
    plt.figure()
    nx.draw(g, with_labels=True)
    
def get_connected_nodes(g1):
    return list(nx.connected_components(g1))

In [None]:
g, graph_inds = get_graph_from_connectivity(connect_list)

In [None]:
view_graph(g)

In [None]:
def get_connected_nodes(g1):
    return list(nx.connected_components(g1))

In [None]:
conn_nodes = get_connected_nodes(g)

In [None]:
unconn_nodes = list(np.arange(0, len(ranked_lrs)))

In [None]:
[unconn_nodes.pop(unconn_nodes.index(x)) for x in np.unique(np.concatenate([np.asarray(list(x)) for x in conn_nodes]))]

In [None]:
unconn_nodes

In [None]:
len(graph_inds)

In [None]:
cpatt = np.mean(terr_patts[np.where(ranked_lrs[35]==1)], axis = 0)

In [None]:
plt.figure()
plt.imshow(cpatt)

In [None]:
bmaps[2]

In [None]:
len(np.where(bmaps[0]!=0)[0])

In [None]:
[np.argsort([len(np.where(b!=0)[0]) for b in bmaps])]

In [None]:
biggest_bmaps = np.flip([bmaps[i] for i in [np.argsort([len(np.where(b!=0)[0]) for b in bmaps])][0]], axis = 0)

In [None]:
for b in biggest_bmaps[:20]:
    plt.figure()
    plt.imshow(b)

In [None]:
bfbmap = np.zeros_like(bmaps[0])
for i in range(bmaps[0].shape[0]):
    print(i)
    for j in range(bmaps[0].shape[1]):
        for num, bbm in enumerate(biggest_bmaps):
            if bbm[i,j] != 0:
                bfbmap[i,j] = num  
                break

In [None]:
for i in range(1, 15):
    plt.figure()
    plt.imshow(np.where(bfbmap==i, 1, 0))

In [None]:
ubfbmap = [np.where(bfbmap == u, 1, 0) for u in np.unique(bfbmap)]
sbfbmap = [len(np.where(ubfbmap[i] ==1)[0]) for i in range(len(ubfbmap))]
rbfbmap = [ubfbmap[i] for i in np.flip(np.argsort(sbfbmap))]

In [None]:
for i in range(1, 20):
    plt.figure()
    plt.imshow(rbfbmap[i])

In [None]:
fmap = np.zeros(sample.all_maps['vae'].shape[0]*sample.all_maps['vae'].shape[1])
for i,e in enumerate(sample.encoded_data):
    cgp = closest_gridpoint(latgrid, e)[1]
    cls = final_lat_map[cgp]
    fmap[i] = cls
fmap = fmap.reshape((sample.all_maps['vae'].shape[0],sample.all_maps['vae'].shape[1]))
        

In [None]:
plt.figure()
plt.imshow(fmap)

In [None]:
finalregions = [np.where(fmap == u, 1, 0) for u in np.unique(fmap)]

In [None]:
np.flip([np.argsort([len(np.where(b!=0)[0]) for b in finalregions])])

In [None]:
ranked_final_regions = [finalregions[i] for i in np.flip([np.argsort([len(np.where(b!=0)[0]) for b in finalregions])])[0]]

In [None]:
plt.figure()
plt.imshow(ranked_final_regions[51])

### You can the use the auto_cartography to map the latent terrain and cluster your encoded data

In [None]:
tmp_dir = Path('/dls/science/groups/imaging/ePSIC_students/Andy_Bridger/tmp')

In [None]:
sample.auto_cartography(tmp_dir, terr_resolution=12, n_comps=16, n_segments=36, tag= 'vae',pca_skips=1, use_terr_decomp=True, norm = True, mask_PCA=False, inc_nav=False)

### View the result of the clustering

In [None]:
X, Y = sample.encoded_data[:,0], sample.encoded_data[:,1] 
plt.figure()
fig = plt.scatter(X, Y, c = sample.all_maps['vae'], cmap = 'turbo')

In [None]:
lsc = GM(36).fit_predict(sample.encoded_data)

In [None]:
sample.all_maps['lat_clust'] = lsc.reshape(sample.raw_data.data.shape[0:2])

In [None]:
X, Y = sample.encoded_data[:,0], sample.encoded_data[:,1] 
plt.figure()
fig = plt.scatter(X, Y, c = lsc, cmap = 'turbo')

In [None]:
sample.imshow(None, 'vae')

In [None]:
sample.imshow(None, 'lat_clust')

### Remove the pixels corresponding to the background or support

In [None]:
remove_background(sample, old_tag = 'vae', new_tag = 'vl_vae')

In [None]:
sample.imshow(None, 'vl_vae')

In [None]:
np.save('/dls/science/groups/imaging/ePSIC_students/Andy_Bridger/ManualSegmentation/workflow_seg5.npy',sample.all_maps['vl_vae'])

In [None]:
sample.all_maps['vl_vae']= np.load('/dls/science/groups/imaging/ePSIC_students/Andy_Bridger/ManualSegmentation/workflow_seg5.npy')

In [None]:
X, Y = sample.encoded_data[:,0], sample.encoded_data[:,1] 
plt.figure(figsize = (10,10))
plt.scatter(X, Y, c = sample.all_maps['vl_vae'], cmap = 'turbo', s= 15)

### Calculate the patterns associated with each cluster

In [None]:
sample.get_map_patterns('vl_vae', method = 'mean', recompute=True)

### View the patterns and their associated regions

In [None]:
show_cluster_patterns(sample, 'vl_vae').plot()

### View a signal boosted representation of the sample

In [None]:
sbs = signal_boosted_scan(sample, 'vl_vae')

In [None]:
sbs.plot()

In [None]:
unique_regions = [np.asarray(np.where(sample.all_maps['vl_vae'] == x))[:,0] for x in np.unique(sample.all_maps['vl_vae'])]

In [None]:
def inv_sbs(sbs, sp = (0,0), return_fig = False, interactive = True, vmax= 0.1):
    sbsg = np.repeat(sbs.data.sum(axis= (2,3))[:,:,None],3, -1)
    sbsg /= sbsg.max()
    
    def boost(array):
        return np.log10(np.log10(array+1)+1)

    def format_ax():
        ax[0].set_frame_on(False)
        #ax[1].set_frame_on(False)
        ax[0].set_xticks([])
        ax[0].set_yticks([])
        ax[1].set_xticks([])
        ax[1].set_yticks([])
    fig, ax = plt.subplots(2, 1, gridspec_kw={'height_ratios': [1, 2]}, figsize=(8,8))
    
    
    clust = sample.all_maps['vl_vae'][sp[0],sp[1]]

    clust_loc = np.where(sample.all_maps['vl_vae'] == clust)

    new_nav = sbsg.copy()

    new_nav[clust_loc] = np.array([0.1254902 , 0.69803922, 0.66666667])
    
    
    ax[0].imshow(new_nav)
    ax[1].imshow(boost(sbs.data[sp[0],sp[1]]), cmap= 'gray', vmax=vmax)

    format_ax()
    
    if interactive == True:
    
        global coords
        coords = []

        def onclick(event):
            global ix, iy
            ix, iy = np.round(event.xdata,0), np.round(event.ydata,0)
            print(ix, iy)

            coords.append((ix, iy))

            ax[0].clear()
            ax[1].clear()

            clust = sample.all_maps['vl_vae'][int(iy),int(ix)]

            clust_loc = np.where(sample.all_maps['vl_vae'] == clust)

            new_nav = sbsg.copy()

            new_nav[clust_loc] = np.array([0.1254902 , 0.69803922, 0.66666667])



            ax[0].imshow(new_nav)
            ax[1].imshow(boost(sbs.data[int(iy),int(ix)]), cmap = 'gray', vmax=vmax)

            format_ax()

            ax[0].draw()
            ax[1].draw()


            return coords

        cid = fig.canvas.mpl_connect('button_press_event', onclick)

    if return_fig == True:
        return fig
        

In [None]:
f = inv_sbs(sbs, return_fig=True, vmax = 0.1)

In [None]:
sample.all_maps['vl_vae']

In [None]:
def get_refined_training_set(sample, tag, target_nums = 5000):
    cluster_dict = {}
    lens = []
    cs = []

    for uc in np.unique(sample.all_maps[tag]):
        flat_map = sample.all_maps[tag].reshape((sample.all_maps[tag].shape[0] * sample.all_maps[tag].shape[1]))
        pos = np.where(flat_map == uc)
        lens.append(len(pos[0]))
        cs.append(uc)
        np.random.shuffle(pos)
        cluster_dict[uc] = pos[0]

    new_inds = {}
    for i in np.argsort(lens):
        c = cs[i]
        if lens[i] > target_nums:
            skips = lens[i]//target_nums
            new_inds[c] = cluster_dict[c][::skips]
        else:
            dup = target_nums//lens[i]
            new_inds[c] = np.repeat(cluster_dict[c], dup)

    new_inds

    for v in new_inds.values():
        print(v.shape)

    new_training_set = np.concatenate([v for v in new_inds.values()])
    np.random.shuffle(new_training_set)
    return new_training_set

In [None]:
new_training_set = get_refined_training_set(sample, 'vl_vae',1000)

In [None]:
new_training_set.shape

In [None]:
np.save('/dls/science/groups/imaging/ePSIC_students/Andy_Bridger/ManualSegmentation/refined_training_set3.npy', new_training_set)

In [None]:
figp = Path(f'/dls/science/groups/imaging/ePSIC_students/Andy_Bridger/TransferFigures')

In [None]:
time_stamp = str(dp).split('/')[-2]

In [None]:
for i, ur in enumerate(unique_regions):
    f = inv_sbs(sbs, ur, return_fig=True, vmax = 0.01)
    f.savefig(str(figp)+f'/{time_stamp}-region-{i}-vmax0.01.jpg', dpi = 200)

In [None]:
patts = sample.all_patterns['vl_vae']

In [None]:
patts.shape

In [None]:
patts = patts.reshape(6,5,476,476)

In [None]:
patts

In [None]:
np.save(f'{figp}/tilt1patterns', patts)

In [None]:
sp = dp.redirect('Final_Maps')
sp

In [None]:
sp.mkdir()

In [None]:
sample.save_all(sp, 'transferbin3')

## If loading skip here:

In [None]:
sample.imshow(None, 'vl_vae')

In [None]:
sample.get_map_patterns('vl_vae',method = 'mean', recompute=True)

In [None]:
sbs = signal_boosted_scan(sample, 'vl_vae')

In [None]:
unique_regions = [np.asarray(np.where(sample.all_maps['vl_vae'] == x))[:,0] for x in np.unique(sample.all_maps['vl_vae'])]

In [None]:
def inv_sbs(sbs, sp = (0,0), return_fig = False, interactive = True):
    sbsg = np.repeat(sbs.data.sum(axis= (2,3))[:,:,None],3, -1)
    sbsg /= sbsg.max()
    
    def boost(array):
        barr = np.log10(np.log10(array+1))
        barr = np.where(barr == -np.inf, 10, barr)
        low = np.min(barr)
        return np.where(barr == 10, low, barr)

    def format_ax():
        ax[0].set_frame_on(False)
        #ax[1].set_frame_on(False)
        ax[0].set_xticks([])
        ax[0].set_yticks([])
        ax[1].set_xticks([])
        ax[1].set_yticks([])
    fig, ax = plt.subplots(2, 1, gridspec_kw={'height_ratios': [1, 2]}, figsize=(8,8))
    
    
    clust = sample.all_maps['vl_vae'][sp[0],sp[1]]

    clust_loc = np.where(sample.all_maps['vl_vae'] == clust)

    new_nav = sbsg.copy()

    new_nav[clust_loc] = np.array([0.1254902 , 0.69803922, 0.66666667])
    
    
    ax[0].imshow(new_nav)
    ax[1].imshow(boost(sbs.data[sp[0],sp[1]]), cmap= 'gray')

    format_ax()
    
    if interactive == True:
    
        global coords
        coords = []

        def onclick(event):
            global ix, iy
            ix, iy = np.round(event.xdata,0), np.round(event.ydata,0)
            print(ix, iy)

            coords.append((ix, iy))

            ax[0].clear()
            ax[1].clear()

            clust = sample.all_maps['vl_vae'][int(iy),int(ix)]

            clust_loc = np.where(sample.all_maps['vl_vae'] == clust)

            new_nav = sbsg.copy()

            new_nav[clust_loc] = np.array([0.1254902 , 0.69803922, 0.66666667])



            ax[0].imshow(new_nav)
            ax[1].imshow(boost(sbs.data[int(iy),int(ix)]), cmap = 'gray')

            format_ax()

            ax[0].draw()
            ax[1].draw()


            return coords

        cid = fig.canvas.mpl_connect('button_press_event', onclick)

    if return_fig == True:
        return fig
        

In [None]:
f = inv_sbs(sbs, return_fig=True)

In [None]:
time_stamp = str(dp).split('/')[-2]
figp = Path(f'/dls/science/groups/imaging/ePSIC_students/Andy_Bridger/TransferFigures')

for i, ur in enumerate(unique_regions):
    f = inv_sbs(sbs, ur, return_fig=True)
    f.savefig(str(figp)+f'/{time_stamp}-region-{i}.jpg', dpi = 200)
    f.clf()
plt.close('all')

In [None]:
patts = sample.all_patterns['vl_vae']

patts.shape

In [None]:
patts = patts.reshape(5,5,476,476)

np.save(f'{figp}/tilt2patterns', patts)