In [2]:
import gurobipy as gp
from gurobipy import GRB
import numpy as np 
import torch
import scipy
import scipy.sparse
from math import prod

In [3]:
def generate_01_sparse_matrix(size, density):
    f, c, k = size
    x = scipy.sparse.rand(f, c*k, density, format='csr')
    x.data[:] = 1
    x = np.array(x.todense()).reshape((f, c, k))
    return x

In [4]:
def get_ones_count(matrix):
    return prod(matrix[np.where(matrix == 1)].shape)

In [5]:
def calc_density(matrix):
    ones = get_ones_count(matrix)
    return ones/prod(matrix.shape)

In [6]:
def density_of_remaining_weights(original_matrix, selection_bitmap):
    not_selected = 1-selection_bitmap
    total_unselected = np.sum(not_selected)
    ignored_weights = np.multiply(not_selected, original_matrix)
    unselected_ones = get_ones_count(ignored_weights)
    return  unselected_ones/ total_unselected 

In [7]:
weight_tensor = generate_01_sparse_matrix(size = (64, 64, 9), density = 0.5)
weight_tensor[np.where(weight_tensor != 0)]

array([1., 1., 1., ..., 1., 1., 1.])

In [23]:
def find_densest_subtensor_in_weight_tensor(tensor, filter_bounds, channel_bounds, initialize = False, timeout = 10):
    f_size, c_size, k_size = tensor.shape
    min_filters, max_filters = filter_bounds
    min_channels, max_channels = channel_bounds
    if min_filters > f_size:
        raise ValueError("filter lowerbound must be lower than max filters")
    if min_channels > c_size:
        raise ValueError("channel lowerbound must be lower than max filters")
    
    tensor_cpy = np.copy(tensor)
    tensor_cpy[np.where(tensor_cpy == 0)] = -1
    m = gp.Model('densify')
    m.setParam(GRB.Param.TimeLimit, timeout)
    F = m.addVars(f_size, vtype=GRB.BINARY, name='F')
    C = m.addVars(c_size, vtype=GRB.BINARY, name='C')
    
    if initialize:
        est_filter_density = [(i, s) for s, i in zip(np.sum(tensor, axis=(2, 1)), range(tensor.shape[0]))]
        est_filter_density.sort(key = lambda x: x[1], reverse = True)
        est_channel_density = [(i, s) for s, i in zip(np.sum(tensor, axis=(2, 0)), range(tensor.shape[1]))]
        est_channel_density.sort(key = lambda x: x[1], reverse = True)
        initial_filters = [i for i, _ in est_filter_density[:min_filters+1]]
        initial_channels = [j for j, _ in est_channel_density[:min_channels+1]]

        for i in initial_filters:
            F[i].start = 1
        for j in initial_channels:
            C[j].start = 1
        
    Z = m.addVars(f_size, c_size, vtype=GRB.BINARY, name='Z')
    m.addConstr(gp.quicksum([F[i] for i in range(len(F))]) <= max_filters )
    m.addConstr(min_filters <= gp.quicksum([F[i] for i in range(len(F))]))
    m.addConstr(gp.quicksum([C[j] for j in range(len(C))]) <= max_channels )
    m.addConstr(min_channels <= gp.quicksum([C[j] for j in range(len(C))]))
    m.addConstrs((Z[i, j] == gp.and_(F[i], C[j]) for i in range(len(F)) for j in range(len(C))), name='and_constraints')
    m.setObjective(gp.quicksum(Z[i, j]*tensor_cpy[i, j, k] for i in range(len(F)) for j in range(len(C)) for k in range(k_size)), GRB.MAXIMIZE)
    m.optimize()
    dense_filter_indicies = [i for i, f in F.items() if f.X > 0]
    dense_channel_indicies = [j for j, c in C.items() if c.X > 0]
    selection_bitmap = np.zeros(tensor.shape)
    for f in dense_filter_indicies:
        for c in dense_channel_indicies:
            selection_bitmap[f, c] = 1
    dense_tensor = tensor[dense_filter_indicies, :, :][:, dense_channel_indicies, :]
    return dense_tensor, selection_bitmap, dense_filter_indicies, dense_channel_indicies
