In [1]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.externals import joblib
from collections import Counter
from scipy import signal
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

In [2]:
data = joblib.load('data.pkl')

In [3]:
data_dir = 'bbdc_2019_Bewegungsdaten/'

In [4]:
data.keys()

dict_keys(['features', 'labels', 'bounds', 'lens', 'label_encoder'])

In [5]:
def standard_scale(data):
    features = data['features']
    lens = np.concatenate((data['lens']['train'],data['lens']['valid'],data['lens']['test']))
    features = np.concatenate(features, axis=0)
    features = StandardScaler().fit_transform(features)
    ret = []
    l = 0
    for t in lens:
        ret.append(features[l: l + t])
        l += t
    features = np.array(ret)
    data['features'] = features

In [6]:
def filter_emg(x):
    high = 20/(1000/2)
    low = 450/(1000/2)
    b, a = signal.butter(4, [high, low], btype='bandpass')
    emg_filtered = signal.filtfilt(b, a, x, axis=0)
    return emg_filtered

In [7]:
# standard_scale(data)

In [8]:
def get_split(data, split):
    if split == 'train':
        return data['features'][: data['bounds']['train']]
    elif split == 'valid':
        return data['features'][data['bounds']['train']: data['bounds']['train'] + data['bounds']['valid']]
    elif split == 'test':
        return data['features'][data['bounds']['train'] + data['bounds']['valid']:]

In [9]:
def generate_batch_idx(n, batch_size, randomise=False):
    idx = np.arange(0, n)
    if randomise:
        np.random.shuffle(idx)
    for batch_idx in np.arange(0, n, batch_size):
        yield idx[batch_idx:batch_idx+batch_size]

In [10]:
def generate_batches(data, split, batch_size,
                     time_steps=10, stride=5, randomise=False):
    features = get_split(data, split)

    try:
        labels = data['label_encoder'].transform(data['labels'][split])
    except:
        labels = np.zeros(features.shape[0])
    
    lens = data['lens'][split]
    new_features = []
    new_labels = []
    new_lens = []
    for i in range(len(lens)):
        mat = features[i]
        label = labels[i]
        l = lens[i]
        acc_emg = [0, 1, 2, 3, 5, 6, 7, 9, 10, 11]
        mat[:, acc_emg] -= mat[:, acc_emg].mean(axis=0)
        mat[:, :4] = filter_emg(mat[:, :4])
        extracted_steps = []
        for j in range(0, len(mat) - time_steps, stride):
            window = mat[j: j + time_steps, :]
            means = window[:, 4:].mean(axis=0).reshape(1, -1)
            rms = np.sqrt((window[:, :4]**2).mean(axis=0)).reshape(1, -1)
            feature_vector = np.concatenate((rms, means), axis=1).reshape(1, -1)
            
            extracted_steps.append(feature_vector)
        extracted_steps = np.concatenate(extracted_steps, axis=0)
        new_features.append(extracted_steps)
        new_labels.append(label)
        new_lens.append(len(extracted_steps))
    features = np.array(new_features)
    labels = np.array(new_labels)
    lens = np.array(new_lens)
    
    n = len(features)
    for batch_idx in generate_batch_idx(n, batch_size, randomise):
        batch_data = features[batch_idx]
        batch_labels = labels[batch_idx]
        batch_lens = lens[batch_idx]
#         batch_data = torch.from_numpy(batch_data).float()
#         batch_labels = torch.from_numpy(labels[batch_idx]).float()
#         lens = torch.from_numpy(lens[batch_idx]).float()
        yield batch_data, batch_labels, batch_lens

In [11]:
# def sort_batch(batch, targets, lengths):
#     """
#     Sort a minibatch by the length of the sequences with the longest sequences first
#     return the sorted batch targes and sequence lengths.
#     This way the output can be used by pack_padded_sequences(...)
#     """
#     perm_idx = np.argsort(lengths)[::-1]
#     seq_lengths = lengths[perm_idx]
#     seq_tensor = batch[perm_idx]
#     target_tensor = targets[perm_idx]
#     return seq_tensor, target_tensor, seq_lengths

