In [None]:
from pathlib import Path
import numpy as np
import copy as cp
from gpcam.gp_optimizer import fvGPOptimizer
import matplotlib.pyplot as plt
from numpy.random import default_rng
import time
from typing import *
from tqdm.auto import tqdm

import dataloader as dl 


# define helper functions

In [None]:
def aqf_multid(x, gp):
    a = 2.0  #####change here, 3.0 for 95 percent confidence interval
    norm = 1.0
    ret = None
    for i in range(gp.input_dim-1):
        x_pred=np.c_[x,np.zeros(x.shape[0])+i].reshape(-1,gp.input_dim)
        cov = gp.posterior_covariance(x_pred)["v(x)"]
        if ret is None:
            ret = cov
        else:
            ret += cov
            
    ret=a * np.sqrt(ret)

    for i in range(gp.input_dim-1):
        x_pred=np.c_[x,np.zeros(x.shape[0])+i].reshape(-1,gp.input_dim)
        mean = gp.posterior_mean(x_pred)["f(x)"]
        ret += norm * mean

    return ret

def init_gp(
        points,
        values,
        index_set_bounds,
        hyperparameter_bounds,
        hps_guess,
        vp,
        device:Literal['cpu','gpu']='cpu',
    ) -> fvGPOptimizer:
    gp = fvGPOptimizer(input_space_dimension=2, output_space_dimension=1, output_number=2, input_space_bounds=index_set_bounds, )
    gp.tell(points,values,value_positions=vp)
    gp.init_fvgp(hps_guess,compute_device=device)
    gp.train_gp(hyperparameter_bounds,pop_size = 20,tolerance = 1e-6,max_iter = 2)
    return gp

def find_next(
        points,
        values,
        gp,
        hyperparameter_bounds,
        maxnum:int=None,
        dask_client=None,
        method:Literal['global','local']='global',
    ) -> Tuple[int,int]:
    if maxnum is None:
        gp.tell(points,values)
    else:
        gp.tell(points[-maxnum:],values[-maxnum:])
    gp.train_gp(
        hyperparameter_bounds=hyperparameter_bounds,
        pop_size = 20,
        method=method,
        tolerance = 1e-6,
        max_iter = 2,
        dask_client = dask_client,
        # device=device,
    )
        # gp.train_gp_async(hyperparameter_bounds, max_iter = 10000, dask_client = None)

    new = gp.ask(
        position = None, 
        n = 1, 
        acquisition_function = aqf_multid, 
        bounds = None,
        method=method, 
        pop_size = 20, 
        max_iter = 20, 
        tol = 10e-6, 
        x0 = None, 
        dask_client = dask_client,
        # device=device,
    )
    x,y = np.round(new['x'][0]).astype(int)
    return x,y

def measure_next(a,new) -> np.ndarray:
    """ get the value of the next point"""
    return a[:,new[0],new[1]]
    # newval = [new[0]*map_shape[1]+new[1]]
    # assert np.allclose(new,newval[:2]),"coordinates don't match"
    # points.append(new)
    # values.append(a[new[0]*250+new[1]])
    # return newval[2:]
    # new,a[new[0]*250+new[1]][2]

def run_gp(init_points,init_values,device,method,dask_client=None):
    """Run the GP for one iteration.
    
    Args:
        points (np.ndarray): The points that have been measured so far.
        values (np.ndarray): The values that have been measured so far.
        gp (GaussianProcess): The GP that is used to model the data.
        device (str): The device that is used for the GP.
        method (str): The method that is used for the GP.
    
    Returns:
        Tuple[np.ndarray,np.ndarray]
    """
    gp = init_gp(
        init_points,init_values,index_set_bounds,hyperparameter_bounds,hps_guess,vp,
        device=device,
    )
    # info = []
    times = []
    points = points.copy()
    values = values.copy()
    for i in tqdm(range(n_iterations)):
        try:
            t0 = time.time()
            newpoint = find_next(
                points,
                values,
                gp,
                method=method,
                hyperparameter_bounds=hyperparameter_bounds,
                dask_client=dask_client,
            )
            newval = measure_next(reduced_data,newpoint)
            values = np.append(values,np.array(newval)[None,:],axis=0)
            points = np.append(points,np.array(newpoint)[None,:],axis=0)
            times.append(time.time()-t0)
            # ig = gp.shannon_information_gain(np.append(points[-1],np.array([0])))['prior entropy']
            # ig += gp.shannon_information_gain(np.append(points[-1],np.array([1])))['prior entropy']
            # info.append(ig)
        except KeyboardInterrupt:
            break
    return points, values, times


