In [59]:
import torch
import torch.nn as nn
import torch.optim as optim
from tools import *
from config import Params
from tqdm import tqdm

In [72]:
args = Params(use_cuda=True, debug=False, data=r'./processed_signal/HKU956/956_772_12s_step_4s.pkl', batch_size=128,
            spliter=r'./processed_signal/HKU956/956_772_12s_step_4s_spliter5.pkl',)

spliter = load_model(args.spliter)
data = pd.read_pickle(args.data)

for i, k in enumerate(spliter[args.valid]):
    args.k = i
    print("\n" + "=======" * 6 + '[Fold {}]'.format(i), "=======" * 6)
    train_index = k['train_index']
    test_index = k['test_index']
    break

dataprepare = DataPrepare(args,
                                  target=args.target,
                                  data=data,
                                  train_index=train_index,
                                  test_index=test_index,
                                  device=args.device,
                                  batch_size=args.batch_size)
train_dataloader, test_dataloader = dataprepare.get_data()


Target distribution:
1    552
0    404
dtype: int64
(764, 4, 768) (764, 1) (192, 4, 768) (192, 1)
Train target distribution
1    441
0    323
dtype: int64
Test target distribution
1    111
0     81
dtype: int64


In [73]:
class LSTMClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, device):
        super(LSTMClassifier, self).__init__()
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.device = device
        
    def forward(self, x):
        x = x.permute(0, 2, 1)
        h0 = torch.zeros(1, x.size(0), self.hidden_dim).requires_grad_().to(self.device)
        c0 = torch.zeros(1, x.size(0), self.hidden_dim).requires_grad_().to(self.device)
        out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
        out = self.fc(out[:, -1, :])
        return out

In [74]:
# define the training function
def train(model, optimizer, criterion, train_loader, test_loader, n_epochs):
    for epoch in range(n_epochs):
        running_loss = 0.0
        for i, (inputs, labels) in tqdm(enumerate(train_loader)):
            labels = labels.flatten()
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print('Epoch [%d/%d], Loss: %.4f' % (epoch+1, n_epochs, running_loss / len(train_loader)))

        test(model, criterion, test_loader)

def test(model, criterion, test_loader):
    with torch.no_grad():
        total_correct = 0
        total_samples = 0
        for inputs, labels in test_loader:
            labels = labels.flatten()
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total_samples += labels.size(0)
            total_correct += (predicted == labels).sum().item()
        accuracy = 100.0 * total_correct / total_samples
        print('Test Accuracy: %.2f%% (%d/%d)' % (accuracy, total_correct, total_samples))

In [75]:
input_dim = 4
hidden_dim = 32
output_dim = 2
learning_rate = 0.01
n_epochs = 100

# create the model
model = LSTMClassifier(input_dim, hidden_dim, output_dim, device=args.device)
model = model.to(args.device)

# define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# create the dataset and dataloader
train_data = torch.randn(100, 10, input_dim)
train_labels = torch.randint(0, output_dim, (100,))
train_dataset = torch.utils.data.TensorDataset(train_data, train_labels)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)

# train the model
# train(model, optimizer, criterion, train_loader, n_epochs)
train(model, optimizer, criterion, train_dataloader, test_dataloader, n_epochs)


6it [00:01,  3.76it/s]


Epoch [1/100], Loss: 0.6820
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.91it/s]


Epoch [2/100], Loss: 0.6820
Test Accuracy: 57.81% (111/192)


6it [00:00, 19.68it/s]


Epoch [3/100], Loss: 0.6805
Test Accuracy: 57.81% (111/192)


6it [00:00, 19.82it/s]


Epoch [4/100], Loss: 0.6806
Test Accuracy: 57.81% (111/192)


6it [00:00, 19.67it/s]


Epoch [5/100], Loss: 0.6802
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.43it/s]


Epoch [6/100], Loss: 0.6797
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.69it/s]


Epoch [7/100], Loss: 0.6796
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.19it/s]


Epoch [8/100], Loss: 0.6791
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.39it/s]


Epoch [9/100], Loss: 0.6795
Test Accuracy: 57.81% (111/192)


6it [00:00, 19.80it/s]


Epoch [10/100], Loss: 0.6787
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.85it/s]


