In [1]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.externals import joblib
from collections import Counter
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

In [2]:
data = joblib.load('data.pkl')

In [3]:
data_dir = 'bbdc_2019_Bewegungsdaten/'

In [4]:
data.keys()

dict_keys(['features', 'labels', 'bounds', 'lens', 'label_encoder'])

In [5]:
def standard_scale(data):
    features = data['features']
    lens = np.concatenate((data['lens']['train'],data['lens']['valid'],data['lens']['test']))
    features = np.concatenate(features, axis=0)
    features = StandardScaler().fit_transform(features)
    ret = []
    l = 0
    for t in lens:
        ret.append(features[l: l + t])
        l += t
    features = np.array(ret)
    data['features'] = features

In [None]:
standard_scale(data)

In [6]:
def get_split(data, split):
    if split == 'train':
        return data['features'][: data['bounds']['train']]
    elif split == 'valid':
        return data['features'][data['bounds']['train']: data['bounds']['train'] + data['bounds']['valid']]
    elif split == 'test':
        return data['features'][data['bounds']['train'] + data['bounds']['valid']:]

In [7]:
def generate_batch_idx(n, batch_size, randomise=False):
    idx = np.arange(0, n)
    if randomise:
        np.random.shuffle(idx)
    for batch_idx in np.arange(0, n, batch_size):
        yield idx[batch_idx:batch_idx+batch_size]

In [8]:
def generate_batches(data, split, batch_size,
                     randomise=False):
    features = get_split(data, split)
    n = features.shape[0]
    try:
        labels = data['label_encoder'].transform(data['labels'][split])
    except:
        labels = np.zeros(features.shape[0])
    lens = data['lens'][split]
    for batch_idx in generate_batch_idx(n, batch_size, randomise):
        batch_data = features[batch_idx]
        batch_labels = labels[batch_idx]
        batch_lens = lens[batch_idx]
#         batch_data = torch.from_numpy(batch_data).float()
#         batch_labels = torch.from_numpy(labels[batch_idx]).float()
#         lens = torch.from_numpy(lens[batch_idx]).float()
        yield batch_data, batch_labels, batch_lens

In [9]:
# def sort_batch(batch, targets, lengths):
#     """
#     Sort a minibatch by the length of the sequences with the longest sequences first
#     return the sorted batch targes and sequence lengths.
#     This way the output can be used by pack_padded_sequences(...)
#     """
#     perm_idx = np.argsort(lengths)[::-1]
#     seq_lengths = lengths[perm_idx]
#     seq_tensor = batch[perm_idx]
#     target_tensor = targets[perm_idx]
#     return seq_tensor, target_tensor, seq_lengths

def pad_batch(batch, lens):
    max_len = max(lens)
    batch_size = batch.shape[0]
    num_feature = batch[0].shape[1]
    padded_seqs = np.zeros((batch_size, max_len, num_feature))
    
    for i, l in enumerate(lens):
        padded_seqs[i, :l, :] = batch[i][:l]

    return padded_seqs

In [10]:
def torch_batch(batch, targets):
    return torch.from_numpy(batch).float(), torch.from_numpy(targets).long()

In [28]:
def get_preds(model, data, split, batch_size):
    model.eval()
    preds = []
    labels = []
    with torch.no_grad():
        for b_data, b_labels, b_lens in generate_batches(data, split, batch_size, False):
            b_data = pad_batch(b_data, b_lens)
            b_data, b_labels = torch_batch(b_data, b_labels)
            preds.append(model(b_data, b_lens))
            labels.append(b_labels)
    preds = torch.cat(preds, dim=0)
    labels = torch.cat(labels, dim=0)
    return preds, labels

In [36]:
def get_accuracy(preds, labels, le):
    preds = preds.max(dim=1)[1].numpy()
    preds = [le.classes_[i] for i in preds]
    labels = labels.numpy()
    labels = [le.classes_[i] for i in labels]
    return accuracy_score(labels, preds)

In [40]:
class HARNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn = nn.GRU(input_size=14,
                          hidden_size=32,
                          num_layers=1,
                          batch_first=True)
#         self.lin1 = nn.Linear(19, 19, bias=True)
        self.lin2 = nn.Linear(32, 22, bias=True)
    
    def forward(self, data, lens):
        x, _ = self.rnn(data[:, :300, 5:])
        x = torch.cat([seq[-1].view(1, -1) for l, seq in zip(lens, x)], dim=0)
#         x = self.lin1(x)
#         x = F.relu(x)
        x = self.lin2(x)
        return x

In [41]:
num_epochs = 10000
batch_size = 32
objective = nn.CrossEntropyLoss()
model = HARNet()
optimiser = torch.optim.Adam(model.parameters(), weight_decay=0.0)
running_loss = 0
running_batch = 0
min_valid_loss = float('inf')

