In [150]:
import numpy as np
import random
import math

random.seed(96)
F=6 #we assume level[0] accessed by stage[0] has F nodes, etc. (root is not a level)
D = 128 #number of pipe stages
pip_lst = [None] * D
Y = 256 #number of banks
bank_lst = [None] * Y
    
assigned_tree_level_lst = [1000] * Y #lst[ind]=i means bank[ind] is assigned tree level i (1000: unassigned)
# === Assume each bank can store 40 node entries at maximum. tree size=N=250*40=10K nodes 
bank_node_counter=[0] * Y #lst[ind]=c means bank[ind] has already inserted c nodes. if c=40, this bank is full.
# ===========================================================
# current_insertion_bank = [3000] * D # memoization. lst[ind]=2 means tree level[ind] is currently inserting to bank 2, etc ...
level_banks_dict = {} 
[level_banks_dict.setdefault(x, []) for x in range(D)] #dict key is tree level, value is a list of bank idx storing the tree level.
# ===========================================================
#with current F and N, only the first 5 levels can be full. 
# e.g., when num_nodes_inserted[0]=F, num_treelevels_tobe_excluded should always = 1
num_nodes_inserted_lst=[0] * 5
num_treelevels_tobe_excluded = 0 

# ===========================================================
# # Global congestion_count_dict: count num congestions from stage pov (total num possible conflicts with other logD stages)
# # key: treelev(pipe stage).  value: sum of num potential butterfly conflicts with other logD stages (logD at max)
# congestion_count_dict = {} 
# [congestion_count_dict.setdefault(x, 0) for x in range(D)] 
# ===========================================================
#Global src_cong_list_dict: tracks and counts congestions from stage pov by avoiding repeating counts.
# should only be updated when an empty banks is assigned.
src_cong_list_dict={} 
#key: src pipe stage(treelev)).  
# value: list of other src IDs that incurs a congestion with the key based on level_banks_dict (len is the counts of congestions, logD at max)
[src_cong_list_dict.setdefault(x, []) for x in range(D)] 

# print(pip_lst)
# print(bank_lst)
    
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# basic version: only look for bank, do not consider butterfly congestions
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# output: update level_banks_dict, assigned_tree_level_lst
# output: return_index is the decided bank index for the treelevel
def find_a_valid_bank_basic(treelevel):
    found_valid=False
    return_index=3000
    indices = [i for i, x in enumerate(assigned_tree_level_lst) if x == treelevel]
    #found a labeled valid bank to insert to, return bank index     
    if (len(indices) is not 0): 
        #now that there exist banks labeled treelevel, check if they are full
        for ind in indices:
            if (bank_node_counter[ind]<=60): # this bound is set larger than 40
                assert(level_banks_dict[treelevel][-1]==ind)
                bank_node_counter[ind]+=1
                found_valid=True
                return_index=ind
                break
    #cannot find a labeled valid bank to insert to, label a new bank
    if (found_valid!=True):
        try:
            first_empty_index=assigned_tree_level_lst.index(1000) #if cannot find, this will throw an error
            assigned_tree_level_lst[first_empty_index]=treelevel
            return_index=first_empty_index
            found_valid=True
        except ValueError:
            print("ERROR: no empty banks can be found! Returned an invalid index 3000")
    if (return_index not in level_banks_dict[treelevel]):
        level_banks_dict[treelevel].append(return_index)
    
    return return_index
    
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# advanced version: Strictly avoid any butterfly congestions
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# output: update level_banks_dict, assigned_tree_level_lst, assigned_tree_level_lst
# output: return_index is the decided bank index for the treelevel
def find_a_valid_bank_advanced(treelevel):
    found_valid=False
    return_index=3000
    indices = [i for i, x in enumerate(assigned_tree_level_lst) if x == treelevel]
    #found a labeled valid bank to insert to, return bank index     
    if (len(indices) is not 0): 
        #now that there exist banks labeled treelevel, check if they are full
        for ind in indices:
            if (bank_node_counter[ind]<=60): # this bound is set larger than 40
                assert(level_banks_dict[treelevel][-1]==ind)
                bank_node_counter[ind]+=1
                found_valid=True
                return_index=ind #if there already exists a non-full bank to insert to, return this and do not search for new banks
                break
    #cannot find a labeled valid bank to insert to (all full or no such bank), label a new bank
    #the new bank should be contrained to one that do not create conflict (or create minimum conflict) with an existing banks assignment:
    if (found_valid!=True):
        flag_found_0_cong=False
        num_empty=assigned_tree_level_lst.count(1000)
        min_cong_bankidx=0
        min_cong=1000000
        if (num_empty==0):
            print("ERROR: no empty banks can be found! Returned an invalid index 3000")
        if (num_empty>0):
            bank_itr=1 #only used to iterate over the empty banks, not the actual idx of empty banks
            while (bank_itr!=num_empty):
                candidate_bank_idx=find_yth_empty_bank(bank_itr) #actual index of empty bank
                num_cong=len(congestion_check(treelevel, candidate_bank_idx))
                if (num_cong<min_cong):
                    min_cong=num_cong
                    min_cong_bankidx=candidate_bank_idx
                if (num_cong==0):
                    flag_found_0_cong=True
                    return_index=candidate_bank_idx
                    break       
                bank_itr+=1
            if (flag_found_0_cong==False): #after the while loop, cannot find 0-cong new banks
