In [1]:
import copy

import matplotlib.pyplot as plt
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from seqeval.scheme import IOB2
from seqeval.metrics import classification_report,f1_score

In [2]:
import pandas as pd
import numpy as np
import json
from tqdm import tqdm
import warnings
import random

from keras.preprocessing.text import Tokenizer
from sklearn.model_selection import KFold

In [3]:
data_train = pd.read_csv("BC_train_IOB2.txt",sep = '\t')
data_dev = pd.read_csv("BC_dev_IOB2.txt",sep = '\t')

In [4]:
data_train

Unnamed: 0,Word,Tag,Length,Sentence#
0,@,O,1,0
1,Rhy_QD10,O,8,0
2,yeah,O,4,0
3,irking,O,6,0
4,he,O,2,0
...,...,...,...,...
14779,smell,O,5,704
14780,like,O,4,704
14781,dogs,O,4,704
14782,medicine,O,8,704


In [5]:
class SentenceGetter(object):
    
    def __init__(self, data):
        self.n_sent = 1
        self.data = data
        self.empty = False
        agg_func = lambda s: [(w, t) for w, t in zip(s["Word"].values.tolist(),
                                                     s["Tag"].values.tolist())]
        self.grouped = self.data.groupby("Sentence#").apply(agg_func)
        self.sentences = [s for s in self.grouped]
    
    def get_next(self):
        try:
            s = self.grouped[self.n_sent]
            self.n_sent += 1
            return s
        except:
            return None

In [6]:
getter = SentenceGetter(data_train)
dev_getter = SentenceGetter(data_dev)

In [7]:
sentences = [[word[0] for word in sentence] for sentence in getter.sentences]
dev_sentences = [[word[0] for word in sentence] for sentence in dev_getter.sentences]
labels = [[s[1] for s in sent] for sent in getter.sentences]
dev_labels = [[s[1] for s in sent] for sent in dev_getter.sentences]

In [8]:
max_features = 40000
tokenizer = Tokenizer(num_words=max_features, split=' ', oov_token='<unk>')
tokenizer.fit_on_texts(sentences)

In [9]:
word2idx = tokenizer.word_index.copy()
word2idx['<pad>'] = 0
idx2word = {word2idx[i]:i for i in word2idx}

In [10]:
len(word2idx)

1328

In [11]:
tags = list(set(data_train['Tag'].values))
tags.append("<pad>")
tag2idx = {t: i for i, t in enumerate(tags)}
idx2tag = {i: t for i, t in enumerate(tags)}

In [12]:
X_train = tokenizer.texts_to_sequences(sentences)
X_test = tokenizer.texts_to_sequences(dev_sentences)
y_train = [[tag2idx[l] for l in s] for s in labels]
y_test = [[tag2idx[l] for l in s] for s in dev_labels]

In [13]:
for i in X_test[0]:
    print(idx2word[i],end = ' ')

just get on birth control and use two <unk> . oh and do n't have <unk> when <unk> . 

In [14]:
from torch.utils.data import Dataset
from torchcrf import CRF
class NERDataset(Dataset):
    def __init__(self,sentences,labels, word_pad_idx, tag_pad_idx, max_len = 500):
        self.sentences = sentences
        self.labels = labels
        self.word_pad_idx = word_pad_idx
        self.tag_pad_idx = tag_pad_idx
        self.max_len = max_len
    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, index):
        return (self.sentences[index],self.labels[index])
        
    def collate_fn(self, datasets):
        sentences = [dataset[0] for dataset in datasets]
        labels = [dataset[1] for dataset in datasets]
        max_sent = max([len(data) for data in sentences])
        max_len = max([min(len(sentence), self.max_len) for sentence in sentences])
        pad_sentence = []
        pad_label = []
        for sentence,label in zip(sentences,labels):
            
            if len(sentence) > max_len:
#                 print('asd')
                pad_sentence.append(sentence[:max_len])
                pad_label.append(label[:max_len])
            else:
#                 print('zxc')
                pad_sentence.append(sentence+[self.word_pad_idx]*(max_len-len(sentence)))
                pad_label.append(label+[self.tag_pad_idx]*(max_len-len(label)))
        return torch.LongTensor(pad_sentence), torch.LongTensor(pad_label)

In [15]:
class EmbeddedRnn(nn.Module):
    def __init__(self, vocab, hidden_dim, output_vocab, n_layer,word_pad_idx,tag_pad_idx):
        super(EmbeddedRnn, self).__init__()
        self.n_layer = n_layer
        self.embedding_size = 300
        self.hidden_dim = hidden_dim
        self.embedded = nn.Embedding(vocab, self.embedding_size , padding_idx  = word_pad_idx)
        self.lstm = nn.LSTM(self.embedding_size, hidden_dim, num_layers=n_layer,batch_first = True, bidirectional=True)
        self.fc1 = nn.Linear(2 * hidden_dim, output_vocab)
        self.softmax = nn.Softmax(dim=-1)
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        self.crf = CRF(num_tags=output_vocab, batch_first = True)
        self.tag_pad_idx = tag_pad_idx
        
    def forward(self, x, hidden,y_tag):
        embedded = self.embedded(x)
        output, hidden = self.lstm(embedded, hidden)
        output = self.fc1(output)
        if y_tag is not None:
            mask = y_tag != self.tag_pad_idx
            crf_out = self.crf.decode(output, mask=mask)
            crf_loss = -self.crf(output, tags=y_tag, mask=mask)
        else:
            crf_out = self.crf.decode(output)
            crf_loss = None
        return crf_out, crf_loss
    
    def initHidden(self, batch_size):
        hidden = Variable(torch.zeros(2 * self.n_layer, batch_size, self.hidden_dim))
        cell = Variable(torch.zeros(2 * self.n_layer, batch_size, self.hidden_dim))
        return [hidden, cell]
#         return hidden
    def init_crf_transitions(tag_names, imp_value=-100):
        crf = CRF(num_tags=n_tags)
        num_tags = len(tag_names)
        for i in range(num_tags):
            tag_name = tag_names[i]
            if tag_name[0] in ("I") or tag_name[0] in ("L") or tag_name == "ENDPAD":
                torch.nn.init.constant_(crf.start_transitions[i], imp_value)
        tag_is = {}
        for tag_position in ("B", "I", "O", 'L','U'):
            tag_is[tag_position] = [i for i, tag in enumerate(tag_names) if tag[0] == tag_position]
        impossible_transitions_position = {
            "O": "I",  
            "O": 'L',
            
            'B': 'U',
            'B': 'I',
            
            'I': 'B',
            'I': 'U',
            'I': 'O',
            
            'L': 'I',
            
            'U': 'I',
            'U': 'L',
        }
        for from_tag, to_tag_list in impossible_transitions_position.items():
            to_tags = list(to_tag_list)
            for from_tag_i in tag_is[from_tag]:
                for to_tag in to_tags:
                    for to_tag_i in tag_is[to_tag]:
                        torch.nn.init.constant_(
                            crf.transitions[from_tag_i, to_tag_i], imp_value
                        )
        # init impossible B and I transitions to different entity types
        impossible_transitions_tags = {
            "B": "I",
            'B': 'L',
            "I": "I",
            'I': 'L',
            'U': 'B',
        }
        for from_tag, to_tag_list in impossible_transitions_tags.items():
            to_tags = list(to_tag_list)
            for from_tag_i in tag_is[from_tag]:
                for to_tag in to_tags:
                    for to_tag_i in tag_is[to_tag]:
                        if tag_names[from_tag_i].split("-")[1] != tag_names[to_tag_i].split("-")[1]:
                            torch.nn.init.constant_(
                                crf.transitions[from_tag_i, to_tag_i], imp_value
                            )

## 使用 trainingset+validset 去做 10-fold CV

