In [1]:
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel, AutoConfig
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup, get_linear_schedule_with_warmup
from sklearn.metrics import f1_score
import time
import math
from tqdm import tqdm
import torchaudio
import IPython.display as ipd
from transformers import AutoConfig, Wav2Vec2Processor
from transformers.models.wav2vec2.modeling_wav2vec2 import (
    Wav2Vec2PreTrainedModel,
    Wav2Vec2Model
)

In [None]:
#Check Parameters
fold = 0
model_name = 'klue/roberta-base'
BATCH_SIZE =64
MAX_LEN =256
EPOCHS = 30
set_lr = 1e-4

In [2]:
import random
import os 
def seed_everything(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(42)

In [3]:
df = pd.read_csv('/workspace/ETRI/new_data.csv')
df = df[['Numb','Segment ID','Total Evaluation','text','max_count']]

In [4]:
mapping_info = {"neutral":0,"angry":1,"disgust":2,"fear":3,"happy":4,"sad":5,"surprise":6}
df['target'] = df['Total Evaluation'].map(mapping_info)

In [None]:
import re
s1 = re.compile('\n')
s1 = re.compile('\n')

def remove_characters(sentence, lower=True):
    sentence = s1.sub(' ', str(sentence))
    if lower:
        sentence = sentence.lower()
    return sentence

df['text'] = df['text'].map(remove_characters)

In [6]:
wav_dir = '/workspace/ETRI/KEMDy20/wav/'
for i in range(len(df)):
    SegmentID = df.iloc[i,1]
    tmp_dir = wav_dir + "Session" + SegmentID[4:6] +"/" + SegmentID + ".wav"
    df.loc[i,'wav_dir'] = tmp_dir

In [7]:
folds = [['Sess01','Sess02','Sess03','Sess04','Sess05','Sess06','Sess07','Sess08'],
['Sess09','Sess10','Sess11','Sess12','Sess13','Sess14','Sess15','Sess16'],
['Sess17','Sess18','Sess19','Sess20','Sess21','Sess22','Sess23','Sess24'],
['Sess25','Sess26','Sess27','Sess28','Sess29','Sess30','Sess31','Sess32'],
['Sess33','Sess34','Sess35','Sess36','Sess37','Sess38','Sess39','Sess40']]

In [8]:
test_list = '|'.join(folds[fold])

In [9]:
test = df[df['Segment ID'].str.contains(test_list)]

In [10]:
train =df[df['Segment ID'].str.contains(test_list) == False]

In [11]:
num_labels = 7
input_column = "wav_dir"
output_column = "Total Evaluation"

model_name_or_path = "facebook/wav2vec2-base-960h"
pooling_mode = "mean"

audio_config = AutoConfig.from_pretrained(
    model_name_or_path,
    num_labels=num_labels,
    finetuning_task="wav2vec2_clf",
)
setattr(audio_config, 'pooling_mode', pooling_mode)

In [12]:
processor = Wav2Vec2Processor.from_pretrained(model_name_or_path,)
target_sampling_rate = processor.feature_extractor.sampling_rate
print(f"The target sampling rate: {target_sampling_rate}")

The target sampling rate: 16000


In [13]:
def speech_file_to_array_fn(path):
    speech_array, sampling_rate = torchaudio.load(path)
    resampler = torchaudio.transforms.Resample(sampling_rate, target_sampling_rate)
    speech = resampler(speech_array).squeeze().numpy()
    input_values = processor(speech, sampling_rate=sampling_rate, return_tensors="pt").input_values
    return input_values

In [14]:
class SentimentDataset(Dataset):
  def __init__(self, subjects, targets, tokenizer, max_len,wav_dir,max_wv_len):
    self.subjects = subjects
    self.targets = targets
    self.tokenizer = tokenizer
    self.max_len = max_len
    self.wav_dir = wav_dir
    self.max_wv_len = max_wv_len
  def __len__(self):
    return len(self.subjects)
  def __getitem__(self, item):
    subject = str(self.subjects[item])
    target = self.targets[item]
    encoding = self.tokenizer.encode_plus(
      subject,
      add_special_tokens=True,
      max_length=self.max_len,
      return_token_type_ids=False,
      padding = 'max_length',
      truncation = True,
      return_attention_mask=True,
      return_tensors='pt',
    )

    wav_data = speech_file_to_array_fn(self.wav_dir[item])
    if wav_data.size(-1) > self.max_wv_len:
      wav_data = wav_data[:, :self.max_wv_len]
    else:
      k = self.max_wv_len // wav_data.size(-1)
      tmp = torch.zeros(self.max_wv_len - k * wav_data.size(-1)).unsqueeze(0)
      tmp2 = wav_data
      for i in range(k-1):
        wav_data = torch.cat([wav_data,tmp2], dim=1) 
      wav_data = torch.cat([wav_data,tmp], dim=1) 

    return {
      'subject_text': subject,
      'input_ids': encoding['input_ids'].flatten(),
      'attention_mask': encoding['attention_mask'].flatten(),
      'targets': torch.tensor(target, dtype=torch.long),
      'wav_data': wav_data.flatten(),
    }
def create_data_loader(df, tokenizer, max_len, max_wv_len, batch_size, shuffle_=False, valid=False):
  ds = SentimentDataset(
    subjects=df.text.to_numpy(),
    targets=df.target.to_numpy(),
    tokenizer=tokenizer,
    max_len=max_len,
    wav_dir=df.wav_dir.to_numpy(),
    max_wv_len = max_wv_len,
  )
  return DataLoader(
    ds,
    batch_size=batch_size,
    num_workers=4,
    shuffle = shuffle_
  )

In [15]:
def calc_review_acc(pred, label):
    _, idx = pred.max(1)
    
    acc = torch.eq(idx, label).sum().item() / idx.size()[0] 
    x = label.cpu().numpy()
    y = idx.cpu().numpy()
    f1_acc = f1_score(x, y, average='macro')
    return acc,f1_acc


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))

