In [113]:
%matplotlib inline
import sys
sys.path
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from tqdm import tqdm_notebook
import tqdm
import json
import shutil
import xlrd
import numpy as np
import pandas as pd

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

from torchvision import datasets, transforms
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from torchvision.utils import make_grid
from torch.autograd import Variable

from ECG_Dataset_class import ECG_Dataset
from ECG_Model import updateBN, ECG_Net,fine_tune

In [115]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = ECG_Net().to(device)
params = json.load(open('config.json', 'r'))
train_data = ECG_Dataset(txt = params["multilabel_save_folder"]+'train.txt', transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_data, batch_size=params["batch_size"], shuffle=True,num_workers=16,pin_memory=True)
val_data = ECG_Dataset(txt= params["multilabel_save_folder"]+'validation.txt', transform=transforms.ToTensor())
val_loader = DataLoader(dataset=val_data, batch_size=params["batch_size"], shuffle=False,num_workers=16,pin_memory=True)
test_data = ECG_Dataset(txt= params["multilabel_save_folder"]+'test.txt', transform=transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=params["batch_size"], shuffle=False,num_workers=16,pin_memory=True)

In [116]:
tmp = {i: len(os.listdir(os.path.join(params["ecg_root_path"], j))) for i, j in enumerate(sorted(
    os.listdir(params["ecg_root_path"]), key=lambda x: int(x[0]) if x[1] == '-' else int(x[:2])))}
counter = Counter(tmp)
max_val = float(max(counter.values()))       
class_weight_tmp = {class_id : max_val/num_ecg for class_id, num_ecg in counter.items()}
c_weight = []
for key,val in class_weight_tmp.items():
    c_weight.append(val)
class_weight = torch.FloatTensor(c_weight).to(device)#必须转成tensor类型
Loss_fn = nn.BCELoss(weight = class_weight).to(device)

In [None]:
model_CKPT = torch.load('./ECG/Pruned_model/ECG_sparse_model.pth.tar')
model.load_state_dict(model_CKPT['state_dict'])

In [None]:
total = 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm1d):
        total = total + m.weight.data.shape[0]##total为所有的gama的总数

bn = torch.zeros(total)
index = 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm1d):
        size = m.weight.data.shape[0]
        bn[index:(index+size)] = m.weight.data.abs().clone()
        index += size

y, i = torch.sort(bn)

thre_index = int(total * 0.5)
thre = y[thre_index]

pruned = 0
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):
    if isinstance(m, nn.BatchNorm1d):
        weight_copy = m.weight.data.clone()
        mask = weight_copy.abs().gt(thre).float().to(device)
        pruned = pruned + mask.shape[0] - torch.sum(mask)
        m.weight.data.mul_(mask)
        m.bias.data.mul_(mask)
        cfg.append(int(torch.sum(mask)))
        cfg_mask.append(mask.clone())
        print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
            format(k, mask.shape[0], int(torch.sum(mask))))

pruned_ratio = pruned/total

In [123]:
newmodel = ECG_Net(init_weights= False, num_classes=26, cfg = cfg).to(device)
layer_id_in_cfg = 0
end_mask = cfg_mask[layer_id_in_cfg]
for [m0, m1] in zip(model.modules(), newmodel.modules()):
    if isinstance(m0, nn.BatchNorm1d):
        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
        m1.weight.data = m0.weight.data[idx1].clone()
        m1.bias.data = m0.bias.data[idx1].clone()
        m1.running_mean = m0.running_mean[idx1].clone()
        m1.running_var = m0.running_var[idx1].clone()
        layer_id_in_cfg =layer_id_in_cfg+  1
        if layer_id_in_cfg < len(cfg_mask):
            end_mask = cfg_mask[layer_id_in_cfg]

In [None]:
convlayer_id_in_cfg = 0
start_mask = torch.ones(12)
end_mask = cfg_mask[convlayer_id_in_cfg]
for [m0, m1] in zip(model.modules(), newmodel.modules()):
    if isinstance(m0, nn.Conv1d):
        idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
        print('In shape: {:d} Out shape:{:d}'.format(idx0.shape[0], idx1.shape[0]))
        w = m0.weight.data[:, idx0, :].clone()
        w = w[idx1, :, :].clone()
        m1.weight.data = w.clone()
        convlayer_id_in_cfg = convlayer_id_in_cfg+1
        start_mask = end_mask.clone()
        if convlayer_id_in_cfg < len(cfg_mask):
            end_mask = cfg_mask[convlayer_id_in_cfg]

In [125]:
line_layer_id_in_cfg = 0
end_mask = cfg_mask[-1]
index = end_mask
idx = []
flag = 1
for [m0, m1] in zip(model.modules(), newmodel.modules()):
    if isinstance(m0, nn.Linear):
        if flag==1:
            idx0 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
            idx1 = idx0+512
            for i in range(len(idx0)):
                idx.append(idx0[i])
                idx.append(idx1[i])            
            m1.weight.data = m0.weight.data[:, idx].clone()
            m1.bias.data = m0.bias.data.clone()
            flag = flag +1
        else:
            m1.weight.data = m0.weight.data.clone()


In [None]:
num_epochs = 20
lr = 1e-5
optimizer_finetune = torch.optim.Adam(newmodel.parameters(),lr = lr, betas = (0.9, 0.999),eps = 1e-8, weight_decay = 0)

newmodel,avg_train_losses, avg_valid_losses = fine_tune(newmodel,train_loader,val_loader,params["batch_size"], num_epochs,Loss_fn,optimizer_finetune)
state = {'cfg': cfg, 'state_dict': newmodel.state_dict()}
torch.save(state,'./ECG/finetune_slimming_model.pth.tar')