In [45]:
import math
import sys

In [46]:
accessTypes = { 'LD':0, 'RFO': 1, 'PF': 2, 'WB':3 }

In [47]:
class Block():
    def __init__(self, numWays = 16):
        self.tag: int = 0
        self.valid: bool = False
        self.offset: int = 0
        # self.dirty: bool = False
        self.preuse: int = int(sys.maxsize)
        self.preuseCounter: int = 0
        self.ageSinceInsertion: int = 0
        self.ageSinceAccess: int = 0 
        self.accessType: int = 0
        self.accessCounts = [0, 0, 0, 0]
        self.hits: int = 0
        self.recency: int = numWays - 1
    
    def getState(self):
        state = [self.offset, self.valid, self.preuse, self.ageSinceInsertion, self.ageSinceAccess, self.accessType]
        state.extend(self.accessCounts)
        state.extend([self.hits, self.recency])
        return state
               

In [48]:
class Cache():
    def __init__(self, numSets_ = 2048, numWays_ = 16, blockSize_ = 65):
        self.numSets: int = numSets_
        self.numWays: int = numWays_
        self.blockSize: int = blockSize_
        self.BLOCKS = [Block(numWays = self.numWays) for _ in range(self.numSets*self.numWays)]
        self.setAccesses = [0 for _ in range(self.numSets)]
        self.setAccessesSinceMiss = [0 for _ in range(self.numSets)]
        self.preuseDistances = {}
        self.globalAccessCount = 0
        
        self.offsetBits: int = int(math.log2(self.blockSize))
        self.setBits: int = int(math.ceil(math.log2(self.numSets)))
        self.setBitMask: int = (1<<self.setBits)-1
        self.offsetBitMask: int = (1<<self.offsetBits)-1
        
    def splitAddress(self, address: int) -> (int, int, int):
        setIdx: int = (address >> self.offsetBits) & self.setBitMask
        offset: int = address & self.offsetBitMask
        tag: int = address >> (self.offsetBits + self.setBits)
        return tag, setIdx, offset
    
    # TODO Need to normalize the state access count value
    def getCurrentState(self, address: int, accessTypes: int):
        tag, setIdx, offset = self.getSetIndex(address)
        # Get preuse of cache
        preuse = sys.maxsize
        cacheLineAddress = address >> self.offsetBits # Address of cache line, remove offset from address
        if cacheLineAddress in self.preuseDistances:
            # User a global access counter to compute preuse distance as its differnce to the value in the preuseDistance dictionary
            preuse = globalAccessCount - self.preuseDistances[cacheLineAddress]
        
        blocks = self.BLOCKS[setIdx*self.numWays: setIdx*self.numWays + self.numWays]
        state = [offset, preuse ,accessType] # Access Info
        state.extend( [setIdx, self.setAccesses[setIdx], self.setAccessesSinceMiss[setIdx] ] ) # set info
        # cache line info
        for line in blocks:
            state.extend(line.getState())
        return state
    
    def updateRecency(self, setIdx, way):
        blocks = self.BLOCKS[setIdx*self.numWays: setIdx*self.numWays + self.numWays]
        # Store recency of block being updated
        currentBlockRecency = blocks[way].recency
        blocks[way].recency = 0
        # Update recency of all those lower that current
        for i in self.numWays:
            if blocks[i].recency < currentBlockRecency:
                blocks[i].recency += 1
        
    def accessCache(self, address: int, accessType: int, way: int):
        self.globalAccessCount += 1
        cacheLineAddress = address >> self.offsetBits
        # Update the preuseDistances dict to the current value of globalAccessCount on each access to a cache line address
        if cacheLineAddress in self.preuseDistances:
            self.preuseDistances[cacheLineAddress] = globalAccessCount
        # Split address to parts
        tag, setIdx, offset = self.getSetIndex(address)
        setBlockIndex = setIdx*self.numWays
        #update set params
        self.setAccesses[setIdx] += 1;
        self.setAccessesSinceMiss[setIdx] += 1
        # Check for hits and update block params
        hit: bool = False
        way = 0
        for i in range(self.numWays):
            self.BLOCKS[setBlockIndex + i].ageSinceInsertion += 1 #reset on miss
            self.BLOCKS[setBlockIndex + i].ageSinceAccess += 1 #reset on hit
            self.BLOCKS[setBlockIndex + i].preuseCounter += 1
            
            if self.BLOCKS[setBlockIndex + i].tag == tag and self.BLOCKS[setBlockIndex + i].valid:
                hit = True
                way = i
        if hit:
            handleHit(setIdx, way, accessType, (tag, setIdx, offset) )
        else:
            handleMiss()
            self.setAccessesSinceMiss[setIdx] = 0
            
        updateRecency(setIdx, way)
            
    def handleHit(self, setIdx, way, accessType, addressParts):
        block: Block = self.BLOCKS[setIdx*way]
        tag, setIdx, offset = addressParts
        # Update block params
        block.offset = offset
        block.preuse = block.preuseCounter
        block.preuseCounter = 0
        block.ageSinceAccess = 0
        block.accessType = accessType
        block.accessCounts[accessType] += 1
        block.hits += 1

    def handleMiss(self, setIdx, way, accessType, addressParts):
        block: Block = self.BLOCKS[setIdx*way]
        tag, setIdx, offset = addressParts
        # Update block params
        block.valid = True
        block.tag = tag
        block.offset = offset
        block.preuse = int(sys.maxsize)
        block.preuseCounter = 0
        # block.ageSinceAccess = 0
        block.ageSinceInsertion = 0
        block.accessType = accessType
        block.accessCounts[accessType] += 1
             