Epoch [11/100], Loss: 0.6792
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.04it/s]


Epoch [12/100], Loss: 0.6784
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.06it/s]


Epoch [13/100], Loss: 0.6791
Test Accuracy: 57.81% (111/192)


6it [00:00, 20.40it/s]


Epoch [14/100], Loss: 0.6781
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.14it/s]


Epoch [15/100], Loss: 0.6790
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.51it/s]


Epoch [16/100], Loss: 0.6778
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.05it/s]


Epoch [17/100], Loss: 0.6787
Test Accuracy: 57.81% (111/192)


6it [00:00, 20.18it/s]


Epoch [18/100], Loss: 0.6775
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.16it/s]


Epoch [19/100], Loss: 0.6782
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.40it/s]


Epoch [20/100], Loss: 0.6831
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.49it/s]


Epoch [21/100], Loss: 0.6781
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.34it/s]


Epoch [22/100], Loss: 0.6792
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.65it/s]


Epoch [23/100], Loss: 0.6779
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.64it/s]


Epoch [24/100], Loss: 0.6778
Test Accuracy: 57.81% (111/192)


6it [00:00, 23.67it/s]


Epoch [25/100], Loss: 0.6769
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.33it/s]


Epoch [26/100], Loss: 0.6804
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.20it/s]


Epoch [27/100], Loss: 0.6789
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.06it/s]


Epoch [28/100], Loss: 0.6781
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.47it/s]


Epoch [29/100], Loss: 0.6779
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.22it/s]


Epoch [30/100], Loss: 0.6778
Test Accuracy: 57.81% (111/192)


6it [00:00, 20.55it/s]


Epoch [31/100], Loss: 0.6776
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.30it/s]


Epoch [32/100], Loss: 0.6779
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.43it/s]


Epoch [33/100], Loss: 0.6779
Test Accuracy: 57.81% (111/192)


6it [00:00, 19.66it/s]


Epoch [34/100], Loss: 0.6775
Test Accuracy: 57.81% (111/192)


6it [00:00, 20.98it/s]


Epoch [35/100], Loss: 0.6773
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.98it/s]


Epoch [36/100], Loss: 0.6772
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.88it/s]


Epoch [37/100], Loss: 0.6770
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.43it/s]


Epoch [38/100], Loss: 0.6768
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.14it/s]


Epoch [39/100], Loss: 0.6766
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.38it/s]


Epoch [40/100], Loss: 0.6781
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.88it/s]


Epoch [41/100], Loss: 0.6779
Test Accuracy: 57.81% (111/192)


6it [00:00, 20.99it/s]


Epoch [42/100], Loss: 0.6761
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.24it/s]


Epoch [43/100], Loss: 0.6759
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.80it/s]


Epoch [44/100], Loss: 0.6747
Test Accuracy: 57.81% (111/192)


6it [00:00, 19.17it/s]


Epoch [45/100], Loss: 0.6752
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.28it/s]


Epoch [46/100], Loss: 0.6851
Test Accuracy: 58.33% (112/192)


6it [00:00, 21.47it/s]


Epoch [47/100], Loss: 0.6823
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.55it/s]


Epoch [48/100], Loss: 0.6813
Test Accuracy: 57.81% (111/192)


6it [00:00, 23.57it/s]


Epoch [49/100], Loss: 0.6800
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.81it/s]


Epoch [50/100], Loss: 0.6798
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.76it/s]


Epoch [51/100], Loss: 0.6792
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.72it/s]


Epoch [52/100], Loss: 0.6789
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.42it/s]


Epoch [53/100], Loss: 0.6784
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.58it/s]


Epoch [54/100], Loss: 0.6782
Test Accuracy: 57.81% (111/192)


6it [00:00, 20.53it/s]


Epoch [55/100], Loss: 0.6784
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.22it/s]


Epoch [56/100], Loss: 0.6780
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.19it/s]


Epoch [57/100], Loss: 0.6776
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.14it/s]


Epoch [58/100], Loss: 0.6777
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.13it/s]


Epoch [59/100], Loss: 0.6777
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.39it/s]


Epoch [60/100], Loss: 0.6769
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.95it/s]


Epoch [61/100], Loss: 0.6773
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.73it/s]


