## Example: Learning the state transitions of a triple-dot array.

In this notebook we want to show how our algorithm can be used on a tripple-dot array to discover the transition of a given state. We choose this example, because it can be easily visualized in 3D still.


In [8]:
#import our simulation
from simulator import sim_NxM, Simulator

#for the algorithm
import numpy as np
from fit_zero_boundary import learn_zero_polytope
from fit_convex_polytope import learn_convex_polytope, generate_transitions_for_state, sample_boundary_points

#for plotting
import plotly.graph_objects as go
from scipy.spatial import HalfspaceIntersection
from scipy.optimize import linprog

#some helper code:

#function that finds a point inside a given polytope
def find_feasible_point(halfspaces):
    norm_vector = np.reshape(np.linalg.norm(halfspaces[:, :-1], axis=1), (halfspaces.shape[0], 1))
    c = np.zeros((halfspaces.shape[1],))
    c[-1] = -1
    A = np.hstack((halfspaces[:, :-1], norm_vector))
    b = - halfspaces[:, -1:]
    res = linprog(c, A_ub=A, b_ub=b, bounds = (None,None))
    return res.x[:-1]

def plot_facet(halfspaces, labels = None, points = None):
    startpoint = find_feasible_point(halfspaces)
    vertices=HalfspaceIntersection(halfspaces, startpoint).intersections
    
    
    
    fig = go.Figure()
    fig.add_trace(go.Mesh3d(x=vertices[:, 2],
                           y=vertices[:, 1],
                           z=vertices[:, 0],
                           color="yellow",
                           opacity=1.0,
                           alphahull=0))
    for i in range(halfspaces.shape[0]):
        Ai = halfspaces[i,:-1]
        bi = halfspaces[i,-1]
        distance = np.max(np.abs(vertices@Ai.reshape(-1,1)+ bi.reshape(1,-1)),axis=1)
        vi = vertices[distance < 1.e-5]

        mean = np.mean(vi,axis=0)
        
        #add annotation
        if not labels is None:
            text_pos = mean + 0.07*Ai/np.linalg.norm(Ai)
            arrow_end = mean + 0.05*Ai/np.linalg.norm(Ai)
            text = str(labels[i])
            #annotations.append(annotation)
            fig.add_trace(go.Scatter3d(x=mean[2:3],
                            y=mean[1:2],
                            z=mean[0:1],
                            showlegend=False,
                            mode="markers",
                            marker=dict(color= 'black',size=3)))
            fig.add_trace(go.Scatter3d(x=[mean[2],arrow_end[2]],
                        y=[mean[1],arrow_end[1]],
                        z=[mean[0],arrow_end[0]],
                        mode='lines',
                        showlegend=False,
                        line=dict(color= 'black', width=2)))
            fig.add_trace(go.Scatter3d(x=text_pos[2:3],
                            y=text_pos[1:2],
                            z=text_pos[0:1],
                            showlegend=False,
                            textfont=dict(
                                color="black",
                                size=14
                            ),
                            mode="text", textposition="middle center",
                            marker=dict(color= 'black',size=3), text=[text]))
        #plot mesh
        for j in range(vi.shape[0]):
            for k in range(j+1,vi.shape[0]):
                #check if midpoint is inside to filter out lines inside a facet.
                mid = (vi[j]+vi[k])/2
                di = halfspaces[:,:-1]@mid+ halfspaces[:,-1]
                di[i] = -1
                if np.all(di<-1.e-5):
                    continue
                fig.add_trace(go.Scatter3d(x=vi[[j,k], 2],
                        y=vi[[j,k], 1],
                        z=vi[[j,k], 0],
                        mode='lines',
                        showlegend=False,
                        line=dict(color= 'black', width=2)))
    
    if not points is None:
        fig.add_trace(go.Scatter3d(x=points[:,2],
                            y=points[:,1],
                            z=points[:,0],
                            mode='markers', marker=dict(color= 'red',size=3)))
    fig.update_layout(scene=dict( xaxis_title=r'V<sub>3</sub>', yaxis_title=r'V<sub>2</sub>',zaxis_title=r'V<sub>1</sub>'))
    fig.show()

