In [32]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import random
import os
import numpy as np
from lstm_classification import LstmClassification
import matplotlib.pyplot as plt
label2text = [str(i+1) for i in range(60)]

In [33]:
def read_data(test=False):
    # plt.figure()
    train_X = []
    train_Y = []
    lengths = []
    # #load data
    files = []
    path = "./data/"
    files = os.listdir(path)
    for file in files:
        if ((not test) and int(file.split("_")[2].split(".")[0])>8) or (test and int(file.split("_")[2].split(".")[0])<=8):
            continue
        x = np.loadtxt(path + file, delimiter=',')
        is_useful = np.sum(abs(x) > 10, 1) > 0
        start = np.argmax(is_useful)
        end = len(is_useful) - np.argmax(is_useful[::-1])
        length = end-start
        lengths.append(length)
        pad = [[0,0,0] for i in range(max_length-length)]
        x = np.concatenate((x[start:end], pad))
        train_X.append(x)
        train_Y.append(label2text.index(file.split("_")[1]))

    return np.array(train_X), np.array(train_Y), np.array(lengths)


In [76]:
input_dim = 3
hidden_dim = 128
output_dim = len(label2text)
epoch = 40000    #20000
batch_size = 16
max_length = 150

train_X, train_Y, lengths = read_data()
fabrics_lstm = LstmClassification(input_dim, hidden_dim, output_dim, bidirectional=True).cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(fabrics_lstm.parameters(), lr = 0.0001)
fabrics_lstm.train()
optimizer.zero_grad()
for i in range(epoch):
    batch_index = np.array([random.randint(0,len(train_X)-1) for _ in range(batch_size)])
    batch_X = torch.from_numpy(train_X[batch_index]).float().cuda()
    batch_Y = torch.from_numpy(train_Y[batch_index]).long().cuda()
    batch_lengths = torch.from_numpy(lengths[batch_index]).cuda()
    batch_lengths, perm_idx = batch_lengths.sort(0, True)
    batch_X = batch_X[perm_idx]
    batch_Y = batch_Y[perm_idx]

    output = fabrics_lstm(batch_X, batch_lengths)

    loss = criterion(output, batch_Y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()  
    
    
    print ("epoch:%f  loss:%f" % (i, loss))

    if loss < 1e-4  and i > 1000 or i == epoch - 1:
        torch.save(fabrics_lstm.state_dict(), "./model/fabrics_lstm.pt")
        break
        


epoch:0.000000  loss:4.114791
epoch:1.000000  loss:4.098124
epoch:2.000000  loss:4.096746
epoch:3.000000  loss:4.088993
epoch:4.000000  loss:4.069513
epoch:5.000000  loss:4.090737
epoch:6.000000  loss:4.107424
epoch:7.000000  loss:4.091794
epoch:8.000000  loss:4.101750
epoch:9.000000  loss:4.074938
epoch:10.000000  loss:4.050159
epoch:11.000000  loss:4.071853
epoch:12.000000  loss:4.045362
epoch:13.000000  loss:4.069872
epoch:14.000000  loss:4.088027
epoch:15.000000  loss:4.072928
epoch:16.000000  loss:4.077981
epoch:17.000000  loss:4.088943
epoch:18.000000  loss:4.050716
epoch:19.000000  loss:4.053168
epoch:20.000000  loss:4.029744
epoch:21.000000  loss:4.052345
epoch:22.000000  loss:4.035497
epoch:23.000000  loss:4.014260
epoch:24.000000  loss:4.038762
epoch:25.000000  loss:4.083438
epoch:26.000000  loss:4.055924
epoch:27.000000  loss:4.051844
epoch:28.000000  loss:4.012474
epoch:29.000000  loss:3.992504
epoch:30.000000  loss:4.013988
epoch:31.000000  loss:4.027504
epoch:32.000000  l

In [86]:
fabrics_lstm.eval()
#validation
valida_lengths = torch.from_numpy(lengths).cuda()
valida_lengths, valida_perm_idx = valida_lengths.sort(0, True)
valida__X = torch.from_numpy(train_X).float().cuda()
valida__Y = torch.from_numpy(train_Y).cuda()
valida__X = valida__X[valida_perm_idx]
valida__Y = valida__Y[valida_perm_idx].long()
output = fabrics_lstm(valida__X, valida_lengths)
_, result = output.max(1)
train_accuracy = torch.eq(result,valida__Y).sum().item()*1.0/len(valida__X)

#test
test_X, test_Y, test_lengths = read_data(True)
test_lengths = torch.from_numpy(test_lengths).cuda()
test_lengths, test_perm_idx = test_lengths.sort(0, True)
test__X = torch.from_numpy(test_X).float().cuda()
test__Y = torch.from_numpy(test_Y).cuda()
test__X = test__X[test_perm_idx]
test__Y = test__Y[test_perm_idx].long()
output = fabrics_lstm(test__X, test_lengths)
_, result = output.max(1)
test_accuracy = torch.eq(result,test__Y).sum().item()*1.0/len(test__X)    

print ("training_accuracy:%f%%   test_accuracy:%f%%" % (train_accuracy*100,test_accuracy*100))
wrong_test = test__Y[~torch.eq(result,test__Y)]
wrong_result = result[~torch.eq(result,test__Y)]
for i in range(len(wrong_test)):
    print ("%d --> %d"%(wrong_test[i],wrong_result[i]))

training_accuracy:100.000000%   test_accuracy:99.166667%
4 --> 12
