In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import random

In [2]:
MODE = "flu"
infect_len_dict = {"flu": (7, 0.38), "cold": (5, 0.5), "covid": (7, 0.63), "stomach bug": (6, 0.49)}

In [3]:
# @title
class Sim:
  def __init__(self, size=(10, 10), num_cells=5):
    self.board = np.zeros(size, dtype=object)
    self.cells = [Cell((0, 0), True)]
    self.board[0, 0] = self.cells[0]
    for i in range(num_cells-1):
      pos = (np.random.randint(0, size[0]), np.random.randint(0, size[1]))
      self.cells.append(Cell(pos))
      self.board[pos[0], pos[1]] = self.cells[-1]

    self.plot()
    self.time = 0

  def step(self):
    for cell in self.cells:
      self.move_cell(cell)
      self.check_neighborhood(cell)

    self.plot()
    self.time += 1

  def is_legal(self, new_pos, move=True):
    if (new_pos[0] < 0 or new_pos[0] == self.board.shape[0]) or (new_pos[1] < 0 or new_pos[1] == self.board.shape[1]):
      return False
    elif move and self.board[new_pos] != 0:
      return False

    return True

  def move_cell(self, cell):
    curr_pos = cell.pos
    enter_loop = True
    while enter_loop or (not self.is_legal(temp_pos)):
      enter_loop = False
      temp_pos = curr_pos
      choice = np.random.randint(0, 4)
      if choice == 0:
        temp_pos = (temp_pos[0] + 1, temp_pos[1])
      elif choice == 1:
        temp_pos = (temp_pos[0] - 1, temp_pos[1])
      elif choice == 2:
        temp_pos = (temp_pos[0], temp_pos[1] + 1)
      else:
        temp_pos = (temp_pos[0], temp_pos[1] - 1)

    self.board[temp_pos] = cell
    cell.update_pos(temp_pos)
    self.board[curr_pos] = 0

  def check_neighborhood(self, cell):
    curr_pos = cell.pos
    neighbors = []
    for n in [1, -1]:
      temp_pos1 = (curr_pos[0] + n, curr_pos[1])
      temp_pos2 = (curr_pos[0], curr_pos[1] + n)
      temp_pos3 = (curr_pos[0] + n, curr_pos[1] + n)
      temp_pos4 = (curr_pos[0] - n, curr_pos[1] + n)
      neighbors.extend([temp_pos1, temp_pos2, temp_pos3, temp_pos4])

    neighbor_count = 0
    for neighbor in neighbors:
      if self.is_legal(neighbor, False) and self.board[neighbor] != 0 and self.board[neighbor].state == 1:
        neighbor_count += 1
    cell.update_state(neighbor_count)

  def plot(self):
    plt.figure()

    state_board = np.zeros((self.board.shape[0], self.board.shape[1], 4))
    cmap = {0: [0, 0, 0, 1], 1: [0.1, 0.1, 1, 1], 2: [1, 0.1, 0.1, 1], 3: [0.1, 1, 0.1, 1]}
    for i in range(self.board.shape[0]):
      for j in range(self.board.shape[1]):
        tile = self.board[i, j]
        if tile != 0:
          state_board[i, j] = cmap[tile.state + 1]
        else:
          state_board[i,j] = cmap[tile]

    labels = {0: "empty", 1: "susceptible", 2: "infected", 3: "recovered"}
    patches =[mpatches.Patch(color=cmap[i], label=labels[i]) for i in cmap]
    plt.imshow(state_board)
    plt.legend(handles=patches, loc=4, borderaxespad=0)

In [4]:
# @title
class Cell:
  def __init__(self, start_pos, start_infect=False):
    self.pos = start_pos
    if start_infect:
      self.state = 1
      self.infect_prob = 0
      self.rem_infect_time = infect_len_dict[MODE][0]
    else:
      self.state = 0
      self.infect_prob = 0
      self.rem_infect_time = 0

  def update_pos(self, new_pos):
    self.pos = new_pos

  def update_state(self, neighbors):
    if self.state == 0:
      if neighbors == 0:
        self.infect_prob += 0.01 # idle environmental spread
      elif neighbors == 1:
        if self.infect_prob < infect_len_dict[MODE][1]:
          self.infect_prob += infect_len_dict[MODE][1]
        else:
          dim_factor = np.log10(1.1)
          self.infect_prob += dim_factor
      else:
        dim_factor = np.log10(1 + (neighbors * 0.2))
        if self.infect_prob < infect_len_dict[MODE][1]:
          self.infect_prob += infect_len_dict[MODE][1] + dim_factor
        else:
          self.infect_prob += dim_factor

      if self.infect_prob < 1:
        self.state = random.choices([0, 1], (1 - self.infect_prob, self.infect_prob))[0]
      if self.state == 1 or self.infect_prob >= 1:
        self.state = 1
        self.infect_prob = 0
        self.rem_infect_time = infect_len_dict[MODE][0]

    elif self.state == 1:
      self.rem_infect_time -= 1
      if self.rem_infect_time == 0:
        self.state = 2

In [None]:
sim = Sim()
for i in range(10):
  sim.step()