In [2]:
import gurobipy as gp
from gurobipy import GRB
import numpy as np 
import torch
import scipy
import scipy.sparse
from math import prod

In [3]:
def generate_01_sparse_matrix(size, density):
    f, c, k = size
    x = scipy.sparse.rand(f, c*k, density, format='csr')
    x.data[:] = 1
    x = np.array(x.todense()).reshape((f, c, k))
    return x

In [4]:
def get_ones_count(matrix):
    return prod(matrix[np.where(matrix == 1)].shape)

In [5]:
def calc_density(matrix):
    ones = get_ones_count(matrix)
    return ones/prod(matrix.shape)

In [6]:
def density_of_remaining_weights(original_matrix, selection_bitmap):
    not_selected = 1-selection_bitmap
    total_unselected = np.sum(not_selected)
    ignored_weights = np.multiply(not_selected, original_matrix)
    unselected_ones = get_ones_count(ignored_weights)
    return  unselected_ones/ total_unselected 

In [46]:
weight_tensor = generate_01_sparse_matrix(size = (64, 64, 9), density = 0.95)
weight_tensor[np.where(weight_tensor != 0)]

array([1., 1., 1., ..., 1., 1., 1.])

In [47]:
def find_densest_subtensor_in_weight_tensor(tensor, filter_bounds, channel_bounds, initialize = False, timeout = None):
    f_size, c_size, k_size = tensor.shape
    min_filters, max_filters = filter_bounds
    min_channels, max_channels = channel_bounds
    if min_filters > f_size:
        raise ValueError("filter lowerbound must be lower than max filters")
    if min_channels > c_size:
        raise ValueError("channel lowerbound must be lower than max filters")
    
    tensor_cpy = np.copy(tensor)
    tensor_cpy[np.where(tensor_cpy == 0)] = -1
    m = gp.Model('densify')
    if timeout is not None:
        m.setParam(GRB.Param.TimeLimit, timeout)
    F = m.addVars(f_size, vtype=GRB.BINARY, name='F')
    C = m.addVars(c_size, vtype=GRB.BINARY, name='C')
    
    if initialize:
        est_filter_density = [(i, s) for s, i in zip(np.sum(tensor, axis=(2, 1)), range(tensor.shape[0]))]
        est_filter_density.sort(key = lambda x: x[1], reverse = True)
        est_channel_density = [(i, s) for s, i in zip(np.sum(tensor, axis=(2, 0)), range(tensor.shape[1]))]
        est_channel_density.sort(key = lambda x: x[1], reverse = True)
        initial_filters = [i for i, _ in est_filter_density[:min_filters+1]]
        initial_channels = [j for j, _ in est_channel_density[:min_channels+1]]

        for i in initial_filters:
            F[i].start = 1
        for j in initial_channels:
            C[j].start = 1
        
    Z = m.addVars(f_size, c_size, vtype=GRB.BINARY, name='Z')
    m.addConstr(gp.quicksum([F[i] for i in range(len(F))]) <= max_filters )
    m.addConstr(min_filters <= gp.quicksum([F[i] for i in range(len(F))]))
    m.addConstr(gp.quicksum([C[j] for j in range(len(C))]) <= max_channels )
    m.addConstr(min_channels <= gp.quicksum([C[j] for j in range(len(C))]))
    m.addConstrs((Z[i, j] == gp.and_(F[i], C[j]) for i in range(len(F)) for j in range(len(C))), name='and_constraints')
    m.setObjective(gp.quicksum(Z[i, j]*tensor_cpy[i, j, k] for i in range(len(F)) for j in range(len(C)) for k in range(k_size)), GRB.MAXIMIZE)
    m.optimize()
    dense_filter_indicies = [i for i, f in F.items() if f.X > 0]
    dense_channel_indicies = [j for j, c in C.items() if c.X > 0]
    selection_bitmap = np.zeros(tensor.shape)
    for f in dense_filter_indicies:
        for c in dense_channel_indicies:
            selection_bitmap[f, c] = 1
    dense_tensor = tensor[dense_filter_indicies, :, :][:, dense_channel_indicies, :]
    return dense_tensor, selection_bitmap, dense_filter_indicies, dense_channel_indicies
dense_tensor, selection_bitmap, dense_filter_indicies, dense_channel_indicies = find_densest_subtensor_in_weight_tensor(weight_tensor, (16, 32), (16, 32), timeout=60)
print(f'density of weight tensor: {calc_density(weight_tensor)}')
print(f'density of dense tensor: {calc_density(dense_tensor)}')
print(f'density of sparse tensor: {density_of_remaining_weights(weight_tensor, selection_bitmap)}')
print(f'selected filters: {dense_filter_indicies}')
print(f'selected channels: {dense_channel_indicies}')

Gurobi Optimizer version 9.5.0 build v9.5.0rc5 (linux64)
Thread count: 8 physical cores, 8 logical processors, using up to 8 threads
Optimize a model with 4 rows, 4224 columns and 256 nonzeros
Model fingerprint: 0x145c7ff9
Model has 4096 general constraints
Variable types: 0 continuous, 4224 integer (4224 binary)
Coefficient statistics:
  Matrix range     [1e+00, 1e+00]
  Objective range  [1e+00, 9e+00]
  Bounds range     [1e+00, 1e+00]
  RHS range        [4e+00, 3e+01]
Presolve added 8191 rows and 0 columns
Presolve time: 0.03s
Presolved: 8195 rows, 4224 columns, 16639 nonzeros
Variable types: 0 continuous, 4224 integer (4224 binary)
Found heuristic solution: objective 8292.0000000

Root relaxation: objective 1.658950e+04, 6949 iterations, 0.26 seconds (0.28 work units)

    Nodes    |    Current Node    |     Objective Bounds      |     Work
 Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time

     0     0 16589.5000    0 4223 8292.00000 16589.5000   100%    