def pad_batch(batch, lens):
    max_len = max(lens)
    batch_size = batch.shape[0]
    num_feature = batch[0].shape[1]
    padded_seqs = np.zeros((batch_size, max_len, num_feature))
    
    for i, l in enumerate(lens):
        padded_seqs[i, :l, :] = batch[i][:l]

    return padded_seqs

In [12]:
def torch_batch(batch, targets):
    return torch.from_numpy(batch).float(), torch.from_numpy(targets).long()

In [13]:
def get_preds(model, data, split, batch_size, time_steps, stride):
    model.eval()
    preds = []
    labels = []
    with torch.no_grad():
        for b_data, b_labels, b_lens in generate_batches(data, split, batch_size, 
                                                         time_steps, stride, False):
            b_data = pad_batch(b_data, b_lens)
            b_data, b_labels = torch_batch(b_data, b_labels)
            preds.append(model(b_data, b_lens))
            labels.append(b_labels)
        preds = torch.cat(preds, dim=0)
        labels = torch.cat(labels, dim=0)
    return preds, labels

In [14]:
def get_accuracy(preds, labels, le):
    preds = preds.max(dim=1)[1].numpy()
    preds = [le.classes_[i] for i in preds]
    labels = labels.numpy()
    labels = [le.classes_[i] for i in labels]
    return accuracy_score(labels, preds)

In [28]:
class HARNet(nn.Module):
    def __init__(self, ):
        super().__init__()
        self.rnn1 = nn.RNN(19, 64, num_layers=1)
        self.rnn2 = nn.RNN(64, 64)
        self.lin2 = nn.Linear(64*3, 256, bias=True)
        self.lin3 = nn.Linear(256, 22, bias=True)
    
    def forward(self, data, lens):
        x, _ = self.rnn1(data)
        x = x[:, ::8, :]
        x, _ = self.rnn2(x)
        x_max = torch.cat([seq.max(dim=0)[0].view(1, -1) for seq in x], dim=0)
        x_min = torch.cat([seq.min(dim=0)[0].view(1, -1) for seq in x], dim=0)
        x_avg = torch.cat([seq.mean(dim=0).view(1, -1) for seq in x], dim=0)
#         x_max = torch.cat([seq[0:l,:].max(dim=0)[0].view(1, -1) for l, seq in zip(lens, x)], dim=0)
#         x_min = torch.cat([seq[0:l,:].min(dim=0)[0].view(1, -1) for l, seq in zip(lens, x)], dim=0)
#         x_avg = torch.cat([seq[0:l,:].mean(dim=0).view(1, -1) for l, seq in zip(lens, x)], dim=0)
        x = torch.cat((x_min, x_avg, x_max), dim=1)
        x = self.lin2(x)
        x = F.relu(x)
        x = self.lin3(x)
        return x

In [30]:
num_epochs = 10000
batch_size = 32
time_steps = 10
stride = 5
objective = nn.CrossEntropyLoss()
model = HARNet()
optimiser = torch.optim.Adam(model.parameters(), weight_decay=0.0)
running_loss = 0
running_batch = 0
min_valid_loss = float('inf')

for epoch in range(1, num_epochs + 1):
    with tqdm(enumerate(generate_batches(data, 'train', batch_size,
                                         time_steps, stride, True), 1)) as pbar:
        model.train()
        for batch_num, (batch_data, batch_labels, batch_lens) in pbar:
            batch_data = pad_batch(batch_data, batch_lens)
            batch_data, batch_labels = torch_batch(batch_data, batch_labels)
            optimiser.zero_grad()
            preds = model(batch_data, batch_lens)
            loss = objective(preds, batch_labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 4, norm_type=2)
            optimiser.step()
            running_loss += loss.item()
            running_batch += 1
            pbar.set_description(f'[Epoch: {epoch}] | Batch {batch_num} | Loss: {running_loss/running_batch}')
            
    valid_preds, valid_labels = get_preds(model, data, 'valid',
                                          64, time_steps, stride)
    
    valid_loss = objective(valid_preds, valid_labels).item()
    
    if valid_loss < min_valid_loss:
        print(f'Validation loss improved from {min_valid_loss} to {valid_loss}')
        acc = get_accuracy(valid_preds, valid_labels, data['label_encoder'])
        print(f'Validation accuracy: {acc}')
        min_valid_loss = valid_loss
        with open('best_rnn_model.pt', 'wb') as f:
            torch.save(model.state_dict(), f)
    else:
        print('Validation loss did not improve')

