In [1]:
import torch
config_list = [{
    'op_types': ['Conv2d'],
    'sparse_ratio': 0.5,
    'exclude_op_names': [
        'recon.decoder.fin_out.0',
        'seg.decoder_segment.fin_out.0',
    ]
}]
save_path = 'compressed_bubbles_50.mod'
device = torch.device('cpu')
input_shape = (1, 8, 128, 128)
from nni.compression.pruning import L1NormPruner
prunerClass = L1NormPruner

In [2]:
import path, sys
cur_path = path.Path('.').abspath()
sys.path.append(cur_path.parent.parent)
from DRAEM.base.model_unet import ReconstructiveSubNetwork, DiscriminativeSubNetwork

class DRAEMPack(torch.nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.recon = ReconstructiveSubNetwork(in_channels=in_channels, out_channels=in_channels)
        self.seg = DiscriminativeSubNetwork(in_channels=2*in_channels, out_channels=2)

    def load_pack_checkpoint(self, pack_path):
        pack_dict = torch.load(pack_path, map_location='cpu')
        self.recon.load_state_dict(pack_dict['model'])
        self.seg.load_state_dict(pack_dict['model_seg'])
    
    def load_checkpoint(self, recon_path, seg_path):
        recon_dict = torch.load(recon_path, map_location='cpu')
        self.recon.load_state_dict(recon_dict)
        seg_dict = torch.load(seg_path, map_location='cpu')
        self.seg.load_state_dict(seg_dict)
    
    def forward(self, x):
        rec = self.recon(x)
        joined_in = torch.cat((rec, x), dim=1)
        out_mask = self.seg(joined_in)
        out_mask_sm = torch.softmax(out_mask, dim=1)
        if self.training:
            return rec, out_mask_sm
        else:
            return out_mask_sm

In [3]:
model = DRAEMPack(in_channels=8)
model.load_pack_checkpoint('/home/caoxiatian/DRAEM/checkpoints/gray_ms1k/81000.ckpt')
model = model.to(device)
print(model)
pruner = prunerClass(model, config_list)
# show the wrapped model structure, `PrunerModuleWrapper` have wrapped the layers that configured in the config_list.
# print(model)
# compress the model and generate the masks
_, masks = pruner.compress()
# # show the masks sparsity
# for name, mask in masks.items():
#     print(name, ' sparsity : ', '{:.2}'.format(mask['weight'].sum() / mask['weight'].numel()))
# need to unwrap the model, if the model is wrapped before speedup
pruner.unwrap_model()
# speedup the model, for more information about speedup, please refer :doc:`pruning_speedup`.
from nni.compression.speedup import ModelSpeedup
ModelSpeedup(model, torch.rand(*input_shape).to(device), masks).speedup_model()

DRAEMPack(
  (recon): ReconstructiveSubNetwork(
    (encoder): EncoderReconstructive(
      (block1): Sequential(
        (0): Conv2d(8, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
      (mp1): Sequential(
        (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (block2): Sequential(
        (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(256,



[2023-11-02 19:28:23] [32mStart to speedup the model...[0m
[2023-11-02 19:28:23] [32mResolve the mask conflict before mask propagate...[0m
[2023-11-02 19:28:24] [32mdim0 sparsity: 0.500000[0m
[2023-11-02 19:28:24] [32mdim1 sparsity: 0.000000[0m
0 Filter
[2023-11-02 19:28:24] [32mdim0 sparsity: 0.500000[0m
[2023-11-02 19:28:24] [32mdim1 sparsity: 0.000000[0m
[2023-11-02 19:28:24] [32mInfer module masks...[0m
[2023-11-02 19:28:24] [32mPropagate original variables[0m
[2023-11-02 19:28:24] [32mPropagate variables for placeholder: x, output mask:  0.0000 [0m
[2023-11-02 19:28:24] [32mPropagate variables for call_module: recon_encoder_block1_0, weight:  0.5000 bias:  0.5000 , output mask:  0.0000 [0m
[2023-11-02 19:28:24] [32mPropagate variables for call_module: recon_encoder_block1_1, , output mask:  0.0000 [0m
[2023-11-02 19:28:24] [32mPropagate variables for call_module: recon_encoder_block1_2, , output mask:  0.0000 [0m
[2023-11-02 19:28:24] [32mPropagate variabl

DRAEMPack(
  (recon): ReconstructiveSubNetwork(
    (encoder): EncoderReconstructive(
      (block1): Sequential(
        (0): Conv2d(8, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
      (mp1): Sequential(
        (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (block2): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(128, eps=1

In [7]:
import psutil
import os

print(u'当前进程的内存使用：%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )

info = psutil.virtual_memory()
print( u'电脑总内存：%.4f GB' % (info.total / 1024 / 1024 / 1024) )
print(u'当前使用的总内存占比：',info.percent)
print(u'cpu个数：',psutil.cpu_count())

当前进程的内存使用：6.8981 GB
电脑总内存：125.3765 GB
当前使用的总内存占比： 29.0
cpu个数： 64


In [5]:
torch.save(model.cpu(), save_path)

In [6]:
# model = torch.load(save_path, map_location='cpu')