In [1]:
%matplotlib inline
import sys
sys.path
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from tqdm import tqdm_notebook
import json
import shutil
from sklearn.metrics import roc_curve, auc, confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
import xlrd
import numpy as np
import pandas as pd

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

from torchvision import datasets, transforms
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from torchvision.utils import make_grid
from torch.autograd import Variable

from ECG_Dataset_class import ECG_Dataset
from util_wqaq_1D import *
from ECG_Quant_Net import ECG_Quant_Net, ECG_Net, quantization_train

In [2]:
params = json.load(open('config.json', 'r'))

train_data = ECG_Dataset(txt = params["multilabel_save_folder"]+'train.txt', transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_data, batch_size=params["batch_size"], shuffle=True,num_workers=16,pin_memory=True)
val_data = ECG_Dataset(txt= params["multilabel_save_folder"]+'validation.txt', transform=transforms.ToTensor())
val_loader = DataLoader(dataset=val_data, batch_size=params["batch_size"], shuffle=False,num_workers=16,pin_memory=True)
test_data = ECG_Dataset(txt= params["multilabel_save_folder"]+'test.txt', transform=transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=params["batch_size"], shuffle=False,num_workers=16,pin_memory=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tmp = {i: len(os.listdir(os.path.join(params["ecg_root_path"], j))) for i, j in enumerate(sorted(
    os.listdir(params["ecg_root_path"]), key=lambda x: int(x[0]) if x[1] == '-' else int(x[:2])))}
counter = Counter(tmp)
max_val = float(max(counter.values()))       
class_weight_tmp = {class_id : max_val/num_ecg for class_id, num_ecg in counter.items()}
c_weight = []
for key,val in class_weight_tmp.items():
    c_weight.append(val)
class_weight = torch.FloatTensor(c_weight).to(device)
Loss_fn = nn.BCELoss(weight = class_weight).to(device)

In [4]:
model_CKPT = torch.load('./ECG/finetune_slimming_model.pth.tar')
model  = ECG_Net(init_weights= False, num_classes=26, cfg = model_CKPT['cfg']).to(device)
model.load_state_dict(model_CKPT['state_dict'])
model_q = ECG_Quant_Net(init_weights= False, num_classes=26, cfg = model_CKPT['cfg'], act_bits =8, weight_bits = 8, q_type = 1).to(device)

In [5]:
model_conv_module = []
model_batchnorm_module = []
model_liner_module = []

model_q_conv_module  =[]
model_q_liner_module= []

for m0 in model.modules():
    if isinstance(m0, nn.Conv1d):
        model_conv_module.append(m0)
        
        
    if isinstance(m0, nn.BatchNorm1d):
        model_batchnorm_module.append(m0)
    if isinstance(m0, nn.Linear):
        model_liner_module.append(m0)  
        
        
for m1 in model_q.modules():
    if isinstance(m1, BNFold_Conv1d_Q):
        model_q_conv_module.append(m1)
    if isinstance(m1, Linear_Q):
        model_q_liner_module.append(m1)

for [m0, m1] in zip(model_conv_module, model_q_conv_module):
    w = m0.weight.data.clone()    
    m1.weight.data = w.clone()
    m1.bias.data = m0.bias.data.clone()#

for [m0, m1] in zip(model_batchnorm_module, model_q_conv_module):      
    w = m0.weight.data.clone()    
    m1.gamma.data = w.clone()
    m1.beta.data = m0.bias.data.clone()       
        
model_q_liner_module[0].weight.data = model_liner_module[0].weight.data.clone()
model_q_liner_module[0].bias.data = model_liner_module[0].bias.data.clone()

In [None]:
num_epochs = 15
lr = 1e-5
optimizer = torch.optim.Adam(model_q.parameters(),lr = lr, betas = (0.9, 0.999),eps = 1e-8, weight_decay = 0)

model_q,avg_train_losses, avg_valid_losses = quantization_train(model_q,train_loader,val_loader,params["batch_size"], num_epochs,Loss_fn,optimizer)