In [16]:
from torch.utils.data import TensorDataset, DataLoader ,SubsetRandomSampler ,ConcatDataset
bs = 16
k_folds = 10
kfold = KFold(n_splits=k_folds, shuffle=True)
tr_dataset = NERDataset(X_train,y_train,word2idx['<pad>'],tag2idx['<pad>'])
va_dataset = NERDataset(X_test,y_test,word2idx['<pad>'],tag2idx['<pad>'])
dataset = ConcatDataset([tr_dataset, va_dataset])
word_pad_idx = word2idx['<pad>']
tag_pad_idx = tag2idx['<pad>']
num_epoch = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [17]:
All_Fold_score = []
for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset)):
    print(f'FOLD {fold}')
    print('--------------------------------')
    train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
    test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)
    train_dataloader = DataLoader(dataset, batch_size=bs,
                                  collate_fn=tr_dataset.collate_fn,
                                  sampler=train_subsampler)
    valid_dataloader = DataLoader(dataset, batch_size=bs,
                                  collate_fn=va_dataset.collate_fn,
                                  sampler=test_subsampler)  
    model = EmbeddedRnn(len(word2idx), 256, len(tags),2,word_pad_idx,tag_pad_idx)
    optimizer = optim.AdamW(model.parameters(), lr=5e-3)
    model = model.to(device)
    all_loader = {"train" : train_dataloader,
                  "valid" : valid_dataloader}
    Fold_score = []
    for epoch in tqdm(range(num_epoch)):
        all_loss = {
            'train': [],
            'valid': [],
        }
        print('')
        for loader in all_loader:
            predictions , true_labels  = [],[]
            for x, y in all_loader[loader]:
                optimizer.zero_grad()
                x = x.to(device)
                y = y.to(device)
                hidden = model.initHidden(x.size(0))
                hidden[0] = hidden[0].to(device)
                hidden[1] = hidden[1].to(device)
                output, loss = model(x, hidden,y)
                if loader == 'train':
                    loss.backward()
                    optimizer.step()
                all_loss[loader].append(loss.cpu().item()) 
                predictions.extend([[idx2tag[j] for j in i] for i in output])
                for i in y.detach().cpu().numpy():
                    _ = []
                    for j in i:
                        if j != tag_pad_idx:
                            _.append(idx2tag[j])
                    true_labels.append(_)
            print(f'{loader}_loss : {np.mean(np.array(all_loss[loader]))/64}')
            f_ = f1_score(true_labels,predictions,scheme = IOB2)
            print(f'{loader}_F1: {f_}')
            if loader == 'valid':
                Fold_score.append(f_)
    if Fold_score != []:
        All_Fold_score.append(Fold_score)

FOLD 0
--------------------------------


  0%|                                                                                           | 0/10 [00:00<?, ?it/s]




 10%|████████▎                                                                          | 1/10 [00:03<00:31,  3.52s/it]

train_loss : 1.0442178495552228
train_F1: 0.453355155482815
valid_loss : 0.2164492905139923
valid_F1: 0.8604651162790699



 20%|████████████████▌                                                                  | 2/10 [00:06<00:24,  3.09s/it]

train_loss : 0.2118674526396005
train_F1: 0.8717607973421927
valid_loss : 0.12267504135767619
valid_F1: 0.9447852760736196



 30%|████████████████████████▉                                                          | 3/10 [00:09<00:20,  2.98s/it]

train_loss : 0.15433670288842657
train_F1: 0.8952254641909814
valid_loss : 0.13719341158866882
valid_F1: 0.934131736526946



 40%|█████████████████████████████████▏                                                 | 4/10 [00:12<00:17,  2.95s/it]

train_loss : 0.10282802193061165
train_F1: 0.9065420560747665
valid_loss : 0.16306322813034058
valid_F1: 0.9125



 50%|█████████████████████████████████████████▌                                         | 5/10 [00:14<00:14,  2.91s/it]

train_loss : 0.08690741042728009
train_F1: 0.9082256968048947
valid_loss : 0.12125401695569356
valid_F1: 0.9240506329113924



 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:17<00:11,  2.91s/it]

train_loss : 0.07717215334591658
train_F1: 0.9111111111111111
valid_loss : 0.1494238426287969
valid_F1: 0.9192546583850932

train_loss : 0.06970002534596817
train_F1: 0.9125168236877523
valid_loss : 0.18335077166557312
valid_F1: 0.8982035928143713

 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:20<00:08,  2.87s/it]





 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:23<00:05,  2.87s/it]

train_loss : 0.06264131807762643
train_F1: 0.9216909216909218
valid_loss : 0.19721119602521262
valid_F1: 0.9146341463414634



 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:26<00:02,  2.87s/it]

train_loss : 0.0595356575820757
train_F1: 0.926330150068213
valid_loss : 0.21213054160277048
valid_F1: 0.8848484848484849



100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:29<00:00,  2.85s/it]

train_loss : 0.058006096793257675
train_F1: 0.9255393180236604
valid_loss : 0.22351795931657156
valid_F1: 0.9182389937106918


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:29<00:00,  2.91s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

FOLD 1
--------------------------------



 10%|████████▎                                                                          | 1/10 [00:03<00:27,  3.03s/it]

train_loss : 1.0214742846463039
train_F1: 0.4137931034482759
valid_loss : 0.275197500983874
valid_F1: 0.8354430379746836



 20%|████████████████▌                                                                  | 2/10 [00:05<00:23,  2.90s/it]

train_loss : 0.21957138925790787
train_F1: 0.8722109533468559
valid_loss : 0.17739402254422507
valid_F1: 0.8941176470588234



 30%|████████████████████████▉                                                          | 3/10 [00:08<00:20,  2.87s/it]

train_loss : 0.14509571635204813
train_F1: 0.898861352980576
valid_loss : 0.1313612312078476
valid_F1: 0.8588957055214724



 40%|█████████████████████████████████▏                                                 | 4/10 [00:11<00:17,  2.88s/it]

train_loss : 0.10729622873275177
train_F1: 0.9040163376446563
valid_loss : 0.13695518672466278
valid_F1: 0.860759493670886



 50%|█████████████████████████████████████████▌                                         | 5/10 [00:14<00:14,  2.91s/it]

train_loss : 0.08556999168966127
train_F1: 0.9156956819739548
valid_loss : 0.1383363902568817
valid_F1: 0.8500000000000001



 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:17<00:11,  2.94s/it]

train_loss : 0.07415454186823058
train_F1: 0.9186602870813397
valid_loss : 0.12091301878293355
valid_F1: 0.8862275449101796



 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:20<00:08,  2.93s/it]

train_loss : 0.06697391005961792
train_F1: 0.9192200557103064
valid_loss : 0.11088752746582031
valid_F1: 0.8622754491017963



 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:23<00:05,  2.89s/it]

train_loss : 0.06736362401558005
train_F1: 0.9167240192704749
valid_loss : 0.12997986872990927
valid_F1: 0.8701298701298701



 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:26<00:02,  2.88s/it]

train_loss : 0.06214010067608045
train_F1: 0.9229711141678129
valid_loss : 0.11943572759628296
valid_F1: 0.8395061728395061



100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:28<00:00,  2.89s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

train_loss : 0.06273515716842983
train_F1: 0.919972164231037
valid_loss : 0.11923639973004659
valid_F1: 0.8674698795180723
FOLD 2
--------------------------------



 10%|████████▎                                                                          | 1/10 [00:02<00:25,  2.88s/it]

train_loss : 0.9675028499053873
train_F1: 0.4852941176470589
valid_loss : 0.2682473932703336
valid_F1: 0.8513513513513513



 20%|████████████████▌                                                                  | 2/10 [00:05<00:23,  2.90s/it]

train_loss : 0.20133513050234836
train_F1: 0.8792528352234823
valid_loss : 0.16835667937994003
valid_F1: 0.8606060606060606



 30%|████████████████████████▉                                                          | 3/10 [00:08<00:19,  2.85s/it]

train_loss : 0.1412908353883287
train_F1: 0.899009900990099
valid_loss : 0.15195515751838684
valid_F1: 0.8846153846153846



 40%|█████████████████████████████████▏                                                 | 4/10 [00:11<00:17,  2.87s/it]

