In [None]:
#DialogueRNN

In [None]:
cd "/content/drive/My Drive/"

/content/drive/My Drive


In [None]:
#MELD data Download
!git clone https://github.com/declare-lab/MELD.git

import yaml
with open("/content/drive/MyDrive/MELD/data/MELD/datasets.yaml") as f:
    file = yaml.full_load(f)

fatal: destination path 'MELD' already exists and is not an empty directory.


In [None]:
#DialogueRNN, bs_LSTM
!git clone https://github.com/declare-lab/conv-emotion.git

fatal: destination path 'conv-emotion' already exists and is not an empty directory.


In [None]:
!pip install tensorboardX

Collecting tensorboardX
  Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl (101 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/101.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m101.7/101.7 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tensorboardX
Successfully installed tensorboardX-2.6.2.2


In [None]:
#DialogueRNN_Model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

class SimpleAttention(nn.Module):

    def __init__(self, input_dim):
        super(SimpleAttention, self).__init__()
        self.input_dim = input_dim
        self.scalar = nn.Linear(self.input_dim,1,bias=False)

    def forward(self, M, x=None):
        """
        M -> (seq_len, batch, vector)
        x -> dummy argument for the compatibility with MatchingAttention
        """
        scale = self.scalar(M) # seq_len, batch, 1
        alpha = F.softmax(scale, dim=0).permute(1,2,0) # batch, 1, seq_len
        attn_pool = torch.bmm(alpha, M.transpose(0,1))[:,0,:] # batch, vector

        return attn_pool, alpha

class MatchingAttention(nn.Module):

    def __init__(self, mem_dim, cand_dim, alpha_dim=None, att_type='general'):
        super(MatchingAttention, self).__init__()
        assert att_type!='concat' or alpha_dim!=None
        assert att_type!='dot' or mem_dim==cand_dim
        self.mem_dim = mem_dim
        self.cand_dim = cand_dim
        self.att_type = att_type
        if att_type=='general':
            self.transform = nn.Linear(cand_dim, mem_dim, bias=False)
        if att_type=='general2':
            self.transform = nn.Linear(cand_dim, mem_dim, bias=True)
            #torch.nn.init.normal_(self.transform.weight,std=0.01)
        elif att_type=='concat':
            self.transform = nn.Linear(cand_dim+mem_dim, alpha_dim, bias=False)
            self.vector_prod = nn.Linear(alpha_dim, 1, bias=False)

    def forward(self, M, x, mask=None):
        """
        M -> (seq_len, batch, mem_dim)
        x -> (batch, cand_dim)
        mask -> (batch, seq_len)
        """
        if type(mask)==type(None):
            mask = torch.ones(M.size(1), M.size(0)).type(M.type())

        if self.att_type=='dot':
            # vector = cand_dim = mem_dim
            M_ = M.permute(1,2,0) # batch, vector, seqlen
            x_ = x.unsqueeze(1) # batch, 1, vector
            alpha = F.softmax(torch.bmm(x_, M_), dim=2) # batch, 1, seqlen
        elif self.att_type=='general':
            M_ = M.permute(1,2,0) # batch, mem_dim, seqlen
            x_ = self.transform(x).unsqueeze(1) # batch, 1, mem_dim
            alpha = F.softmax(torch.bmm(x_, M_), dim=2) # batch, 1, seqlen
        elif self.att_type=='general2':
            M_ = M.permute(1,2,0) # batch, mem_dim, seqlen
            x_ = self.transform(x).unsqueeze(1) # batch, 1, mem_dim
            alpha_ = F.softmax((torch.bmm(x_, M_))*mask.unsqueeze(1), dim=2) # batch, 1, seqlen
            alpha_masked = alpha_*mask.unsqueeze(1) # batch, 1, seqlen
            alpha_sum = torch.sum(alpha_masked, dim=2, keepdim=True) # batch, 1, 1
            alpha = alpha_masked/alpha_sum # batch, 1, 1 ; normalized
            #import ipdb;ipdb.set_trace()
        else:
            M_ = M.transpose(0,1) # batch, seqlen, mem_dim
            x_ = x.unsqueeze(1).expand(-1,M.size()[0],-1) # batch, seqlen, cand_dim
            M_x_ = torch.cat([M_,x_],2) # batch, seqlen, mem_dim+cand_dim
            mx_a = F.tanh(self.transform(M_x_)) # batch, seqlen, alpha_dim
            alpha = F.softmax(self.vector_prod(mx_a),1).transpose(1,2) # batch, 1, seqlen

        attn_pool = torch.bmm(alpha, M.transpose(0,1))[:,0,:] # batch, mem_dim

        return attn_pool, alpha


class DialogueRNNCell(nn.Module):

    def __init__(self, D_m, D_g, D_p, D_e, listener_state=False,
                            context_attention='simple', D_a=100, dropout=0.5):
        super(DialogueRNNCell, self).__init__()

        self.D_m = D_m
        self.D_g = D_g
        self.D_p = D_p
        self.D_e = D_e

        self.listener_state = listener_state
        self.g_cell = nn.GRUCell(D_m+D_p,D_g)
        self.p_cell = nn.GRUCell(D_m+D_g,D_p)
        self.e_cell = nn.GRUCell(D_p,D_e)
        if listener_state:
            self.l_cell = nn.GRUCell(D_m+D_p,D_p)

        self.dropout = nn.Dropout(dropout)

        if context_attention=='simple':
            self.attention = SimpleAttention(D_g)
        else:
            self.attention = MatchingAttention(D_g, D_m, D_a, context_attention)

    def _select_parties(self, X, indices):
        q0_sel = []
        for idx, j in zip(indices, X):
            q0_sel.append(j[idx].unsqueeze(0))
        q0_sel = torch.cat(q0_sel,0)
        return q0_sel

    def forward(self, U, qmask, g_hist, q0, e0):
        """
        U -> batch, D_m
        qmask -> batch, party
        g_hist -> t-1, batch, D_g
        q0 -> batch, party, D_p
        e0 -> batch, self.D_e
        """
        qm_idx = torch.argmax(qmask, 1)
        q0_sel = self._select_parties(q0, qm_idx)

        g_ = self.g_cell(torch.cat([U,q0_sel], dim=1),
                torch.zeros(U.size()[0],self.D_g).type(U.type()) if g_hist.size()[0]==0 else
                g_hist[-1])
        g_ = self.dropout(g_)
        if g_hist.size()[0]==0:
            c_ = torch.zeros(U.size()[0],self.D_g).type(U.type())
            alpha = None
        else:
            c_, alpha = self.attention(g_hist,U)
        # c_ = torch.zeros(U.size()[0],self.D_g).type(U.type()) if g_hist.size()[0]==0\
        #         else self.attention(g_hist,U)[0] # batch, D_g
        U_c_ = torch.cat([U,c_], dim=1).unsqueeze(1).expand(-1,qmask.size()[1],-1)
        qs_ = self.p_cell(U_c_.contiguous().view(-1,self.D_m+self.D_g),
                q0.view(-1, self.D_p)).view(U.size()[0],-1,self.D_p)
        qs_ = self.dropout(qs_)

        if self.listener_state:
            U_ = U.unsqueeze(1).expand(-1,qmask.size()[1],-1).contiguous().view(-1,self.D_m)
            ss_ = self._select_parties(qs_, qm_idx).unsqueeze(1).\
                    expand(-1,qmask.size()[1],-1).contiguous().view(-1,self.D_p)
            U_ss_ = torch.cat([U_,ss_],1)
            ql_ = self.l_cell(U_ss_,q0.view(-1, self.D_p)).view(U.size()[0],-1,self.D_p)
            ql_ = self.dropout(ql_)
        else:
            ql_ = q0
        qmask_ = qmask.unsqueeze(2)
        q_ = ql_*(1-qmask_) + qs_*qmask_
        e0 = torch.zeros(qmask.size()[0], self.D_e).type(U.type()) if e0.size()[0]==0\
                else e0
        e_ = self.e_cell(self._select_parties(q_,qm_idx), e0)
        e_ = self.dropout(e_)

        return g_,q_,e_,alpha

class DialogueRNN(nn.Module):

    def __init__(self, D_m, D_g, D_p, D_e, listener_state=False,
                            context_attention='simple', D_a=100, dropout=0.5):
        super(DialogueRNN, self).__init__()

        self.D_m = D_m
        self.D_g = D_g
        self.D_p = D_p
        self.D_e = D_e
        self.dropout = nn.Dropout(dropout)

        self.dialogue_cell = DialogueRNNCell(D_m, D_g, D_p, D_e,
                            listener_state, context_attention, D_a, dropout)

    def forward(self, U, qmask):
        """
        U -> seq_len, batch, D_m
        qmask -> seq_len, batch, party
        """

        g_hist = torch.zeros(0).type(U.type()) # 0-dimensional tensor
        q_ = torch.zeros(qmask.size()[1], qmask.size()[2],
                                    self.D_p).type(U.type()) # batch, party, D_p
        e_ = torch.zeros(0).type(U.type()) # batch, D_e
        e = e_

        alpha = []
        for u_,qmask_ in zip(U, qmask):
            g_, q_, e_, alpha_ = self.dialogue_cell(u_, qmask_, g_hist, q_, e_)
            g_hist = torch.cat([g_hist, g_.unsqueeze(0)],0)
            e = torch.cat([e, e_.unsqueeze(0)],0)
            if type(alpha_)!=type(None):
                alpha.append(alpha_[:,0,:])

        return e,alpha # seq_len, batch, D_e
class BiModel(nn.Module):

    def __init__(self, D_m, D_g, D_p, D_e, D_h,
                 n_classes=7, listener_state=False, context_attention='simple', D_a=100, dropout_rec=0.5,
                 dropout=0.5):
        super(BiModel, self).__init__()

        self.D_m       = D_m
        self.D_g       = D_g
        self.D_p       = D_p
        self.D_e       = D_e
        self.D_h       = D_h
        self.n_classes = n_classes
        self.dropout   = nn.Dropout(dropout)
        self.dropout_rec = nn.Dropout(dropout+0.15)
        self.dialog_rnn_f = DialogueRNN(D_m, D_g, D_p, D_e,listener_state,
                                    context_attention, D_a, dropout_rec)
        self.dialog_rnn_r = DialogueRNN(D_m, D_g, D_p, D_e,listener_state,
                                    context_attention, D_a, dropout_rec)
        self.linear     = nn.Linear(2*D_e, 2*D_h)
        self.smax_fc    = nn.Linear(2*D_h, n_classes)
        self.matchatt = MatchingAttention(2*D_e,2*D_e,att_type='general2')

    def _reverse_seq(self, X, mask):
        """
        X -> seq_len, batch, dim
        mask -> batch, seq_len
        """
        X_ = X.transpose(0,1)
        mask_sum = torch.sum(mask, 1).int()

        xfs = []
        for x, c in zip(X_, mask_sum):
            xf = torch.flip(x[:c], [0])
            xfs.append(xf)

        return pad_sequence(xfs)


    def forward(self, U, qmask, umask,att2=True):
        """
        U -> seq_len, batch, D_m
        qmask -> seq_len, batch, party
        """

        emotions_f, alpha_f = self.dialog_rnn_f(U, qmask) # seq_len, batch, D_e
        emotions_f = self.dropout_rec(emotions_f)
        rev_U = self._reverse_seq(U, umask)
        rev_qmask = self._reverse_seq(qmask, umask)
        emotions_b, alpha_b = self.dialog_rnn_r(rev_U, rev_qmask)
        emotions_b = self._reverse_seq(emotions_b, umask)
        emotions_b = self.dropout_rec(emotions_b)
        emotions = torch.cat([emotions_f,emotions_b],dim=-1)
        if att2:
            att_emotions = []
            alpha = []
            for t in emotions:
                att_em, alpha_ = self.matchatt(emotions,t,mask=umask)
                att_emotions.append(att_em.unsqueeze(0))
                alpha.append(alpha_[:,0,:])
            att_emotions = torch.cat(att_emotions,dim=0)
            hidden = F.relu(self.linear(att_emotions))
        else:
            hidden = F.relu(self.linear(emotions))
        #hidden = F.relu(self.linear(emotions))
        hidden = self.dropout(hidden)
        log_prob = F.log_softmax(self.smax_fc(hidden), 2) # seq_len, batch, n_classes
        if att2:
            return log_prob, alpha, alpha_f, alpha_b
        else:
            return log_prob, [], alpha_f, alpha_b

class BiE2EModel(nn.Module):

    def __init__(self, D_emb, D_m, D_g, D_p, D_e, D_h, word_embeddings,
                 n_classes=7, listener_state=False, context_attention='simple', D_a=100, dropout_rec=0.5,
                 dropout=0.5):
        super(BiE2EModel, self).__init__()

        self.D_emb     = D_emb
        self.D_m       = D_m
        self.D_g       = D_g
        self.D_p       = D_p
        self.D_e       = D_e
        self.D_h       = D_h
        self.n_classes = n_classes
        self.dropout   = nn.Dropout(dropout)
        #self.dropout_rec = nn.Dropout(0.2)
        self.dropout_rec = nn.Dropout(dropout)
        self.turn_rnn = nn.GRU(D_emb, D_m)
        self.dialog_rnn_f = DialogueRNN(D_m, D_g, D_p, D_e,listener_state,
                                    context_attention, D_a, dropout_rec)
        self.dialog_rnn_r = DialogueRNN(D_m, D_g, D_p, D_e,listener_state,
                                    context_attention, D_a, dropout_rec)
        self.linear1     = nn.Linear(2*D_e, D_h)
        #self.linear2     = nn.Linear(D_h, D_h)
        #self.linear3     = nn.Linear(D_h, D_h)
        self.smax_fc    = nn.Linear(D_h, n_classes)
        self.embedding = nn.Embedding(word_embeddings.shape[0],word_embeddings.shape[1])
        self.embedding.weight.data.copy_(word_embeddings)
        self.embedding.weight.requires_grad = True
        self.matchatt = MatchingAttention(2*D_e,2*D_e,att_type='general2')
    def _reverse_seq(self, X, mask):
        """
        X -> seq_len, batch, dim
        mask -> batch, seq_len
        """
        X_ = X.transpose(0,1)
        mask_sum = torch.sum(mask, 1).int()

        xfs = []
        for x, c in zip(X_, mask_sum):
            xf = torch.flip(x[:c], [0])
            xfs.append(xf)

        return pad_sequence(xfs)

    def forward(self, data, att2=False):

        #T1 = word_embeddings[data.turn1] # seq_len, batch, D_emb
        #T2 = word_embeddings[data.turn2] # seq_len, batch, D_emb
        #T3 = word_embeddings[data.turn3] # seq_len, batch, D_emb

        T1 = (self.embedding(data.turn1))
        T2 = (self.embedding(data.turn2))
        T3 = (self.embedding(data.turn3))

        T1_, h_out1 = self.turn_rnn(T1,
                                    torch.zeros(1, T1.size(1), self.D_m).type(T1.type()))
        T2_, h_out2 = self.turn_rnn(T2,
                                    torch.zeros(1, T1.size(1), self.D_m).type(T1.type()))
        T3_, h_out3 = self.turn_rnn(T3,
                                    torch.zeros(1, T1.size(1), self.D_m).type(T1.type()))

        U = torch.cat([h_out1, h_out2, h_out3], 0) # 3, batch, D_m

        qmask = torch.FloatTensor([[1,0],[0,1],[1,0]]).type(T1.type())
        qmask = qmask.unsqueeze(1).expand(-1, T1.size(1), -1)

        umask = torch.FloatTensor([[1,1,1]]).type(T1.type())
        umask = umask.expand( T1.size(1),-1)

        emotions_f, alpha_f = self.dialog_rnn_f(U, qmask) # seq_len, batch, D_e
        emotions_f = self.dropout_rec(emotions_f)
        rev_U = self._reverse_seq(U, umask)
        rev_qmask = self._reverse_seq(qmask, umask)
        emotions_b, alpha_b = self.dialog_rnn_r(rev_U, rev_qmask)
        emotions_b = self._reverse_seq(emotions_b, umask)
        #emotions_b = self.dropout_rec(emotions_b)
        emotions = torch.cat([emotions_f,emotions_b],dim=-1)
        #print(emotions)
        emotions = self.dropout_rec(emotions)

        #emotions = emotions.unsqueeze(1)
        if att2:
            att_emotion, _ = self.matchatt(emotions, emotions[-1])
            hidden = F.relu(self.linear1(att_emotion))
        else:
            hidden = F.relu(self.linear1(emotions[-1]))
        #hidden = F.relu(self.linear2(hidden))
        #hidden = F.relu(self.linear3(hidden))
       # hidden = self.dropout(hidden)
        log_prob = F.log_softmax(self.smax_fc(hidden), -1) # batch, n_classes
        return log_prob

class E2EModel(nn.Module):

    def __init__(self, D_emb, D_m, D_g, D_p, D_e, D_h,
                 n_classes=7, listener_state=False, context_attention='simple', D_a=100, dropout_rec=0.5,
                 dropout=0.5):
        super(E2EModel, self).__init__()

        self.D_emb     = D_emb
        self.D_m       = D_m
        self.D_g       = D_g
        self.D_p       = D_p
        self.D_e       = D_e
        self.D_h       = D_h
        self.n_classes = n_classes
        self.dropout   = nn.Dropout(dropout)
        #self.dropout_rec = nn.Dropout(0.2)
        self.dropout_rec = nn.Dropout(dropout+0.15)
        self.turn_rnn = nn.GRU(D_emb, D_m)
        self.dialog_rnn = DialogueRNN(D_m, D_g, D_p, D_e,listener_state,
                                    context_attention, D_a, dropout_rec)
        self.linear1     = nn.Linear(D_e, D_h)
        #self.linear2     = nn.Linear(D_h, D_h)
        #self.linear3     = nn.Linear(D_h, D_h)
        self.smax_fc    = nn.Linear(D_h, n_classes)

        self.matchatt = MatchingAttention(D_e,D_e,att_type='general2')

    def forward(self, data, word_embeddings, att2=False):

        T1 = word_embeddings[data.turn1] # seq_len, batch, D_emb
        T2 = word_embeddings[data.turn2] # seq_len, batch, D_emb
        T3 = word_embeddings[data.turn3] # seq_len, batch, D_emb

        T1_, h_out1 = self.turn_rnn(T1,
                                    torch.zeros(1, T1.size(1), self.D_m).type(T1.type()))
        T2_, h_out2 = self.turn_rnn(T2,
                                    torch.zeros(1, T1.size(1), self.D_m).type(T1.type()))
        T3_, h_out3 = self.turn_rnn(T3,
                                    torch.zeros(1, T1.size(1), self.D_m).type(T1.type()))

        U = torch.cat([h_out1, h_out2, h_out3], 0) # 3, batch, D_m

        qmask = torch.FloatTensor([[1,0],[0,1],[1,0]]).type(T1.type())
        qmask = qmask.unsqueeze(1).expand(-1, T1.size(1), -1)

        emotions, _ = self.dialog_rnn(U, qmask) # seq_len, batch, D_e
        #print(emotions)
        emotions = self.dropout_rec(emotions)

        #emotions = emotions.unsqueeze(1)
        if att2:
            att_emotion, _ = self.matchatt(emotions,emotions[-1])
            hidden = F.relu(self.linear1(att_emotion))
        else:
            hidden = F.relu(self.linear1(emotions[-1]))
        #hidden = F.relu(self.linear2(hidden))
        #hidden = F.relu(self.linear3(hidden))
        hidden = self.dropout(hidden)
        log_prob = F.log_softmax(self.smax_fc(hidden), -1) # batch, n_classes
        return log_prob
class Model(nn.Module):

    def __init__(self, D_m, D_g, D_p, D_e, D_h,
                 n_classes=7, listener_state=False, context_attention='simple', D_a=100, dropout_rec=0.5,
                 dropout=0.5):
        super(Model, self).__init__()

        self.D_m       = D_m
        self.D_g       = D_g
        self.D_p       = D_p
        self.D_e       = D_e
        self.D_h       = D_h
        self.n_classes = n_classes
        self.dropout   = nn.Dropout(dropout)
        #self.dropout_rec = nn.Dropout(0.2)
        self.dropout_rec = nn.Dropout(dropout+0.15)
        self.dialog_rnn = DialogueRNN(D_m, D_g, D_p, D_e,listener_state,
                                    context_attention, D_a, dropout_rec)
        self.linear1     = nn.Linear(D_e, D_h)
        #self.linear2     = nn.Linear(D_h, D_h)
        #self.linear3     = nn.Linear(D_h, D_h)
        self.smax_fc    = nn.Linear(D_h, n_classes)

        self.matchatt = MatchingAttention(D_e,D_e,att_type='general2')

    def forward(self, U, qmask, umask=None, att2=False):
        """
        U -> seq_len, batch, D_m
        qmask -> seq_len, batch, party
        """

        emotions = self.dialog_rnn(U, qmask) # seq_len, batch, D_e
        #print(emotions)
        emotions = self.dropout_rec(emotions)

        #emotions = emotions.unsqueeze(1)
        if att2:
            att_emotions = []
            for t in emotions:
                att_emotions.append(self.matchatt(emotions,t,mask=umask)[0].unsqueeze(0))
            att_emotions = torch.cat(att_emotions,dim=0)
            hidden = F.relu(self.linear1(att_emotions))
        else:
            hidden = F.relu(self.linear1(emotions))
        #hidden = F.relu(self.linear2(hidden))
        #hidden = F.relu(self.linear3(hidden))
        hidden = self.dropout(hidden)
        log_prob = F.log_softmax(self.smax_fc(hidden), 2) # seq_len, batch, n_classes
        return log_prob

class AVECModel(nn.Module):

    def __init__(self, D_m, D_g, D_p, D_e, D_h, attr, listener_state=False,
            context_attention='simple', D_a=100, dropout_rec=0.5, dropout=0.5):
        super(AVECModel, self).__init__()

        self.D_m         = D_m
        self.D_g         = D_g
        self.D_p         = D_p
        self.D_e         = D_e
        self.D_h         = D_h
        self.attr        = attr
        self.dropout     = nn.Dropout(dropout)
        self.dropout_rec = nn.Dropout(dropout)
        self.dialog_rnn  = DialogueRNN(D_m, D_g, D_p, D_e,listener_state,
                                    context_attention, D_a, dropout_rec)
        self.linear      = nn.Linear(D_e, D_h)
        self.smax_fc     = nn.Linear(D_h, 1)

    def forward(self, U, qmask):
        """
        U -> seq_len, batch, D_m
        qmask -> seq_len, batch, party
        """

        emotions,_ = self.dialog_rnn(U, qmask) # seq_len, batch, D_e
        emotions = self.dropout_rec(emotions)
        hidden = torch.tanh(self.linear(emotions))
        hidden = self.dropout(hidden)
        if self.attr!=4:
            pred = (self.smax_fc(hidden).squeeze()) # seq_len, batch
        else:
            pred = (self.smax_fc(hidden).squeeze()) # seq_len, batch
        return pred.transpose(0,1).contiguous().view(-1)

class MaskedNLLLoss(nn.Module):

    def __init__(self, weight=None):
        super(MaskedNLLLoss, self).__init__()
        self.weight = weight
        self.loss = nn.NLLLoss(weight=weight,
                               reduction='sum')

    def forward(self, pred, target, mask):
        """
        pred -> batch*seq_len, n_classes
        target -> batch*seq_len
        mask -> batch, seq_len
        """
        mask_ = mask.view(-1,1) # batch*seq_len, 1
        if type(self.weight)==type(None):
            loss = self.loss(pred*mask_, target)/torch.sum(mask)
        else:
            loss = self.loss(pred*mask_, target)\
                            /torch.sum(self.weight[target]*mask_.squeeze())
        return loss

class MaskedMSELoss(nn.Module):

    def __init__(self):
        super(MaskedMSELoss, self).__init__()
        self.loss = nn.MSELoss(reduction='sum')

    def forward(self, pred, target, mask):
        """
        pred -> batch*seq_len
        target -> batch*seq_len
        mask -> batch*seq_len
        """
        loss = self.loss(pred*mask, target)/torch.sum(mask)
        return loss

if torch.cuda.is_available():
    FloatTensor = torch.cuda.FloatTensor
    LongTensor = torch.cuda.LongTensor
    ByteTensor = torch.cuda.ByteTensor

else:
    FloatTensor = torch.FloatTensor
    LongTensor = torch.LongTensor
    ByteTensor = torch.ByteTensor

class CNNFeatureExtractor(nn.Module):

    def __init__(self, vocab_size, embedding_dim, output_size, filters, kernel_sizes, dropout):
        super(CNNFeatureExtractor, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.convs = nn.ModuleList([nn.Conv1d(in_channels=embedding_dim, out_channels=filters, kernel_size=K) for K in kernel_sizes])
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(len(kernel_sizes) * filters, output_size)
        self.feature_dim = output_size


    def init_pretrained_embeddings_from_numpy(self, pretrained_word_vectors):
        self.embedding.weight = nn.Parameter(torch.from_numpy(pretrained_word_vectors).float())
        # if is_static:
        self.embedding.weight.requires_grad = False


    def forward(self, x, umask):

        num_utt, batch, num_words = x.size()

        x = x.type(LongTensor)  # (num_utt, batch, num_words)
        x = x.view(-1, num_words) # (num_utt, batch, num_words) -> (num_utt * batch, num_words)
        emb = self.embedding(x) # (num_utt * batch, num_words) -> (num_utt * batch, num_words, 300)
        emb = emb.transpose(-2, -1).contiguous() # (num_utt * batch, num_words, 300)  -> (num_utt * batch, 300, num_words)

        convoluted = [F.relu(conv(emb)) for conv in self.convs]
        pooled = [F.max_pool1d(c, c.size(2)).squeeze() for c in convoluted]
        concated = torch.cat(pooled, 1)
        features = F.relu(self.fc(self.dropout(concated))) # (num_utt * batch, 150) -> (num_utt * batch, 100)
        features = features.view(num_utt, batch, -1) # (num_utt * batch, 100) -> (num_utt, batch, 100)
        mask = umask.unsqueeze(-1).type(FloatTensor) # (batch, num_utt) -> (batch, num_utt, 1)
        mask = mask.transpose(0, 1) # (batch, num_utt, 1) -> (num_utt, batch, 1)
        mask = mask.repeat(1, 1, self.feature_dim) #  (num_utt, batch, 1) -> (num_utt, batch, 100)
        features = (features * mask) # (num_utt, batch, 100) -> (num_utt, batch, 100)

        return features

class DailyDialogueModel(nn.Module):

    def __init__(self, D_m, D_g, D_p, D_e, D_h,
                 vocab_size, n_classes=7, embedding_dim=300,
                 cnn_output_size=100, cnn_filters=50, cnn_kernel_sizes=(3,4,5), cnn_dropout=0.5,
                 listener_state=False, context_attention='simple', D_a=100, dropout_rec=0.5,
                 dropout=0.5, att2=True):

        super(DailyDialogueModel, self).__init__()

        self.cnn_feat_extractor = CNNFeatureExtractor(vocab_size, embedding_dim, cnn_output_size, cnn_filters, cnn_kernel_sizes, cnn_dropout)

        self.D_m       = D_m
        self.D_g       = D_g
        self.D_p       = D_p
        self.D_e       = D_e
        self.D_h       = D_h
        self.dropout   = nn.Dropout(dropout)
        self.dropout_rec = nn.Dropout(dropout_rec)
        self.dialog_rnn_f = DialogueRNN(D_m, D_g, D_p, D_e, listener_state,
                                    context_attention, D_a, dropout_rec)
        self.dialog_rnn_r = DialogueRNN(D_m, D_g, D_p, D_e, listener_state,
                                    context_attention, D_a, dropout_rec)
        self.linear     = nn.Linear(2*D_e, 2*D_h)
        self.matchatt = MatchingAttention(2*D_e,2*D_e,att_type='general2')

        self.n_classes = n_classes
        self.smax_fc    = nn.Linear(2*D_h, n_classes)
        self.att2 = att2



    def init_pretrained_embeddings(self, pretrained_word_vectors):
        self.cnn_feat_extractor.init_pretrained_embeddings_from_numpy(pretrained_word_vectors)


    def _reverse_seq(self, X, mask):
        """
        X -> seq_len, batch, dim
        mask -> batch, seq_len
        """
        X_ = X.transpose(0,1)
        mask_sum = torch.sum(mask, 1).int()

        xfs = []
        for x, c in zip(X_, mask_sum):
            xf = torch.flip(x[:c], [0])
            xfs.append(xf)

        return pad_sequence(xfs)


    def forward(self, input_seq, qmask, umask):
        """
        U -> seq_len, batch, D_m
        qmask -> seq_len, batch, party
        """

        U = self.cnn_feat_extractor(input_seq, umask)

        emotions_f, alpha_f = self.dialog_rnn_f(U, qmask) # seq_len, batch, D_e
        emotions_f = self.dropout_rec(emotions_f)
        rev_U = self._reverse_seq(U, umask)
        rev_qmask = self._reverse_seq(qmask, umask)
        emotions_b, alpha_b = self.dialog_rnn_r(rev_U, rev_qmask)
        emotions_b = self._reverse_seq(emotions_b, umask)
        emotions_b = self.dropout_rec(emotions_b)
        emotions = torch.cat([emotions_f, emotions_b], dim=-1)
        if self.att2:
            att_emotions = []
            alpha = []
            for t in emotions:
                att_em, alpha_ = self.matchatt(emotions,t,mask=umask)
                att_emotions.append(att_em.unsqueeze(0))
                alpha.append(alpha_[:,0,:])
            att_emotions = torch.cat(att_emotions,dim=0)
            hidden = F.relu(self.linear(att_emotions))
        else:
            hidden = F.relu(self.linear(emotions))
        # hidden = F.relu(self.linear(emotions))
        hidden = self.dropout(hidden)
        log_prob = F.log_softmax(self.smax_fc(hidden), 2) # seq_len, batch, n_classes
        return log_prob, alpha, alpha_f, alpha_b

class UnMaskedWeightedNLLLoss(nn.Module):

    def __init__(self, weight=None):
        super(UnMaskedWeightedNLLLoss, self).__init__()
        self.weight = weight
        self.loss = nn.NLLLoss(weight=weight,
                               reduction='sum')

    def forward(self, pred, target):
        """
        pred -> batch*seq_len, n_classes
        target -> batch*seq_len
        """
        if type(self.weight)==type(None):
            loss = self.loss(pred, target)
        else:
            loss = self.loss(pred, target)\
                            /torch.sum(self.weight[target])
        return loss


In [None]:
#Dataloader

import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
import pickle
import pandas as pd

class IEMOCAPDataset(Dataset):

    def __init__(self, path, train=True):
        self.videoIDs, self.videoSpeakers, self.videoLabels, self.videoText,\
        self.videoAudio, self.videoVisual, self.videoSentence, self.trainVid,\
        self.testVid = pickle.load(open(path, 'rb'), encoding='latin1')
        '''
        label index mapping = {'hap':0, 'sad':1, 'neu':2, 'ang':3, 'exc':4, 'fru':5}
        '''
        self.keys = [x for x in (self.trainVid if train else self.testVid)]

        self.len = len(self.keys)

    def __getitem__(self, index):
        vid = self.keys[index]
        return torch.FloatTensor(self.videoText[vid]),\
               torch.FloatTensor(self.videoVisual[vid]),\
               torch.FloatTensor(self.videoAudio[vid]),\
               torch.FloatTensor([[1,0] if x=='M' else [0,1] for x in\
                                  self.videoSpeakers[vid]]),\
               torch.FloatTensor([1]*len(self.videoLabels[vid])),\
               torch.LongTensor(self.videoLabels[vid]),\
               vid

    def __len__(self):
        return self.len

    def collate_fn(self, data):
        dat = pd.DataFrame(data)
        return [pad_sequence(dat[i]) if i<4 else pad_sequence(dat[i], True) if i<6 else dat[i].tolist() for i in dat]

class AVECDataset(Dataset):

    def __init__(self, path, train=True):
        self.videoIDs, self.videoSpeakers, self.videoLabels, self.videoText,\
            self.videoAudio, self.videoVisual, self.videoSentence,\
            self.trainVid, self.testVid = pickle.load(open(path, 'rb'),encoding='latin1')

        self.keys = [x for x in (self.trainVid if train else self.testVid)]

        self.len = len(self.keys)

    def __getitem__(self, index):
        vid = self.keys[index]
        return torch.FloatTensor(self.videoText[vid]),\
               torch.FloatTensor(self.videoVisual[vid]),\
               torch.FloatTensor(self.videoAudio[vid]),\
               torch.FloatTensor([[1,0] if x=='user' else [0,1] for x in\
                                  self.videoSpeakers[vid]]),\
               torch.FloatTensor([1]*len(self.videoLabels[vid])),\
               torch.FloatTensor(self.videoLabels[vid])

    def __len__(self):
        return self.len

    def collate_fn(self, data):
        dat = pd.DataFrame(data)
        return [pad_sequence(dat[i]) if i<4 else pad_sequence(dat[i], True) for i in dat]
class MELDDataset(Dataset):

    def __init__(self, path, n_classes, train=True):
        if n_classes == 3:
            self.videoIDs, self.videoSpeakers, _, self.videoText,\
            self.videoAudio, self.videoSentence, self.trainVid,\
            self.testVid, self.videoLabels = pickle.load(open(path, 'rb'))
        elif n_classes == 7:
            self.videoIDs, self.videoSpeakers, self.videoLabels, self.videoText,\
            self.videoAudio, self.videoSentence, self.trainVid,\
            self.testVid, _ = pickle.load(open(path, 'rb'))
        '''
        label index mapping = {'neutral': 0, 'surprise': 1, 'fear': 2, 'sadness': 3, 'joy': 4, 'disgust': 5, 'anger':6}
        '''
        self.keys = [x for x in (self.trainVid if train else self.testVid)]

        self.len = len(self.keys)

    def __getitem__(self, index):
        vid = self.keys[index]
        return torch.FloatTensor(self.videoText[vid]),\
               torch.FloatTensor(self.videoAudio[vid]),\
               torch.FloatTensor(self.videoSpeakers[vid]),\
               torch.FloatTensor([1]*len(self.videoLabels[vid])),\
               torch.LongTensor(self.videoLabels[vid]),\
               vid

    def __len__(self):
        return self.len

    def collate_fn(self, data):
        dat = pd.DataFrame(data)
        return [pad_sequence(dat[i]) if i<3 else pad_sequence(dat[i], True) if i<5 else dat[i].tolist() for i in dat]


class DailyDialogueDataset(Dataset):

    def __init__(self, split, path):

        self.Speakers, self.InputSequence, self.InputMaxSequenceLength, \
        self.ActLabels, self.EmotionLabels, self.trainId, self.testId, self.validId = pickle.load(open(path, 'rb'))

        if split == 'train':
            self.keys = [x for x in self.trainId]
        elif split == 'test':
            self.keys = [x for x in self.testId]
        elif split == 'valid':
            self.keys = [x for x in self.validId]

        self.len = len(self.keys)

    def __getitem__(self, index):
        conv = self.keys[index]

        return torch.LongTensor(self.InputSequence[conv]), \
                torch.FloatTensor([[1,0] if x=='0' else [0,1] for x in self.Speakers[conv]]),\
                torch.FloatTensor([1]*len(self.ActLabels[conv])), \
                torch.LongTensor(self.ActLabels[conv]), \
                torch.LongTensor(self.EmotionLabels[conv]), \
                self.InputMaxSequenceLength[conv], \
                conv

    def __len__(self):
        return self.len



class DailyDialoguePadCollate:

    def __init__(self, dim=0):
        self.dim = dim

    def pad_tensor(self, vec, pad, dim):

        pad_size = list(vec.shape)
        pad_size[dim] = pad - vec.size(dim)
        return torch.cat([vec, torch.zeros(*pad_size).type(torch.LongTensor)], dim=dim)

    def pad_collate(self, batch):

        # find longest sequence
        max_len = max(map(lambda x: x.shape[self.dim], batch))

        # pad according to max_len
        batch = [self.pad_tensor(x, pad=max_len, dim=self.dim) for x in batch]

        # stack all
        return torch.stack(batch, dim=0)

    def __call__(self, batch):
        dat = pd.DataFrame(batch)

        return [self.pad_collate(dat[i]).transpose(1, 0).contiguous() if i==0 else \
                pad_sequence(dat[i]) if i == 1 else \
                pad_sequence(dat[i], True) if i < 5 else \
                dat[i].tolist() for i in dat]


In [None]:
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
import pickle
import pandas as pd
import numpy as np


import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torch.optim as optim

import argparse
import time
import pickle

from sklearn.metrics import f1_score, confusion_matrix, accuracy_score,\
                        classification_report, precision_recall_fscore_support

# from model import BiModel, Model, MaskedNLLLoss
# from dataloader import MELDDataset
np.random.seed(1234)

def get_train_valid_sampler(trainset, valid=0.1):
    size = len(trainset)
    idx = list(range(size))
    split = int(valid*size)
    return SubsetRandomSampler(idx[split:]), SubsetRandomSampler(idx[:split])

def get_MELD_loaders(path, n_classes, batch_size=32, valid=0.1, num_workers=0, pin_memory=False):
    trainset = MELDDataset(path=path, n_classes=n_classes)
    train_sampler, valid_sampler = get_train_valid_sampler(trainset, valid)
    train_loader = DataLoader(trainset,
                              batch_size=batch_size,
                              sampler=train_sampler,
                              collate_fn=trainset.collate_fn,
                              num_workers=num_workers,
                              pin_memory=pin_memory)
    valid_loader = DataLoader(trainset,
                              batch_size=batch_size,
                              sampler=valid_sampler,
                              collate_fn=trainset.collate_fn,
                              num_workers=num_workers,
                              pin_memory=pin_memory)

    testset = MELDDataset(path=path, n_classes=n_classes, train=False)
    test_loader = DataLoader(testset,
                             batch_size=batch_size,
                             collate_fn=testset.collate_fn,
                             num_workers=num_workers,
                             pin_memory=pin_memory)

    return train_loader, valid_loader, test_loader

def train_or_eval_model(model, loss_function, dataloader, epoch, optimizer=None, train=False):
    losses = []
    preds = []
    labels = []
    masks = []
    alphas, alphas_f, alphas_b, vids = [], [], [], []
    assert not train or optimizer!=None
    if train:
        model.train()
    else:
        model.eval()
    for data in dataloader:
        if train:
            optimizer.zero_grad()
        # import ipdb;ipdb.set_trace()
        textf, acouf, qmask, umask, label =\
                [d.cuda() for d in data[:-1]] if cuda else data[:-1]
        if feature_type == "audio":
            log_prob, alpha, alpha_f, alpha_b = model(acouf, qmask,umask) # seq_len, batch, n_classes
        elif feature_type == "text":
            log_prob, alpha, alpha_f, alpha_b = model(textf, qmask,umask) # seq_len, batch, n_classes
        else:
            log_prob, alpha, alpha_f, alpha_b = model(torch.cat((textf,acouf),dim=-1), qmask,umask) # seq_len, batch, n_classes
        lp_ = log_prob.transpose(0,1).contiguous().view(-1,log_prob.size()[2]) # batch*seq_len, n_classes
        labels_ = label.view(-1) # batch*seq_len
        loss = loss_function(lp_, labels_, umask)

        pred_ = torch.argmax(lp_,1) # batch*seq_len
        preds.append(pred_.data.cpu().numpy())
        labels.append(labels_.data.cpu().numpy())
        masks.append(umask.view(-1).cpu().numpy())

        losses.append(loss.item()*masks[-1].sum())
        if train:
            loss.backward()
#             if args.tensorboard:
#                 for param in model.named_parameters():
#                     writer.add_histogram(param[0], param[1].grad, epoch)
            optimizer.step()
        else:
            alphas += alpha
            alphas_f += alpha_f
            alphas_b += alpha_b
            vids += data[-1]

    if preds!=[]:
        preds  = np.concatenate(preds)
        labels = np.concatenate(labels)
        masks  = np.concatenate(masks)
    else:
        return float('nan'), float('nan'), [], [], [], float('nan'),[]

    avg_loss = round(np.sum(losses)/np.sum(masks),4)
    avg_accuracy = round(accuracy_score(labels,preds,sample_weight=masks)*100,2)
    avg_fscore = round(f1_score(labels,preds,sample_weight=masks,average='weighted')*100,2)
    class_report = classification_report(labels,preds,sample_weight=masks,digits=4)
    return avg_loss, avg_accuracy, labels, preds, masks,avg_fscore, [alphas, alphas_f, alphas_b, vids], class_report

cuda = torch.cuda.is_available()
if cuda:
    print('Running on GPU')
else:
    print('Running on CPU')

tensorboard = True
if tensorboard:
    from tensorboardX import SummaryWriter
writer = SummaryWriter()

# choose between 'sentiment' or 'emotion'
classification_type = 'emotion'
feature_type = 'multimodal'

data_path = "/content/drive/MyDrive/DialogueRNN_features/DialogueRNN_features/MELD_features/"
batch_size = 30
n_classes = 3
n_epochs = 100
active_listener = False
attention = 'general'
class_weight = False
dropout = 0.1
rec_dropout = 0.1
l2 = 0.00001
lr = 0.0005

if feature_type == 'text':
    print("Running on the text features........")
    D_m = 600
elif feature_type == 'audio':
    print("Running on the audio features........")
    D_m = 300
else:
    print("Running on the multimodal features........")
    D_m = 900
D_g = 150
D_p = 150
D_e = 100
D_h = 100

D_a = 100 # concat attention

loss_weights = torch.FloatTensor([1.0,1.0,1.0])

if classification_type.strip().lower() == 'emotion':
    n_classes = 7
    loss_weights = torch.FloatTensor([1.0,1.0,1.0,1.0,1.0,1.0,1.0])

model = BiModel(D_m, D_g, D_p, D_e, D_h,
                n_classes=n_classes,
                listener_state=active_listener,
                context_attention=attention,
                dropout_rec=rec_dropout,
                dropout=dropout)

if cuda:
    model.cuda()
if class_weight:
    loss_function  = MaskedNLLLoss(loss_weights.cuda() if cuda else loss_weights)
else:
    loss_function = MaskedNLLLoss()
optimizer = optim.Adam(model.parameters(),
                       lr=lr,
                       weight_decay=l2)

train_loader, valid_loader, test_loader =\
        get_MELD_loaders(data_path + 'MELD_features_raw.pkl', n_classes,
                            valid=0.0,
                            batch_size=batch_size,
                            num_workers=0)

best_fscore, best_loss, best_label, best_pred, best_mask = None, None, None, None, None


for e in range(n_epochs):
    start_time = time.time()
    train_loss, train_acc, _,_,_,train_fscore,_,_= train_or_eval_model(model, loss_function,
                                           train_loader, e, optimizer, True)
    valid_loss, valid_acc, _,_,_,val_fscore,_= train_or_eval_model(model, loss_function, valid_loader, e)
    test_loss, test_acc, test_label, test_pred, test_mask, test_fscore, attentions, test_class_report = train_or_eval_model(model, loss_function, test_loader, e)

    if best_fscore == None or best_fscore < test_fscore:
        best_fscore, best_loss, best_label, best_pred, best_mask, best_attn =\
                test_fscore, test_loss, test_label, test_pred, test_mask, attentions

#     if args.tensorboard:
#         writer.add_scalar('test: accuracy/loss',test_acc/test_loss,e)
#         writer.add_scalar('train: accuracy/loss',train_acc/train_loss,e)
    print('epoch {} train_loss {} train_acc {} train_fscore {} valid_loss {} valid_acc {} val_fscore {} test_loss {} test_acc {} test_fscore {} time {}'.\
            format(e+1, train_loss, train_acc, train_fscore, valid_loss, valid_acc, val_fscore,\
                    test_loss, test_acc, test_fscore, round(time.time()-start_time,2)))
    print (test_class_report)
if tensorboard:
    writer.close()

print('Test performance..')
print('Fscore {} accuracy {}'.format(best_fscore,
                                 round(accuracy_score(best_label,best_pred,sample_weight=best_mask)*100,2)))
print(classification_report(best_label,best_pred,sample_weight=best_mask,digits=4))
print(confusion_matrix(best_label,best_pred,sample_weight=best_mask))


Running on GPU
Running on the multimodal features........


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 1 train_loss 1.5908 train_acc 45.52 train_fscore 29.57 valid_loss nan valid_acc nan val_fscore nan test_loss 1.5283 test_acc 48.12 test_fscore 31.27 time 14.46
              precision    recall  f1-score   support

           0     0.4812    1.0000    0.6498    1256.0
           1     0.0000    0.0000    0.0000     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.0000    0.0000    0.0000     208.0
           4     0.0000    0.0000    0.0000     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.0000    0.0000    0.0000     345.0

    accuracy                         0.4812    2610.0
   macro avg     0.0687    0.1429    0.0928    2610.0
weighted avg     0.2316    0.4812    0.3127    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 2 train_loss 1.4274 train_acc 49.51 train_fscore 37.55 valid_loss nan valid_acc nan val_fscore nan test_loss 1.3974 test_acc 50.27 test_fscore 44.84 time 13.42
              precision    recall  f1-score   support

           0     0.7080    0.7818    0.7431    1256.0
           1     0.0000    0.0000    0.0000     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.0000    0.0000    0.0000     208.0
           4     0.2920    0.2463    0.2672     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.2613    0.6696    0.3759     345.0

    accuracy                         0.5027    2610.0
   macro avg     0.1802    0.2425    0.1980    2610.0
weighted avg     0.4202    0.5027    0.4484    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 3 train_loss 1.2175 train_acc 56.34 train_fscore 50.25 valid_loss nan valid_acc nan val_fscore nan test_loss 1.4173 test_acc 50.5 test_fscore 46.18 time 13.57
              precision    recall  f1-score   support

           0     0.7351    0.7182    0.7265    1256.0
           1     0.0000    0.0000    0.0000     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.0000    0.0000    0.0000     208.0
           4     0.2779    0.6070    0.3812     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.3406    0.4986    0.4047     345.0

    accuracy                         0.5050    2610.0
   macro avg     0.1934    0.2605    0.2161    2610.0
weighted avg     0.4416    0.5050    0.4618    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 4 train_loss 1.1733 train_acc 58.79 train_fscore 52.54 valid_loss nan valid_acc nan val_fscore nan test_loss 1.3317 test_acc 55.56 test_fscore 50.22 time 13.83
              precision    recall  f1-score   support

           0     0.7070    0.8320    0.7644    1256.0
           1     0.2469    0.0712    0.1105     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.0000    0.0000    0.0000     208.0
           4     0.3675    0.5622    0.4444     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.3664    0.4609    0.4082     345.0

    accuracy                         0.5556    2610.0
   macro avg     0.2411    0.2752    0.2468    2610.0
weighted avg     0.4719    0.5556    0.5022    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 5 train_loss 1.1291 train_acc 60.73 train_fscore 55.41 valid_loss nan valid_acc nan val_fscore nan test_loss 1.326 test_acc 55.4 test_fscore 52.07 time 13.7
              precision    recall  f1-score   support

           0     0.7124    0.8145    0.7600    1256.0
           1     0.2878    0.3452    0.3139     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.2162    0.0385    0.0653     208.0
           4     0.4941    0.3109    0.3817     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.3528    0.5594    0.4327     345.0

    accuracy                         0.5540    2610.0
   macro avg     0.2948    0.2955    0.2791    2610.0
weighted avg     0.5138    0.5540    0.5207    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 6 train_loss 1.1099 train_acc 61.97 train_fscore 58.01 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2975 test_acc 56.74 test_fscore 53.83 time 13.49
              precision    recall  f1-score   support

           0     0.7109    0.8240    0.7633    1256.0
           1     0.3405    0.3950    0.3657     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.2088    0.0913    0.1271     208.0
           4     0.5684    0.3308    0.4182     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.3638    0.5304    0.4316     345.0

    accuracy                         0.5674    2610.0
   macro avg     0.3132    0.3102    0.3008    2610.0
weighted avg     0.5310    0.5674    0.5383    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 7 train_loss 1.0573 train_acc 65.32 train_fscore 62.21 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2561 test_acc 58.58 test_fscore 55.35 time 13.63
              precision    recall  f1-score   support

           0     0.7179    0.8002    0.7568    1256.0
           1     0.5190    0.3879    0.4440     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.2564    0.0481    0.0810     208.0
           4     0.4165    0.6144    0.4965     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.4293    0.4580    0.4432     345.0

    accuracy                         0.5858    2610.0
   macro avg     0.3342    0.3298    0.3173    2610.0
weighted avg     0.5427    0.5858    0.5535    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 8 train_loss 1.007 train_acc 67.27 train_fscore 63.94 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2556 test_acc 57.97 test_fscore 56.13 time 13.69
              precision    recall  f1-score   support

           0     0.7206    0.7906    0.7540    1256.0
           1     0.5283    0.3986    0.4544     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.2189    0.1779    0.1963     208.0
           4     0.5373    0.4478    0.4885     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.3702    0.5536    0.4437     345.0

    accuracy                         0.5797    2610.0
   macro avg     0.3393    0.3384    0.3338    2610.0
weighted avg     0.5528    0.5797    0.5613    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 9 train_loss 0.9943 train_acc 67.61 train_fscore 64.77 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2415 test_acc 59.31 test_fscore 56.16 time 13.52
              precision    recall  f1-score   support

           0     0.7178    0.7978    0.7557    1256.0
           1     0.4706    0.5409    0.5033     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.1860    0.0385    0.0637     208.0
           4     0.4966    0.5473    0.5207     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.4099    0.4812    0.4427     345.0

    accuracy                         0.5931    2610.0
   macro avg     0.3258    0.3437    0.3266    2610.0
weighted avg     0.5416    0.5931    0.5616    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 10 train_loss 0.9716 train_acc 68.15 train_fscore 65.13 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2363 test_acc 57.66 test_fscore 56.65 time 13.51
              precision    recall  f1-score   support

           0     0.7432    0.7373    0.7402    1256.0
           1     0.4578    0.5409    0.4959     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.2372    0.2452    0.2411     208.0
           4     0.4989    0.5423    0.5197     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.4158    0.4580    0.4359     345.0

    accuracy                         0.5766    2610.0
   macro avg     0.3361    0.3605    0.3475    2610.0
weighted avg     0.5576    0.5766    0.5665    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 11 train_loss 0.9663 train_acc 68.66 train_fscore 66.01 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2322 test_acc 58.12 test_fscore 56.42 time 13.83
              precision    recall  f1-score   support

           0     0.7271    0.7635    0.7449    1256.0
           1     0.4961    0.4555    0.4750     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.2636    0.1635    0.2018     208.0
           4     0.5370    0.4876    0.5111     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.3711    0.5797    0.4525     345.0

    accuracy                         0.5812    2610.0
   macro avg     0.3421    0.3500    0.3407    2610.0
weighted avg     0.5561    0.5812    0.5642    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 12 train_loss 0.9599 train_acc 68.66 train_fscore 65.89 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2286 test_acc 59.23 test_fscore 57.04 time 13.89
              precision    recall  f1-score   support

           0     0.7153    0.8041    0.7571    1256.0
           1     0.4690    0.4840    0.4764     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.2901    0.1827    0.2242     208.0
           4     0.5795    0.4353    0.4972     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.3937    0.5420    0.4561     345.0

    accuracy                         0.5923    2610.0
   macro avg     0.3496    0.3497    0.3444    2610.0
weighted avg     0.5591    0.5923    0.5704    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 13 train_loss 0.9531 train_acc 68.98 train_fscore 66.42 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2242 test_acc 58.89 test_fscore 56.18 time 13.86
              precision    recall  f1-score   support

           0     0.7174    0.7922    0.7529    1256.0
           1     0.4774    0.4875    0.4824     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.2821    0.1058    0.1538     208.0
           4     0.4381    0.6070    0.5089     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.4618    0.4029    0.4303     345.0

    accuracy                         0.5889    2610.0
   macro avg     0.3395    0.3422    0.3326    2610.0
weighted avg     0.5476    0.5889    0.5618    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 14 train_loss 0.9453 train_acc 69.15 train_fscore 66.73 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2108 test_acc 59.39 test_fscore 57.24 time 13.64
              precision    recall  f1-score   support

           0     0.7303    0.7890    0.7585    1256.0
           1     0.4461    0.5302    0.4846     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.3626    0.1587    0.2207     208.0
           4     0.5179    0.5050    0.5113     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.3991    0.5043    0.4456     345.0

    accuracy                         0.5939    2610.0
   macro avg     0.3509    0.3553    0.3458    2610.0
weighted avg     0.5609    0.5939    0.5724    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 15 train_loss 0.9385 train_acc 69.25 train_fscore 66.68 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2234 test_acc 57.32 test_fscore 56.63 time 13.75
              precision    recall  f1-score   support

           0     0.7428    0.7381    0.7404    1256.0
           1     0.4528    0.5125    0.4808     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.2310    0.3365    0.2740     208.0
           4     0.5098    0.5174    0.5136     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.4414    0.4261    0.4336     345.0

    accuracy                         0.5732    2610.0
   macro avg     0.3397    0.3615    0.3489    2610.0
weighted avg     0.5615    0.5732    0.5663    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 16 train_loss 0.9311 train_acc 69.37 train_fscore 67.02 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2113 test_acc 59.85 test_fscore 57.44 time 13.6
              precision    recall  f1-score   support

           0     0.7180    0.8089    0.7608    1256.0
           1     0.4583    0.5089    0.4823     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.3636    0.1731    0.2345     208.0
           4     0.5262    0.4751    0.4993     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.4181    0.5101    0.4595     345.0

    accuracy                         0.5985    2610.0
   macro avg     0.3549    0.3537    0.3481    2610.0
weighted avg     0.5602    0.5985    0.5744    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 17 train_loss 0.9308 train_acc 69.35 train_fscore 66.8 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2252 test_acc 59.5 test_fscore 56.68 time 13.84
              precision    recall  f1-score   support

           0     0.7150    0.8129    0.7608    1256.0
           1     0.4851    0.4626    0.4736     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.3729    0.1058    0.1648     208.0
           4     0.5312    0.4652    0.4960     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.3837    0.5594    0.4552     345.0

    accuracy                         0.5950    2610.0
   macro avg     0.3554    0.3437    0.3358    2610.0
weighted avg     0.5586    0.5950    0.5668    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 18 train_loss 0.9319 train_acc 69.5 train_fscore 67.22 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2076 test_acc 59.77 test_fscore 57.53 time 13.49
              precision    recall  f1-score   support

           0     0.7257    0.7962    0.7593    1256.0
           1     0.4494    0.5374    0.4895     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.3700    0.1779    0.2403     208.0
           4     0.5083    0.5348    0.5212     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.4209    0.4551    0.4373     345.0

    accuracy                         0.5977    2610.0
   macro avg     0.3535    0.3573    0.3497    2610.0
weighted avg     0.5610    0.5977    0.5753    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 19 train_loss 0.9262 train_acc 69.63 train_fscore 67.21 valid_loss nan valid_acc nan val_fscore nan test_loss 1.203 test_acc 59.23 test_fscore 57.09 time 13.8
              precision    recall  f1-score   support

           0     0.7241    0.7962    0.7584    1256.0
           1     0.5000    0.4235    0.4586     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.3519    0.1827    0.2405     208.0
           4     0.5396    0.4577    0.4953     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.3782    0.5942    0.4622     345.0

    accuracy                         0.5923    2610.0
   macro avg     0.3563    0.3506    0.3450    2610.0
weighted avg     0.5634    0.5923    0.5709    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 20 train_loss 0.9117 train_acc 69.82 train_fscore 67.54 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2288 test_acc 58.08 test_fscore 56.58 time 13.61
              precision    recall  f1-score   support

           0     0.7384    0.7572    0.7476    1256.0
           1     0.4862    0.5018    0.4939     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.3421    0.1875    0.2422     208.0
           4     0.5394    0.4428    0.4863     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.3520    0.6000    0.4437     345.0

    accuracy                         0.5808    2610.0
   macro avg     0.3512    0.3556    0.3448    2610.0
weighted avg     0.5645    0.5808    0.5658    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 21 train_loss 0.9121 train_acc 70.09 train_fscore 67.8 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2181 test_acc 60.61 test_fscore 56.67 time 14.02
              precision    recall  f1-score   support

           0     0.6951    0.8623    0.7697    1256.0
           1     0.4839    0.4804    0.4821     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.3611    0.0625    0.1066     208.0
           4     0.5372    0.4851    0.5098     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.4171    0.4522    0.4339     345.0

    accuracy                         0.6061    2610.0
   macro avg     0.3563    0.3346    0.3289    2610.0
weighted avg     0.5533    0.6061    0.5667    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 22 train_loss 0.9076 train_acc 70.35 train_fscore 67.9 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2372 test_acc 58.47 test_fscore 56.79 time 13.86
              precision    recall  f1-score   support

           0     0.7318    0.7667    0.7488    1256.0
           1     0.4308    0.5872    0.4970     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.2800    0.2019    0.2346     208.0
           4     0.5115    0.4975    0.5044     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.4216    0.4522    0.4364     345.0

    accuracy                         0.5847    2610.0
   macro avg     0.3394    0.3579    0.3459    2610.0
weighted avg     0.5554    0.5847    0.5679    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 23 train_loss 0.8996 train_acc 70.45 train_fscore 68.14 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2179 test_acc 58.89 test_fscore 57.19 time 13.76
              precision    recall  f1-score   support

           0     0.7362    0.7731    0.7542    1256.0
           1     0.4566    0.5623    0.5040     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.3203    0.1971    0.2440     208.0
           4     0.5352    0.4726    0.5020     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.3831    0.5130    0.4387     345.0

    accuracy                         0.5889    2610.0
   macro avg     0.3474    0.3597    0.3490    2610.0
weighted avg     0.5620    0.5889    0.5719    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 24 train_loss 0.9007 train_acc 70.5 train_fscore 68.22 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2081 test_acc 59.31 test_fscore 57.1 time 13.75
              precision    recall  f1-score   support

           0     0.7224    0.7978    0.7582    1256.0
           1     0.4350    0.5480    0.4850     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.3737    0.1779    0.2410     208.0
           4     0.5376    0.4627    0.4973     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.3986    0.4899    0.4395     345.0

    accuracy                         0.5931    2610.0
   macro avg     0.3525    0.3537    0.3459    2610.0
weighted avg     0.5598    0.5931    0.5710    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 25 train_loss 0.8926 train_acc 70.47 train_fscore 68.17 valid_loss nan valid_acc nan val_fscore nan test_loss 1.1998 test_acc 58.66 test_fscore 56.91 time 13.84
              precision    recall  f1-score   support

           0     0.7206    0.7763    0.7474    1256.0
           1     0.4674    0.4591    0.4632     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.2809    0.2404    0.2591     208.0
           4     0.4978    0.5597    0.5269     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.4330    0.4406    0.4368     345.0

    accuracy                         0.5866    2610.0
   macro avg     0.3428    0.3537    0.3476    2610.0
weighted avg     0.5534    0.5866    0.5691    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 26 train_loss 0.8826 train_acc 70.83 train_fscore 68.56 valid_loss nan valid_acc nan val_fscore nan test_loss 1.207 test_acc 58.62 test_fscore 56.89 time 13.8
              precision    recall  f1-score   support

           0     0.7334    0.7755    0.7539    1256.0
           1     0.4527    0.5445    0.4943     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.3203    0.1971    0.2440     208.0
           4     0.5272    0.4577    0.4900     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.3812    0.5159    0.4384     345.0

    accuracy                         0.5862    2610.0
   macro avg     0.3450    0.3558    0.3458    2610.0
weighted avg     0.5588    0.5862    0.5689    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 27 train_loss 0.8811 train_acc 71.04 train_fscore 68.74 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2038 test_acc 60.19 test_fscore 57.65 time 14.0
              precision    recall  f1-score   support

           0     0.7157    0.8177    0.7633    1256.0
           1     0.5040    0.4448    0.4726     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.3933    0.1683    0.2357     208.0
           4     0.5385    0.4876    0.5117     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.3966    0.5449    0.4591     345.0

    accuracy                         0.6019    2610.0
   macro avg     0.3640    0.3519    0.3489    2610.0
weighted avg     0.5654    0.6019    0.5765    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 28 train_loss 0.8725 train_acc 71.22 train_fscore 68.91 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2549 test_acc 58.66 test_fscore 56.58 time 14.04
              precision    recall  f1-score   support

           0     0.7229    0.7747    0.7479    1256.0
           1     0.4912    0.4947    0.4929     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.3625    0.1394    0.2014     208.0
           4     0.5247    0.5025    0.5133     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.3643    0.5449    0.4367     345.0

    accuracy                         0.5866    2610.0
   macro avg     0.3522    0.3509    0.3417    2610.0
weighted avg     0.5586    0.5866    0.5658    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 29 train_loss 0.8779 train_acc 70.92 train_fscore 68.66 valid_loss nan valid_acc nan val_fscore nan test_loss 1.223 test_acc 59.04 test_fscore 57.24 time 13.74
              precision    recall  f1-score   support

           0     0.7327    0.7683    0.7501    1256.0
           1     0.4807    0.4875    0.4841     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.3704    0.1923    0.2532     208.0
           4     0.4922    0.5473    0.5183     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.3951    0.5188    0.4486     345.0

    accuracy                         0.5904    2610.0
   macro avg     0.3530    0.3592    0.3506    2610.0
weighted avg     0.5619    0.5904    0.5724    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 30 train_loss 0.8722 train_acc 71.34 train_fscore 69.02 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2267 test_acc 58.89 test_fscore 56.89 time 13.73
              precision    recall  f1-score   support

           0     0.7320    0.7548    0.7432    1256.0
           1     0.4982    0.5018    0.5000     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.3636    0.1346    0.1965     208.0
           4     0.4916    0.5846    0.5341     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.3878    0.5362    0.4501     345.0

    accuracy                         0.5889    2610.0
   macro avg     0.3533    0.3589    0.3463    2610.0
weighted avg     0.5619    0.5889    0.5689    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 31 train_loss 0.861 train_acc 71.32 train_fscore 69.13 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2236 test_acc 59.81 test_fscore 57.3 time 13.54
              precision    recall  f1-score   support

           0     0.7142    0.8177    0.7624    1256.0
           1     0.4575    0.4982    0.4770     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.3396    0.1731    0.2293     208.0
           4     0.5243    0.4826    0.5026     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.4216    0.4754    0.4469     345.0

    accuracy                         0.5981    2610.0
   macro avg     0.3510    0.3496    0.3455    2610.0
weighted avg     0.5565    0.5981    0.5730    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 32 train_loss 0.8647 train_acc 71.53 train_fscore 69.32 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2298 test_acc 58.66 test_fscore 56.62 time 13.85
              precision    recall  f1-score   support

           0     0.7306    0.7643    0.7471    1256.0
           1     0.5059    0.4591    0.4813     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.3765    0.1538    0.2184     208.0
           4     0.4564    0.6119    0.5228     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.3933    0.4754    0.4304     345.0

    accuracy                         0.5866    2610.0
   macro avg     0.3518    0.3521    0.3429    2610.0
weighted avg     0.5583    0.5866    0.5662    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 33 train_loss 0.8545 train_acc 71.8 train_fscore 69.65 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2342 test_acc 59.2 test_fscore 57.18 time 13.76
              precision    recall  f1-score   support

           0     0.7207    0.7930    0.7551    1256.0
           1     0.5238    0.4306    0.4727     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.3385    0.2115    0.2604     208.0
           4     0.5573    0.4353    0.4888     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.3779    0.6058    0.4655     345.0

    accuracy                         0.5920    2610.0
   macro avg     0.3597    0.3538    0.3489    2610.0
weighted avg     0.5660    0.5920    0.5718    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 34 train_loss 0.8656 train_acc 70.95 train_fscore 68.82 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2705 test_acc 59.39 test_fscore 55.78 time 14.18
              precision    recall  f1-score   support

           0     0.7089    0.8280    0.7639    1256.0
           1     0.4174    0.5302    0.4671     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.3818    0.1010    0.1597     208.0
           4     0.4527    0.6070    0.5186     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.5026    0.2783    0.3582     345.0

    accuracy                         0.5939    2610.0
   macro avg     0.3519    0.3349    0.3239    2610.0
weighted avg     0.5527    0.5939    0.5578    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 35 train_loss 0.8544 train_acc 71.66 train_fscore 69.47 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2382 test_acc 59.54 test_fscore 56.98 time 13.91
              precision    recall  f1-score   support

           0     0.7205    0.7986    0.7576    1256.0
           1     0.4679    0.5196    0.4924     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.3824    0.1250    0.1884     208.0
           4     0.4988    0.5323    0.5150     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.4034    0.4783    0.4377     345.0

    accuracy                         0.5954    2610.0
   macro avg     0.3533    0.3505    0.3416    2610.0
weighted avg     0.5578    0.5954    0.5698    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 36 train_loss 0.8454 train_acc 71.65 train_fscore 69.38 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2306 test_acc 58.12 test_fscore 57.04 time 14.02
              precision    recall  f1-score   support

           0     0.7393    0.7404    0.7399    1256.0
           1     0.4847    0.5089    0.4965     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.3000    0.2740    0.2864     208.0
           4     0.5167    0.5373    0.5268     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.3808    0.4957    0.4307     345.0

    accuracy                         0.5812    2610.0
   macro avg     0.3459    0.3652    0.3543    2610.0
weighted avg     0.5618    0.5812    0.5704    2610.0

epoch 37 train_loss 0.8405 train_acc 71.78 train_fscore 69.54 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2442 test_acc 59.16 test_fscore 57.1 time 13.81
              precision    recall  f1-score   support

           0   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 38 train_loss 0.833 train_acc 72.08 train_fscore 69.88 valid_loss nan valid_acc nan val_fscore nan test_loss 1.255 test_acc 57.89 test_fscore 56.19 time 14.24
              precision    recall  f1-score   support

           0     0.7336    0.7389    0.7362    1256.0
           1     0.4630    0.5338    0.4959     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.3718    0.1394    0.2028     208.0
           4     0.5153    0.5025    0.5088     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.3673    0.5855    0.4514     345.0

    accuracy                         0.5789    2610.0
   macro avg     0.3501    0.3572    0.3422    2610.0
weighted avg     0.5604    0.5789    0.5619    2610.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 39 train_loss 0.8291 train_acc 72.55 train_fscore 70.48 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2661 test_acc 56.9 test_fscore 55.91 time 13.56
              precision    recall  f1-score   support

           0     0.7536    0.6990    0.7253    1256.0
           1     0.4660    0.5125    0.4881     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.3814    0.1779    0.2426     208.0
           4     0.5299    0.4851    0.5065     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.3443    0.6696    0.4547     345.0

    accuracy                         0.5690    2610.0
   macro avg     0.3536    0.3634    0.3453    2610.0
weighted avg     0.5704    0.5690    0.5591    2610.0

epoch 40 train_loss 0.8281 train_acc 72.34 train_fscore 70.26 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2499 test_acc 57.89 test_fscore 56.36 time 14.05
              precision    recall  f1-score   support

           0   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 43 train_loss 0.8057 train_acc 73.2 train_fscore 71.19 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2579 test_acc 58.77 test_fscore 56.88 time 13.78
              precision    recall  f1-score   support

           0     0.7275    0.7803    0.7530    1256.0
           1     0.4849    0.5160    0.5000     281.0
           2     0.0000    0.0000    0.0000      50.0
           3     0.3190    0.1779    0.2284     208.0
           4     0.5226    0.4602    0.4894     402.0
           5     0.0000    0.0000    0.0000      68.0
           6     0.3801    0.5420    0.4468     345.0

    accuracy                         0.5877    2610.0
   macro avg     0.3477    0.3538    0.3454    2610.0
weighted avg     0.5585    0.5877    0.5688    2610.0

epoch 44 train_loss 0.7964 train_acc 73.25 train_fscore 71.32 valid_loss nan valid_acc nan val_fscore nan test_loss 1.2675 test_acc 59.66 test_fscore 57.02 time 14.41
              precision    recall  f1-score   support

           0   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
