In [1]:
from torchmetrics.functional import auc, mean_squared_error
from torchmetrics import F1Score
from tools import *
from CONSTANT import *
from models import CNNBiLSTM, CNNTransformer
from config import Params
from torch.utils.data import (
    TensorDataset, DataLoader, SequentialSampler, WeightedRandomSampler)
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
import torch.optim as optim
import torch.nn.functional as F
from torch import nn
import torch
from tqdm import tqdm
import pandas as pd
import numpy as np
import math
import time
import datetime
import os
import sys
import time
import warnings

# load baseline results

In [2]:
bs_results = load_dict_model(r'./output/KEC/valence_CTransformer_loso_0.0001_64_32/results.pkl')
parse_res(bs_results)
# valence CTransformer

0.56804

In [3]:
bs_results = load_dict_model(r'./output/KEC/arousal_CLSTM_loso_0.0001_64_32/results.pkl')
parse_res(bs_results)
# arousal CLSTM

0.9817

# set params

In [4]:
ckpt_path = r'./output/HKU956/valence_CTransformer_loso_0.0001_256_32/fold1_checkpoint.pt'

args = Params(dataset='KEC', 
              model='CTransformer',
              target='arousal', 
              debug=False, 
              fcn_input=12608,
              batch_size=64
              )

# load data

In [5]:
spliter = load_model(args.spliter)
data = pd.read_pickle(args.data)

for i, k in enumerate(spliter[args.valid]):
    train_index = k['train_index']
    test_index = k['test_index']
    break

dataprepare = DataPrepare(args,
                        target=args.target, data=data, train_index=train_index, test_index=test_index, device=args.device, batch_size=args.batch_size
                        )

train_dataloader, test_dataloader = dataprepare.get_data()

(2837, 4, 400) (2837, 1) (608, 4, 400) (608, 1)


# load pretrain model

In [6]:
kec_fcn = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(args.fcn_input, 128),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(128, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

model = CNNTransformer.CTransformer(args)
model.load_state_dict(torch.load(ckpt_path))
model.fcn = kec_fcn
model = model.to(args.device)

# train and eval

In [6]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    train_loss_list = []
    loss_fn = nn.MSELoss()
    for batch_idx, (data, target) in tqdm(enumerate(train_loader)):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target.float())
        loss.backward()
        optimizer.step()
        train_loss_list.append(loss.item())
    return np.mean(train_loss_list)


def eval(model, device, val_loader):
    model.eval()
    val_loss = []
    loss_fn = nn.MSELoss()
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = loss_fn(output, target.float())
            val_loss.append(loss.item())
    return np.mean(val_loss)

def run(train_loader, val_loader, ckpt_path):
    best_score = float('inf')
    patience = 25
    stop_count = 0
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5,
                                    verbose=True, threshold_mode='rel',
                                    cooldown=0, min_lr=0, eps=1e-08
                                    )
    for epoch in range(1, args.epochs + 1):
        train_loss = train(model, args.device, train_loader, optimizer, epoch)
        val_loss = eval(model, args.device, val_loader)
        scheduler.step(val_loss)
        print('[Epoch{}] | train_loss:{:.4f} | val_loss:{:.4f} | lr:{:e}'.format(epoch, train_loss, val_loss, optimizer.param_groups[0]['lr']))

        if val_loss < best_score:
            best_score = val_loss
            torch.save(model.state_dict(), ckpt_path)
            print("<<<<<< reach best {0} >>>>>>".format(val_loss))
            stop_count = 0
        else:
            model.load_state_dict(torch.load(ckpt_path))
            stop_count += 1
            if stop_count >= patience:
                print("<<<<<< without improvement in {} epoch, early stopping, best score {:.4f} >>>>>>".format(patience, best_score))
                break
        # wandb.log({'train_loss': train_loss, 'val_loss': val_loss})
    print('best score', best_score)
    return model

In [None]:
model = run(train_dataloader, test_dataloader, ckpt_path=os.path.join(args.save_path, 'arousal_checkpoint_dp02.pt'))

# back to HKU956

In [2]:
ckpt_path = r'./output/KEC/valence_CTransformer_loso_0.0001_32_32/checkpoint.pt'

args = Params(dataset='HKU956', 
              model='CTransformer',
              target='arousal', 
              debug=False, 
              fcn_input=12608,
              batch_size=64
              )

In [3]:
spliter = load_model(args.spliter)
data = pd.read_pickle(args.data)

for i, k in enumerate(spliter[args.valid]):
    train_index = k['train_index']
    test_index = k['test_index']
    break

dataprepare = DataPrepare(args,
                        target=args.target, data=data, train_index=train_index, test_index=test_index, device=args.device, batch_size=args.batch_size
                        )

train_dataloader, test_dataloader = dataprepare.get_data()

(18089, 4, 400) (18089, 1) (4638, 4, 400) (4638, 1)


