In [100]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

class Sequence(nn.Module):
    def __init__(self):
        super(Sequence, self).__init__()
        self.lstm1 = nn.LSTMCell(1, 51)
        self.lstm2 = nn.LSTMCell(51, 51)
        self.linear = nn.Linear(51, 1)

    def forward(self, input, future = 0, seq_len = 32):
        outputs = []
        h_t = torch.zeros(input.size(0), 51, dtype=torch.double)
        c_t = torch.zeros(input.size(0), 51, dtype=torch.double)
        h_t2 = torch.zeros(input.size(0), 51, dtype=torch.double)
        c_t2 = torch.zeros(input.size(0), 51, dtype=torch.double)

        for i, input_t in enumerate(input.chunk(input.size(1), dim=1)):
            h_t, c_t = self.lstm1(input_t, (h_t, c_t))
            h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
            output = self.linear(h_t2)
            output = self.linear(h_t)
            outputs += [output]
        for i in range(future):# if we should predict the future
            h_t, c_t = self.lstm1(output, (h_t, c_t))
            h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
            output = self.linear(h_t2)
            output = self.linear(h_t)
            outputs += [output]
        outputs = torch.stack(outputs, 1).squeeze(2)
        return outputs

In [101]:
def make_data(directory):
    import os
    directory = "./train_data"
    xTr = []
    yTr = []
    for filename in os.listdir(directory):
        beats = np.load(directory + os.sep + filename)
        x, y = split_data(beats)
        xTr.append(x)
        yTr.append(y)
    xTr = np.concatenate(np.array(xTr))
    yTr = np.concatenate(np.array(yTr))
    
    
def split_data(beats, data_len=32):
    end = beats.shape[0]- (beats.shape[0]%(2*data_len))
    beats = beats[:end].reshape(-1, 2*data_len)
    x, y = beats[:,:32], beats[:,32:]
    return x, y

In [102]:
print(xTr.shape, yTr.shape)

(9861, 32) (9861, 32)


In [118]:
n_epochs = 1000
lr = 0.1
momentum = 0.9
weight_decay = 1e-4
nesterov = True

# set random seed to 0
np.random.seed(0)
torch.manual_seed(0)

# load data and make training set
input = torch.from_numpy(xTr[:1000]).double()
target = torch.from_numpy(yTr[:1000]).double()
test_input = input[:100]
test_target = target[:100]

# build the model
seq = Sequence()
seq.double()
criterion = nn.MSELoss()

optimizer = optim.SGD(seq.parameters(), lr=lr, momentum=momentum, nesterov=nesterov, weight_decay=weight_decay)

#begin to train
for i in range(n_epochs):
    print('STEP: ', i)
    
    def closure():
        optimizer.zero_grad()
        out = seq(input)
        loss = criterion(out, target)
        print('loss:', loss.item())
        loss.backward()
        return loss
    
    optimizer.step(closure)
    
    # begin to predict, no need to track gradient here
    with torch.no_grad():
        future = 32
        pred = seq(test_input, future=future)
        loss = criterion(pred[:, :-future], test_target)
        print('test loss:', loss.item())
        y = pred.detach().numpy()
        
    # draw the result
    if i % 10 == 0:
        plt.figure(figsize=(30,10))
        plt.title('Predict future values for time sequences\n(Dashlines are predicted values)', fontsize=30)
        plt.xlabel('x', fontsize=20)
        plt.ylabel('y', fontsize=20)
        plt.xticks(fontsize=20)
        plt.yticks(fontsize=20)
        def draw(yi, color):
            plt.plot(np.arange(input.size(1)), yi[:input.size(1)], color, linewidth = 2.0)
            plt.plot(np.arange(input.size(1), input.size(1) + future), yi[input.size(1):], color + ':', linewidth = 2.0)
        draw(y[0], 'r')
        draw(y[1], 'g')
        draw(y[2], 'b')
        plt.savefig('predict%d.png'%i, bbox_inches='tight')
        plt.close()


STEP:  0
loss: 0.10727303510896695
test loss: 0.10796947738200849
STEP:  1
loss: 0.1071526081953708
test loss: 0.10780587471215781
STEP:  2
loss: 0.10698275826256122
test loss: 0.10759977133457245
STEP:  3
loss: 0.10676818592580439
test loss: 0.107354934241635
STEP:  4
loss: 0.1065127904819715
test loss: 0.10707494092054311
STEP:  5
loss: 0.10622037695990272
test loss: 0.10676322110497093
STEP:  6
loss: 0.10589464184025878
test loss: 0.10642281591632491
STEP:  7
loss: 0.1055388798565359
test loss: 0.10605626234537655
STEP:  8
loss: 0.10515583345035137
test loss: 0.10566566589108001
STEP:  9
loss: 0.10474775167421073
test loss: 0.10525282248210761
STEP:  10
loss: 0.10431651631132384
test loss: 0.10481928010659751
STEP:  11
loss: 0.10386371847958939
test loss: 0.1043663376013991
STEP:  12
loss: 0.10339067554288527
test loss: 0.10389503075995302
STEP:  13
loss: 0.10289843267070048
test loss: 0.10340614029830553
STEP:  14
loss: 0.10238778054964788


KeyboardInterrupt: 