In [7]:
import numpy as np
import torch
from torch import nn
import os
from tqdm.notebook import tqdm

In [8]:
# Define the LSTM model
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=True)
        self.fc = nn.Linear(hidden_size * 2, output_size)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers * 2, self.hidden_size).to(x.device) 
        c0 = torch.zeros(self.num_layers * 2, self.hidden_size).to(x.device)
        out, _ = self.lstm(x, (h0, c0))
        out = self.fc(out)
        return out


In [9]:
# Prepare data
X = []
y = []
for dirname, _, filenames in os.walk('./result'):
    for filename in filenames:
        file_path = os.path.join(dirname, filename)
        file = np.load(file_path)
        X.append(file['features'])
        y.append(file['labels'])
N = len(y)

X_test = X[0:N//4]
y_test = y[0:N//4]

X_train = X[N//4:N]
y_train = y[N//4:N]
n = len(y_train)
# print(torch.cuda.is_available())

In [10]:

def train(inputs, outputs):
    # Parameters
    hidden_size = 256
    input_size = 105
    output_size = 7
    num_layers = 2
    num_epochs = 100
    learning_rate = 0.01
    batch_size = 32

    # Instantiate the model
    model = LSTMModel(input_size, hidden_size, num_layers, output_size)

    # Define loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    model = model.to(device='cuda')
    criterion = criterion.to(device='cuda')

    n = len(inputs)
    # Training loop
    for epoch in range(num_epochs):
        total_loss=0
        slices = np.random.randint(0,n-1,(batch_size,))
        for j in tqdm(range(batch_size)):
            i = slices[j]
            input = torch.tensor(inputs[i], dtype=torch.float32).to(device='cuda')
            output = torch.tensor(outputs[i], dtype=torch.long).to(device='cuda')
            output_pred = model(input)
            optimizer.zero_grad()
            loss = criterion(output_pred, output)
            total_loss += loss
            loss.backward()
            optimizer.step()

        # if (epoch+1) % 10 == 0:
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss.item()}')
    return model


In [11]:
model = train(X_train, y_train)

  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 1/100, Loss: 25.896987915039062


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 2/100, Loss: 23.696809768676758


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 3/100, Loss: 18.68931007385254


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 4/100, Loss: 19.201263427734375


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 5/100, Loss: 17.947114944458008


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 6/100, Loss: 17.619916915893555


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 7/100, Loss: 15.686386108398438


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 8/100, Loss: 19.294219970703125


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 9/100, Loss: 14.474236488342285


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 10/100, Loss: 16.77131462097168


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 11/100, Loss: 18.2420597076416


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 12/100, Loss: 18.23877716064453


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 13/100, Loss: 14.372781753540039


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 14/100, Loss: 14.732131004333496


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 15/100, Loss: 15.532395362854004


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 16/100, Loss: 13.538240432739258


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 17/100, Loss: 15.730769157409668


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 18/100, Loss: 16.739091873168945


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 19/100, Loss: 14.220458984375


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 20/100, Loss: 14.59953784942627


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 21/100, Loss: 15.775411605834961


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 22/100, Loss: 14.186503410339355


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 23/100, Loss: 12.718279838562012


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 24/100, Loss: 13.198070526123047


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 25/100, Loss: 14.160350799560547


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 26/100, Loss: 12.696510314941406


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 27/100, Loss: 14.810194969177246


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 28/100, Loss: 12.526734352111816


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 29/100, Loss: 13.474266052246094


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 30/100, Loss: 14.682963371276855


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 31/100, Loss: 13.40316390991211


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 32/100, Loss: 11.401834487915039


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 33/100, Loss: 11.912910461425781


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 34/100, Loss: 11.081364631652832


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 35/100, Loss: 13.592769622802734


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 36/100, Loss: 12.126814842224121


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 37/100, Loss: 12.392843246459961


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 38/100, Loss: 12.91526985168457


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 39/100, Loss: 13.156179428100586


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 40/100, Loss: 12.641663551330566


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 41/100, Loss: 13.074846267700195


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 42/100, Loss: 12.694363594055176


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 43/100, Loss: 12.871809959411621


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 44/100, Loss: 13.969071388244629


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 45/100, Loss: 10.540632247924805


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 46/100, Loss: 10.86459732055664


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 47/100, Loss: 11.0736665725708


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 48/100, Loss: 11.714128494262695


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 49/100, Loss: 13.21635913848877


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 50/100, Loss: 10.638955116271973


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 51/100, Loss: 12.50018310546875


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 52/100, Loss: 11.525776863098145


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 53/100, Loss: 11.476055145263672


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 54/100, Loss: 9.706205368041992


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 55/100, Loss: 12.775014877319336


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 56/100, Loss: 11.374446868896484


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 57/100, Loss: 10.439889907836914


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 58/100, Loss: 9.786698341369629


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 59/100, Loss: 11.005970001220703


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 60/100, Loss: 10.535987854003906


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 61/100, Loss: 11.554068565368652


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 62/100, Loss: 9.297061920166016


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 63/100, Loss: 10.35595989227295


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 64/100, Loss: 10.576927185058594


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 65/100, Loss: 10.29020881652832


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 66/100, Loss: 11.605777740478516


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 67/100, Loss: 10.140684127807617


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 68/100, Loss: 10.171228408813477


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 69/100, Loss: 10.704283714294434


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 70/100, Loss: 13.180298805236816


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 71/100, Loss: 10.427045822143555


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 72/100, Loss: 10.945908546447754


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 73/100, Loss: 10.546229362487793


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 74/100, Loss: 9.227752685546875


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 75/100, Loss: 9.98647403717041


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 76/100, Loss: 9.55467414855957


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 77/100, Loss: 10.266983985900879


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 78/100, Loss: 8.527851104736328


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 79/100, Loss: 9.542495727539062


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 80/100, Loss: 10.754571914672852


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 81/100, Loss: 10.44175910949707


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 82/100, Loss: 10.070686340332031


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 83/100, Loss: 9.525789260864258


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 84/100, Loss: 8.882603645324707


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 85/100, Loss: 8.678709030151367


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 86/100, Loss: 9.17988395690918


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 87/100, Loss: 9.00948429107666


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 88/100, Loss: 9.149613380432129


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 89/100, Loss: 9.546273231506348


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 90/100, Loss: 9.539321899414062


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 91/100, Loss: 8.7986421585083


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 92/100, Loss: 8.44086742401123


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 93/100, Loss: 7.99735164642334


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 94/100, Loss: 7.994756698608398


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 95/100, Loss: 8.630611419677734


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 96/100, Loss: 8.077882766723633


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 97/100, Loss: 7.2956132888793945


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 98/100, Loss: 7.92908239364624


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 99/100, Loss: 8.804515838623047


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 100/100, Loss: 8.215948104858398


In [23]:
import scoring

def predict(model, inputs):
    y = []
    for input in tqdm(inputs):
        y += predict_one(model, input)
        # print(y)
    return y

def predict_one(model, input):
    output_pred = model(torch.tensor(input, dtype=torch.float32).to(device='cuda'))
    N = len(input)
    y = []
    for i in range(N):
        mx = output_pred[i][0]
        p = 0
        for j in range(1,output_pred.shape[1]):
            if output_pred[i][j] > mx:
                mx = output_pred[i][j]
                p = j
        y.append(p)
    return y


y_pred = predict(model, X_train)


  0%|          | 0/115 [00:00<?, ?it/s]

train score:


TypeError: only length-1 arrays can be converted to Python scalars

In [30]:
print('train score:')
y_true = []
for y in y_train:
    y_true += y.tolist()
print(len(y_pred))
print(len(y_true))
scoring.score(y_pred, y_true)


train score:
293884
293884
Class 0: Precision: 0.959, Recall: 0.988
Class 1: Precision: 0.519, Recall: 0.334
Class 2: Precision: 0.821, Recall: 0.857
Class 3: Precision: 0.628, Recall: 0.131
Class 4: Precision: 0.653, Recall: 0.440
Class 5: Precision: 0.825, Recall: 0.797
Class 6: Precision: 0.890, Recall: 0.565
mf1 score: 0.661
acc_score: 0.913
kappa_score: 0.983


  pe += np.dot(cfm[:,i], cfm[i,:])


  0%|          | 0/38 [00:00<?, ?it/s]

test score:


ValueError: operands could not be broadcast together with shapes (0,) (2497,) 

In [None]:
y_pred = predict(model, X_test)


In [31]:
print('test score:')
y_true = []
for y in y_test:
    y_true += y.tolist()
scoring.score(y_pred, y_true)

test score:
Class 0: Precision: 0.907, Recall: 0.987
Class 1: Precision: 0.134, Recall: 0.153
Class 2: Precision: 0.702, Recall: 0.687
Class 3: Precision: 0.255, Recall: 0.018
Class 4: Precision: 0.225, Recall: 0.024
Class 5: Precision: 0.694, Recall: 0.529
Class 6: Precision: 0.000, Recall: 0.000
mf1 score: nan
acc_score: 0.846
kappa_score: 0.697


  pre += TP / (TP + FP)
