In [None]:
import numpy as np

class MemoryState:
    def __init__(self, name, unit_access_cost):
        self.name = name
        self.unit_access_cost = unit_access_cost

    def __repr__(self):
        return f"{self.name}(Cost={self.unit_access_cost})"

# --- Define Your Hierarchy ---
GMEM  = MemoryState("GMEM",  100.0) # Expensive to read
DSMEM = MemoryState("DSMEM", 20.0)  # Medium
SMEM  = MemoryState("SMEM",  10.0)  # Cheap
REG   = MemoryState("REG",   1.0)   # Fastest

class AugmentedTensor:
    def __init__(self, data, state):
        """
        :param data: The actual numpy array (or shape, if you want empty)
        :param state: MemoryState object (holds the cost logic)
        """
        # If user passes just a shape, we create zeros, otherwise use the array
        if isinstance(data, (tuple, list)):
            self.data = np.zeros(data)
        else:
            self.data = np.array(data)
            
        self.state = state
        self.total_cost = 0.0

    def load(self, source_tensor):
        """
        B.load(A) -> Copies data from A to B.
        
        The COST is incurred by A (the Source), because we are 
        accessing A's memory to read it.
        """
        # 1. Check if source is actually an AugmentedTensor (or a slice of one)
        if not isinstance(source_tensor, AugmentedTensor):
            raise ValueError("Source must be an AugmentedTensor")

        # 2. Incur Cost on the SOURCE
        # Cost = (Number of Elements Accessed) * (Source's Unit Cost)
        access_size = source_tensor.data.size
        cost_incurred = access_size * source_tensor.state.unit_access_cost
        
        source_tensor.total_cost += cost_incurred
        
        # 3. Actually Move the Data (Numpy Copy)
        # We assume shapes match for this primitive load, or broadcast works
        self.data[:] = source_tensor.data

    def __getitem__(self, key):
        """
        Allows slicing: A[0:10]. 
        Returns a temporary AugmentedTensor view so we can pass it to load().
        We share the 'total_cost' accumulator with the parent so the parent gets charged.
        """
        sliced_data = self.data[key]
        
        # Create a 'View' Tensor
        # It needs to point back to the original to charge costs? 
        # For simplicity in this snippet, let's just make a new object 
        # but manually link the cost accounting if needed. 
        # ACTUALLY: The prompt asks to incur cost on the tensor itself.
        # If I slice A, and pass the slice to load, I need A to be charged.
        
        # Implementation Trick: The "View" is just a new Tensor, but we 
        # need to handle the charging manually or link them.
        # Let's keep it simple: The 'load' function calculates size from the input.
        # If the input is a slice, it has a smaller size.
        # But we want the ORIGINAL A to be charged? 
        # YOU SAID: "A_gmem(slice) ... INCURR A COST ON A_GMEM"
        
        # To support this strictly:
        return TensorSlice(self, key)

    def __repr__(self):
        return f"[{self.state.name}] Shape:{self.data.shape} | AccCost: {self.total_cost}"


class TensorSlice(AugmentedTensor):
    """
    Helper class to handle A[slice].
    When this is passed to load(), we charge the PARENT.
    """
    def __init__(self, parent, slice_key):
        self.parent = parent
        self.data = parent.data[slice_key] # The actual data view
        self.state = parent.state # Same state as parent

    def report_access(self):
        """ Charge the parent for this slice's size """
        cost = self.data.size * self.state.unit_access_cost
        self.parent.total_cost += cost

# --- Re-implementing load to handle the Slice Wrapper ---
def smart_load(self, source):
    if isinstance(source, TensorSlice):
        # Charge the parent of the slice
        source.report_access()
        # Copy data
        self.data[:] = source.data
    elif isinstance(source, AugmentedTensor):
        # Charge the full tensor
        cost = source.data.size * source.state.unit_access_cost
        source.total_cost += cost
        self.data[:] = source.data
    else:
        raise ValueError("Unknown source type")

# Patching the method for the demo
AugmentedTensor.load = smart_load


# ==========================================
#  USER DEMO
# ==========================================

# 1. Setup Data
A_gmem = AugmentedTensor(np.ones((100, 100)), GMEM) # 100x100 ones
B_smem = AugmentedTensor(np.zeros((10, 100)), SMEM) # Smaller buffer

print("--- Initial State ---")
print("A (Source):", A_gmem)
print("B (Dest):  ", B_smem)

# 2. Perform the Load
# "B_smem.load(A_gmem(slice))"
# We act on a slice of A (first 10 rows)
B_smem.load(A_gmem[0:10]) 

print("\n--- After Load ---")
print("A (Source):", A_gmem) # Should show cost: 1000 elements * 100 cost = 100,000
print("B (Dest):  ", B_smem) # Should have data, Cost 0 (since it was written to, not read from)
print("\nB Data Sample (First row):", B_smem.data[0, :5])