# Investigate a Sample

### First run this cell

In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib notebook
%load_ext autoreload
%autoreload 2


#load some packages in
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
import random as python_random
from numba import njit
from tensorboard.plugins.hparams import api as hp
from stemutils.io import Path
import hyperspy.api as hs
import concurrent.futures
from skimage.transform import resize
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from functools import lru_cache
from stemseg.processing_funcs import *
import json
import itertools

#set some variables
print('Using TensorFlow v%s' % tf.__version__)
plt.style.use('default')
python_random.seed(0)
np.random.seed(0)
tf.random.set_seed(0)


#define some functions

###################################################
########### Data Preprocessing ####################
###################################################

def batch_resize(d, bs=512):
    if len(d.shape) == 4:
        flat_d = flatten_nav(d)
    else:
        flat_d = d
    n_batches = int(np.ceil(flat_d.shape[0]//bs))
    batches = [flat_d[i*bs:(i+1)*bs] for i in range(n_batches+1)]
    if len(batches[-1])==0:
        batches.pop(-1)
    print(len(batches[-1]))
    with concurrent.futures.ProcessPoolExecutor(max_workers=4) as exe:
        res = [exe.submit(resize, batch, (batch.shape[0],128,128)) for batch in batches]
    r_batches = [f.result() for f in res]
    return np.concatenate(r_batches, axis = 0).reshape((d.shape[0],128,128))

def data_manip(d, bs = 512):
    if type(d) != np.ndarray:
        print('dask to numpy')
        d = d.compute()
        print('dask to numpy done')
    print('started data manipulations')
    #d = resize(d,(d.shape[0],128,128))
    print('resized')
    d = d.astype('float32')
    for i in range(d.shape[0]):
        d_max = np.max(d[i])
        d[i] = d[i]/d_max
    d = batch_resize(d, bs)
    scaler = np.log(1001)
    return np.log((d*1000)+1)/scaler 

def data_manip_lowq(d, central_box = 128):
    pxc, pyc = d.shape[1]//2, d.shape[2]//2 
    pxl, pxu = pxc - central_box//2, pxc + central_box//2 
    pyl, pyu = pyc - central_box//2, pyc + central_box//2 
    
    d = d[:, pxl:pxu, pyl:pyu]
    if type(d) != np.ndarray:
        print('dask to numpy')
        d = d.compute()
        print('dask to numpy done')
    print('started data manipulations')
    #d = resize(d,(d.shape[0],128,128))
    print('resized')
    d = d.astype('float32')
    for i in range(d.shape[0]):
        d_max = np.max(d[i])
        d[i] = d[i]/d_max
    
    scaler = np.log(1001)
    return np.log((d*1000)+1)/scaler 



###################################################
###################################################
###################################################

def flatten_nav(sig):
    shape = [sig.shape[0]*sig.shape[1]]
    for i in sig.shape[2:]:
        shape.append(i)
    return sig.reshape(shape)


class My_Custom_Generator(keras.utils.Sequence) :
    def __init__(self, image_filenames,  batch_size) :
        self.image_filenames = image_filenames
        self.batch_size = batch_size

    def __len__(self) :
        return (np.ceil(len(self.image_filenames) / float(self.batch_size))).astype(np.int)
    
    
    @lru_cache(None)
    def __getitem__(self, idx) :
        batch_x = self.image_filenames[idx * self.batch_size : (idx+1) * self.batch_size]
        out_img = np.asarray([np.load(file_name)[:,:,None] for file_name in batch_x])
        return out_img, out_img
        #return batch_x, batch_y
        
        
class Array_Generator(keras.utils.Sequence) :
    def __init__(self, images,  batch_size) :
        self.images = images
        self.batch_size = batch_size

    def __len__(self) :
        return (np.ceil(len(self.images) / float(self.batch_size))).astype(np.int)
    
    
    @lru_cache(None)
    def __getitem__(self, idx) :
        out_img = self.images[idx * self.batch_size : (idx+1) * self.batch_size, :,:,None]
        return out_img, out_img
        #return batch_x, batch_y

class Sampling(layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        epsilon = tf.keras.backend.random_normal(shape=tf.shape(z_mean))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon
    
def create_vae_model(hparams):
    
    n_img = 128
    latent_dim = hparams['LAT']
    beta = hparams['B']

    image_input = keras.Input(shape=(n_img, n_img,1), name = 'enc_input')
    x = layers.Conv2D(hparams['KN1'],5, strides = 2, activation='relu',padding='same', input_shape=image_input.shape, name = 'enc_conv1')(image_input)
    x = layers.Conv2D(hparams['KN2'],5, strides = 2, activation='relu',padding='same', name = 'enc_conv2')(x)
    x = layers.Conv2D(hparams['KN3'],5, strides = 2, activation='relu',padding='same', name = 'enc_conv3')(x)
    x = layers.Conv2D(hparams['KN4'],5, strides = 2, activation='relu',padding='same', name = 'enc_conv4')(x)
    x = layers.Conv2D(hparams['KN5'],5, strides = 2, activation='relu',padding='same', name = 'enc_conv5')(x)
    x = layers.Flatten()(x)
    x = layers.Dense(hparams['D1'], activation='relu', name = 'enc_d1')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d2_t')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d3_t')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d4_t')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d5_t')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d6_t')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d7_t')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d8_t')(x)
    z_mean = layers.Dense(latent_dim, name="z_mean_t")(x)
    z_log_var = layers.Dense(latent_dim, name="z_log_var_t")(x)
    z_output = Sampling()([z_mean, z_log_var])
    encoder_VAE = keras.Model(image_input, [z_mean, z_log_var, z_output])

    z_input = keras.Input(shape=(latent_dim,), name = 'dec_input_t')
    x = layers.Dense(hparams['D2'], activation="relu", name = 'dec_d1_t')(z_input)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'dec_d2')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'dec_d3')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'dec_d4')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'dec_d5')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'dec_d6')(x)
    x = layers.Dense(hparams['D1'], activation="relu", name = 'dec_d7')(x)
    x = layers.Dense(4*4*hparams['KN5'], activation="relu", name = 'dec_d8')(x)
    x = layers.Reshape((4, 4,hparams['KN5']))(x)
    x = layers.Conv2DTranspose(hparams['KN4'],5, strides = 2, activation='relu',padding='same', name = 'dec_conv1')(x)
    x = layers.Conv2DTranspose(hparams['KN3'],5, strides = 2, activation='relu',padding='same', name = 'dec_conv2')(x)
    x = layers.Conv2DTranspose(hparams['KN2'],5, strides = 2, activation='relu',padding='same', name = 'dec_conv3')(x)
    x = layers.Conv2DTranspose(hparams['KN1'],5, strides = 2, activation='relu',padding='same', name = 'dec_conv4')(x)
    image_output = layers.Conv2DTranspose(1,5, strides = 2, activation='sigmoid',padding='same', name = 'dec_conv5')(x)
    #image_output = layers.Conv2DTranspose(16,3, strides = 2, activation='sigmoid',padding='same')
    #image_output = layers.Reshape((n_img, n_img,1))(x)
    decoder_VAE = keras.Model(z_input, image_output)

    # VAE class
    class VAE(keras.Model):
        # constructor
        def __init__(self, encoder, decoder, **kwargs):
            super(VAE, self).__init__(**kwargs)
            self.encoder = encoder
            self.decoder = decoder

        # customise train_step() to implement the loss 
        def train_step(self, x):
            if isinstance(x, tuple):
                x = x[0]
            with tf.GradientTape() as tape:
                # encoding
                z_mean, z_log_var, z = self.encoder(x)
                # decoding
                x_prime = self.decoder(z)
                # reconstruction error by binary crossentropy loss
                reconstruction_loss = tf.reduce_mean(keras.losses.binary_crossentropy(x, x_prime)) * n_img * n_img
                # KL divergence
                kl_loss = -0.5 * tf.reduce_mean(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
                # loss = reconstruction error + KL divergence
                loss = reconstruction_loss + beta* kl_loss
            # apply gradient
            grads = tape.gradient(loss, self.trainable_weights)
            self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
            # return loss for metrics log
            return {"loss": loss}


        def call(self, x):
            if isinstance(x, tuple):
                x = x[0]
            # encoding
            z_mean, z_log_var, z = self.encoder(x)
            # decoding
            x_prime = self.decoder(z)
            return x_prime
    # build the VAE
    vae_model = VAE(encoder_VAE, decoder_VAE)

    # compile the VAE
    vae_model.compile(optimizer=keras.optimizers.Adam(learning_rate=hparams['LR']),loss=custom_loss)
    vae_model.build((1,128,128,1))
    
    return vae_model



def custom_loss(x,y):
    n_img = 128
    return tf.reduce_mean(keras.losses.binary_crossentropy(x, y)) * n_img * n_img

def remove_background(sample, thresh = 500, old_tag=None, new_tag=None,blanker = 30):
    d = sample.raw_data.data.copy()
    d_shape = d.shape
    n_shape, p_shape = d_shape[0:2], d_shape[2:]
    ps0 = p_shape[0] //2
    try:
        d[:,:,ps0- blanker:ps0 + blanker, ps0 - blanker: ps0 + blanker] = np.zeros((2*blanker,2*blanker))
    except:
        d = d.compute()
        d[:,:,ps0- blanker:ps0 + blanker, ps0 - blanker: ps0 + blanker] = np.zeros((2*blanker,2*blanker))
    maskx, masky = np.where(d.sum(axis=(2,3))<thresh)
    if old_tag !=None:
        clustmap = sample.all_maps[old_tag].copy()
        clustmap += 1
        clustmap[maskx, masky] = 0 
        newmap = np.zeros_like(clustmap)
        for i, o in enumerate(np.unique(clustmap)):
            newmap[np.where(clustmap == o)] = i
        newmap += 1
        if new_tag != None:
            sample.all_maps[new_tag] = newmap
        return newmap
    else:
        return np.where(d.sum(axis=(2,3))<thresh, 0, 1)

def show_cluster_patterns(sample, tag):
    uis = np.unique(sample.all_maps[tag])
    od = np.zeros((uis.size, 512, 256))
    for x,i in enumerate(uis):
        p = resize(sample.all_patterns[tag][x], (256,256))
        n = resize(np.where(sample.all_maps[tag] == i, p.max(), 1), (256,256))
        p[0,:] = p.max()
        o = np.concatenate([n,p], axis = 0)
        od[x] = o
    return hs.signals.Signal2D(od)

def signal_boosted_scan(sample, tag):
    uts = np.unique(sample.all_maps[tag])
    blank = np.zeros_like((sample.raw_data))
    blank = blank.astype('float32')
    for i,t in enumerate(uts):
        print(i,t)
        blank[np.where(sample.all_maps[tag]==t)] = sample.all_patterns[tag][i]
    return hs.signals.Signal2D(blank)

def inv_sbs(sbs, tag = 'vl_vae', sp = (0,0), return_fig = False, interactive = True, **kwargs):
    sbsg = np.repeat(sbs.data.sum(axis= (2,3))[:,:,None],3, -1)
    sbsg /= sbsg.max()
    
    def boost(array):
        return np.log10(np.log10(array+1)+1)

    def format_ax():
        ax[0].set_frame_on(False)
        #ax[1].set_frame_on(False)
        ax[0].set_xticks([])
        ax[0].set_yticks([])
        ax[1].set_xticks([])
        ax[1].set_yticks([])
    fig, ax = plt.subplots(2, 1, gridspec_kw={'height_ratios': [1, 2]}, figsize=(8,8))
    
    
    clust = sample.all_maps[tag][sp[0],sp[1]]

    clust_loc = np.where(sample.all_maps[tag] == clust)

    new_nav = sbsg.copy()

    new_nav[clust_loc] = np.array([0.1254902 , 0.69803922, 0.66666667])
    
    
    ax[0].imshow(new_nav)
    ax[1].imshow(boost(sbs.data[sp[0],sp[1]]), cmap= 'gray', **kwargs)

    format_ax()
    
    if interactive == True:
    
        global coords
        coords = []

        def onclick(event):
            global ix, iy
            ix, iy = np.round(event.xdata,0), np.round(event.ydata,0)
            print(ix, iy)

            coords.append((ix, iy))

            ax[0].clear()
            ax[1].clear()

            clust = sample.all_maps[tag][int(iy),int(ix)]

            clust_loc = np.where(sample.all_maps[tag] == clust)

            new_nav = sbsg.copy()

            new_nav[clust_loc] = np.array([0.1254902 , 0.69803922, 0.66666667])



            ax[0].imshow(new_nav)
            ax[1].imshow(boost(sbs.data[int(iy),int(ix)]), cmap = 'gray', **kwargs)

            format_ax()

            ax[0].draw()
            ax[1].draw()


            return coords

        cid = fig.canvas.mpl_connect('button_press_event', onclick)

    if return_fig == True:
        return fig

from skimage.metrics import structural_similarity as SSI
from skimage.transform import PiecewiseAffineTransform, warp
from sklearn.neighbors import NearestNeighbors as kNN

def get_latgrid(sample, res=100):
    xmin, xmax = np.floor(np.min(sample.encoded_data[:,0])), np.ceil(np.max(sample.encoded_data[:,0]))
    ymin, ymax = np.floor(np.min(sample.encoded_data[:,1])), np.ceil(np.max(sample.encoded_data[:,1]))

    latgrid_res = res

    xgrid, ygrid = np.repeat(np.linspace(xmin, xmax, latgrid_res)[:,None],latgrid_res, axis = 1), np.repeat(np.linspace(ymin, ymax, latgrid_res)[None,:],latgrid_res, axis = 0)

    return np.concatenate([xgrid[:,:,None], ygrid[:,:,None]],axis = 2)

def get_latgrid_free(sample, xmin, xmax, ymin, ymax, res=100):
    latgrid_res = res

    xgrid, ygrid = np.repeat(np.linspace(xmin, xmax, latgrid_res)[:,None],latgrid_res, axis = 1), np.repeat(np.linspace(ymin, ymax, latgrid_res)[None,:],latgrid_res, axis = 0)

    return np.concatenate([xgrid[:,:,None], ygrid[:,:,None]],axis = 2)








def batch_calc_grad(img, radial_kernel, decoded_data, weighting_func, bs=256):
    
    ssi_ff = []
    n_batches = int(np.ceil(img.shape[0]//bs))
    batches = [img[i*bs:(i+1)*bs] for i in range(n_batches+1)]
    dec_batches = [decoded_data[i*bs:(i+1)*bs] for i in range(n_batches+1)]
    rs_batches = [b.reshape(b.shape[0]*b.shape[1]) for b in batches]
    cart_rs_batches = [np.concatenate([b.real[:,None], b.imag[:,None]], axis = 1) for b in rs_batches]
    for i, batch in enumerate(cart_rs_batches):
        t1 = time.time()
        print(i, n_batches)
        nimg = sample.model.decoder(batch).numpy()
        rs_nimg = nimg.reshape((int(nimg.size/(img.shape[1]*128*128)), img.shape[1], 128, 128))
        comp_patterns = dec_batches[i]
        for x, dec_pat in enumerate(comp_patterns):
            grad_ssi = np.asarray([weighting_func(dec_pat, y) for y in rs_nimg[x]])
            ssi_ff.append(np.sum(grad_ssi*radial_kernel))
        print(time.time()-t1, 'single thread')
    return np.asarray(ssi_ff)


def SSI_weighting(img1, img2):
    return 100*SSI(img1,img2)

def get_mobile_points(nn_comp_enc,steps, prev_mp_locs = (), thresh = 'mean', relative_locs = False):
    if thresh == 'mean':
        thresh = np.mean(np.abs(steps))
    if thresh == 'ten':
        thresh = np.max(np.abs(steps))/10
        print(thresh)
        print(np.where(np.abs(steps) > thresh))
    mp_locs = np.where(np.abs(steps) > thresh)
    mobile_points = nn_comp_enc[mp_locs]
    
    if len(prev_mp_locs) != 0:
        n_mp_locs = prev_mp_locs[mp_locs]
        
    if relative_locs == True:

        return mobile_points, n_mp_locs, mp_locs
    else:
        return mobile_points, n_mp_locs

def get_grad_and_decode_data(mobile_points, radial_kernel, r_scale_kernel = False, nn_scale = False):
    if r_scale_kernel ==False:
        grad_points = np.repeat(mobile_points[:,None], radial_kernel.shape[0], axis = 1) + radial_kernel[None, :]
    else:
        if nn_scale == False:
            rf = np.round((np.abs(mobile_points)/np.abs(mobile_points).min()),0).astype('int')
            grad_points = np.repeat(mobile_points[:,None], radial_kernel.shape[0], axis = 1) + r_scale_kernel*rf[:,None]*radial_kernel[None, :]
        else:
            sample_locs = np.concatenate((mobile_points.real[:,None], mobile_points.imag[:,None]), axis = 0)
            nbrs = kNN(n_neighbors=1, algorithm='ball_tree').fit(sample_locs)
            p_sep, indices = nbrs.kneighbors(sample_locs)
            closest = p_sep.min()
            norm_sep = p_sep/closest
            grad_points = np.repeat(mobile_points[:,None], radial_kernel.shape[0], axis = 1) + r_scale_kernel*norm_sep[:,None]*radial_kernel[None, :]

    dec_dat = get_terr_patts(np.concatenate([mobile_points.real[:,None], mobile_points.imag[:,None]],axis = 1))
    return grad_points, dec_dat

def sig_step_from_grad(d_gp, gradient_step, sigz=0.25, sigf=100):
    grad_mag = np.abs(d_gp)

    return sigmoid(grad_mag, sigz, sigf)*gradient_step*(d_gp/grad_mag) 

def norm_step_from_grad(d_gp, factor):
    grad_mag = np.max(np.abs(d_gp))
    
    return (d_gp/grad_mag)*factor 

def sigmoid(z, sigz=0.25, sigf=100):
    x = sigf*(z - sigz)
    return np.exp(-np.logaddexp(0, -x))


def adjust_encoding(mobile_points, grads, comp_enc, mp_locs):
    X,Y  = mobile_points.real, mobile_points.imag

    dX, dY = grads.real, grads.imag
    U, V = X+dX, Y+dY

    moved_points = U+1j*V

    migrated_points = comp_enc.copy()

    migrated_points[mp_locs] = moved_points
    
    return (X,Y), (U,V), migrated_points

def get_terr_patts(img, bs =256):
    n_batches = int(np.ceil(img.shape[0]//bs))
    batches = [img[i*bs:(i+1)*bs] for i in range(n_batches+1)]
    nimg = [sample.model.decoder(batch).numpy() for batch in batches]
    return np.concatenate(nimg, axis = 0).reshape((img.shape[0], 128,128))


def sig_step_from_grad(d_gp, gradient_step, sigz=0.25, sigf=100):
    grad_mag = np.abs(d_gp)

    return sigmoid(grad_mag, sigz, sigf)*gradient_step*(d_gp/grad_mag) 

def sigmoid(z, sigz=0.25, sigf=100):
    x = sigf*(z - sigz)
    return np.exp(-np.logaddexp(0, -x))

def lin_thresh_step(d_gp, thresh, mag = 1):
    scale = np.abs(d_gp)
    return (np.where(scale>thresh, thresh, scale)/thresh)*(d_gp/scale)*mag

def scaled_thresh_step(d_gp, thresh, mobile_points, mag):
    sample_locs = np.concatenate((mobile_points.real[:,None], mobile_points.imag[:,None]), axis = 1)
    nbrs = kNN(n_neighbors=5, algorithm='ball_tree').fit(sample_locs)
    p_sep, indices = nbrs.kneighbors(sample_locs, n_neighbors = 2)
    print(p_sep.shape, p_sep[:,0])
    p_sep = p_sep[:,1]
    closest = p_sep.min()
    norm_sep = p_sep/closest
    
    scale = np.abs(d_gp)
    return (np.where(scale>thresh, thresh, scale)/thresh)*(d_gp/scale)*norm_sep*mag

def sig_step_from_grad(d_gp, gradient_step, sigz=0.25, sigf=100):
    grad_mag = np.abs(d_gp)

    return sigmoid(grad_mag, sigz, sigf)*gradient_step*(d_gp/grad_mag) 

def sigmoid(z, sigz=0.25, sigf=100):
    x = sigf*(z - sigz)
    return np.exp(-np.logaddexp(0, -x))

def lin_thresh_step(d_gp, thresh, mag = 1):
    scale = np.abs(d_gp)
    return (np.where(scale>thresh, thresh, scale)/thresh)*(d_gp/scale)*mag

def scaled_thresh_step(d_gp, thresh, mobile_points, mag):
    sample_locs = np.concatenate((mobile_points.real[:,None], mobile_points.imag[:,None]), axis = 1)
    nbrs = kNN(n_neighbors=5, algorithm='ball_tree').fit(sample_locs)
    p_sep, indices = nbrs.kneighbors(sample_locs, n_neighbors = 2)
    print(p_sep.shape, p_sep[:,0])
    p_sep = p_sep[:,1]
    closest = p_sep.min()
    norm_sep = p_sep/closest
    
    scale = np.abs(d_gp)
    return (np.where(scale>thresh, thresh, scale)/thresh)*(d_gp/scale)*norm_sep*mag


from sklearn.neighbors import KernelDensity
def get_density_net(sample, n_samples, n_bkg_samples, density_approx = 10,  bandwidth=0.5):
    D = sample.encoded_data.copy()
    np.random.shuffle(D)
    D = D[::density_approx]
    kde = KernelDensity(kernel='gaussian', bandwidth=bandwidth).fit(D)
    R = kde.sample(n_samples)
    
    
    
    xmin, xmax = np.floor(np.min(sample.encoded_data[:,0])), np.ceil(np.max(sample.encoded_data[:,0]))
    ymin, ymax = np.floor(np.min(sample.encoded_data[:,1])), np.ceil(np.max(sample.encoded_data[:,1]))
    
    print(xmin, xmax,ymin,ymax)
    s_samples = np.random.random((n_bkg_samples, 2))
    
    s_samples[:,0] *= np.abs((xmax - xmin))
    s_samples[:,1] *= np.abs((ymax - ymin))
    s_samples = s_samples + np.array((xmin, ymin))
    
    return np.concatenate((R, s_samples), axis = 0)

import sklearn.metrics.cluster as cmet

def get_map_label_df(map1):
    return np.asarray([np.where(map1 == uinds, 1, 0) for uinds in np.unique(map1)])

def get_cluster_label_overlap(map_pair):
    db1_df,db2_df = map_pair
    label_overlap = np.zeros((db1_df.shape[0], db2_df.shape[0]))
    for i, idf in enumerate(db1_df):
        for j, jdf in enumerate(db2_df):
            label_overlap[i,j] = np.sum(db1_df[i] * db2_df[j])/ np.sum(db1_df[i])
    return label_overlap

def find_map_label(pos, map1):
    return map1[pos]

def get_confidence_from_maps(maps):
    dfs = [x for x in map(get_map_label_df, maps)]

    cluster_overlaps = [x for x in map(get_cluster_label_overlap, [x for x in itertools.permutations(dfs, 2)])]

    overlap_inds = [x for x in itertools.permutations(np.arange(len(maps)), 2)]

    overlap_inds

    len(cluster_overlaps)

    confidence = np.zeros_like(map1, dtype='float32')
    for point in range(len(map1)):
        labels = [i for i in map(find_map_label, np.repeat(point, len(maps)) , maps)]
        total = 0
        for cind, oinds in enumerate(overlap_inds):
            l1, l2 = labels[oinds[0]], labels[oinds[1]]
            total+=cluster_overlaps[cind][l1, l2]
        mean = total/len(overlap_inds)
        confidence[point] = mean
    return confidence



def refine_based_on_density(sample, density_cutoff = -9, n_bulk_samples = 500, sample_grid_res = 200,
                            gn = 0.1,density_approx = 5, bw = 0.4, n_sample_points = 2500, n_bkg_points = 500, 
                            show_net = True, rand_gradient_step = 0.01,rand_n_rsteps = 12, step_scale = 0.01, 
                            step_thresh = 0.001,show_step_size = True, show_net_movement= True, 
                            show_first_refinement = False,n_refine_steps = 200, show_refinement = True, 
                            animate_refinement = True):
    
    '''
    sample: ProcessedSample with the encoded data
    density_cutoff: Cutoff value below which the denisty gradient is ignored (default - -9)
    n_bulk_samples: number of samples to take from the density distribution (default - 500)
    sample_grid_res: The number of grid points to use to approximate the density gradient function (default - 200)
    gn: Gaussian noise maximum magnitude to be added to the randomly sampled density gradient points (default - 0.1)
    density_approx: the number of skips to take of randomly shuffled encoded data to approximate the density (default - 10)
    bw: bandwidth for the density approximation (default - 0.3)
    n_sample_points: number of net points to sample from the density gradient distribution (default - 4000)
    n_bkg_points: number of net points to sample uniformly (default - 1000)
    show_net: Show the positions of the net points (default - True)
    rand_gradient_step: The distance around net points to sample for gradient approximation (default - 0.01)
    rand_n_rsteps: The number of radial samples around net points to sample for grad approx (default - 12)
    step_scale: The magnitude of the largest steps you want in net movement (default - 0.01)
    step_thresh: The gradient cut-off, above which, the step size == step_scale (default - 0.001)
    show_step_size: Show a graph of the distribution of step sizes (default - True)
    show_net_movement: Show how the net points have distorted (default - True)
    show_first_refinement: Show how the sample points have moved after one step (default - True)
    n_refine_steps: Number of refinement steps to run (default - 500)
    show_refinement: Show how the sample points have moved after n steps (default - True)
    animate_refinement: Show an animation of the refinement process (default - True)
    '''
    
    #create a dictionary to hold some data that might be useful to return
    accessory_dict = {}
    
    #get a density based net
    #R = get_density_net(sample, n_sample_points, n_bkg_points, density_approx, bw)
    R = get_density_gradient_net(sample, n_sample_points, density_cutoff, n_bkg_points,n_bulk_samples, density_approx, sample_grid_res,bw, gn)
    
    accessory_dict['net'] = R
    
    #view the point distribution
    if show_net == True:
        plt.figure()
        plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1], s = 10)
        plt.scatter(R[:,0], R[:,1], s = 20)
    
    #get all the sample points for the gradient 
    rand_latspace = R[:,0] + 1j*R[:,1]
    rand_radial_kernel = rand_gradient_step* np.exp(1j*np.pi*(np.linspace(0, 360, rand_n_rsteps+1)/180))[1:]
    rand_grad_p1, rand_decdat1 = get_grad_and_decode_data(rand_latspace, rand_radial_kernel)
    
    #calculate the gradients
    rand_delta_gp1 = batch_calc_grad(rand_grad_p1, rand_radial_kernel, rand_decdat1, SSI_weighting, 256)
    accessory_dict['grads'] = rand_delta_gp1
    
    #scale the gradients for step sizes
    rand_linsteps = lin_thresh_step(rand_delta_gp1, step_thresh, step_scale)
    if show_step_size == True:
        plt.figure()
        plt.plot(np.sort(np.abs(rand_linsteps)))
    accessory_dict['steps'] = rand_linsteps
    
    #adjust the net points
    rand_op1, rand_np1, rand_current_ps1= adjust_encoding(rand_latspace, rand_linsteps, rand_latspace, np.where(rand_latspace != None))
    accessory_dict['net_displacement'] = rand_np1
    if show_net_movement == True:
        plt.figure(figsize = (8,8))
        plt.scatter(rand_op1[0], rand_op1[1],s =20)
        plt.scatter(rand_np1[0], rand_np1[1],s =20)
        plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1], s = 10, alpha = 0.2)
    #refine the points once
    test_data = sample.encoded_data.copy()

    rand_o_latspacer = np.asarray(rand_op1).T
    rand_n_latspacer = np.asarray(rand_np1).T

    rand_tform2 = PiecewiseAffineTransform()
    accessory_dict['transform'] = rand_tform2
    rand_tform2.estimate(rand_n_latspacer, rand_o_latspacer)

    rand_out_data2 = rand_tform2.inverse(test_data)
    
    accessory_dict['first_refinement'] = rand_out_data2.copy()

    if show_first_refinement == True:
        plt.figure()
        plt.scatter(test_data[:,0], test_data[:,1], s =10)
        plt.scatter(rand_out_data2[:,0], rand_out_data2[:,1], s =10)
        
    rand_refine_steps2=[]

    for i in range(n_refine_steps):
        rand_out_data2 = rand_tform2.inverse(rand_out_data2)
        rand_refine_steps2.append(rand_out_data2)
    
    if show_refinement == True:
        plt.figure()
        plt.scatter(test_data[:,0], test_data[:,1], s =10)
        #plt.scatter(out_data[:,0], out_data[:,1], s =10)
        plt.scatter(rand_out_data2[:,0], rand_out_data2[:,1], s =10)
        
    accessory_dict['refinement_steps'] = rand_refine_steps2

    
    if animate_refinement == True:
        # First set up the figure, the axis, and the plot element we want to animate
        figr3, axr3 = plt.subplots()

        axr3.set_xlim(( -1, 1))
        axr3.set_ylim((-1, 1))

        liner3, = axr3.plot([], [], lw=2, ls = '', marker = 'o', alpha = 0.2)

        def init3():
            liner3.set_data([], [])
            return (liner,)
        def animate3(i):
            d3 = rand_refine_steps2[i]
            x3,y3 = d3[:,0], d3[:,1]
            liner3.set_data(x3, y3)
            return (liner3,)
        # call the animator. blit=True means only re-draw the parts that 
        # have changed.
        animr3 = animation.FuncAnimation(figr3, animate3, init_func=init3,
                                       frames=1250, interval=20, blit=True)
        accessory_dict['animation'] = animr3
    return rand_out_data2, accessory_dict

def get_density_gradient_net(D, n_samples, density_cutoff, n_bkg_samples, n_bulk_samples, density_approx = 10, sample_grid_res = 200, bandwidth=0.5, gn = 0.1):
    np.random.shuffle(D)
    D = D[::density_approx]
    kde = KernelDensity(kernel='gaussian', bandwidth=bandwidth).fit(D)
    R = kde.sample(n_bulk_samples)
    
    
    xgrid = np.linspace(np.floor(D[:,0].min()),np.ceil(D[:,0].max()),sample_grid_res)
    ygrid = np.linspace(np.floor(D[:,1].min()),np.ceil(D[:,1].max()),sample_grid_res)
    X,Y = np.meshgrid(xgrid, ygrid)
    xy = np.vstack((X.ravel(), Y.ravel())).T

    Z = kde.score_samples(xy).reshape(X.shape)

    dY, dX = np.gradient(Z)
    
    dZ = np.hypot(dY,dX)*np.where(Z < density_cutoff, 0, 1)*np.where(Z> (density_cutoff+3), 0, 1)

    dZ = dZ/np.sum(dZ)

    dZ = dZ.reshape(xy.shape[0])

    draw = np.random.choice(np.arange(xy.shape[0]), n_samples,
                  p=dZ, replace = True)

    kdgrad = xy[draw] + gn*(np.random.random((n_samples, 2))-0.5*np.ones((n_samples, 2)))
    
    
    
    xmin, xmax = np.floor(np.min(sample.encoded_data[:,0])), np.ceil(np.max(sample.encoded_data[:,0]))
    ymin, ymax = np.floor(np.min(sample.encoded_data[:,1])), np.ceil(np.max(sample.encoded_data[:,1]))
    
    print(xmin, xmax,ymin,ymax)
    
    bbox = np.array(((xmin, ymin), (xmin, ymax), (xmax, ymin), (xmax, ymax)))
    s_samples = np.random.random((n_bkg_samples, 2))
    
    s_samples[:,0] *= np.abs((xmax - xmin))
    s_samples[:,1] *= np.abs((ymax - ymin))
    s_samples = s_samples + np.array((xmin, ymin))
    
    return np.concatenate((kdgrad, s_samples, R), axis = 0)

def SSI_remesh(sample, R, n_add_points = 1, ssi_thresh = 0.95):
    tri = Delaunay(R)

    all_simps = tri.simplices

    line_segs= np.asarray([[np.asarray(x) for x in itertools.combinations(R[simps], 2)] for simps in all_simps])

    line_add_points= np.asarray([np.asarray([line_interp(x[0], x[1],2+n_add_points)[1:-1] for x in itertools.combinations(R[simps], 2)]) for simps in all_simps])

    f_line_segs = flatten_nav(line_segs)

    f_add_points = flatten_nav(line_add_points)

    fline_start, fline_finish = f_line_segs[:,0,:], f_line_segs[:,1,:] 

    patts_start, patts_finish = get_terr_patts(fline_start), get_terr_patts(fline_finish)

    line_ssi = np.zeros(patts_start.shape[0])

    ssi_input_data = []
    for i in range(patts_start.shape[0]):
        ssi_input_data.append((i, patts_start[i], patts_finish[i]))
        
    print('getting line ssi')

    with concurrent.futures.ProcessPoolExecutor() as exe:
        res = [exe.submit(compare_point_SSI, ssi_input) for ssi_input in ssi_input_data]
    r_batches = [f.result() for f in res]
    
    print('done getting line ssi')

    for each in r_batches:
        line_ssi[each[0]] = each[1]

    poor_line_locs = np.where(line_ssi < ssi_thresh)

    new_points = flatten_nav(f_add_points[poor_line_locs])

    plt.figure()
    plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1], s = 10, alpha = 0.5, c = 'grey')
    plt.triplot(R[:,0], R[:,1], all_simps, lw = 1)
    plt.scatter(new_points[:,0], new_points[:,1], s = 10, alpha = 1, c = 'black', marker = 'x')
    plt.title('Additional Mesh Points')
    
    Rp = np.concatenate((R, new_points), axis = 0)
    
    trip = Delaunay(Rp)
    
    plt.figure()
    plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1], s = 10, alpha = 0.5, c = 'grey')
    plt.triplot(Rp[:,0], Rp[:,1], trip.simplices, lw = 1)
    plt.title('New Mesh')
    return Rp

