In [None]:
import numpy as np

class MemoryState:
    def __init__(self, name, unit_access_cost, capacity_bytes):
        self.name = name
        self.unit_access_cost = unit_access_cost
        self.capacity_bytes = capacity_bytes

    def __repr__(self):
        # Format bytes to KB/MB for readability
        if self.capacity_bytes >= 1024**3:
            cap_str = f"{self.capacity_bytes / 1024**3:.1f} GB"
        elif self.capacity_bytes >= 1024**2:
            cap_str = f"{self.capacity_bytes / 1024**2:.1f} MB"
        elif self.capacity_bytes >= 1024:
            cap_str = f"{self.capacity_bytes / 1024:.1f} KB"
        else:
            cap_str = f"{self.capacity_bytes} B"
            
        return f"{self.name}(Cost={self.unit_access_cost}, Max={cap_str})"

# --- Define Your Hierarchy (Approx Hopper Capacities) ---
# GMEM: Effectively infinite (80GB+)
GMEM  = MemoryState("GMEM",  100.0, capacity_bytes=80 * 1024**3) 

# DSMEM: Distributed SMEM (Cluster level) - large but not infinite
DSMEM = MemoryState("DSMEM", 20.0,  capacity_bytes=100 * 1024**2) 

# SMEM: Hopper has ~227KB max per SM. Let's be strict.
SMEM  = MemoryState("SMEM",  10.0,  capacity_bytes=227 * 1024) 

# REG: 64K registers per SM (32-bit). Very tight constraint.
REG   = MemoryState("REG",   1.0,   capacity_bytes=64 * 1024 * 4) 

class AugmentedTensor:
    def __init__(self, data, state):
        """
        :param data: The actual numpy array (or shape, if you want empty)
        :param state: MemoryState object (holds the cost logic)
        """
        # 1. Initialize Data
        if isinstance(data, (tuple, list)):
            self.data = np.zeros(data) # Default float64, be careful!
        else:
            self.data = np.array(data)
            
        self.state = state
        self.total_cost = 0.0

        # 2. CAPACITY CHECK
        tensor_bytes = self.data.nbytes
        if tensor_bytes > self.state.capacity_bytes:
            raise MemoryError(
                f"OOM: Cannot allocate tensor of size {tensor_bytes} bytes "
                f"in {self.state.name} (Max: {self.state.capacity_bytes} bytes)."
            )

    def load(self, source_tensor):
        """
        B.load(A) -> Copies data from A to B.
        The COST is incurred by A (the Source).
        """
        # 1. Check if source is actually an AugmentedTensor (or a slice of one)
        # Note: We do NOT need to check capacity here, because 'self' (B) 
        # was already checked during __init__.
        
        # 2. Incur Cost on the SOURCE
        if isinstance(source_tensor, TensorSlice):
            source_tensor.report_access()
            # Copy data (Broadcasting/Slicing handled by numpy)
            self.data[:] = source_tensor.data
        elif isinstance(source_tensor, AugmentedTensor):
            cost = source_tensor.data.size * source_tensor.state.unit_access_cost
            source_tensor.total_cost += cost
            self.data[:] = source_tensor.data
        else:
            raise ValueError("Source must be an AugmentedTensor or TensorSlice")

    def __getitem__(self, key):
        return TensorSlice(self, key)

    def __repr__(self):
        return f"[{self.state.name}] Shape:{self.data.shape} | Size:{self.data.nbytes} B | AccCost: {self.total_cost}"


class TensorSlice:
    """
    Helper class to handle A[slice].
    When this is passed to load(), we charge the PARENT.
    """
    def __init__(self, parent, slice_key):
        self.parent = parent
        self.data = parent.data[slice_key] # The actual data view
        self.state = parent.state # Same state as parent

    def report_access(self):
        """ Charge the parent for this slice's size """
        cost = self.data.size * self.state.unit_access_cost
        self.parent.total_cost += cost


# ==========================================
#  USER DEMO: Valid & Invalid Allocations
# ==========================================

print("--- 1. Valid Allocation (GMEM) ---")
# 100x100 float64 = 80KB. Fits easily in GMEM.
A_gmem = AugmentedTensor(np.ones((100, 100)), GMEM) 
print(A_gmem)

print("\n--- 2. Valid Allocation (SMEM) ---")
# 10x100 float64 = 8KB. Fits in 227KB SMEM.
B_smem = AugmentedTensor(np.zeros((10, 100)), SMEM) 
print(B_smem)

print("\n--- 3. Perform Load ---")
B_smem.load(A_gmem[0:10])
print(f"Transfer Complete. Source Cost: {A_gmem.total_cost}")

print("\n--- 4. INVALID Allocation (OOM Check) ---")
try:
    # Try to fit a HUGE tensor into SMEM
    # 200x200 float64 = 320,000 bytes (~312KB). 
    # Max SMEM is 227KB. This should FAIL.
    huge_shape = (200, 200) 
    print(f"Attempting to alloc {huge_shape} in SMEM...")
    C_fail = AugmentedTensor(np.zeros(huge_shape), SMEM)
except MemoryError as e:
    print(f"SUCCESS: Caught Expected Error -> {e}")

--- Initial State ---
A (Source): [GMEM] Shape:(100, 100) | AccCost: 0.0
B (Dest):   [SMEM] Shape:(10, 100) | AccCost: 0.0

--- After Load ---
A (Source): [GMEM] Shape:(100, 100) | AccCost: 100000.0
B (Dest):   [SMEM] Shape:(10, 100) | AccCost: 0.0

B Data Sample (First row): [1. 1. 1. 1. 1.]
