In [235]:
import numpy as np

def swizzle_cols(cols, num_banks):
    return ((cols % num_banks)) + 1

def simulate_tensor_in_scpad(tensor, num_banks, rows, cols):
    padded_cols = cols + swizzle_cols(cols, num_banks)

    banks = [[] for i in range(num_banks)]
    
    for r in range(rows):
        for c in range(cols):
            bank_id = (r * padded_cols + c) % num_banks
            banks[bank_id] += [f"{r},{c}"]
    
    return banks

def get_bank_from_id(index, banks):
    for i, bank in enumerate(banks):
        if index in bank:
            return i
    return None

def get_coalesced_indices(start_row, start_col, transposed_mode, num_banks):
    indices = []
    
    if not transposed_mode:
        for i in range(num_banks):
            indices.append(f"{start_row},{i}")
    else:
        for i in range(num_banks):
            indices.append(f"{i},{start_col}")

    return indices

def print_scpad(banks):
    num_rows = max(len(bank) for bank in banks)
    for row in range(num_rows):
        row_values = [f"{banks[bank_idx][row]}".center(5) if (row < len(banks[bank_idx])) else "".center(5) for bank_idx in range(len(banks)) ]
        print(" || ".join(row_values))

def print_tensor(num_rows, num_cols):
    for r in range(num_rows):
        row_values = [f"{r},{c}".center(5) for c in range(num_cols)]
        print(" || ".join(row_values))

In [236]:
def run_tests(rows, cols, num_banks, transposed_mode, print_stuff): 
    tensor = np.arange(rows * cols).reshape((rows, cols))
    banks = simulate_tensor_in_scpad(tensor, num_banks, rows, cols)

    if (print_stuff): 
        print_tensor(rows, cols)
        print("-------------------------")
        print_scpad(banks)

    if transposed_mode: domain = cols 
    else: domain = rows
    
    for a in range(domain): 
        if (print_stuff): print("---------------------------------------")

        if transposed_mode: 
            r = 0
            c = a
        else:
            r = a 
            c = 0

        indices_non_transposed = get_coalesced_indices(r, c, transposed_mode, num_banks)
        curr_banks = [] 
        for i, index in enumerate(indices_non_transposed):
            bank_from_sim = get_bank_from_id(index, banks)

            if (print_stuff):
                if (i < len(indices_non_transposed) - 1): print(f"{index} [{bank_from_sim}]", end = ", ")
                else: print(f"{index} [{bank_from_sim}]")
            
            curr_banks += [bank_from_sim]


        if (len(set(curr_banks)) != num_banks): 
            print(f"{set(curr_banks)} - FAILED!!!")
            return 
            
    if (print_stuff): print("---------------------------------------")

    return banks

In [237]:
num_banks = 5
transposed_mode = False 
print_stuff = 0

In [238]:
for num_rows in [10, 12, 13, 15, 20, 21, 22]: 
    for num_cols in [10, 12, 13, 20, 21, 22, 24]:
        print(num_rows, num_cols)
        banks = run_tests(num_rows, num_cols, num_banks, transposed_mode, print_stuff)

10 10
10 12
10 13
10 20
10 21
10 22
10 24
12 10
12 12
12 13
12 20
12 21
12 22
12 24
13 10
13 12
13 13
13 20
13 21
13 22
13 24
15 10
15 12
15 13
15 20
15 21
15 22
15 24
20 10
20 12
20 13
20 20
20 21
20 22
20 24
21 10
21 12
21 13
21 20
21 21
21 22
21 24
22 10
22 12
22 13
22 20
22 21
22 22
22 24


In [239]:
banks = run_tests(12, 10, 4, False, 1)

 0,0  ||  0,1  ||  0,2  ||  0,3  ||  0,4  ||  0,5  ||  0,6  ||  0,7  ||  0,8  ||  0,9 
 1,0  ||  1,1  ||  1,2  ||  1,3  ||  1,4  ||  1,5  ||  1,6  ||  1,7  ||  1,8  ||  1,9 
 2,0  ||  2,1  ||  2,2  ||  2,3  ||  2,4  ||  2,5  ||  2,6  ||  2,7  ||  2,8  ||  2,9 
 3,0  ||  3,1  ||  3,2  ||  3,3  ||  3,4  ||  3,5  ||  3,6  ||  3,7  ||  3,8  ||  3,9 
 4,0  ||  4,1  ||  4,2  ||  4,3  ||  4,4  ||  4,5  ||  4,6  ||  4,7  ||  4,8  ||  4,9 
 5,0  ||  5,1  ||  5,2  ||  5,3  ||  5,4  ||  5,5  ||  5,6  ||  5,7  ||  5,8  ||  5,9 
 6,0  ||  6,1  ||  6,2  ||  6,3  ||  6,4  ||  6,5  ||  6,6  ||  6,7  ||  6,8  ||  6,9 
 7,0  ||  7,1  ||  7,2  ||  7,3  ||  7,4  ||  7,5  ||  7,6  ||  7,7  ||  7,8  ||  7,9 
 8,0  ||  8,1  ||  8,2  ||  8,3  ||  8,4  ||  8,5  ||  8,6  ||  8,7  ||  8,8  ||  8,9 
 9,0  ||  9,1  ||  9,2  ||  9,3  ||  9,4  ||  9,5  ||  9,6  ||  9,7  ||  9,8  ||  9,9 
 10,0 ||  10,1 ||  10,2 ||  10,3 ||  10,4 ||  10,5 ||  10,6 ||  10,7 ||  10,8 ||  10,9
 11,0 ||  11,1 ||  11,2 ||  11,3 ||  11,4 |