[Epoch: 1] | Batch 186 | Loss: 2.391327222188314: : 186it [02:36,  3.73it/s] 
0it [00:00, ?it/s]

Validation loss improved from inf to 1.9982115030288696
Validation accuracy: 0.21867881548974943


[Epoch: 2] | Batch 186 | Loss: 2.138530998140253: : 186it [02:36,  4.48it/s] 
0it [00:00, ?it/s]

Validation loss improved from 1.9982115030288696 to 1.9344723224639893
Validation accuracy: 0.2870159453302961


[Epoch: 3] | Batch 186 | Loss: 2.027135161088786: : 186it [02:39,  3.93it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 4] | Batch 186 | Loss: 1.9486290923049372: : 186it [02:45,  4.56it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.9344723224639893 to 1.900794267654419
Validation accuracy: 0.3234624145785877


[Epoch: 5] | Batch 186 | Loss: 1.889930397208019: : 186it [02:49,  4.36it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 6] | Batch 186 | Loss: 1.850513397365488: : 186it [02:54,  4.68it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 7] | Batch 186 | Loss: 1.8155900985231415: : 186it [02:49,  4.34it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 8] | Batch 186 | Loss: 1.7838551599972992: : 186it [02:44,  4.50it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 9] | Batch 186 | Loss: 1.76314821137892: : 186it [02:50,  4.28it/s]  
0it [00:00, ?it/s]

Validation loss improved from 1.900794267654419 to 1.8866676092147827
Validation accuracy: 0.3143507972665148


[Epoch: 10] | Batch 186 | Loss: 1.7515867497331352: : 186it [02:55,  4.34it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 11] | Batch 186 | Loss: 1.7426855088095978: : 186it [02:36,  4.07it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.8866676092147827 to 1.8850773572921753
Validation accuracy: 0.34851936218678814


[Epoch: 12] | Batch 186 | Loss: 1.7309430442403295: : 186it [02:41,  4.35it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 13] | Batch 186 | Loss: 1.723718890974204: : 186it [02:46,  4.58it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 14] | Batch 186 | Loss: 1.7134885437294451: : 186it [02:57,  4.07it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 15] | Batch 186 | Loss: 1.702919852989976: : 186it [02:54,  4.39it/s] 
0it [00:00, ?it/s]

Validation loss improved from 1.8850773572921753 to 1.8379545211791992
Validation accuracy: 0.3325740318906606


[Epoch: 16] | Batch 186 | Loss: 1.695331424594887: : 186it [02:58,  4.07it/s] 
0it [00:00, ?it/s]

Validation loss improved from 1.8379545211791992 to 1.8309239149093628
Validation accuracy: 0.2847380410022779


[Epoch: 17] | Batch 186 | Loss: 1.6913639067246293: : 186it [02:52,  4.08it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 18] | Batch 186 | Loss: 1.6875915736707736: : 186it [02:41,  4.55it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 19] | Batch 186 | Loss: 1.6821235637915761: : 186it [03:04,  4.62it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 20] | Batch 186 | Loss: 1.6779625953525625: : 186it [02:58,  3.97it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.8309239149093628 to 1.7847496271133423
Validation accuracy: 0.3097949886104784


[Epoch: 21] | Batch 186 | Loss: 1.671354329195379: : 186it [02:48,  4.50it/s] 
0it [00:00, ?it/s]

Validation loss improved from 1.7847496271133423 to 1.7836196422576904
Validation accuracy: 0.3462414578587699


[Epoch: 22] | Batch 186 | Loss: 1.6623346857427967: : 186it [02:37,  4.42it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 23] | Batch 186 | Loss: 1.6596172772428042: : 186it [02:42,  4.59it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 24] | Batch 186 | Loss: 1.6594098740153842: : 186it [02:41,  4.64it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 25] | Batch 186 | Loss: 1.659781830772277: : 186it [02:36,  4.23it/s] 
0it [00:00, ?it/s]

Validation loss improved from 1.7836196422576904 to 1.729230284690857
Validation accuracy: 0.27790432801822323


[Epoch: 26] | Batch 186 | Loss: 1.6612499338768059: : 186it [02:38,  4.37it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.729230284690857 to 1.6946970224380493
Validation accuracy: 0.30751708428246016


[Epoch: 27] | Batch 186 | Loss: 1.6618906587113047: : 186it [02:36,  4.24it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 28] | Batch 186 | Loss: 1.659756799424482: : 186it [02:37,  4.55it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 29] | Batch 186 | Loss: 1.6559218791488546: : 186it [02:44,  4.36it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 30] | Batch 186 | Loss: 1.653870115057969: : 186it [02:39,  4.56it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 31] | Batch 186 | Loss: 1.6541643689954773: : 186it [02:33,  4.52it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 32] | Batch 186 | Loss: 1.65548610222596: : 186it [02:34,  4.69it/s]  
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 33] | Batch 186 | Loss: 1.654758428917447: : 186it [03:07,  3.54it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 34] | Batch 186 | Loss: 1.654248680101777: : 186it [02:49,  4.36it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 35] | Batch 186 | Loss: 1.6556237679106482: : 186it [02:36,  4.57it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 36] | Batch 186 | Loss: 1.655000576907446: : 186it [02:39,  4.79it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 37] | Batch 186 | Loss: 1.6557005930175548: : 186it [02:36,  4.65it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 38] | Batch 186 | Loss: 1.6536827449685243: : 186it [02:36,  4.50it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 39] | Batch 186 | Loss: 1.6499872681586447: : 186it [02:37,  4.45it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 40] | Batch 186 | Loss: 1.646816321550518: : 186it [02:37,  4.17it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 41] | Batch 186 | Loss: 1.6433557625083264: : 186it [02:37,  4.59it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 42] | Batch 186 | Loss: 1.6413456269459301: : 186it [02:38,  4.58it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 43] | Batch 186 | Loss: 1.6431399923051766: : 186it [02:40,  4.46it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 44] | Batch 186 | Loss: 1.6431951667905902: : 186it [02:38,  4.54it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 45] | Batch 186 | Loss: 1.6455262312894796: : 186it [02:38,  4.69it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 46] | Batch 186 | Loss: 1.6472020512181167: : 186it [02:38,  4.43it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 47] | Batch 186 | Loss: 1.6488342465429278: : 186it [02:35,  4.53it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 48] | Batch 186 | Loss: 1.6493170243090411: : 186it [02:37,  4.71it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 49] | Batch 186 | Loss: 1.6500866618992494: : 186it [02:37,  4.61it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 50] | Batch 186 | Loss: 1.6509028367201488: : 186it [02:36,  4.17it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 51] | Batch 186 | Loss: 1.6517823937822438: : 186it [02:38,  4.34it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 52] | Batch 186 | Loss: 1.6515400039819375: : 186it [02:37,  4.34it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 53] | Batch 186 | Loss: 1.6516832968578736: : 186it [02:38,  4.41it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 54] | Batch 186 | Loss: 1.6531391962858273: : 186it [02:41,  4.30it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 55] | Batch 186 | Loss: 1.654210024203956: : 186it [02:39,  4.86it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 56] | Batch 186 | Loss: 1.6562289557492678: : 186it [02:37,  4.69it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 57] | Batch 186 | Loss: 1.65742292690448: : 186it [02:36,  4.53it/s]  
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 58] | Batch 186 | Loss: 1.6574135051310217: : 186it [02:37,  4.45it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 59] | Batch 186 | Loss: 1.657076194102981: : 186it [02:35,  4.29it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 60] | Batch 186 | Loss: 1.6550496292904713: : 186it [02:35,  4.74it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 61] | Batch 186 | Loss: 1.6545802082688472: : 186it [02:35,  4.13it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 62] | Batch 186 | Loss: 1.6550983764667955: : 186it [02:35,  4.50it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 63] | Batch 186 | Loss: 1.6566410950984456: : 186it [02:36,  4.30it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 64] | Batch 186 | Loss: 1.6577280807378951: : 186it [02:35,  4.62it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 65] | Batch 186 | Loss: 1.6577084154704074: : 186it [02:41,  4.55it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 66] | Batch 186 | Loss: 1.6569160591724759: : 186it [02:37,  4.55it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 67] | Batch 186 | Loss: 1.6567619678067693: : 186it [02:39,  4.74it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 68] | Batch 186 | Loss: 1.656458906353387: : 186it [02:38,  4.20it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 69] | Batch 186 | Loss: 1.6567246201477894: : 186it [02:36,  4.55it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 70] | Batch 186 | Loss: 1.6570392996210106: : 186it [02:36,  4.47it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 71] | Batch 186 | Loss: 1.656477243525141: : 186it [02:35,  4.31it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 72] | Batch 186 | Loss: 1.6556730591463928: : 186it [02:36,  4.45it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 73] | Batch 186 | Loss: 1.6551606785848962: : 186it [02:35,  4.26it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 74] | Batch 186 | Loss: 1.654514066071678: : 186it [02:36,  4.68it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 75] | Batch 186 | Loss: 1.6539567200770087: : 186it [02:36,  4.65it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 76] | Batch 186 | Loss: 1.653285127459911: : 186it [02:36,  4.76it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 77] | Batch 186 | Loss: 1.653861655842374: : 186it [02:39,  4.42it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 78] | Batch 186 | Loss: 1.6535822235588071: : 186it [02:36,  4.46it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 79] | Batch 186 | Loss: 1.6540313918233258: : 186it [02:35,  4.67it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 80] | Batch 186 | Loss: 1.6540276530609337: : 186it [02:35,  4.64it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 81] | Batch 186 | Loss: 1.6543117006524821: : 186it [02:42,  4.49it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 82] | Batch 186 | Loss: 1.654617398828776: : 186it [02:34,  4.20it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 83] | Batch 186 | Loss: 1.6545605910984187: : 186it [02:35,  4.83it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 84] | Batch 186 | Loss: 1.6546915126225306: : 186it [02:35,  4.70it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 85] | Batch 186 | Loss: 1.6549753858951133: : 186it [02:35,  4.29it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 86] | Batch 186 | Loss: 1.654297500729531: : 186it [02:35,  4.78it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 87] | Batch 186 | Loss: 1.6537189297999635: : 186it [02:38,  4.35it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 88] | Batch 186 | Loss: 1.6528495357299946: : 186it [02:38,  4.31it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 89] | Batch 186 | Loss: 1.6517825102307766: : 186it [02:37,  4.65it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 90] | Batch 186 | Loss: 1.651853125603823: : 186it [02:37,  4.55it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 91] | Batch 186 | Loss: 1.6511840731500664: : 186it [02:37,  4.40it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 92] | Batch 186 | Loss: 1.6504014195014851: : 186it [02:37,  4.29it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 93] | Batch 186 | Loss: 1.6498654530990853: : 186it [02:36,  4.59it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 94] | Batch 186 | Loss: 1.649680685892009: : 186it [02:36,  4.35it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 95] | Batch 186 | Loss: 1.6495278753867308: : 186it [02:37,  4.44it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 96] | Batch 186 | Loss: 1.6501762560538706: : 186it [02:38,  4.35it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 97] | Batch 186 | Loss: 1.6503269001750693: : 186it [02:42,  4.23it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 98] | Batch 186 | Loss: 1.6498513815432807: : 186it [02:38,  4.12it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 99] | Batch 186 | Loss: 1.6494576656703155: : 186it [02:38,  4.60it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 100] | Batch 186 | Loss: 1.6490917599265293: : 186it [02:36,  4.47it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 101] | Batch 186 | Loss: 1.6483608715059301: : 186it [02:36,  4.37it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 102] | Batch 186 | Loss: 1.6475773022668965: : 186it [02:37,  4.67it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 103] | Batch 186 | Loss: 1.646558422748101: : 186it [02:36,  4.66it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 104] | Batch 186 | Loss: 1.645920145853035: : 186it [02:35,  4.60it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 105] | Batch 186 | Loss: 1.6450486665130943: : 186it [02:36,  4.71it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 106] | Batch 186 | Loss: 1.6446310230178411: : 186it [02:38,  4.59it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 107] | Batch 186 | Loss: 1.6443896233202908: : 186it [02:36,  4.21it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 108] | Batch 186 | Loss: 1.6442362672776543: : 186it [02:41,  4.47it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 109] | Batch 186 | Loss: 1.6440344714278756: : 186it [02:38,  4.35it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 110] | Batch 186 | Loss: 1.6432126205046738: : 186it [02:38,  4.69it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 111] | Batch 186 | Loss: 1.6430594604987454: : 186it [02:37,  4.72it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 112] | Batch 186 | Loss: 1.6433004716814663: : 186it [02:35,  4.66it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 113] | Batch 186 | Loss: 1.6433697594075372: : 186it [02:37,  3.84it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 114] | Batch 186 | Loss: 1.6437230125010587: : 186it [02:36,  4.54it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.6946970224380493 to 1.6564568281173706
Validation accuracy: 0.35990888382687924


[Epoch: 115] | Batch 186 | Loss: 1.64390445258808: : 186it [02:36,  4.35it/s]  
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 116] | Batch 186 | Loss: 1.6439622979131856: : 186it [02:38,  4.06it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 117] | Batch 186 | Loss: 1.6438918181031452: : 186it [02:37,  4.45it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 118] | Batch 186 | Loss: 1.6440694262328772: : 186it [02:35,  4.49it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 119] | Batch 186 | Loss: 1.6446377916841033: : 186it [02:35,  4.39it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 120] | Batch 186 | Loss: 1.644895737252355: : 186it [02:37,  4.24it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 121] | Batch 186 | Loss: 1.6447709135762174: : 186it [02:37,  4.52it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 122] | Batch 186 | Loss: 1.6441010020932003: : 186it [02:36,  4.89it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 123] | Batch 186 | Loss: 1.6432844494458148: : 186it [02:36,  4.37it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 124] | Batch 186 | Loss: 1.6424141107201369: : 186it [02:38,  4.26it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 125] | Batch 186 | Loss: 1.6419514037819318: : 186it [02:37,  4.61it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 126] | Batch 186 | Loss: 1.6415058619069818: : 186it [02:37,  4.57it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 127] | Batch 186 | Loss: 1.6408500148010399: : 186it [02:36,  4.99it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 128] | Batch 186 | Loss: 1.639620883690734: : 186it [02:35,  4.41it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 129] | Batch 186 | Loss: 1.6388182613743716: : 186it [02:43,  4.53it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 130] | Batch 186 | Loss: 1.6376757948284213: : 186it [02:37,  4.40it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 131] | Batch 186 | Loss: 1.6366252336196396: : 186it [02:37,  4.08it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 132] | Batch 186 | Loss: 1.635576070607699: : 186it [02:36,  4.26it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 133] | Batch 186 | Loss: 1.6343660265368722: : 186it [02:35,  4.62it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 134] | Batch 186 | Loss: 1.6334989814061143: : 186it [02:35,  3.92it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 135] | Batch 186 | Loss: 1.632917484155645: : 186it [02:36,  4.56it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 136] | Batch 186 | Loss: 1.6319930902237496: : 186it [02:35,  4.51it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 137] | Batch 186 | Loss: 1.6313744001420922: : 186it [02:35,  4.38it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 138] | Batch 186 | Loss: 1.6308702511548328: : 186it [02:39,  4.31it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 139] | Batch 186 | Loss: 1.6304873996496736: : 186it [02:37,  4.59it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 140] | Batch 186 | Loss: 1.6301968624904042: : 186it [02:38,  3.68it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 141] | Batch 186 | Loss: 1.6294057058754388: : 186it [02:35,  4.69it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 142] | Batch 186 | Loss: 1.6290877354306632: : 186it [02:40,  4.52it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 143] | Batch 186 | Loss: 1.6287109378527247: : 186it [02:38,  4.33it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 144] | Batch 186 | Loss: 1.6278666541705813: : 186it [02:36,  4.52it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 145] | Batch 186 | Loss: 1.6268933955081357: : 186it [02:36,  4.69it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 146] | Batch 186 | Loss: 1.6257208409159607: : 186it [02:38,  4.39it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 147] | Batch 186 | Loss: 1.6248829190204892: : 186it [02:35,  4.55it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 148] | Batch 186 | Loss: 1.6244010737545267: : 186it [02:35,  4.48it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 149] | Batch 186 | Loss: 1.6240203468656096: : 186it [02:36,  4.46it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.6564568281173706 to 1.6423368453979492
Validation accuracy: 0.3895216400911162


[Epoch: 150] | Batch 186 | Loss: 1.6245301305756348: : 186it [02:36,  4.55it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 151] | Batch 186 | Loss: 1.62443561611637: : 186it [02:35,  4.58it/s]  
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 152] | Batch 186 | Loss: 1.6241803776027965: : 186it [02:36,  4.39it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 153] | Batch 186 | Loss: 1.6247879384050585: : 186it [02:38,  5.03it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 154] | Batch 186 | Loss: 1.6253110394037864: : 186it [02:36,  4.39it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 155] | Batch 186 | Loss: 1.6250582550885238: : 186it [02:36,  4.48it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 156] | Batch 186 | Loss: 1.6254248518197867: : 186it [02:37,  4.58it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 157] | Batch 186 | Loss: 1.6250359883076377: : 186it [02:35,  4.69it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 158] | Batch 186 | Loss: 1.624731376288504: : 186it [02:36,  4.23it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 159] | Batch 186 | Loss: 1.6244137428402134: : 186it [02:37,  4.64it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 160] | Batch 186 | Loss: 1.623881434843505: : 186it [02:37,  4.32it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 161] | Batch 186 | Loss: 1.623538353673157: : 186it [02:35,  4.39it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 162] | Batch 186 | Loss: 1.6237258663373562: : 186it [02:36,  4.44it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 163] | Batch 186 | Loss: 1.6241337668947624: : 186it [02:37,  4.52it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 164] | Batch 186 | Loss: 1.6252072232951493: : 186it [02:38,  3.92it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 165] | Batch 186 | Loss: 1.626448906794477: : 186it [02:36,  4.88it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 166] | Batch 186 | Loss: 1.6269021988405374: : 186it [02:35,  4.70it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 167] | Batch 186 | Loss: 1.6274223019443363: : 186it [02:36,  4.47it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 168] | Batch 186 | Loss: 1.6277218101802724: : 186it [02:37,  4.49it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 169] | Batch 186 | Loss: 1.6283939162089927: : 186it [02:36,  4.86it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 170] | Batch 186 | Loss: 1.6287597338754418: : 186it [02:34,  4.06it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 171] | Batch 186 | Loss: 1.629023420406289: : 186it [02:35,  4.54it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 172] | Batch 186 | Loss: 1.6288883107066692: : 186it [02:35,  4.63it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 173] | Batch 186 | Loss: 1.6285222680297216: : 186it [02:38,  4.61it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 174] | Batch 186 | Loss: 1.628157010388277: : 186it [02:36,  4.18it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 175] | Batch 186 | Loss: 1.6279335509393988: : 186it [02:38,  4.12it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 176] | Batch 186 | Loss: 1.6280446816108327: : 186it [02:37,  4.44it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 177] | Batch 186 | Loss: 1.627824625592498: : 186it [02:36,  4.46it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 178] | Batch 186 | Loss: 1.6273302753449523: : 186it [02:37,  4.14it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 179] | Batch 186 | Loss: 1.626620467478086: : 186it [02:36,  4.33it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 180] | Batch 186 | Loss: 1.6258989760311702: : 186it [02:37,  4.46it/s]
0it [00:00, ?it/s]

Validation loss did not improve





KeyboardInterrupt: 

In [39]:
model.load_state_dict(torch.load('best_rnn_model.pt'))

In [40]:
train_preds, train_labels = get_preds(model, data, 'train',
                                          64, time_steps, stride)

  b = a[a_slice]


In [41]:
train_loss = objective(train_preds, train_labels).item()
train_loss

1.5043046474456787

In [42]:
acc = get_accuracy(train_preds, train_labels, data['label_encoder'])
acc

0.43340060544904135

In [43]:
valid_preds, valid_labels = get_preds(model, data, 'valid',
                                          64, time_steps, stride)

In [44]:
valid_loss = objective(valid_preds, valid_labels).item()
valid_loss

1.6424287557601929

In [45]:
acc = get_accuracy(valid_preds, valid_labels, data['label_encoder'])
acc

0.3895216400911162