In [43]:
import numpy as np 

In [44]:
class SharedMemoryMatrix:
    def __init__(self, name, shape):
        self.name = name
        self.val = np.random.randn(*shape)

    def __repr__(self):
        return f"SMEM({self.name})"

class Register:
    def __init__(self, name, val=None):
        self.name = name
        self.ready_at = 0
        self.val = val

    def is_ready(self, t):
        return t >= self.ready_at

    def reserve_until(self, t):
        self.ready_at = t

    def __repr__(self):
        return f"{self.name}@{self.ready_at}"



from collections import deque


class Pipe:
    def __init__(self, latency, throughput, depth, name=""):
        self.latency = latency
        self.throughput = throughput
        self.depth = depth
        self.next_issue = 0
        self.inflight = deque()
        self.name = name

    def can_issue(self, t):
        return t >= self.next_issue and len(self.inflight) < self.depth

    def issue(self, t):
        self.next_issue = t + self.throughput
        self.inflight.append(t + self.latency)

    def advance(self, t):
        while self.inflight and self.inflight[0] <= t:
            self.inflight.popleft()


class LDMatrix:
    def __init__(self, latency, throughput, depth):
        self.pipe = Pipe(latency, throughput, depth, name="ldmatrix")

    def can_issue(self, t, smem_idx, dst_reg, smem):
        return dst_reg.is_ready(t) and self.pipe.can_issue(t)

    def issue(self, t, smem_idx, dst_reg, smem):
        done = t + self.pipe.latency
        dst_reg.val = smem.val[smem_idx]
        dst_reg.reserve_until(done)
        self.pipe.issue(t)



        
class MMA:
    def __init__(self, latency, throughput, depth):
        self.pipe = Pipe(latency, throughput, depth, name="mma")

    def can_issue(self, t, A, B, C):
        return (
            A.is_ready(t)
            and B.is_ready(t)
            and C.is_ready(t)
            and self.pipe.can_issue(t)
        )

    def issue(self, t, A, B, C):
        done = t + self.pipe.latency
        if C.val is None:
            C.val = 0.0
        C.val = C.val + A.val * B.val
        C.reserve_until(done)
        self.pipe.issue(t)





In [45]:
class WarpMachine:
    def __init__(self, A_regs, B_regs, C_regs, As, Bs, ld, mma):
        self.t = 0

        self.A = A_regs
        self.B = B_regs
        self.C = C_regs

        self.As = As
        self.Bs = Bs

        self.ld = ld
        self.mma = mma

    def step(self, inst, operands):
        # advance pipes
        self.ld.pipe.advance(self.t)
        self.mma.pipe.advance(self.t)

        if inst.can_issue(self.t, *operands):
            inst.issue(self.t, *operands)
            return True

        return False

    def run(self, program, max_cycles=None):
        pc = 0

        while pc < len(program):
            inst, operands = program[pc]

            if self.step(inst, operands):
                pc += 1
            else:
                self.t += 1

            if max_cycles and self.t >= max_cycles:
                break

        return self.t


In [46]:
A0,A1,A2 = Register("A0"), Register("A1"), Register("A2")
B0,B1,B2 = Register("B0"), Register("B1"), Register("B2")
C0,C1,C2 = Register("C0"), Register("C1"), Register("C2")

As = SharedMemoryMatrix("As", (8,8))
Bs = SharedMemoryMatrix("Bs", (8,8))

ld  = LDMatrix(60, 11, 8)
mma = MMA(34, 4, 8)

machine = WarpMachine(
    A_regs=[A0,A1,A2],
    B_regs=[B0,B1,B2],
    C_regs=[C0,C1,C2],
    As=As,
    Bs=Bs,
    ld=ld,
    mma=mma
)

program = [
    (ld,  (0, A0, As)),
    (ld,  (1, B0, Bs)),
    (mma, (A0, B0, C0)),

    (ld,  (2, A1, As)),
    (ld,  (3, B1, Bs)),
    (mma, (A1, B1, C1)),
]

cycles = machine.run(program)
print("cycles:", cycles)


cycles: 142


In [51]:
n_a_regs = 2 
n_b_regs = 2 
n_c_regs = 4
c_reg_m = 2 
c_reg_n = 2 

A_regs = [Register(f"A{i}") for i in range(n_a_regs)]
B_regs = [Register(f"B{i}") for i in range(n_b_regs)]
C_regs = [Register(f"C{i}") for i in range(n_c_regs)]

ld  = LDMatrix(60, 11, 8)
mma = MMA(34, 4, 8)
BM = 16 
BN = 16
BK = 16



As = SharedMemoryMatrix("As", (BM,BK))
Bs = SharedMemoryMatrix("Bs", (BK,BN))

machine = WarpMachine(
    A_regs,
    B_regs,
    C_regs,
    As,
    Bs,
    ld,
    mma
)

In [52]:
program = []

for bk in range(0, BK):
  for bm in range(0,BM,c_reg_m): 
    for bn in range(0, BN, c_reg_n): 
      
      a_load_0 = (ld, ((bm,bk),A_regs[0], As))
      
      a_load_1 = (ld, ((bm+1,bk),A_regs[1],As))
      
      b_load_0 = (ld, ((bk,bn), B_regs[0], Bs))
      
      b_load_1 = (ld, ((bk,bn+1), B_regs[1], Bs))
      
      
      
      mma_00 = (mma, (A_regs[0],B_regs[0], C_regs[0]))
      
      mma_01 = (mma, (A_regs[0], B_regs[1], C_regs[1]))
      
      mma_10 = (mma, (A_regs[1],B_regs[0], C_regs[2]))
      
      mma_11 = (mma, (A_regs[1], B_regs[1], C_regs[3]))      
      
      program += [a_load_0, b_load_0, b_load_1, a_load_1, mma_00, mma_01, mma_10, mma_11]


In [53]:
cycles = machine.run(program)

In [54]:
cycles

99328