In [4]:
import numpy as np
from scipy.special import kv  # modified Bessel function of the second kind: K_v
from typing import Tuple, Optional
import matplotlib.pyplot as plt

from matplotlib.animation import FuncAnimation
import ipywidgets as widgets

from IPython.display import display

In [None]:
def besselk3(bessel_matr, dimx, dimy, centerx, centery, ix, iy):
    """
    Return the submatrix (dimx x dimy) corresponding to position (ix,iy),
    where ix,iy are 1-based (MATLAB-style).
    Slicing is inclusive on both sides in MATLAB; Python slice end is exclusive,
    so add +1 to upper bounds.
    """
    # Convert to 0-based MATLAB-like offsets for slicing
    # MATLAB:
    #   rows: (centerx-(pos_x-1)) : (centerx+(dimx-pos_x))
    #   cols: (centery-(pos_y-1)) : (centery+(dimy-pos_y))
    # Python slices are [start:end), so +1 on end.
    row_start = centerx - (ix - 1)
    row_end   = centerx + (dimx - ix) + 1
    col_start = centery - (iy - 1)
    col_end   = centery + (dimy - iy) + 1
    return bessel_matr[row_start:row_end, col_start:col_end]

# Helper functions (use closures so they see current Lt/pA when needed)
def R(bessel_matr, dimx, dimy, centerx, centery, rate, lam, particle_size, ix, iy):
    # rate / log2(lambda / particle_size) * K0(...)
    denom = np.log(lam / particle_size)
    return (rate / denom) * besselk3(bessel_matr, dimx, dimy, centerx, centery, ix, iy)

def p0y(bessel_matr, dimx, dimy, centerx, centery, rate, lam, particle_size, ix, iy):  # probability of zero hits given position (ix,iy)
    return np.exp(-R(bessel_matr, dimx, dimy, centerx, centery, rate, lam, particle_size, ix, iy))

def p1y(bessel_matr, dimx, dimy, centerx, centery, rate, lam, particle_size, ix, iy):  # probability of a hit given position (ix, iy)
    return 1.0 - p0y(bessel_matr, dimx, dimy, centerx, centery, rate, lam, particle_size, ix, iy)

def p0(bessel_matr, dimx, dimy, centerx, centery, rate, lam, particle_size, ix, iy, pA_lin, Pt_lin): # probability of zero hits over belief space
    # (1-pA)*sum(Pt.*p0y) + pA
    prob_of_source_in_space = 1.0 - pA_lin
    expected_prob_of_receiving_zero_hits_over_belief_given_source_in_space = np.nansum(Pt_lin * p0y(bessel_matr, dimx, dimy, centerx, centery, rate, lam, particle_size, ix, iy))
    prob_of_zero_hits_and_source_in_space = (prob_of_source_in_space * expected_prob_of_receiving_zero_hits_over_belief_given_source_in_space)

    prob_of_source_not_in_space = pA_lin
    expected_prob_of_receiving_zero_hits_over_belief_given_source_not_in_space = 1
    prob_of_zero_hits_and_source_not_in_space = (prob_of_source_not_in_space * expected_prob_of_receiving_zero_hits_over_belief_given_source_not_in_space)
    
    return prob_of_zero_hits_and_source_in_space + prob_of_zero_hits_and_source_not_in_space

def p1(bessel_matr, dimx, dimy, centerx, centery, rate, lam, particle_size, ix, iy, pA_lin, Pt_lin): # probability of a hit over belief space
    # (1-pA)*sum(Pt.*p1y)
    prob_of_source_in_space = 1.0 - pA_lin
    expected_prob_of_receiving_nonzero_hits_over_belief_given_source_in_space = np.nansum(Pt_lin * p1y(bessel_matr, dimx, dimy, centerx, centery, rate, lam, particle_size, ix, iy))
    prob_of_nonzero_hits_and_source_in_space = (prob_of_source_in_space * expected_prob_of_receiving_nonzero_hits_over_belief_given_source_in_space)

    prob_of_source_not_in_space = pA_lin
    expected_prob_of_receiving_nonzero_hits_over_belief_given_source_not_in_space = 0
    prob_of_nonzero_hits_and_source_not_in_space = (prob_of_source_not_in_space * expected_prob_of_receiving_nonzero_hits_over_belief_given_source_not_in_space)

    return prob_of_nonzero_hits_and_source_in_space + prob_of_nonzero_hits_and_source_not_in_space