# Load SGM4 data

In [None]:
source_file =  Path(r"D:\data\SGM4\SmartScan\Z006_35_0.h5")
ldr = dl.load(source_file)

In [None]:
xdata = ldr.to_xarray()

In [None]:
xdata

In [None]:
def hist_variance(data):
    # [hist_variance(xdata.values[i,j,...]) for i,j in product(range(xdata.shape[0]),range(xdata.shape[1]))] 
    hist = np.histogram(data,bins=100)
    return np.sqrt(np.var(hist[0]))

In [None]:
from scipy.stats import entropy
from scipy.stats import norm
from scipy.stats import gaussian_kde
from scipy.ndimage import gaussian_filter


In [None]:
def eval_map(xdata,func,**kwargs) -> np.ndarray:
    out = np.zeros(xdata.shape[:2])
    pts = np.array(list(product(range(xdata.shape[0]),range(xdata.shape[1]))))
    for i,j in tqdm(pts,leave=False):
        out[i,j,...] = func(xdata.values[i,j,...],**kwargs)
    return out


In [None]:
def hist_entropy(data):
    hist = np.histogram(data/data.sum(),bins=100)
    return entropy(hist[0])

In [None]:
def hist_gradient_entropy(data):
    vals = data.ravel()
    tol = 0
    grad = np.gradient(vals[vals>tol])
    smooth_grad = gaussian_filter(grad,sigma=3)
    hist = np.histogram(smooth_grad,bins=100)
    return entropy(hist[0])

In [None]:
fig,ax = plt.subplots(1,3,figsize=(10,5),layout='constrained')
# def 
vals = xdata.values[20,20,...]
tol = 0
# squared sum of gradient
grad = np.sqrt(np.power(np.gradient(vals)[0],2.)+np.power(np.gradient(vals)[1],2))
# grad = np.gradient(vals)[0]
# grad[vals<tol] = 0
smooth_grad = gaussian_filter(grad,sigma=5)
smooth_vals = gaussian_filter(vals,sigma=5)
ax[0].imshow(smooth_vals)
ax[1].imshow(smooth_grad)#/smooth_grad.mean() - smooth_vals/smooth_vals.mean())#.reshape(xdata.shape[2:]))
# hist = np.histogram(smooth_grad,bins=100)
ax[2].hist(smooth_grad.ravel(),bins=100);

In [None]:
xdata.values[20,20,...].shape

In [None]:
vals[vals>tol].shape

In [None]:
reduced_data = np.array([
    xdata.mean(['Kinetic Energy','OrdinateRange']),
    # eval_map(xdata,hist_variance),
    eval_map(xdata,hist_gradient_entropy),
    # xdata.std(['Kinetic Energy','OrdinateRange']),
    
])


In [None]:
fig,ax = plt.subplots(1,2,figsize=(10,5))
ax[0].imshow(reduced_data[0])
ax[1].imshow(reduced_data[1])
# add colorbars
fig.colorbar(ax[0].imshow(reduced_data[0]),ax=ax[0])
fig.colorbar(ax[1].imshow(reduced_data[1]),ax=ax[1])

# test variance evaluation of focus

In [None]:
idx_1 = np.random.randint(0,xdata.FSamX.size), np.random.randint(0,xdata.FSamY.size)
idx_2 = np.random.randint(0,xdata.FSamX.size), np.random.randint(0,xdata.FSamY.size)
print(idx_1,idx_2)
img1 = xdata.isel(FSamX=idx_1[0],FSamY=idx_1[1]).values
img2 = xdata.isel(FSamX=idx_2[0],FSamY=idx_2[1]).values
fig,ax = plt.subplots(1,2)
ax[0].imshow(img1)
ax[1].imshow(img2)


In [None]:
# histogram
fig,ax = plt.subplots(1,1)
ax.hist((img1/img1.sum()).ravel(),bins=100,color='r',alpha=.5,label='1')
ax.hist((img2/img2.sum()).ravel(),bins=100,color='b',alpha=.5,label='2')
ax.set_yscale('log')
ax.legend()
ax.set_ylim(1e2,1e6)

In [None]:
# histogram variance
img_1_hist = np.histogram(img1,bins=100)
img_1_var = np.sqrt(np.var(img1))
img_2_var = np.sqrt(np.var(img2))
print(img_1_var,img_2_var)

In [None]:
from itertools import product

In [None]:
all_values = reduced_data.reshape(2,-1).T
all_points = np.array(tuple(product(range(xdata.shape[0]),range(xdata.shape[1]))))

