# Generate all the consistency maps

### First run this cell

Import libraries

In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib notebook
%load_ext autoreload
%autoreload 2


#load some packages in
import numpy as np
import matplotlib.pyplot as plt
import random as python_random
from numba import njit
import hyperspy.api as hs
import json
import itertools
from skimage.metrics import structural_similarity as SSI
from stemutils.io import Path
import palettable
import matplotlib.colors as mcolors
from matplotlib.gridspec import GridSpec

define some functions

In [None]:
def get_map_label_df(map1):
    return np.asarray([np.where(map1 == uinds, 1, 0) for uinds in np.unique(map1)])

def get_cluster_label_overlap(map_pair):
    '''
    Takes a pair of sets of domain binary decompositions 
    
    Looks at each of the N binary decomp in set 1 and compares it to each of the M in set 2 
    and works out the number of shared pixels
    
    Returns the proportion of overlap between each label in 1 with each label in 2 as returns these as a NxM array
    
    '''
    
    
    db1_df,db2_df = map_pair
    label_overlap = np.zeros((db1_df.shape[0], db2_df.shape[0]))
    for i, idf in enumerate(db1_df):
        for j, jdf in enumerate(db2_df):
            label_overlap[i,j] = np.sum(db1_df[i] * db2_df[j])/ np.sum(db1_df[i])
    return label_overlap

def find_map_label(pos, map1):
    return map1[pos]

def get_confidence_from_maps(maps):
    dfs = [x for x in map(get_map_label_df, maps)] #get a list of lists of domain regions
    map1 = maps[0]
    
    #for each pair of maps get the cluster overlap proportion arrays
    cluster_overlaps = [x for x in map(get_cluster_label_overlap, [x for x in itertools.permutations(dfs, 2)])]

    #get all the permutations of pairs of maps in index form
    overlap_inds = [x for x in itertools.permutations(np.arange(len(maps)), 2)]

    len(cluster_overlaps)

    confidence = np.zeros_like(map1, dtype='float32')
    for point in range(len(map1)):
        #get a list of labels for each row of each map
        labels = [i for i in map(find_map_label, np.repeat(point, len(maps)) , maps)]
        total = 0
        for cind, oinds in enumerate(overlap_inds):
            #for each pair of map comparisons get the labels for this row
            l1, l2 = labels[oinds[0]], labels[oinds[1]]
            total+=cluster_overlaps[cind][l1, l2]
        mean = total/len(overlap_inds)
        confidence[point] = mean
    return confidence

def flatten_nav(sig):
    shape = [sig.shape[0]*sig.shape[1]]
    for i in sig.shape[2:]:
        shape.append(i)
    return sig.reshape(shape)


def plot_map_confs(gtmap, conf, save_root = None, **kwargs):
    mean_class_ssi = []
    conf_r, r_patts = [],[]
    for uind in np.unique(gtmap):
        rconf = np.round((np.where(gtmap == uind,1,0) * conf),1)
        frconf = flatten_nav(rconf.copy())
        patterns = []
        for confind in np.unique(frconf):
            if confind != 0:
                patterns.append(frd[np.where(frconf == confind)].mean(axis = 0))
        patterns = np.asarray(patterns)

        if patterns.shape[0] > 1:
            
            conf_r.append(rconf)
            r_patts.append(patterns)

            p_comb = [inds for inds in itertools.combinations(list(range(patterns.shape[0])),2)]

            class_ssi = np.mean([SSI(patterns[p_c[0]], patterns[p_c[1]]) for p_c in p_comb])
            mean_class_ssi.append(class_ssi)

            pgs = int(np.ceil(np.sqrt(patterns.shape[0])))

            fig = plt.figure()
            gs = GridSpec(pgs*2, pgs, figure = fig)
            ax0 = fig.add_subplot(gs[:pgs,:])
            ax0.imshow(rconf, cmap= cmap, interpolation = 'nearest')
            ax0.set_xticks([])
            ax0.set_yticks([])
            ax0.set_title(str(class_ssi))
            for ipatt, patt in enumerate(patterns):
                gsx = ipatt//pgs
                gsy = ipatt%pgs
                axp = fig.add_subplot(gs[pgs+gsx,gsy])
                axp.imshow(patt, cmap = 'gray', **kwargs)
                axp.set_xticks([])
                axp.set_yticks([])
            if save_root != None:
                fig.savefig(f'{save_root}/{uind}-conf_regions.jpg')
    return conf_r, r_patts

def eval_map_conf(gtmap, conf):
    mean_class_ssi = []
    for uind in np.unique(gtmap):
        rconf = np.round((np.where(gtmap == uind,1,0) * conf),1)
        frconf = flatten_nav(rconf)
        patterns = []
        for confind in np.unique(frconf):
            if confind != 0:
                patterns.append(frd[np.where(frconf == confind)].max(axis = 0))
        patterns = np.asarray(patterns)

        if patterns.shape[0] > 1:

            p_comb = [inds for inds in itertools.combinations(list(range(patterns.shape[0])),2)]

            class_ssi = np.mean([SSI(patterns[p_c[0]], patterns[p_c[1]]) for p_c in p_comb])
            mean_class_ssi.append(class_ssi)
        else:
            mean_class_ssi.append(1.0)

    return np.mean(mean_class_ssi)