for epoch in range(1, num_epochs + 1):
    with tqdm(enumerate(generate_batches(data, 'train', batch_size, True), 1)) as pbar:
        model.train()
        for batch_num, (batch_data, batch_labels, batch_lens) in pbar:
            batch_data = pad_batch(batch_data, batch_lens)
            batch_data, batch_labels = torch_batch(batch_data, batch_labels)
            optimiser.zero_grad()
            preds = model(batch_data, batch_lens)
            loss = objective(preds, batch_labels)
            loss.backward()
            optimiser.step()
            running_loss += loss.item()
            running_batch += 1
            pbar.set_description(f'[Epoch: {epoch}] | Batch {batch_num} | Loss: {running_loss/running_batch}')
            
    valid_preds, valid_labels = get_preds(model, data, 'valid', 64)
    valid_loss = objective(valid_preds, valid_labels).item()
    if valid_loss < min_valid_loss:
        print(f'Validation loss improved from {min_valid_loss} to {valid_loss}')
        acc = get_accuracy(valid_preds, valid_labels, data['label_encoder'])
        print(f'Validation accuracy: {acc}')
        min_valid_loss = valid_loss
        with open('best_rnn_model.pt', 'wb') as f:
            torch.save(model.state_dict(), f)
    else:
        print('Validation loss did not improve')

[Epoch: 1] | Batch 165 | Loss: 2.920111407655658: : 165it [00:39,  5.18it/s] 
0it [00:00, ?it/s]

Validation loss improved from inf to 2.7040443420410156
Validation accuracy: 0.17213842058562556


[Epoch: 2] | Batch 165 | Loss: 2.7633545702153985: : 165it [00:33,  4.14it/s]
0it [00:00, ?it/s]

Validation loss improved from 2.7040443420410156 to 2.559023380279541
Validation accuracy: 0.16858917480035493


[Epoch: 3] | Batch 165 | Loss: 2.666747970773716: : 165it [00:34,  5.90it/s] 
[Epoch: 4] | Batch 1 | Loss: 2.666240318167594: : 1it [00:00,  5.22it/s]

Validation loss improved from 2.559023380279541 to 2.4815807342529297
Validation accuracy: 0.225377107364685


[Epoch: 4] | Batch 165 | Loss: 2.606426884911277: : 165it [00:28,  6.18it/s] 
[Epoch: 5] | Batch 1 | Loss: 2.605842046045139: : 1it [00:00,  6.79it/s]

Validation loss improved from 2.4815807342529297 to 2.419738531112671
Validation accuracy: 0.2422360248447205


[Epoch: 5] | Batch 165 | Loss: 2.558132241277984: : 165it [00:29,  5.74it/s] 
[Epoch: 6] | Batch 1 | Loss: 2.5577618475399064: : 1it [00:00,  7.11it/s]

Validation loss did not improve


[Epoch: 6] | Batch 165 | Loss: 2.5193893878146856: : 165it [00:25,  7.41it/s]
[Epoch: 7] | Batch 1 | Loss: 2.519496263095998: : 1it [00:00,  6.65it/s]

Validation loss improved from 2.419738531112671 to 2.417797327041626
Validation accuracy: 0.2422360248447205


[Epoch: 7] | Batch 165 | Loss: 2.4860201597213747: : 165it [00:23,  7.82it/s]
[Epoch: 8] | Batch 1 | Loss: 2.485858016051223: : 1it [00:00,  6.50it/s]

Validation loss improved from 2.417797327041626 to 2.3897104263305664
Validation accuracy: 0.25377107364685003


[Epoch: 8] | Batch 165 | Loss: 2.4563843357743638: : 165it [00:23,  7.88it/s]
[Epoch: 9] | Batch 1 | Loss: 2.456153692123477: : 1it [00:00,  7.46it/s]

Validation loss improved from 2.3897104263305664 to 2.3842618465423584
Validation accuracy: 0.23868677905944988


[Epoch: 9] | Batch 165 | Loss: 2.429340716805121: : 165it [00:22,  7.27it/s] 
[Epoch: 10] | Batch 1 | Loss: 2.4291690295188615: : 1it [00:00,  7.33it/s]

Validation loss improved from 2.3842618465423584 to 2.3750364780426025
Validation accuracy: 0.24134871339840283


[Epoch: 10] | Batch 165 | Loss: 2.4051210685209794: : 165it [00:22,  7.88it/s]
[Epoch: 11] | Batch 1 | Loss: 2.4049424775653288: : 1it [00:00,  7.29it/s]

Validation loss improved from 2.3750364780426025 to 2.3606815338134766
Validation accuracy: 0.24933451641526175


[Epoch: 11] | Batch 165 | Loss: 2.3808117205774817: : 165it [00:22,  7.31it/s]
[Epoch: 12] | Batch 1 | Loss: 2.3803689971500557: : 1it [00:00,  7.17it/s]

Validation loss improved from 2.3606815338134766 to 2.3310701847076416
Validation accuracy: 0.2511091393078971


[Epoch: 12] | Batch 165 | Loss: 2.3569025335287805: : 165it [00:22,  8.07it/s]
[Epoch: 13] | Batch 1 | Loss: 2.356614926257319: : 1it [00:00,  6.77it/s]

Validation loss did not improve


[Epoch: 13] | Batch 165 | Loss: 2.3340041033871524: : 165it [00:22,  7.13it/s]
[Epoch: 14] | Batch 1 | Loss: 2.333872556019807: : 1it [00:00,  7.50it/s]

Validation loss did not improve


