In [1]:
!pip install numpy==1.21.0

Collecting numpy==1.21.0
  Downloading numpy-1.21.0.zip (10.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.3/10.3 MB[0m [31m41.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: numpy
  Building wheel for numpy (pyproject.toml) ... [?25l[?25hdone
  Created wheel for numpy: filename=numpy-1.21.0-cp310-cp310-linux_x86_64.whl size=15985728 sha256=ccb65d6aeb0df15f47b4717c9312a661e14a325dc8555428a351555a9ea23b93
  Stored in directory: /root/.cache/pip/wheels/05/61/d1/ccc2cd557b39e127ad98a392d9558f3c5dda28764b7f54b2f5
Successfully built numpy
Installing collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.26.4
    Uninstalling numpy-1.26.4:
      Successfully uninstalled numpy-1.26.4
[31mERROR: pip's dependency resolver doe

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Pruning -> Fine-tune

1. Sort the weight of Batchnorm1d.
2. Remove the channel of Conv1d if the weight of Batchnorm1d before this channel is smaller than threshold.
3. Fine-tine model to recover accuracy.


## Load model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from speech_command_dataset import SpeechCommandDataset
from torchvision.transforms import Compose
import torchvision.models as models
import model

# from apex import amp

import os,time
import numpy as np
import matplotlib.pyplot as plt

In [None]:
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
BATCH_SIZE = 4
training_params = {"batch_size": BATCH_SIZE,
                       "shuffle": True,
                       "drop_last": False,
                       "num_workers": 1}

testing_params = {"batch_size": BATCH_SIZE,
                       "shuffle": False,
                       "drop_last": False,
                       "num_workers": 1}

train_set = SpeechCommandDataset()
train_loader = DataLoader(train_set, **training_params)

test_set = SpeechCommandDataset(is_training=False)
test_loader = DataLoader(test_set, **testing_params)

In [None]:
# load model
net = model.SincNet().cuda()
model_path = '/content/drive/MyDrive/DL_lab4/Checkpoint/SincNet_best.pth.tar'
# you could load the model after pruning and fine-tune to prune again
# model_path = './Checkpoint/SincNet_finetune.pth.tar'

if os.path.isfile(model_path):
    print("=> loading checkpoint '{}'".format(model_path))
    checkpoint = torch.load(model_path)
#     print(checkpoint)
    net.load_state_dict(checkpoint['state_dict'])
else:
    print("=> no checkpoint found at '{}'".format(model_path))

=> loading checkpoint './Checkpoint/SincNet_best.pth.tar'


In [None]:
print('Before pruning')

loss_func = nn.CrossEntropyLoss()

total_val_loss = 0
correct = 0
total = 0
batch_num = 0

net.eval()

for audios, labels in test_loader:
    audios = audios.cuda()
    labels = labels.cuda()

    outputs = net(audios)
    loss = loss_func(outputs, labels)
    total_val_loss += loss.item()
    batch_num += 1
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()


val_loss = total_val_loss / batch_num
val_acc = 100.0 * float(correct) / float(total)


print('Validation loss: %.4f' % val_loss,'Validation accuracy: %.2f' % val_acc)

Before pruning
Validation loss: 0.1473 Validation accuracy: 95.95


## Start Pruning

In [None]:
# you can choose your pruning rate
pruning_rate = 0.1

In [None]:
# extract the weight of BatchNorm1d layer and sort them

total_par = sum([param.nelement() for param in net.parameters()])
print("Number of parameter: %.2fM" % (total_par/1e6))

total = 0

for m in net.modules():
    if isinstance(m, nn.BatchNorm1d):
        total += m.weight.data.shape[0]

bn = torch.zeros(total)
index = 0
for m in net.modules():
    if isinstance(m, nn.BatchNorm1d):
        size = m.weight.data.shape[0]
#         print(size)
        bn[index:(index+size)] = m.weight.data.abs().clone()

        index += size

y, i = torch.sort(bn)
thre_index = int(total * pruning_rate)
thre = y[thre_index]


Number of parameter: 0.27M


In [None]:
# record the renaming weight
pruned = 0
cfg = []
cfg_mask = []
for k, m in enumerate(net.modules()):
    if isinstance(m, nn.BatchNorm1d):
        weight_copy = m.weight.data.abs().clone().cpu()

        # if weight larger than threshold
        mask = weight_copy.gt(thre).float().cuda()

        # pruning number
        pruned = pruned + mask.shape[0] - torch.sum(mask)
        m.weight.data.mul_(mask)
        m.bias.data.mul_(mask)

        cfg.append(int(torch.sum(mask)))
        cfg_mask.append(mask.clone())
        print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(k, mask.shape[0], int(torch.sum(mask))))

    elif isinstance(m, nn.AvgPool1d):
        cfg.append('P')


# print('cfg',cfg)
pruned_ratio = pruned/total

print('Pre-processing Successful!')

layer index: 4 	 total channel: 40 	 remaining channel: 37
layer index: 11 	 total channel: 256 	 remaining channel: 243
layer index: 17 	 total channel: 256 	 remaining channel: 208
layer index: 23 	 total channel: 256 	 remaining channel: 239
layer index: 29 	 total channel: 256 	 remaining channel: 214
layer index: 35 	 total channel: 160 	 remaining channel: 160
Pre-processing Successful!


In [None]:
new_module = model.SincNet(cfg).cuda()

In [None]:
old_modules = list(net.modules())
new_modules = list(new_module.modules())
layer_id_in_cfg = 0
start_mask = torch.ones(1)
end_mask = cfg_mask[layer_id_in_cfg]
conv_count = 0

for layer_id in range(len(old_modules)):
    m0 = old_modules[layer_id]
    m1 = new_modules[layer_id]
    if isinstance(m0, nn.BatchNorm1d):


        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))

        m1.weight.data = m0.weight.data[idx1.tolist()].clone()
        m1.bias.data = m0.bias.data[idx1.tolist()].clone()
        m1.running_mean = m0.running_mean[idx1.tolist()].clone()
        m1.running_var = m0.running_var[idx1.tolist()].clone()
        layer_id_in_cfg += 1
        start_mask = end_mask.clone()
        if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FC
            end_mask = cfg_mask[layer_id_in_cfg]
    elif isinstance(m0, nn.Conv1d):
        if isinstance(old_modules[layer_id-4], nn.BatchNorm1d) or isinstance(old_modules[layer_id-3], nn.BatchNorm1d) or isinstance(old_modules[layer_id+2], nn.BatchNorm1d):
            # This convers the convolutions in the residual block.
            conv_count += 1

            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))


            if conv_count % 2 != 0:
                w1 = m0.weight.data[idx0.tolist(), :, :].clone()
                print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx0.size))
            else:
                w1 = m0.weight.data[:, idx0.tolist(), :].clone()
                w1 = w1[idx1.tolist(), :, :].clone()
                print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))


            m1.weight.data = w1.clone()

            continue

        # We need to consider the case where there are downsampling convolutions.
        # For these convolutions, we just copy the weights.
        m1.weight.data = m0.weight.data.clone()
    elif isinstance(m0, nn.Linear):
        idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
        if idx0.size == 1:
            idx0 = np.resize(idx0, (1,))

        m1.weight.data = m0.weight.data[:, idx0].clone()
        m1.bias.data = m0.bias.data.clone()