dense_tensor, selection_bitmap, dense_filter_indicies, dense_channel_indicies = find_densest_subtensor_in_weight_tensor(weight_tensor, (16, 32), (16, 32), initialize = True, timeout = 10)
print(f'density of weight tensor: {calc_density(weight_tensor)}')
print(f'density of dense tensor: {calc_density(dense_tensor)}')
print(f'density of sparse tensor: {density_of_remaining_weights(weight_tensor, selection_bitmap)}')

Set parameter TimeLimit to value 10


NameError: name 'initalize' is not defined

In [9]:
dense_tensor, dense_filter_indicies, dense_channel_indicies

(array([[[0., 1., 0., ..., 0., 1., 1.],
         [1., 1., 1., ..., 0., 1., 1.],
         [1., 1., 1., ..., 1., 1., 0.],
         ...,
         [1., 0., 0., ..., 1., 0., 0.],
         [1., 0., 1., ..., 0., 1., 0.],
         [0., 1., 1., ..., 1., 0., 1.]],
 
        [[1., 1., 0., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [0., 1., 0., ..., 1., 0., 0.],
         ...,
         [1., 0., 1., ..., 1., 0., 1.],
         [1., 1., 1., ..., 0., 1., 1.],
         [1., 1., 1., ..., 0., 0., 0.]],
 
        [[1., 1., 1., ..., 0., 1., 0.],
         [0., 1., 0., ..., 0., 0., 1.],
         [0., 1., 1., ..., 1., 1., 0.],
         ...,
         [1., 0., 1., ..., 1., 1., 0.],
         [1., 1., 1., ..., 0., 1., 1.],
         [1., 1., 1., ..., 1., 0., 0.]],
 
        ...,
 
        [[1., 1., 0., ..., 1., 0., 0.],
         [1., 1., 0., ..., 1., 0., 1.],
         [1., 0., 0., ..., 0., 0., 1.],
         ...,
         [1., 0., 0., ..., 1., 0., 1.],
         [1., 1., 1., ..., 0., 1., 1.],


In [10]:
est_filter_density = [(i, s) for s, i in zip(np.sum(weight_tensor, axis=(2, 1)), range(weight_tensor.shape[0]))]
est_channel_density = [(i, s) for s, i in zip(np.sum(weight_tensor, axis=(2, 0)), range(weight_tensor.shape[1]))]

In [11]:
est_filter_density.sort(key= lambda x: x[1], reverse=True)

In [12]:
est_filter_density

[(23, 316.0),
 (35, 315.0),
 (12, 312.0),
 (47, 308.0),
 (20, 307.0),
 (58, 307.0),
 (62, 307.0),
 (43, 306.0),
 (18, 302.0),
 (16, 301.0),
 (28, 301.0),
 (48, 301.0),
 (6, 300.0),
 (9, 296.0),
 (15, 296.0),
 (39, 296.0),
 (60, 296.0),
 (2, 295.0),
 (27, 295.0),
 (4, 294.0),
 (17, 294.0),
 (49, 294.0),
 (34, 292.0),
 (46, 292.0),
 (14, 291.0),
 (22, 291.0),
 (57, 291.0),
 (63, 291.0),
 (7, 290.0),
 (8, 290.0),
 (11, 290.0),
 (30, 289.0),
 (31, 289.0),
 (45, 289.0),
 (13, 288.0),
 (25, 286.0),
 (29, 286.0),
 (38, 286.0),
 (37, 285.0),
 (36, 283.0),
 (42, 282.0),
 (51, 282.0),
 (3, 281.0),
 (32, 281.0),
 (40, 281.0),
 (50, 281.0),
 (1, 280.0),
 (21, 280.0),
 (33, 280.0),
 (26, 277.0),
 (59, 276.0),
 (10, 275.0),
 (54, 275.0),
 (56, 275.0),
 (55, 274.0),
 (52, 273.0),
 (53, 272.0),
 (19, 271.0),
 (24, 271.0),
 (41, 270.0),
 (0, 268.0),
 (44, 268.0),
 (5, 261.0),
 (61, 260.0)]