[Epoch: 14] | Batch 165 | Loss: 2.312129600584765: : 165it [00:22,  6.55it/s] 
[Epoch: 15] | Batch 1 | Loss: 2.3121564007073654: : 1it [00:00,  7.21it/s]

Validation loss did not improve


[Epoch: 15] | Batch 165 | Loss: 2.2916503328265567: : 165it [00:22,  7.43it/s]
[Epoch: 16] | Batch 1 | Loss: 2.2915644715983956: : 1it [00:00,  7.07it/s]

Validation loss improved from 2.3310701847076416 to 2.316545009613037
Validation accuracy: 0.2759538598047915


[Epoch: 16] | Batch 165 | Loss: 2.2702608046658113: : 165it [00:22,  8.04it/s]
[Epoch: 17] | Batch 1 | Loss: 2.2701131935961: : 1it [00:00,  7.37it/s]

Validation loss did not improve


[Epoch: 17] | Batch 165 | Loss: 2.2481484899758866: : 165it [00:22,  7.36it/s]
[Epoch: 18] | Batch 1 | Loss: 2.248009513083816: : 1it [00:00,  7.33it/s]

Validation loss did not improve


[Epoch: 18] | Batch 165 | Loss: 2.2292532864243095: : 165it [00:23,  7.06it/s]
[Epoch: 19] | Batch 1 | Loss: 2.2290625518839273: : 1it [00:00,  7.63it/s]

Validation loss did not improve


[Epoch: 19] | Batch 165 | Loss: 2.2082789791447883: : 165it [00:22,  7.41it/s]
[Epoch: 20] | Batch 1 | Loss: 2.208196036441594: : 1it [00:00,  7.50it/s]

Validation loss did not improve


[Epoch: 20] | Batch 165 | Loss: 2.1877361034624503: : 165it [00:22,  7.40it/s]
[Epoch: 21] | Batch 1 | Loss: 2.187626479972821: : 1it [00:00,  7.40it/s]

Validation loss did not improve


[Epoch: 21] | Batch 165 | Loss: 2.167627999483249: : 165it [00:22,  7.44it/s] 
[Epoch: 22] | Batch 1 | Loss: 2.167524624338607: : 1it [00:00,  7.36it/s]

Validation loss did not improve


[Epoch: 22] | Batch 165 | Loss: 2.1487409524024352: : 165it [00:22,  7.44it/s]
[Epoch: 23] | Batch 1 | Loss: 2.1486266459751575: : 1it [00:00,  7.36it/s]

Validation loss did not improve


[Epoch: 23] | Batch 165 | Loss: 2.131818603778231: : 165it [00:22,  7.46it/s] 
[Epoch: 24] | Batch 1 | Loss: 2.131728010477834: : 1it [00:00,  7.75it/s]

Validation loss did not improve


[Epoch: 24] | Batch 165 | Loss: 2.114759165199116: : 165it [00:22,  7.48it/s] 
[Epoch: 25] | Batch 1 | Loss: 2.114618723424388: : 1it [00:00,  7.85it/s]

Validation loss did not improve


[Epoch: 25] | Batch 165 | Loss: 2.0969165690161966: : 165it [00:22,  7.44it/s]
[Epoch: 26] | Batch 1 | Loss: 2.09675434718592: : 1it [00:00,  7.76it/s]

Validation loss did not improve


[Epoch: 26] | Batch 165 | Loss: 2.079741462349614: : 165it [00:22,  7.44it/s] 
[Epoch: 27] | Batch 1 | Loss: 2.0796430876026673: : 1it [00:00,  7.70it/s]

Validation loss did not improve


[Epoch: 27] | Batch 165 | Loss: 2.063381568751351: : 165it [00:22,  7.45it/s] 
[Epoch: 28] | Batch 1 | Loss: 2.063231390966119: : 1it [00:00,  7.63it/s]

Validation loss did not improve


[Epoch: 28] | Batch 165 | Loss: 2.0464488895037474: : 165it [00:22,  7.46it/s]
[Epoch: 29] | Batch 1 | Loss: 2.046305445449947: : 1it [00:00,  7.54it/s]

Validation loss did not improve


[Epoch: 29] | Batch 165 | Loss: 2.0299948994393384: : 165it [00:22,  7.48it/s]
[Epoch: 30] | Batch 1 | Loss: 2.0299312873485444: : 1it [00:00,  7.74it/s]

Validation loss did not improve


[Epoch: 30] | Batch 165 | Loss: 2.0136594951875284: : 165it [00:22,  7.47it/s]
[Epoch: 31] | Batch 1 | Loss: 2.0134959705693145: : 1it [00:00,  7.78it/s]

Validation loss did not improve


[Epoch: 31] | Batch 165 | Loss: 1.9972610621275206: : 165it [00:22,  7.45it/s]
[Epoch: 32] | Batch 1 | Loss: 1.9971590858241186: : 1it [00:00,  7.84it/s]

Validation loss did not improve


[Epoch: 32] | Batch 165 | Loss: 1.9815271450375969: : 165it [00:22,  8.02it/s]
[Epoch: 33] | Batch 1 | Loss: 1.981434996336659: : 1it [00:00,  7.60it/s]

Validation loss did not improve