train_loss : 0.09876601106446722
train_F1: 0.9072580645161291
valid_loss : 0.14167756338914236
valid_F1: 0.87248322147651



 50%|█████████████████████████████████████████▌                                         | 5/10 [00:14<00:14,  2.87s/it]

train_loss : 0.0836430817194607
train_F1: 0.9022038567493113
valid_loss : 0.1123570700486501
valid_F1: 0.888888888888889



 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:17<00:11,  2.91s/it]

train_loss : 0.06924710325572801
train_F1: 0.9189944134078212
valid_loss : 0.16334273417790732
valid_F1: 0.8903225806451613



 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:20<00:08,  2.89s/it]

train_loss : 0.06656149224094722
train_F1: 0.9258241758241759
valid_loss : 0.12082989017168681
valid_F1: 0.8974358974358974



 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:23<00:05,  2.92s/it]

train_loss : 0.06585031359092049
train_F1: 0.9226557152635183
valid_loss : 0.14129778742790222
valid_F1: 0.8831168831168831



 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:26<00:02,  2.89s/it]

train_loss : 0.061542586787887245
train_F1: 0.9170731707317072
valid_loss : 0.12906432648499808
valid_F1: 0.8933333333333333



100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:28<00:00,  2.90s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

train_loss : 0.06021994481916013
train_F1: 0.9203664552501761
valid_loss : 0.167593980828921
valid_F1: 0.8918918918918919
FOLD 3
--------------------------------



 10%|████████▎                                                                          | 1/10 [00:02<00:26,  2.91s/it]

train_loss : 1.0143362484548404
train_F1: 0.49409780775716694
valid_loss : 0.2545907547076543
valid_F1: 0.8275862068965517



 20%|████████████████▌                                                                  | 2/10 [00:05<00:23,  2.95s/it]

train_loss : 0.20156753289958704
train_F1: 0.8872581721147431
valid_loss : 0.19490453600883484
valid_F1: 0.8823529411764706



 30%|████████████████████████▉                                                          | 3/10 [00:08<00:20,  2.90s/it]

train_loss : 0.13459298306185266
train_F1: 0.902488231338265
valid_loss : 0.1650985380013784
valid_F1: 0.8969696969696969



 40%|█████████████████████████████████▏                                                 | 4/10 [00:11<00:17,  2.93s/it]

train_loss : 0.09793117707190306
train_F1: 0.9151391717583164
valid_loss : 0.2201439986626307
valid_F1: 0.8641975308641975

train_loss : 0.0802116633757301
train_F1: 0.9093387866394002
valid_loss : 0.22494809329509735
valid_F1: 0.8624999999999999

 50%|█████████████████████████████████████████▌                                         | 5/10 [00:14<00:14,  2.95s/it]





 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:17<00:11,  2.97s/it]

train_loss : 0.06866942993972612
train_F1: 0.9192200557103064
valid_loss : 0.24146036803722382
valid_F1: 0.8571428571428572



 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:20<00:08,  2.94s/it]

train_loss : 0.06168718247309975
train_F1: 0.9211618257261411
valid_loss : 0.26307331522305805
valid_F1: 0.8819875776397514



 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:23<00:05,  2.93s/it]

train_loss : 0.05881759459557741
train_F1: 0.926490984743412
valid_loss : 0.2668558855851491
valid_F1: 0.8588957055214724



 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:26<00:02,  2.91s/it]

train_loss : 0.05705782123233961
train_F1: 0.9232895646164477
valid_loss : 0.28795959055423737
valid_F1: 0.8734177215189873



100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:29<00:00,  2.93s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

train_loss : 0.05602396247179612
train_F1: 0.9355281207133059
valid_loss : 0.2871881077686946
valid_F1: 0.8789808917197451
FOLD 4
--------------------------------



 10%|████████▎                                                                          | 1/10 [00:02<00:26,  2.95s/it]

train_loss : 1.0348913928736811
train_F1: 0.5314091680814941
valid_loss : 0.19956830392281213
valid_F1: 0.8500000000000001



 20%|████████████████▌                                                                  | 2/10 [00:05<00:23,  2.97s/it]

train_loss : 0.2100717094929322
train_F1: 0.8662593346911067
valid_loss : 0.12068269898494084
valid_F1: 0.9156626506024096



 30%|████████████████████████▉                                                          | 3/10 [00:08<00:20,  3.00s/it]

train_loss : 0.13772446986125864
train_F1: 0.9010695187165775
valid_loss : 0.12344420949618022
valid_F1: 0.8765432098765432



 40%|█████████████████████████████████▏                                                 | 4/10 [00:11<00:17,  2.98s/it]

train_loss : 0.11101098293843477
train_F1: 0.9012178619756427
valid_loss : 0.1110730121533076
valid_F1: 0.8674698795180723



 50%|█████████████████████████████████████████▌                                         | 5/10 [00:14<00:14,  2.91s/it]

train_loss : 0.08939832286990207
train_F1: 0.9054325955734407
valid_loss : 0.12312527497609456
valid_F1: 0.8902439024390244



 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:17<00:11,  2.89s/it]

train_loss : 0.07302355409964272
train_F1: 0.9180555555555556
valid_loss : 0.15412658949693045
valid_F1: 0.8875



 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:20<00:08,  2.86s/it]

train_loss : 0.06844824422960696
train_F1: 0.9174690508940854
valid_loss : 0.1378340870141983
valid_F1: 0.8765432098765432

train_loss : 0.0632084530332814
train_F1: 0.921443736730361
valid_loss : 0.1331190566221873
valid_F1: 0.89171974522293

 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:23<00:05,  2.88s/it]





 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:26<00:02,  2.95s/it]

train_loss : 0.06065622749535934
train_F1: 0.9270326615705352
valid_loss : 0.1387586643298467
valid_F1: 0.8902439024390244



100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:29<00:00,  2.98s/it]

train_loss : 0.060659166587435684
train_F1: 0.9214732453092425
valid_loss : 0.1458306759595871
valid_F1: 0.8834355828220859


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:29<00:00,  2.94s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

FOLD 5
--------------------------------

train_loss : 1.0073696575734927
train_F1: 0.5426086956521738


 10%|████████▎                                                                          | 1/10 [00:03<00:27,  3.01s/it]

valid_loss : 0.2933809806903203
valid_F1: 0.8187134502923977



 20%|████████████████▌                                                                  | 2/10 [00:06<00:24,  3.09s/it]

train_loss : 0.1960111762518468
train_F1: 0.8798938287989383
valid_loss : 0.18151339888572693
valid_F1: 0.8606060606060606



 30%|████████████████████████▉                                                          | 3/10 [00:09<00:21,  3.06s/it]

train_loss : 0.12397061936233354
train_F1: 0.9022252191503709
valid_loss : 0.16086803376674652
valid_F1: 0.8554216867469879



 40%|█████████████████████████████████▏                                                 | 4/10 [00:12<00:18,  3.02s/it]

train_loss : 0.09734682933143947
train_F1: 0.9113233287858117
valid_loss : 0.18159512182076773
valid_F1: 0.8227848101265822



 50%|█████████████████████████████████████████▌                                         | 5/10 [00:15<00:14,  2.98s/it]

train_loss : 0.08257498916076578
train_F1: 0.9125683060109291
valid_loss : 0.16175648073355356
valid_F1: 0.8536585365853657

train_loss : 0.06709812613932983
train_F1: 0.9203049203049204
valid_loss : 0.1871513028939565
valid_F1: 0.8226950354609929

 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:17<00:11,  2.95s/it]





 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:20<00:08,  2.96s/it]

train_loss : 0.06640475275723831
train_F1: 0.9292096219931272
valid_loss : 0.20870643854141235
valid_F1: 0.8226950354609929



 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:23<00:05,  2.91s/it]

train_loss : 0.06023509444101997
train_F1: 0.9266943291839558
valid_loss : 0.2169093539317449
valid_F1: 0.8133333333333332



 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:26<00:02,  2.91s/it]

