In [1]:
import sys
import argparse
from os.path import join
import os

import pickle
from sklearn.model_selection import train_test_split

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix
from sklearn.utils import class_weight
import warnings
import struct
import time
from sklearn import preprocessing

prepare data for simulation

In [2]:
def Quant(Vx, Q, RQM):
    return round(Q * Vx) - RQM


def ListQuant(data_list, quant_bits):
    
    data_min = min(data_list)
    data_max = max(data_list)

    
    Q = ((1 << quant_bits) - 1) * 1.0 / (data_max - data_min)
    RQM = (int)(round(Q*data_min))

    
    quant_data_list = []
    for x in data_list:
        quant_data = Quant(x, Q, RQM)
        quant_data_list.append(quant_data)
    quant_data_list = np.array(quant_data_list)
    return (Q, RQM, quant_data_list)

In [3]:
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader


class ReadDataset(Dataset):
    def __init__(self, csv_path):
        self.data_info = pd.read_csv(csv_path, header = None)
        self.ecg_arr = np.asarray(self.data_info.iloc[:, 0]) # read csv filename
        self.label_arr = preprocessing.LabelEncoder().fit_transform(np.asarray(self.data_info.iloc[:, 1])) # read labels     
        self.data_len = len(self.data_info.index) # calculate length
        
    def __getitem__(self, index):
        single_ecg_name = self.ecg_arr[index] # get filename
        
        ecg_raw_data = (open(os.path.join('./simplified_atrial_fibrillation/',single_ecg_name),'rb')).read() # read ecg file
        ecg_data = []
        for t in (range(0,len(ecg_raw_data),4)):
            ecg_data.append((struct.unpack('h',ecg_raw_data[t:t+2]))[0])
            
        ecg_Q, ecg_RQM, ecg_Quantdata = ListQuant(ecg_data,quant_bits=8) #quant to 8 bit
        
        # 1D to 2D, and divide 255
        data_resize=ecg_Quantdata
        data_resize=np.float32(np.trunc(data_resize))
        #data_resize=np.float32((np.trunc(data_resize))/255.0)
        
        data_resize.resize((73,73),refcheck=False) # 74*74=5476>5250
        #data_resize = np.expand_dims(data_resize, axis = 1)
        data_resize = np.expand_dims(data_resize, axis = 0)
        
        data = data_resize
        
        label = self.label_arr[index]
        return(data, label)
    
    def __len__(self):
        return self.data_len
    

In [4]:
csv_path = "./simplified_atrial_fibrillation/test.csv"
read_dataset = ReadDataset(csv_path)
test_loader = DataLoader(dataset=read_dataset, batch_size=512,shuffle=True)

In [85]:
import random
cnt = 0

for m, (X_test, y_test) in enumerate(test_loader): 
    if cnt>=1: 
        break  
    i = random.randint(0,len(X_test))
    print((X_test[i]))
    output_txt = X_test[i]
    print(np.median(X_test))
    #print(min(X_train[i][0]))
    print((X_test.shape))
    print(y_test[i])
    output_label = y_test[i]
    print(X_test[i].shape)
    #print(y_test.shape)
    cnt += 1

tensor([[[153., 152., 154.,  ..., 166., 163., 162.],
         [161., 160., 162.,  ..., 128., 149., 156.],
         [168., 172., 179.,  ..., 184., 181., 182.],
         ...,
         [190., 187., 191.,  ..., 222., 248., 205.],
         [100.,  27.,  68.,  ...,   0.,   0.,   0.],
         [  0.,   0.,   0.,  ...,   0.,   0.,   0.]]])
132.0
torch.Size([512, 1, 73, 73])
tensor(0)
torch.Size([1, 73, 73])


In [86]:
output_txt

tensor([[[153., 152., 154.,  ..., 166., 163., 162.],
         [161., 160., 162.,  ..., 128., 149., 156.],
         [168., 172., 179.,  ..., 184., 181., 182.],
         ...,
         [190., 187., 191.,  ..., 222., 248., 205.],
         [100.,  27.,  68.,  ...,   0.,   0.,   0.],
         [  0.,   0.,   0.,  ...,   0.,   0.,   0.]]])