Epoch [62/100], Loss: 0.6767
Test Accuracy: 57.29% (110/192)


6it [00:00, 22.78it/s]


Epoch [63/100], Loss: 0.6768
Test Accuracy: 57.81% (111/192)


6it [00:00, 19.55it/s]


Epoch [64/100], Loss: 0.6787
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.22it/s]


Epoch [65/100], Loss: 0.6772
Test Accuracy: 57.81% (111/192)


6it [00:00, 20.98it/s]


Epoch [66/100], Loss: 0.6775
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.73it/s]


Epoch [67/100], Loss: 0.6768
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.44it/s]


Epoch [68/100], Loss: 0.6771
Test Accuracy: 57.29% (110/192)


6it [00:00, 21.26it/s]


Epoch [69/100], Loss: 0.6769
Test Accuracy: 57.29% (110/192)


6it [00:00, 21.35it/s]


Epoch [70/100], Loss: 0.6764
Test Accuracy: 57.29% (110/192)


6it [00:00, 22.73it/s]


Epoch [71/100], Loss: 0.6775
Test Accuracy: 57.29% (110/192)


6it [00:00, 22.08it/s]


Epoch [72/100], Loss: 0.6783
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.82it/s]


Epoch [73/100], Loss: 0.6775
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.28it/s]


Epoch [74/100], Loss: 0.6778
Test Accuracy: 57.81% (111/192)


6it [00:00, 20.22it/s]


Epoch [75/100], Loss: 0.6772
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.47it/s]


Epoch [76/100], Loss: 0.6766
Test Accuracy: 57.81% (111/192)


6it [00:00, 20.62it/s]


Epoch [77/100], Loss: 0.6762
Test Accuracy: 57.29% (110/192)


6it [00:00, 22.30it/s]


Epoch [78/100], Loss: 0.6758
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.47it/s]


Epoch [79/100], Loss: 0.6780
Test Accuracy: 57.29% (110/192)


6it [00:00, 19.64it/s]


Epoch [80/100], Loss: 0.6763
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.96it/s]


Epoch [81/100], Loss: 0.6758
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.32it/s]


Epoch [82/100], Loss: 0.6757
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.00it/s]


Epoch [83/100], Loss: 0.6753
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.66it/s]


Epoch [84/100], Loss: 0.6757
Test Accuracy: 57.29% (110/192)


6it [00:00, 21.21it/s]


Epoch [85/100], Loss: 0.6749
Test Accuracy: 57.29% (110/192)


6it [00:00, 22.48it/s]


Epoch [86/100], Loss: 0.6745
Test Accuracy: 57.81% (111/192)


6it [00:00, 19.67it/s]


Epoch [87/100], Loss: 0.6746
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.80it/s]


Epoch [88/100], Loss: 0.6749
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.02it/s]


Epoch [89/100], Loss: 0.6737
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.58it/s]


Epoch [90/100], Loss: 0.6843
Test Accuracy: 55.21% (106/192)


6it [00:00, 22.55it/s]


Epoch [91/100], Loss: 0.7793
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.12it/s]


Epoch [92/100], Loss: 0.7630
Test Accuracy: 42.19% (81/192)


6it [00:00, 21.90it/s]


Epoch [93/100], Loss: 0.7177
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.28it/s]


Epoch [94/100], Loss: 0.7112
Test Accuracy: 57.81% (111/192)


6it [00:00, 20.13it/s]


Epoch [95/100], Loss: 0.6993
Test Accuracy: 55.21% (106/192)


6it [00:00, 21.90it/s]


Epoch [96/100], Loss: 0.6846
Test Accuracy: 57.81% (111/192)


6it [00:00, 20.90it/s]


Epoch [97/100], Loss: 0.6884
Test Accuracy: 57.81% (111/192)


6it [00:00, 22.28it/s]


Epoch [98/100], Loss: 0.6830
Test Accuracy: 55.73% (107/192)


6it [00:00, 21.90it/s]


Epoch [99/100], Loss: 0.6831
Test Accuracy: 57.81% (111/192)


6it [00:00, 21.51it/s]

Epoch [100/100], Loss: 0.6821
Test Accuracy: 57.81% (111/192)



