# from resnet_v0

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import math
import numpy as np
from argparse import Namespace

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

class SparseConv(nn.Module):
    def __init__(self, k, k_percent):
        super(SparseConv, self).__init__()
        self.k = k
        self.k_percent = k_percent

    def unravel_index(self, mask, indices, k):
        """ """
        
        a = torch.arange(0, indices.shape[0])
        a = torch.cat(k * [a.unsqueeze(-1)], axis=-1)
        indices = [a.view(-1), indices.view(-1)]
        mask[indices] = 1
        return mask

    def batch_topk(self, inp, k):
        """ """ 
        
        (buffer, indices) = torch.topk(inp, k, -1, True)#k个最大
        mask = torch.zeros_like(inp).bool()
        mask = self.unravel_index(mask, indices, k)
        inp = inp * mask
        return inp

    def filter(self, x):  # Applied to a batch.
        """ """
        k_per_map = math.ceil(self.k * x.shape[2] * x.shape[3]) #k乘维度2和维度3
        x = x.permute(1, 0, 2, 3)#维度换位,原来的维度0和维度1换位
        inp = x.reshape(x.shape[0] * x.shape[1], -1)#保留新维度0，其余展开
        
        inp = self.batch_topk(inp, k_per_map)
        inp = inp.reshape(x.shape[0], x.shape[1] * x.shape[2] * x.shape[3])
        k_factor = math.ceil(self.k_percent * k_per_map * x.shape[1])
        inp = self.batch_topk(inp, k_factor)
        inp = inp.reshape(x.shape[0], x.shape[1], x.shape[2], x.shape[3])
        inp = inp.permute(1, 0, 2, 3)
        return inp

In [21]:

def unravel_index(mask, indices, k):

    a = torch.arange(0, indices.shape[0])
    a = torch.cat(k * [a.unsqueeze(-1)], axis=-1)
    indices = [a.view(-1), indices.view(-1)]
    mask[indices] = 1
    return mask

def batch_topk( inp, k):

    (buffer, indices) = torch.topk(inp, k, -1, True)#k个最大
    mask = torch.zeros_like(inp).bool()
    mask = unravel_index(mask, indices, k)
    inp = inp * mask
    return inp

def filter(x):  # Applied to a batch.
    k_per_map = math.ceil(k * x.shape[2] * x.shape[3]) #k乘维度2和维度3
    x = x.permute(1, 0, 2, 3)#维度换位,原来的维度0和维度1换位
    inp = x.reshape(x.shape[0] * x.shape[1], -1)#保留新维度0，其余展开
        
    inp = batch_topk(inp, k_per_map)
    inp = inp.reshape(x.shape[0], x.shape[1] * x.shape[2] * x.shape[3])
    k_factor = math.ceil(k_percent * k_per_map * x.shape[1])
    inp = batch_topk(inp, k_factor)
    inp = inp.reshape(x.shape[0], x.shape[1], x.shape[2], x.shape[3])
    inp = inp.permute(1, 0, 2, 3)
    return inp

In [63]:
x=torch.rand(3, 1, 3, 4)
print(x)

tensor([[[[0.7267, 0.2396, 0.1355, 0.6031],
          [0.9992, 0.6631, 0.2969, 0.4840],
          [0.1630, 0.7144, 0.1736, 0.9222]]],


        [[[0.1905, 0.5748, 0.1654, 0.6083],
          [0.9095, 0.4006, 0.0388, 0.7231],
          [0.3230, 0.2629, 0.1584, 0.0243]]],


        [[[0.2465, 0.3460, 0.6081, 0.3679],
          [0.8153, 0.0063, 0.5922, 0.6756],
          [0.6459, 0.7110, 0.4536, 0.4574]]]])


In [64]:
import numpy
b=x.reshape(1,-1).tolist()
b[0].sort()
print(b[0])

[0.006291508674621582, 0.024332642555236816, 0.038789212703704834, 0.13550329208374023, 0.15841537714004517, 0.16299211978912354, 0.16537773609161377, 0.17358827590942383, 0.19052261114120483, 0.23958462476730347, 0.24649661779403687, 0.2629052996635437, 0.29690873622894287, 0.32299792766571045, 0.3459511399269104, 0.36789363622665405, 0.4005868434906006, 0.4535815119743347, 0.45743393898010254, 0.4839909076690674, 0.5747764110565186, 0.5922250151634216, 0.6031078696250916, 0.6080601811408997, 0.6082941889762878, 0.6459240317344666, 0.6630747318267822, 0.6755623817443848, 0.7109774351119995, 0.7143705487251282, 0.7230732440948486, 0.726678192615509, 0.8152569532394409, 0.909501850605011, 0.9222007989883423, 0.9991892576217651]