In [7]:
import onnx
import torch
from brevitas.nn import QuantConv2d, QuantIdentity, QuantMaxPool2d, QuantLinear,QuantDropout,QuantReLU,QuantHardTanh
from brevitas.core.quant import QuantType, BinaryQuant, TernaryQuant
from brevitas.core.scaling import ConstScaling
from brevitas.quant import SignedTernaryActPerTensorConst,SignedBinaryActPerTensorConst,Int8ActPerTensorFloatMinMaxInit
import torch.nn as nn
from torch import optim

In [8]:
# according to 1D

class ECG_AF_2D(nn.Module):
    def __init__(self,input_size,num_classes,weight_bit_width,act_bit_width):
        super(ECG_AF_2D,self).__init__()
        
        self.cnn_1 = QuantConv2d(in_channels=input_size,out_channels=6,kernel_size=3,
                                padding=0, stride=2,
                                bias=False,weight_bit_width=weight_bit_width)
        self.bn_1 = nn.BatchNorm2d(6)
        self.relu_1 = QuantReLU(bit_width=act_bit_width)
        self.maxpool_1 = nn.MaxPool2d(kernel_size=2,stride=None)
        
        self.cnn_d2 = QuantConv2d(in_channels=6,out_channels=6,kernel_size=3,
                                 padding=0,groups=6,
                                 bias=False,weight_bit_width=weight_bit_width)
        self.bn_d2 = nn.BatchNorm2d(6)
        self.relu_d2 = QuantReLU(bit_width=act_bit_width)
        self.cnn_2 = QuantConv2d(in_channels=6,out_channels=5,kernel_size=1,
                                padding=0,
                                bias=False,weight_bit_width=weight_bit_width)
        self.bn_2 = nn.BatchNorm2d(5)
        self.relu_2 = QuantReLU(bit_width=act_bit_width)
        
        self.cnn_d3 = QuantConv2d(in_channels=5,out_channels=5,kernel_size=3,
                                 padding=1,groups=5,
                                 bias=False,weight_bit_width=weight_bit_width)
        self.bn_d3 = nn.BatchNorm2d(5)
        self.relu_d3 = QuantReLU(bit_width=act_bit_width)
        self.cnn_3 = QuantConv2d(in_channels=5,out_channels=5,kernel_size=1,
                                padding=0,
                                bias=False,weight_bit_width=weight_bit_width)
        self.bn_3 = nn.BatchNorm2d(5)
        self.relu_3 = QuantReLU(bit_width=act_bit_width)
        self.maxpool_3 = nn.MaxPool2d(kernel_size=2,stride=None) 
     
        self.cnn_d4 = QuantConv2d(in_channels=5,out_channels=5,kernel_size=3,
                                 padding=0,groups=5,
                                 bias=False,weight_bit_width=weight_bit_width)
        self.bn_d4 = nn.BatchNorm2d(5)
        self.relu_d4 = QuantReLU(bit_width=act_bit_width)
        self.cnn_4 = QuantConv2d(in_channels=5,out_channels=3,kernel_size=1,
                                padding=0,
                                bias=False,weight_bit_width=weight_bit_width)
        self.bn_4 = nn.BatchNorm2d(3)
        self.relu_4 = QuantReLU(bit_width=act_bit_width)
        self.maxpool_4 = nn.MaxPool2d(kernel_size=3,stride=None)
        
        self.drop = QuantDropout(p=0.5)
        self.dense =QuantLinear(2*2*3,num_classes,bias=False,weight_bit_width=weight_bit_width)
        self.bn_5 = nn.BatchNorm1d(2)
        self.relu_5 = QuantReLU(bit_width=act_bit_width)
        
    def forward(self,x):
        
        x = self.relu_1(self.bn_1(self.cnn_1(x)))
        x = self.maxpool_1(x)
            
        x = self.relu_d2(self.bn_d2(self.cnn_d2(x)))
        x = self.relu_2(self.bn_2(self.cnn_2(x)))

        x = self.relu_d3(self.bn_d3(self.cnn_d3(x)))
        x = self.relu_3(self.bn_3(self.cnn_3(x)))
        x = self.maxpool_3(x)
        
        x = self.relu_d4(self.bn_d4(self.cnn_d4(x)))
        x = self.relu_4(self.bn_4(self.cnn_4(x)))
        x = self.maxpool_4(x)
            
        x = x.view(x.size(0),-1)
        x = self.drop(x)
        x = self.dense(x)
        x = self.bn_5(x)
        x = self.relu_5(x)
        
        return x

