In [1]:
import numpy as np
import torch
from torch import nn
import os
from tqdm.notebook import tqdm

# Define the LSTM model
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=True)
        self.fc = nn.Linear(hidden_size * 2, output_size)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers * 2, self.hidden_size).to(x.device) 
        c0 = torch.zeros(self.num_layers * 2, self.hidden_size).to(x.device)
        out, _ = self.lstm(x, (h0, c0))
        out = self.fc(out)
        return out

# Prepare data
X = []
y = []
for dirname, _, filenames in os.walk('./result'):
    for filename in filenames:
        file_path = os.path.join(dirname, filename)
        file = np.load(file_path)
        X.append(file['features'])
        y.append(file['labels'])
N = len(y)

X_test = X[0:N//4]
y_test = y[0:N//4]

X_train = X[N//4:N]
y_train = y[N//4:N]
n = len(y_train)
# print(torch.cuda.is_available())

def train(inputs, outputs):
    # Parameters
    hidden_size = 256
    input_size = 105
    output_size = 7
    num_layers = 2
    num_epochs = 100
    learning_rate = 0.01
    batch_size = 32

    # Instantiate the model
    model = LSTMModel(input_size, hidden_size, num_layers, output_size)

    # Define loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    model = model.to(device='cuda')
    criterion = criterion.to(device='cuda')

    n = len(inputs)
    # Training loop
    for epoch in range(num_epochs):
        total_loss=0
        slices = np.random.randint(0,n-1,(batch_size,))
        for j in tqdm(range(batch_size)):
            i = slices[j]
            input = torch.tensor(inputs[i], dtype=torch.float32).to(device='cuda')
            output = torch.tensor(outputs[i], dtype=torch.long).to(device='cuda')
            output_pred = model(input)
            optimizer.zero_grad()
            loss = criterion(output_pred, output)
            total_loss += loss
            loss.backward()
            optimizer.step()

        # if (epoch+1) % 10 == 0:
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss.item()}')
    return model


In [2]:
model = train(X_train, y_train)

  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 1/100, Loss: 26.21760368347168


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 2/100, Loss: 20.542049407958984


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 3/100, Loss: 21.689367294311523


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 4/100, Loss: 19.155364990234375


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 5/100, Loss: 20.514076232910156


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 6/100, Loss: 20.364412307739258


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 7/100, Loss: 18.978309631347656


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 8/100, Loss: 20.255218505859375


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 9/100, Loss: 17.76738929748535


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 10/100, Loss: 16.102169036865234


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 11/100, Loss: 15.042852401733398


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 12/100, Loss: 16.715681076049805


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 13/100, Loss: 16.876367568969727


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 14/100, Loss: 18.29715347290039


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 15/100, Loss: 15.626832962036133


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 16/100, Loss: 18.182239532470703


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 17/100, Loss: 16.88909149169922


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 18/100, Loss: 14.39378833770752


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 19/100, Loss: 16.220455169677734


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 20/100, Loss: 15.135056495666504


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 21/100, Loss: 15.986533164978027


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 22/100, Loss: 13.620040893554688


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 23/100, Loss: 15.330870628356934


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 24/100, Loss: 14.645596504211426


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 25/100, Loss: 14.118328094482422


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 26/100, Loss: 13.413165092468262


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 27/100, Loss: 14.425512313842773


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 28/100, Loss: 14.006898880004883


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 29/100, Loss: 13.926262855529785


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 30/100, Loss: 15.042328834533691


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 31/100, Loss: 12.267387390136719


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 32/100, Loss: 13.68831729888916


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 33/100, Loss: 12.635119438171387


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 34/100, Loss: 12.96300983428955


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 35/100, Loss: 11.923195838928223


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 36/100, Loss: 12.379673957824707


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 37/100, Loss: 12.515949249267578


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 38/100, Loss: 16.869647979736328


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 39/100, Loss: 12.431291580200195


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 40/100, Loss: 14.307660102844238


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 41/100, Loss: 10.956543922424316


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 42/100, Loss: 12.940834045410156


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 43/100, Loss: 12.011139869689941


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 44/100, Loss: 12.496047019958496


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 45/100, Loss: 11.47098159790039


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 46/100, Loss: 11.951377868652344


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 47/100, Loss: 12.207024574279785


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 48/100, Loss: 14.110847473144531


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 49/100, Loss: 13.56688117980957


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 50/100, Loss: 13.34900951385498


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 51/100, Loss: 12.100984573364258


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 52/100, Loss: 11.335356712341309


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 53/100, Loss: 10.990632057189941


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 54/100, Loss: 10.529277801513672


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 55/100, Loss: 11.068413734436035


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 56/100, Loss: 11.79809284210205


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 57/100, Loss: 13.035368919372559


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 58/100, Loss: 11.03519344329834


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 59/100, Loss: 11.546463012695312


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 60/100, Loss: 9.301698684692383


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 61/100, Loss: 11.27316665649414


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 62/100, Loss: 11.141389846801758


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 63/100, Loss: 10.12055778503418


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 64/100, Loss: 10.572088241577148


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 65/100, Loss: 11.413536071777344


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 66/100, Loss: 11.055182456970215


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 67/100, Loss: 10.730358123779297


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 68/100, Loss: 10.266783714294434


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 69/100, Loss: 9.979137420654297


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 70/100, Loss: 8.80406379699707


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 71/100, Loss: 10.536677360534668


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 72/100, Loss: 12.777175903320312


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 73/100, Loss: 10.602477073669434


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 74/100, Loss: 9.85476016998291


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 75/100, Loss: 10.699928283691406


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 76/100, Loss: 9.240870475769043


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 77/100, Loss: 11.475152969360352


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 78/100, Loss: 9.52474594116211


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 79/100, Loss: 10.80014705657959


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 80/100, Loss: 10.257065773010254


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 81/100, Loss: 13.998160362243652


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 82/100, Loss: 9.847394943237305


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 83/100, Loss: 8.900575637817383


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 84/100, Loss: 9.729482650756836


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 85/100, Loss: 9.807718276977539


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 86/100, Loss: 9.877592086791992


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 87/100, Loss: 9.72604751586914


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 88/100, Loss: 8.783428192138672


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 89/100, Loss: 9.317639350891113


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 90/100, Loss: 8.244608879089355


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 91/100, Loss: 9.742929458618164


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 92/100, Loss: 9.747270584106445


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 93/100, Loss: 11.514961242675781


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 94/100, Loss: 9.975007057189941


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 95/100, Loss: 9.85997486114502


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 96/100, Loss: 11.24404525756836


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 97/100, Loss: 10.01763916015625


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 98/100, Loss: 12.164592742919922


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 99/100, Loss: 11.329961776733398


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 100/100, Loss: 10.103872299194336


In [3]:

import scoring
# Parameters
hidden_size = 256
input_size = 105
output_size = 7
num_layers = 2
num_epochs = 50
learning_rate = 0.01
batch_size = 32
    
def predict(model, inputs):
    y = []
    for input in tqdm(inputs):
        y += predict_one(model, input)
        # print(y)
    return y

def predict_one(model, input):
    output_pred = model(torch.tensor(input, dtype=torch.float32).to(device='cuda'))
    N = len(input)
    y = []
    for i in range(N):
        mx = output_pred[i][0]
        p = 0
        for j in range(1,output_pred.shape[1]):
            if output_pred[i][j] > mx:
                mx = output_pred[i][j]
                p = j
        y.append(p)
    return y

y_train_pred = predict(model, X_train)
y_test_pred = predict(model, X_test)

  0%|          | 0/115 [00:00<?, ?it/s]

  0%|          | 0/38 [00:00<?, ?it/s]

In [4]:
print('train score:')
y_true = []
for y in y_train:
    y_true += y.tolist()
scoring.score(y_train_pred, y_true)
print('test score:')
y_true = []
for y in y_test:
    y_true += y.tolist()
scoring.score(y_test_pred, y_true)

train score:
Class 0: Precision: 0.955, Recall: 0.981
Class 1: Precision: 0.396, Recall: 0.365
Class 2: Precision: 0.792, Recall: 0.860
Class 3: Precision: 0.617, Recall: 0.021
Class 4: Precision: 0.650, Recall: 0.161
Class 5: Precision: 0.847, Recall: 0.625
Class 6: Precision: 0.981, Recall: 0.226
mf1 score: 0.572
acc_score: 0.897
kappa_score: 1.037
test score:
Class 0: Precision: 0.928, Recall: 0.985
Class 1: Precision: 0.110, Recall: 0.242
Class 2: Precision: 0.694, Recall: 0.729
Class 3: Precision: 0.364, Recall: 0.002
Class 4: Precision: 0.311, Recall: 0.012
Class 5: Precision: 0.751, Recall: 0.485
Class 6: Precision: 0.000, Recall: 0.000
mf1 score: 0.395
acc_score: 0.849
kappa_score: 0.708