Generate a list (dss) of all the datasets you want to generate the consistency maps for

In [None]:
ds_root = Path('/dls/science/groups/imaging/ePSIC_students/Andy_Bridger/mg28034-1/processing/Merlin/Calibrated/O3_pure/')

In [None]:
dss = [ds for ds in ds_root.ls() if str(ds.parts[-1]).replace(' ','').isnumeric()]

In [None]:
dss

First Test the workflow on a single dataset

In [None]:
for ds in dss[2]

Find all the map files that you want to use to generate the consistency map

In [None]:
domain_maps_root = ds.redirect('Refined_N_components',0)

In [None]:
domain_map_paths = domain_maps_root.walk('/mapdata', max_depth=2)

In [None]:
domain_map_paths

Find the path to the raw data as well

In [None]:
dp = ds.walk(f"{ds.parts[-1].split(' ')[-1]}.hdf5",'binned', max_depth=1)[0]

In [None]:
dp

Load in all the maps and then generate the confidence map

In [None]:
maps = [np.load(p, allow_pickle=True).astype('int8') for p in domain_map_paths]

In [None]:
conf = get_confidence_from_maps(maps)

Visualise

In [None]:
# Given colormap which takes values from 0→50
colors1 = palettable.colorbrewer.sequential.YlGn_9.mpl_colormap(np.linspace(0, 1, 256))
colors1[0] = [0.,0.,0.,1.]
# generating a smoothly-varying LinearSegmentedColormap
cmap = mcolors.LinearSegmentedColormap.from_list('colormap', colors1)

conf_fig = plt.figure(figsize = (8,8))
plt.imshow(conf, cmap= cmap, interpolation = 'nearest', vmin=0, vmax =1 )
plt.colorbar()
plt.xticks([])
plt.yticks([])

Save to a desired path

In [None]:
domain_maps_root.redirect('consistency_map.jpg',0)

In [None]:
conf_fig.savefig(domain_maps_root.redirect('consistency_map.jpg',0))

Load in the diffraction data

In [None]:
data = hs.load(dp)

In [None]:
frd = flatten_nav(data.data)

Find all the map directories and for each one visualise and save the consistency regions within each map region

In [None]:
ncomps = [int(x.parts[-1]) for x in domain_map_paths[0].redirect('',2).ls() if x.is_dir()]

In [None]:
for ncomp in ncomps:
    print(ncomp)
    comp_dir = domain_maps_root.redirect(f'{ncomp}/',0)
    gtmap = np.load(comp_dir.redirect('mapdata.npy',0))
    consistency_dir = comp_dir.redirect('ConsistencyRegions',0)
    consistency_dir.mk()
    conf_r, r_patts = plot_map_confs(gtmap,conf, save_root= consistency_dir, vmax = 3)

In [None]:
plt.close('all')

If that has all worked fine, repeat in a loop for all datasets

In [None]:
for ds in dss[3:]:

    domain_maps_root = ds.redirect('Refined_N_components',0)

    domain_map_paths = domain_maps_root.walk('/mapdata', max_depth=2)

    domain_map_paths

    dp = ds.walk(f"{ds.parts[-1].split(' ')[-1]}.hdf5",'binned', max_depth=1)[0]

    dp

    maps = [np.load(p, allow_pickle=True).astype('int8') for p in domain_map_paths]

    conf = get_confidence_from_maps(maps)



    #
    colors1 = palettable.colorbrewer.sequential.YlGn_9.mpl_colormap(np.linspace(0, 1, 256))
    colors1[0] = [0.,0.,0.,1.]
    # generating a smoothly-varying LinearSegmentedColormap
    cmap = mcolors.LinearSegmentedColormap.from_list('colormap', colors1)


    conf_fig = plt.figure(figsize = (8,8))
    plt.imshow(conf, cmap= cmap, interpolation = 'nearest', vmin=0, vmax =1 )
    plt.colorbar()
    plt.xticks([])
    plt.yticks([])

    conf_fig.savefig(domain_maps_root.redirect('consistency_map.jpg',0))

    data = hs.load(dp)

    frd = flatten_nav(data.data)

    data.data.shape[0]

    ncomps = [int(x.parts[-1]) for x in domain_map_paths[0].redirect('',2).ls() if x.is_dir()]

    for ncomp in ncomps:
        print(ncomp)
        comp_dir = domain_maps_root.redirect(f'{ncomp}/',0)
        gtmap = np.load(comp_dir.redirect('mapdata.npy',0))
        consistency_dir = comp_dir.redirect('ConsistencyRegions',0)
        consistency_dir.mk()
        conf_r, r_patts = plot_map_confs(gtmap,conf, save_root= consistency_dir, vmax = 3)

    plt.close('all')
    del data