[Epoch: 33] | Batch 165 | Loss: 1.966059383905952: : 165it [00:22,  7.80it/s] 
[Epoch: 34] | Batch 1 | Loss: 1.9659537911918554: : 1it [00:00,  7.82it/s]

Validation loss did not improve


[Epoch: 34] | Batch 165 | Loss: 1.9510011244053933: : 165it [00:22,  7.48it/s]
[Epoch: 35] | Batch 1 | Loss: 1.9508938957739377: : 1it [00:00,  7.56it/s]

Validation loss did not improve


[Epoch: 35] | Batch 165 | Loss: 1.936070972263039: : 165it [00:22,  7.41it/s] 
[Epoch: 36] | Batch 1 | Loss: 1.9359387093200413: : 1it [00:00,  7.85it/s]

Validation loss did not improve


[Epoch: 36] | Batch 165 | Loss: 1.9225690801536996: : 165it [00:22,  7.47it/s]
[Epoch: 37] | Batch 1 | Loss: 1.9224923132565903: : 1it [00:00,  7.57it/s]

Validation loss did not improve


[Epoch: 37] | Batch 165 | Loss: 1.9085673969754617: : 165it [00:22,  7.44it/s]
[Epoch: 38] | Batch 1 | Loss: 1.9085338117839703: : 1it [00:00,  7.19it/s]

Validation loss did not improve


[Epoch: 38] | Batch 165 | Loss: 1.8960952035072698: : 165it [00:22,  8.00it/s]
[Epoch: 39] | Batch 1 | Loss: 1.8959889443197084: : 1it [00:00,  7.60it/s]

Validation loss did not improve


[Epoch: 39] | Batch 165 | Loss: 1.882965840835764: : 165it [00:22,  7.47it/s] 
[Epoch: 40] | Batch 1 | Loss: 1.8828946992967943: : 1it [00:00,  7.67it/s]

Validation loss did not improve


[Epoch: 40] | Batch 165 | Loss: 1.869981798538656: : 165it [00:22,  7.23it/s] 
[Epoch: 41] | Batch 1 | Loss: 1.869862216392298: : 1it [00:00,  7.73it/s]

Validation loss did not improve


[Epoch: 41] | Batch 165 | Loss: 1.857338693533134: : 165it [00:22,  7.41it/s] 
[Epoch: 42] | Batch 1 | Loss: 1.8572645176013758: : 1it [00:00,  7.75it/s]

Validation loss did not improve


[Epoch: 42] | Batch 165 | Loss: 1.8449891431335075: : 165it [00:22,  7.38it/s]
[Epoch: 43] | Batch 1 | Loss: 1.844953007881061: : 1it [00:00,  7.54it/s]

Validation loss did not improve


[Epoch: 43] | Batch 165 | Loss: 1.8326601261400353: : 165it [00:22,  7.46it/s]
[Epoch: 44] | Batch 1 | Loss: 1.8325471488723222: : 1it [00:00,  7.53it/s]

Validation loss did not improve


[Epoch: 44] | Batch 165 | Loss: 1.820804473839844: : 165it [00:22,  8.05it/s] 
[Epoch: 45] | Batch 1 | Loss: 1.8207510024075206: : 1it [00:00,  7.49it/s]

Validation loss did not improve


[Epoch: 45] | Batch 165 | Loss: 1.8099340298681548: : 165it [00:22,  7.45it/s]
[Epoch: 46] | Batch 1 | Loss: 1.8098386175671646: : 1it [00:00,  7.35it/s]

Validation loss did not improve


[Epoch: 46] | Batch 165 | Loss: 1.7985597651937733: : 165it [00:22,  7.30it/s]
[Epoch: 47] | Batch 1 | Loss: 1.7984895021795142: : 1it [00:00,  7.69it/s]

Validation loss did not improve


[Epoch: 47] | Batch 165 | Loss: 1.78738700921577: : 165it [00:22,  7.39it/s]  
[Epoch: 48] | Batch 1 | Loss: 1.787311897434913: : 1it [00:00,  7.66it/s]

Validation loss did not improve


[Epoch: 48] | Batch 165 | Loss: 1.778301730210131: : 165it [00:22,  7.47it/s] 
[Epoch: 49] | Batch 1 | Loss: 1.778224898158202: : 1it [00:00,  7.93it/s]

Validation loss did not improve


[Epoch: 49] | Batch 165 | Loss: 1.76729589740509: : 165it [00:22,  8.07it/s]  
[Epoch: 50] | Batch 1 | Loss: 1.7672296616735064: : 1it [00:00,  7.58it/s]

Validation loss did not improve


[Epoch: 50] | Batch 165 | Loss: 1.7564808769153826: : 165it [00:22,  7.44it/s]
[Epoch: 51] | Batch 1 | Loss: 1.7563786091782543: : 1it [00:00,  7.59it/s]

Validation loss did not improve


[Epoch: 51] | Batch 165 | Loss: 1.7459602473820524: : 165it [00:22,  7.44it/s]
[Epoch: 52] | Batch 1 | Loss: 1.7458806281533985: : 1it [00:00,  7.58it/s]

Validation loss did not improve


[Epoch: 52] | Batch 165 | Loss: 1.7355035868349609: : 165it [00:22,  7.45it/s]
[Epoch: 53] | Batch 1 | Loss: 1.7354498857621643: : 1it [00:00,  7.50it/s]

