In [1]:
import sys, os
sys.path.append(os.path.dirname(os.path.abspath('.')))

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_pruning as tp

In [2]:
class DeepFCN(nn.Module):
    def __init__(self, input_size, num_classes):
        super(DeepFCN, self).__init__()
        self.fc1 = nn.Linear(input_size, 256)
        self.add_module('first_relu', nn.ReLU())
        self.fc2 = nn.Sequential(
            nn.Linear(256,64),
            nn.ReLU()
        )
        self.fc3 = nn.ModuleList(
            [nn.Sequential(
            nn.Linear(64,64),
            nn.ReLU()) for i in range(3)
            ]
        )
        self.fc4 = nn.ModuleDict({
            'fc4-1': nn.Linear(64,32),
            'relu': nn.ReLU()
        })
        self.fc5 = nn.Linear(32, num_classes)

    def forward(self, x):
        x = self.fc1(x)
        x = self.first_relu(x)
        x = self.fc2(x)
        for i, l in enumerate(self.fc3):
            x = l(x)
        x = self.fc4['fc4-1'](x)
        x = self.fc4['relu'](x)
        y_hat = self.fc5(x)
        return y_hat

model = DeepFCN(128, 10)
print(model)

DeepFCN(
  (fc1): Linear(in_features=128, out_features=256, bias=True)
  (first_relu): ReLU()
  (fc2): Sequential(
    (0): Linear(in_features=256, out_features=64, bias=True)
    (1): ReLU()
  )
  (fc3): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
    )
    (1): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
    )
    (2): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
    )
  )
  (fc4): ModuleDict(
    (fc4-1): Linear(in_features=64, out_features=32, bias=True)
    (relu): ReLU()
  )
  (fc5): Linear(in_features=32, out_features=10, bias=True)
)


In [3]:
strategy = tp.strategy.RandomStrategy()

In [4]:
module_to_idxs = {}

def init_strategy(m):
    if isinstance(m, nn.Linear):
        print('[linear]', m, end='\n')
        module_to_idxs[m] = strategy(m.weight, amount=0.2)
        print(module_to_idxs[m])
        print()
    else:
        print(m, end='\n\n')

model.apply(init_strategy)

[linear] Linear(in_features=128, out_features=256, bias=True)
[10, 54, 15, 111, 238, 170, 183, 59, 104, 191, 88, 99, 177, 56, 215, 205, 8, 167, 62, 121, 85, 25, 80, 49, 82, 206, 231, 131, 89, 100, 124, 113, 203, 225, 44, 114, 90, 237, 161, 103, 96, 144, 180, 150, 24, 248, 73, 244, 174, 14, 171]

ReLU()

[linear] Linear(in_features=256, out_features=64, bias=True)
[5, 57, 54, 31, 47, 22, 0, 11, 38, 2, 48, 55]

ReLU()

Sequential(
  (0): Linear(in_features=256, out_features=64, bias=True)
  (1): ReLU()
)

[linear] Linear(in_features=64, out_features=64, bias=True)
[59, 19, 57, 47, 18, 4, 40, 9, 61, 52, 23, 11]

ReLU()

Sequential(
  (0): Linear(in_features=64, out_features=64, bias=True)
  (1): ReLU()
)

[linear] Linear(in_features=64, out_features=64, bias=True)
[42, 53, 46, 18, 21, 25, 40, 24, 23, 47, 35, 16]

ReLU()

Sequential(
  (0): Linear(in_features=64, out_features=64, bias=True)
  (1): ReLU()
)

[linear] Linear(in_features=64, out_features=64, bias=True)
[29, 41, 11, 4, 46, 42,

DeepFCN(
  (fc1): Linear(in_features=128, out_features=256, bias=True)
  (first_relu): ReLU()
  (fc2): Sequential(
    (0): Linear(in_features=256, out_features=64, bias=True)
    (1): ReLU()
  )
  (fc3): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
    )
    (1): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
    )
    (2): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
    )
  )
  (fc4): ModuleDict(
    (fc4-1): Linear(in_features=64, out_features=32, bias=True)
    (relu): ReLU()
  )
  (fc5): Linear(in_features=32, out_features=10, bias=True)
)

In [5]:
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,128))

<torch_pruning.dependency.DependencyGraph at 0x7fef012db250>

In [6]:
pruning_plans = []

def get_pruning_plans(m):
    if m in module_to_idxs:
        pruning_plans.append(DG.get_pruning_plan(m, tp.prune_linear, idxs=module_to_idxs[m]))

model.apply(get_pruning_plans)
print(pruning_plans)

[<torch_pruning.dependency.PruningPlan object at 0x7fef012addc0>, <torch_pruning.dependency.PruningPlan object at 0x7fef012ad5b0>, <torch_pruning.dependency.PruningPlan object at 0x7fef012adc40>, <torch_pruning.dependency.PruningPlan object at 0x7fef012ad100>, <torch_pruning.dependency.PruningPlan object at 0x7fef012add00>, <torch_pruning.dependency.PruningPlan object at 0x7fef012f30d0>, <torch_pruning.dependency.PruningPlan object at 0x7fef012f3280>]


In [7]:
print(model)

DeepFCN(
  (fc1): Linear(in_features=128, out_features=256, bias=True)
  (first_relu): ReLU()
  (fc2): Sequential(
    (0): Linear(in_features=256, out_features=64, bias=True)
    (1): ReLU()
  )
  (fc3): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
    )
    (1): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
    )
    (2): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
    )
  )
  (fc4): ModuleDict(
    (fc4-1): Linear(in_features=64, out_features=32, bias=True)
    (relu): ReLU()
  )
  (fc5): Linear(in_features=32, out_features=10, bias=True)
)


In [8]:
pruning_plans.pop() # don't change the output layer
for plan in pruning_plans:
        plan.exec()
print(model)

DeepFCN(
  (fc1): Linear(in_features=128, out_features=205, bias=True)
  (first_relu): ReLU()
  (fc2): Sequential(
    (0): Linear(in_features=205, out_features=52, bias=True)
    (1): ReLU()
  )
  (fc3): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=52, out_features=52, bias=True)
      (1): ReLU()
    )
    (1): Sequential(
      (0): Linear(in_features=52, out_features=52, bias=True)
      (1): ReLU()
    )
    (2): Sequential(
      (0): Linear(in_features=52, out_features=52, bias=True)
      (1): ReLU()
    )
  )
  (fc4): ModuleDict(
    (fc4-1): Linear(in_features=52, out_features=26, bias=True)
    (relu): ReLU()
  )
  (fc5): Linear(in_features=26, out_features=10, bias=True)
)