def compare_point_SSI(required_data):
    return (required_data[0], SSI(required_data[1], required_data[2]))

from scipy.spatial import Delaunay

def latspace_from_R(R):
    return R[:,0] + 1j*R[:,1]

def get_mesh_gradients(R, rand_gradient_step, rand_n_rsteps, bs = 256):
    rand_latspace = latspace_from_R(R)

    rand_radial_kernel = rand_gradient_step* np.exp(1j*np.pi*(np.linspace(0, 360, rand_n_rsteps+1)/180))[1:]

    rand_grad_p1, rand_decdat1 = get_grad_and_decode_data(rand_latspace, rand_radial_kernel)

    return batch_calc_grad(rand_grad_p1, rand_radial_kernel, rand_decdat1, SSI_weighting, bs)

def get_mesh_transform(R, rand_linsteps):
    rand_latspace = latspace_from_R(R)
    rand_op1, rand_np1, rand_current_ps1= adjust_encoding(rand_latspace, rand_linsteps, rand_latspace, np.where(rand_latspace != None))

    rand_o_latspacer = np.asarray(rand_op1).T
    rand_n_latspacer = np.asarray(rand_np1).T

    rand_tform2 = PiecewiseAffineTransform()
    rand_tform2.estimate(rand_n_latspacer, rand_o_latspacer)
    R_moves =  (rand_np1, rand_op1)
    return rand_tform2, R_moves