def move_entropy(bessel_matr, dimx, dimy, centerx, centery, log_posterior, pA_lin, rate, lam, particle_size, ix, iy, S) -> float:
    """Compute expected change in entropy for moving to (ix, iy)."""
    posterior = np.exp(log_posterior)
    forecasted_posterior = posterior
    forecasted_posterior[ix - 1, iy - 1] = 0.0

    # Case: zero hit
    forecasted_unnormalized_posterior_under_zero_hits = forecasted_posterior * p0y(bessel_matr, dimx, dimy, centerx, centery, rate, lam, particle_size, ix, iy)
    forecasted_unnormalized_posterior_under_zero_hits = np.where(np.isnan(forecasted_unnormalized_posterior_under_zero_hits), 0.0, forecasted_unnormalized_posterior_under_zero_hits)
    normalization_under_zero_hits = np.sum(forecasted_unnormalized_posterior_under_zero_hits)
    if normalization_under_zero_hits > 0:
        forecasted_posterior_under_zero_hits = forecasted_unnormalized_posterior_under_zero_hits / normalization_under_zero_hits
        m0 = forecasted_posterior_under_zero_hits > 1e-300
        entropyS_under_zero_hits = -np.sum(forecasted_posterior_under_zero_hits[m0] * np.log2(forecasted_posterior_under_zero_hits[m0]))
    else:
        entropyS_under_zero_hits = 0.0
    delta_S_under_zero_hits = entropyS_under_zero_hits - S
    expected_delta_S_under_zero_hits = p0(bessel_matr, dimx, dimy, centerx, centery, rate, lam, particle_size, ix, iy, pA_lin, posterior) * delta_S_under_zero_hits

    # Case: at least one hit
    forecasted_unnormalized_posterior_under_nonzero_hits = forecasted_posterior * p1y(bessel_matr, dimx, dimy, centerx, centery, rate, lam, particle_size, ix, iy)
    forecasted_unnormalized_posterior_under_nonzero_hits = np.where(np.isnan(forecasted_unnormalized_posterior_under_nonzero_hits), 0.0, forecasted_unnormalized_posterior_under_nonzero_hits)
    normalization_under_nonzero_hits = np.sum(forecasted_unnormalized_posterior_under_nonzero_hits)
    if normalization_under_nonzero_hits > 0:
        forecasted_posterior_under_nonzero_hits = forecasted_unnormalized_posterior_under_nonzero_hits / normalization_under_nonzero_hits
        m1 = forecasted_posterior_under_nonzero_hits > 1e-300
        entropyS_under_nonzero_hits = -np.sum(forecasted_posterior_under_nonzero_hits[m1] * np.log2(forecasted_posterior_under_nonzero_hits[m1]))
        if not np.isfinite(entropyS_under_nonzero_hits):
            entropyS_under_nonzero_hits = 0.0
    else:
        entropyS_under_nonzero_hits = 0.0
    delta_S_under_nonzero_hits = entropyS_under_nonzero_hits - S
    expected_delta_S_under_nonzero_hits = p1(bessel_matr, dimx, dimy, centerx, centery, rate, lam, particle_size, ix, iy, pA_lin, posterior) * delta_S_under_nonzero_hits

    # Immediate reward if the source is exactly at (ix,iy)
    posterior = np.exp(log_posterior[ix - 1, iy - 1])
    expected_delta_S = (posterior * (-S)) + ((1.0 - posterior) * (expected_delta_S_under_zero_hits + expected_delta_S_under_nonzero_hits))

    return expected_delta_S