## Step 1: initializing the simulation

First, we create the simulation. We simulate a 1x3 array according to the constant interaction model and plot the (1,1,1) facet

In [9]:
# Set a seed for reproducability
np.random.seed(0)

delta = 1.0/1000 #precision of the line-search. 1mV
rho=1.0/10 #setting of rho in the paper as a measure of simulation complexity. this is what rho=1 in the paper is
sim = sim_NxM(1, 3, delta, rho) #simulation of an 1x3 array.
sim.set_reservoir(True)

#plot the 1,1,1 facet
sim.activate_state([1,1,1]) #activate state simulations
#ground truth infor
A=sim.boundaries().A #normals of transitions of simulated state
b=sim.boundaries().b #offsets of transitons
labels=sim.boundaries().labels #labels of transitions. Same as t in the paper
print("labels of transitions present in the state")
print(labels)
halfspaces=np.hstack([A,b.reshape(-1,1)]) #if this polytope was open, we would need to add lower bounds as well
print("plot of the device")
plot_facet(halfspaces, labels)




labels of transitions present in the state
[[ 1. -1.  1.]
 [-1.  1. -1.]
 [-1.  0.  0.]
 [-1.  0.  1.]
 [-1.  1.  0.]
 [ 0. -1.  0.]
 [ 0. -1.  1.]
 [ 0.  0. -1.]
 [ 0.  0.  1.]
 [ 0.  1. -1.]
 [ 0.  1.  0.]
 [ 1. -1.  0.]
 [ 1.  0. -1.]
 [ 1.  0.  0.]]
plot of the device


## Computing Gamma

To compute gamma, we first learn the boundaries of the state (0,0,0). We then label the learned transitions based on which we think adds one electron to the i-th dot.

For a quick verification that the learned boundaries are correct, we compute the relative error of each element. 

In [10]:
#set the simulation to the polytope of the zero state
sim.activate_state([0,0,0])

#learning the polytope belonging to the 0,0 state has its own algorithm
num_start_samples = 4*sim.num_dots*(sim.num_dots+5) #number of initial line-searches (same as in the paper)
halfspaces, _,_,_ = learn_zero_polytope(sim, delta, num_start_samples = num_start_samples)

# A priori, we do not know which transition (row of gamma) belongs to which dot
# We're assuming that each dot has one sensor (dot), so we order the rows s.t. for dot i gamma_ii is the largest element
gamma_order = []
for i in range(sim.num_dots):
    jmax = -1
    maxv = -np.infty
    for j in range(sim.num_dots):
        if j in gamma_order: continue
        if np.abs(halfspaces[j,i]) > maxv:
            maxv = np.abs(halfspaces[j,i])
            jmax = j
    gamma_order.append(jmax)
halfspaces=halfspaces[gamma_order,:]


halfspaces /= np.linalg.norm(halfspaces[:,:-1],axis=1).reshape(-1,1) #normalize to unit length normals
gamma = halfspaces[:,:-1]#normal
b = halfspaces[:,-1]#offset

print("relative error in %")
gamma_truth = sim.boundaries().A/np.linalg.norm(sim.boundaries().A,axis=1).reshape(-1,1)
print(100*np.abs(gamma_truth-gamma)/gamma)

relative error in %
[[3.22572896e-03 8.89552902e-02 3.99778413e-01]
 [2.06431692e-02 1.02582115e-03 2.56285294e-03]
 [1.32272239e-01 3.94747427e-03 1.32173837e-05]]


### Now learn all of the other polytopes that are relevant to the target states

We first search for all transitions in the state. for this, we call ``generate_transitions_for_state`` with arguments ``max_k=3`` and ``max_moves = 2``. ``max_k`` indicates the maximum number of sites in the array that can be changed by a transition. ``max_moves`` limits the amount of electrons that can move in a transition. e.g., transition (1,-1,1) has at least 2 moving electrons: one moving from the middle to the sides and then another added to the array.