def plot_R_movement(sample, R_moves):
    X,Y = (R_moves[1][0], R_moves[1][1])
    U, V = (R_moves[0][0] -  R_moves[1][0], R_moves[0][1]- R_moves[1][1])
    plt.figure(figsize = (8,8))
    plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1], s = 10, alpha = 0.2, c = 'grey')
    plt.scatter(X, Y ,s =10, c='blue')
    plt.scatter(R_moves[0][0], R_moves[0][1],s =10, c ='orange')
    plt.quiver(X, Y, U,V)
    
def plot_enc_movement(sample, tform):
    rand_out_data2 = tform.inverse(sample.encoded_data)
    plt.figure()
    plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1], s =10)
    #plt.scatter(out_data[:,0], out_data[:,1], s =10)
    plt.scatter(rand_out_data2[:,0], rand_out_data2[:,1], s =10)
    
def repeat_tform(tform, n_refine_steps, iter_seq):
    rand_out_data2 = iter_seq[-1]
    import time
    t1 = time.time()
    for i in range(n_refine_steps):
        print(i)
        rand_out_data2 = tform.inverse(rand_out_data2)
        iter_seq.append(rand_out_data2)
    print(time.time() - t1)
    return iter_seq

def cosine_rule(a, b, c):
    return np.arccos(((a**2) +(b**2) - (c**2))/(2*a*b))

