In [13]:
import torch 
import numpy as np
import time 
from typing import Tuple
import threading
import time

In [14]:
import threading
import time


class mBar:
  def __init__(self, name: str, arrival_count: int, count_per_thread: int = 1) -> None:
    self.name = name
    self.phase = 0
    self.full_arrival_count = arrival_count
    self.current_arrival_count = arrival_count
    self.count_per_thread = count_per_thread
    self.lock = threading.Lock()
    self.cond = threading.Condition(self.lock)

  def arrive(self, n_threads: int) -> None:
    with self.lock:
      decrement = n_threads * self.count_per_thread
      self.current_arrival_count -= decrement

      if self.current_arrival_count < 0:
        raise RuntimeError(f"{self.name}: Barrier over-arrival")

      if self.current_arrival_count == 0:
        self.phase ^= 1
        self.current_arrival_count = self.full_arrival_count
        print(f"[{self.name}] Phase flipped to {self.phase}")
        self.cond.notify_all()

  def wait(self, expected_phase: int) -> None:
    with self.lock:
      while self.phase == expected_phase:
        self.cond.wait()


def tma_copy(name: str, bar: mBar, sleep_time: float, n_threads: int):
  print(f"{name}: Launch TMA")
  time.sleep(sleep_time)
  print(f"{name}: TMA complete")
  bar.arrive(n_threads)


def ldmatrix(name: str, sleep_time: float):
  print(f"{name}: ldmatrix")
  time.sleep(sleep_time)


def mma(name: str, sleep_time: float):
  print(f"{name}: mma")
  time.sleep(sleep_time)


# --------------------------------------




In [23]:
num_stages = 2
empty = [mBar(f"                                          empty[{i}]", 1) for i in range(num_stages)]
full  = [mBar(f"full[{i}]", 1)  for i in range(num_stages)]

GM = 4
GN = 4
GK = 4

tma_sleep = 0.01
ldmatrix_sleep_time = 0.0001
mma_sleep_time = 0.00001


def producer():
  stage = 0
  phase = 0

  for i in range(GM):
    for j in range(GN):
      for k in range(GK):

        empty[stage].wait(phase)

        tma_copy(
          f"producer: stage {stage}, phase {phase}",
          full[stage],
          tma_sleep,
          1
        )

        stage = (stage + 1) % num_stages
        if stage == 0:
          phase ^= 1


def consumer():
  # initialize empty stages as available
  for s in range(num_stages):
    empty[s].arrive(1)

  stage = 0
  phase = 0

  for i in range(GM):
    for j in range(GN):
      for k in range(GK):

        full[stage].wait(phase)

        ldmatrix(
          f"                                    consumer_ld: stage {stage}, phase {phase}",
          ldmatrix_sleep_time
        )

        # stage becomes empty after ldmatrix
        empty[stage].arrive(1)

        mma(
      f"                                         consumer_mma: stage {stage}, phase {phase}",
          mma_sleep_time
        )

        stage = (stage + 1) % num_stages
        if stage == 0:
          phase ^= 1


In [24]:

# --------------------------------------

t_prod = threading.Thread(target=producer)
t_cons = threading.Thread(target=consumer)

t_prod.start()
t_cons.start()

t_prod.join()
t_cons.join()


[                                          empty[0]] Phase flipped to 1
[                                          empty[1]] Phase flipped to 1
producer: stage 0, phase 0: Launch TMA
producer: stage 0, phase 0: TMA complete
[full[0]] Phase flipped to 1
producer: stage 1, phase 0: Launch TMA
                                    consumer_ld: stage 0, phase 0: ldmatrix
[                                          empty[0]] Phase flipped to 0
                                         consumer_mma: stage 0, phase 0: mma
producer: stage 1, phase 0: TMA complete
[full[1]] Phase flipped to 1
producer: stage 0, phase 1: Launch TMA
                                    consumer_ld: stage 1, phase 0: ldmatrix
[                                          empty[1]] Phase flipped to 0
                                         consumer_mma: stage 1, phase 0: mma
producer: stage 0, phase 1: TMA complete
[full[0]] Phase flipped to 0
producer: stage 1, phase 1: Launch TMA
                                    cons

producer: stage 1, phase 1: TMA complete
[full[1]] Phase flipped to 0
producer: stage 0, phase 0: Launch TMA
                                    consumer_ld: stage 1, phase 1: ldmatrix
[                                          empty[1]] Phase flipped to 1
                                         consumer_mma: stage 1, phase 1: mma
producer: stage 0, phase 0: TMA complete
[full[0]] Phase flipped to 1
producer: stage 1, phase 0: Launch TMA
                                    consumer_ld: stage 0, phase 0: ldmatrix
[                                          empty[0]] Phase flipped to 0
                                         consumer_mma: stage 0, phase 0: mma
producer: stage 1, phase 0: TMA complete
[full[1]] Phase flipped to 1
producer: stage 0, phase 1: Launch TMA
                                    consumer_ld: stage 1, phase 0: ldmatrix
[                                          empty[1]] Phase flipped to 0
                                         consumer_mma: stage 1, phase 0: mm