In [4]:
kec_fcn = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(args.fcn_input, 128),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(128, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

model = CNNTransformer.CTransformer(args)
model.load_state_dict(torch.load(ckpt_path))
model.fcn = kec_fcn
model = model.to(args.device)

In [None]:
model = run(train_dataloader, test_dataloader, ckpt_path=os.path.join(args.save_path, 'pretrain_hku_arousal_pretrain_checkpoint.pt'))

# classification

In [2]:
from sklearn.metrics import f1_score

In [3]:
def bin_train(model, device, train_loader, optimizer, epoch):
    model.train()
    train_loss_list = []
    train_f1s = []
    sig = nn.Sigmoid()
    f1_m = F1Score().to(device)
    loss_fn = nn.BCEWithLogitsLoss()
    for batch_idx, (data, target) in tqdm(enumerate(train_loader)):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target.float())
        # f1 = f1_score(torch.round(sig(output)).cpu().detach().numpy().astype(int), target.cpu().detach().numpy(), average='macro')
        f1 = f1_m(torch.round(sig(output)).long(), target).item()
        train_f1s.append(f1)
        loss.backward()
        optimizer.step()
        train_loss_list.append(loss.item())
    return np.mean(train_loss_list), np.mean(train_f1s)


def bin_eval(model, device, val_loader):
    model.eval()
    val_loss = []
    val_f1s = []
    sig = nn.Sigmoid()
    f1_m = F1Score().to(device)
    loss_fn = nn.BCEWithLogitsLoss()
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = loss_fn(output, target.float())
            # f1 = f1_score(torch.round(sig(output)).cpu().detach().numpy(), target.cpu().detach().numpy(), average='macro')
            # print(torch.round(sig(output)).long(), target)
            f1 = f1_m(torch.round(sig(output)).long(), target).item()
            val_f1s.append(f1)
            val_loss.append(loss.item())
    return np.mean(val_loss), np.mean(val_f1s)

def bin_run(model, train_loader, val_loader, ckpt_path):
    best_score = -1 * float('inf')
    patience = 25
    stop_count = 0
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5,
                                    verbose=True, threshold_mode='rel',
                                    cooldown=0, min_lr=0, eps=1e-08
                                    )
    for epoch in range(1, args.epochs + 1):
        train_loss, train_f1 = bin_train(model, args.device, train_loader, optimizer, epoch)
        val_loss, val_f1 = bin_eval(model, args.device, val_loader)
        scheduler.step(val_loss)
        print('[Epoch{}] | train_loss:{:.4f} | val_loss:{:.4f} | train_f1:{:.4f} | val_f1:{:.4f} | lr:{:e}'.format(epoch, train_loss, val_loss, train_f1, val_f1, optimizer.param_groups[0]['lr']))

        if val_f1 > best_score:
            best_score = val_f1
            torch.save(model.state_dict(), ckpt_path)
            print("<<<<<< reach best {0} >>>>>>".format(val_f1))
            stop_count = 0
        else:
            model.load_state_dict(torch.load(ckpt_path))
            stop_count += 1
            if stop_count >= patience:
                print("<<<<<< without improvement in {} epoch, early stopping, best score {:.4f} >>>>>>".format(patience, best_score))
                break
        # wandb.log({'train_loss': train_loss, 'val_loss': val_loss})
    print('best score', best_score)
    return model

In [10]:
ckpt_path = r'./output/HKU956/valence_CTransformer_loso_0.0001_256_32/fold2_checkpoint.pt'

args = Params(dataset='KEC', 
              model='CTransformer',
              target='arousal_label', 
              debug=False, 
              fcn_input=12608,
              batch_size=64,
              valid='cv'
              )

# valence_label 0.7562

spliter = load_model(args.spliter)
data = pd.read_pickle(args.data)

for i, k in enumerate(spliter[args.valid]):
    train_index = k['train_index']
    test_index = k['test_index']
    break

dataprepare = DataPrepare(args,
                        target=args.target, data=data, train_index=train_index, test_index=test_index, device=args.device, batch_size=args.batch_size
                        )

train_dataloader, test_dataloader = dataprepare.get_data()

(2756, 4, 400) (2756, 1) (689, 4, 400) (689, 1)


