In [1]:
from torchmetrics.functional import auc, mean_squared_error
from torchmetrics import F1Score
from tools import *
from CONSTANT import *
from models import CNNBiLSTM, CNNTransformer
from config import Params
from torch.utils.data import (
    TensorDataset, DataLoader, SequentialSampler, WeightedRandomSampler)
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
import torch.optim as optim
import torch.nn.functional as F
from torch import nn
import torch
from tqdm import tqdm
import pandas as pd
import numpy as np
import math
import time
import datetime
import os
import sys
import time
import warnings

# load baseline results

In [2]:
bs_results = load_dict_model(r'./output/KEC/valence_CTransformer_loso_0.0001_64_32/results.pkl')
parse_res(bs_results)
# valence CTransformer

0.56804

In [3]:
bs_results = load_dict_model(r'./output/KEC/arousal_CLSTM_loso_0.0001_64_32/results.pkl')
parse_res(bs_results)
# valence CLSTM

0.9817

# set params

In [3]:
ckpt_path = r'./output/HKU956/valence_CTransformer_loso_0.0001_256_32/fold1_checkpoint.pt'

args = Params(dataset='KEC', 
              model='CTransformer',
              target='valence', 
              debug=False, 
              fcn_input=12608,
              batch_size=32
              )

# load data

In [4]:
spliter = load_model(args.spliter)
data = pd.read_pickle(args.data)

for i, k in enumerate(spliter[args.valid]):
    train_index = k['train_index']
    test_index = k['test_index']
    break

dataprepare = DataPrepare(args,
                        target=args.target, data=data, train_index=train_index, test_index=test_index, device=args.device, batch_size=args.batch_size
                        )

train_dataloader, test_dataloader = dataprepare.get_data()

(2837, 4, 400) (2837, 1) (608, 4, 400) (608, 1)


# load pretrain model

In [5]:
kec_fcn = nn.Sequential(
            nn.Dropout(p=0.1),
            nn.Linear(args.fcn_input, 128),
            nn.ReLU(),
            nn.Dropout(p=0.1),
            nn.Linear(128, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

In [6]:
model = CNNTransformer.CTransformer(args)
model.load_state_dict(torch.load(ckpt_path))
model.fcn = kec_fcn
model = model.to(args.device)

# train and eval

In [7]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    train_loss_list = []
    loss_fn = nn.MSELoss()
    for batch_idx, (data, target) in tqdm(enumerate(train_loader)):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target.float())
        loss.backward()
        optimizer.step()
        train_loss_list.append(loss.item())
    return np.mean(train_loss_list)


def eval(model, device, val_loader):
    model.eval()
    val_loss = []
    loss_fn = nn.MSELoss()
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = loss_fn(output, target.float())
            val_loss.append(loss.item())
    return np.mean(val_loss)

In [8]:
def run(train_loader, val_loader, ckpt_path):
    best_score = float('inf')
    patience = 25
    stop_count = 0
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5,
                                    verbose=True, threshold_mode='rel',
                                    cooldown=0, min_lr=0, eps=1e-08
                                    )
    for epoch in range(1, args.epochs + 1):
        train_loss = train(model, args.device, train_loader, optimizer, epoch)
        val_loss = eval(model, args.device, val_loader)
        scheduler.step(val_loss)
        print('[Epoch{}] | train_loss:{:.4f} | val_loss:{:.4f} | lr:{:e}'.format(epoch, train_loss, val_loss, optimizer.param_groups[0]['lr']))

        if val_loss < best_score:
            best_score = val_loss
            torch.save(model.state_dict(), ckpt_path)
            print("<<<<<< reach best {0} >>>>>>".format(val_loss))
            stop_count = 0
        else:
            model.load_state_dict(torch.load(ckpt_path))
            stop_count += 1
            if stop_count >= patience:
                print("<<<<<< without improvement in {} epoch, early stopping, best score {:.4f} >>>>>>".format(patience, best_score))
                break
        # wandb.log({'train_loss': train_loss, 'val_loss': val_loss})
    print('best score', best_score)
    return model

In [9]:
model = run(train_dataloader, test_dataloader, ckpt_path=os.path.join(args.save_path, 'checkpoint.pt'))

89it [00:05, 15.39it/s]


[Epoch1] | train_loss:1.3936 | val_loss:0.5025 | lr:1.000000e-04
<<<<<< reach best 0.5025444182346722 >>>>>>


89it [00:03, 23.02it/s]


[Epoch2] | train_loss:0.7854 | val_loss:0.5155 | lr:1.000000e-04


