In [3]:
import math
import numpy as np
class fully_associative_cache:
  def __init__(self, mem_addr_width, cache_block_size, N_cache_blocks): 
    self.w = mem_addr_width      # e.g., 32 or 64 bits
    self.b = cache_block_size    # e.g., 64 bytes
    self.N_b = N_cache_blocks
    
    # 1. Calculate how many bits represent the offset
    # If block size is 64, offset_bits = 6
    self.offset_bits = int(math.log2(self.b))
    
    # 2. Create the masks
    self.offset_mask = (1 << self.offset_bits) - 1
    self.tag_mask = ((1 << self.w) - 1) ^ self.offset_mask
    
    print(f"Block Size: {self.b} bytes")
    print(f"Offset Bits: {self.offset_bits}")
    print(f"Offset Mask: {bin(self.offset_mask)}")
    print(f"Tag Mask:    {bin(self.tag_mask)}")

# Example Usage
c = fully_associative_cache(32, 64, 128)
    
  

Block Size: 64 bytes
Offset Bits: 6
Offset Mask: 0b111111
Tag Mask:    0b11111111111111111111111111000000


In [4]:


class TensorCache:
    def __init__(self, addr_width, block_size, num_sets, associativity):
        self.W = addr_width
        self.B = block_size      # Block Size (bytes)
        self.S = num_sets        # Number of Sets (Rows)
        self.K = associativity   # Ways (Columns)

        # --- 1. Geometry & Addressing ---
        self.offset_bits = int(math.log2(self.B))
        self.index_bits  = int(math.log2(self.S))
        self.tag_bits    = self.W - self.index_bits - self.offset_bits

        # Bitmasks
        self.offset_mask = (1 << self.offset_bits) - 1
        self.index_mask  = ((1 << self.index_bits) - 1) << self.offset_bits

        # --- 2. The Tensors (State) ---
        # Shape: (S, K) -> (Sets, Ways)
        # We use -1 for invalid tags initially
        self.tags = np.full((self.S, self.K), -1, dtype=np.int64)
        self.valid = np.zeros((self.S, self.K), dtype=bool)
        
        # LRU Tensor: (S, K). 
        # Logic: Higher number = More recently used. 0 = Oldest/Empty.
        self.lru_counters = np.zeros((self.S, self.K), dtype=np.int64)
        self.global_timer = 0 # Monotonic clock for LRU

        # Data Tensor: (S, K, B)
        self.data = np.zeros((self.S, self.K, self.B), dtype=np.uint8)

        print(f"TensorCache Init: {self.S} Sets x {self.K} Ways x {self.B} Bytes")
        print(f"Geometry: [ Tag: {self.tag_bits} | Index: {self.index_bits} | Offset: {self.offset_bits} ]")

    def _split_addr(self, addr):
        offset = addr & self.offset_mask
        index  = (addr & self.index_mask) >> self.offset_bits
        tag    = addr >> (self.index_bits + self.offset_bits)
        return tag, index, offset

    def lookup(self, addr):
        self.global_timer += 1
        tag_in, idx_in, off_in = self._split_addr(addr)
        
        # --- 1. Slice (Select the Set) ---
        # We grab the specific row for this Set Index.
        # Hardware: This is the decoder enabling one specific wordline.
        set_tags = self.tags[idx_in, :]      # Shape: (K,)
        set_valid = self.valid[idx_in, :]    # Shape: (K,)

        # --- 2. Search (Broadcast & Contract) ---
        # Compare Input Tag against ALL K ways in parallel
        # hit_vector becomes a boolean array e.g., [False, True, False, False]
        hit_vector = (set_tags == tag_in) & (set_valid)

        if np.any(hit_vector):
            # === HIT ===
            # Find which way hit (e.g., Way 1)
            way_idx = np.where(hit_vector)[0][0]
            
            # Update LRU
            self.lru_counters[idx_in, way_idx] = self.global_timer
            
            # Retrieve Data (Slice the Data Tensor)
            # data shape: (S, K, B) -> scalar byte
            val = self.data[idx_in, way_idx, off_in]
            return "HIT", val
        
        else:
            # === MISS ===
            return "MISS", self._handle_miss(tag_in, idx_in, off_in)

    def _handle_miss(self, tag, idx, offset):
        # 1. Find Victim (LRU Eviction)
        # Look at the counters for this specific set
        # argmin gives us the index of the oldest (smallest) timer
        victim_way = np.argmin(self.lru_counters[idx, :])

        # 2. "Fetch" from memory (simulated)
        # In a real engine, this is a GMEM load
        new_data_block = np.full((self.B), 0xAA, dtype=np.uint8) # 0xAA dummy data

        # 3. Update Tensors
        self.tags[idx, victim_way]  = tag
        self.valid[idx, victim_way] = True
        self.lru_counters[idx, victim_way] = self.global_timer
        self.data[idx, victim_way, :] = new_data_block

        return new_data_block[offset]

# --- Usage ---
# 32-bit addr, 64B block, 4 Sets, 2 Ways
cache = TensorCache(32, 64, 4, 2)

# 1. Cold Miss
print(cache.lookup(0x1000)) # Set 0, Way 0 filled

# 2. Cold Miss (Same Set, Different Tag)
print(cache.lookup(0x2000)) # Set 0, Way 1 filled

# 3. Conflict Miss (Same Set, Cache Full -> Evict LRU (Way 0))
print(cache.lookup(0x3000)) 

# 4. Hit (Back to 0x3000)
print(cache.lookup(0x3000))

TensorCache Init: 4 Sets x 2 Ways x 64 Bytes
Geometry: [ Tag: 24 | Index: 2 | Offset: 6 ]
('MISS', np.uint8(170))
('MISS', np.uint8(170))
('MISS', np.uint8(170))
('HIT', np.uint8(170))