Validation loss did not improve


[Epoch: 53] | Batch 165 | Loss: 1.7253756389337107: : 165it [00:22,  8.00it/s]
[Epoch: 54] | Batch 1 | Loss: 1.725292866784213: : 1it [00:00,  7.72it/s]

Validation loss did not improve


[Epoch: 54] | Batch 165 | Loss: 1.7152688554976256: : 165it [00:22,  7.46it/s]
[Epoch: 55] | Batch 1 | Loss: 1.7151762772898738: : 1it [00:00,  7.39it/s]

Validation loss did not improve


[Epoch: 55] | Batch 165 | Loss: 1.7063412965559106: : 165it [00:22,  7.45it/s]
[Epoch: 56] | Batch 1 | Loss: 1.7062904555769207: : 1it [00:00,  7.75it/s]

Validation loss did not improve


[Epoch: 56] | Batch 165 | Loss: 1.6981247947071538: : 165it [00:22,  7.45it/s]
[Epoch: 57] | Batch 1 | Loss: 1.698061182519314: : 1it [00:00,  7.63it/s]

Validation loss did not improve


[Epoch: 57] | Batch 165 | Loss: 1.6888806663188194: : 165it [00:22,  7.44it/s]
[Epoch: 58] | Batch 1 | Loss: 1.6888418784870027: : 1it [00:00,  7.73it/s]

Validation loss did not improve


[Epoch: 58] | Batch 165 | Loss: 1.6794738707497576: : 165it [00:22,  7.98it/s]
[Epoch: 59] | Batch 1 | Loss: 1.6793699271999523: : 1it [00:00,  7.51it/s]

Validation loss did not improve


[Epoch: 59] | Batch 165 | Loss: 1.670440160058741: : 165it [00:22,  7.35it/s] 
[Epoch: 60] | Batch 1 | Loss: 1.670336339270716: : 1it [00:00,  7.27it/s]

Validation loss did not improve


[Epoch: 60] | Batch 165 | Loss: 1.6620793059317753: : 165it [00:22,  8.08it/s]
[Epoch: 61] | Batch 1 | Loss: 1.66203848580112: : 1it [00:00,  7.75it/s]

Validation loss did not improve


[Epoch: 61] | Batch 165 | Loss: 1.653540717205241: : 165it [00:22,  7.45it/s] 
[Epoch: 62] | Batch 1 | Loss: 1.6535270306039893: : 1it [00:00,  7.61it/s]

Validation loss did not improve


[Epoch: 62] | Batch 165 | Loss: 1.6449542480188615: : 165it [00:22,  7.34it/s]
[Epoch: 63] | Batch 1 | Loss: 1.644902999357282: : 1it [00:00,  7.65it/s]

Validation loss did not improve


[Epoch: 63] | Batch 165 | Loss: 1.6366717882658668: : 165it [00:22,  7.43it/s]
[Epoch: 64] | Batch 1 | Loss: 1.636598360537336: : 1it [00:00,  7.51it/s]

Validation loss did not improve


[Epoch: 64] | Batch 165 | Loss: 1.628527182447865: : 165it [00:22,  7.99it/s] 
[Epoch: 65] | Batch 1 | Loss: 1.628470430922953: : 1it [00:00,  7.56it/s]

Validation loss did not improve


[Epoch: 65] | Batch 165 | Loss: 1.6205005968811907: : 165it [00:22,  7.43it/s]
[Epoch: 66] | Batch 1 | Loss: 1.6204541911223964: : 1it [00:00,  6.13it/s]

Validation loss did not improve


[Epoch: 66] | Batch 165 | Loss: 1.6125389135792454: : 165it [00:22,  7.30it/s]
[Epoch: 67] | Batch 1 | Loss: 1.6125070644622443: : 1it [00:00,  7.58it/s]

Validation loss did not improve


[Epoch: 67] | Batch 165 | Loss: 1.6048126174820219: : 165it [00:22,  7.39it/s]
[Epoch: 68] | Batch 1 | Loss: 1.6047406076964439: : 1it [00:00,  7.67it/s]

Validation loss did not improve


[Epoch: 68] | Batch 165 | Loss: 1.5971147348005292: : 165it [00:22,  7.44it/s]
[Epoch: 69] | Batch 1 | Loss: 1.5970759535949186: : 1it [00:00,  6.67it/s]

Validation loss did not improve


[Epoch: 69] | Batch 165 | Loss: 1.5893980982484763: : 165it [00:21,  7.61it/s]
[Epoch: 70] | Batch 1 | Loss: 1.5893108561810387: : 1it [00:00,  7.86it/s]

Validation loss did not improve


[Epoch: 70] | Batch 165 | Loss: 1.5816218041702783: : 165it [00:21,  7.78it/s]
[Epoch: 71] | Batch 1 | Loss: 1.581570328283017: : 1it [00:00,  7.93it/s]

Validation loss did not improve


[Epoch: 71] | Batch 165 | Loss: 1.574378060297329: : 165it [00:21,  7.82it/s] 
[Epoch: 72] | Batch 1 | Loss: 1.5743028360265399: : 1it [00:00,  7.96it/s]