In [16]:
model_name = 'klue/roberta-base'
tokenizer = AutoTokenizer.from_pretrained(model_name)
BATCH_SIZE =64
MAX_LEN =196
MAX_WV_LEN = 4 * 16000
train_data_loader = create_data_loader(train, tokenizer, MAX_LEN, MAX_WV_LEN, BATCH_SIZE, shuffle_=True)
test_data_loader = create_data_loader(test, tokenizer, MAX_LEN, MAX_WV_LEN, 1, valid=True)

In [17]:
class SentimentClassifier(nn.Module):
  def __init__(self, n_classes,audio_config):
    super(SentimentClassifier, self).__init__()
    self.bert = AutoModel.from_pretrained(model_name)
    self.drop = nn.Dropout(p=0.1)
    self.audio_config = audio_config
    self.pooling_mode = audio_config.pooling_mode
    self.wav2vec2 = Wav2Vec2Model(audio_config)
    def get_cls(target_size= n_classes):
      return nn.Sequential(
          nn.Linear(self.bert.config.hidden_size + self.audio_config.hidden_size, self.bert.config.hidden_size + self.audio_config.hidden_size),
          nn.LayerNorm(self.bert.config.hidden_size + self.audio_config.hidden_size),
          nn.Dropout(p = 0.1),
          nn.ReLU(),
          nn.Linear(self.bert.config.hidden_size + self.audio_config.hidden_size, target_size),
      )  
    self.cls = get_cls(n_classes)



  def freeze_feature_extractor(self):
      self.wav2vec2.feature_extractor._freeze_parameters()

  def merged_strategy(
          self,
          hidden_states,
          mode="mean"
  ):
      if mode == "mean":
          outputs = torch.mean(hidden_states, dim=1)
      elif mode == "sum":
          outputs = torch.sum(hidden_states, dim=1)
      elif mode == "max":
          outputs = torch.max(hidden_states, dim=1)[0]
      else:
          raise Exception(
              "The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']")

      return outputs

  def forward(self, input_ids, attention_mask,input_values,
            audio_attention_mask=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,):
    _, pooled_output = self.bert(
      input_ids=input_ids,
      attention_mask=attention_mask,
      return_dict=False
    )
    output = self.drop(pooled_output)

    return_dict = return_dict if return_dict is not None else self.audio_config.use_return_dict
    output2 = self.wav2vec2(
            input_values,
            attention_mask=audio_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
    hidden_states = output2[0]
    hidden_states = self.merged_strategy(hidden_states, mode=self.pooling_mode)

    output2 = self.drop(hidden_states)

    output = torch.cat([output,output2],1) 
    out = self.cls(output)

    return out

In [18]:
device = torch.device("cuda")

EPOCHS = 20
model = SentimentClassifier(n_classes=7,audio_config = audio_config).to(device)
optimizer = optim.AdamW(model.parameters(), lr=set_lr)
total_steps = len(train_data_loader) * EPOCHS
scheduler = get_cosine_schedule_with_warmup(
  optimizer,
  num_warmup_steps=int(total_steps*0.1),
  num_training_steps=total_steps
)

nSamples = train.target.value_counts().sort_index().tolist()
normedWeights = [1 - (x / sum(nSamples)) for x in nSamples]
normedWeights = torch.FloatTensor(normedWeights).to(device)
#loss_fn = nn.CrossEntropyLoss(normedWeights).to(device)

loss_fn = nn.CrossEntropyLoss().to(device)

Some weights of the model checkpoint at klue/roberta-base were not used when initializing RobertaModel: ['lm_head.decoder.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaModel were not initialized from the model checkpoint at klue/roberta-base and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for

In [19]:
model = nn.DataParallel(model)

In [20]:
def train_epoch(model,data_loader,loss_fn,optimizer,device,scheduler,n_examples):

  batch_time = AverageMeter()     
  data_time = AverageMeter()      
  losses = AverageMeter()         
  accuracies = AverageMeter()
  f1_accuracies = AverageMeter()
  
  sent_count = AverageMeter()   
    

  start = end = time.time()

  model = model.train()
  correct_predictions = 0
  for step,d in enumerate(data_loader):
    data_time.update(time.time() - end)
    batch_size = d["input_ids"].size(0) 
    wav_data = d["wav_data"].to(device)
    input_ids = d["input_ids"].to(device)
    attention_mask = d["attention_mask"].to(device)
    targets = d["targets"].to(device)
    outputs = model(
      input_ids=input_ids,
      attention_mask=attention_mask,
      input_values=wav_data,
    )
    _, preds = torch.max(outputs, dim=1)
    loss = loss_fn(outputs, targets)
    correct_predictions += torch.sum(preds == targets)
    losses.update(loss.item(), batch_size)
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    scheduler.step()
    optimizer.zero_grad()

    batch_time.update(time.time() - end)
    end = time.time()

    sent_count.update(batch_size)
    if step % 50 == 0 or step == (len(data_loader)-1):
                acc,f1_acc = calc_review_acc(outputs, targets)
                accuracies.update(acc, batch_size)
                f1_accuracies.update(f1_acc, batch_size)

                
                print('Epoch: [{0}][{1}/{2}] '
                      'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                      'Elapsed {remain:s} '
                      'Loss: {loss.val:.3f}({loss.avg:.3f}) '
                      'Acc: {acc.val:.3f}({acc.avg:.3f}) '   
                      'f1_Acc: {f1_acc.val:.3f}({f1_acc.avg:.3f}) '           
                      'sent/s {sent_s:.0f} '
                      .format(
                      epoch, step+1, len(data_loader),
                      data_time=data_time, loss=losses,
                      acc=accuracies,
                      f1_acc=f1_accuracies,
                      remain=timeSince(start, float(step+1)/len(data_loader)),
                      sent_s=sent_count.avg/batch_time.avg
                      ))

  return correct_predictions.double() / n_examples, losses.avg

def validate(model,data_loader,loss_fn,optimizer,device,scheduler,n_examples):
  model = model.eval()
  losses = []
  outputs_arr = []
  preds_arr = []
  targets_arr = []
  correct_predictions = 0
  for d in tqdm(data_loader):
    input_ids = d["input_ids"].to(device)
    attention_mask = d["attention_mask"].to(device)
    wav_data = d["wav_data"].to(device)
    targets = d["targets"].to(device)
    outputs = model(
      input_ids=input_ids,
      attention_mask=attention_mask,
      input_values=wav_data,
    )
    _, preds = torch.max(outputs, dim=1)
    outputs_arr.append(outputs.cpu().detach().numpy()[0])
    preds_arr.append(preds.cpu().numpy())
    
    loss = loss_fn(outputs, targets)
    correct_predictions += torch.sum(preds == targets)
    targets_arr.append(targets.cpu().numpy())
    losses.append(loss.item())
    nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
  return correct_predictions.double() / n_examples, np.mean(losses), outputs_arr, preds_arr, targets_arr





In [21]:
for epoch in range(EPOCHS):
  print('-' * 10)
  print(f'Epoch {epoch}/{EPOCHS-1}')
  print('-' * 10)
  train_acc, train_loss = train_epoch(
    model,
    train_data_loader,
    loss_fn,
    optimizer, 
    device,
    scheduler,
    len(train)
  )
  validate_acc, validate_loss, outputs_arr, preds_arr, targets_max_arr= validate(
      model,
      test_data_loader,
      loss_fn,
      optimizer,
      device,
      scheduler,
      len(test)
  )
  print(f'Train loss {train_loss} accuracy {train_acc}')
  print(f'validate loss {validate_loss} accuracy {validate_acc}')
  print(f'validate f1-score: ',f1_score(preds_arr, targets_max_arr, average='macro'))
  print("")
  print("")

----------
Epoch 0/19
----------
Epoch: [0][1/168] Data 0.388 (0.388) Elapsed 0m 5s (remain 15m 24s) Loss: 1.876(1.876) Acc: 0.125(0.125) f1_Acc: 0.042(0.042) sent/s 12 
Epoch: [0][51/168] Data 0.004 (0.012) Elapsed 0m 42s (remain 1m 36s) Loss: 0.638(1.147) Acc: 0.844(0.484) f1_Acc: 0.183(0.113) sent/s 78 
Epoch: [0][101/168] Data 0.004 (0.008) Elapsed 1m 18s (remain 0m 52s) Loss: 0.593(0.840) Acc: 0.844(0.604) f1_Acc: 0.183(0.136) sent/s 82 
Epoch: [0][151/168] Data 0.004 (0.007) Elapsed 1m 55s (remain 0m 13s) Loss: 0.571(0.718) Acc: 0.844(0.664) f1_Acc: 0.153(0.140) sent/s 83 
Epoch: [0][168/168] Data 0.004 (0.007) Elapsed 2m 9s (remain 0m 0s) Loss: 0.404(0.689) Acc: 0.879(0.704) f1_Acc: 0.305(0.171) sent/s 83 


100%|██████████| 2716/2716 [02:17<00:00, 19.79it/s]

Train loss 0.689118568555128 accuracy 0.8181648985669087
validate loss 0.36044301031039416 accuracy 0.9072164948453608
validate f1-score:  0.2293595862836857


----------
Epoch 1/19
----------





Epoch: [1][1/168] Data 0.453 (0.453) Elapsed 0m 1s (remain 4m 0s) Loss: 0.630(0.630) Acc: 0.844(0.844) f1_Acc: 0.255(0.255) sent/s 45 
Epoch: [1][51/168] Data 0.004 (0.014) Elapsed 0m 38s (remain 1m 27s) Loss: 0.449(0.427) Acc: 0.891(0.867) f1_Acc: 0.308(0.282) sent/s 86 
Epoch: [1][101/168] Data 0.004 (0.009) Elapsed 1m 14s (remain 0m 49s) Loss: 0.328(0.422) Acc: 0.875(0.870) f1_Acc: 0.385(0.316) sent/s 87 
Epoch: [1][151/168] Data 0.004 (0.008) Elapsed 1m 51s (remain 0m 12s) Loss: 0.244(0.424) Acc: 0.922(0.883) f1_Acc: 0.340(0.322) sent/s 87 
Epoch: [1][168/168] Data 0.004 (0.008) Elapsed 2m 4s (remain 0m 0s) Loss: 0.324(0.425) Acc: 0.897(0.885) f1_Acc: 0.519(0.358) sent/s 86 


100%|██████████| 2716/2716 [02:18<00:00, 19.58it/s]


Train loss 0.42512929782884296 accuracy 0.8831193002047274
validate loss 0.32120578646363585 accuracy 0.9171575846833578
validate f1-score:  0.22865605844959255


----------
Epoch 2/19
----------
Epoch: [2][1/168] Data 0.451 (0.451) Elapsed 0m 1s (remain 3m 18s) Loss: 0.510(0.510) Acc: 0.844(0.844) f1_Acc: 0.262(0.262) sent/s 54 
Epoch: [2][51/168] Data 0.005 (0.014) Elapsed 0m 37s (remain 1m 27s) Loss: 0.394(0.380) Acc: 0.875(0.859) f1_Acc: 0.419(0.341) sent/s 86 
Epoch: [2][101/168] Data 0.004 (0.010) Elapsed 1m 14s (remain 0m 49s) Loss: 0.365(0.379) Acc: 0.891(0.870) f1_Acc: 0.312(0.331) sent/s 86 
Epoch: [2][151/168] Data 0.004 (0.008) Elapsed 1m 51s (remain 0m 12s) Loss: 0.378(0.364) Acc: 0.859(0.867) f1_Acc: 0.243(0.309) sent/s 86 
Epoch: [2][168/168] Data 0.006 (0.008) Elapsed 2m 4s (remain 0m 0s) Loss: 0.359(0.364) Acc: 0.879(0.869) f1_Acc: 0.251(0.298) sent/s 86 


100%|██████████| 2716/2716 [02:19<00:00, 19.53it/s]


Train loss 0.3643220168362111 accuracy 0.889819467708915
validate loss 0.30679449361227096 accuracy 0.9120029455081001
validate f1-score:  0.2362752362752363


----------
Epoch 3/19
----------
Epoch: [3][1/168] Data 0.425 (0.425) Elapsed 0m 1s (remain 3m 15s) Loss: 0.222(0.222) Acc: 0.906(0.906) f1_Acc: 0.521(0.521) sent/s 55 
Epoch: [3][51/168] Data 0.005 (0.013) Elapsed 0m 38s (remain 1m 27s) Loss: 0.378(0.297) Acc: 0.875(0.891) f1_Acc: 0.319(0.420) sent/s 86 
Epoch: [3][101/168] Data 0.004 (0.009) Elapsed 1m 14s (remain 0m 49s) Loss: 0.121(0.297) Acc: 0.953(0.911) f1_Acc: 0.591(0.477) sent/s 86 
Epoch: [3][151/168] Data 0.005 (0.008) Elapsed 1m 52s (remain 0m 12s) Loss: 0.281(0.297) Acc: 0.906(0.910) f1_Acc: 0.321(0.438) sent/s 86 
Epoch: [3][168/168] Data 0.004 (0.008) Elapsed 2m 4s (remain 0m 0s) Loss: 0.188(0.297) Acc: 0.948(0.917) f1_Acc: 0.886(0.521) sent/s 86 


100%|██████████| 2716/2716 [02:21<00:00, 19.20it/s]

Train loss 0.29688449194261357 accuracy 0.9011725293132329
validate loss 0.28670267105870295 accuracy 0.9120029455081001
validate f1-score:  0.28131667190702286


----------
Epoch 4/19
----------





Epoch: [4][1/168] Data 0.446 (0.446) Elapsed 0m 1s (remain 3m 22s) Loss: 0.100(0.100) Acc: 0.984(0.984) f1_Acc: 0.960(0.960) sent/s 53 
Epoch: [4][51/168] Data 0.004 (0.013) Elapsed 0m 37s (remain 1m 26s) Loss: 0.268(0.206) Acc: 0.891(0.938) f1_Acc: 0.224(0.592) sent/s 86 
Epoch: [4][101/168] Data 0.005 (0.009) Elapsed 1m 14s (remain 0m 49s) Loss: 0.404(0.229) Acc: 0.875(0.917) f1_Acc: 0.358(0.514) sent/s 87 
Epoch: [4][151/168] Data 0.005 (0.008) Elapsed 1m 51s (remain 0m 12s) Loss: 0.181(0.223) Acc: 0.922(0.918) f1_Acc: 0.504(0.511) sent/s 87 
Epoch: [4][168/168] Data 0.004 (0.008) Elapsed 2m 4s (remain 0m 0s) Loss: 0.127(0.224) Acc: 0.948(0.924) f1_Acc: 0.887(0.581) sent/s 86 


100%|██████████| 2716/2716 [02:20<00:00, 19.35it/s]

Train loss 0.22444415357322328 accuracy 0.9205285687697748
validate loss 0.32189093813981134 accuracy 0.9024300441826215
validate f1-score:  0.37312362063237264


----------
Epoch 5/19
----------





Epoch: [5][1/168] Data 0.467 (0.467) Elapsed 0m 1s (remain 3m 27s) Loss: 0.156(0.156) Acc: 0.953(0.953) f1_Acc: 0.528(0.528) sent/s 52 
Epoch: [5][51/168] Data 0.005 (0.014) Elapsed 0m 37s (remain 1m 26s) Loss: 0.172(0.165) Acc: 0.938(0.945) f1_Acc: 0.683(0.606) sent/s 86 
Epoch: [5][101/168] Data 0.004 (0.010) Elapsed 1m 14s (remain 0m 49s) Loss: 0.083(0.169) Acc: 0.969(0.953) f1_Acc: 0.448(0.553) sent/s 86 
Epoch: [5][151/168] Data 0.005 (0.008) Elapsed 1m 51s (remain 0m 12s) Loss: 0.139(0.160) Acc: 0.969(0.957) f1_Acc: 0.923(0.646) sent/s 86 
Epoch: [5][168/168] Data 0.004 (0.008) Elapsed 2m 4s (remain 0m 0s) Loss: 0.199(0.160) Acc: 0.948(0.955) f1_Acc: 0.906(0.694) sent/s 86 


100%|██████████| 2716/2716 [02:18<00:00, 19.56it/s]

Train loss 0.15957949806275856 accuracy 0.9470500651405174
validate loss 0.38602960645705525 accuracy 0.8969072164948454
validate f1-score:  0.3315858906856479


----------
Epoch 6/19
----------





Epoch: [6][1/168] Data 0.442 (0.442) Elapsed 0m 1s (remain 3m 21s) Loss: 0.164(0.164) Acc: 0.891(0.891) f1_Acc: 0.894(0.894) sent/s 53 
Epoch: [6][51/168] Data 0.004 (0.014) Elapsed 0m 37s (remain 1m 26s) Loss: 0.057(0.106) Acc: 0.984(0.938) f1_Acc: 0.977(0.935) sent/s 86 
Epoch: [6][101/168] Data 0.004 (0.010) Elapsed 1m 14s (remain 0m 49s) Loss: 0.212(0.109) Acc: 0.938(0.938) f1_Acc: 0.441(0.771) sent/s 87 
Epoch: [6][151/168] Data 0.005 (0.008) Elapsed 1m 51s (remain 0m 12s) Loss: 0.056(0.112) Acc: 0.984(0.949) f1_Acc: 0.798(0.778) sent/s 86 
Epoch: [6][168/168] Data 0.005 (0.008) Elapsed 2m 4s (remain 0m 0s) Loss: 0.129(0.112) Acc: 0.966(0.952) f1_Acc: 0.703(0.764) sent/s 86 


100%|██████████| 2716/2716 [02:18<00:00, 19.58it/s]

Train loss 0.11236401368747227 accuracy 0.9621254420249395
validate loss 0.44197168435246414 accuracy 0.8825478645066274
validate f1-score:  0.346115610374954


----------
Epoch 7/19
----------





Epoch: [7][1/168] Data 0.406 (0.406) Elapsed 0m 1s (remain 3m 8s) Loss: 0.139(0.139) Acc: 0.953(0.953) f1_Acc: 0.695(0.695) sent/s 57 
Epoch: [7][51/168] Data 0.004 (0.013) Elapsed 0m 37s (remain 1m 26s) Loss: 0.034(0.070) Acc: 0.984(0.969) f1_Acc: 0.664(0.679) sent/s 87 
Epoch: [7][101/168] Data 0.005 (0.009) Elapsed 1m 14s (remain 0m 49s) Loss: 0.208(0.074) Acc: 0.938(0.958) f1_Acc: 0.635(0.665) sent/s 87 
Epoch: [7][151/168] Data 0.005 (0.008) Elapsed 1m 51s (remain 0m 12s) Loss: 0.072(0.074) Acc: 0.984(0.965) f1_Acc: 0.815(0.702) sent/s 86 
Epoch: [7][168/168] Data 0.007 (0.008) Elapsed 2m 4s (remain 0m 0s) Loss: 0.132(0.076) Acc: 0.966(0.965) f1_Acc: 0.911(0.741) sent/s 86 


100%|██████████| 2716/2716 [02:20<00:00, 19.35it/s]

Train loss 0.07572758535577358 accuracy 0.9758049506793226
validate loss 0.5068199166310268 accuracy 0.8888070692194403
validate f1-score:  0.3491806784674959


----------
Epoch 8/19
----------





Epoch: [8][1/168] Data 0.431 (0.431) Elapsed 0m 1s (remain 3m 21s) Loss: 0.075(0.075) Acc: 0.953(0.953) f1_Acc: 0.641(0.641) sent/s 53 
Epoch: [8][51/168] Data 0.004 (0.014) Elapsed 0m 37s (remain 1m 26s) Loss: 0.073(0.047) Acc: 0.969(0.961) f1_Acc: 0.848(0.744) sent/s 86 
Epoch: [8][101/168] Data 0.004 (0.010) Elapsed 1m 14s (remain 0m 49s) Loss: 0.014(0.051) Acc: 1.000(0.974) f1_Acc: 1.000(0.829) sent/s 86 
Epoch: [8][151/168] Data 0.005 (0.008) Elapsed 1m 52s (remain 0m 12s) Loss: 0.167(0.051) Acc: 0.922(0.961) f1_Acc: 0.382(0.718) sent/s 86 
Epoch: [8][168/168] Data 0.004 (0.008) Elapsed 2m 4s (remain 0m 0s) Loss: 0.126(0.051) Acc: 0.966(0.962) f1_Acc: 0.912(0.754) sent/s 86 


100%|██████████| 2716/2716 [02:19<00:00, 19.43it/s]


Train loss 0.05147629346862843 accuracy 0.9840871021775545
validate loss 0.5501056852510898 accuracy 0.875920471281296
validate f1-score:  0.3821749396658664


----------
Epoch 9/19
----------
Epoch: [9][1/168] Data 0.422 (0.422) Elapsed 0m 1s (remain 3m 11s) Loss: 0.026(0.026) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 56 
Epoch: [9][51/168] Data 0.004 (0.013) Elapsed 0m 37s (remain 1m 27s) Loss: 0.009(0.032) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 86 
Epoch: [9][101/168] Data 0.004 (0.009) Elapsed 1m 14s (remain 0m 49s) Loss: 0.026(0.035) Acc: 0.984(0.995) f1_Acc: 0.978(0.993) sent/s 86 
Epoch: [9][151/168] Data 0.009 (0.008) Elapsed 1m 52s (remain 0m 12s) Loss: 0.007(0.034) Acc: 1.000(0.996) f1_Acc: 1.000(0.995) sent/s 86 
Epoch: [9][168/168] Data 0.004 (0.008) Elapsed 2m 4s (remain 0m 0s) Loss: 0.071(0.035) Acc: 0.983(0.994) f1_Acc: 0.970(0.990) sent/s 86 


100%|██████████| 2716/2716 [02:19<00:00, 19.53it/s]

Train loss 0.0350939930663991 accuracy 0.9888330541596874
validate loss 0.5736102638532568 accuracy 0.8928571428571428
validate f1-score:  0.3813363924963338


----------
Epoch 10/19
----------





Epoch: [10][1/168] Data 0.421 (0.421) Elapsed 0m 1s (remain 3m 16s) Loss: 0.006(0.006) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 55 
Epoch: [10][51/168] Data 0.005 (0.013) Elapsed 0m 37s (remain 1m 26s) Loss: 0.020(0.017) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 86 
Epoch: [10][101/168] Data 0.004 (0.009) Elapsed 1m 14s (remain 0m 49s) Loss: 0.009(0.021) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 87 
Epoch: [10][151/168] Data 0.005 (0.008) Elapsed 1m 51s (remain 0m 12s) Loss: 0.010(0.021) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 86 
Epoch: [10][168/168] Data 0.005 (0.008) Elapsed 2m 4s (remain 0m 0s) Loss: 0.049(0.021) Acc: 0.983(0.997) f1_Acc: 0.924(0.986) sent/s 86 


100%|██████████| 2716/2716 [02:19<00:00, 19.52it/s]

Train loss 0.020657192148155902 accuracy 0.9938581797878281
validate loss 0.6674389047890782 accuracy 0.8810751104565537
validate f1-score:  0.36863181016704166


----------
Epoch 11/19
----------





Epoch: [11][1/168] Data 0.445 (0.445) Elapsed 0m 1s (remain 3m 38s) Loss: 0.012(0.012) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 49 
Epoch: [11][51/168] Data 0.004 (0.014) Elapsed 0m 38s (remain 1m 27s) Loss: 0.015(0.016) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 86 
Epoch: [11][101/168] Data 0.005 (0.009) Elapsed 1m 14s (remain 0m 49s) Loss: 0.003(0.020) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 86 
Epoch: [11][151/168] Data 0.005 (0.008) Elapsed 1m 52s (remain 0m 12s) Loss: 0.013(0.020) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 86 
Epoch: [11][168/168] Data 0.005 (0.008) Elapsed 2m 5s (remain 0m 0s) Loss: 0.036(0.020) Acc: 0.983(0.997) f1_Acc: 0.962(0.993) sent/s 86 


100%|██████████| 2716/2716 [02:19<00:00, 19.47it/s]

Train loss 0.01963695268623882 accuracy 0.9945095849618463
validate loss 0.6670324074928334 accuracy 0.8884388807069219
validate f1-score:  0.40468466258582547


----------
Epoch 12/19
----------





Epoch: [12][1/168] Data 0.423 (0.423) Elapsed 0m 1s (remain 3m 14s) Loss: 0.035(0.035) Acc: 0.984(0.984) f1_Acc: 0.976(0.976) sent/s 55 
Epoch: [12][51/168] Data 0.009 (0.014) Elapsed 0m 37s (remain 1m 27s) Loss: 0.002(0.014) Acc: 1.000(0.992) f1_Acc: 1.000(0.988) sent/s 86 
Epoch: [12][101/168] Data 0.005 (0.009) Elapsed 1m 14s (remain 0m 49s) Loss: 0.028(0.014) Acc: 0.984(0.990) f1_Acc: 0.818(0.931) sent/s 87 
Epoch: [12][151/168] Data 0.005 (0.008) Elapsed 1m 51s (remain 0m 12s) Loss: 0.001(0.013) Acc: 1.000(0.992) f1_Acc: 1.000(0.949) sent/s 86 
Epoch: [12][168/168] Data 0.004 (0.008) Elapsed 2m 4s (remain 0m 0s) Loss: 0.004(0.014) Acc: 1.000(0.994) f1_Acc: 1.000(0.958) sent/s 86 


100%|██████████| 2716/2716 [02:19<00:00, 19.46it/s]

Train loss 0.013759005907350982 accuracy 0.9960915689558906
validate loss 0.7076337023703557 accuracy 0.8877025036818851
validate f1-score:  0.37375441499017853


----------
Epoch 13/19
----------





Epoch: [13][1/168] Data 0.425 (0.425) Elapsed 0m 1s (remain 3m 17s) Loss: 0.000(0.000) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 55 
Epoch: [13][51/168] Data 0.004 (0.014) Elapsed 0m 37s (remain 1m 26s) Loss: 0.000(0.010) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 86 
Epoch: [13][101/168] Data 0.005 (0.010) Elapsed 1m 14s (remain 0m 49s) Loss: 0.001(0.010) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 86 
Epoch: [13][151/168] Data 0.004 (0.008) Elapsed 1m 52s (remain 0m 12s) Loss: 0.001(0.011) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 86 
Epoch: [13][168/168] Data 0.007 (0.008) Elapsed 2m 4s (remain 0m 0s) Loss: 0.000(0.010) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 86 


100%|██████████| 2716/2716 [02:20<00:00, 19.39it/s]

Train loss 0.009632485577449751 accuracy 0.9974874371859297
validate loss 0.7475144512047741 accuracy 0.8814432989690721
validate f1-score:  0.3578975845791333


----------
Epoch 14/19
----------





Epoch: [14][1/168] Data 0.433 (0.433) Elapsed 0m 1s (remain 3m 17s) Loss: 0.001(0.001) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 54 
Epoch: [14][51/168] Data 0.004 (0.013) Elapsed 0m 37s (remain 1m 27s) Loss: 0.002(0.005) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 86 
Epoch: [14][101/168] Data 0.005 (0.010) Elapsed 1m 14s (remain 0m 49s) Loss: 0.003(0.006) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 86 
Epoch: [14][151/168] Data 0.005 (0.008) Elapsed 1m 51s (remain 0m 12s) Loss: 0.003(0.006) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 86 
Epoch: [14][168/168] Data 0.004 (0.008) Elapsed 2m 4s (remain 0m 0s) Loss: 0.003(0.006) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 86 


100%|██████████| 2716/2716 [02:19<00:00, 19.53it/s]

Train loss 0.005646911430490343 accuracy 0.9985110738879583
validate loss 0.7556522320687563 accuracy 0.8832842415316642
validate f1-score:  0.3522134936719322


----------
Epoch 15/19
----------





Epoch: [15][1/168] Data 0.430 (0.430) Elapsed 0m 1s (remain 3m 20s) Loss: 0.001(0.001) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 54 
Epoch: [15][51/168] Data 0.004 (0.014) Elapsed 0m 37s (remain 1m 26s) Loss: 0.000(0.004) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 86 
Epoch: [15][101/168] Data 0.004 (0.009) Elapsed 1m 14s (remain 0m 49s) Loss: 0.003(0.004) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 86 
Epoch: [15][151/168] Data 0.004 (0.008) Elapsed 1m 51s (remain 0m 12s) Loss: 0.001(0.004) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 86 
Epoch: [15][168/168] Data 0.006 (0.008) Elapsed 2m 4s (remain 0m 0s) Loss: 0.000(0.004) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 86 


100%|██████████| 2716/2716 [02:20<00:00, 19.31it/s]


Train loss 0.004434960832579795 accuracy 0.9988833054159687
validate loss 0.7894133729849364 accuracy 0.8792341678939617
validate f1-score:  0.3518348432478013


----------
Epoch 16/19
----------
Epoch: [16][1/168] Data 0.433 (0.433) Elapsed 0m 1s (remain 3m 22s) Loss: 0.061(0.061) Acc: 0.984(0.984) f1_Acc: 0.980(0.980) sent/s 53 
Epoch: [16][51/168] Data 0.004 (0.014) Elapsed 0m 37s (remain 1m 26s) Loss: 0.027(0.006) Acc: 0.984(0.984) f1_Acc: 0.980(0.980) sent/s 86 
Epoch: [16][101/168] Data 0.004 (0.010) Elapsed 1m 14s (remain 0m 49s) Loss: 0.001(0.005) Acc: 1.000(0.990) f1_Acc: 1.000(0.987) sent/s 87 
Epoch: [16][151/168] Data 0.005 (0.008) Elapsed 1m 51s (remain 0m 12s) Loss: 0.000(0.005) Acc: 1.000(0.992) f1_Acc: 1.000(0.990) sent/s 86 
Epoch: [16][168/168] Data 0.004 (0.008) Elapsed 2m 4s (remain 0m 0s) Loss: 0.016(0.005) Acc: 0.983(0.990) f1_Acc: 0.960(0.984) sent/s 86 


100%|██████████| 2716/2716 [02:18<00:00, 19.54it/s]

Train loss 0.0051306760146251585 accuracy 0.9983249581239532
validate loss 0.7796842033305615 accuracy 0.8770250368188512
validate f1-score:  0.35411618058775174


----------
Epoch 17/19
----------





Epoch: [17][1/168] Data 0.431 (0.431) Elapsed 0m 1s (remain 3m 19s) Loss: 0.001(0.001) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 54 
Epoch: [17][51/168] Data 0.005 (0.014) Elapsed 0m 37s (remain 1m 26s) Loss: 0.001(0.001) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 86 
Epoch: [17][101/168] Data 0.005 (0.010) Elapsed 1m 14s (remain 0m 49s) Loss: 0.000(0.002) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 86 
Epoch: [17][151/168] Data 0.005 (0.008) Elapsed 1m 52s (remain 0m 12s) Loss: 0.002(0.002) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 86 
Epoch: [17][168/168] Data 0.004 (0.008) Elapsed 2m 4s (remain 0m 0s) Loss: 0.001(0.002) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 86 


100%|██████████| 2716/2716 [02:18<00:00, 19.54it/s]


Train loss 0.002122973148537172 accuracy 0.999534710589987
validate loss 0.7710434791795442 accuracy 0.8836524300441826
validate f1-score:  0.3590533619855974


----------
Epoch 18/19
----------
Epoch: [18][1/168] Data 0.444 (0.444) Elapsed 0m 1s (remain 3m 17s) Loss: 0.001(0.001) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 54 
Epoch: [18][51/168] Data 0.004 (0.014) Elapsed 0m 37s (remain 1m 27s) Loss: 0.000(0.002) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 86 
Epoch: [18][101/168] Data 0.004 (0.009) Elapsed 1m 14s (remain 0m 49s) Loss: 0.000(0.003) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 87 
Epoch: [18][151/168] Data 0.004 (0.008) Elapsed 1m 51s (remain 0m 12s) Loss: 0.001(0.003) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 86 
Epoch: [18][168/168] Data 0.005 (0.008) Elapsed 2m 4s (remain 0m 0s) Loss: 0.008(0.003) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 86 


100%|██████████| 2716/2716 [02:18<00:00, 19.67it/s]

Train loss 0.002604080005220427 accuracy 0.9993485948259818
validate loss 0.7664422804734462 accuracy 0.8851251840942562
validate f1-score:  0.36618094138233637


----------
Epoch 19/19
----------





Epoch: [19][1/168] Data 0.464 (0.464) Elapsed 0m 1s (remain 3m 26s) Loss: 0.001(0.001) Acc: 1.000(1.000) f1_Acc: 1.000(1.000) sent/s 52 
Epoch: [19][51/168] Data 0.005 (0.014) Elapsed 0m 37s (remain 1m 27s) Loss: 0.022(0.002) Acc: 0.984(0.992) f1_Acc: 0.798(0.899) sent/s 86 
Epoch: [19][101/168] Data 0.010 (0.010) Elapsed 1m 14s (remain 0m 49s) Loss: 0.000(0.002) Acc: 1.000(0.995) f1_Acc: 1.000(0.933) sent/s 86 
Epoch: [19][151/168] Data 0.005 (0.008) Elapsed 1m 52s (remain 0m 12s) Loss: 0.000(0.004) Acc: 1.000(0.996) f1_Acc: 1.000(0.950) sent/s 86 
Epoch: [19][168/168] Data 0.004 (0.008) Elapsed 2m 5s (remain 0m 0s) Loss: 0.001(0.003) Acc: 1.000(0.997) f1_Acc: 1.000(0.959) sent/s 86 


100%|██████████| 2716/2716 [02:19<00:00, 19.40it/s]

Train loss 0.003394321558690494 accuracy 0.9988833054159687
validate loss 0.7681011658622294 accuracy 0.8862297496318114
validate f1-score:  0.36765774962107306







In [22]:
f1_score(preds_arr, targets_max_arr, average='macro')

0.36765774962107306

In [23]:
f1_score(preds_arr, targets_max_arr, average='micro')

0.8862297496318114

In [24]:
from sklearn.metrics import accuracy_score
accuracy_score(preds_arr,targets_max_arr)

0.8862297496318114

In [25]:
from sklearn.metrics import balanced_accuracy_score
balanced_accuracy_score(targets_max_arr, preds_arr)

0.43061963116057417