Afterwards we plot the learned facet and repeat the array with only one electron moving at most (And thus at most 2 sites can be affected). This excludes two facets in our example and we show that we can still learn the polytope reasonably well.

In [12]:
# Set a seed for reproducability
np.random.seed(0)

#the states we want to compute polytopes for
target_state=np.array([1,1,1],dtype=np.int64)

# Set device in the target state
print("\nComputing polytope for state ", target_state)
sim.activate_state(target_state)
    
# Generate the transitions we want to look for
T = generate_transitions_for_state(target_state, max_k=3, max_moves = 2)
print("searching for transitions")
print(T)

# we need a starting point to create a polytope
#this is slightly cheating, since we compute a point from the ground turth polytope
startpoint = sim.boundaries().point_inside
#this is less cheating and how we do it in the paper.
#startpoint, _, _ = sim.line_search(sim.boundaries().point_inside, np.random.randn(sim.num_inputs))
#startpoint = 0.95*startpoint + 0.05 * sim.boundaries().point_inside
print("start point:", startpoint)
# Learn the polytope (set verbose = 2 for more output, such as the convergence of the max likelihood)
A_res, b_res, x_m, x_p, found, num_searches, params = learn_convex_polytope(sim, delta, startpoint, T.astype(float), gamma, max_searches=500, verbose=1)


Computing polytope for state  [1 1 1]
searching for transitions
[[-1  0  0]
 [ 0 -1  0]
 [ 0  0 -1]
 [ 0  0  1]
 [ 0  1  0]
 [ 1  0  0]
 [-1  0  1]
 [-1  1  0]
 [ 0 -1  1]
 [ 0  1 -1]
 [ 1 -1  0]
 [ 1  0 -1]
 [-1 -1  1]
 [-1  1 -1]
 [-1  1  1]
 [ 1 -1 -1]
 [ 1 -1  1]
 [ 1  1 -1]]
start point: [0.11844219 0.12562139 0.12921188]
Number of searches:  49 / 500
Number of searches:  101 / 500
Number of searches:  152 / 500
Number of searches:  195 / 500
Number of searches:  224 / 500
Number of searches:  251 / 500
Finished learning polytope
Number of transitions found: 12
max_rad not found: 0.0006721040548163164
[-1.  0.  0.] 24 0.06995347288378195
[ 0. -1.  0.] 18 0.06413354421148029
[ 0.  0. -1.] 24 0.06473461762064453
[0. 0. 1.] 23 0.06464249956784385
[0. 1. 0.] 19 0.06411005395014217
[1. 0. 0.] 24 0.07004599107118795
[-1.  1.  0.] 20 0.007228898344261156
[ 0. -1.  1.] 21 0.006720647605911723
[ 0.  1. -1.] 22 0.006696025305421172
[ 1. -1.  0.] 18 0.007201190157847279
[-1.  1. -1.] 11 0.0

ValueError: not enough values to unpack (expected 7, got 6)

## Plot

We plot the learned polytope together with measurement points

In [None]:
halfspaces=np.hstack([A_res[found],b_res[found].reshape(-1,1)]) #if this polytope was open, we would need to add lower bounds as well
plot_facet(halfspaces, T[found], points=x_p)

## Excluding facets to search

In [None]:
# Set a seed for reproducability
np.random.seed(0)
    
# Generate the transitions we want to look for
#This time we set max_k=2 and max:moves = 1 which only allows a single electron to ove during a transition.
#This ecludes the [-1 1 -1] and [1 -1 1] states above.
T = generate_transitions_for_state(target_state, max_k=2, max_moves = 1)
print("searching for transitions")
print(T)

# Learn the polytope (set verbose = 2 for more output, such as the convergence of the max likelihood)
A_res, b_res, x_m, x_p, found, num_searches, params = learn_convex_polytope(sim, delta, startpoint, T.astype(float), gamma, max_searches=500, verbose=1)

In [None]:
halfspaces=np.hstack([A_res[found],b_res[found].reshape(-1,1)]) #if this polytope was open, we would need to add lower bounds as well
plot_facet(halfspaces, T[found], points=x_p)