#                 print("cannot find 0 cong!")
                return_index=min_cong_bankidx #choose the first bank idx that creates the smallest number of cong
    if (return_index not in level_banks_dict[treelevel]):
        level_banks_dict[treelevel].append(return_index)
    assigned_tree_level_lst[return_index]=treelevel
    src_cong_list_dict[treelevel]=congestion_check(treelevel, return_index)
        
    return return_index

# ++++++++++++++++++++++++++++++++++++++++++++++
# helper function to find the y^th empty bank in assigned_tree_level_lst
# input: y->1: the first empty bank, y->2: the second... the y^th
# output: the bank index that is the y^th empty bank
def find_yth_empty_bank(y):
    ind_itr=0
    empty_cnt=0
    while (ind_itr!=len(assigned_tree_level_lst)):
        if (assigned_tree_level_lst[ind_itr]==1000):
            empty_cnt+=1
        if (empty_cnt==y):
            break
        ind_itr+=1
    if (ind_itr>=len(assigned_tree_level_lst)):
        ind_itr=None
    return ind_itr #bank index

# ++++++++++++++++++++++++++++++++++++++++++++++
# ???????
# helper function to check if the routing from treelevel to a candidate destination bank index incurs conflict(s) with 
# other routes specified by level_banks_dict (key:treelevel(i.e., src), value:bank_ids(i.e., dests))
# It only checks and returns the result, it does not modify the src_cong_list_dict because it does not make the decision of whether
# the bank cand_dest should be used to store treelev.
# input: treelev->src index, cand_dest->candidate bank destination index
# output: flag_no_cong->True if no congestion identified, false otherwise.
# output: list of other src that will conflict with treelev if cand_dest is selected
# Proof ganranteed: only need to compare with logD other treelevs' routing that has only-1-bit diff with given treelev
# O(logD)
def congestion_check(treelev, cand_dest):
    lev_ind=0
    empty_cnt=0
    treelev2='{0:07b}'.format(treelev) #logDbits
    cand_dest2='{0:07b}'.format(cand_dest) #logD bits
    #for lev_ind in range(D):
    #    if (lev_ind!=treelev): #for every pair of srcs, check if they have cong
    #        if (level_banks_dict[lev_ind]==[]): #skip empty treelevs
    #            continue
    treelev_cong_list=[]
    for x in range(int(math.log2(D))): 
    #iterate trhough only the last (Least Signif.) logD bits of src/dest. MSB->x=0+; LSB->x=math.log2(D)-1
        potential_conflict_src2=treelev2
        list_temp = list(potential_conflict_src2)
        list_temp[x]=str(1-int(treelev2[x])) #only bit x different with treelev2
        potential_conflict_src2 = ''.join(list_temp)
        potential_conflict_src=int(potential_conflict_src2,2) #binary to int
        for potential_src_dest in level_banks_dict[potential_conflict_src]: #check all the dests of potential_conflict_src
            #if empty, will not enter for loop
            if (LSXBsame(int(math.log2(D))-x,cand_dest,potential_src_dest)):  
                #print ("found congestion at level",x)
                if potential_conflict_src not in treelev_cong_list:
                    #even if all the dest banks from one src conflict with treelev, they count as one cong because at runtime only one bank can be accessed
                    treelev_cong_list.append(potential_conflict_src)
    return treelev_cong_list #its len is number of other src that will conflict with treelev if cand_dest is selected