def angles(ps):
    p1, p2, p3 = ps
    a, b, c = p2-p1, p3-p2, p1-p3
    al, bl, cl = np.linalg.norm(p2-p1), np.linalg.norm(p3-p2), np.linalg.norm(p1-p3)
    return [np.rad2deg(x) for x in [cosine_rule(al, bl, cl), cosine_rule(bl, cl, al), cosine_rule(cl,al,bl)]]

def heron(ps):
    p1, p2, p3 = ps
    a, b, c = np.linalg.norm(p2-p1), np.linalg.norm(p3-p2), np.linalg.norm(p1-p3)
    s = (a+b+c)/2
    return np.sqrt(s*(s-a)*(s-b)*(s-c))

def line_interp(p1, p2, nsteps):
    return np.concatenate([np.linspace(p1[0], p2[0], nsteps)[:,None], np.linspace(p1[1], p2[1], nsteps)[:,None]], axis = 1)

def remesh(sample, R, R_moves, density_cutoff, n_add_points = 1, n_movement_bins = 10, area_thresh = None, angle_thresh = 15, n_area_bins = 1000, n_line_interp = 10):
    #get current triangulation 
    tri = Delaunay(R)
    #view current triangulation
    plt.figure()
    plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1], s = 10, alpha = 0.5, c = 'grey')
    plt.triplot(R[:,0], R[:,1], tri.simplices, lw = 1)
    plt.plot(R[:,0], R[:,1], 'o', markersize= 1)
    plt.title('Initial Triangulation')
    #get current mesh movements
    rand_np1, rand_op1 = R_moves
    movement = np.asarray((rand_np1[0] -  rand_op1[0], rand_np1[1]- rand_op1[1])).T
    #get total movement of the simplex vertices
    simp_move = np.sum(np.asarray([np.asarray([np.linalg.norm(movement[x]) for x in simp]) for simp in tri.simplices]), axis = 1)
    #hist these
    plt.figure()
    (n, bins, patches) =  plt.hist(simp_move, n_movement_bins)
    plt.title('Simplex Movement Histogram')
    #Truncate after first bin
    high_m_simps = tri.simplices[np.where(simp_move > bins[1])]
    plt.figure()
    plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1], s = 10, alpha = 0.5, c = 'grey')
    plt.triplot(R[:,0], R[:,1], high_m_simps, lw = 1)
    plt.title('High Movement Simplices')
    #Get Area of remaining simplices
    simp_area = np.asarray([heron(R[simp]) for simp in high_m_simps])
    simp_angles =  np.asarray([angles(R[simp]) for simp in high_m_simps])
    plt.figure()
    (n, area_bins, patches) =  plt.hist(simp_area, n_area_bins)
    plt.title('Simplex Area Histogram')
    if area_thresh == None:
        size_inc = np.where(np.where(simp_area > area_bins[1], 1, 0) + (np.where(simp_angles[:,0] < angle_thresh, 1, 0)*np.where(simp_area > area_bins[1]/10, 1, 0))>0)
        high_a_simps = high_m_simps[size_inc]
    else: 
        size_inc = np.where(np.where(simp_area > area_thresh, 1, 0) + (np.where(simp_angles[:,0] < angle_thresh, 1, 0)*np.where(simp_area > area_thresh/10, 1, 0))>0)
        high_a_simps = high_m_simps[size_inc]
    plt.figure()
    plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1], s = 10, alpha = 0.5, c = 'grey')
    plt.triplot(R[:,0], R[:,1], tri.simplices, lw = 1)
    plt.triplot(R[:,0], R[:,1], high_a_simps, lw = 1)
    plt.title('Area and Movement Pruned Simplices')
    #For each of the remaining mesh lines, calculate a linear interpolation of sampling points
    line_segs= np.asarray([np.asarray([line_interp(x[0], x[1],n_line_interp) for x in itertools.combinations(R[simps], 2)]) for simps in high_a_simps])
    #and a midpoint to be potentially added to the new mesh
    line_add_points= np.asarray([np.asarray([line_interp(x[0], x[1],2+n_add_points)[1:-1] for x in itertools.combinations(R[simps], 2)]) for simps in high_a_simps])
    #calculate an approximation of real data density at each of the points along the mesh line
    print(line_segs.shape)
    den_seg = np.asarray([[kde.score_samples(lss) for lss in ls] for ls in line_segs])
    #store the min and max value of this density
    den_seg_minmax = np.concatenate((den_seg.min(axis = 2)[:,:,None], den_seg.max(axis = 2)[:,:,None]), axis = 2)
    #calculate the gradient of the change in real data density along the simplex line
    den_seg_grad = np.gradient(den_seg, axis = 2)
    #Find if there is a change in sign of the density (implying a change in character of the underlying point distr)
    #old grad_changes = np.asarray([[(lines.max() * lines.min())> 0 for lines in simps] for simps in den_seg_grad])
    grad_changes = np.asarray([[ np.abs(lines).max() > 1 for lines in simps] for simps in den_seg_grad])
    #If there is a change in sign and the maximum value of the density is sufficently large 
    # ie (the line itself is near points) then add the line to be split
    interesting_lines = []
    for si in range(line_segs.shape[0]):
        for li in range(line_segs.shape[1]):
            if grad_changes[si, li] == True:
                if den_seg_minmax[si,li, 1] > density_cutoff:
                    interesting_lines.append((si, li))
    ilines = np.asarray(interesting_lines)
    i_simps =  high_a_simps[np.unique(ilines[:,0])]
    #get the midpoints of these lines
    refinement_points = flatten_nav(np.asarray([line_add_points[ninds[0], ninds[1]] for ninds in ilines]))
    plt.figure()
    plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1], s = 10, alpha = 0.5, c = 'grey')
    plt.triplot(R[:,0], R[:,1], high_a_simps, lw = 1)
    plt.triplot(R[:,0], R[:,1], i_simps, lw = 1)
    plt.scatter(refinement_points[:,0], refinement_points[:,1], s = 10, alpha = 1, c = 'black', marker = 'x')
    plt.title('Additional Mesh Points')
    #Add these points to the original points
    Rp = np.concatenate((R, refinement_points), axis = 0)
    #View the new Triangulation
    trip = Delaunay(Rp)
    plt.figure()
    plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1], s = 10, alpha = 0.5, c = 'grey')
    plt.triplot(Rp[:,0], Rp[:,1], trip.simplices, lw = 1)
    plt.plot(Rp[:,0], Rp[:,1], 'o', markersize= 1)
    plt.title('New Triangulation')
    return (Rp, refinement_points)