In [11]:
kec_fcn = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(args.fcn_input, 128),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(128, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

model = CNNTransformer.CTransformer(args)
model.load_state_dict(torch.load(ckpt_path))
model.fcn = kec_fcn
model = model.to(args.device)

In [12]:
model = bin_run(model, train_dataloader, test_dataloader, 'bin_kec_valence_pretrain_checkpoint.pt')

44it [00:04, 10.64it/s]


[Epoch1] | train_loss:0.6640 | val_loss:0.6094 | train_f1:0.6584 | val_f1:0.7190 | lr:1.000000e-04
<<<<<< reach best 0.7189819108356129 >>>>>>


44it [00:03, 11.62it/s]


[Epoch2] | train_loss:0.6420 | val_loss:0.6150 | train_f1:0.6808 | val_f1:0.7190 | lr:1.000000e-04


44it [00:03, 11.54it/s]


[Epoch3] | train_loss:0.6402 | val_loss:0.6139 | train_f1:0.6808 | val_f1:0.7190 | lr:1.000000e-04


44it [00:03, 11.59it/s]


[Epoch4] | train_loss:0.6393 | val_loss:0.6144 | train_f1:0.6808 | val_f1:0.7190 | lr:1.000000e-04


44it [00:03, 11.51it/s]


[Epoch5] | train_loss:0.6416 | val_loss:0.6160 | train_f1:0.6808 | val_f1:0.7190 | lr:1.000000e-04


44it [00:03, 11.60it/s]


[Epoch6] | train_loss:0.6412 | val_loss:0.6153 | train_f1:0.6808 | val_f1:0.7190 | lr:1.000000e-04


44it [00:03, 11.10it/s]


Epoch 00007: reducing learning rate of group 0 to 5.0000e-05.
[Epoch7] | train_loss:0.6407 | val_loss:0.6162 | train_f1:0.6808 | val_f1:0.7190 | lr:5.000000e-05


44it [00:03, 11.66it/s]


[Epoch8] | train_loss:0.6335 | val_loss:0.6095 | train_f1:0.6808 | val_f1:0.7190 | lr:5.000000e-05


44it [00:03, 11.65it/s]


[Epoch9] | train_loss:0.6324 | val_loss:0.6091 | train_f1:0.6808 | val_f1:0.7190 | lr:5.000000e-05


44it [00:03, 11.50it/s]


[Epoch10] | train_loss:0.6320 | val_loss:0.6094 | train_f1:0.6808 | val_f1:0.7190 | lr:5.000000e-05


44it [00:03, 11.67it/s]


[Epoch11] | train_loss:0.6322 | val_loss:0.6094 | train_f1:0.6808 | val_f1:0.7190 | lr:5.000000e-05


44it [00:03, 11.78it/s]


[Epoch12] | train_loss:0.6339 | val_loss:0.6098 | train_f1:0.6808 | val_f1:0.7190 | lr:5.000000e-05


44it [00:03, 11.76it/s]


[Epoch13] | train_loss:0.6325 | val_loss:0.6094 | train_f1:0.6808 | val_f1:0.7190 | lr:5.000000e-05


44it [00:03, 11.12it/s]


[Epoch14] | train_loss:0.6357 | val_loss:0.6109 | train_f1:0.6808 | val_f1:0.7190 | lr:5.000000e-05


44it [00:03, 11.64it/s]


Epoch 00015: reducing learning rate of group 0 to 2.5000e-05.
[Epoch15] | train_loss:0.6337 | val_loss:0.6092 | train_f1:0.6808 | val_f1:0.7190 | lr:2.500000e-05


44it [00:03, 11.76it/s]


[Epoch16] | train_loss:0.6281 | val_loss:0.6043 | train_f1:0.6808 | val_f1:0.7190 | lr:2.500000e-05


44it [00:03, 11.40it/s]


[Epoch17] | train_loss:0.6279 | val_loss:0.6042 | train_f1:0.6808 | val_f1:0.7190 | lr:2.500000e-05


44it [00:03, 11.40it/s]


[Epoch18] | train_loss:0.6291 | val_loss:0.6046 | train_f1:0.6808 | val_f1:0.7190 | lr:2.500000e-05


44it [00:03, 11.44it/s]


[Epoch19] | train_loss:0.6269 | val_loss:0.6042 | train_f1:0.6808 | val_f1:0.7190 | lr:2.500000e-05


44it [00:03, 11.52it/s]


[Epoch20] | train_loss:0.6292 | val_loss:0.6042 | train_f1:0.6808 | val_f1:0.7190 | lr:2.500000e-05


44it [00:03, 11.59it/s]


[Epoch21] | train_loss:0.6297 | val_loss:0.6042 | train_f1:0.6808 | val_f1:0.7190 | lr:2.500000e-05


44it [00:03, 11.83it/s]


[Epoch22] | train_loss:0.6292 | val_loss:0.6048 | train_f1:0.6808 | val_f1:0.7190 | lr:2.500000e-05


44it [00:03, 11.83it/s]


[Epoch23] | train_loss:0.6293 | val_loss:0.6046 | train_f1:0.6808 | val_f1:0.7190 | lr:2.500000e-05


44it [00:03, 11.67it/s]


[Epoch24] | train_loss:0.6278 | val_loss:0.6041 | train_f1:0.6808 | val_f1:0.7190 | lr:2.500000e-05


44it [00:03, 11.59it/s]


[Epoch25] | train_loss:0.6263 | val_loss:0.6040 | train_f1:0.6808 | val_f1:0.7190 | lr:2.500000e-05


44it [00:03, 11.66it/s]


[Epoch26] | train_loss:0.6287 | val_loss:0.6047 | train_f1:0.6808 | val_f1:0.7190 | lr:2.500000e-05
<<<<<< without improvement in 25 epoch, early stopping, best score 0.7190 >>>>>>
best score 0.7189819108356129