In [None]:
len(all_points) == len(all_values) == np.prod(xdata.shape[:2])

In [None]:
fig,ax = plt.subplots(1,2,figsize=(10,5))
ax[0].imshow(reduced_data[0])
ax[1].imshow(reduced_data[1])
# add colorbars
fig.colorbar(ax[0].imshow(reduced_data[0]),ax=ax[0])
fig.colorbar(ax[1].imshow(reduced_data[1]),ax=ax[1])

In [None]:
fig,ax = plt.subplots(1,3,figsize=(15,5))
ax[0].imshow(xdata.values[20,20,...])
ax[1].imshow(xdata.values[30,20,...])
ax[2].imshow(xdata.values[50,15,...])

In [None]:
map_shape = reduced_data.shape[1:]

# Initialize the gp

In [None]:

rng = default_rng()
ind = rng.choice(len(all_points)-1, size=5, replace=False)
init_points = all_points[ind]
init_values = all_values[ind]
print(f"init_points: {len(init_points)} {init_points[0]} -> {init_points[-1]}")
print(f"init_values: {len(init_values)} {init_values[0]} -> {init_values[-1]}")
print("x_min ", np.min(init_points[:,0])," x_max ",np.max(init_points[:,0]))
print("y_min ", np.min(init_points[:,1])," y_max ",np.max(init_points[:,1]))
print("val_min ", np.min(init_values[:,1])," val_max ",np.max(init_values[:,1]))
print("length of data set: ", len(init_values))

index_set_bounds = np.array([[0,map_shape[0]-1],[0,map_shape[1]-1]])
hyperparameter_bounds = np.array([[0.001,1e9],[1,1000],[1,1000],[1,1000],[1,1000]])
hps_guess = np.array([4.71907062e+06, 4.07439017e+02, 3.59068120e+02,4e2,4e2])

z=[[[0],[1]]]
vp = np.array(z*len(ind))



In [None]:
fig,ax = plt.subplots(1,2,figsize=(10,5))
ax[0].imshow(reduced_data[0])
ax[1].imshow(reduced_data[1])
ax[0].scatter(init_points[:,1],init_points[:,0],c='r')
ax[1].scatter(init_points[:,1],init_points[:,0],c='r')

In [None]:
n_iterations = 100

# CPU version

In [None]:
info_cpu_g, times_cpu_g, points_cpu_g, values_cpu_g = run_gp(init_points,init_values,'cpu','global')


In [None]:
fig,ax = plt.subplots(1,1,figsize=(8,4),sharex=True,sharey=True,layout='constrained')

ax.imshow(reduced_data[0,...],alpha=0.5,cmap='Reds')
ax.imshow(reduced_data[1,...],alpha=0.5,cmap='Blues')
ax.scatter(init_points[:,1],init_points[:,0],c='k',marker='o',label='initial points')
ax.scatter(points_cpu_g[:,1],points_cpu_g[:,0],c='r',marker='x',label='cpu global')
ax.set_title('cpu global')

In [None]:
info_cpu_g, times_cpu_g, points_cpu_g, values_cpu_g = run_gp(init_points,init_values,'cpu','global')
info_cpu_l, times_cpu_l, points_cpu_l, values_cpu_l = run_gp(init_points,init_values,'cpu','local')
info_gpu_g, times_gpu_g, points_gpu_g, values_gpu_g = run_gp(init_points,init_values,'gpu','global')
info_gpu_l, times_gpu_l, points_gpu_l, values_gpu_l = run_gp(init_points,init_values,'gpu','local')

# Plot Results

In [None]:
fig,ax = plt.subplots(1,2,figsize=(8,4))
ax[0].plot(times_cpu_g,label='cpu global')
ax[0].plot(times_cpu_l,label='cpu local')
ax[0].plot(times_gpu_g,label='gpu global')
ax[0].plot(times_gpu_l,label='gpu local')
ax[0].legend()
ax[0].set_title('time per iteration')
ax[0].set_xlabel('iteration')
ax[0].set_ylabel('time [s]')
# reset colors
ax[1].set_prop_cycle(None)
ax[1].plot(np.cumsum(times_cpu_g),info_cpu_g,label='cpu global')
ax[1].plot(np.cumsum(times_cpu_l),info_cpu_l,label='cpu local')
ax[1].plot(np.cumsum(times_gpu_g),info_gpu_g,label='gpu global')
ax[1].plot(np.cumsum(times_gpu_l),info_gpu_l,label='gpu local')
ax[1].legend()
ax[1].set_xlabel('time [s]')
ax[1].set_ylabel('information gain')
ax[1].set_title('cumulative time vs. information gain')


