In [1]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.externals import joblib
from collections import Counter
from scipy import signal
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

In [2]:
data = joblib.load('data.pkl')

In [3]:
data.keys()

dict_keys(['features', 'labels', 'bounds', 'lens', 'label_encoder'])

In [4]:
# def standard_scale(data):
#     features = data['features']
#     lens = np.concatenate((data['lens']['train'],data['lens']['valid'],data['lens']['test']))
#     features = np.concatenate(features, axis=0)
#     features = StandardScaler().fit_transform(features)
#     ret = []
#     l = 0
#     for t in lens:
#         ret.append(features[l: l + t])
#         l += t
#     features = np.array(ret)
#     data['features'] = features

In [5]:
def filter_emg(x):
    high = 20/(1000/2)
    low = 450/(1000/2)
    b, a = signal.butter(4, [high, low], btype='bandpass')
    emg_filtered = signal.filtfilt(b, a, x, axis=0)
    return emg_filtered

In [6]:
# standard_scale(data)

In [7]:
def get_split(data, split):
    if split == 'train':
        return data['features'][: data['bounds']['train']]
    elif split == 'valid':
        return data['features'][data['bounds']['train']: data['bounds']['train'] + data['bounds']['valid']]
    elif split == 'test':
        return data['features'][data['bounds']['train'] + data['bounds']['valid']:]

In [8]:
def generate_batch_idx(n, batch_size, randomise=False):
    idx = np.arange(0, n)
    if randomise:
        np.random.shuffle(idx)
    for batch_idx in np.arange(0, n, batch_size):
        yield idx[batch_idx:batch_idx+batch_size]

In [9]:
def generate_batches(data, split, batch_size,
                     time_steps=10, stride=5, randomise=False):
    features = get_split(data, split)

    try:
        labels = data['label_encoder'].transform(data['labels'][split])
    except:
        labels = np.zeros(features.shape[0])
    
    lens = data['lens'][split]
    new_features = []
    new_labels = []
    new_lens = []
    for i in range(len(lens)):
        mat = features[i]
        label = labels[i]
        l = lens[i]
        acc_emg = [0, 1, 2, 3, 5, 6, 7, 9, 10, 11]
        mat[:, acc_emg] -= mat[:, acc_emg].mean(axis=0)
        mat[:, :4] = filter_emg(mat[:, :4])
        extracted_steps = []
        for j in range(0, len(mat) - time_steps, stride):
            window = mat[j: j + time_steps, :]
            means = window[:, 4:].mean(axis=0).reshape(1, -1)
            rms = np.sqrt((window[:, :4]**2).mean(axis=0)).reshape(1, -1)
            feature_vector = np.concatenate((rms, means), axis=1).reshape(1, -1)
            
            extracted_steps.append(feature_vector)
        extracted_steps = np.concatenate(extracted_steps, axis=0)
        new_features.append(extracted_steps)
        new_labels.append(label)
        new_lens.append(len(extracted_steps))
    features = np.array(new_features)
    labels = np.array(new_labels)
    lens = np.array(new_lens)
    
    n = len(features)
    for batch_idx in generate_batch_idx(n, batch_size, randomise):
        batch_data = features[batch_idx]
        batch_labels = labels[batch_idx]
        batch_lens = lens[batch_idx]
#         batch_data = torch.from_numpy(batch_data).float()
#         batch_labels = torch.from_numpy(labels[batch_idx]).float()
#         lens = torch.from_numpy(lens[batch_idx]).float()
        yield batch_data, batch_labels, batch_lens

In [10]:
# def sort_batch(batch, targets, lengths):
#     """
#     Sort a minibatch by the length of the sequences with the longest sequences first
#     return the sorted batch targes and sequence lengths.
#     This way the output can be used by pack_padded_sequences(...)
#     """
#     perm_idx = np.argsort(lengths)[::-1]
#     seq_lengths = lengths[perm_idx]
#     seq_tensor = batch[perm_idx]
#     target_tensor = targets[perm_idx]
#     return seq_tensor, target_tensor, seq_lengths

def pad_batch(batch, lens):
    max_len = max(lens)
    batch_size = batch.shape[0]
    num_feature = batch[0].shape[1]
    padded_seqs = np.zeros((batch_size, max_len, num_feature))
    
    for i, l in enumerate(lens):
        padded_seqs[i, :l, :] = batch[i][:l]

    return padded_seqs