train_loss : 0.05921123662720556
train_F1: 0.9267955801104973
valid_loss : 0.21435859302679697
valid_F1: 0.8163265306122449



100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:29<00:00,  2.93s/it]

train_loss : 0.05830122303703557
train_F1: 0.9289693593314763
valid_loss : 0.19879927734533945
valid_F1: 0.8402366863905325


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:29<00:00,  2.96s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

FOLD 6
--------------------------------



 10%|████████▎                                                                          | 1/10 [00:03<00:27,  3.09s/it]

train_loss : 0.9539886648240297
train_F1: 0.5235434956105348
valid_loss : 0.37929170827070874
valid_F1: 0.7368421052631577



 20%|████████████████▌                                                                  | 2/10 [00:05<00:23,  2.93s/it]

train_loss : 0.19569411906211273
train_F1: 0.8596256684491979
valid_loss : 0.2228871782620748
valid_F1: 0.8452380952380952



 30%|████████████████████████▉                                                          | 3/10 [00:08<00:19,  2.81s/it]

train_loss : 0.12464353247829106
train_F1: 0.90020366598778
valid_loss : 0.19645422200361887
valid_F1: 0.8520710059171598



 40%|█████████████████████████████████▏                                                 | 4/10 [00:11<00:17,  2.85s/it]

train_loss : 0.09213785211677136
train_F1: 0.9056087551299589
valid_loss : 0.20087007184823355
valid_F1: 0.8363636363636364



 50%|█████████████████████████████████████████▌                                         | 5/10 [00:14<00:14,  2.84s/it]

train_loss : 0.07304367163906926
train_F1: 0.9192886456908345
valid_loss : 0.22494304180145264
valid_F1: 0.8333333333333334



 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:17<00:11,  2.84s/it]

train_loss : 0.06887961729713109
train_F1: 0.9205479452054794
valid_loss : 0.2236345261335373
valid_F1: 0.8125000000000001



 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:20<00:08,  2.87s/it]

train_loss : 0.06550597982562106
train_F1: 0.9225496915695682
valid_loss : 0.23144780099391937
valid_F1: 0.8516129032258065



 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:22<00:05,  2.88s/it]

train_loss : 0.05888758049063061
train_F1: 0.9269662921348315
valid_loss : 0.25862740973631543
valid_F1: 0.8421052631578948



 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:25<00:02,  2.90s/it]

train_loss : 0.058708191241907036
train_F1: 0.9316005471956224
valid_loss : 0.2640799582004547
valid_F1: 0.8387096774193549



100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:28<00:00,  2.93s/it]

train_loss : 0.05537198062824166
train_F1: 0.9323098394975575
valid_loss : 0.27105702459812164
valid_F1: 0.8387096774193549


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:28<00:00,  2.89s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

FOLD 7
--------------------------------



 10%|████████▎                                                                          | 1/10 [00:02<00:25,  2.83s/it]

train_loss : 1.0124586163007694
train_F1: 0.5355285961871751
valid_loss : 0.22092862923940024
valid_F1: 0.8383233532934131



 20%|████████████████▌                                                                  | 2/10 [00:06<00:24,  3.08s/it]

train_loss : 0.19219536139913226
train_F1: 0.8835016835016836
valid_loss : 0.15788319955269495
valid_F1: 0.8928571428571429



 30%|████████████████████████▉                                                          | 3/10 [00:08<00:20,  2.98s/it]

train_loss : 0.13498243138841962
train_F1: 0.9100671140939598
valid_loss : 0.154767836133639
valid_F1: 0.8780487804878048



 40%|█████████████████████████████████▏                                                 | 4/10 [00:11<00:17,  2.95s/it]

train_loss : 0.10132372735635094
train_F1: 0.912542372881356
valid_loss : 0.13915157690644264
valid_F1: 0.8765432098765432



 50%|█████████████████████████████████████████▌                                         | 5/10 [00:14<00:14,  2.92s/it]

train_loss : 0.0793853240168613
train_F1: 0.9204152249134948
valid_loss : 0.12502253676454225
valid_F1: 0.8765432098765432



 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:17<00:11,  2.89s/it]

train_loss : 0.0722501595383105
train_F1: 0.9205479452054794
valid_loss : 0.12480438748995464
valid_F1: 0.8727272727272727



 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:20<00:08,  2.87s/it]

train_loss : 0.06288893196893774
train_F1: 0.9216909216909218
valid_loss : 0.1384169285496076
valid_F1: 0.8658536585365854



 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:23<00:05,  2.85s/it]

train_loss : 0.05953436377255813
train_F1: 0.9302004146510021
valid_loss : 0.12890997777382532
valid_F1: 0.880503144654088



 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:26<00:02,  2.85s/it]

train_loss : 0.060452900502992714
train_F1: 0.9278350515463917
valid_loss : 0.1394630347688993
valid_F1: 0.860759493670886



100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:28<00:00,  2.87s/it]

train_loss : 0.05689396806385206
train_F1: 0.9291666666666667
valid_loss : 0.1388466308514277
valid_F1: 0.8641975308641975


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:28<00:00,  2.90s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

FOLD 8
--------------------------------



 10%|████████▎                                                                          | 1/10 [00:02<00:25,  2.84s/it]

train_loss : 1.053929908444052
train_F1: 0.3991537376586742
valid_loss : 0.18178798258304596
valid_F1: 0.8690476190476191



 20%|████████████████▌                                                                  | 2/10 [00:05<00:22,  2.86s/it]

train_loss : 0.20590444071137387
train_F1: 0.8764195056780227
valid_loss : 0.16868495444456735
valid_F1: 0.8848484848484849



 30%|████████████████████████▉                                                          | 3/10 [00:08<00:19,  2.82s/it]

train_loss : 0.13892591080587843
train_F1: 0.8962264150943395
valid_loss : 0.1044229840238889
valid_F1: 0.8862275449101796



 40%|█████████████████████████████████▏                                                 | 4/10 [00:11<00:16,  2.81s/it]

train_loss : 0.10640211630126704
train_F1: 0.9092140921409213
valid_loss : 0.12646561612685522
valid_F1: 0.8834355828220859



 50%|█████████████████████████████████████████▌                                         | 5/10 [00:14<00:14,  2.87s/it]

train_loss : 0.0798454177768334
train_F1: 0.9134287661895024
valid_loss : 0.13983224829037985
valid_F1: 0.860759493670886



 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:17<00:11,  2.83s/it]

train_loss : 0.07912427057390628
train_F1: 0.9289617486338798
valid_loss : 0.15457448363304138
valid_F1: 0.8727272727272727



 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:19<00:08,  2.87s/it]

train_loss : 0.07287671319816423
train_F1: 0.9250681198910082
valid_loss : 0.14031698306401572
valid_F1: 0.8695652173913043



 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:22<00:05,  2.87s/it]

train_loss : 0.06549571130586707
train_F1: 0.9226480836236934
valid_loss : 0.1634787619113922
valid_F1: 0.8679245283018868



 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:25<00:02,  2.91s/it]

train_loss : 0.06438311254200728
train_F1: 0.9290144727773949
valid_loss : 0.16176948448022208
valid_F1: 0.860759493670886



100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:28<00:00,  2.93s/it]

train_loss : 0.05911494856295378
train_F1: 0.9321203638908326
valid_loss : 0.18458642065525055
valid_F1: 0.8809523809523809


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:28<00:00,  2.88s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

FOLD 9
--------------------------------



 10%|████████▎                                                                          | 1/10 [00:03<00:27,  3.05s/it]

train_loss : 1.0148448882543522
train_F1: 0.4184397163120568
valid_loss : 0.3107211838165919
valid_F1: 0.7719298245614035



 20%|████████████████▌                                                                  | 2/10 [00:05<00:23,  2.92s/it]

train_loss : 0.20369962149340173
train_F1: 0.8702594810379243
valid_loss : 0.23317093153794607
valid_F1: 0.8587570621468925



 30%|████████████████████████▉                                                          | 3/10 [00:08<00:20,  2.87s/it]