def get_dense_centroids(boundaries, allowed_centroid_mask, thresh = 0.01, eps = 1.5, min_samples = 6):
    bdZ = get_grad(boundaries)

    plt.figure()
    plt.imshow(bdZ)

    plt.figure()
    plt.imshow(np.where(bdZ<thresh, 1, 0))

    plt.figure()
    plt.imshow(np.where(bdZ<thresh, 1, 0)* allowed_centroid_mask)
    
    centroid_search_region = np.where(bdZ<thresh, 1, 0)* allowed_centroid_mask

    density_stationary_points = np.asarray(np.where((centroid_search_region)==1)).T

    db = DBSCAN(eps, min_samples=min_samples )

    dspc = db.fit_predict(density_stationary_points)

    dspc.max()

    plt.figure()
    plt.scatter(density_stationary_points[:,0], density_stationary_points[:,1], c = dspc)

    dsp_centroids = np.asarray([np.mean(density_stationary_points[np.where(dspc == uind)],axis = 0) for uind in np.unique(dspc) if uind != -1])

    plt.figure()
    plt.scatter(density_stationary_points[:,0], density_stationary_points[:,1], c = dspc)
    plt.scatter(dsp_centroids[:,0], dsp_centroids[:,1], marker = 'x', c = 'red')

    xgrid = np.linspace(np.floor(D[:,0].min()),np.ceil(D[:,0].max()),sample_grid_res)
    ygrid = np.linspace(np.floor(D[:,1].min()),np.ceil(D[:,1].max()),sample_grid_res)

    centroid_approx = np.round(dsp_centroids,0).astype('int')

    centroid_approx

    espace_dsp_centroids = np.concatenate((xgrid[centroid_approx[:,1]][:,None], ygrid[centroid_approx[:,0]][:,None]), axis = 1)

    plt.figure()
    plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1], s=5, alpha = 0.1, cmap= 'turbo')
    plt.scatter(espace_dsp_centroids[:,0], espace_dsp_centroids[:,1], marker='x')
    
    return espace_dsp_centroids, centroid_search_region