# ++++++++++++++++++++++++++++++++++++++++++++++
# helper helper function to check 1-bit difference in two IDs
# input: i - an log2(Y)-bit source ID, j - another log2(Y)-bit source ID  
# output: If only LSB is diff and rest same => 0; if only LSB+1 diff => 1;...If only MSB-1 diff =>. Otherwise:None 
def oneBdiff(i,j):
    stringi='{0:08b}'.format(i)
    stringj='{0:08b}'.format(j)
    flag_pass=0 #0 means no conflict pattern detected. flag_pass=0 iff level_conflict=None
    level_conflict=None #this is the only bit that are diff, which is the butterfly network level where potential conflict happens
    for x in range(int(math.log2(Y))): #only bit x is diff, other same. MSB->x=0; LSB->x=math.log2(Y)-1
        if (stringi[x]!=stringj[x]):
            flag_all_other_same=True
            for xx in range(int(math.log2(Y))):
                if (xx!=x and stringi[xx]!=stringj[xx]): #all otehr bits than x must be the same, if not, cannot pass
                    flag_all_other_same=False
            if (flag_all_other_same==True):
                flag_pass=1
                level_conflict=x
    return level_conflict
            

# ++++++++++++++++++++++++++++++++++++++++++++++
# helper helper function to check X-LSB similarities in two IDs
# input: X - check the last X bits; a - an log2(Y)-bit destnation ID; b - another log2(Y)-bit destnation ID 
def LSXBsame(X,a,b):
    stringa='{0:08b}'.format(a)
    stringb='{0:08b}'.format(b) 
#     print(stringa,stringb)
    flag_same=False
#     print("stringa[-X:]",stringa[-X:])
#     print("stringb[-X:]",stringb[-X:])
    if (stringa[-X:]==stringb[-X:]): 
        flag_same=True
    return flag_same
        
# ++++++++++++++++++++++++++++++++++++++++++++++
# Testbench helper function to count butterfly congestions for every pair of src-dest routes in a list
# input: a list of destination (bank) IDs, index is their source ids
def find_congestion(list_dests):
    Flag_found_cong=0
    src_cong_list=[]
    for i in range(len(list_dests)): #i: A src, list_dests[i]: its dest
        for j in range(i): #j: Another src, list_dests[j]: its dest
            if (i!=j): #for every pair of routing requests, check if they have cong
                if (oneBdiff(i,j)==None): #these srcs cannot have congestion
                    continue
                for x in range(int(math.log2(Y))): #MSB->x=0; LSB->x=math.log2(Y)-1
                    if (oneBdiff(i,j)==x): #if srcs have potential congestion, check dest
#                         print("srcs have oneBdiff only at bit",x)
                        if (LSXBsame(int(math.log2(Y))-x,list_dests[i],list_dests[j])):  
                            print ("found congestion at level",x)
                            Flag_found_cong=1
                            src_cong_list.append((i,j))
    return Flag_found_cong, src_cong_list