train_loss : 0.1331894485198933
train_F1: 0.8986083499005965
valid_loss : 0.2765078047911326
valid_F1: 0.8658536585365854



 40%|█████████████████████████████████▏                                                 | 4/10 [00:11<00:17,  2.84s/it]

train_loss : 0.10137805277886598
train_F1: 0.9116465863453815
valid_loss : 0.21009308099746704
valid_F1: 0.8809523809523809



 50%|█████████████████████████████████████████▌                                         | 5/10 [00:14<00:14,  2.90s/it]

train_loss : 0.07555053739444069
train_F1: 0.9164420485175202
valid_loss : 0.20132736365000406
valid_F1: 0.8928571428571429



 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:17<00:11,  2.95s/it]

train_loss : 0.06908856137939122
train_F1: 0.9186935371785963
valid_loss : 0.1906373699506124
valid_F1: 0.9056603773584906



 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:20<00:08,  3.00s/it]

train_loss : 0.06346160562142082
train_F1: 0.9224433768016473
valid_loss : 0.22618363300959268
valid_F1: 0.903225806451613



 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:23<00:05,  2.97s/it]

train_loss : 0.0625229413094728
train_F1: 0.9208128941836019
valid_loss : 0.20550512274106345
valid_F1: 0.8780487804878048



 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:26<00:02,  2.93s/it]

train_loss : 0.05860645291597947
train_F1: 0.924346629986245
valid_loss : 0.20725572109222412
valid_F1: 0.8789808917197451



100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:29<00:00,  2.91s/it]

train_loss : 0.057276470181734665
train_F1: 0.9316420014094433
valid_loss : 0.21379833420117697
valid_F1: 0.9012345679012346





In [18]:
All_score = len(All_Fold_score[0]) * [0]
for score in All_Fold_score:
    All_score = np.sum([All_score,score], axis = 0)
All_score = np.round((All_score / 10) ,2) 

In [19]:
All_score

array([0.83, 0.88, 0.88, 0.87, 0.87, 0.87, 0.87, 0.87, 0.86, 0.88])

## 單用 trainingset 去做 10-fold CV

In [20]:
All_Fold_score = []
for fold, (train_ids, test_ids) in enumerate(kfold.split(tr_dataset)):
    print(f'FOLD {fold}')
    print('--------------------------------')
    train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
    test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)
    train_dataloader = DataLoader(dataset, batch_size=bs,
                                  collate_fn=tr_dataset.collate_fn,
                                  sampler=train_subsampler)
    valid_dataloader = DataLoader(dataset, batch_size=bs,
                                  collate_fn=va_dataset.collate_fn,
                                  sampler=test_subsampler)  
    model = EmbeddedRnn(len(word2idx), 256, len(tags),2,word_pad_idx,tag_pad_idx)
    optimizer = optim.AdamW(model.parameters(), lr=5e-3)
    model = model.to(device)
    all_loader = {"train" : train_dataloader,
                  "valid" : valid_dataloader}
    Fold_score = []
    for epoch in tqdm(range(num_epoch)):
        all_loss = {
            'train': [],
            'valid': [],
        }
        print('')
        for loader in all_loader:
            predictions , true_labels  = [],[]
            for x, y in all_loader[loader]:
                optimizer.zero_grad()
                x = x.to(device)
                y = y.to(device)
                hidden = model.initHidden(x.size(0))
                hidden[0] = hidden[0].to(device)
                hidden[1] = hidden[1].to(device)
                output, loss = model(x, hidden,y)
                if loader == 'train':
                    loss.backward()
                    optimizer.step()
                all_loss[loader].append(loss.cpu().item()) 
                predictions.extend([[idx2tag[j] for j in i] for i in output])
                for i in y.detach().cpu().numpy():
                    _ = []
                    for j in i:
                        if j != tag_pad_idx:
                            _.append(idx2tag[j])
                    true_labels.append(_)
            print(f'{loader}_loss : {np.mean(np.array(all_loss[loader]))/64}')
            f_ = f1_score(true_labels,predictions,scheme = IOB2)
            print(f'{loader}_F1: {f_}')
            if loader == 'valid':
                Fold_score.append(f_)
    if Fold_score != []:
        All_Fold_score.append(Fold_score)

  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

FOLD 0
--------------------------------



 10%|████████▎                                                                          | 1/10 [00:02<00:22,  2.52s/it]

train_loss : 1.1518160752952098
train_F1: 0.4047856430707876
valid_loss : 0.21394155025482178
valid_F1: 0.875



 20%|████████████████▌                                                                  | 2/10 [00:05<00:20,  2.51s/it]

train_loss : 0.15547118932008744
train_F1: 0.9047619047619049
valid_loss : 0.10516305565834046
valid_F1: 0.9324324324324325



 30%|████████████████████████▉                                                          | 3/10 [00:07<00:17,  2.48s/it]

train_loss : 0.09333100691437721
train_F1: 0.9252120277563608
valid_loss : 0.08272008299827575
valid_F1: 0.9166666666666666



 40%|█████████████████████████████████▏                                                 | 4/10 [00:09<00:14,  2.49s/it]

train_loss : 0.07505445517599582
train_F1: 0.9320843091334895
valid_loss : 0.07060465216636658
valid_F1: 0.8888888888888888



 50%|█████████████████████████████████████████▌                                         | 5/10 [00:12<00:12,  2.51s/it]

train_loss : 0.05935889407992363
train_F1: 0.9256329113924051
valid_loss : 0.05605042576789856
valid_F1: 0.9041095890410958



 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:15<00:10,  2.56s/it]

train_loss : 0.05721569433808327
train_F1: 0.9287392325763508
valid_loss : 0.056273990869522096
valid_F1: 0.8873239436619719



 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:17<00:07,  2.54s/it]

train_loss : 0.05613960474729538
train_F1: 0.9340746624305004
valid_loss : 0.05742876529693604
valid_F1: 0.8985507246376812



 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:20<00:05,  2.56s/it]

train_loss : 0.0562263622879982
train_F1: 0.9270248596631916
valid_loss : 0.052383828163146975
valid_F1: 0.9027777777777778



 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:22<00:02,  2.53s/it]

train_loss : 0.054189054667949675
train_F1: 0.9382329945269742
valid_loss : 0.05786541700363159
valid_F1: 0.8857142857142857



100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:25<00:00,  2.52s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

train_loss : 0.05554633736610413
train_F1: 0.9311999999999999
valid_loss : 0.05536609888076782
valid_F1: 0.918918918918919
FOLD 1
--------------------------------



 10%|████████▎                                                                          | 1/10 [00:02<00:22,  2.46s/it]

train_loss : 1.1549613550305367
train_F1: 0.3772609819121447
valid_loss : 0.22663582861423492
valid_F1: 0.7913669064748201



 20%|████████████████▌                                                                  | 2/10 [00:04<00:19,  2.46s/it]

train_loss : 0.1564963825047016
train_F1: 0.8952234206471493
valid_loss : 0.16876394748687745
valid_F1: 0.88



 30%|████████████████████████▉                                                          | 3/10 [00:07<00:17,  2.56s/it]

train_loss : 0.10066209845244885
train_F1: 0.9202453987730062
valid_loss : 0.08620806038379669
valid_F1: 0.8951048951048951



 40%|█████████████████████████████████▏                                                 | 4/10 [00:10<00:15,  2.53s/it]

train_loss : 0.07086763456463814
train_F1: 0.9227166276346603
valid_loss : 0.0626590222120285
valid_F1: 0.8920863309352517



 50%|█████████████████████████████████████████▌                                         | 5/10 [00:12<00:12,  2.53s/it]

train_loss : 0.05846415907144546
train_F1: 0.9322957198443579
valid_loss : 0.05958524942398071
valid_F1: 0.9172932330827067



 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:15<00:10,  2.51s/it]

train_loss : 0.055303892493247984
train_F1: 0.9260143198090692
valid_loss : 0.05544670820236206
valid_F1: 0.9037037037037037



 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:17<00:07,  2.48s/it]