def get_sparse_centroids(boundaries, allowed_centroid_mask, thresh = 0.01, eps = 6, min_samples = 7):
    bdZ = get_grad(get_grad(get_grad(get_grad(boundaries))))

    plt.figure()
    plt.imshow(bdZ)

    plt.figure()
    plt.imshow(np.where(bdZ<thresh, 1, 0))

    plt.figure()
    plt.imshow(np.where(bdZ<thresh, 1, 0)* allowed_centroid_mask)
    
    centroid_search_region = np.where(bdZ<thresh, 1, 0)* allowed_centroid_mask

    density_stationary_points = np.asarray(np.where((centroid_search_region)==1)).T

    db = DBSCAN(eps, min_samples=min_samples )

    dspc = db.fit_predict(density_stationary_points)

    dspc.max()

    plt.figure()
    plt.scatter(density_stationary_points[:,0], density_stationary_points[:,1], c = dspc)

    dsp_centroids = np.asarray([np.mean(density_stationary_points[np.where(dspc == uind)],axis = 0) for uind in np.unique(dspc) if uind != -1])

    plt.figure()
    plt.scatter(density_stationary_points[:,0], density_stationary_points[:,1], c = dspc)
    plt.scatter(dsp_centroids[:,0], dsp_centroids[:,1], marker = 'x', c = 'red')

    xgrid = np.linspace(np.floor(D[:,0].min()),np.ceil(D[:,0].max()),sample_grid_res)
    ygrid = np.linspace(np.floor(D[:,1].min()),np.ceil(D[:,1].max()),sample_grid_res)

    centroid_approx = np.round(dsp_centroids,0).astype('int')

    centroid_approx

    espace_dsp_centroids = np.concatenate((xgrid[centroid_approx[:,1]][:,None], ygrid[centroid_approx[:,0]][:,None]), axis = 1)

    plt.figure()
    plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1], s=5, alpha = 0.1, cmap= 'turbo')
    plt.scatter(espace_dsp_centroids[:,0], espace_dsp_centroids[:,1], marker='x')
    
    return espace_dsp_centroids, centroid_search_region

def from_centroids_refine_clusters_and_centroids(centroids, R, sample):
    R_closest_c = find_R_closest_centroid(R, centroids)
    new_hull_map, probs = get_map_from_R_boundaries(R_closest_c, sample, centroids)
    hull_r_centroids = get_centroids(sample, new_hull_map)
    plt.figure()
    plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1], c=flatten_nav(new_hull_map), cmap = 'turbo')
    plt.scatter(centroids[:,0], centroids[:,1], marker='x', c = 'grey')
    plt.scatter(hull_r_centroids[:,0], hull_r_centroids[:,1], marker='x', c = 'black')

    plt.figure()
    plt.imshow(new_hull_map, cmap= 'turbo')
    return hull_r_centroids, new_hull_map, probs, R_closest_c

def point_in_hull(point, hull, tolerance=1e-12):
    return all(
        (np.dot(eq[:-1], point) + eq[-1] <= tolerance)
        for eq in hull.equations)

def get_centroids(sample, map1):
    centroids = [np.mean(sample.encoded_data[np.where(flatten_nav(map1) == uind)], axis = 0) for uind in np.unique(map1)]
    return np.asarray(centroids)

def get_new_centroids(centroids, map1, ssi_thresh = 1):
    centroid_patts = get_terr_patts(centroids)

    centroid_cm = np.zeros((centroids.shape[0], centroids.shape[0]))

    for i in range((centroid_patts.shape[0])):
        for j in range((centroid_patts.shape[0])):
            if i == j:
                centroid_cm[i][j] = 100
            else:
                ssi = (1- SSI(centroid_patts[i], centroid_patts[j]))*100
                centroid_cm[i][j] = ssi


    edges = np.asarray(np.where(centroid_cm<ssi_thresh)).T
    nodes = np.unique(edges)

    g = nx.Graph()

    g.add_nodes_from(nodes)
    g.add_edges_from(edges)

    plt.figure()
    nx.draw(g, with_labels= True)

    con_comp = [x for x in nx.connected_components(g)]
    

    all_con_comps = []
    _ = [[all_con_comps.append(i) for i in e] for e in con_comp]

    uncon_comps = list(range(centroids.shape[0]))
    _ = [uncon_comps.pop(uncon_comps.index(i)) for i in all_con_comps]
    
    print(all_con_comps, uncon_comps)

    comb_map1 = np.zeros_like(map1)

    for i,uc  in enumerate(uncon_comps):
        comb_map1[np.where(map1 == uc)] = i
    for j, cc in enumerate(con_comp):
        for jc in cc:
            comb_map1[np.where(map1 == jc)] = i+j+1
            

    new_centroids = get_centroids(sample, comb_map1)
    plt.figure()
    plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1], s=5, alpha = 0.1, c = flatten_nav(comb_map1), cmap= 'turbo')
    plt.scatter(new_centroids[:,0], new_centroids[:,1], marker='x', c='red', s = 25)

    return new_centroids

def find_R_closest_centroid(R, new_centroids):
    R_patts = get_terr_patts(R)

    ncp = get_terr_patts(new_centroids)

    R_ssi = []
    for R_p in R_patts:
        R_ssi.append(np.argsort([SSI(R_p, incp) for incp in ncp]))

    R_ssi = np.asarray(R_ssi)

    best_centroid = R_ssi[:,-1]

    centroid_Rs = []
    for uind in np.unique(best_centroid):
        centroid_Rs.append(R[np.where(best_centroid==uind)])
    return np.asarray(centroid_Rs)

def get_map_from_R_boundaries(centroid_Rs, sample, new_centroids):
    hull_labels = []
    for cind in range(len(centroid_Rs)):
        try:
            hull = ConvexHull(centroid_Rs[cind])
            hull_labels.append(np.where([point_in_hull(p,hull) for p in sample.encoded_data], 1, 0))
        except:
            hull_labels.append(np.zeros(sample.encoded_data.shape[0]))
    hull_labels = np.asarray(hull_labels)

    hull_labels = np.asarray(hull_labels)

    conflict_points = np.where(hull_labels.sum(axis = 0) > 1)[0]

    non_conflict = list(range(hull_labels.shape[1]))
    _ = [non_conflict.pop(non_conflict.index(i)) for i in conflict_points]

    conf = np.ones(hull_labels.shape[0])
            
    ncp = get_terr_patts(new_centroids)
    conflict_patts = get_terr_patts(sample.encoded_data[conflict_points])

    conflicting_point_probs = []

    for cpi, cp in enumerate(conflict_points):
        conflicting_cents = np.where(hull_labels[:,cp] ==1)[0]
        cp_patt = conflict_patts[cpi]
        conflicting_ssi = []
        for ccent in conflicting_cents:
            conflicting_ssi.append(SSI(ncp[ccent], cp_patt))
        conflicting_ssi = np.asarray(conflicting_ssi)
        probabilities = conflicting_ssi/np.sum(conflicting_ssi)
        prob_dict = {}
        for i, ccent in enumerate(conflicting_cents):
            prob_dict[ccent] = probabilities[i]
        conflicting_point_probs.append(prob_dict)

    conflicting_point_probs

    hull_prediction_labels = np.zeros(sample.encoded_data.shape[0])

    hull_labels.shape

    np.where(hull_labels[:,0]==1)[0][0]

    for nci in non_conflict:
        clust = np.where(hull_labels[:,nci]==1)[0]
        if len(clust) == 0:
            best_clust = -1
        else:
            best_clust = clust[0]
        hull_prediction_labels[nci] = best_clust

    for i, cpi in enumerate(conflict_points):
        probs = conflicting_point_probs[i]
        best_fit = list(probs.keys())[np.asarray(list(probs.values())).argmax()]
        hull_prediction_labels[cpi] = best_fit
    

    hull_labels.shape

    outliers = np.where(hull_prediction_labels ==-1)[0]

    outlier_patts = get_terr_patts(sample.encoded_data[outliers])

    for i, p in enumerate(outlier_patts):
        hull_prediction_labels[outliers[i]] = np.argsort([SSI(p, incp) for incp in ncp])[-1]

    hull_prediction_relabelled = np.zeros_like(hull_prediction_labels)
    for i, ind in enumerate(np.unique(hull_prediction_labels)):
        hull_prediction_relabelled[np.where(hull_prediction_labels==ind)]=i 
        
    nav_shape = sample.raw_data.data.shape[:2]

    hull_pred_map = hull_prediction_relabelled.reshape(nav_shape)
    return hull_pred_map, conflicting_point_probs

