In [4]:
# Markov Transition Matrix from a streaming data source
import numpy as np
import itertools
# https://sam-black.medium.com/creating-a-markov-transition-matrix-from-a-streaming-data-source-995fcf28422
class MarkovTransitionMatrix:
  def __init__(self, states):
    self.states = sorted(states)
    if len(self.states) < 2:
      raise Exception("need at least 2 states for a markov process")
    if len(set(states)) != len(self.states):
      raise Exception("the state vector contains duplicate states")
    self.n_states = len(self.states)
    # constructs a square matrix representing the state transition matrix
    self.state_transition_matrix = np.zeros((self.n_states, self.n_states))
    self.counter = {}
    # initializes counts 
    for state in self.states:
      self.counter[state] = 0
    self.state_combinations = ["_".join(str(_) for _ in a) for a in itertools.product(self.states, repeat=2)]
    self.state_combination_matrix = np.array(self.state_combinations).reshape((self.n_states, self.n_states))
    for state_combination in self.state_combinations:
      self.counter[state_combination] = 0

  def observe_state(self, observation, return_state_transitions=True):
    if len(observation) != 2:
      raise Exception("a state observation must be a tuple containing last state and the current state")
    state_combination_key = "_".join(str(_) for _ in observation)
    self.counter[observation[0]] += 1
    self.counter[state_combination_key] += 1
    if return_state_transitions:
      return self._get_state_transition_matrix()

  def _get_state_transition_matrix(self):
    for s in self.states:
      for t in self.states:
        combination_key = "_".join((str(s),str(t)))
        if self.counter[s] == 0:
          continue
        else:
          self.state_transition_matrix[self.states.index(s)][self.states.index(t)] = float(self.counter[combination_key] / self.counter[s])
    return self.state_transition_matrix.copy()
    
states = ["x", "y", "z","p", "q", "r"]

mtm = MarkovTransitionMatrix(states)

print(mtm.state_combination_matrix)

observations =np.random.choice(states, size=300000, p=[0.2, 0.1, 0.2, 0.4, 0.05, 0.05])

for i in range(0,len(observations)):
  if i > 0: 
    prev = observations[i-1]
    current = observations[i]
    _obs = (prev, current)
    if i % 10000 == 0:
      print(mtm.observe_state(_obs, return_state_transitions=True).tolist())
    else:
      mtm.observe_state(_obs, return_state_transitions=False)

[['p_p' 'p_q' 'p_r' 'p_x' 'p_y' 'p_z']
 ['q_p' 'q_q' 'q_r' 'q_x' 'q_y' 'q_z']
 ['r_p' 'r_q' 'r_r' 'r_x' 'r_y' 'r_z']
 ['x_p' 'x_q' 'x_r' 'x_x' 'x_y' 'x_z']
 ['y_p' 'y_q' 'y_r' 'y_x' 'y_y' 'y_z']
 ['z_p' 'z_q' 'z_r' 'z_x' 'z_y' 'z_z']]
[[0.3933214376752485, 0.05174611266887586, 0.04919704307927607, 0.19423910272750447, 0.10833545755799133, 0.20316084629110376], [0.3737957610789981, 0.04046242774566474, 0.05202312138728324, 0.20809248554913296, 0.10211946050096339, 0.22350674373795762], [0.414, 0.06, 0.028, 0.176, 0.104, 0.218], [0.39028056112224446, 0.05110220440881764, 0.057615230460921846, 0.20140280561122245, 0.10170340681362726, 0.19789579158316634], [0.3775894538606403, 0.0583804143126177, 0.038606403013182675, 0.2128060263653484, 0.1111111111111111, 0.2015065913370998], [0.399, 0.0505, 0.055, 0.205, 0.1055, 0.185]]
[[0.3946137311923126, 0.046149955746617774, 0.05158679984827412, 0.1998988494120622, 0.10696674674421545, 0.20078391705651788], [0.3829145728643216, 0.05025125628140704