train_loss : 0.05354507938027382
train_F1: 0.9304897314375987
valid_loss : 0.05567455291748047
valid_F1: 0.8857142857142857



 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:19<00:04,  2.47s/it]

train_loss : 0.05295584127306938
train_F1: 0.9303442754203363
valid_loss : 0.054925060272216795
valid_F1: 0.9037037037037037



 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:22<00:02,  2.50s/it]

train_loss : 0.05206039175391197
train_F1: 0.9364069952305246
valid_loss : 0.053928375244140625
valid_F1: 0.8951048951048951



100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:25<00:00,  2.50s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

train_loss : 0.052072839438915254
train_F1: 0.9311999999999999
valid_loss : 0.05347089767456055
valid_F1: 0.8873239436619719
FOLD 2
--------------------------------



 10%|████████▎                                                                          | 1/10 [00:02<00:22,  2.50s/it]

train_loss : 1.0603826422244311
train_F1: 0.4239226033421284
valid_loss : 0.12673501968383788
valid_F1: 0.9178082191780823



 20%|████████████████▌                                                                  | 2/10 [00:04<00:19,  2.47s/it]

train_loss : 0.16484499387443066
train_F1: 0.9002284843869003
valid_loss : 0.0629791796207428
valid_F1: 0.951048951048951



 30%|████████████████████████▉                                                          | 3/10 [00:07<00:17,  2.49s/it]

train_loss : 0.1112667128443718
train_F1: 0.9179331306990881
valid_loss : 0.044490212202072145
valid_F1: 0.9589041095890412



 40%|█████████████████████████████████▏                                                 | 4/10 [00:09<00:14,  2.50s/it]

train_loss : 0.0778329972177744
train_F1: 0.9103448275862068
valid_loss : 0.03304000496864319
valid_F1: 0.9558823529411764



 50%|█████████████████████████████████████████▌                                         | 5/10 [00:12<00:12,  2.47s/it]

train_loss : 0.06546348109841346
train_F1: 0.9162717219589256
valid_loss : 0.02955838441848755
valid_F1: 0.9523809523809524



 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:15<00:10,  2.53s/it]

train_loss : 0.06114770770072937
train_F1: 0.9194312796208529
valid_loss : 0.0268868088722229
valid_F1: 0.9655172413793103



 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:17<00:07,  2.51s/it]

train_loss : 0.0586444154381752
train_F1: 0.9258964143426295
valid_loss : 0.02683367133140564
valid_F1: 0.9577464788732394



 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:19<00:04,  2.48s/it]

train_loss : 0.056923973560333255
train_F1: 0.9285714285714287
valid_loss : 0.026916921138763428
valid_F1: 0.9635036496350364



 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:22<00:02,  2.48s/it]

train_loss : 0.05700540691614151
train_F1: 0.9286846275752774
valid_loss : 0.0277826189994812
valid_F1: 0.9558823529411764



100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:24<00:00,  2.49s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

train_loss : 0.05633391961455345
train_F1: 0.9238171611868484
valid_loss : 0.029253482818603516
valid_F1: 0.9577464788732394
FOLD 3
--------------------------------



 10%|████████▎                                                                          | 1/10 [00:02<00:22,  2.47s/it]

train_loss : 1.025277066975832
train_F1: 0.39814097598760645
valid_loss : 0.24466924965381623
valid_F1: 0.8057553956834531



 20%|████████████████▌                                                                  | 2/10 [00:04<00:19,  2.46s/it]

train_loss : 0.14980625100433825
train_F1: 0.8960739030023095
valid_loss : 0.11095255613327026
valid_F1: 0.8874172185430463



 30%|████████████████████████▉                                                          | 3/10 [00:07<00:17,  2.48s/it]

train_loss : 0.08911373429000377
train_F1: 0.9224865694551037
valid_loss : 0.09024635255336762
valid_F1: 0.9066666666666667



 40%|█████████████████████████████████▏                                                 | 4/10 [00:09<00:14,  2.46s/it]

train_loss : 0.06459811255335808
train_F1: 0.9276729559748429
valid_loss : 0.08732407093048096
valid_F1: 0.8859060402684563



 50%|█████████████████████████████████████████▌                                         | 5/10 [00:12<00:12,  2.47s/it]

train_loss : 0.05746899619698524
train_F1: 0.9296875
valid_loss : 0.10498000979423523
valid_F1: 0.8707482993197279



 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:14<00:09,  2.44s/it]

train_loss : 0.05611406192183495
train_F1: 0.9349206349206349
valid_loss : 0.10456883311271667
valid_F1: 0.88



 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:17<00:07,  2.45s/it]

train_loss : 0.05493486002087593
train_F1: 0.928909952606635
valid_loss : 0.09787464737892151
valid_F1: 0.8707482993197279



 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:19<00:04,  2.48s/it]

train_loss : 0.05441888496279716
train_F1: 0.9321901792673422
valid_loss : 0.1057765781879425
valid_F1: 0.8707482993197279



 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:22<00:02,  2.47s/it]

train_loss : 0.05359228849411011
train_F1: 0.9293089753772836
valid_loss : 0.11863350868225098
valid_F1: 0.8707482993197279



100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:24<00:00,  2.47s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

train_loss : 0.053869336098432544
train_F1: 0.9362706530291109
valid_loss : 0.10495192408561707
valid_F1: 0.8707482993197279
FOLD 4
--------------------------------



 10%|████████▎                                                                          | 1/10 [00:02<00:23,  2.57s/it]

train_loss : 1.022728405892849
train_F1: 0.43902439024390244
valid_loss : 0.2364177107810974
valid_F1: 0.8413793103448276



 20%|████████████████▌                                                                  | 2/10 [00:05<00:20,  2.57s/it]

train_loss : 0.15056984685361385
train_F1: 0.9050535987748852
valid_loss : 0.14224184155464173
valid_F1: 0.9166666666666666



 30%|████████████████████████▉                                                          | 3/10 [00:07<00:17,  2.55s/it]

train_loss : 0.09332856647670269
train_F1: 0.9252120277563608
valid_loss : 0.08194224834442139
valid_F1: 0.9178082191780823



 40%|█████████████████████████████████▏                                                 | 4/10 [00:10<00:14,  2.49s/it]

train_loss : 0.0653215229511261
train_F1: 0.9321766561514195
valid_loss : 0.08063842058181762
valid_F1: 0.9271523178807948



 50%|█████████████████████████████████████████▌                                         | 5/10 [00:12<00:12,  2.50s/it]

train_loss : 0.058840805664658546
train_F1: 0.9296
valid_loss : 0.06834262609481812
valid_F1: 0.906474820143885



 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:15<00:10,  2.53s/it]

train_loss : 0.054326131194829944
train_F1: 0.9353312302839116
valid_loss : 0.06739122867584228
valid_F1: 0.9051094890510948



 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:17<00:07,  2.53s/it]

train_loss : 0.05261952430009842
train_F1: 0.9326845093268451
valid_loss : 0.0683364987373352
valid_F1: 0.9



 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:20<00:05,  2.53s/it]

train_loss : 0.05362853407859802
train_F1: 0.9287392325763508
valid_loss : 0.07379831075668335
valid_F1: 0.9090909090909091



 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:22<00:02,  2.55s/it]

train_loss : 0.051544911414384845
train_F1: 0.9355608591885441
valid_loss : 0.06775501966476441
valid_F1: 0.8985507246376812



100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:25<00:00,  2.53s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

train_loss : 0.05187326297163963
train_F1: 0.9366085578446909
valid_loss : 0.0703742504119873
valid_F1: 0.9014084507042254
FOLD 5
--------------------------------



 10%|████████▎                                                                          | 1/10 [00:02<00:21,  2.40s/it]

train_loss : 1.0721427656710147
train_F1: 0.41864555848724716
valid_loss : 0.2113047420978546
valid_F1: 0.8333333333333334



 20%|████████████████▌                                                                  | 2/10 [00:04<00:19,  2.46s/it]

