# Ant and Seeds - Problem 280
<p>A laborious ant walks randomly on a $5 \times 5$ grid. The walk starts from the central square. At each step, the ant moves to an adjacent square at random, without leaving the grid; thus there are $2$, $3$ or $4$ possible moves at each step depending on the ant's position.</p>

<p>At the start of the walk, a seed is placed on each square of the lower row. When the ant isn't carrying a seed and reaches a square of the lower row containing a seed, it will start to carry the seed. The ant will drop the seed on the first empty square of the upper row it eventually reaches.</p>

<p>What's the expected number of steps until all seeds have been dropped in the top row? <br>
Give your answer rounded to $6$ decimal places.</p>

## Solution.
https://lips.cs.princeton.edu/the-fundamental-matrix-of-a-finite-markov-chain/

$state = (x_coord, y_coord, carry_seed, seeds_coordinates_bottom, seeds_coordinates_top)

We start at (3, 3, False, [1,2,3,4,5], [0,0,0,0,0])

We finish when we hit the set A= {(x, 5, False, [0,0,0,0,0], [1,2,3,4,5]): x=1,2,3,4,5}


In [1]:
from math import factorial
from functools import lru_cache
from collections import defaultdict
import numpy as np

#### Construct the graph

In [2]:
def neighbourhood(x, y):
    neighbors = []
    moves = [(-1, 0), (1, 0), (0, -1), (0, 1)]
    
    for dx, dy in moves:
        nx, ny = x + dx, y + dy
        
        if 1 <= nx <= 5 and 1 <= ny <= 5:
            neighbors.append((nx, ny))
    
    return neighbors

def construct_graph(start):
    graph = defaultdict(set)
    queue = [start]
    V = set()
    E = set()

    while queue:
        state = queue.pop()
        x, y, carry, bottom, top = state

        neighbours = neighbourhood(x, y)
        p = 1 / len(neighbours)

        for n in neighbours:
            x_new, y_new = n
            carry_new = carry
            bottom_new = list(bottom)
            top_new = list(top)

            if not carry:
                if y_new == 1 and bottom[x_new-1] == 1:
                    carry_new = True
                    bottom_new[x_new-1] = 0
            else:
                if y_new == 5 and top[x_new-1] == 0:
                    carry_new = False
                    top_new[x_new-1] = 1

            state_new = (x_new, y_new, carry_new, tuple(bottom_new), tuple(top_new))
            
            if (state, state_new) not in E:
                graph[state].add((state_new, p))
                E.add((state, state_new))
                V.add(state_new)
                
                if tuple(top_new) != (1,1,1,1,1):
                    queue.append(state_new)
    return graph, V

#### Transition matrix and EV to hit absobtion

In [3]:
def construct_transition_matrix(graph, V):
    state_list = list(V)
    state_index = {state: i for i, state in enumerate(state_list)}
    size = len(state_list)
    P = np.zeros((size, size))

    for state in graph:
        i = state_index[state]
        for (next_state, p) in graph[state]:
            j = state_index[next_state]
            P[i, j] = p

    return P, state_list

def identify_states(state_list):
    absorbing_states = []
    transient_states = []
    for state in state_list:
        if state[4] == (1,1,1,1,1):
            absorbing_states.append(state)
        else:
            transient_states.append(state)
    return transient_states, absorbing_states

def calculate_fundamental_matrix(Q):
    I = np.eye(Q.shape[0])
    N = np.linalg.inv(I - Q)
    return N

def expected_time_to_absorption(N):
    t = N.dot(np.ones((N.shape[0], 1)))
    return t

#### Main

In [4]:
start = (3, 3, False, (1,1,1,1,1), (0,0,0,0,0))
graph, V = construct_graph(start)
P, state_list = construct_transition_matrix(graph, V)
transient_states, absorbing_states = identify_states(state_list)

# Partition the transition matrix
transient_indices = [state_list.index(s) for s in transient_states]
Q = P[np.ix_(transient_indices, transient_indices)]

# Calculate the fundamental matrix
N = calculate_fundamental_matrix(Q)

# Compute the expected time to absorption
t = expected_time_to_absorption(N)

# Find the index of the start state in the transient states
start_index = transient_states.index(start)

print(round(t[start_index][0], 6))

430.088247


In [None]:
Q