def merge_centroids(centroids, ssi_thresh = 1):
    centroid_patts = get_terr_patts(centroids)

    centroid_cm = np.zeros((centroids.shape[0], centroids.shape[0]))

    for i in range((centroid_patts.shape[0])):
        for j in range((centroid_patts.shape[0])):
            if i == j:
                centroid_cm[i][j] = 100
            else:
                ssi = (1- SSI(centroid_patts[i], centroid_patts[j]))*100
                centroid_cm[i][j] = ssi


    edges = np.asarray(np.where(centroid_cm<ssi_thresh)).T
    nodes = np.unique(edges)

    g = nx.Graph()

    g.add_nodes_from(nodes)
    g.add_edges_from(edges)

    plt.figure()
    nx.draw(g, with_labels= True)

    con_comp = [x for x in nx.connected_components(g)]
    

    all_con_comps = []
    _ = [[all_con_comps.append(i) for i in e] for e in con_comp]

    uncon_comps = list(range(centroids.shape[0]))
    _ = [uncon_comps.pop(uncon_comps.index(i)) for i in all_con_comps]


    new_centroids = []

    for uc  in uncon_comps:
        new_centroids.append(centroids[uc])
    for j, cc in enumerate(con_comp):
        new_centroids.append(np.mean(np.asarray([centroids[jc] for jc in cc]), axis = 0))
            

    new_centroids = np.asarray(new_centroids)
    plt.figure()
    plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1], s=5, alpha = 0.1, cmap= 'turbo')
    plt.scatter(new_centroids[:,0], new_centroids[:,1], marker='x', c='red', s = 25)

    return new_centroids

from sklearn.cluster import DBSCAN
import networkx as nx
from scipy.spatial import ConvexHull, convex_hull_plot_2d

def get_split_centroids(new_hull_map, sample, classes= (), eps=0.1, ms = 15, return_all = True):
    split_centroids = []
    for uind in np.unique(new_hull_map):
        include = False
        cluster_member_loc = np.where(flatten_nav(new_hull_map) == uind)
        cmembers = sample.encoded_data[cluster_member_loc]
        try:
            classes.index(uind)
            c = DBSCAN(eps = eps, min_samples =ms).fit_predict(cmembers)
            include = True
            print('scanned')
        except:
            c = np.ones(cmembers.shape[0])*-1
        plt.figure()
        plt.scatter(cmembers[:,0], cmembers[:,1], c=c)
        if return_all == True:
            ccentroids = [split_centroids.append(np.mean(cmembers[np.where(c ==cind)], axis = 0)) for cind in np.unique(c) if ((cind != -1) or (len(np.unique(c)==0)))]
        else:
            if include == True:
                ccentroids = [split_centroids.append(np.mean(cmembers[np.where(c ==cind)], axis = 0)) for cind in np.unique(c) if ((cind != -1) or (len(np.unique(c)==0)))]
                
    return np.asarray(split_centroids)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


2022-05-04 15:17:18.679152: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1


Using TensorFlow v2.3.0
Using TensorFlow v2.3.0


### Check the GPU can be found

In [2]:
tf.__version__

'2.3.0'

In [3]:
tf.config.list_physical_devices()

2022-05-04 15:17:37.289830: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcuda.so.1


