In [2]:
from typing import Tuple

import numpy as np

class MemoryState:
    def __init__(self, name, unit_access_cost, capacity_bytes,level):
        self.level = level
        self.name = name
        self.unit_access_cost = unit_access_cost
        self.capacity_bytes = capacity_bytes

    def __repr__(self):
        # Format bytes to KB/MB for readability
        if self.capacity_bytes >= 1024**3:
            cap_str = f"{self.capacity_bytes / 1024**3:.1f} GB"
        elif self.capacity_bytes >= 1024**2:
            cap_str = f"{self.capacity_bytes / 1024**2:.1f} MB"
        elif self.capacity_bytes >= 1024:
            cap_str = f"{self.capacity_bytes / 1024:.1f} KB"
        else:
            cap_str = f"{self.capacity_bytes} B"
            
        return f"{self.name}(Cost={self.unit_access_cost}, Max={cap_str})"



class AugmentedTensor:
    def __init__(self, data, state):
        """
        :param data: The actual numpy array (or shape, if you want empty)
        :param state: MemoryState object (holds the cost logic)
        """
        # 1. Initialize Data
        if isinstance(data, (tuple, list)):
            self.data = np.zeros(data) # Default float64, be careful!
        else:
            self.data = np.array(data)
            
        self.state = state
        self.total_cost = 0.0

        # 2. CAPACITY CHECK
        tensor_bytes = self.data.nbytes
        if tensor_bytes > self.state.capacity_bytes:
            raise MemoryError(
                f"OOM: Cannot allocate tensor of size {tensor_bytes} bytes "
                f"in {self.state.name} (Max: {self.state.capacity_bytes} bytes)."
            )

    def load(self, source_tensor):
        """
        B.load(A) -> Copies data from A to B.
        The COST is incurred by A (the Source).
        """
        # 1. Check if source is actually an AugmentedTensor (or a slice of one)
        # Note: We do NOT need to check capacity here, because 'self' (B) 
        # was already checked during __init__.
        
        # 2. Incur Cost on the SOURCE
        if isinstance(source_tensor, TensorSlice):
            source_tensor.report_access()
            # Copy data (Broadcasting/Slicing handled by numpy)
            self.data[:] = source_tensor.data
        elif isinstance(source_tensor, AugmentedTensor):
            cost = source_tensor.data.size * source_tensor.state.unit_access_cost
            source_tensor.total_cost += cost
            self.data[:] = source_tensor.data
        else:
            raise ValueError("Source must be an AugmentedTensor or TensorSlice")

    def __getitem__(self, key):
        return TensorSlice(self, key)

    def __repr__(self):
        return f"[{self.state.name}] Shape:{self.data.shape} | Size:{self.data.nbytes} B | AccCost: {self.total_cost}"


class TensorSlice:
    """
    Helper class to handle A[slice].
    When this is passed to load(), we charge the PARENT.
    """
    def __init__(self, parent, slice_key):
        self.parent = parent
        self.data = parent.data[slice_key] # The actual data view
        self.state = parent.state # Same state as parent

    def report_access(self):
        """ Charge the parent for this slice's size """
        cost = self.data.size * self.state.unit_access_cost
        self.parent.total_cost += cost




In [3]:
H100 = {"N_sms":132,"rmem_kb":228, "SMEM_kb":228, "DSMEM_mb":3.5, "GMEM_gb":80}

GMEM = MemoryState("gmem",500, H100["GMEM_gb"]*1024**3,3)
DSMEM = MemoryState("DSMEM",100, H100["DSMEM_mb"]*1024**2,2)
SMEM = MemoryState("SMEM", 30, H100["SMEM_kb"]*1024,1)
RMEM = MemoryState("RMEM", 1, H100["rmem_kb"]*1024,0)



In [None]:
#we will make the symplifying assumption of tiling-shift isomorphism across paralell processors, that is, 
#we only care about what cluster_0 processes (its whole loop), and every part block_0 in cluster_0 processes, and so on, 
#and also outer K is literally always the best.

class make_matmul:
  def __init__(self,g_shape :Tuple[int,int,int],c_shape:Tuple[int,int], c_loop_shape:Tuple[int,int,int],
               b_shape:Tuple[int,int],b_loop_shape:Tuple[int,int,int], t_shape, t_loop_shape:Tuple[int,int,int]):
    
    self.M,self.K,self.N = g_shape 
    self.CM, self.CN = c_shape 
    self.cm,self.ck,self.cn = c_loop_shape 
    self.BM, self.BN = b_shape 
    self.bm, self.bk, self.bn = b_loop_shape 
    self.TM, self.TN = t_shape
    self.tm, self.tk, self.tn = t_loop_shape
    
    self.gA = AugmentedTensor(np.random.randn(self.M,self.K), GMEM)
    self.gB = AugmentedTensor(np.random.randn(self.K,self.N), GMEM)
    self.gC = AugmentedTensor(np.zeros((self.M, self.N)), GMEM)
    
    self.cA = AugmentedTensor(np.zeros((self.cm,self.ck)), DSMEM)
    self.cB = AugmentedTensor(np.zeros((self.ck,self.cn)), DSMEM)
    self.cC = AugmentedTensor(np.zeros((self.cm,self.cn)), DSMEM)

    self.sA = AugmentedTensor(np.zeros((self.bm, self.bk)), SMEM)
    self.sB = AugmentedTensor(np.zeros((self.bk, self.bn)), SMEM)
    self.sC = AugmentedTensor(np.zeros((self.bm, self.bn)), SMEM)
    
    self.rA = AugmentedTensor(np.zeros((self.tm,self.tk)), RMEM)
    self.rB = AugmentedTensor(np.zeros((self.tk,self.tn)), RMEM)
    self.rC = AugmentedTensor(np.zeros((self.tm, self.tn)),RMEM)
    
    
def run(self): 
  for ck_idx in range(0, self.K, self.ck): 
    for cm_idx in range(0,self.CM, self.cm): 
      for cn_idx in range(0,self.CN, self.cn): 
        self.cA.load(self.gA[cm_idx:cm_idx+self.cm,ck_idx:ck_idx+self.ck])
        self.cB.load(self.gB[ck_idx:ck_idx+self.ck,cn_idx:cn_idx+self.cn])
        for bk_idx in range(0, self.ck, self.bk):
          for bm_idx in range(0,self.BM, self.bm): 
            for bn_idx in range(0, self.BN, self.bn):
              self.sA.load(self.cA[bm_idx:bm_idx+self.bm,bk_idx:bk_idx+self.bk])
              self.sB.load(self.cB[bk_idx:bk_idx+self.bk,bn_idx:bn_idx+self.bn])
              for tk_idx in range(0, self.bk, self.tk): 
                for tm_idx in range(0,self.TM, self.tm): 
                  for tn_idx in range(0, self.TN, self.tn): 
                    self.rA.load(self.sA[tm_idx:tm_idx+self.tm,tk_idx:tk_idx+self.tk])
                    self.rB.load(self.sB[tk_idx:tk_idx+self.tk,tn_idx:tn_idx+self.tn])
                    for k in range(0, tk_idx): 
                      for m in range(0, tm_idx):
                        for n in range(0, tn_idx): 
                      
                            
                          c_global_m_coord = tm_idx + bm_idx + cm_idx + m
                          c_global_n_coord = tn_idx + bn_idx + cn_idx + n 
                          self.rC.load(self.gC[c_global_m_coord, c_global_n_coord])                          
              
        
        

    