Validation loss did not improve


[Epoch: 72] | Batch 165 | Loss: 1.5673542351935448: : 165it [00:21,  7.73it/s]
[Epoch: 73] | Batch 1 | Loss: 1.5672981871889151: : 1it [00:00,  8.02it/s]

Validation loss did not improve


[Epoch: 73] | Batch 165 | Loss: 1.5599879439538225: : 165it [00:21,  7.80it/s]
[Epoch: 74] | Batch 1 | Loss: 1.5599264787955673: : 1it [00:00,  7.94it/s]

Validation loss did not improve


[Epoch: 74] | Batch 165 | Loss: 1.554434490650526: : 165it [00:21,  7.84it/s] 
[Epoch: 75] | Batch 1 | Loss: 1.554406454154387: : 1it [00:00,  7.78it/s]

Validation loss did not improve


[Epoch: 75] | Batch 165 | Loss: 1.552157662755311: : 165it [00:21,  7.82it/s] 
[Epoch: 76] | Batch 1 | Loss: 1.5521295880989137: : 1it [00:00,  7.98it/s]

Validation loss did not improve


[Epoch: 76] | Batch 165 | Loss: 1.547211384599764: : 165it [00:21,  6.04it/s] 
[Epoch: 77] | Batch 1 | Loss: 1.5471477542075802: : 1it [00:00,  8.05it/s]

Validation loss did not improve


[Epoch: 77] | Batch 165 | Loss: 1.541786287387775: : 165it [00:21,  7.72it/s] 
[Epoch: 78] | Batch 1 | Loss: 1.541754160824018: : 1it [00:00,  8.14it/s]

Validation loss did not improve


[Epoch: 78] | Batch 165 | Loss: 1.5358262051721832: : 165it [00:21,  7.81it/s]
[Epoch: 79] | Batch 1 | Loss: 1.5357806685573616: : 1it [00:00,  7.89it/s]

Validation loss did not improve


[Epoch: 79] | Batch 165 | Loss: 1.5297539680521746: : 165it [00:21,  7.80it/s]
[Epoch: 80] | Batch 1 | Loss: 1.5297023258730655: : 1it [00:00,  7.85it/s]

Validation loss did not improve


[Epoch: 80] | Batch 165 | Loss: 1.5238705982200125: : 165it [00:21,  7.82it/s]
[Epoch: 81] | Batch 1 | Loss: 1.5238217411303319: : 1it [00:00,  7.81it/s]

Validation loss did not improve


[Epoch: 81] | Batch 165 | Loss: 1.5179247578230415: : 165it [00:21,  7.79it/s]
[Epoch: 82] | Batch 1 | Loss: 1.517871746963126: : 1it [00:00,  7.96it/s]

Validation loss did not improve


[Epoch: 82] | Batch 165 | Loss: 1.5121649922768627: : 165it [00:21,  7.69it/s]
[Epoch: 83] | Batch 1 | Loss: 1.5121560641842842: : 1it [00:00,  7.90it/s]

Validation loss did not improve


[Epoch: 83] | Batch 165 | Loss: 1.506136975143895: : 165it [00:21,  7.78it/s] 
[Epoch: 84] | Batch 1 | Loss: 1.5061145554792796: : 1it [00:00,  7.77it/s]

Validation loss did not improve


[Epoch: 84] | Batch 165 | Loss: 1.5000338952735732: : 165it [00:21,  7.80it/s]
[Epoch: 85] | Batch 1 | Loss: 1.5000279477583314: : 1it [00:00,  7.75it/s]

Validation loss did not improve


[Epoch: 85] | Batch 165 | Loss: 1.4939132704879299: : 165it [00:21,  8.30it/s]
[Epoch: 86] | Batch 1 | Loss: 1.49385755197325: : 1it [00:00,  7.88it/s]

Validation loss did not improve


[Epoch: 86] | Batch 165 | Loss: 1.4880942448459098: : 165it [00:21,  7.80it/s]
[Epoch: 87] | Batch 1 | Loss: 1.4880406736915959: : 1it [00:00,  7.96it/s]

Validation loss did not improve


[Epoch: 87] | Batch 165 | Loss: 1.482752057410665: : 165it [00:21,  7.81it/s] 
[Epoch: 88] | Batch 1 | Loss: 1.4827582842295932: : 1it [00:00,  8.03it/s]

Validation loss did not improve


[Epoch: 88] | Batch 165 | Loss: 1.4775297784115657: : 165it [00:21,  7.77it/s]
[Epoch: 89] | Batch 1 | Loss: 1.4774961207870378: : 1it [00:00,  7.81it/s]

Validation loss did not improve


[Epoch: 89] | Batch 165 | Loss: 1.4720291212784582: : 165it [00:21,  7.73it/s]
[Epoch: 90] | Batch 1 | Loss: 1.4719826451572624: : 1it [00:00,  7.67it/s]

Validation loss did not improve


[Epoch: 90] | Batch 165 | Loss: 1.466448855536554: : 165it [00:21,  7.76it/s] 
[Epoch: 91] | Batch 1 | Loss: 1.466417572264093: : 1it [00:00,  7.72it/s]

Validation loss did not improve