train_loss : 0.16351160630583764
train_F1: 0.9002284843869002
valid_loss : 0.06689940690994263
valid_F1: 0.945945945945946



 30%|████████████████████████▉                                                          | 3/10 [00:07<00:17,  2.48s/it]

train_loss : 0.08838750794529915
train_F1: 0.9155693261037954
valid_loss : 0.040873610973358156
valid_F1: 0.9370629370629371



 40%|█████████████████████████████████▏                                                 | 4/10 [00:09<00:14,  2.45s/it]

train_loss : 0.06752798929810525
train_F1: 0.9173745173745174
valid_loss : 0.030399537086486815
valid_F1: 0.9444444444444445



 50%|█████████████████████████████████████████▌                                         | 5/10 [00:12<00:12,  2.44s/it]

train_loss : 0.05890435501933098
train_F1: 0.9223378702962369
valid_loss : 0.028673011064529418
valid_F1: 0.921985815602837



 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:14<00:09,  2.47s/it]

train_loss : 0.057092498242855075
train_F1: 0.9228321400159109
valid_loss : 0.028254604339599608
valid_F1: 0.9370629370629371



 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:17<00:07,  2.48s/it]

train_loss : 0.05660344287753105
train_F1: 0.9238095238095239
valid_loss : 0.028996586799621582
valid_F1: 0.948905109489051



 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:19<00:04,  2.47s/it]

train_loss : 0.056515153497457504
train_F1: 0.9298245614035088
valid_loss : 0.03163902759552002
valid_F1: 0.9295774647887323



 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:22<00:02,  2.46s/it]

train_loss : 0.05695496946573257
train_F1: 0.9311424100156495
valid_loss : 0.03020545244216919
valid_F1: 0.9370629370629371



100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:24<00:00,  2.46s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

train_loss : 0.05600139498710632
train_F1: 0.9293089753772835
valid_loss : 0.03018786907196045
valid_F1: 0.921985815602837
FOLD 6
--------------------------------



 10%|████████▎                                                                          | 1/10 [00:02<00:22,  2.55s/it]

train_loss : 1.0729405775666236
train_F1: 0.46209386281588444
valid_loss : 0.224786639213562
valid_F1: 0.8407643312101911



 20%|████████████████▌                                                                  | 2/10 [00:04<00:19,  2.46s/it]

train_loss : 0.15931340344250203
train_F1: 0.8977709454265949
valid_loss : 0.14507938623428346
valid_F1: 0.8413793103448276



 30%|████████████████████████▉                                                          | 3/10 [00:07<00:17,  2.48s/it]

train_loss : 0.0750445444136858
train_F1: 0.9326103795507359
valid_loss : 0.09702659845352173
valid_F1: 0.8666666666666666



 40%|█████████████████████████████████▏                                                 | 4/10 [00:10<00:15,  2.52s/it]

train_loss : 0.057408804446458815
train_F1: 0.9321766561514195
valid_loss : 0.09993247985839844
valid_F1: 0.8450704225352113



 50%|█████████████████████████████████████████▌                                         | 5/10 [00:12<00:12,  2.53s/it]

train_loss : 0.05354450047016144
train_F1: 0.9301587301587302
valid_loss : 0.09657948017120362
valid_F1: 0.851063829787234



 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:15<00:10,  2.52s/it]

train_loss : 0.04998828396201134
train_F1: 0.9382911392405064
valid_loss : 0.09533041715621948
valid_F1: 0.832116788321168



 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:17<00:07,  2.49s/it]

train_loss : 0.0491546630859375
train_F1: 0.9451073985680191
valid_loss : 0.09051567316055298
valid_F1: 0.832116788321168



 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:19<00:04,  2.48s/it]

train_loss : 0.04827393889427185
train_F1: 0.9423835832675611
valid_loss : 0.09699244499206543
valid_F1: 0.855263157894737



 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:22<00:02,  2.51s/it]

train_loss : 0.04840454906225204
train_F1: 0.9420970266040688
valid_loss : 0.10042332410812378
valid_F1: 0.8391608391608392



100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:24<00:00,  2.50s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

train_loss : 0.04837017506361008
train_F1: 0.9369085173501578
valid_loss : 0.0964084267616272
valid_F1: 0.8244274809160306
FOLD 7
--------------------------------



 10%|████████▎                                                                          | 1/10 [00:02<00:21,  2.36s/it]

train_loss : 1.075817956030369
train_F1: 0.45454545454545453
valid_loss : 0.21690743565559387
valid_F1: 0.7538461538461538



 20%|████████████████▌                                                                  | 2/10 [00:04<00:19,  2.42s/it]

train_loss : 0.15815497040748597
train_F1: 0.8992248062015504
valid_loss : 0.09614918231964112
valid_F1: 0.9295774647887323



 30%|████████████████████████▉                                                          | 3/10 [00:07<00:17,  2.45s/it]

train_loss : 0.08416910506784916
train_F1: 0.9168609168609169
valid_loss : 0.06944820284843445
valid_F1: 0.9172932330827068



 40%|█████████████████████████████████▏                                                 | 4/10 [00:09<00:14,  2.41s/it]

train_loss : 0.07035920359194278
train_F1: 0.9261318506751389
valid_loss : 0.06471070051193237
valid_F1: 0.887218045112782



 50%|█████████████████████████████████████████▌                                         | 5/10 [00:12<00:12,  2.46s/it]

train_loss : 0.06083245873451233
train_F1: 0.9264475743348983
valid_loss : 0.056942975521087645
valid_F1: 0.9251700680272108



 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:14<00:09,  2.48s/it]

train_loss : 0.0560784712433815
train_F1: 0.933649289099526
valid_loss : 0.054404163360595705
valid_F1: 0.8905109489051095



 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:17<00:07,  2.46s/it]

train_loss : 0.054092679917812345
train_F1: 0.9309504467912266
valid_loss : 0.05345410108566284
valid_F1: 0.8905109489051095



 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:19<00:04,  2.43s/it]

train_loss : 0.053941337019205095
train_F1: 0.9348171701112877
valid_loss : 0.053601419925689696
valid_F1: 0.8951048951048951



 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:21<00:02,  2.43s/it]

train_loss : 0.053299762308597565
train_F1: 0.934959349593496
valid_loss : 0.05271778106689453
valid_F1: 0.916030534351145



100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:24<00:00,  2.45s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

train_loss : 0.05206667929887772
train_F1: 0.9416466826538769
valid_loss : 0.05273834466934204
valid_F1: 0.9104477611940298
FOLD 8
--------------------------------



 10%|████████▎                                                                          | 1/10 [00:02<00:22,  2.55s/it]

train_loss : 1.0564922861754895
train_F1: 0.41518578352180935
valid_loss : 0.15885572433471679
valid_F1: 0.8843537414965986

train_loss : 0.14724653996527196
train_F1: 0.906888720666162
valid_loss : 0.12228359580039978
valid_F1: 0.9066666666666667

 20%|████████████████▌                                                                  | 2/10 [00:05<00:20,  2.52s/it]





 30%|████████████████████████▉                                                          | 3/10 [00:07<00:17,  2.53s/it]

train_loss : 0.09876161254942417
train_F1: 0.9200304645849201
valid_loss : 0.09094708859920501
valid_F1: 0.9241379310344828



 40%|█████████████████████████████████▏                                                 | 4/10 [00:10<00:15,  2.52s/it]

train_loss : 0.06683434434235096
train_F1: 0.92018779342723
valid_loss : 0.07210142612457275
valid_F1: 0.9275362318840579



 50%|█████████████████████████████████████████▌                                         | 5/10 [00:12<00:12,  2.50s/it]

train_loss : 0.05766922831535339
train_F1: 0.9227091633466136
valid_loss : 0.06836115121841431
valid_F1: 0.9090909090909091



 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:15<00:09,  2.48s/it]

train_loss : 0.05563724935054779
train_F1: 0.9259547934528448
valid_loss : 0.06925854086875916
valid_F1: 0.8951048951048951



 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:17<00:07,  2.51s/it]