In [11]:
def torch_batch(batch, targets):
    return torch.from_numpy(batch).float(), torch.from_numpy(targets).long()

In [12]:
def get_preds(model, data, split, batch_size, time_steps, stride):
    model.eval()
    preds = []
    labels = []
    with torch.no_grad():
        for b_data, b_labels, b_lens in generate_batches(data, split, batch_size, 
                                                         time_steps, stride, False):
            b_data = pad_batch(b_data, b_lens)
            b_data, b_labels = torch_batch(b_data, b_labels)
            preds.append(model(b_data, b_lens))
            labels.append(b_labels)
        preds = torch.cat(preds, dim=0)
        labels = torch.cat(labels, dim=0)
    return preds, labels

In [13]:
def get_accuracy(preds, labels, le):
    preds = preds.max(dim=1)[1].numpy()
    preds = [le.classes_[i] for i in preds]
    labels = labels.numpy()
    labels = [le.classes_[i] for i in labels]
    return accuracy_score(labels, preds)

In [14]:
class HARNet(nn.Module):
    def __init__(self, ):
        super().__init__()
        self.conv1 = nn.Conv1d(19, 64, 8)
        self.conv2 = nn.Conv1d(64, 128, 8)
        self.maxpool1 = nn.MaxPool1d(4)
        self.conv3 = nn.Conv1d(128, 64, 8)
        self.conv4 = nn.Conv1d(64, 64, 4)
