In [1]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.externals import joblib
from collections import Counter
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

In [2]:
data = joblib.load('data.pkl')

In [3]:
data_dir = 'bbdc_2019_Bewegungsdaten/'

In [4]:
data.keys()

dict_keys(['features', 'labels', 'bounds', 'lens', 'label_encoder'])

In [5]:
def standard_scale(data):
    features = data['features']
    lens = np.concatenate((data['lens']['train'],data['lens']['valid'],data['lens']['test']))
    features = np.concatenate(features, axis=0)
    features = StandardScaler().fit_transform(features)
    ret = []
    l = 0
    for t in lens:
        ret.append(features[l: l + t])
        l += t
    features = np.array(ret)
    data['features'] = features

In [6]:
standard_scale(data)

In [7]:
def get_split(data, split):
    if split == 'train':
        return data['features'][: data['bounds']['train']]
    elif split == 'valid':
        return data['features'][data['bounds']['train']: data['bounds']['train'] + data['bounds']['valid']]
    elif split == 'test':
        return data['features'][data['bounds']['train'] + data['bounds']['valid']:]

In [8]:
def generate_batch_idx(n, batch_size, randomise=False):
    idx = np.arange(0, n)
    if randomise:
        np.random.shuffle(idx)
    for batch_idx in np.arange(0, n, batch_size):
        yield idx[batch_idx:batch_idx+batch_size]

In [9]:
def generate_batches(data, split, batch_size,
                     skip=5, randomise=False):
    features = get_split(data, split)

    try:
        labels = data['label_encoder'].transform(data['labels'][split])
    except:
        labels = np.zeros(features.shape[0])
    
    lens = data['lens'][split]
    new_features = []
    new_labels = []
    new_lens = []
    for i in range(len(lens)):
        mat = features[i]
        label = labels[i]
        l = lens[i]
        idx = np.arange(0, mat.shape[0], skip)
        new_features.append(mat[idx,:])
        new_labels.append(label)
        new_lens.append(idx.shape[0])
    
    features = np.array(new_features)
    labels = np.array(new_labels)
    lens = np.array(new_lens)
    
    n = len(features)
    for batch_idx in generate_batch_idx(n, batch_size, randomise):
        batch_data = features[batch_idx]
        batch_labels = labels[batch_idx]
        batch_lens = lens[batch_idx]
        yield batch_data, batch_labels, batch_lens

In [10]:
# def sort_batch(batch, targets, lengths):
#     """
#     Sort a minibatch by the length of the sequences with the longest sequences first
#     return the sorted batch targes and sequence lengths.
#     This way the output can be used by pack_padded_sequences(...)
#     """
#     perm_idx = np.argsort(lengths)[::-1]
#     seq_lengths = lengths[perm_idx]
#     seq_tensor = batch[perm_idx]
#     target_tensor = targets[perm_idx]
#     return seq_tensor, target_tensor, seq_lengths

def pad_batch(batch, lens):
    max_len = max(lens)
    batch_size = batch.shape[0]
    num_feature = batch[0].shape[1]
    padded_seqs = np.zeros((batch_size, max_len, num_feature))
    
    for i, l in enumerate(lens):
        padded_seqs[i, :l, :] = batch[i][:l]

    return padded_seqs

In [11]:
def torch_batch(batch, targets):
    return torch.from_numpy(batch).float(), torch.from_numpy(targets).long()

In [12]:
def get_preds(model, data, split, batch_size, stride):
    model.eval()
    preds = []
    labels = []
    with torch.no_grad():
        for b_data, b_labels, b_lens in generate_batches(data, split, batch_size, 
                                                         stride, False):
            b_data = pad_batch(b_data, b_lens)
            b_data, b_labels = torch_batch(b_data, b_labels)
            preds.append(model(b_data, b_lens))
            labels.append(b_labels)
    preds = torch.cat(preds, dim=0)
    labels = torch.cat(labels, dim=0)
    return preds, labels

