# Interpolation of trajectories

This notebook takes a csv of discretized trajectories and interpolates those trajectories such that all subsequent states are connected via a single action. Interpolating is done according to a Boltzmann distribution.

In [2]:
import sys
import os


import numpy as np
import pandas as pd
from insect_rl.mdp.utils import grid_math, algebra, policies

red = '\033[91m'
green = '\033[92m'
yellow = '\033[93m'
blue = '\033[94m'
pink = '\033[95m'
teal = '\033[96m'
grey = '\033[97m'
black = '\033[0m'

In [3]:
def combine_utilities(actions, utilities):
    utilities = {
                    action: sum(pressure * utility[action] for pressure, utility in utilities)
                        for action in actions
                }
    return utilities


def calculate_utilities(actions, delta):
    # utility of the actions according to how similar they are to the homing vector
    similarity_to_delta = {action: algebra.similarity(action, delta) for action in actions}
    #move_cost = {action: np.linalg.norm(action) for action in actions}
    # mix utility from homing vector with utility from memory
    return similarity_to_delta


def _interpolate(s1, s2, actions):
    print(f"\n fun interpol. {s1} {s2}")
    delta = s2 - s1
    
    utils = calculate_utilities(actions, delta)
    boltzmann = policies.boltzmann(actions, utils)
    print("delta", delta)
    print("utils", utils)
    print("a", actions)
    print("boltz", boltzmann)
    rng = np.random.default_rng()
    return rng.choice(actions, p=boltzmann)

def valid(a, s, a_idx, transition_probs):
    can_move = transition_probs[grid_math.point_to_int(s, width),a_idx,:]
    return (0 <= s[0] + a[0] < width) and (0 <= s[1] + a[1] < height) and (any(can_move))

def _interpolate_max(s1, s2, actions, transition_probs):
    if tuple(s1) in traps and tuple(s2) not in traps:
        #print(trap_exits[1][0][0])
        if s1[0] > trap_exits[1][0][0]:
            return (-1,0)
        elif s1[0] < trap_exits[0][0][0]:
            return (1,0)
        else:
            return (0,1)
        
    #print(f"\n fun interpol. {s1} {s2}")
    delta = s2 - s1
    actions = tuple(filter(lambda a: valid(a, s1, actions.index(a), transition_probs), actions))
    
    utils = calculate_utilities(actions, delta)
    boltzmann = policies.boltzmann(actions, utils)
    #print("boltz", boltzmann)
    
    boltzmann = [int(x==max(boltzmann)) for x in boltzmann]
    #print("delta", delta)
    #print("utils", utils)
    #print("a", actions)
    boltzmann = [p / sum(boltzmann) for p in boltzmann]
    #print("boltz", boltzmann)
    
    rng = np.random.default_rng()
    return rng.choice(actions, p=boltzmann)


def ok(s1, s2, actions, transition_probs):
    #print(s1, s2, actions)
    s = grid_math.point_to_int(s1, width)
    u = grid_math.point_to_int(s2, width)
    #print(transition_probs[s,:,u])
    #print(transition_probs[s,:,u].shape)
    return any(transition_probs[s,:,u])


def interpolate(trajectory, actions, transition_probs):
    traj_iter = iter(trajectory)
    s1 = next(traj_iter)
    s2 = next(traj_iter)
    interpolated = [s1]
    while s2 is not None:
        #print(s1, s2)
        if snakemake.wildcards.condition == "trap":
            if s1[1] > s2[1]:
                #print("replace!")
                s2 = next(traj_iter, None)
        if ok(s1,s2,actions,transition_probs):#grid_math.neighbouring(s1, s2, actions):
            interpolated.append(tuple(s2))
            #print(f"\033[91m neighboring \033[0m", interpolated)
            #break
            
            s1 = s2
            s2 = next(traj_iter, None)
        else:
            inter = _interpolate_max(np.array(s1), np.array(s2), actions, transition_probs)
            from_ = s1
            s1 = (s1[0] + inter[0], s1[1] + inter[1])
            interpolated.append(tuple(s1))
            #print(f"interpolated from {from_} with action {inter} to {s1}", interpolated)

    return interpolated

In [4]:
df = pd.read_csv(snakemake.input[0])
actions = grid_math.__dict__[snakemake.config["actions"]]
width = 23
height = 60

if snakemake.wildcards.condition == "trap":
    traps = [(x,26) for x in range(2,20)] # M.bagoti 10 cm wide, 2m long, 10cm deep with 20cm wide exit board
    trap_exits =[((5,26), (5,27)), ((6,26), (6,27))]
else:
    traps = []
    trap_exits = []

print("Actions:", actions)

df

In [5]:
transition_probs = np.load(f"irl/{snakemake.wildcards.experiment}/{snakemake.wildcards.condition}/transition_probs_wind=0.0.npy")

In [6]:

trajectories_dfs = []
for ant in pd.unique(df["ant_nb"]):
    ant_df = df[df.ant_nb == ant]
    for run in pd.unique(ant_df["trial_nb"]):
    
        traj = ant_df[ant_df.trial_nb == run]
        
        
        traj = list(zip(traj.path_x, traj.path_y))
        #print("before:", traj)
        i = len(traj)
        for j, (x,y) in enumerate(traj):
            if y > 52:
                i = j
                break
        traj = traj[:i]
        if traj[-1] != (10, 52):
            traj.append((10, 52))
        #print("after:", traj)

        interpolated = interpolate(traj, actions, transition_probs)
        data = {
            'ant_nb': ant,
            'trial_nb': run,
            'path_x': [point[0] for point in interpolated],
            'path_y': [point[1] for point in interpolated]
        }
        
        trajectories_dfs.append(pd.DataFrame(data))

In [7]:
#pd.set_option('display.max_rows', None)
result = pd.concat(trajectories_dfs, ignore_index=True)
result

In [8]:
# save to csv
result.to_csv(snakemake.output[0], index=False)