#### Import packages

In [None]:
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from itertools import product, combinations

#### Functions

In [None]:
def create_baseline_lattice(x: int, y: int):
    '''
    Create a baseline directed 2D lattice where all lateral and downwards directions are allowed.
    Inputs:
        x: number of nodes per row, type=int
        y: number of nodes per column, type=int
    Outputs
        g: 2D lattice, type=nx.diGraph
        bulk_nodes: list of non-outflow nodes, type=list
    '''
    g = nx.grid_2d_graph(x,int(y))
    g.remove_edges_from([((i1,0),(i1+1,0)) for i1 in range(x-1)])
    g = g.to_directed()
    g.remove_edges_from([((i1,i2),(i1,i2+1)) for i1,i2 in product(range(x), range(int(y)-1))])
    g.add_edges_from([((0,i2),(x-1,i2)) for i2 in range(1,int(y))]+[((x-1,i2),(0,i2)) for i2 in range(1,int(y))])
    bulk_nodes = [i for i in g if i not in [(i, 0) for i in range(x)]]
    return g, bulk_nodes

def declare_simulation_variables(g):
    '''
    Initialise all the variables required for a sandpile simulation.
    Input
        g: landscape lattice, type=nx.diGraph
    Outputs
        state: number of particles per node, type=dict
        coupling: distances covered by particles exchanged between nodes, type=dict
        current_av: nodes participating in the currently occuring avalanche, empty during accumulation phase, type=list
        branches: length of branches described by the currently occuring avalanche, type=dict
        new_active: nodes having received particles during the last time step, type=list
        size: list of avalanche sizes having occured during the entire simulation, type=list
    '''
    state, coupling, current_av, branches, new_active, size, ordlist = {}, {}, [], {}, [], [], sorted(list(g), key=lambda a:a[1], reverse=True)
    for i in g:
        state[i], coupling[i] = 0, {}
    for i,j in combinations(ordlist, 2):
        if nx.has_path(g,i,j):
            coupling[i][j] = []
        if i[1] == j[1]:		# for tgrid
            if nx.has_path(g,j,i):
                coupling[j][i] = []
    return state, coupling, current_av, branches, new_active, size