89it [00:03, 23.08it/s]


[Epoch3] | train_loss:0.7828 | val_loss:0.5251 | lr:1.000000e-04


89it [00:03, 22.95it/s]


[Epoch4] | train_loss:0.7892 | val_loss:0.5078 | lr:1.000000e-04


89it [00:03, 23.05it/s]


[Epoch5] | train_loss:0.8268 | val_loss:0.4897 | lr:1.000000e-04
<<<<<< reach best 0.489677058330639 >>>>>>


89it [00:03, 23.02it/s]


[Epoch6] | train_loss:0.7558 | val_loss:0.6786 | lr:1.000000e-04


89it [00:03, 23.16it/s]


[Epoch7] | train_loss:0.7603 | val_loss:0.6708 | lr:1.000000e-04


89it [00:03, 23.20it/s]


[Epoch8] | train_loss:0.7804 | val_loss:0.6652 | lr:1.000000e-04


89it [00:03, 23.12it/s]


[Epoch9] | train_loss:0.7842 | val_loss:0.6708 | lr:1.000000e-04


89it [00:03, 23.08it/s]


[Epoch10] | train_loss:0.7857 | val_loss:0.7078 | lr:1.000000e-04


89it [00:03, 23.09it/s]


Epoch 00011: reducing learning rate of group 0 to 5.0000e-05.
[Epoch11] | train_loss:0.7844 | val_loss:0.6434 | lr:5.000000e-05


89it [00:03, 23.05it/s]


[Epoch12] | train_loss:0.6903 | val_loss:0.6647 | lr:5.000000e-05


89it [00:03, 23.17it/s]


[Epoch13] | train_loss:0.6949 | val_loss:0.6435 | lr:5.000000e-05


89it [00:03, 23.19it/s]


[Epoch14] | train_loss:0.6822 | val_loss:0.6380 | lr:5.000000e-05


89it [00:03, 23.27it/s]


[Epoch15] | train_loss:0.7047 | val_loss:0.6568 | lr:5.000000e-05


89it [00:03, 23.07it/s]


[Epoch16] | train_loss:0.7056 | val_loss:0.6612 | lr:5.000000e-05


89it [00:03, 23.17it/s]


Epoch 00017: reducing learning rate of group 0 to 2.5000e-05.
[Epoch17] | train_loss:0.6923 | val_loss:0.6357 | lr:2.500000e-05


89it [00:03, 23.11it/s]


[Epoch18] | train_loss:0.6520 | val_loss:0.5582 | lr:2.500000e-05


89it [00:03, 23.20it/s]


[Epoch19] | train_loss:0.6478 | val_loss:0.5451 | lr:2.500000e-05


89it [00:03, 23.24it/s]


[Epoch20] | train_loss:0.6492 | val_loss:0.5592 | lr:2.500000e-05


89it [00:03, 23.21it/s]


[Epoch21] | train_loss:0.6500 | val_loss:0.5542 | lr:2.500000e-05


89it [00:03, 23.22it/s]


[Epoch22] | train_loss:0.6501 | val_loss:0.5494 | lr:2.500000e-05


89it [00:03, 23.26it/s]


Epoch 00023: reducing learning rate of group 0 to 1.2500e-05.
[Epoch23] | train_loss:0.6502 | val_loss:0.5512 | lr:1.250000e-05


89it [00:03, 23.20it/s]


[Epoch24] | train_loss:0.6388 | val_loss:0.4813 | lr:1.250000e-05
<<<<<< reach best 0.48131972308711785 >>>>>>


89it [00:03, 23.19it/s]


[Epoch25] | train_loss:0.6262 | val_loss:0.4758 | lr:1.250000e-05
<<<<<< reach best 0.4757885927591767 >>>>>>


89it [00:03, 23.12it/s]


[Epoch26] | train_loss:0.6190 | val_loss:0.4746 | lr:1.250000e-05
<<<<<< reach best 0.474614597183015 >>>>>>


89it [00:03, 23.16it/s]


[Epoch27] | train_loss:0.6195 | val_loss:0.4727 | lr:1.250000e-05
<<<<<< reach best 0.4726544385500203 >>>>>>


89it [00:03, 23.13it/s]


[Epoch28] | train_loss:0.6253 | val_loss:0.4706 | lr:1.250000e-05
<<<<<< reach best 0.470639728119989 >>>>>>


89it [00:03, 23.17it/s]


[Epoch29] | train_loss:0.6118 | val_loss:0.4715 | lr:1.250000e-05


89it [00:03, 23.21it/s]