[Epoch: 91] | Batch 165 | Loss: 1.4611330969151837: : 165it [00:21,  7.79it/s]
[Epoch: 92] | Batch 1 | Loss: 1.4610994389726244: : 1it [00:00,  8.01it/s]

Validation loss did not improve


[Epoch: 92] | Batch 165 | Loss: 1.4558628086751628: : 165it [00:21,  7.54it/s]
[Epoch: 93] | Batch 1 | Loss: 1.4558635553849733: : 1it [00:00,  7.90it/s]

Validation loss did not improve


[Epoch: 93] | Batch 165 | Loss: 1.4505900113638546: : 165it [00:21,  7.57it/s]
[Epoch: 94] | Batch 1 | Loss: 1.4505540410756568: : 1it [00:00,  7.27it/s]

Validation loss did not improve


[Epoch: 94] | Batch 165 | Loss: 1.4460338171184792: : 165it [00:22,  7.28it/s]
[Epoch: 95] | Batch 1 | Loss: 1.4460249675879602: : 1it [00:00,  7.86it/s]

Validation loss did not improve


[Epoch: 95] | Batch 165 | Loss: 1.4415285838695995: : 165it [00:21,  7.77it/s]
[Epoch: 96] | Batch 1 | Loss: 1.441493819611022: : 1it [00:00,  8.01it/s]

Validation loss did not improve


[Epoch: 96] | Batch 165 | Loss: 1.4365802163314638: : 165it [00:21,  7.81it/s]
[Epoch: 97] | Batch 1 | Loss: 1.4365426190543165: : 1it [00:00,  7.71it/s]

Validation loss did not improve


[Epoch: 97] | Batch 165 | Loss: 1.4315814364593278: : 165it [00:21,  7.77it/s]
[Epoch: 98] | Batch 1 | Loss: 1.4315496772531418: : 1it [00:00,  8.01it/s]

Validation loss did not improve


[Epoch: 98] | Batch 165 | Loss: 1.4265581943477166: : 165it [00:21,  7.53it/s]
[Epoch: 99] | Batch 1 | Loss: 1.4265506145396971: : 1it [00:00,  7.51it/s]

Validation loss did not improve


[Epoch: 99] | Batch 165 | Loss: 1.421549510065481: : 165it [00:21,  7.68it/s] 
[Epoch: 100] | Batch 1 | Loss: 1.4215212899290324: : 1it [00:00,  8.00it/s]

Validation loss did not improve


[Epoch: 100] | Batch 165 | Loss: 1.4165301036617972: : 165it [00:21,  7.77it/s]
[Epoch: 101] | Batch 1 | Loss: 1.4165118559729901: : 1it [00:00,  7.99it/s]

Validation loss did not improve


[Epoch: 101] | Batch 165 | Loss: 1.411672870718678: : 165it [00:21,  7.75it/s] 
[Epoch: 102] | Batch 1 | Loss: 1.4116329455610284: : 1it [00:00,  7.88it/s]

Validation loss did not improve


[Epoch: 102] | Batch 165 | Loss: 1.4070443191063666: : 165it [00:21,  7.78it/s]
[Epoch: 103] | Batch 1 | Loss: 1.4070177130977934: : 1it [00:00,  7.61it/s]

Validation loss did not improve


[Epoch: 103] | Batch 165 | Loss: 1.4024110711753701: : 165it [00:21,  7.79it/s]
[Epoch: 104] | Batch 1 | Loss: 1.4023765362194214: : 1it [00:00,  7.64it/s]

Validation loss did not improve


[Epoch: 104] | Batch 165 | Loss: 1.3977168294585947: : 165it [00:21,  7.73it/s]
[Epoch: 105] | Batch 1 | Loss: 1.3976743130700293: : 1it [00:00,  7.90it/s]

Validation loss did not improve


[Epoch: 105] | Batch 165 | Loss: 1.3932014446255105: : 165it [00:21,  7.79it/s]
[Epoch: 106] | Batch 1 | Loss: 1.393165172228331: : 1it [00:00,  7.80it/s]

Validation loss did not improve


[Epoch: 106] | Batch 165 | Loss: 1.388596670928514: : 165it [00:21,  7.78it/s] 
[Epoch: 107] | Batch 1 | Loss: 1.388573562788463: : 1it [00:00,  7.82it/s]

Validation loss did not improve


[Epoch: 107] | Batch 165 | Loss: 1.3841804421258628: : 165it [00:21,  7.81it/s]
[Epoch: 108] | Batch 1 | Loss: 1.384164927099415: : 1it [00:00,  7.80it/s]

Validation loss did not improve


[Epoch: 108] | Batch 165 | Loss: 1.38009269714991: : 165it [00:21,  7.75it/s]  
[Epoch: 109] | Batch 1 | Loss: 1.3800688570269257: : 1it [00:00,  7.97it/s]

Validation loss did not improve


[Epoch: 109] | Batch 165 | Loss: 1.3760729330726753: : 165it [00:21,  7.56it/s]
[Epoch: 110] | Batch 1 | Loss: 1.3760372290923573: : 1it [00:00,  7.11it/s]

Validation loss did not improve