In [None]:
def infotaxis(runtime: int,
              pos_x: int,
              pos_y: int,
              sourcex: int,
              sourcey: int,
              dimx: int,
              dimy: int,
              tau: float,
              D: float,
              rate: float,
              particle_size: float,
              gauss_sig: float,
              nowherebase: float = -300.0) -> Tuple[np.ndarray, dict, np.ndarray]:
    """
    Python port of MATLAB `infotaxis.m`.

    Parameters
    ----------
    runtime : int
        Number of time bins to run.
    pos_x, pos_y : int
        Initial position (1-based indices, to match MATLAB).
    sourcex, sourcey : int
        Center of Gaussian prior over source location (1-based).
    dimx, dimy : int
        Arena dimensions.
    tau : float
        Prior belief about lifetime of emitted particles.
    D : float
        Diffusion rate of emitted particles.
    rate : float
        Emission rate of particles.
    particle_size : float
        Relative size of particles.
    gauss_sig : float
        Variance of Gaussian prior.
    nowherebase : float, optional
        Initial log-probability that the food is nowhere in the arena.
    
    Returns
    -------
    path : (T, 2) ndarray of int
        Trajectory (1-based indices, like MATLAB).
    debug : dict
        Placeholder for debug values (kept for API parity).
    A : (runtime+1, dimx, dimy) ndarray of float
        Log-probability tensor snapshots (Lt) over time.
    """

    rng = np.random.default_rng()

    # Keep 1-based indexing semantics externally; internal arrays are 0-based
    # but we will be careful when indexing with positions.

    dt = 1.0
    lam = np.sqrt(D * tau)

    # --- Precompute the Bessel K0 distance matrix over a large window ----
    # MATLAB builds a (2*dimx+1) x (2*dimy+1) grid (after transpose).
    xs = np.arange(-dimx, dimx + 1, dtype=float)
    ys = np.arange(-dimy, dimy + 1, dtype=float)
    XX, YY = np.meshgrid(xs, ys, indexing='xy')  # shape (2*dimy+1, 2*dimx+1)
    r_over_lambda = np.sqrt(XX**2 + YY**2) / lam

    # kv(0, 0) -> inf; like the MATLAB code, set NaN/Inf to 0 after computing
    bessel_matr = kv(0, r_over_lambda)
    bessel_matr = np.where(np.isfinite(bessel_matr), bessel_matr, 0.0)
    # MATLAB transposes after building; match their memory/view by transposing
    bessel_matr = bessel_matr.T  # shape (2*dimx+1, 2*dimy+1)

    centerx = dimx  # zero-based index of center along x in our transposed array
    centery = dimy  # zero-based index of center along y

    # --- Build initial log-likelihood grid Lt (Gaussian prior over source) ---
    # In MATLAB, they build using distances from (sourcex, sourcey) (1-based).
    # We'll create an array of shape (dimx, dimy) addressed as [ix-1, iy-1].
    xs_grid = np.arange(1, dimx + 1)
    ys_grid = np.arange(1, dimy + 1)
    Xg, Yg = np.meshgrid(xs_grid, ys_grid, indexing='ij')  # shape (dimx, dimy)

    dist = np.sqrt((Xg - sourcex)**2 + (Yg - sourcey)**2)
    log_posterior = - (dist**2) / gauss_sig # Previously called Lt
    # Normalize to log-space (subtract log-sum-exp)
    # Lt_norm = log(sum(exp(Lt))) -> subtract it so that sum(exp(Lt)) == 1
    unnormalized_posterior = np.exp(log_posterior)
    log_posterior = log_posterior - np.log(np.sum(unnormalized_posterior))

    debug = {}

    # A will store Lt snapshots
    A = np.zeros((runtime + 1, dimx, dimy), dtype=float)
    pA_log_curve = np.zeros(runtime + 1)
    time_taken = 0

    path = [np.array([pos_x, pos_y], dtype=int)]

    # --- Main loop ---
    pA_log = float(nowherebase)  # log-probability that source is outside space
    for t in range(runtime + 1):
        # Normalize log posterior in log-space so sum(posterior) == 1
        unnormalized_posterior = np.exp(log_posterior)
        log_posterior = log_posterior - np.log(np.sum(unnormalized_posterior))
        normalized_posterior = np.exp(log_posterior)

        pA_lin = np.exp(pA_log) # probability that source is outside space
        # Update pA in log-space:
        prob_of_receiving_0_hits = p0(bessel_matr, dimx, dimy, centerx, centery, rate, lam, particle_size, pos_x, pos_y, pA_lin, normalized_posterior)
        pA_log = pA_log - np.log(prob_of_receiving_0_hits)
        # Clamp: if > 0, set to 0  (i.e., cap at log(1))
        if pA_log > 0:
            pA_log = 0.0

        # Entropy S of current belief (exclude underflows)
        mask = normalized_posterior > 1e-300
        S = -np.sum(normalized_posterior[mask] * (np.log2(normalized_posterior[mask])))
        
        # Compute expected entropy change for five actions
        big = 1_000_000.0

        # up: (pos_y+1)
        if pos_y + 1 < dimy:
            up = move_entropy(bessel_matr, dimx, dimy, centerx, centery, log_posterior, pA_lin, rate, lam, particle_size, pos_x, pos_y + 1, S)
        else:
            up = big

        # down: (pos_y-1)
        if pos_y - 1 > 0:
            down = move_entropy(bessel_matr, dimx, dimy, centerx, centery, log_posterior, pA_lin, rate, lam, particle_size, pos_x, pos_y - 1, S)
        else:
            down = big

        # left: (pos_x-1)
        if pos_x - 1 > 0:
            left = move_entropy(bessel_matr, dimx, dimy, centerx, centery, log_posterior, pA_lin, rate, lam, particle_size, pos_x - 1, pos_y, S)
        else:
            left = big

        # right: (pos_x+1)
        if pos_x + 1 < dimx:
            right = move_entropy(bessel_matr, dimx, dimy, centerx, centery, log_posterior, pA_lin, rate, lam, particle_size, pos_x + 1, pos_y, S)
        else:
            right = big

        # stay:
        stay = move_entropy(bessel_matr, dimx, dimy, centerx, centery, log_posterior, pA_lin, rate, lam, particle_size, pos_x, pos_y, S)

        # Record log posterior
        A[t, :, :] = log_posterior
        pA_log_curve[t] = pA_log

        # Decision matrix (same layout as MATLAB)
        dc = max(up, down, left, right, stay) + 10.0
        decision_matr = np.array([
            [dc,   down, dc],
            [left, stay, right],
            [dc,   up,   dc]
        ])
        mval = decision_matr.min()
        decisions = np.argwhere(decision_matr == mval)

        # Tie-break uniformly
        if len(decisions) > 1:
            choice_idx = rng.integers(0, len(decisions))
            chosen = decisions[choice_idx]
        else:
            chosen = decisions[0]

        movey, movex = chosen  # MATLAB's (row, col) order
        movex = movex - 1  # center is 1 -> becomes 0; left=0->-1; right=2->+1
        movey = movey - 1  # center is 1 -> 0; up=2->+1; down=0->-1

        # Parity with MATLAB: if diagonal (product nonzero), force right move
        if movex * movey != 0:
            movex = 1
            movey = 0

        # Update position (still 1-based tracking)
        pos_x = pos_x + movex
        pos_y = pos_y + movey
        path.append(np.array([pos_x, pos_y], dtype=int))

        # Update log posterior after move
        log_posterior[pos_x - 1, pos_y - 1] = -np.inf
        likelihood_observing_zero_hits = p0y(bessel_matr, dimx, dimy, centerx, centery, rate, lam, particle_size, pos_x, pos_y) # Source is absent so we will always observe 0
        log_posterior = log_posterior + np.log(likelihood_observing_zero_hits)
        log_posterior = np.where(np.isnan(log_posterior), 0.0, log_posterior)

        # Stopping criteria (ported directly)
        if len(path) > 15:
            recent = np.array(path[-11:], dtype=int)
            # Borders checks similar to MATLAB logic
            cond_x = np.sum((recent[:, 0] == dimx) | (recent[:, 0] == dimx - 1) | (recent[:, 0] == 1) | (recent[:, 1] == 2)) == 11
            cond_y = np.sum((recent[:, 1] == dimy) | (recent[:, 1] == dimy - 1) | (recent[:, 1] == 1) | (recent[:, 1] == 2)) == 11
            if cond_x or cond_y:
                break

        if len(path) > 500:
            recent = np.array(path[-501:], dtype=int)
            if np.all(recent == recent[-1]):
                break
        
        time_taken += 1

    path = np.vstack(path)
    return path, debug, A, pA_log_curve, time_taken


