In [None]:
import numpy as np
from scipy.sparse import csr_matrix, coo_matrix
import time
from threading import Thread
from threading import Event
from SpArchClasses import MatrixAFetcher
from SpArchClasses import DistanceListBuilder
from SpArchClasses import MatrixBPrefetcher
from SpArchClasses import MultiplierArray
from SpArchClasses import MergeTree
from SpArchClasses import Scheduler
from SpArchClasses import Memory

In [None]:
LLB_TILE_SIZE = 300
PE_TILE_SIZE = 150
I = 100
K = 100
J = 100
NUM_INTS = 100

gen = np.random.default_rng()
data1 = gen.integers(1,10,NUM_INTS)
row1 = gen.integers(0,I,NUM_INTS)
col1 = gen.integers(0,K,NUM_INTS)

data2 = gen.integers(1,10,NUM_INTS)
row2 = gen.integers(0,K,NUM_INTS)
col2 = gen.integers(0,J,NUM_INTS)
i1 = csr_matrix(coo_matrix((data1, (row1, col1)), shape=(I, K)))
i2 = csr_matrix(coo_matrix((data2, (row2, col2)), shape=(K, J)))

In [None]:
print(i1.toarray())

In [None]:
print(i2.toarray())

In [None]:
NUM_CHANNELS = 16
BANDWIDTH_PER_CHANNEL = 8
NUM_INPUT_FIFOS = 64
MERGE_SIZE = 1

AL = MatrixAFetcher(i1)
DL = DistanceListBuilder()
MBP = MatrixBPrefetcher(i2)
MPA = MultiplierArray()
MT = MergeTree(MERGE_SIZE)
M = Memory(NUM_CHANNELS, BANDWIDTH_PER_CHANNEL)
S = Scheduler(AL,MT,DL,MBP,MPA,NUM_INPUT_FIFOS)

AL.setDLB(DL)
AL.setMemory(M)

DL.setMBP(MBP)
DL.setMultiplier(MPA)

MBP.setMultiplier(MPA)
MBP.setDistanceListBuilder(DL)
MBP.setMemory(M)

MPA.setMergeTree(MT)
MPA.setScheduler(S)

MT.setScheduler(S)

endFlag = True



In [None]:

ALEvent = Event()
ALEvent.set()
Thread(target=AL.running,args=[ALEvent]).start()

MBPEvent = Event()
MBPEvent.set()
Thread(target=MBP.running,args=[MBPEvent]).start()

MPAEvent = Event()
MPAEvent.set()
Thread(target=MPA.running,args=[MPAEvent]).start()

MTEvent = Event()
MTEvent.set()
Thread(target=MT.running,args=[MTEvent]).start()

count = 0
while endFlag:
    count += 1
    ALEvent.clear()
    MBPEvent.clear()
    MPAEvent.clear()
    MTEvent.clear()

    if not AL.endFlag:
        ALEvent.wait()
    if not MBP.endFlag:
        MBPEvent.wait()
    if not MPA.endFlag:
        MPAEvent.wait()
    if not MT.endFlag:
        MTEvent.wait()
    
    endFlag = not (AL.endFlag and MBP.endFlag and MPA.endFlag and MT.endFlag)
    
    if not MPA.endFlag:
        MPA.loadInputs()
    M.cycle()

# Cycle output
print("Cycle Count: ", count)

# Memory Bandwidth Utilization:
print("Average Bandwidth Utilization (%): " + str(M.TotalMemoryPulled/count) + ", Bandwidth Utilization When Memory Is In Use (%): " + str(M.TotalMemoryPulled/M.NumCyclesInUse))

# Memory Use
print("Memory Use ")
print("Matrix A Fetcher Memory Use (bytes): ", AL.memoryUse)
print("Matrix B Prefetcher Memory Use: (bytes)", MBP.memoryAccessBytes)
print("Total Memory Use: (bytes)", M.TotalMemoryPulled)

# Hardware Utilization
print("Hardware Utilization ")
print("Number of Rounds of Merging: ", S.rounds-1)
# We don't need any numbers for A matrix Fetcher. 
print("Matrix B Prefetcher Memory Wasted Cycles %: ", round(MBP.memoryWastedCycles/count,4), ", MBP Wasted Cycles: ", MBP.memoryWastedCycles)
print("Matrix B Prefetcher Hardware Utilization %: ", (1-round(MBP.wastedCycles/count,4)), ", MBP Wasted Cycles: ", MBP.wastedCycles)
print("Multiplier Array Hardware Utilization %", (1-round(MPA.wastedCycles/count,4)), ", MPA Wasted Cycles: ", MPA.wastedCycles)
print("MergeTree Cycles Idle (No merging) %", (round(MT.idleCycles/count,4)) ,", MergeTree Cycles Idle: ", MT.idleCycles)
print("MergeTree Wasted Cycles (No Partial Loading or Merging) %", (round(MT.wastedCycles/count,4)) ,", MergeTree Cycles Wasted: ", MT.wastedCycles)
print("MergeTree Average FIFO Uptime (%)", (1-round(MT.totalEmptyFifos/MT.totalMergeCycles,4)))

In [None]:
data = []
i = []
j = []
for x in MT.partialMatrices[-1]:
    data.append(x[0])
    i.append(x[1])
    j.append(x[2])

total = coo_matrix((data,(i,j)),shape=(I,K))
print(total.toarray())
print(np.matmul(i1.toarray(),i2.toarray()))
print(np.equal(total.toarray(), np.matmul(i1.toarray(),i2.toarray())))
print(np.allclose(total.toarray(),np.matmul(i1.toarray(),i2.toarray()),rtol=0.01))