load parameters

In [None]:
input_size = 1 # 3 input layers for RGB image
num_classes = 2
weight_bit_width = 2
act_bit_width = 2

original_model = ECG_AF_2D(input_size=input_size,num_classes=num_classes,
                            weight_bit_width=weight_bit_width,act_bit_width=act_bit_width)
state_dict = torch.load('./multi_model/relu_without_ternarize.pth')
original_model.load_state_dict(state_dict)
original_model.eval()

In [88]:
input_tensor = output_txt/255.0
input_tensor = input_tensor.unsqueeze(0)
print(input_tensor.shape)
print("input:",input_tensor)
print("Label_true:",output_label)

torch.Size([1, 1, 73, 73])
input: tensor([[[[0.6000, 0.5961, 0.6039,  ..., 0.6510, 0.6392, 0.6353],
          [0.6314, 0.6275, 0.6353,  ..., 0.5020, 0.5843, 0.6118],
          [0.6588, 0.6745, 0.7020,  ..., 0.7216, 0.7098, 0.7137],
          ...,
          [0.7451, 0.7333, 0.7490,  ..., 0.8706, 0.9725, 0.8039],
          [0.3922, 0.1059, 0.2667,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]]])
Label_true: tensor(0)


prediction

In [89]:
output_golden = original_model.forward(input_tensor).detach().numpy()
print("output_golden:",output_golden)
output_golden_label = torch.from_numpy(output_golden)
output_golden_label = torch.max(output_golden_label,1)[1]
print("Label_pred:",output_golden_label)

output_golden: [[1.0797518 0.       ]]
Label_pred: tensor([0])


creat test file if needed

In [93]:
output_txt = output_txt.squeeze()
output_txt = (np.array(output_txt)).astype(int)
np.savetxt('./test_data/test_5.txt',output_txt, fmt="%x")

## Simulation using Python

In [12]:
import numpy as np
from finn.core.modelwrapper import ModelWrapper
import onnx.numpy_helper as nph

In [90]:
input_test_1 = output_txt
input_test_1 = input_test_1.unsqueeze(0)
print(input_test_1.shape)

torch.Size([1, 1, 73, 73])


In [91]:
input_numpy = (input_test_1.numpy()).astype(np.float32)#.tolist()
input_dict = {"global_in": input_numpy}
print("input:",input_numpy)
print("Label_true:",output_label)

model_for_sim = ModelWrapper("./test/ECG_AF_2D_w2a2_streamlined.onnx")

input: [[[[153. 152. 154. ... 166. 163. 162.]
   [161. 160. 162. ... 128. 149. 156.]
   [168. 172. 179. ... 184. 181. 182.]
   ...
   [190. 187. 191. ... 222. 248. 205.]
   [100.  27.  68. ...   0.   0.   0.]
   [  0.   0.   0. ...   0.   0.   0.]]]]
Label_true: tensor(0)


In [92]:
import finn.core.onnx_exec as oxe
output_dict = oxe.execute_onnx(model_for_sim, input_dict)
output_pysim = output_dict[list(output_dict.keys())[0]]
print("Results for Simulation: Label_pred",output_pysim)

Results for Simulation: Label_pred [[0]]