In [23]:
agent_init_position = np.array([64, 80])
source_init_position = np.array([64, 64])
grid_size = np.array([128, 128])

path, debug, A, pA_log_curve, time_taken = infotaxis(10000, agent_init_position[0], agent_init_position[1], 
                                                     source_init_position[0], source_init_position[1], 
                                                     grid_size[0], grid_size[1], 10, 1, 1, 1, 0.6, -100)

In [24]:
def plot_probability_of_source(ind):
    fig, ax_all = plt.subplots(1, 2, figsize=(14, 5))
    plt.rcParams.update({'font.size': 14})

    ax = ax_all[0]
    im = ax.imshow((A[ind,:,:].T), cmap='viridis', vmin=-100, vmax=0)
    ax.plot(path[:ind+1, 0] - 1.0, path[:ind+1, 1] - 1.0, color='red', linewidth=2, label=f'{(np.sum(np.exp(A[ind,:,:].T)))}')
    ax.scatter([64 - 1], [80 - 1], marker='.', color='r', s=100)
    ax.set_xlim(32, 96)
    ax.set_ylim(32, 96)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_aspect('equal')
    ax.set_title('Infotaxis path')
    ax.legend()

    cbar = fig.colorbar(im, ax=ax)
    cbar.set_label('Log posterior')

    ax = ax_all[1]
    ax.plot(pA_log_curve[:ind+1], color='k', linewidth=1)
    ax.set_ylim(-100, 0)
    ax.set_xlim(0, time_taken)
    ax.set_ylabel('log p(source is elsewhere)')
    ax.set_xlabel('timesteps')

    fig.tight_layout()
    plt.show()

In [25]:
time_slider = widgets.IntSlider(
    value=0,  
    min=0,
    max=time_taken,
    description="Time index",
    style={'description_width': 'initial'},
    continuous_update=True,
    layout=widgets.Layout(width="1000px"),
)

interactive_plot = widgets.interactive(
    plot_probability_of_source, ind=time_slider)
display(interactive_plot)

interactive(children=(IntSlider(value=0, description='Time index', layout=Layout(width='1000px'), max=524, sty…

In [10]:
agent_init_position = np.array([64, 80])
source_init_position = np.array([64, 64])
grid_size = np.array([128, 128])

path_2, debug_2, A_2, pA_log_curve2, time_taken2 = infotaxis(10000, agent_init_position[0], agent_init_position[1], 
                                                     source_init_position[0], source_init_position[1], 
                                                     grid_size[0], grid_size[1], 10, 1, 1, 1, 0.6, -100)


In [11]:
np.array_equal(pA_log_curve, pA_log_curve2)

True