In [13]:
def get_accuracy(preds, labels, le):
    preds = preds.max(dim=1)[1].numpy()
    preds = [le.classes_[i] for i in preds]
    labels = labels.numpy()
    labels = [le.classes_[i] for i in labels]
    return accuracy_score(labels, preds)

In [37]:
class HARNet(nn.Module):
    def __init__(self, time_steps):
        super().__init__()
        self.rnn = nn.GRU(14, 64)
        self.lin2 = nn.Linear(64, 64, bias=True)
        self.lin3 = nn.Linear(64, 22, bias=True)
    
    def forward(self, data, lens):
        x, _ = self.rnn(data[:, :, 5:])
        x = torch.cat([seq[l-1].view(1, -1) for l, seq in zip(lens, x)], dim=0)
        x = self.lin2(x)
        x = F.relu(x)
        x = self.lin3(x)
        return x

In [39]:
num_epochs = 10000
batch_size = 32
stride = 5
objective = nn.CrossEntropyLoss()
model = HARNet(time_steps)
optimiser = torch.optim.Adam(model.parameters(), weight_decay=0.0)
running_loss = 0
running_batch = 0
min_valid_loss = float('inf')

for epoch in range(1, num_epochs + 1):
    with tqdm(enumerate(generate_batches(data, 'valid', batch_size,
                                         stride, True), 1)) as pbar:
        model.train()
        for batch_num, (batch_data, batch_labels, batch_lens) in pbar:
            batch_data = pad_batch(batch_data, batch_lens)
            batch_data, batch_labels = torch_batch(batch_data, batch_labels)
            optimiser.zero_grad()
            preds = model(batch_data, batch_lens)
            loss = objective(preds, batch_labels)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 2, norm_type=2)
            loss.backward()
            optimiser.step()
            running_loss += loss.item()
            running_batch += 1
            pbar.set_description(f'[Epoch: {epoch}] | Batch {batch_num} | Loss: {running_loss/running_batch}')
            
    valid_preds, valid_labels = get_preds(model, data, 'valid',
                                          64, stride)
    
    valid_loss = objective(valid_preds, valid_labels).item()
    
    if valid_loss < min_valid_loss:
        print(f'Validation loss improved from {min_valid_loss} to {valid_loss}')
        acc = get_accuracy(valid_preds, valid_labels, data['label_encoder'])
        print(f'Validation accuracy: {acc}')
        min_valid_loss = valid_loss
        with open('best_rnn_model.pt', 'wb') as f:
            torch.save(model.state_dict(), f)
    else:
        print('Validation loss did not improve')

[Epoch: 1] | Batch 36 | Loss: 3.073989040321774: : 36it [00:19,  1.89it/s] 
0it [00:00, ?it/s]

Validation loss improved from inf to 3.029745101928711
Validation accuracy: 0.09849157054125998


[Epoch: 2] | Batch 36 | Loss: 3.0157430701785617: : 36it [00:18,  1.99it/s]
0it [00:00, ?it/s]

Validation loss improved from 3.029745101928711 to 2.8421804904937744
Validation accuracy: 0.12244897959183673


[Epoch: 3] | Batch 36 | Loss: 2.938975621152807: : 36it [00:17,  2.02it/s] 
0it [00:00, ?it/s]

Validation loss improved from 2.8421804904937744 to 2.69958758354187
Validation accuracy: 0.14995563442768411


[Epoch: 4] | Batch 36 | Loss: 2.868657206495603: : 36it [00:18,  1.97it/s] 
0it [00:00, ?it/s]

Validation loss improved from 2.69958758354187 to 2.505964994430542
Validation accuracy: 0.2067435669920142


[Epoch: 5] | Batch 36 | Loss: 2.7836705300543043: : 36it [00:18,  1.99it/s]
0it [00:00, ?it/s]

Validation loss improved from 2.505964994430542 to 2.296405076980591
Validation accuracy: 0.2573203194321207


