In [1]:
from argparse import ArgumentParser
from torchmetrics.functional import auc, mean_squared_error
from torchmetrics import F1Score
from tools import *
from CONSTANT import *
from models import CNNBiLSTM, Transformer
from config import Params
from torch.utils.data import (
    TensorDataset, DataLoader, SequentialSampler, WeightedRandomSampler)
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
import torch.optim as optim
import torch.nn.functional as F
from torch import nn
import torch
from tqdm import tqdm
import pandas as pd
import numpy as np
import math
import time
import datetime
import os
import sys
import time
import warnings
from copy import deepcopy

In [2]:
args = Params(debug=True)
spliter = load_model(
    r'./processed_signal/HKU956/400_4s_step_2s_spliter.pkl')
# data = pd.read_pickle(r'./processed_signal/HKU956/400_4s_step_2s.pkl')
data = pd.read_csv(r'./processed_signal/HKU956/400_4s_step_2s.csv')
for i, k in enumerate(spliter[args.valid]):
    args.k = i
    print('[Fold {}]'.format(i), '='*31)
    train_index = k['train_index']
    test_index = k['test_index']
    break



In [11]:
class DataPrepare(object):
    def __init__(self, args, target, data, train_index, test_index, device, batch_size=64):

        self.args = args

        X, y = join_signals(data, target=target)
        xtrain, ytrain, xtest, ytest = X[train_index], y[train_index], X[test_index], y[test_index]

        if self.args.debug:
            xtrain, ytrain, xtest, ytest = xtrain[:
                                                  100], ytrain[:100], xtest[:100], ytest[:100]
        print(xtrain.shape, ytrain.shape, xtest.shape, ytest.shape)

        xtrain = torch.from_numpy(xtrain).to(torch.float32)
        xtest = torch.from_numpy(xtest).to(torch.float32)

        self.xtrain, self.xtest = xtrain.to(device), xtest.to(device)

        self.xtrain.requires_grad_()
        # self.xtest.requires_grad_()

        ytrain = torch.from_numpy(ytrain)
        ytest = torch.from_numpy(ytest)

        self.ytrain, self.ytest = ytrain.to(device), ytest.to(device)

        if args.target in ['valence', 'arousal']:
            self.ytrain = self.ytrain.to(torch.float32)
            self.ytest = self.ytest.to(torch.float32)

        self.batch_size = batch_size

    def get_data(self):
        train_data = TensorDataset(self.xtrain, self.ytrain)
        test_data = TensorDataset(self.xtest, self.ytest)

        train_sampler = SequentialSampler(train_data)
        train_dataloader = DataLoader(
            train_data, sampler=train_sampler, batch_size=self.batch_size, drop_last=False)

        test_sampler = SequentialSampler(test_data)
        test_dataloader = DataLoader(
            test_data, sampler=test_sampler, batch_size=self.batch_size, drop_last=False)

        return train_dataloader, test_dataloader

In [23]:
from sklearn.metrics import f1_score

In [58]:
def train(args, model, train_dataloader, optimizer, epoch):
    model.train()
    train_loss_list = []
    correct_list = []
    # loss_fn = nn.BCEWithLogitsLoss()
    # loss = F.binary_cross_entropy_with_logits(output, target)
    for batch_idx, (data, target) in tqdm(enumerate(train_dataloader)):
        data, target = data.to(args.device), target.to(args.device)
        optimizer.zero_grad()
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)
        pred = pred.float()
        target = target.float()
        loss = F.binary_cross_entropy_with_logits(pred, target)
        loss.requires_grad_()
        loss.backward()
        optimizer.step()
        train_loss_list.append(loss.item())
        
        correct_list.append(pred.eq(target.view_as(pred)).sum().item()/len(target))
    return np.mean(train_loss_list), np.mean(correct_list)

def eval(model, device, val_dataloader):
    model.eval()
    val_loss = 0
    correct = 0
    f1_ = []
    with torch.no_grad():
        for data, target in val_dataloader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            pred = pred.float()
            target = target.float()
            loss = F.binary_cross_entropy_with_logits(pred, target)
            val_loss += loss.item()  # sum up batch loss
            # pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            f1_.append(f1_score(target.cpu().numpy(), pred.flatten().cpu().numpy(), average='weighted'))
            correct += pred.eq(target.view_as(pred)).sum().item()


    val_loss /= len(val_dataloader.dataset)
    accuracy = correct / len(val_dataloader.dataset)
    return val_loss, accuracy, np.mean(f1_)

In [12]:
dataprepare = DataPrepare(args,
            target='valence', data=data, train_index=train_index, test_index=test_index, device=args.device, batch_size=args.batch_size)

(100, 4, 400) (100, 1) (100, 4, 400) (100, 1)


In [17]:
train_dataloader, test_dataloader = dataprepare.get_data()

In [18]:
model = CNNBiLSTM.CNNBiLSTM(args)