[Epoch: 110] | Batch 165 | Loss: 1.3716606187393514: : 165it [00:21,  7.75it/s]
[Epoch: 111] | Batch 1 | Loss: 1.371626319343292: : 1it [00:00,  7.91it/s]

Validation loss did not improve


[Epoch: 111] | Batch 165 | Loss: 1.367299912640087: : 165it [00:21,  6.72it/s] 
[Epoch: 112] | Batch 1 | Loss: 1.3672570129553416: : 1it [00:00,  7.83it/s]

Validation loss did not improve


[Epoch: 112] | Batch 165 | Loss: 1.3630449921053984: : 165it [00:21,  7.70it/s]
[Epoch: 113] | Batch 1 | Loss: 1.3630145578942743: : 1it [00:00,  7.87it/s]

Validation loss did not improve


[Epoch: 113] | Batch 165 | Loss: 1.3587497192546105: : 165it [00:21,  7.74it/s]
[Epoch: 114] | Batch 1 | Loss: 1.3587288650370688: : 1it [00:00,  7.74it/s]

Validation loss did not improve


[Epoch: 114] | Batch 165 | Loss: 1.3549368955666306: : 165it [00:21,  7.73it/s]
[Epoch: 115] | Batch 1 | Loss: 1.3549029950043985: : 1it [00:00,  7.97it/s]

Validation loss did not improve


[Epoch: 115] | Batch 165 | Loss: 1.3508107742694684: : 165it [00:21,  7.74it/s]
[Epoch: 116] | Batch 1 | Loss: 1.35078447772124: : 1it [00:00,  7.87it/s]

Validation loss did not improve


[Epoch: 116] | Batch 165 | Loss: 1.346493605725942: : 165it [00:21,  7.77it/s] 
[Epoch: 117] | Batch 1 | Loss: 1.3464583321790027: : 1it [00:00,  7.86it/s]

Validation loss did not improve


[Epoch: 117] | Batch 165 | Loss: 1.3428550731731308: : 165it [00:21,  7.66it/s]
[Epoch: 118] | Batch 1 | Loss: 1.3428271843999775: : 1it [00:00,  7.78it/s]

Validation loss did not improve


[Epoch: 118] | Batch 165 | Loss: 1.338781065260011: : 165it [00:21,  6.77it/s] 
[Epoch: 119] | Batch 1 | Loss: 1.338759555072656: : 1it [00:00,  6.60it/s]

Validation loss did not improve


[Epoch: 119] | Batch 165 | Loss: 1.3347384366850712: : 165it [00:21,  7.55it/s]
[Epoch: 120] | Batch 1 | Loss: 1.334710666253771: : 1it [00:00,  7.89it/s]

Validation loss did not improve


[Epoch: 120] | Batch 165 | Loss: 1.3306520347763793: : 165it [00:22,  7.47it/s]
[Epoch: 121] | Batch 1 | Loss: 1.3306240499197224: : 1it [00:00,  7.50it/s]

Validation loss did not improve


[Epoch: 121] | Batch 165 | Loss: 1.3267031887331135: : 165it [00:21,  7.70it/s]
[Epoch: 122] | Batch 1 | Loss: 1.3266743452139038: : 1it [00:00,  7.75it/s]

Validation loss did not improve


[Epoch: 122] | Batch 165 | Loss: 1.323846699786731: : 165it [00:21,  7.78it/s] 
[Epoch: 123] | Batch 1 | Loss: 1.3238389876248151: : 1it [00:00,  7.34it/s]

Validation loss did not improve


[Epoch: 123] | Batch 165 | Loss: 1.32094875839462: : 165it [00:21,  7.73it/s]  
[Epoch: 124] | Batch 1 | Loss: 1.3209303974015745: : 1it [00:00,  7.98it/s]

Validation loss did not improve


[Epoch: 124] | Batch 165 | Loss: 1.317348076538606: : 165it [00:21,  7.77it/s] 
[Epoch: 125] | Batch 1 | Loss: 1.3173320845397976: : 1it [00:00,  7.73it/s]

Validation loss did not improve


[Epoch: 125] | Batch 165 | Loss: 1.3136375039794228: : 165it [00:22,  7.50it/s]
[Epoch: 126] | Batch 1 | Loss: 1.313603371323964: : 1it [00:00,  7.91it/s]

Validation loss did not improve


[Epoch: 126] | Batch 161 | Loss: 1.3100012689757183: : 161it [00:24,  6.70it/s]


KeyboardInterrupt: 

In [20]:
preds.max(dim=1)

(tensor([0.9428, 1.2849, 0.8000, 0.7982, 4.9522, 1.0544, 1.2595, 1.0704, 1.5669,
         0.7547], grad_fn=<MaxBackward0>),
 tensor([21,  0, 20, 20, 12,  0, 20,  6, 20, 17]))

In [21]:
batch_labels

tensor([ 1, 18, 16,  9, 11, 16,  7, 13,  1, 10])

In [23]:
data['label_encoder'].classes_[16]

'stand-to-sit'

In [24]:
data['label_encoder'].classes_[17]

'v-cut-left-Lfirst'

In [33]:
preds.max(dim=1)[1].numpy()

array([21,  9, 17, ...,  8,  0,  4])

In [3]:
np.arange(0, 10, 3)

array([0, 3, 6, 9])