train_loss : 0.054621639847755435
train_F1: 0.9266347687400319
valid_loss : 0.06514195203781128
valid_F1: 0.9197080291970803



 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:20<00:04,  2.50s/it]

train_loss : 0.05259038358926773
train_F1: 0.9253731343283582
valid_loss : 0.062290513515472413
valid_F1: 0.9117647058823529



 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:22<00:02,  2.47s/it]

train_loss : 0.05181585848331451
train_F1: 0.9305666400638467
valid_loss : 0.06356804370880127
valid_F1: 0.923076923076923



100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:24<00:00,  2.50s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

train_loss : 0.05186048224568367
train_F1: 0.9333333333333335
valid_loss : 0.06343463063240051
valid_F1: 0.916030534351145
FOLD 9
--------------------------------



 10%|████████▎                                                                          | 1/10 [00:02<00:22,  2.45s/it]

train_loss : 1.1059430897235871
train_F1: 0.4123893805309734
valid_loss : 0.22844302654266357
valid_F1: 0.8368794326241136



 20%|████████████████▌                                                                  | 2/10 [00:04<00:19,  2.42s/it]

train_loss : 0.15693032443523408
train_F1: 0.9002284843869002
valid_loss : 0.09943410158157348
valid_F1: 0.9395973154362416



 30%|████████████████████████▉                                                          | 3/10 [00:07<00:17,  2.47s/it]

train_loss : 0.09477468430995942
train_F1: 0.9153318077803203
valid_loss : 0.07063151001930237
valid_F1: 0.923076923076923



 40%|█████████████████████████████████▏                                                 | 4/10 [00:09<00:14,  2.48s/it]

train_loss : 0.07236549034714698
train_F1: 0.9245432883240667
valid_loss : 0.062227845191955566
valid_F1: 0.9315068493150684



 50%|█████████████████████████████████████████▌                                         | 5/10 [00:12<00:12,  2.49s/it]

train_loss : 0.06055576764047146
train_F1: 0.923682140047207
valid_loss : 0.05470069646835327
valid_F1: 0.9264705882352942



 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:14<00:09,  2.48s/it]

train_loss : 0.058972270041704175
train_F1: 0.9273743016759776
valid_loss : 0.05177503824234009
valid_F1: 0.9253731343283582



 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:17<00:07,  2.48s/it]

train_loss : 0.05461499541997909
train_F1: 0.9281150159744409
valid_loss : 0.04816208481788635
valid_F1: 0.9185185185185185



 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:19<00:04,  2.47s/it]

train_loss : 0.05508963316679001
train_F1: 0.9265536723163842
valid_loss : 0.04886829853057861
valid_F1: 0.9343065693430657



 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:22<00:02,  2.48s/it]

train_loss : 0.05333574637770653
train_F1: 0.9350441058540497
valid_loss : 0.047375190258026126
valid_F1: 0.9253731343283582



100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:24<00:00,  2.46s/it]

train_loss : 0.05351424887776375
train_F1: 0.9270248596631917
valid_loss : 0.04525509476661682
valid_F1: 0.9185185185185185





In [21]:
All_score = len(All_Fold_score[0]) * [0]
for score in All_Fold_score:
    All_score = np.sum([All_score,score], axis = 0)
All_score = np.round((All_score / 10) ,2) 
All_score

array([0.84, 0.91, 0.92, 0.91, 0.91, 0.9 , 0.9 , 0.91, 0.9 , 0.9 ])

## 全部的 trainingset 訓練 model 後, 透過 validset 評估 model

In [26]:
tr_dataset = NERDataset(X_train,y_train,word2idx['<pad>'],tag2idx['<pad>'])
va_dataset = NERDataset(X_test,y_test,word2idx['<pad>'],tag2idx['<pad>'])
train_dataloader = DataLoader(tr_dataset, batch_size=bs,
                            collate_fn=tr_dataset.collate_fn)
valid_dataloader = DataLoader(va_dataset, batch_size=bs,
                            collate_fn=va_dataset.collate_fn)  
model = EmbeddedRnn(len(word2idx), 256, len(tags),2,word_pad_idx,tag_pad_idx)
optimizer = optim.AdamW(model.parameters(), lr=5e-3)
model = model.to(device)
all_loader = {"train" : train_dataloader,
              "valid" : valid_dataloader}

In [27]:
F_score = []
for epoch in tqdm(range(num_epoch)):
    all_loss = {
    'train': [],
    'valid': [],
    }
    print('')
    for loader in all_loader:
        predictions , true_labels  = [],[]
        for x, y in all_loader[loader]:
            optimizer.zero_grad()
            x = x.to(device)
            y = y.to(device)
            hidden = model.initHidden(x.size(0))
            hidden[0] = hidden[0].to(device)
            hidden[1] = hidden[1].to(device)
            output, loss = model(x, hidden,y)
            if loader == 'train':
                loss.backward()
                optimizer.step()
            all_loss[loader].append(loss.cpu().item()) 
            predictions.extend([[idx2tag[j] for j in i] for i in output])
            for i in y.detach().cpu().numpy():
                _ = []
                for j in i:
                    if j != tag_pad_idx:
                        _.append(idx2tag[j])
                true_labels.append(_)
        print(f'{loader}_loss : {np.mean(np.array(all_loss[loader]))/64}')
        f_ = f1_score(true_labels,predictions,scheme = IOB2)
        print(f'{loader}_F1: {f_}')
        if loader == 'valid':
            F_score.append(f_)

  0%|                                                                                           | 0/10 [00:00<?, ?it/s]


train_loss : 0.9918276329835256
train_F1: 0.44787644787644787


 10%|████████▎                                                                          | 1/10 [00:02<00:24,  2.76s/it]

valid_loss : 0.8534433160509381
valid_F1: 0.603174603174603

train_loss : 0.1488828053077062
train_F1: 0.9070887818306951


 20%|████████████████▌                                                                  | 2/10 [00:05<00:21,  2.74s/it]

valid_loss : 0.6977005771228245
valid_F1: 0.6969696969696969

train_loss : 0.09127396510707007
train_F1: 0.926027397260274


 30%|████████████████████████▉                                                          | 3/10 [00:08<00:19,  2.76s/it]

valid_loss : 0.7558879256248474
valid_F1: 0.6804123711340206

train_loss : 0.060549938016467623
train_F1: 0.9298369950389794


 40%|█████████████████████████████████▏                                                 | 4/10 [00:11<00:16,  2.78s/it]

valid_loss : 0.8134842429842267
valid_F1: 0.6489361702127661

train_loss : 0.05712536573410034
train_F1: 0.9308624376336423


 50%|█████████████████████████████████████████▌                                         | 5/10 [00:14<00:14,  2.83s/it]

valid_loss : 0.8615969589778355
valid_F1: 0.66

train_loss : 0.05198064645131429
train_F1: 0.935251798561151
valid_loss : 0.9410698328699384
valid_F1: 0.6699029126213593

 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:17<00:11,  2.89s/it]



train_loss : 0.051404794719484115
train_F1: 0.9356223175965664


 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:19<00:08,  2.83s/it]

valid_loss : 0.9697112611361912
valid_F1: 0.6699029126213593

train_loss : 0.05101567970381843
train_F1: 0.9363831308077198


 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:22<00:05,  2.83s/it]

valid_loss : 0.9949791942323957
valid_F1: 0.6763285024154588

train_loss : 0.051482242345809934
train_F1: 0.9358327325162221


 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:25<00:02,  2.81s/it]

valid_loss : 0.9544537237712315
valid_F1: 0.6699507389162561

train_loss : 0.05123677518632677
train_F1: 0.9358974358974359


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:28<00:00,  2.81s/it]

valid_loss : 0.965035447052547
valid_F1: 0.6601941747572816





In [28]:
np.round(np.array(F_score),2)

array([0.6 , 0.7 , 0.68, 0.65, 0.66, 0.67, 0.67, 0.68, 0.67, 0.66])