# print(cfg)
torch.save({'cfg': cfg, 'state_dict': new_module.state_dict()}, os.path.join('./Checkpoint', 'SincNet_prune.pth.tar'))

In shape: 37, Out shape 37.
In shape: 37, Out shape 243.
In shape: 243, Out shape 243.
In shape: 243, Out shape 208.
In shape: 208, Out shape 208.
In shape: 208, Out shape 239.
In shape: 239, Out shape 239.
In shape: 239, Out shape 214.
In shape: 214, Out shape 214.
In shape: 214, Out shape 160.


## Fine-tune

In [None]:
checkpoint = torch.load('./Checkpoint/SincNet_prune.pth.tar')
net = model.SincNet(cfg=checkpoint['cfg'])
net.load_state_dict(checkpoint['state_dict'])

net.cuda()
print(net)

SincNet(
  (sincconv): _Layer(
    (conv0): SincConv1d()
    (logabs): LogAbs()
    (bn): BatchNorm1d(37, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pool): AvgPool1d(kernel_size=(2,), stride=(2,), padding=(0,))
  )
  (features): ModuleList(
    (0): _Layer(
      (conv0): Conv1d(37, 37, kernel_size=(25,), stride=(2,), groups=37)
      (conv1): Conv1d(37, 243, kernel_size=(1,), stride=(1,))
      (relu): ReLU()
      (bn): BatchNorm1d(243, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (pool): AvgPool1d(kernel_size=(2,), stride=(2,), padding=(0,))
    )
    (1): _Layer(
      (conv0): Conv1d(243, 243, kernel_size=(9,), stride=(1,), groups=243)
      (conv1): Conv1d(243, 208, kernel_size=(1,), stride=(1,))
      (relu): ReLU()
      (bn): BatchNorm1d(208, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (pool): AvgPool1d(kernel_size=(2,), stride=(2,), padding=(0,))
    )
    (2): _Layer(
      (conv0): Conv1d(208, 208,

In [None]:
print('Before fine-tune')

total_val_loss = 0
correct = 0
total = 0
batch_num = 0

net.eval()

for audios, labels in test_loader:
    audios = audios.cuda()
    labels = labels.cuda()

    outputs = net(audios)
    loss = loss_func(outputs, labels)
    total_val_loss += loss.item()
    batch_num += 1
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()


val_loss = total_val_loss / batch_num
val_acc = 100.0 * float(correct) / float(total)

print('Validation loss: %.4f' % val_loss,'Validation accuracy: %.2f' % val_acc)

Before fine-tune
Validation loss: 17.7283 Validation accuracy: 21.50


In [None]:
EPOCH = 10
LR = 1e-3
Weight_decay = 1e-9

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=LR, weight_decay=Weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5,  patience=15, verbose=True, eps=1e-09)

In [None]:
print('Begin fine-tune...')
best_accuracy = 0

for epoch in range(EPOCH):
    net.train()
    start_time = time.time()
    total_train_loss = 0
    correct = 0
    total = 0
    batch_num = 0

    for step, (audios, labels) in enumerate(train_loader):
        audios = audios.cuda()
        labels = labels.cuda()
        outputs = net(audios)

        loss = loss_func(outputs, labels)

        total_train_loss += loss.item()
        batch_num += 1
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    end_time = time.time()
    train_time = end_time - start_time
    train_loss = total_train_loss / batch_num
    train_acc = 100.0 * float(correct) / float(total)



    print(time.strftime("%d %b %Y %H:%M:%S", time.localtime()))
    print('Epoch: %3d' % epoch, '|train loss: %.4f' % train_loss, '|train accuracy: %.2f' % train_acc,
          '|train time: %.2f' % train_time)

    scheduler.step(train_loss)

torch.save({'cfg': None, 'state_dict': net.state_dict()}, os.path.join('./Checkpoint', 'SincNet_finetune.pth.tar'))
print('Saving..')


Begin fine-tune...
22 Aug 2023 09:03:08
Epoch:   0 |train loss: 0.4616 |train accuracy: 86.07 |train time: 54.33
22 Aug 2023 09:04:02
Epoch:   1 |train loss: 0.3654 |train accuracy: 88.82 |train time: 53.94
22 Aug 2023 09:04:57
Epoch:   2 |train loss: 0.3346 |train accuracy: 89.87 |train time: 54.12
22 Aug 2023 09:05:51
Epoch:   3 |train loss: 0.2982 |train accuracy: 90.93 |train time: 54.47
22 Aug 2023 09:06:45
Epoch:   4 |train loss: 0.2848 |train accuracy: 91.36 |train time: 54.21
22 Aug 2023 09:07:39
Epoch:   5 |train loss: 0.2652 |train accuracy: 91.97 |train time: 54.27
22 Aug 2023 09:08:34
Epoch:   6 |train loss: 0.2423 |train accuracy: 92.51 |train time: 54.34
22 Aug 2023 09:09:29
Epoch:   7 |train loss: 0.2268 |train accuracy: 93.12 |train time: 54.75
22 Aug 2023 09:10:23
Epoch:   8 |train loss: 0.2204 |train accuracy: 93.38 |train time: 54.54
22 Aug 2023 09:11:18
Epoch:   9 |train loss: 0.2125 |train accuracy: 93.68 |train time: 54.65
Saving..


In [None]:
total_par = sum([param.nelement() for param in net.parameters()])
print("Number of parameter: %.2fM" % (total_par/1e6))

Number of parameter: 0.21M


In [None]:
print('After fine-tune')

total_val_loss = 0
correct = 0
total = 0
batch_num = 0

net.eval()

for audios, labels in test_loader:
    audios = audios.cuda()
    labels = labels.cuda()

    outputs = net(audios)
    loss = loss_func(outputs, labels)
    total_val_loss += loss.item()
    batch_num += 1
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()


val_loss = total_val_loss / batch_num
val_acc = 100.0 * float(correct) / float(total)

print('Validation loss: %.4f' % val_loss,'Validation accuracy: %.2f' % val_acc)

After fine-tune
Validation loss: 0.2430 Validation accuracy: 93.62