[Epoch30] | train_loss:0.6058 | val_loss:0.4716 | lr:1.250000e-05


89it [00:03, 23.19it/s]


[Epoch31] | train_loss:0.6010 | val_loss:0.4714 | lr:1.250000e-05


89it [00:03, 23.22it/s]


[Epoch32] | train_loss:0.6148 | val_loss:0.4720 | lr:1.250000e-05


89it [00:03, 23.23it/s]


[Epoch33] | train_loss:0.6070 | val_loss:0.4719 | lr:1.250000e-05


89it [00:03, 23.27it/s]


Epoch 00034: reducing learning rate of group 0 to 6.2500e-06.
[Epoch34] | train_loss:0.6235 | val_loss:0.4706 | lr:6.250000e-06


89it [00:03, 23.25it/s]


[Epoch35] | train_loss:0.5926 | val_loss:0.4820 | lr:6.250000e-06


89it [00:03, 23.15it/s]


[Epoch36] | train_loss:0.5940 | val_loss:0.4817 | lr:6.250000e-06


89it [00:03, 23.19it/s]


[Epoch37] | train_loss:0.5988 | val_loss:0.4810 | lr:6.250000e-06


89it [00:03, 23.19it/s]


[Epoch38] | train_loss:0.6018 | val_loss:0.4803 | lr:6.250000e-06


89it [00:03, 23.08it/s]


[Epoch39] | train_loss:0.5996 | val_loss:0.4823 | lr:6.250000e-06


89it [00:03, 22.99it/s]


Epoch 00040: reducing learning rate of group 0 to 3.1250e-06.
[Epoch40] | train_loss:0.6040 | val_loss:0.4817 | lr:3.125000e-06


89it [00:03, 23.01it/s]


[Epoch41] | train_loss:0.6051 | val_loss:0.4855 | lr:3.125000e-06


89it [00:03, 23.14it/s]


[Epoch42] | train_loss:0.5946 | val_loss:0.4853 | lr:3.125000e-06


89it [00:03, 23.17it/s]


[Epoch43] | train_loss:0.5996 | val_loss:0.4857 | lr:3.125000e-06


89it [00:03, 23.19it/s]


[Epoch44] | train_loss:0.6016 | val_loss:0.4855 | lr:3.125000e-06


89it [00:03, 23.25it/s]


[Epoch45] | train_loss:0.6022 | val_loss:0.4856 | lr:3.125000e-06


89it [00:03, 23.10it/s]


Epoch 00046: reducing learning rate of group 0 to 1.5625e-06.
[Epoch46] | train_loss:0.5972 | val_loss:0.4870 | lr:1.562500e-06


89it [00:03, 23.18it/s]


[Epoch47] | train_loss:0.6039 | val_loss:0.4796 | lr:1.562500e-06


89it [00:03, 23.22it/s]


[Epoch48] | train_loss:0.6011 | val_loss:0.4803 | lr:1.562500e-06


89it [00:03, 23.15it/s]


[Epoch49] | train_loss:0.5997 | val_loss:0.4806 | lr:1.562500e-06


89it [00:03, 23.17it/s]


[Epoch50] | train_loss:0.5968 | val_loss:0.4808 | lr:1.562500e-06


89it [00:03, 23.22it/s]


[Epoch51] | train_loss:0.5985 | val_loss:0.4801 | lr:1.562500e-06


89it [00:03, 23.19it/s]


Epoch 00052: reducing learning rate of group 0 to 7.8125e-07.
[Epoch52] | train_loss:0.5915 | val_loss:0.4804 | lr:7.812500e-07


89it [00:03, 23.21it/s]


[Epoch53] | train_loss:0.5997 | val_loss:0.4745 | lr:7.812500e-07
<<<<<< without improvement in 25 epoch, early stopping, best score 0.4706 >>>>>>


CTransformer(
  (cnns): Sequential(
    (0): Conv1d(4, 16, kernel_size=(3,), stride=(1,))
    (1): ReLU()
    (2): MaxPool1d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
    (3): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): Conv1d(16, 32, kernel_size=(3,), stride=(1,))
    (5): MaxPool1d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
    (6): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ReLU()
  )
  (transformer): TransformerBlock(
    (pos_encoder): PositionalEncoding(
      (dropout): Dropout(p=0.2, inplace=False)
    )
    (encoder_layers): TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
      )
      (linear1): Linear(in_features=32, out_features=2048, bias=True)
      (dropout): Dropout(p=0.2, inplace=False)
      (linear2): Linear(in_features=204