[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:XLA_CPU:0', device_type='XLA_CPU'),
 PhysicalDevice(name='/physical_device:XLA_GPU:0', device_type='XLA_GPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

2022-05-04 15:17:37.290958: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 0 with properties: 
pciBusID: 0000:3b:00.0 name: Tesla V100-PCIE-32GB computeCapability: 7.0
coreClock: 1.38GHz coreCount: 80 deviceMemorySize: 31.75GiB deviceMemoryBandwidth: 836.37GiB/s
2022-05-04 15:17:37.290986: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2022-05-04 15:17:37.299737: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10
2022-05-04 15:17:37.305970: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcufft.so.10
2022-05-04 15:17:37.308053: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcurand.so.10
2022-05-04 15:17:37.312153: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusolv

### Select your data set

In [4]:
dp = Path('/dls/science/groups/imaging/ePSIC_students/Andy_Bridger/PaperDataRepo/SimulatedData/SimulatedDSA-data.hdf5')

### Set the Model Path

In [5]:
mp = dp.redirect('Final_Models')
mp

Path('/dls/science/groups/imaging/ePSIC_students/Andy_Bridger/PaperDataRepo/SimulatedData/Final_Models')

### Create a ProcessedSample Object

In [6]:
sample = ProcessedSample(dp, 'Test')



In [7]:
sig = pxm.signals.ElectronDiffraction2D(sample.raw_data)

In [8]:
sig.set_scan_calibration(100)
sig.set_diffraction_calibration(0.01)

In [9]:
sig.plot()

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [10]:
sample.raw_data.plot()

[########################################] | 100% Completed |  4.0s


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

### Pre-process the data to speed up later functions (can avoid this if there are memory constraints)

In [11]:
sample.save_ml_manipulation('full_ds', data_manip_lowq, 128)

dask to numpy
dask to numpy done
started data manipulations
resized


## Set the hparams, can pull these out of the info dictionary

In [12]:
hparams= {'KN1':32,'KN2':64,'KN3':128, 'KN4':128, 'KN5':256,'D1':128,'D2':512,'LAT':2,'LR':0.0001, 'B':1}

In [13]:
sample.set_model_data('/dls/science/groups/imaging/ePSIC_students/Andy_Bridger/mg28034-1/processing/Models/New','cnn',hparams['LAT'],use_generic_model = True)

### Check the model has built

In [14]:
model = create_vae_model(hparams)
model.summary()

2022-05-04 15:18:06.034907: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN)to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-05-04 15:18:06.045618: I tensorflow/core/platform/profile_utils/cpu_utils.cc:104] CPU Frequency: 2400000000 Hz
2022-05-04 15:18:06.049830: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x55ad591fe420 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2022-05-04 15:18:06.049853: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
2022-05-04 15:18:06.148023: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x55ad5926af80 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2022-05-04 15:18:06.

Model: "vae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
functional_1 (Functional)    [(None, 2), (None, 2), (N 3654660   
_________________________________________________________________
functional_3 (Functional)    (None, 128, 128, 1)       3394817   
Total params: 7,049,477
Trainable params: 7,049,477
Non-trainable params: 0
_________________________________________________________________


### Load in the trained weights

best_model = mp.walk('.hdf5')[1]
best_model

In [200]:
model.load_weights('/dls/science/groups/imaging/ePSIC_students/Andy_Bridger/PaperDataRepo/SimulatedData/FullModel/chk-189-4.37663e+01.hdf5')

### You now need to set the model to the sample 

In [201]:
sample.set_model(model)

### You can encode the data you pre-processed earlier (or will default to sample.raw_data)

In [202]:
sample.encode('vae',input_data_tag='full_ds', bn= 16)

In [203]:
sample.raw_data

<LazySignal2D, title: , dimensions: (296, 295|128, 128)>

### You can inspect the reconstructed image and compare it to the raw data

In [204]:
dim0_info = (0, np.floor(sample.encoded_data[:,0].min()), np.ceil(sample.encoded_data[:,0].max()),200)
dim1_info = (1, np.floor(sample.encoded_data[:,1].min()), np.ceil(sample.encoded_data[:,1].max()),200)

In [205]:
dim0_info

(0, -5.0, 5.0, 200)

In [206]:
sample.chart_terrain(dim0_info, dim1_info)

(200, 200, 128, 128)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

## Sample from KDE first derivative

In [207]:
density_approx = 2
bw = 0.025
n_sample_points = 4000
n_bkg_points = 0
n_bulk_samples = 1000
density_cutoff = -8
sample_grid_res = 200
gn = 0.1

D = sample.encoded_data.copy()

np.random.shuffle(D)

D.shape

(87320, 2)

In [208]:
plt.figure()
plt.scatter(D[:,0], D[:,1], s =1)

<IPython.core.display.Javascript object>

<matplotlib.collections.PathCollection at 0x7f4eb4785190>

In [209]:
from sklearn.neighbors import KernelDensity
kde = KernelDensity(kernel='gaussian', bandwidth=bw).fit(D)

In [210]:
xgrid, ygrid = sample.terrain_grid

In [211]:
X,Y = np.meshgrid(xgrid, ygrid)
xy = np.vstack((X.ravel(), Y.ravel())).T

xy.shape

Z = kde.score_samples(xy).reshape(X.shape)

dY, dX = np.gradient(Z)

In [212]:
plt.figure()
plt.imshow(Z)

plt.figure()
plt.imshow(np.hypot(dY,dX)*np.where(Z < density_cutoff, 0, 1)*np.where(Z > (density_cutoff+3), 0, 1))

plt.figure()
plt.contourf(X,Y,Z, levels = np.linspace(Z.min(), Z.max(), 50))

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<matplotlib.contour.QuadContourSet at 0x7f4eb4641d50>

In [213]:
R = get_density_gradient_net(sample.encoded_data.copy(), n_sample_points, density_cutoff, n_bkg_points,n_bulk_samples, density_approx, sample_grid_res,bw, gn)

-5.0 5.0 -3.0 3.0


In [214]:
plt.figure()
plt.scatter(sample.encoded_data[:,0], sample.encoded_data[:,1], s = 10, alpha = 0.3)
plt.scatter(R[:,0], R[:,1], s = 5, alpha = 1 )

<IPython.core.display.Javascript object>

<matplotlib.collections.PathCollection at 0x7f515e808310>

In [215]:
density_cutoff = -6

In [216]:
allowed_centroid_mask = np.where(Z < density_cutoff, 0, 1)

In [217]:
plt.figure()
plt.imshow(allowed_centroid_mask)

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f515674ed50>

In [218]:
bounded_locs = (np.where(allowed_centroid_mask==1)[1], np.where(allowed_centroid_mask==1)[0])

In [219]:
bounded_regions = sample.terrain_signal.data[bounded_locs]

In [220]:
norm_bounded_regions = bounded_regions/np.max(bounded_regions, axis = (1,2))[:,None,None]

In [221]:
cent = pxm.signals.ElectronDiffraction2D(bounded_regions)

In [222]:
cent.decomposition(True, algorithm='NMF', output_dimension=16)

Decomposition info:
  normalize_poissonian_noise=True
  algorithm=NMF
  output_dimension=16
  centre=None
scikit-learn estimator:
NMF(n_components=16)


In [223]:
cent.plot_decomposition_factors()

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [224]:
decomp_facts = cent.get_decomposition_factors().data

In [225]:
norm_dcf = decomp_facts/ np.max(decomp_facts, axis = (1,2))[:,None,None]

In [226]:
br_locs = np.asarray(np.where(allowed_centroid_mask==1)).T

In [227]:
br_locs

array([[ 12, 118],
       [ 13,  81],
       [ 13,  93],
       ...,
       [189,  59],
       [189, 100],
       [191,  71]])

In [228]:
br_locs[0]

array([ 12, 118])

In [229]:
best_matches = [np.argmax([SSI(br, df) for br in norm_bounded_regions]) for df in norm_dcf]

In [230]:
pca_centroid_locs = br_locs[np.asarray(best_matches)]

In [231]:
pca_centroid_locs

array([[ 66, 164],
       [108,  81],
       [179, 113],
       [139,  54],
       [123,  47],
       [101,  94],
       [147,  71],
       [116,  79],
       [157, 147],
       [130,  60],
       [ 92,  47],
       [104,  52],
       [111, 122],
       [ 95, 143],
       [141, 119],
       [ 94,  88]])

In [232]:
pca_centroids = np.concatenate((xgrid[pca_centroid_locs[:,1]][:,None], ygrid[pca_centroid_locs[:,0]][:,None]), axis = 1)

plt.figure()
plt.scatter(D[:,0], D[:,1], s =1)
plt.scatter(pca_centroids[:,0], pca_centroids[:,1], s =10)

<IPython.core.display.Javascript object>

<matplotlib.collections.PathCollection at 0x7f4eb442c5d0>

In [233]:
decomp_ind = 10
plt.figure()
plt.imshow(norm_dcf[decomp_ind])
pcax, pcay = pca_centroid_locs[decomp_ind]
plt.figure()
plt.imshow(sample.terrain_signal.data[pcay, pcax])
plt.figure()
plt.imshow(sample.model.decoder(pca_centroids[decomp_ind][None,:]).numpy()[0,:,:,0])

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f4eb42b5e90>

In [234]:
refined_centroids = merge_centroids(pca_centroids, ssi_thresh=1)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [235]:
t1 = time.time()
hull_r_centroids, new_hull_map, probs, clustered_Rs = from_centroids_refine_clusters_and_centroids(refined_centroids, R, sample)
print(time.time()-t1)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

288.60171341896057


In [236]:
plt.figure()
plt.imshow(new_hull_map, interpolation = 'nearest', cmap = 'turbo')

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f515821f450>

In [240]:
max_diff = np.max(sample.raw_data.data, axis = (0,1)).compute()
mean_nav = np.mean(sample.raw_data.data, axis = (2,3)).compute()

In [243]:
plt.figure(figsize = (8,8))
plt.imshow(max_diff, cmap = 'gray', vmax = 30)
plt.xticks([])
plt.yticks([])

<IPython.core.display.Javascript object>

([], [])

In [245]:
plt.figure(figsize = (8,8))
plt.imshow(mean_nav, cmap = 'gray')
plt.xticks([])
plt.yticks([])

<IPython.core.display.Javascript object>

([], [])

In [191]:
np.save('/dls/science/groups/imaging/ePSIC_students/Andy_Bridger/PaperDataRepo/SimulatedData/bestmap-amorph-pca.npy', new_hull_map)

In [194]:
new_hull_map = np.load('/dls/science/groups/imaging/ePSIC_students/Andy_Bridger/PaperDataRepo/SimulatedData/bestmap-amorph-pca.npy')

In [195]:
sample.all_maps['refine1'] = new_hull_map.astype('int')

sample.get_map_patterns('refine1', method = 'mean', recompute=True)

(512, 512)


<IPython.core.display.Javascript object>

Paper Figures

In [196]:
gtmap = np.load('/dls/science/groups/imaging/ePSIC_students/Andy_Bridger/PaperDataRepo/SimulatedData/SimulatedDSA-gtmap.npy')

In [197]:
pattern_locations = ([15,20],[32,50], [90,50],[146,50], [202, 50],[260,50], [80,139], [61,139],[46,139],
                     [210,139], [228, 139],[242,139], [60,243],[150,243],[250,243])
text_color = ['white','black', 'black', 'black', 'black', 'black', 'black','black', 'black', 
              'black', 'black', 'black', 'white','white','white']

In [198]:
plt.figure(figsize = (8,8))
plt.imshow(gtmap, cmap = 'turbo', interpolation = 'nearest')
for i, coord in enumerate(pattern_locations):
    plt.annotate(str(i), coord, c = text_color[i], fontsize = 13)
plt.xticks([])
plt.yticks([])

<IPython.core.display.Javascript object>

([], [])

In [199]:
fig,ax = plt.subplots(4,4, figsize = (8,8))
for i, coord in enumerate(pattern_locations):
    ax_c_x = i//4
    ax_c_y = i - 4*ax_c_x
    set_ax = ax[ax_c_x, ax_c_y]
    print(coord)
    pattern = sample.raw_data.data[coord[1], coord[0]][8:-8, 8:-8]
    set_ax.imshow(pattern, cmap = 'gray', vmax = 3)
    set_ax.annotate(str(i), (5,12), c = 'white')
    set_ax.set_frame_on(False)
    set_ax.set_xticks([])
    set_ax.set_yticks([])

ax_c_x, ax_c_y = 3,3
set_ax = ax[ax_c_x, ax_c_y]
set_ax.set_frame_on(False)
set_ax.set_xticks([])
set_ax.set_yticks([])

<IPython.core.display.Javascript object>

[15, 20]
[32, 50]
[90, 50]
[146, 50]
[202, 50]
[260, 50]
[80, 139]
[61, 139]
[46, 139]
[210, 139]
[228, 139]
[242, 139]
[60, 243]
[150, 243]
[250, 243]


[]

In [19]:
sample.all_maps['gt'] = gtmap

In [89]:
sample.raw_data

<LazySignal2D, title: , dimensions: (296, 295|128, 128)>