In [1]:
import torch

import copy
import numpy as np
import pandas as pd
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc
from sklearn.model_selection import train_test_split
from scipy.io import arff

from torch import nn, optim

import torch.nn.functional as F



%matplotlib inline
%config InlineBackend.figure_format='retina'

sns.set(style='whitegrid', palette='muted', font_scale=1.2)

HAPPY_COLORS_PALETTE = ["#01BEFE", "#FFDD00", "#FF7D00", "#FF006D", "#ADFF02", "#8F00FF"]

sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))

rcParams['figure.figsize'] = 12, 8

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

<torch._C.Generator at 0x1aed9751e30>

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
with open('ECG5000_TRAIN.arff') as f:
    data, meta = arff.loadarff(f)
train = pd.DataFrame(data)

In [4]:
with open('ECG5000_TEST.arff') as f:
    data, meta = arff.loadarff(f)
test = pd.DataFrame(data)

In [5]:
train.head()

Unnamed: 0,att1,att2,att3,att4,att5,att6,att7,att8,att9,att10,...,att132,att133,att134,att135,att136,att137,att138,att139,att140,target
0,-0.112522,-2.827204,-3.773897,-4.349751,-4.376041,-3.474986,-2.181408,-1.818286,-1.250522,-0.477492,...,0.792168,0.933541,0.796958,0.578621,0.25774,0.228077,0.123431,0.925286,0.193137,b'1'
1,-1.100878,-3.99684,-4.285843,-4.506579,-4.022377,-3.234368,-1.566126,-0.992258,-0.75468,0.042321,...,0.538356,0.656881,0.78749,0.724046,0.555784,0.476333,0.77382,1.119621,-1.43625,b'1'
2,-0.567088,-2.59345,-3.87423,-4.584095,-4.187449,-3.151462,-1.74294,-1.490659,-1.18358,-0.394229,...,0.886073,0.531452,0.311377,-0.021919,-0.713683,-0.532197,0.321097,0.904227,-0.421797,b'1'
3,0.490473,-1.914407,-3.616364,-4.318823,-4.268016,-3.88111,-2.99328,-1.671131,-1.333884,-0.965629,...,0.350816,0.499111,0.600345,0.842069,0.952074,0.990133,1.086798,1.403011,-0.383564,b'1'
4,0.800232,-0.874252,-2.384761,-3.973292,-4.338224,-3.802422,-2.53451,-1.783423,-1.59445,-0.753199,...,1.148884,0.958434,1.059025,1.371682,1.277392,0.960304,0.97102,1.614392,1.421456,b'1'


In [6]:
test.head()

Unnamed: 0,att1,att2,att3,att4,att5,att6,att7,att8,att9,att10,...,att132,att133,att134,att135,att136,att137,att138,att139,att140,target
0,3.690844,0.711414,-2.114091,-4.141007,-4.574472,-3.431909,-1.950791,-1.107067,-0.632322,0.334577,...,0.022847,0.188937,0.480932,0.62925,0.577291,0.665527,1.035997,1.492287,-1.905073,b'1'
1,-1.348132,-3.996038,-4.22675,-4.251187,-3.477953,-2.228422,-1.808488,-1.534242,-0.779861,-0.397999,...,1.570938,1.591394,1.549193,1.193077,0.515134,0.126274,0.267532,1.071148,-1.164009,b'1'
2,1.024295,-0.590314,-1.916949,-2.806989,-3.527905,-3.638675,-2.779767,-2.019031,-1.980754,-1.44068,...,0.443502,0.827582,1.237007,1.235121,1.738103,1.800767,1.816301,1.473963,1.389767,b'1'
3,0.545657,-1.014383,-2.316698,-3.63404,-4.196857,-3.758093,-3.194444,-2.221764,-1.588554,-1.202146,...,0.77753,1.11924,0.902984,0.554098,0.497053,0.418116,0.703108,1.064602,-0.044853,b'1'
4,0.661133,-1.552471,-3.124641,-4.313351,-4.017042,-3.005993,-1.832411,-1.503886,-1.071705,-0.521316,...,1.280823,1.494315,1.618764,1.447449,1.238577,1.749692,1.986803,1.422756,-0.357784,b'1'


In [7]:
df = train._append(test)
df = df.sample(frac=1.0)
df.shape

(5000, 141)

In [8]:
CLASS_NORMAL = 1
class_names = ['Normal', 'R on T', 'PVC', 'SP', 'UB']

In [9]:
new_columns = list(df.columns)
new_columns[-1] = 'target'
df.columns = new_columns
df.head

<bound method NDFrame.head of           att1      att2      att3      att4      att5      att6      att7  \
1001  1.469756 -1.048520 -3.394356 -4.254399 -4.162834 -3.822570 -3.003609   
2086 -1.998602 -3.770552 -4.267091 -4.256133 -3.515288 -2.554540 -1.699639   
2153 -1.187772 -3.365038 -3.695653 -4.094781 -3.992549 -3.425381 -2.057643   
555   0.604969 -1.671363 -3.236131 -3.966465 -4.067820 -3.551897 -2.582864   
205  -1.197203 -3.270123 -3.778723 -3.977574 -3.405060 -2.392634 -1.726322   
...        ...       ...       ...       ...       ...       ...       ...   
3926 -0.248881 -1.346474 -1.855199 -2.519039 -2.947360 -3.233288 -3.087431   
466  -0.287286 -1.199089 -1.563916 -2.078314 -2.456073 -2.508211 -2.465002   
2592 -1.032096 -2.811901 -3.588706 -3.883206 -3.279964 -2.275187 -1.771033   
3272 -1.592541 -2.461370 -2.524132 -3.062815 -2.968224 -2.784655 -2.738399   
360  -1.945586 -3.840519 -3.994683 -4.075513 -3.825354 -2.707352 -1.890840   

          att8      att9     att1