In [33]:
import torch 
import numpy as np
import time 
from typing import Tuple
import threading
import time

In [34]:

# ==============================
# Pretty Logger
# ==============================

print_lock = threading.Lock()

def log(iter_tag: str, left: str = "", right: str = "", barrier: str = ""):
  with print_lock:
    print(f"{iter_tag:<12} | {left:<35} || {right:<35} {barrier}")


# ==============================
# Barrier
# ==============================

class mBar:
  def __init__(self, name: str, arrival_count: int):
    self.name = name
    self.phase = 0
    self.full_arrival_count = arrival_count
    self.current_arrival_count = arrival_count
    self.lock = threading.Lock()
    self.cond = threading.Condition(self.lock)

  def arrive(self, n_threads: int, iter_tag: str):
    with self.lock:
      self.current_arrival_count -= n_threads

      if self.current_arrival_count < 0:
        raise RuntimeError(f"{self.name}: over-arrival")

      if self.current_arrival_count == 0:
        self.phase ^= 1
        self.current_arrival_count = self.full_arrival_count
        log(iter_tag, barrier=f"[{self.name}] → phase {self.phase}")
        self.cond.notify_all()

  def wait(self, expected_phase: int):
    with self.lock:
      while self.phase == expected_phase:
        self.cond.wait()


# ==============================
# Config
# ==============================

num_stages = 3
empty = [mBar(f"empty[{i}]", 1) for i in range(num_stages)]
full  = [mBar(f"full[{i}]", 1)  for i in range(num_stages)]

GM = 2
GN = 2
GK = 6

tma_sleep = 0.01
ldmatrix_sleep_time = 0.005
mma_sleep_time = 0.001
store_sleep_time = 0.01


# ==============================
# Producer
# ==============================

def producer():
  stage = 0
  phase = 0

  for i in range(GM):
    for j in range(GN):
      for k in range(GK):

        iter_tag = f"({i},{j},{k})"

        empty[stage].wait(phase)

        log(iter_tag, left=f"TMA   s{stage} p{phase}")
        time.sleep(tma_sleep)

        full[stage].arrive(1, iter_tag)

        stage = (stage + 1) % num_stages
        if stage == 0:
          phase ^= 1


# ==============================
# Consumer
# ==============================

def consumer():
  # initialize empty as available
  for s in range(num_stages):
    empty[s].arrive(1, "init")

  stage = 0
  phase = 0

  for i in range(GM):
    for j in range(GN):

      for k in range(GK):

        iter_tag = f"({i},{j},{k})"

        full[stage].wait(phase)

        log(iter_tag, right=f"LD    s{stage} p{phase}")
        time.sleep(ldmatrix_sleep_time)

        empty[stage].arrive(1, iter_tag)

        log(iter_tag, right=f"MMA   s{stage} p{phase}")
        time.sleep(mma_sleep_time)

        stage = (stage + 1) % num_stages
        if stage == 0:
          phase ^= 1

      log(f"({i},{j})", right="STORE")
      time.sleep(store_sleep_time)


# ==============================
# Launch Threads
# ==============================

print("\nITER        | PRODUCER                            || CONSUMER")
print("-" * 90)

t_prod = threading.Thread(target=producer)
t_cons = threading.Thread(target=consumer)

t_prod.start()
t_cons.start()

t_prod.join()
t_cons.join()

print("\nDone.")


ITER        | PRODUCER                            || CONSUMER
------------------------------------------------------------------------------------------
init         |                                     ||                                     [empty[0]] → phase 1
init         |                                     ||                                     [empty[1]] → phase 1
init         |                                     ||                                     [empty[2]] → phase 1
(0,0,0)      | TMA   s0 p0                         ||                                     
(0,0,0)      |                                     ||                                     [full[0]] → phase 1
(0,0,1)      | TMA   s1 p0                         ||                                     
(0,0,0)      |                                     || LD    s0 p0                         
(0,0,0)      |                                     ||                                     [empty[0]] → phase 0
(0,0,0)      |     