[Epoch: 6] | Batch 36 | Loss: 2.7050909869096897: : 36it [00:18,  1.94it/s]
0it [00:00, ?it/s]

Validation loss improved from 2.296405076980591 to 2.1602771282196045
Validation accuracy: 0.25643300798580304


[Epoch: 7] | Batch 36 | Loss: 2.6336016451555584: : 36it [00:19,  1.81it/s]
0it [00:00, ?it/s]

Validation loss improved from 2.1602771282196045 to 2.0873987674713135
Validation accuracy: 0.29458740017746227


[Epoch: 8] | Batch 36 | Loss: 2.572827105720838: : 36it [00:18,  1.93it/s] 
0it [00:00, ?it/s]

Validation loss improved from 2.0873987674713135 to 2.052206039428711
Validation accuracy: 0.29991126885536823


[Epoch: 9] | Batch 36 | Loss: 2.5207819213837754: : 36it [00:18,  1.98it/s]
0it [00:00, ?it/s]

Validation loss improved from 2.052206039428711 to 1.983601450920105
Validation accuracy: 0.3132209405501331


[Epoch: 10] | Batch 36 | Loss: 2.4739689336882695: : 36it [00:18,  1.92it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 11] | Batch 36 | Loss: 2.4336269732677573: : 36it [00:20,  1.79it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.983601450920105 to 1.9380539655685425
Validation accuracy: 0.3531499556344277


[Epoch: 12] | Batch 36 | Loss: 2.399449874681455: : 36it [00:18,  1.94it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 13] | Batch 36 | Loss: 2.3679552536744337: : 36it [00:18,  1.96it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.9380539655685425 to 1.8757035732269287
Validation accuracy: 0.3744454303460515


[Epoch: 14] | Batch 36 | Loss: 2.337828344768948: : 36it [00:18,  1.92it/s] 
0it [00:00, ?it/s]

Validation loss improved from 1.8757035732269287 to 1.8289741277694702
Validation accuracy: 0.3824312333629104


[Epoch: 15] | Batch 36 | Loss: 2.3111490746339163: : 36it [00:18,  1.96it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.8289741277694702 to 1.7901771068572998
Validation accuracy: 0.4010647737355812


[Epoch: 16] | Batch 36 | Loss: 2.286141916695568: : 36it [00:18,  1.99it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 17] | Batch 36 | Loss: 2.262679155550751: : 36it [00:18,  1.99it/s] 
0it [00:00, ?it/s]

Validation loss improved from 1.7901771068572998 to 1.7365647554397583
Validation accuracy: 0.43744454303460517


[Epoch: 18] | Batch 36 | Loss: 2.238934001988835: : 36it [00:18,  1.94it/s] 
0it [00:00, ?it/s]

Validation loss improved from 1.7365647554397583 to 1.711935043334961
Validation accuracy: 0.4321206743566992


[Epoch: 19] | Batch 36 | Loss: 2.2150668853904767: : 36it [00:17,  2.01it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.711935043334961 to 1.6605236530303955
Validation accuracy: 0.4418811002661934


[Epoch: 20] | Batch 36 | Loss: 2.1933689494927724: : 36it [00:18,  1.94it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.6605236530303955 to 1.6445105075836182
Validation accuracy: 0.45075421472937


[Epoch: 21] | Batch 36 | Loss: 2.172361069255405: : 36it [00:18,  1.97it/s] 
0it [00:00, ?it/s]

Validation loss improved from 1.6445105075836182 to 1.619515299797058
Validation accuracy: 0.45430346051464066


[Epoch: 22] | Batch 36 | Loss: 2.150661195468421: : 36it [00:19,  1.85it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 23] | Batch 36 | Loss: 2.1297866385340116: : 36it [00:17,  2.01it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.619515299797058 to 1.5414314270019531
Validation accuracy: 0.4862466725820763


[Epoch: 24] | Batch 36 | Loss: 2.110472594284349: : 36it [00:18,  2.58it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 25] | Batch 36 | Loss: 2.09272279686398: : 36it [00:18,  1.91it/s]  
0it [00:00, ?it/s]

Validation loss improved from 1.5414314270019531 to 1.492256999015808
Validation accuracy: 0.5004436557231589


[Epoch: 26] | Batch 36 | Loss: 2.0750573822575755: : 36it [00:18,  1.96it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.492256999015808 to 1.4650413990020752
Validation accuracy: 0.5057675244010648


[Epoch: 27] | Batch 36 | Loss: 2.057118269764347: : 36it [00:17,  2.01it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 28] | Batch 36 | Loss: 2.038963364939841: : 36it [00:19,  1.84it/s] 
0it [00:00, ?it/s]

Validation loss improved from 1.4650413990020752 to 1.4376190900802612
Validation accuracy: 0.5022182786157942


[Epoch: 29] | Batch 36 | Loss: 2.022542560351763: : 36it [00:17,  2.01it/s] 
0it [00:00, ?it/s]

Validation loss improved from 1.4376190900802612 to 1.4093499183654785
Validation accuracy: 0.5164152617568767


[Epoch: 30] | Batch 36 | Loss: 2.0059462621256157: : 36it [00:18,  1.91it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.4093499183654785 to 1.3820074796676636
Validation accuracy: 0.5199645075421473


[Epoch: 31] | Batch 36 | Loss: 1.9897874696280367: : 36it [00:18,  1.96it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.3820074796676636 to 1.3562664985656738
Validation accuracy: 0.5368234250221828


[Epoch: 32] | Batch 36 | Loss: 1.9731516257549326: : 36it [00:18,  1.99it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.3562664985656738 to 1.3495666980743408
Validation accuracy: 0.5217391304347826


[Epoch: 33] | Batch 36 | Loss: 1.9593594096325062: : 36it [00:18,  1.97it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 34] | Batch 36 | Loss: 1.9454453556366216: : 36it [00:17,  2.04it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.3495666980743408 to 1.3168466091156006
Validation accuracy: 0.5536823425022183


[Epoch: 35] | Batch 36 | Loss: 1.9318115951522947: : 36it [00:18,  1.98it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.3168466091156006 to 1.3043681383132935
Validation accuracy: 0.5501330967169477


[Epoch: 36] | Batch 36 | Loss: 1.9188400379577537: : 36it [00:18,  1.99it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.3043681383132935 to 1.2966817617416382
Validation accuracy: 0.5563442768411713


[Epoch: 37] | Batch 36 | Loss: 1.9043906098818995: : 36it [00:18,  1.96it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 38] | Batch 36 | Loss: 1.8917683393895974: : 36it [00:18,  2.00it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.2966817617416382 to 1.26130211353302
Validation accuracy: 0.5634427684117125


[Epoch: 39] | Batch 36 | Loss: 1.8781211357510668: : 36it [00:18,  2.00it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 40] | Batch 36 | Loss: 1.8647837036185795: : 36it [00:18,  1.96it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.26130211353302 to 1.2363808155059814
Validation accuracy: 0.577639751552795


[Epoch: 41] | Batch 36 | Loss: 1.8517591049516104: : 36it [00:22,  1.33it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.2363808155059814 to 1.2127554416656494
Validation accuracy: 0.5811889973380656


[Epoch: 42] | Batch 36 | Loss: 1.8399063578200718: : 36it [00:28,  1.26it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 43] | Batch 36 | Loss: 1.8293299906013547: : 36it [00:36,  1.34it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 44] | Batch 36 | Loss: 1.8179723211009093: : 36it [00:27,  1.71it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.2127554416656494 to 1.175963044166565
Validation accuracy: 0.5980479148181012


[Epoch: 45] | Batch 9 | Loss: 1.814711975849281: : 9it [00:05,  1.49it/s] 


KeyboardInterrupt: 

In [None]:
np.arange(0, 6000, 10).shape