In [None]:
fig,axes = plt.subplots(2,2,figsize=(8,4),sharex=True,sharey=True,layout='constrained')
for ax in axes.flatten():
    ax.imshow(a[:,2].reshape(map_shape),alpha=0.5,cmap='Reds')
    ax.imshow(a[:,3].reshape(map_shape),alpha=0.5,cmap='Blues')
axes[0,0].scatter(init_points[:,1],init_points[:,0],c='k',marker='x',label='initial points')
axes[0,0].scatter(points_cpu_g[:,1],points_cpu_g[:,0],c='r',marker='x',label='cpu global')
axes[0,0].set_title('cpu global')

axes[0,1].scatter(points_cpu_l[:,1],points_cpu_l[:,0],c='b',marker='x',label='cpu local')
axes[0,1].set_title('cpu local')

axes[1,0].scatter(points_gpu_g[:,1],points_gpu_g[:,0],c='r',marker='x',label='gpu global')
axes[1,0].set_title('gpu global')

axes[1,1].scatter(points_gpu_l[:,1],points_gpu_l[:,0],c='b',marker='x',label='gpu local')
axes[1,1].set_title('gpu local')

    
# ax[0,0].imshow(values_cpu_g[:,0].reshape(100*3,250*3))

# Implement dask client

In [None]:
# import dask
# import dask.array as da
# import dask.dataframe as dd
# dask client/server
from dask.distributed import Client, LocalCluster


In [None]:
dask_server = LocalCluster(n_workers=4,threads_per_worker=1)

In [None]:
dask_server

In [None]:
info_cpu_g_dask, times_cpu_g_dask, points_cpu_g_dask, values_cpu_g_dask = run_gp(init_points,init_values,'cpu','global',dask_client=dask_server)


In [None]:
info_cpu_g_dask_8, times_cpu_g_dask_8, points_cpu_g_dask_8, values_cpu_g_dask_8 = run_gp(init_points,init_values,'cpu','global',dask_client=dask_server)


In [None]:
info_cpu_g_dask_16, times_cpu_g_dask_16, points_cpu_g_dask_16, values_cpu_g_dask_16 = run_gp(init_points,init_values,'cpu','global',dask_client=dask_server)


In [None]:
info_cpu_g_dask_1, times_cpu_g_dask_1, points_cpu_g_dask_1, values_cpu_g_dask_1 = run_gp(init_points,init_values,'cpu','global',dask_client=dask_server)

In [None]:
info_cpu_g_dask_1, times_cpu_g_dask_1, points_cpu_g_dask_1, values_cpu_g_dask_1 = run_gp(init_points,init_values,'gpu','global',dask_client=dask_server)


In [None]:
times = [
    times_cpu_g,
    times_cpu_g_dask,
    times_cpu_g_dask_8,
    times_cpu_g_dask_16,
    times_cpu_g_dask_1,
    times_cpu_l,
    times_gpu_g,
    times_gpu_l,
]
infos = [
    info_cpu_g,
    info_cpu_g_dask,
    info_cpu_g_dask_8,
    info_cpu_g_dask_16,
    info_cpu_g_dask_1,
    info_cpu_l,
    info_gpu_g,
    info_gpu_l,
]
labels = [
    'cpu global',
    'cpu global dask',
    'cpu global dask 8',
    'cpu global dask 16',
    'cpu global dask 1',
    'cpu local',
    'gpu global',
    'gpu local',
]

In [None]:
%matplotlib widget

In [None]:
from scipy.signal import savgol_filter, butter, filtfilt

In [None]:
fig,ax = plt.subplots(1,2,figsize=(8,4))
for i,time_,info,label in zip(range(len(times)),times,infos,labels):
    if i == 5:
        break
    l = ax[0].plot(time_,'.',alpha=0.5)
    # plot filtered with same color
    time_filt = savgol_filter(time_,51,3)
    time_filt = filtfilt(*butter(3,.3),time_)
    ax[0].plot(time_filt,label=label,c=l[0].get_color(),alpha=1)
    ax[1].plot(np.cumsum(time_),info,'.',alpha=.5)
    ax[1].plot(np.cumsum(time_filt),info,label=label,c=l[0].get_color(),alpha=1)

ax[0].legend(fontsize='x-small')
ax[0].set_title('time per iteration')
ax[0].set_xlabel('iteration')
ax[0].set_ylabel('time [s]')
ax[1].legend()
ax[1].set_xlabel('time [s]')
ax[1].set_ylabel('information gain')
ax[1].set_title('cumulative time vs. information gain')