#         self.avgpool1 = nn.AvgPool1d(4)
        self.dropout = nn.Dropout(.2)
        self.lin1 = nn.Linear(64*3,22)
        
    def forward(self, data, lens):
        data = data.transpose(1,2)
        x = self.conv1(data)
        x = F.relu(x)
        x = self.conv2(x)
        x = self.maxpool1(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = self.conv4(x)
        x = torch.cat((x.mean(dim=2), x.min(dim=2)[0], x.max(dim=2)[0]), dim=1)
        x = self.dropout(x)
        x = F.relu(x)
        x = self.lin1(x)
        return x

In [28]:
num_epochs = 10000
batch_size = 32
time_steps = 100
stride = 40
objective = nn.CrossEntropyLoss()
model = HARNet()
optimiser = torch.optim.Adam(model.parameters(), weight_decay=0.0001)
running_loss = 0
running_batch = 0
min_valid_loss = float('inf')

for epoch in range(1, num_epochs + 1):
    with tqdm(enumerate(generate_batches(data, 'train', batch_size,
                                         time_steps, stride, True), 1)) as pbar:
        model.train()
        for batch_num, (batch_data, batch_labels, batch_lens) in pbar:
            batch_data = pad_batch(batch_data, batch_lens)
            batch_data, batch_labels = torch_batch(batch_data, batch_labels)
            optimiser.zero_grad()
            preds = model(batch_data, batch_lens)
            loss = objective(preds, batch_labels)
            loss.backward()
            optimiser.step()
            running_loss += loss.item()
            running_batch += 1
            pbar.set_description(f'[Epoch: {epoch}] | Batch {batch_num} | Loss: {running_loss/running_batch}')
            
    valid_preds, valid_labels = get_preds(model, data, 'valid',
                                          64, time_steps, stride)
    
    valid_loss = objective(valid_preds, valid_labels).item()
    
    if valid_loss < min_valid_loss:
        print(f'Validation loss improved from {min_valid_loss} to {valid_loss}')
        acc = get_accuracy(valid_preds, valid_labels, data['label_encoder'])
        print(f'Validation accuracy: {acc}')
        min_valid_loss = valid_loss
        with open('best_cnn_model.pt', 'wb') as f:
            torch.save(model.state_dict(), f)
    else:
        print('Validation loss did not improve')

[Epoch: 1] | Batch 180 | Loss: 15.052666171391804: : 180it [00:38, 13.83it/s]
0it [00:00, ?it/s]

Validation loss improved from inf to 2.8088064193725586
Validation accuracy: 0.1189358372456964


[Epoch: 2] | Batch 180 | Loss: 8.83424322406451: : 180it [00:38,  4.68it/s]  
0it [00:00, ?it/s]

Validation loss improved from 2.8088064193725586 to 2.335017204284668
Validation accuracy: 0.18466353677621283


[Epoch: 3] | Batch 180 | Loss: 6.670256291495429: : 180it [00:42,  4.20it/s] 
0it [00:00, ?it/s]

Validation loss improved from 2.335017204284668 to 2.2054688930511475
Validation accuracy: 0.2112676056338028


[Epoch: 4] | Batch 180 | Loss: 5.560101707610819: : 180it [00:39,  4.54it/s] 
0it [00:00, ?it/s]

Validation loss improved from 2.2054688930511475 to 2.078712224960327
Validation accuracy: 0.25508607198748046


[Epoch: 5] | Batch 180 | Loss: 4.882208083205753: : 180it [00:38,  4.71it/s] 
0it [00:00, ?it/s]

Validation loss improved from 2.078712224960327 to 2.0682897567749023
Validation accuracy: 0.28482003129890454


[Epoch: 6] | Batch 180 | Loss: 4.425819495872215: : 180it [00:46,  3.88it/s] 
0it [00:00, ?it/s]

Validation loss improved from 2.0682897567749023 to 2.0609207153320312
Validation accuracy: 0.25508607198748046


[Epoch: 7] | Batch 180 | Loss: 4.0873875406053335: : 180it [00:47,  3.76it/s]
0it [00:00, ?it/s]

Validation loss improved from 2.0609207153320312 to 1.9338470697402954
Validation accuracy: 0.3348982785602504


[Epoch: 8] | Batch 180 | Loss: 3.8246031049225064: : 180it [00:48,  3.74it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.9338470697402954 to 1.86893892288208
Validation accuracy: 0.3380281690140845


[Epoch: 9] | Batch 180 | Loss: 3.6220894274152355: : 180it [00:58,  5.06it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.86893892288208 to 1.786018967628479
Validation accuracy: 0.37402190923317685


[Epoch: 10] | Batch 180 | Loss: 3.4521975760327446: : 180it [01:22,  4.38it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.786018967628479 to 1.7749454975128174
Validation accuracy: 0.36306729264475746


[Epoch: 11] | Batch 180 | Loss: 3.306398932139079: : 180it [01:19,  4.40it/s] 
0it [00:00, ?it/s]

Validation loss improved from 1.7749454975128174 to 1.7451066970825195
Validation accuracy: 0.41471048513302033


[Epoch: 12] | Batch 180 | Loss: 3.176460279138: : 180it [01:19,  4.28it/s]    
0it [00:00, ?it/s]

Validation loss improved from 1.7451066970825195 to 1.6876899003982544
Validation accuracy: 0.41471048513302033


[Epoch: 13] | Batch 180 | Loss: 3.0636081647159705: : 180it [01:13,  4.54it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.6876899003982544 to 1.5839236974716187
Validation accuracy: 0.4522691705790297


[Epoch: 14] | Batch 180 | Loss: 2.965014170465015: : 180it [01:14,  4.81it/s] 
0it [00:00, ?it/s]

Validation loss improved from 1.5839236974716187 to 1.5485446453094482
Validation accuracy: 0.4522691705790297


[Epoch: 15] | Batch 180 | Loss: 2.8765870292981464: : 180it [01:13,  4.72it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.5485446453094482 to 1.468976616859436
Validation accuracy: 0.49295774647887325


[Epoch: 16] | Batch 180 | Loss: 2.793602237270938: : 180it [01:11,  4.66it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 17] | Batch 180 | Loss: 2.7215028848133835: : 180it [01:10,  4.84it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.468976616859436 to 1.3766539096832275
Validation accuracy: 0.4835680751173709


[Epoch: 18] | Batch 180 | Loss: 2.6539720512466665: : 180it [01:10,  5.12it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.3766539096832275 to 1.373326063156128
Validation accuracy: 0.4945226917057903


[Epoch: 19] | Batch 180 | Loss: 2.5918031407727136: : 180it [01:06,  5.27it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.373326063156128 to 1.35984206199646
Validation accuracy: 0.49921752738654146


[Epoch: 20] | Batch 180 | Loss: 2.53811689219541: : 180it [01:05,  5.64it/s]  
0it [00:00, ?it/s]

Validation loss improved from 1.35984206199646 to 1.3202146291732788
Validation accuracy: 0.5352112676056338


[Epoch: 21] | Batch 180 | Loss: 2.484073953842991: : 180it [01:09,  5.33it/s] 
0it [00:00, ?it/s]

Validation loss improved from 1.3202146291732788 to 1.2600979804992676
Validation accuracy: 0.5492957746478874


[Epoch: 22] | Batch 180 | Loss: 2.4349537151780996: : 180it [01:10,  4.94it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.2600979804992676 to 1.238937497138977
Validation accuracy: 0.5539906103286385


[Epoch: 23] | Batch 180 | Loss: 2.3877262581373757: : 180it [01:21,  3.87it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.238937497138977 to 1.232243299484253
Validation accuracy: 0.5805946791862285


[Epoch: 24] | Batch 180 | Loss: 2.3426841274593717: : 180it [01:18,  4.36it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 25] | Batch 180 | Loss: 2.302709248476558: : 180it [01:14,  4.59it/s] 
0it [00:00, ?it/s]

Validation loss improved from 1.232243299484253 to 1.2063792943954468
Validation accuracy: 0.5665101721439749


[Epoch: 26] | Batch 180 | Loss: 2.264157097130759: : 180it [01:14,  4.52it/s] 
0it [00:00, ?it/s]

Validation loss improved from 1.2063792943954468 to 1.1855942010879517
Validation accuracy: 0.5696400625978091


[Epoch: 27] | Batch 180 | Loss: 2.2259029984228897: : 180it [01:12,  4.47it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.1855942010879517 to 1.1633621454238892
Validation accuracy: 0.5508607198748043


[Epoch: 28] | Batch 180 | Loss: 2.191855466129288: : 180it [01:11,  4.55it/s] 
0it [00:00, ?it/s]

Validation loss improved from 1.1633621454238892 to 1.159451961517334
Validation accuracy: 0.5712050078247262


[Epoch: 29] | Batch 180 | Loss: 2.1587350577230198: : 180it [01:14,  4.18it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 30] | Batch 180 | Loss: 2.12684331084843: : 180it [01:16,  4.26it/s]  
0it [00:00, ?it/s]

Validation loss improved from 1.159451961517334 to 1.085875153541565
Validation accuracy: 0.594679186228482


[Epoch: 31] | Batch 180 | Loss: 2.097572628208386: : 180it [01:17,  4.51it/s] 
0it [00:00, ?it/s]

Validation loss improved from 1.085875153541565 to 1.0731785297393799
Validation accuracy: 0.6165884194053208


[Epoch: 32] | Batch 180 | Loss: 2.067651619638006: : 180it [01:02,  5.24it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 33] | Batch 180 | Loss: 2.0397774159406574: : 180it [01:03,  5.39it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 34] | Batch 180 | Loss: 2.013298524127287: : 180it [01:00,  5.53it/s] 
0it [00:00, ?it/s]

Validation loss improved from 1.0731785297393799 to 1.0509331226348877
Validation accuracy: 0.6071987480438185


[Epoch: 35] | Batch 180 | Loss: 1.9874683843340193: : 180it [01:02,  4.63it/s]
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 36] | Batch 180 | Loss: 1.963235319028666: : 180it [01:04,  5.48it/s] 
0it [00:00, ?it/s]

Validation loss did not improve


[Epoch: 37] | Batch 180 | Loss: 1.9392913635697093: : 180it [01:09,  4.63it/s]
0it [00:00, ?it/s]

Validation loss improved from 1.0509331226348877 to 1.0128012895584106
Validation accuracy: 0.6118935837245696


[Epoch: 38] | Batch 180 | Loss: 1.9166010909610325: : 180it [01:07,  4.03it/s]
0it [00:00, ?it/s]

Validation loss did not improve





KeyboardInterrupt: 

In [18]:
model.load_state_dict(torch.load('best_cnn_model.pt'))

In [19]:
train_preds, train_labels = get_preds(model, data, 'train',
                                          64, time_steps, stride)

In [20]:
train_loss = objective(train_preds, train_labels).item()
train_loss

0.1921447217464447

In [21]:
acc = get_accuracy(train_preds, train_labels, data['label_encoder'])
acc

0.9416985729202924

In [22]:
valid_preds, valid_labels = get_preds(model, data, 'valid',
                                          64, time_steps, stride)

In [23]:
valid_loss = objective(valid_preds, valid_labels).item()
valid_loss

0.27615249156951904

In [24]:
acc = get_accuracy(valid_preds, valid_labels, data['label_encoder'])
acc

0.9248826291079812