In [None]:
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import glob
import shutil
import torch
import torch.nn.functional as F
from torch import nn
import torch.utils.data
import torch.optim as optim
import random
import time

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything()

In [None]:
class ECGDataset(torch.utils.data.Dataset):
    
    def __init__(self,path):
        self.data=pd.read_csv(path,header=None)
        
    def __getitem__(self,idx):
        x=self.data.loc[idx,:186].values # removed the label
        return x
    
    def __len__(self):
        return len(self.data)
        

In [None]:
root_path="../input/ecg-data-mit-arrhythmia-ptb"
cuda=torch.cuda.is_available()
device = torch.device("cuda:0" if cuda else "cpu")
    
class opt:
    batch_size=64
    workers=2
    lr=0.001
    normal_train_path=root_path+"/normal_train.csv"
    normal_valid_path=root_path+"/normal_valid.csv"
    patient_valid_path=root_path+"/patient_valid.csv"
    

normal_train_dset=ECGDataset(opt.normal_train_path)
trainloader=torch.utils.data.DataLoader(normal_train_dset,
                                       batch_size=opt.batch_size,
                                       shuffle=True,
                                        num_workers=opt.workers)

for x in trainloader:
    print(x.shape,x.dtype)
    x=x.numpy()
    x=pd.DataFrame(x)
    x.iloc[:10,:].T.plot(figsize=(20,8))
    plt.show() #plot to check the data pipeline is working well
    break

In [None]:
class LSTMEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        #input=(seq_len,batch_input,input_size)(187,64,1)
        #lstm=(input_size,hidden_size)
        self.hidden_size=20
        self.lstm1 = nn.LSTMCell(1,self.hidden_size)
        self.lstm2 = nn.LSTMCell(1,self.hidden_size)
        self.linear = nn.Linear(self.hidden_size,1)
        self.sigmoid = nn.Sigmoid()

    def Encoder(self,inp):
        ht=torch.zeros(inp.size(0),self.hidden_size,dtype=torch.float,device=device)
        ct=torch.zeros(inp.size(0),self.hidden_size,dtype=torch.float,device=device)
        
        for input_t in inp.chunk(inp.size(1),dim=1):
            ht,ct=self.lstm1(input_t,(ht,ct))
            
        return ht,ct
    
    def Decoder(self,ht,ct):
        
        ot=torch.zeros(ht.size(0),1,dtype=torch.float,device=device)
        outputs=torch.zeros(ht.size(0),187,dtype=torch.float,device=device)
        
        for i in range(187):
            ht,ct=self.lstm2(ot,(ht,ct))
            ot=self.sigmoid(self.linear(ht))
            outputs[:,i]=ot.squeeze()
            
        return outputs
        
    def forward(self,inp):
        
        he,ce=self.Encoder(inp) #hidden encoder,cell_state encoder
        out=self.Decoder(he,ce)
    
        return torch.flip(out,dims=[1])


In [None]:
network=LSTMEncoder().to(device)
optimizer=optim.Adam(network.parameters(),lr=opt.lr)

def loss_function(pred,real):
    #criterion=nn.MSELoss(reduction="sum")
    mse=F.mse_loss(pred,real,reduction="sum")
    return mse

In [None]:

total_loss=[]
network.train()
t1=time.time()

for epoch in range(60):
    loss_count=0
    for i,signal in enumerate(trainloader):

        signal=signal.to(device).float()#batch_size,187
        network.zero_grad()
        output=network(signal)
        loss=loss_function(output,signal)
        loss_count+=loss.item()

        loss.backward()
        optimizer.step() 

    total_loss.append(loss_count)
    if(epoch%10==0):
        print("Epoch : {} Loss : {}".format(epoch,loss_count))
        torch.save(network.state_dict(),"klconv_epoch_{}.pt".format(epoch))

t2=time.time()
time_taken=(t2-t1)/60
print(f"Total Time Taken : {time_taken:.2f}")

In [None]:
plt.plot(list(range(len(total_loss))),total_loss)
plt.xlabel("Total Epochs")
plt.ylabel("Loss value")
plt.show()