In [68]:
klist=[1,0.1,0.9]
kperlist=[1,0.1,0.9]
for a in klist:
    for b in kperlist:
        k=a
        print("k is "+str(k))
        k_percent=b
        print("k_percent is "+str(k_percent))
        print(filter(x))

k is 1
k_percent is 1
tensor([[[[0.7267, 0.2396, 0.1355, 0.6031],
          [0.9992, 0.6631, 0.2969, 0.4840],
          [0.1630, 0.7144, 0.1736, 0.9222]]],


        [[[0.1905, 0.5748, 0.1654, 0.6083],
          [0.9095, 0.4006, 0.0388, 0.7231],
          [0.3230, 0.2629, 0.1584, 0.0243]]],


        [[[0.2465, 0.3460, 0.6081, 0.3679],
          [0.8153, 0.0063, 0.5922, 0.6756],
          [0.6459, 0.7110, 0.4536, 0.4574]]]])
k is 1
k_percent is 0.1
tensor([[[[0.0000, 0.0000, 0.0000, 0.0000],
          [0.9992, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.9222]]],


        [[[0.0000, 0.0000, 0.0000, 0.0000],
          [0.9095, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000]]],


        [[[0.0000, 0.0000, 0.0000, 0.0000],
          [0.8153, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000]]]])
k is 1
k_percent is 0.9
tensor([[[[0.7267, 0.2396, 0.1355, 0.6031],
          [0.9992, 0.6631, 0.2969, 0.4840],
          [0.1630, 0.7144, 0

In [3]:
k=0.5

x=1,2,3,4
k_per_map=0.1*3*4 取整=2向上


SyntaxError: invalid syntax (<ipython-input-3-03ad5002bbe5>, line 4)

In [4]:
x=torch.rand(1,2,3,4)
print(x.size())
x=x.permute(1, 0, 2, 3)
print(x.size())
inp = x.reshape(x.shape[0] * x.shape[1], -1)
print(inp.size())

torch.Size([1, 2, 3, 4])
torch.Size([2, 1, 3, 4])
torch.Size([2, 12])


In [5]:
x

tensor([[[[0.0629, 0.3333, 0.9897, 0.0092],
          [0.6994, 0.1588, 0.5580, 0.6335],
          [0.8399, 0.8978, 0.1073, 0.4354]]],


        [[[0.9030, 0.6095, 0.8526, 0.5466],
          [0.3246, 0.6502, 0.0882, 0.1741],
          [0.5851, 0.9862, 0.4548, 0.1090]]]])

In [6]:
k=0.5
k_per_map = math.ceil(k * x.shape[2] * x.shape[3])
print(k_per_map)

6


# batch topk

In [7]:
inp

tensor([[0.0629, 0.3333, 0.9897, 0.0092, 0.6994, 0.1588, 0.5580, 0.6335, 0.8399,
         0.8978, 0.1073, 0.4354],
        [0.9030, 0.6095, 0.8526, 0.5466, 0.3246, 0.6502, 0.0882, 0.1741, 0.5851,
         0.9862, 0.4548, 0.1090]])

In [8]:
(buffer, indices)=torch.topk(inp, 11, -1, True)
print(buffer)#值
print(indices)#索引
print(buffer.size())
print(indices.size())
mask = torch.zeros_like(inp).bool()
print(mask)

tensor([[0.9897, 0.8978, 0.8399, 0.6994, 0.6335, 0.5580, 0.4354, 0.3333, 0.1588,
         0.1073, 0.0629],
        [0.9862, 0.9030, 0.8526, 0.6502, 0.6095, 0.5851, 0.5466, 0.4548, 0.3246,
         0.1741, 0.1090]])
tensor([[ 2,  9,  8,  4,  7,  6, 11,  1,  5, 10,  0],
        [ 9,  0,  2,  5,  1,  8,  3, 10,  4,  7, 11]])
torch.Size([2, 11])
torch.Size([2, 11])
tensor([[False, False, False, False, False, False, False, False, False, False,
         False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False]])


# unravel_index

In [9]:
indices.size()

torch.Size([2, 11])

In [10]:
a = torch.arange(0, indices.shape[0])
print(a)
a = torch.cat(k_per_map * [a.unsqueeze(-1)], axis=-1)
print(a)
indices = [a.view(-1), indices.view(-1)]
print(indices)
mask[indices] = 1

tensor([0, 1])
tensor([[0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1]])
[tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), tensor([ 2,  9,  8,  4,  7,  6, 11,  1,  5, 10,  0,  9,  0,  2,  5,  1,  8,  3,
        10,  4,  7, 11])]


IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [12], [22]

In [11]:
print(mask)

tensor([[False, False, False, False, False, False, False, False, False, False,
         False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False]])


# batch top k

In [None]:
inp = inp * mask

In [None]:
inp

In [None]:
inp = inp.reshape(x.shape[0], x.shape[1], x.shape[2], x.shape[3])
inp = inp.permute(1, 0, 2, 3)
inp