In [19]:
model = model.to(args.device)

In [21]:
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

In [59]:
for epoch in range(args.epochs):
    print('Epoch', epoch)
    train_loss, train_acc = train(args, model, train_dataloader, optimizer, epoch)
    val_loss, val_acc, f1 = eval(model, args.device, test_dataloader)
    print('[Epoch{}] | train_loss:{:.4f} | val_loss:{:.4f} | train_acc:{:.4f} | val_acc:{:.4f} | val_f1:{:.4f} | lr:{:e}'.format(epoch, train_loss,
            val_loss, train_acc, val_acc, f1, optimizer.param_groups[0]['lr']))

Epoch 0


7it [00:01,  5.42it/s]


[Epoch0] | train_loss:0.6864 | val_loss:0.0485 | train_acc:0.0179 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 1


7it [00:01,  6.87it/s]


[Epoch1] | train_loss:0.6898 | val_loss:0.0485 | train_acc:0.0089 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 2


7it [00:01,  6.91it/s]


[Epoch2] | train_loss:0.6898 | val_loss:0.0485 | train_acc:0.0089 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 3


7it [00:01,  6.77it/s]


[Epoch3] | train_loss:0.6796 | val_loss:0.0485 | train_acc:0.0357 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 4


7it [00:01,  6.84it/s]


[Epoch4] | train_loss:0.6830 | val_loss:0.0485 | train_acc:0.0268 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 5


7it [00:01,  6.82it/s]


[Epoch5] | train_loss:0.6864 | val_loss:0.0485 | train_acc:0.0179 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 6


7it [00:01,  6.85it/s]


[Epoch6] | train_loss:0.6898 | val_loss:0.0485 | train_acc:0.0089 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 7


7it [00:01,  6.83it/s]


[Epoch7] | train_loss:0.6796 | val_loss:0.0485 | train_acc:0.0357 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 8


7it [00:01,  6.80it/s]


[Epoch8] | train_loss:0.6898 | val_loss:0.0485 | train_acc:0.0089 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 9


7it [00:01,  6.85it/s]


[Epoch9] | train_loss:0.6796 | val_loss:0.0485 | train_acc:0.0357 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 10


7it [00:01,  6.78it/s]


[Epoch10] | train_loss:0.6830 | val_loss:0.0485 | train_acc:0.0268 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 11


7it [00:01,  6.79it/s]


[Epoch11] | train_loss:0.6864 | val_loss:0.0485 | train_acc:0.0179 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 12


7it [00:01,  6.63it/s]


[Epoch12] | train_loss:0.6864 | val_loss:0.0485 | train_acc:0.0179 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 13


7it [00:01,  6.76it/s]


[Epoch13] | train_loss:0.6796 | val_loss:0.0485 | train_acc:0.0357 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 14


7it [00:01,  6.71it/s]


[Epoch14] | train_loss:0.6931 | val_loss:0.0485 | train_acc:0.0000 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 15


7it [00:01,  6.67it/s]


[Epoch15] | train_loss:0.6864 | val_loss:0.0485 | train_acc:0.0179 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 16


7it [00:01,  6.83it/s]


[Epoch16] | train_loss:0.6864 | val_loss:0.0485 | train_acc:0.0179 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 17


7it [00:01,  6.90it/s]


[Epoch17] | train_loss:0.6830 | val_loss:0.0485 | train_acc:0.0268 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 18


7it [00:01,  6.81it/s]


[Epoch18] | train_loss:0.6898 | val_loss:0.0485 | train_acc:0.0089 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 19


7it [00:01,  6.71it/s]


[Epoch19] | train_loss:0.6898 | val_loss:0.0485 | train_acc:0.0089 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 20


7it [00:01,  6.94it/s]


[Epoch20] | train_loss:0.6898 | val_loss:0.0485 | train_acc:0.0089 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 21


7it [00:01,  6.90it/s]


[Epoch21] | train_loss:0.6898 | val_loss:0.0485 | train_acc:0.0089 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 22


7it [00:01,  6.91it/s]


[Epoch22] | train_loss:0.6864 | val_loss:0.0485 | train_acc:0.0179 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 23


7it [00:01,  6.90it/s]


[Epoch23] | train_loss:0.6796 | val_loss:0.0485 | train_acc:0.0357 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 24


7it [00:01,  6.92it/s]


[Epoch24] | train_loss:0.6864 | val_loss:0.0485 | train_acc:0.0179 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 25


7it [00:01,  6.74it/s]


[Epoch25] | train_loss:0.6864 | val_loss:0.0485 | train_acc:0.0179 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 26


7it [00:01,  6.75it/s]


[Epoch26] | train_loss:0.6864 | val_loss:0.0485 | train_acc:0.0179 | val_acc:0.5400 | val_f1:0.4905 | lr:1.000000e-02
Epoch 27


0it [00:00, ?it/s]


KeyboardInterrupt: 