# ==================================================
# def main():
# if __name__ == "__main__":
#     main()
# ==================================================
# Algorithm for assigning tree levels to banks
# Simulate: a total of 10K nodes to be inserted to on-chip memory
for node_ind in range(10000): 
    #treelevel is a simulated node level to be inserted to
    #generate treelevel         
    if (node_ind<F):
        treelevel=0
    elif (F<=node_ind<F**2):
        treelevel=1
    else:
        if (node_ind<128):
            treelevel=node_ind-F**2+1
        else:
            treelevel=random.randrange(num_treelevels_tobe_excluded, min(D,node_ind)) 
        assert(treelevel<=D)
    #keep track of shallow levels status
    if (treelevel<5):
        num_nodes_inserted_lst[treelevel]+=1
    #when some shallow level is full, do not insert to this tree level in the future
    for i in range(5):
        if num_nodes_inserted_lst[i]>=F**(i+1):
            num_treelevels_tobe_excluded = i+1 
    
    #ind_bank = find_a_valid_bank_basic(treelevel)
    ind_bank = find_a_valid_bank_advanced(treelevel)
#     if (node_ind<128):
#         print("node_ind",node_ind,"treelevel",treelevel,"inserted to bank",ind_bank)
    
count_cong=0  
all_src_with_cong={}
for k,v in src_cong_list_dict.items():
    if (len(v)>0):
        all_src_with_cong[k]=v
    count_cong+=len(v)
print("======================================================")  
print("All the assigned tree levels for",Y,"banks:",assigned_tree_level_lst)
print("======================================================")  
print("All the accessed banks for",D,"tree levels (pipeline stages):",level_banks_dict)
print("======================================================")  
print("total number of potential congestions:",count_cong)
print("======================================================")  
print("total potential congestions:",all_src_with_cong)

# print("\n====================Test Bench=====================")
# # ??????This could not be correct because not all banks are routing at the same time
# print("total congestions:",find_congestion(assigned_tree_level_lst[1]))



All the assigned tree levels for 256 banks: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 117, 119, 105, 98, 114, 104, 109, 126, 121, 108, 113, 103, 106, 96, 110, 100, 99, 107, 124, 111, 95, 120, 118, 115, 102, 125, 93, 112, 127, 123, 97, 116, 122, 94, 101, 123, 117, 103, 109, 4, 99, 126, 110, 100, 102, 107, 97, 12, 85, 14, 122, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 116, 5, 28, 29, 30, 31, 32, 65, 66, 67, 36, 106, 70, 39, 72, 73, 74, 43, 76, 121, 46, 47, 48, 81, 119, 51, 104, 53, 54, 55, 88, 127, 90, 59, 2, 105, 62, 113, 98, 63, 10, 118, 68, 27, 120, 71, 8, 95, 6, 75, 3, 11, 78, 87, 80, 13, 82, 93, 15, 114, 86, 91, 84, 101, 61, 9, 92, 7, 94, 79, 111, 33, 34, 35, 37, 69, 38

In [126]:
assigned_tree_level_lst=[3,2,3,5,6,7,3]
try:
    print("lst.index() method:",assigned_tree_level_lst.index(10))
except ValueError:
    print("no with 10 found")
indices = [i for i, x in enumerate(assigned_tree_level_lst) if x == 3]
print(indices)
if (len(indices) is not 0):
    print("ha!")
else:
    print("cannot find indices")

no with 10 found
[0, 2, 6]
ha!


# testing

In [102]:
# # get binary representations of stages and banks. note: number of bits should be log2(Y). e.g., 8=log2(256)
# for i in range(D):
#     pip_lst[i]='{0:08b}'.format(i)
# for i in range(Y):
#     bank_lst[i]='{0:08b}'.format(i)
# print(pip_lst[6])

# print(find_congestion([0,2,4,1]))
# print(oneBdiff(0,1))
# print(LSXBsame(8-7,0,3))

# print(find_yth_empty_bank(8))


256
