In [1]:
import copy

import matplotlib.pyplot as plt
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from seqeval.scheme import IOB2
from seqeval.metrics import classification_report,f1_score

In [2]:
import pandas as pd
import numpy as np
import json
from tqdm import tqdm
import warnings
import random

from keras.preprocessing.text import Tokenizer
from sklearn.model_selection import KFold

from torch.utils.data import TensorDataset, DataLoader ,SubsetRandomSampler ,ConcatDataset ,Dataset
from torchcrf import CRF


In [3]:
data_train = pd.read_csv("IOB2_Data/BCSMM4H_train_IOB2_all.txt",sep = '\t')
data_dev = pd.read_csv("IOB2_Data/BC_dev_IOB2_all.txt",sep = '\t')

In [4]:
data_train = data_train.fillna('NA')
data_dev = data_dev.fillna('NA')

In [5]:
class SentenceGetter(object):
    
    def __init__(self, data):
        self.n_sent = 1
        self.data = data
        self.empty = False
        agg_func = lambda s: [(w, t) for w, t in zip(s["Word"].values.tolist(),
                                                     s["Tag"].values.tolist())]
        self.grouped = self.data.groupby("Sentence#").apply(agg_func)
        self.sentences = [s for s in self.grouped]
    
    def get_next(self):
        try:
            s = self.grouped[self.n_sent]
            self.n_sent += 1
            return s
        except:
            return None

In [6]:
getter = SentenceGetter(data_train)
dev_getter = SentenceGetter(data_dev)

In [7]:
sentences = [[word[0] for word in sentence] for sentence in getter.sentences]
dev_sentences = [[word[0] for word in sentence] for sentence in dev_getter.sentences]
labels = [[s[1] for s in sent] for sent in getter.sentences]
dev_labels = [[s[1] for s in sent] for sent in dev_getter.sentences]

In [8]:
max_features = 40000
tokenizer = Tokenizer(num_words=max_features, split=' ', oov_token='<unk>')
tokenizer.fit_on_texts(sentences)

In [9]:
word2idx = tokenizer.word_index.copy()
word2idx['<pad>'] = 0
idx2word = {word2idx[i]:i for i in word2idx}

In [10]:
len(word2idx)

98877

In [11]:
tags = list(set(data_train['Tag'].values))
tags.append("<pad>")
tag2idx = {t: i for i, t in enumerate(tags)}
idx2tag = {i: t for i, t in enumerate(tags)}

In [12]:
X_train = tokenizer.texts_to_sequences(sentences)
X_test = tokenizer.texts_to_sequences(dev_sentences)
y_train = [[tag2idx[l] for l in s] for s in labels]
y_test = [[tag2idx[l] for l in s] for s in dev_labels]

In [13]:
for i in X_test[0]:
    print(idx2word[i],end = ' ')

i 'm 37 weeks pregnant so i can do whatever the hell i want 

In [14]:
class NERDataset(Dataset):
    def __init__(self,sentences,labels, word_pad_idx, tag_pad_idx, max_len = 500):
        self.sentences = sentences
        self.labels = labels
        self.word_pad_idx = word_pad_idx
        self.tag_pad_idx = tag_pad_idx
        self.max_len = max_len
    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, index):
        return (self.sentences[index],self.labels[index])
        
    def collate_fn(self, datasets):
        sentences = [dataset[0] for dataset in datasets]
        labels = [dataset[1] for dataset in datasets]
        max_sent = max([len(data) for data in sentences])
        max_len = max([min(len(sentence), self.max_len) for sentence in sentences])
        pad_sentence = []
        pad_label = []
        for sentence,label in zip(sentences,labels):
            
            if len(sentence) > max_len:
#                 print('asd')
                pad_sentence.append(sentence[:max_len])
                pad_label.append(label[:max_len])
            else:
#                 print('zxc')
                pad_sentence.append(sentence+[self.word_pad_idx]*(max_len-len(sentence)))
                pad_label.append(label+[self.tag_pad_idx]*(max_len-len(label)))
        return torch.LongTensor(pad_sentence), torch.LongTensor(pad_label)

In [15]:
class EmbeddedRnn(nn.Module):
    def __init__(self, vocab, hidden_dim, output_vocab, n_layer,word_pad_idx,tag_pad_idx):
        super(EmbeddedRnn, self).__init__()
        self.n_layer = n_layer
        self.embedding_size = 300
        self.hidden_dim = hidden_dim
        self.embedded = nn.Embedding(vocab, self.embedding_size , padding_idx  = word_pad_idx)
        self.lstm = nn.LSTM(self.embedding_size, hidden_dim, num_layers=n_layer,batch_first = True, bidirectional=True)
        self.fc1 = nn.Linear(2 * hidden_dim, output_vocab)
        self.softmax = nn.Softmax(dim=-1)
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        self.crf = CRF(num_tags=output_vocab, batch_first = True)
        self.tag_pad_idx = tag_pad_idx
        
    def forward(self, x, hidden,y_tag):
        embedded = self.embedded(x)
        output, hidden = self.lstm(embedded, hidden)
        output = self.fc1(output)
        if y_tag is not None:
            mask = y_tag != self.tag_pad_idx
            crf_out = self.crf.decode(output, mask=mask)
            crf_loss = -self.crf(output, tags=y_tag, mask=mask)
        else:
            crf_out = self.crf.decode(output)
            crf_loss = None
        return crf_out, crf_loss
    
    def initHidden(self, batch_size):
        hidden = Variable(torch.zeros(2 * self.n_layer, batch_size, self.hidden_dim))
        cell = Variable(torch.zeros(2 * self.n_layer, batch_size, self.hidden_dim))
        return [hidden, cell]
#         return hidden
    def init_crf_transitions(tag_names, imp_value=-100):
        crf = CRF(num_tags=n_tags)
        num_tags = len(tag_names)
        for i in range(num_tags):
            tag_name = tag_names[i]
#             if tag_name[0] in ("I") or tag_name[0] in ("L") or tag_name == "ENDPAD":
            if tag_name[0] in ("I") or tag_name == "ENDPAD":
                torch.nn.init.constant_(crf.start_transitions[i], imp_value)
        tag_is = {}
        for tag_position in ("B", "I", "O"):
            tag_is[tag_position] = [i for i, tag in enumerate(tag_names) if tag[0] == tag_position]
        impossible_transitions_position = {
            "O": "I",  
#             "O": 'L',
            
#             'B': 'U',
#             'B': 'I',
            
#             'I': 'B',
#             'I': 'U',
#             'I': 'O',
            
#             'L': 'I',
            
#             'U': 'I',
#             'U': 'L',
        }
        for from_tag, to_tag_list in impossible_transitions_position.items():
            to_tags = list(to_tag_list)
            for from_tag_i in tag_is[from_tag]:
                for to_tag in to_tags:
                    for to_tag_i in tag_is[to_tag]:
                        torch.nn.init.constant_(
                            crf.transitions[from_tag_i, to_tag_i], imp_value
                        )
        # init impossible B and I transitions to different entity types
        impossible_transitions_tags = {
            "B": "I",
#             'B': 'L',
#             "I": "I",
#             'I': 'L',
#             'U': 'B',
        }
        for from_tag, to_tag_list in impossible_transitions_tags.items():
            to_tags = list(to_tag_list)
            for from_tag_i in tag_is[from_tag]:
                for to_tag in to_tags:
                    for to_tag_i in tag_is[to_tag]:
                        if tag_names[from_tag_i].split("-")[1] != tag_names[to_tag_i].split("-")[1]:
                            torch.nn.init.constant_(
                                crf.transitions[from_tag_i, to_tag_i], imp_value
                            )

## 使用 trainingset+validset 去做 10-fold CV

In [16]:
bs = 64
k_folds = 10
kfold = KFold(n_splits=k_folds, shuffle=True)
tr_dataset = NERDataset(X_train,y_train,word2idx['<pad>'],tag2idx['<pad>'])
va_dataset = NERDataset(X_test,y_test,word2idx['<pad>'],tag2idx['<pad>'])
dataset = ConcatDataset([tr_dataset, va_dataset])
word_pad_idx = word2idx['<pad>']
tag_pad_idx = tag2idx['<pad>']
num_epoch = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [17]:
All_Fold_score = []
for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset)):
    print(f'FOLD {fold}')
    print('--------------------------------')
    train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
    test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)
    train_dataloader = DataLoader(dataset, batch_size=bs,
                                  collate_fn=tr_dataset.collate_fn,
                                  sampler=train_subsampler)
    valid_dataloader = DataLoader(dataset, batch_size=bs,
                                  collate_fn=va_dataset.collate_fn,
                                  sampler=test_subsampler)  
    model = EmbeddedRnn(len(word2idx), 256, len(tags),2,word_pad_idx,tag_pad_idx)
    optimizer = optim.AdamW(model.parameters(), lr=5e-3)
    model = model.to(device)
    all_loader = {"train" : train_dataloader,
                  "valid" : valid_dataloader}
    Fold_score = []
    for epoch in tqdm(range(num_epoch)):
        all_loss = {
            'train': [],
            'valid': [],
        }
        print('')
        for loader in all_loader:
            predictions , true_labels  = [],[]
            for x, y in all_loader[loader]:
                optimizer.zero_grad()
                x = x.to(device)
                y = y.to(device)
                hidden = model.initHidden(x.size(0))
                hidden[0] = hidden[0].to(device)
                hidden[1] = hidden[1].to(device)
                output, loss = model(x, hidden,y)
                if loader == 'train':
                    loss.backward()
                    optimizer.step()
                all_loss[loader].append(loss.cpu().item()) 
                predictions.extend([[idx2tag[j] for j in i] for i in output])
                for i in y.detach().cpu().numpy():
                    _ = []
                    for j in i:
                        if j != tag_pad_idx:
                            _.append(idx2tag[j])
                    true_labels.append(_)
            print(f'{loader}_loss : {np.mean(np.array(all_loss[loader]))/64}')
            f_ = f1_score(true_labels,predictions,scheme = IOB2)
            print(f'{loader}_F1: {f_}')
            if loader == 'valid':
                Fold_score.append(f_)
    if Fold_score != []:
        All_Fold_score.append(Fold_score)

FOLD 0
--------------------------------


  0%|                                                                                           | 0/10 [00:00<?, ?it/s]


train_loss : 0.13933934730403455




train_F1: 0.6604336582454744
valid_loss : 0.07439310885269508


 10%|████████▏                                                                         | 1/10 [03:07<28:09, 187.74s/it]

valid_F1: 0.7713787085514833

train_loss : 0.04307648217434501
train_F1: 0.8434234064589021
valid_loss : 0.07099672815997467


 20%|████████████████▍                                                                 | 2/10 [06:14<24:59, 187.42s/it]

valid_F1: 0.8013876843018215

train_loss : 0.03151238429394513
train_F1: 0.8724632284490784
valid_loss : 0.08474308162683918


 30%|████████████████████████▌                                                         | 3/10 [09:21<21:50, 187.17s/it]

valid_F1: 0.8238255033557047

train_loss : 0.02528264556932886
train_F1: 0.8905407916744099
valid_loss : 0.09031563654333075


 40%|████████████████████████████████▊                                                 | 4/10 [12:28<18:41, 186.99s/it]

valid_F1: 0.830950378469302

train_loss : 0.02213661065195757
train_F1: 0.9088047403018239
valid_loss : 0.08292902967779436


 50%|█████████████████████████████████████████                                         | 5/10 [15:34<15:33, 186.65s/it]

valid_F1: 0.8125530110262933

train_loss : 0.019305546972903053
train_F1: 0.9144177449168207
valid_loss : 0.09369540454767575


 60%|█████████████████████████████████████████████████▏                                | 6/10 [18:41<12:27, 186.77s/it]

valid_F1: 0.8010118043844856

train_loss : 0.017302502772018395
train_F1: 0.9221247113163972
valid_loss : 0.09198502577388677


 70%|█████████████████████████████████████████████████████████▍                        | 7/10 [21:48<09:20, 186.79s/it]

valid_F1: 0.8266883645240032

train_loss : 0.016122532264392034
train_F1: 0.9288548491279874
valid_loss : 0.08365731330797355


 80%|█████████████████████████████████████████████████████████████████▌                | 8/10 [24:55<06:14, 187.01s/it]

valid_F1: 0.8180300500834725

train_loss : 0.01646286974501398
train_F1: 0.9260319512420353
valid_loss : 0.08792155630185887


 90%|█████████████████████████████████████████████████████████████████████████▊        | 9/10 [28:03<03:07, 187.35s/it]

valid_F1: 0.8108108108108107

train_loss : 0.015708150102913026
train_F1: 0.9314891655140618
valid_loss : 0.10364426289920077


100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [31:10<00:00, 187.04s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

valid_F1: 0.8215488215488215
FOLD 1
--------------------------------

train_loss : 0.13411778084379133
train_F1: 0.6999795626405069
valid_loss : 0.07507161226854703


 10%|████████▏                                                                         | 1/10 [03:08<28:13, 188.16s/it]

valid_F1: 0.7625772285966461

train_loss : 0.043075265505393305
train_F1: 0.8465994020926756
valid_loss : 0.06532065665227509


 20%|████████████████▍                                                                 | 2/10 [06:15<25:00, 187.57s/it]

valid_F1: 0.7906103286384977

train_loss : 0.030855312008742305
train_F1: 0.8747220163083765
valid_loss : 0.06759572419050698


 30%|████████████████████████▌                                                         | 3/10 [09:22<21:50, 187.18s/it]

valid_F1: 0.8025247971145175

train_loss : 0.02450014921491854
train_F1: 0.8949308755760369
valid_loss : 0.06301150236835826


 40%|████████████████████████████████▊                                                 | 4/10 [12:28<18:41, 186.91s/it]

valid_F1: 0.8037383177570093

train_loss : 0.019980323182470817
train_F1: 0.9088402867120015
valid_loss : 0.06138924144428206


 50%|█████████████████████████████████████████                                         | 5/10 [15:36<15:36, 187.26s/it]

valid_F1: 0.805128205128205

train_loss : 0.01848790877239494
train_F1: 0.9189882697947215
valid_loss : 0.07365717581325323


 60%|█████████████████████████████████████████████████▏                                | 6/10 [18:42<12:27, 186.92s/it]

valid_F1: 0.7892976588628763

train_loss : 0.017806686240837186
train_F1: 0.9237257059039237
valid_loss : 0.07595374398031403


 70%|█████████████████████████████████████████████████████████▍                        | 7/10 [21:48<09:19, 186.52s/it]

valid_F1: 0.8056288478452066

train_loss : 0.01653601316081138
train_F1: 0.9263370332996972
valid_loss : 0.0814632080739068


 80%|█████████████████████████████████████████████████████████████████▌                | 8/10 [24:54<06:13, 186.50s/it]

valid_F1: 0.7889908256880733

train_loss : 0.01471570374023311
train_F1: 0.9336141378994597
valid_loss : 0.08243522132860026


 90%|█████████████████████████████████████████████████████████████████████████▊        | 9/10 [28:01<03:06, 186.59s/it]

valid_F1: 0.7972508591065292

train_loss : 0.0167327826077997
train_F1: 0.9257257990658484
valid_loss : 0.07846174132357413


100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [31:07<00:00, 186.79s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

valid_F1: 0.7894736842105263
FOLD 2
--------------------------------

train_loss : 0.1359193779853934
train_F1: 0.6325250444714914
valid_loss : 0.06685503864678267


 10%|████████▏                                                                         | 1/10 [03:06<27:55, 186.16s/it]

valid_F1: 0.8003663003663004

train_loss : 0.04241608892883068
train_F1: 0.8413793103448276
valid_loss : 0.05860440051792381


 20%|████████████████▍                                                                 | 2/10 [06:12<24:49, 186.19s/it]

valid_F1: 0.817167381974249

train_loss : 0.03023950504269965
train_F1: 0.8726868985936344
valid_loss : 0.05929103359649671


 30%|████████████████████████▌                                                         | 3/10 [09:18<21:43, 186.26s/it]

valid_F1: 0.8370883882149047

train_loss : 0.024376891401777203
train_F1: 0.8989694516010305
valid_loss : 0.07209052377006162


 40%|████████████████████████████████▊                                                 | 4/10 [12:24<18:37, 186.23s/it]

valid_F1: 0.8334771354616048

train_loss : 0.02094538212938182
train_F1: 0.9105720823798629
valid_loss : 0.07017975667857121


 50%|█████████████████████████████████████████                                         | 5/10 [15:37<15:42, 188.57s/it]

valid_F1: 0.8153191489361703

train_loss : 0.019259514303415926
train_F1: 0.9177029992684711
valid_loss : 0.0816930271424005


 60%|█████████████████████████████████████████████████▏                                | 6/10 [18:49<12:39, 189.80s/it]

valid_F1: 0.8321299638989169

train_loss : 0.016050084862556383
train_F1: 0.9290275497172048
valid_loss : 0.07842849971830984


 70%|█████████████████████████████████████████████████████████▍                        | 7/10 [22:01<09:31, 190.58s/it]

valid_F1: 0.7756521739130435

train_loss : 0.017713216909598647
train_F1: 0.9184792542496801
valid_loss : 0.07955465658141353


 80%|█████████████████████████████████████████████████████████████████▌                | 8/10 [25:13<06:21, 190.98s/it]

valid_F1: 0.8274067649609714

train_loss : 0.016650495589568964
train_F1: 0.9296238674842134
valid_loss : 0.06979880195380858


 90%|█████████████████████████████████████████████████████████████████████████▊        | 9/10 [28:26<03:11, 191.58s/it]

valid_F1: 0.8159457167090755

train_loss : 0.016003393136525556
train_F1: 0.9255669348939282
valid_loss : 0.08438412018259552


100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [31:38<00:00, 189.89s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

valid_F1: 0.8184143222506394
FOLD 3
--------------------------------

train_loss : 0.13997031718236527
train_F1: 0.6740975559873198
valid_loss : 0.07199914558468579


 10%|████████▏                                                                         | 1/10 [03:13<28:57, 193.10s/it]

valid_F1: 0.8165289256198347

train_loss : 0.04399889737817535
train_F1: 0.8391292055414192
valid_loss : 0.06661897840274271


 20%|████████████████▍                                                                 | 2/10 [06:25<25:42, 192.81s/it]

valid_F1: 0.8229088168801808

train_loss : 0.03177909905120986
train_F1: 0.8649204864359215
valid_loss : 0.05971491971340413


 30%|████████████████████████▌                                                         | 3/10 [09:37<22:26, 192.34s/it]

valid_F1: 0.8457142857142858

train_loss : 0.026904744680904005
train_F1: 0.8872166480862134
valid_loss : 0.0530401814955278


 40%|████████████████████████████████▊                                                 | 4/10 [12:49<19:12, 192.15s/it]

valid_F1: 0.8424336973478939

train_loss : 0.02231395757649984
train_F1: 0.9000648328239326
valid_loss : 0.05396707741561058


 50%|█████████████████████████████████████████                                         | 5/10 [15:56<15:51, 190.23s/it]

valid_F1: 0.8442687747035573

train_loss : 0.020450883940405725
train_F1: 0.9095440280598116
valid_loss : 0.06303999745656919


 60%|█████████████████████████████████████████████████▏                                | 6/10 [19:02<12:35, 188.86s/it]

valid_F1: 0.8462164361269324

train_loss : 0.018507421355076277
train_F1: 0.9189938968004439
valid_loss : 0.0619375455458692


 70%|█████████████████████████████████████████████████████████▍                        | 7/10 [22:08<09:24, 188.07s/it]

valid_F1: 0.8293436293436294

train_loss : 0.01841925529714379
train_F1: 0.9211886304909561
valid_loss : 0.06682160237334042


 80%|█████████████████████████████████████████████████████████████████▌                | 8/10 [25:15<06:14, 187.49s/it]

valid_F1: 0.8253452477660439

train_loss : 0.017461486161118766
train_F1: 0.9219255289660907
valid_loss : 0.058085170214360354


 90%|█████████████████████████████████████████████████████████████████████████▊        | 9/10 [28:21<03:07, 187.25s/it]

valid_F1: 0.8219852337981953

train_loss : 0.016840383111900916
train_F1: 0.9277397893180559
valid_loss : 0.06495994376453841


100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [31:28<00:00, 188.83s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

valid_F1: 0.8171384850803368
FOLD 4
--------------------------------

train_loss : 0.1358024398265392
train_F1: 0.6824038758890837
valid_loss : 0.06832894331244664


 10%|████████▏                                                                         | 1/10 [03:06<28:02, 186.96s/it]

valid_F1: 0.8013698630136987

train_loss : 0.04349273486079368
train_F1: 0.8401354147075418
valid_loss : 0.062217593471580575


 20%|████████████████▍                                                                 | 2/10 [06:13<24:53, 186.68s/it]

valid_F1: 0.8305647840531561

train_loss : 0.030819537958740133
train_F1: 0.8731843575418995
valid_loss : 0.062382325068742875


 30%|████████████████████████▌                                                         | 3/10 [09:19<21:45, 186.50s/it]

valid_F1: 0.8450465707027942

train_loss : 0.02494918949018368
train_F1: 0.8907235621521336
valid_loss : 0.06921717264280419


 40%|████████████████████████████████▊                                                 | 4/10 [12:26<18:40, 186.71s/it]

valid_F1: 0.8283333333333333

train_loss : 0.02079831184998721
train_F1: 0.9111562615441449
valid_loss : 0.07014016213469973


 50%|█████████████████████████████████████████                                         | 5/10 [15:33<15:33, 186.61s/it]

valid_F1: 0.8337595907928389

train_loss : 0.017718407483616305
train_F1: 0.9196905507459938
valid_loss : 0.08463143845410825


 60%|█████████████████████████████████████████████████▏                                | 6/10 [18:40<12:26, 186.68s/it]

valid_F1: 0.833195020746888

train_loss : 0.01686997777574639
train_F1: 0.9258815946966209
valid_loss : 0.08367976991059346


 70%|█████████████████████████████████████████████████████████▍                        | 7/10 [21:46<09:19, 186.52s/it]

valid_F1: 0.827930174563591

train_loss : 0.015608396482323289
train_F1: 0.9280397022332506
valid_loss : 0.0915076709134835


 80%|█████████████████████████████████████████████████████████████████▌                | 8/10 [24:52<06:12, 186.42s/it]

valid_F1: 0.832632464255677

train_loss : 0.015511940518326967
train_F1: 0.9238121496186013
valid_loss : 0.09666060966931318


 90%|█████████████████████████████████████████████████████████████████████████▊        | 9/10 [27:58<03:06, 186.46s/it]

valid_F1: 0.8229426433915211

train_loss : 0.01557051020713556
train_F1: 0.9274749517418881
valid_loss : 0.08845200591174468


100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [31:05<00:00, 186.58s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

valid_F1: 0.8164291701592623
FOLD 5
--------------------------------

train_loss : 0.13286020029442872
train_F1: 0.6365773389750823
valid_loss : 0.06434898054070562


 10%|████████▏                                                                         | 1/10 [03:06<27:58, 186.51s/it]

valid_F1: 0.8069565217391306

train_loss : 0.04311743834555801
train_F1: 0.8421249532360644
valid_loss : 0.055352162743840264


 20%|████████████████▍                                                                 | 2/10 [06:13<24:52, 186.52s/it]

valid_F1: 0.8219424460431655

train_loss : 0.029898590987615135
train_F1: 0.8770043562888126
valid_loss : 0.06793002822549543


 30%|████████████████████████▌                                                         | 3/10 [09:20<21:47, 186.81s/it]

valid_F1: 0.8115183246073299

train_loss : 0.025417369121577996
train_F1: 0.891452833671774
valid_loss : 0.05484502203762531


 40%|████████████████████████████████▊                                                 | 4/10 [12:26<18:40, 186.68s/it]

valid_F1: 0.837248322147651

train_loss : 0.020421543863688277
train_F1: 0.9110497237569061
valid_loss : 0.08148586925850293


 50%|█████████████████████████████████████████                                         | 5/10 [15:32<15:31, 186.33s/it]

valid_F1: 0.8149466192170819

train_loss : 0.018929130627234593
train_F1: 0.9223836674636747
valid_loss : 0.058966034786693024


 60%|█████████████████████████████████████████████████▏                                | 6/10 [18:36<12:22, 185.70s/it]

valid_F1: 0.8253452477660438

train_loss : 0.0172947703529574
train_F1: 0.923981983638202
valid_loss : 0.06798919861283258


 70%|█████████████████████████████████████████████████████████▍                        | 7/10 [21:40<09:15, 185.15s/it]

valid_F1: 0.8370118845500849

train_loss : 0.016230404224025666
train_F1: 0.9326975976526682
valid_loss : 0.07180735630757898


 80%|█████████████████████████████████████████████████████████████████▌                | 8/10 [24:45<06:09, 184.89s/it]

valid_F1: 0.824896265560166

train_loss : 0.018233531080273632
train_F1: 0.9182865039152465
valid_loss : 0.06485935691395932


 90%|█████████████████████████████████████████████████████████████████████████▊        | 9/10 [27:49<03:04, 184.60s/it]

valid_F1: 0.8415758591785414

train_loss : 0.01515902003049082
train_F1: 0.9331251148263825
valid_loss : 0.06700043671842769


100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [30:53<00:00, 185.39s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

valid_F1: 0.8260869565217391
FOLD 6
--------------------------------

train_loss : 0.13998998381403013
train_F1: 0.6886888934342401
valid_loss : 0.06521253420092235


 10%|████████▏                                                                         | 1/10 [03:04<27:37, 184.15s/it]

valid_F1: 0.7980691874497184

train_loss : 0.04352155210084744
train_F1: 0.8414416109908723
valid_loss : 0.05736807497026764


 20%|████████████████▍                                                                 | 2/10 [06:08<24:34, 184.35s/it]

valid_F1: 0.8040201005025126

train_loss : 0.030550687477615104
train_F1: 0.879523676621081
valid_loss : 0.06600524627856005


 30%|████████████████████████▌                                                         | 3/10 [09:13<21:31, 184.48s/it]

valid_F1: 0.8116666666666665

train_loss : 0.025878054272860506
train_F1: 0.893286480488256
valid_loss : 0.06246640344333147


 40%|████████████████████████████████▊                                                 | 4/10 [12:17<18:25, 184.19s/it]

valid_F1: 0.8147554129911789

train_loss : 0.021904840403396922
train_F1: 0.9076327433628318
valid_loss : 0.06433180202599441


 50%|█████████████████████████████████████████                                         | 5/10 [15:21<15:21, 184.22s/it]

valid_F1: 0.7895569620253166

train_loss : 0.018465432720690313
train_F1: 0.9148328883159929
valid_loss : 0.06839082314345603


 60%|█████████████████████████████████████████████████▏                                | 6/10 [18:24<12:15, 183.89s/it]

valid_F1: 0.8338870431893688

train_loss : 0.017483935249177424
train_F1: 0.9217743419841709
valid_loss : 0.06636365168841085


 70%|█████████████████████████████████████████████████████████▍                        | 7/10 [21:29<09:12, 184.21s/it]

valid_F1: 0.8246205733558178

train_loss : 0.016563476275294857
train_F1: 0.9257553426676493
valid_loss : 0.06528704588205736


 80%|█████████████████████████████████████████████████████████████████▌                | 8/10 [24:34<06:08, 184.37s/it]

valid_F1: 0.8076602830974188

train_loss : 0.017029027584226916
train_F1: 0.9271840191475652
valid_loss : 0.07221103263757775


 90%|█████████████████████████████████████████████████████████████████████████▊        | 9/10 [27:38<03:04, 184.42s/it]

valid_F1: 0.8056478405315615

train_loss : 0.01444423489659474
train_F1: 0.9327453142227122
valid_loss : 0.07377957482169444


100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [30:43<00:00, 184.36s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

valid_F1: 0.797743755036261
FOLD 7
--------------------------------

train_loss : 0.1403294109071778
train_F1: 0.6844145246838026
valid_loss : 0.06454595698861876


 10%|████████▏                                                                         | 1/10 [03:04<27:37, 184.12s/it]

valid_F1: 0.8076285240464345

train_loss : 0.045785060473356565
train_F1: 0.8339455590091363
valid_loss : 0.0552279521998401


 20%|████████████████▍                                                                 | 2/10 [06:08<24:32, 184.11s/it]

valid_F1: 0.837171052631579

train_loss : 0.03143500002999904
train_F1: 0.8710820895522389
valid_loss : 0.059634164994077705


 30%|████████████████████████▌                                                         | 3/10 [09:12<21:30, 184.38s/it]

valid_F1: 0.8433530906011855

train_loss : 0.025931369698228762
train_F1: 0.8902484241750094
valid_loss : 0.05402624894316509


 40%|████████████████████████████████▊                                                 | 4/10 [12:18<18:28, 184.73s/it]

valid_F1: 0.8410757946210269

train_loss : 0.021457877804228007
train_F1: 0.906088560885609
valid_loss : 0.06117178146971094


 50%|█████████████████████████████████████████                                         | 5/10 [15:22<15:23, 184.60s/it]

valid_F1: 0.8162627052384676

train_loss : 0.019013173029833668
train_F1: 0.9164594270977251
valid_loss : 0.0695942306856268


 60%|█████████████████████████████████████████████████▏                                | 6/10 [18:26<12:17, 184.27s/it]

valid_F1: 0.8377723970944311

train_loss : 0.017509948236840878
train_F1: 0.9185103244837758
valid_loss : 0.07540662537511682


 70%|█████████████████████████████████████████████████████████▍                        | 7/10 [21:31<09:13, 184.46s/it]

valid_F1: 0.8353609083536091

train_loss : 0.017619494936810268
train_F1: 0.9219727640780271
valid_loss : 0.07600446698630106


 80%|█████████████████████████████████████████████████████████████████▌                | 8/10 [24:35<06:08, 184.46s/it]

valid_F1: 0.8268434134217065

train_loss : 0.016707540628465013
train_F1: 0.9283804958982396
valid_loss : 0.07400781861245284


 90%|█████████████████████████████████████████████████████████████████████████▊        | 9/10 [27:39<03:04, 184.38s/it]

valid_F1: 0.809375

train_loss : 0.017348219994222865
train_F1: 0.926465167711021
valid_loss : 0.07066868011030221


100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [30:44<00:00, 184.42s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

valid_F1: 0.8331907613344739
FOLD 8
--------------------------------

train_loss : 0.14264759377909475
train_F1: 0.6309089207151548
valid_loss : 0.050390514101241235


 10%|████████▏                                                                         | 1/10 [03:09<28:24, 189.44s/it]

valid_F1: 0.8319856244384546

train_loss : 0.043720844810181826
train_F1: 0.8468183929400837
valid_loss : 0.04836179918357145


 20%|████████████████▍                                                                 | 2/10 [06:17<25:08, 188.50s/it]

valid_F1: 0.8361990950226244

train_loss : 0.031538811170984254
train_F1: 0.8724832214765099
valid_loss : 0.049931192204843614


 30%|████████████████████████▌                                                         | 3/10 [09:28<22:08, 189.72s/it]

valid_F1: 0.8330206378986866

train_loss : 0.026349385994021576
train_F1: 0.8907208940185032
valid_loss : 0.04858175123754506


 40%|████████████████████████████████▊                                                 | 4/10 [12:37<18:56, 189.39s/it]

valid_F1: 0.8486486486486486

train_loss : 0.020866979427566826
train_F1: 0.9073599414134017
valid_loss : 0.05844456621524051


 50%|█████████████████████████████████████████                                         | 5/10 [15:48<15:50, 190.00s/it]

valid_F1: 0.8181818181818181

train_loss : 0.01955044207560458
train_F1: 0.9162696958592892
valid_loss : 0.05810123963159657


 60%|█████████████████████████████████████████████████▏                                | 6/10 [18:56<12:37, 189.47s/it]

valid_F1: 0.835304822565969

train_loss : 0.0191034984597083
train_F1: 0.9143535427319212
valid_loss : 0.06000875106470868


 70%|█████████████████████████████████████████████████████████▍                        | 7/10 [22:04<09:26, 188.89s/it]

valid_F1: 0.8057675996607294

train_loss : 0.018285267621021572
train_F1: 0.9213258058641414
valid_loss : 0.05417917099281945


 80%|█████████████████████████████████████████████████████████████████▌                | 8/10 [25:16<06:19, 189.78s/it]

valid_F1: 0.8355795148247979

train_loss : 0.01643248972135612
train_F1: 0.9285323609845031
valid_loss : 0.06264945082545985


 90%|█████████████████████████████████████████████████████████████████████████▊        | 9/10 [28:25<03:09, 189.57s/it]

valid_F1: 0.844811753902663

train_loss : 0.015968402216271806
train_F1: 0.93145197997269
valid_loss : 0.06257121720151088


100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [31:29<00:00, 188.96s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

valid_F1: 0.8191304347826088
FOLD 9
--------------------------------

train_loss : 0.1328812434450994
train_F1: 0.673671920607122
valid_loss : 0.06406786569576954


 10%|████████▏                                                                         | 1/10 [03:05<27:46, 185.13s/it]

valid_F1: 0.788975021533161

train_loss : 0.04207048797761786
train_F1: 0.8452569722973604
valid_loss : 0.0677988989992398


 20%|████████████████▍                                                                 | 2/10 [06:10<24:42, 185.29s/it]

valid_F1: 0.7396098388464801

train_loss : 0.02993503924631812
train_F1: 0.8785740809506126
valid_loss : 0.07020596062294512


 30%|████████████████████████▌                                                         | 3/10 [09:16<21:40, 185.81s/it]

valid_F1: 0.8141592920353983

train_loss : 0.02385268669443584
train_F1: 0.8996043066163614
valid_loss : 0.07234365501980707


 40%|████████████████████████████████▊                                                 | 4/10 [12:21<18:32, 185.44s/it]

valid_F1: 0.798360655737705

train_loss : 0.021743164295420333
train_F1: 0.9032258064516128
valid_loss : 0.07202555527110915


 50%|█████████████████████████████████████████                                         | 5/10 [15:28<15:29, 185.95s/it]

valid_F1: 0.8016736401673641

train_loss : 0.019334993988183655
train_F1: 0.9132831159287157
valid_loss : 0.07724397153611008


 60%|█████████████████████████████████████████████████▏                                | 6/10 [18:35<12:25, 186.25s/it]

valid_F1: 0.7898423817863398

train_loss : 0.017013644562032557
train_F1: 0.927776244364707
valid_loss : 0.08797177524419031


 70%|█████████████████████████████████████████████████████████▍                        | 7/10 [21:42<09:19, 186.51s/it]

valid_F1: 0.8150042625745951

train_loss : 0.019247726156467932
train_F1: 0.9165672326755394
valid_loss : 0.0725164008413868


 80%|█████████████████████████████████████████████████████████████████▌                | 8/10 [24:48<06:12, 186.30s/it]

valid_F1: 0.8065573770491804

train_loss : 0.01633739844743305
train_F1: 0.9268695014662756
valid_loss : 0.08542334786224588


 90%|█████████████████████████████████████████████████████████████████████████▊        | 9/10 [27:57<03:07, 187.26s/it]

valid_F1: 0.8013816925734024

train_loss : 0.017759857354907465
train_F1: 0.9184517789831755
valid_loss : 0.08180525875609547


100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [31:05<00:00, 186.57s/it]

valid_F1: 0.8053917438921651





In [18]:
All_score = len(All_Fold_score[0]) * [0]
for score in All_Fold_score:
    All_score = np.sum([All_score,score], axis = 0)
All_score = np.round((All_score / 10) ,2)

In [20]:
All_score

array([0.8 , 0.81, 0.83, 0.83, 0.82, 0.82, 0.82, 0.82, 0.82, 0.81])

## 單用 trainingset 去做 10-fold CV

In [21]:
All_Fold_score = []
for fold, (train_ids, test_ids) in enumerate(kfold.split(tr_dataset)):
    print(f'FOLD {fold}')
    print('--------------------------------')
    train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
    test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)
    train_dataloader = DataLoader(dataset, batch_size=bs,
                                  collate_fn=tr_dataset.collate_fn,
                                  sampler=train_subsampler)
    valid_dataloader = DataLoader(dataset, batch_size=bs,
                                  collate_fn=va_dataset.collate_fn,
                                  sampler=test_subsampler)  
    model = EmbeddedRnn(len(word2idx), 256, len(tags),2,word_pad_idx,tag_pad_idx)
    optimizer = optim.AdamW(model.parameters(), lr=5e-3)
    model = model.to(device)
    all_loader = {"train" : train_dataloader,
                  "valid" : valid_dataloader}
    Fold_score = []
    for epoch in tqdm(range(num_epoch)):
        all_loss = {
            'train': [],
            'valid': [],
        }
        print('')
        for loader in all_loader:
            predictions , true_labels  = [],[]
            for x, y in all_loader[loader]:
                optimizer.zero_grad()
                x = x.to(device)
                y = y.to(device)
                hidden = model.initHidden(x.size(0))
                hidden[0] = hidden[0].to(device)
                hidden[1] = hidden[1].to(device)
                output, loss = model(x, hidden,y)
                if loader == 'train':
                    loss.backward()
                    optimizer.step()
                all_loss[loader].append(loss.cpu().item()) 
                predictions.extend([[idx2tag[j] for j in i] for i in output])
                for i in y.detach().cpu().numpy():
                    _ = []
                    for j in i:
                        if j != tag_pad_idx:
                            _.append(idx2tag[j])
                    true_labels.append(_)
            print(f'{loader}_loss : {np.mean(np.array(all_loss[loader]))/64}')
            f_ = f1_score(true_labels,predictions,scheme = IOB2)
            print(f'{loader}_F1: {f_}')
            if loader == 'valid':
                Fold_score.append(f_)
    if Fold_score != []:
        All_Fold_score.append(Fold_score)

  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

FOLD 0
--------------------------------

train_loss : 0.1900658209428245
train_F1: 0.6817943061245928
valid_loss : 0.08059124796131215


 10%|████████▏                                                                         | 1/10 [02:18<20:50, 138.96s/it]

valid_F1: 0.8197009674582235

train_loss : 0.05481904276011223
train_F1: 0.8467987804878049
valid_loss : 0.08001796457335933


 20%|████████████████▍                                                                 | 2/10 [04:38<18:36, 139.59s/it]

valid_F1: 0.821203953279425

train_loss : 0.03560141487524505
train_F1: 0.891022021456804
valid_loss : 0.07813419904801752


 30%|████████████████████████▌                                                         | 3/10 [06:56<16:08, 138.41s/it]

valid_F1: 0.8377896613190731

train_loss : 0.02972302059189077
train_F1: 0.9054205607476635
valid_loss : 0.07918884336658112


 40%|████████████████████████████████▊                                                 | 4/10 [09:14<13:51, 138.62s/it]

valid_F1: 0.8205128205128206

train_loss : 0.02373184944882935
train_F1: 0.9205846755423144
valid_loss : 0.09360239310069131


 50%|█████████████████████████████████████████                                         | 5/10 [11:31<11:29, 137.86s/it]

valid_F1: 0.8232142857142857

train_loss : 0.018619274283466786
train_F1: 0.942096443865202
valid_loss : 0.09513132143523786


 60%|█████████████████████████████████████████████████▏                                | 6/10 [13:47<09:08, 137.25s/it]

valid_F1: 0.8098591549295775

train_loss : 0.018025922927238883
train_F1: 0.9393291833132026
valid_loss : 0.09339787142181938


 70%|█████████████████████████████████████████████████████████▍                        | 7/10 [16:02<06:49, 136.58s/it]

valid_F1: 0.7966942148760331

train_loss : 0.017066114051780237
train_F1: 0.9440878221229883
valid_loss : 0.09173317691432191


 80%|█████████████████████████████████████████████████████████████████▌                | 8/10 [18:18<04:32, 136.35s/it]

valid_F1: 0.8021390374331551

train_loss : 0.015349872564999487
train_F1: 0.9461066716223749
valid_loss : 0.10387363182017942


 90%|█████████████████████████████████████████████████████████████████████████▊        | 9/10 [20:34<02:16, 136.29s/it]

valid_F1: 0.7911547911547911

train_loss : 0.01793819790410167
train_F1: 0.940312383785794
valid_loss : 0.09742935156667387


100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [22:50<00:00, 137.07s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

valid_F1: 0.8092783505154639
FOLD 1
--------------------------------

train_loss : 0.18483634208890506
train_F1: 0.6734296546777334
valid_loss : 0.09492409045433069


 10%|████████▏                                                                         | 1/10 [02:18<20:42, 138.05s/it]

valid_F1: 0.794040315512708

train_loss : 0.05592102574635068
train_F1: 0.8479069323924859
valid_loss : 0.08596170195317887


 20%|████████████████▍                                                                 | 2/10 [04:36<18:24, 138.04s/it]

valid_F1: 0.8160418482999129

train_loss : 0.03686249832007429
train_F1: 0.8876914350538855
valid_loss : 0.0945793084509961


 30%|████████████████████████▌                                                         | 3/10 [07:00<16:27, 141.10s/it]

valid_F1: 0.8066115702479338

train_loss : 0.029904325933125045
train_F1: 0.9070380127963868
valid_loss : 0.0929515700569594


 40%|████████████████████████████████▊                                                 | 4/10 [09:23<14:10, 141.81s/it]

valid_F1: 0.8118811881188119

train_loss : 0.0248177259447665
train_F1: 0.9248078004875305
valid_loss : 0.10708585250880812


 50%|█████████████████████████████████████████                                         | 5/10 [11:45<11:49, 141.92s/it]

valid_F1: 0.8038125496425735

train_loss : 0.0214942588638985
train_F1: 0.9302499765939519
valid_loss : 0.12309841133121933


 60%|█████████████████████████████████████████████████▏                                | 6/10 [14:04<09:23, 140.81s/it]

valid_F1: 0.7889908256880733

train_loss : 0.018452511340319194
train_F1: 0.937570093457944
valid_loss : 0.13336214287714523


 70%|█████████████████████████████████████████████████████████▍                        | 7/10 [16:22<06:59, 139.82s/it]

valid_F1: 0.8311475409836065

train_loss : 0.016888040051148745
train_F1: 0.9428063278105402
valid_loss : 0.13227341619211358


 80%|█████████████████████████████████████████████████████████████████▌                | 8/10 [18:38<04:37, 138.63s/it]

valid_F1: 0.7844690966719492

train_loss : 0.017478396657128088
train_F1: 0.9386686611817501
valid_loss : 0.1253761387998601


 90%|█████████████████████████████████████████████████████████████████████████▊        | 9/10 [20:55<02:18, 138.28s/it]

valid_F1: 0.8255033557046979

train_loss : 0.01610762612858034
train_F1: 0.9454647676161919
valid_loss : 0.12278942843633038


100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [23:12<00:00, 139.29s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

valid_F1: 0.8031496062992126
FOLD 2
--------------------------------

train_loss : 0.18572275512700476
train_F1: 0.6589945487583283
valid_loss : 0.09980054714946777


 10%|████████▏                                                                         | 1/10 [02:20<21:06, 140.73s/it]

valid_F1: 0.7883461868037702

train_loss : 0.05860524200414062
train_F1: 0.8450166905102526
valid_loss : 0.08058425051148062


 20%|████████████████▍                                                                 | 2/10 [04:46<19:11, 143.98s/it]

valid_F1: 0.8180272108843537

train_loss : 0.03948761306783783
train_F1: 0.8815677364385118
valid_loss : 0.08523786175609022


 30%|████████████████████████▌                                                         | 3/10 [07:13<16:55, 145.04s/it]

valid_F1: 0.8275862068965518

train_loss : 0.03209650401523612
train_F1: 0.9003100629521753
valid_loss : 0.09095033442045188


 40%|████████████████████████████████▊                                                 | 4/10 [09:36<14:26, 144.34s/it]

valid_F1: 0.8256880733944955

train_loss : 0.025335777809150813
train_F1: 0.9186656671664168
valid_loss : 0.10077076430686495


 50%|█████████████████████████████████████████                                         | 5/10 [12:03<12:06, 145.35s/it]

valid_F1: 0.8449014567266496

train_loss : 0.019583066157485604
train_F1: 0.9341295017727188
valid_loss : 0.10459018399479328


 60%|█████████████████████████████████████████████████▏                                | 6/10 [14:53<10:15, 153.78s/it]

valid_F1: 0.8239202657807309

train_loss : 0.01822875764066293
train_F1: 0.9432094152811507
valid_loss : 0.12189546707630544


 70%|█████████████████████████████████████████████████████████▍                        | 7/10 [17:52<08:05, 161.95s/it]

valid_F1: 0.8139158576051779

train_loss : 0.01672505892221463
train_F1: 0.9449275362318841
valid_loss : 0.10460999777749205


 80%|█████████████████████████████████████████████████████████████████▌                | 8/10 [20:53<05:36, 168.01s/it]

valid_F1: 0.8198558847077664

train_loss : 0.014884442342100632
train_F1: 0.9528698086794213
valid_loss : 0.10446456060200543


 90%|█████████████████████████████████████████████████████████████████████████▊        | 9/10 [23:52<02:51, 171.35s/it]

valid_F1: 0.826487845766974

train_loss : 0.018182233313973978
train_F1: 0.9423887587822015
valid_loss : 0.10728254052583588


100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [26:47<00:00, 160.80s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

valid_F1: 0.7952941176470588
FOLD 3
--------------------------------

train_loss : 0.17831835162553547
train_F1: 0.6982731554160126
valid_loss : 0.08099823428148573


 10%|████████▏                                                                         | 1/10 [02:30<22:31, 150.12s/it]

valid_F1: 0.815112540192926

train_loss : 0.053941396857857274
train_F1: 0.8547400611620797
valid_loss : 0.07754828255962241


 20%|████████████████▍                                                                 | 2/10 [05:15<21:14, 159.36s/it]

valid_F1: 0.8317460317460318

train_loss : 0.03743188667652409
train_F1: 0.8860520094562647
valid_loss : 0.09328197182289191


 30%|████████████████████████▌                                                         | 3/10 [08:13<19:35, 167.88s/it]

valid_F1: 0.8120754716981132

train_loss : 0.029518437241657976
train_F1: 0.9098938269284976
valid_loss : 0.09256254550214711


 40%|████████████████████████████████▊                                                 | 4/10 [11:08<17:03, 170.57s/it]

valid_F1: 0.8129251700680272

train_loss : 0.026337242275072996
train_F1: 0.9150560950315829
valid_loss : 0.10298637327338968


 50%|█████████████████████████████████████████                                         | 5/10 [14:04<14:22, 172.51s/it]

valid_F1: 0.8047808764940239

train_loss : 0.023469937499864733
train_F1: 0.926585618697578
valid_loss : 0.09207865381734325


 60%|█████████████████████████████████████████████████▏                                | 6/10 [16:57<11:30, 172.73s/it]

valid_F1: 0.8128000000000001

train_loss : 0.01974974488150261
train_F1: 0.9358781496803309
valid_loss : 0.0997609527563894


 70%|█████████████████████████████████████████████████████████▍                        | 7/10 [19:51<08:38, 172.98s/it]

valid_F1: 0.825910931174089

train_loss : 0.018114356551359706
train_F1: 0.9460830358820215
valid_loss : 0.104466049266713


 80%|█████████████████████████████████████████████████████████████████▌                | 8/10 [22:41<05:43, 171.96s/it]

valid_F1: 0.8186946011281225

train_loss : 0.01890181938458624
train_F1: 0.9373999811729268
valid_loss : 0.10770376297560605


 90%|█████████████████████████████████████████████████████████████████████████▊        | 9/10 [25:34<02:52, 172.36s/it]

valid_F1: 0.8108108108108109

train_loss : 0.015859827216855832
train_F1: 0.9467655619190687
valid_loss : 0.10143280606233067


100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [28:19<00:00, 169.98s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

valid_F1: 0.8242914979757087
FOLD 4
--------------------------------

train_loss : 0.1882718615691154
train_F1: 0.6413268832066343
valid_loss : 0.09671460082391639


 10%|████████▏                                                                         | 1/10 [02:32<22:51, 152.44s/it]

valid_F1: 0.7882136279926335

train_loss : 0.0577131717428834
train_F1: 0.8458015267175573
valid_loss : 0.07908021586088392


 20%|████████████████▍                                                                 | 2/10 [05:24<21:51, 163.93s/it]

valid_F1: 0.8239148239148238

train_loss : 0.03857212866041204
train_F1: 0.8801054018445321
valid_loss : 0.08757957897161121


 30%|████████████████████████▌                                                         | 3/10 [08:15<19:29, 167.12s/it]

valid_F1: 0.8266897746967071

train_loss : 0.030501397446282073
train_F1: 0.9063789868667916
valid_loss : 0.09346464138429661


 40%|████████████████████████████████▊                                                 | 4/10 [11:01<16:40, 166.76s/it]

valid_F1: 0.7867820613690008

train_loss : 0.02639748613192071
train_F1: 0.9114114114114115
valid_loss : 0.11354258840347266


 50%|█████████████████████████████████████████                                         | 5/10 [13:47<13:53, 166.63s/it]

valid_F1: 0.8195292066259807

train_loss : 0.022384663246948582
train_F1: 0.9261555806087937
valid_loss : 0.1035261523559109


 60%|█████████████████████████████████████████████████▏                                | 6/10 [16:33<11:04, 166.17s/it]

valid_F1: 0.8174474959612278

train_loss : 0.020418217598833815
train_F1: 0.9347927776218543
valid_loss : 0.10156529301221108


 70%|█████████████████████████████████████████████████████████▍                        | 7/10 [19:19<08:18, 166.18s/it]

valid_F1: 0.7956448911222781

train_loss : 0.018821682965816655
train_F1: 0.940791931266343
valid_loss : 0.09465303304212096


 80%|█████████████████████████████████████████████████████████████████▌                | 8/10 [22:07<05:33, 166.73s/it]

valid_F1: 0.8359511343804538

train_loss : 0.016214172752882564
train_F1: 0.9454681507810306
valid_loss : 0.10492920076063314


 90%|█████████████████████████████████████████████████████████████████████████▊        | 9/10 [24:38<02:41, 161.97s/it]

valid_F1: 0.808724832214765

train_loss : 0.01649851279504705
train_F1: 0.9465819947702653
valid_loss : 0.10950063767709903


100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [27:06<00:00, 162.70s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

valid_F1: 0.8274067649609714
FOLD 5
--------------------------------

train_loss : 0.18342008740247803
train_F1: 0.6818984318205421
valid_loss : 0.08568838495132211


 10%|████████▏                                                                         | 1/10 [02:35<23:19, 155.54s/it]

valid_F1: 0.8070175438596492

train_loss : 0.05671549700469532
train_F1: 0.8448143895905089
valid_loss : 0.08632202973807013


 20%|████████████████▍                                                                 | 2/10 [05:34<22:34, 169.27s/it]

valid_F1: 0.8194888178913737

train_loss : 0.039796887297815366
train_F1: 0.8780533989774664
valid_loss : 0.09403644396935577


 30%|████████████████████████▌                                                         | 3/10 [08:27<19:56, 170.98s/it]

valid_F1: 0.817741935483871

train_loss : 0.030291039816739328
train_F1: 0.8996897621509825
valid_loss : 0.093638860342371


 40%|████████████████████████████████▊                                                 | 4/10 [11:23<17:17, 172.99s/it]

valid_F1: 0.8094512195121951

train_loss : 0.025688381620265866
train_F1: 0.9191577364166197
valid_loss : 0.11170185175905754


 50%|█████████████████████████████████████████                                         | 5/10 [14:15<14:23, 172.78s/it]

valid_F1: 0.8410596026490067

train_loss : 0.02281466895253784
train_F1: 0.9251101321585902
valid_loss : 0.11723958282404906


 60%|█████████████████████████████████████████████████▏                                | 6/10 [16:54<11:11, 167.95s/it]

valid_F1: 0.8285479901558654

train_loss : 0.018966419376863254
train_F1: 0.9357832567732258
valid_loss : 0.10651995146090722


 70%|█████████████████████████████████████████████████████████▍                        | 7/10 [19:36<08:17, 165.92s/it]

valid_F1: 0.8386568386568386

train_loss : 0.018545236849790904
train_F1: 0.9416408232308993
valid_loss : 0.13421997153381635


 80%|█████████████████████████████████████████████████████████████████▌                | 8/10 [22:32<05:38, 169.15s/it]

valid_F1: 0.8089887640449437

train_loss : 0.018026684939235933
train_F1: 0.9365496527125962
valid_loss : 0.11696644568235262


 90%|█████████████████████████████████████████████████████████████████████████▊        | 9/10 [25:24<02:50, 170.13s/it]

valid_F1: 0.7966101694915254

train_loss : 0.01662423915071295
train_F1: 0.9447668975847219
valid_loss : 0.0980764412042963


100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [28:19<00:00, 169.91s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

valid_F1: 0.8338709677419355
FOLD 6
--------------------------------

train_loss : 0.18567581861296717
train_F1: 0.6453494022742735
valid_loss : 0.08188112769517805


 10%|████████▏                                                                         | 1/10 [02:53<25:59, 173.32s/it]

valid_F1: 0.8236363636363637

train_loss : 0.05760489886471941
train_F1: 0.8477008781040505
valid_loss : 0.07376668089395993


 20%|████████████████▍                                                                 | 2/10 [05:49<23:20, 175.10s/it]

valid_F1: 0.8195829555757026

train_loss : 0.03776416514028496
train_F1: 0.8838120104438643
valid_loss : 0.08558444465909686


 30%|████████████████████████▌                                                         | 3/10 [08:19<19:05, 163.65s/it]

valid_F1: 0.8353265869365225

train_loss : 0.03053406959110434
train_F1: 0.9038782705511227
valid_loss : 0.08082418738828077


 40%|████████████████████████████████▊                                                 | 4/10 [11:10<16:39, 166.64s/it]

valid_F1: 0.8392007611798288

train_loss : 0.024807292905498282
train_F1: 0.92079940784604
valid_loss : 0.08900866141034798


 50%|█████████████████████████████████████████                                         | 5/10 [14:11<14:17, 171.53s/it]

valid_F1: 0.8284444444444444

train_loss : 0.01937748753946503
train_F1: 0.9350601295097133
valid_loss : 0.09212088125286164


 60%|█████████████████████████████████████████████████▏                                | 6/10 [17:09<11:35, 173.75s/it]

valid_F1: 0.845943482224248

train_loss : 0.018175603943416305
train_F1: 0.9418087472201631
valid_loss : 0.0917225479353945


 70%|█████████████████████████████████████████████████████████▍                        | 7/10 [20:04<08:42, 174.25s/it]

valid_F1: 0.8191489361702127

train_loss : 0.017064493462705603
train_F1: 0.947183424290075
valid_loss : 0.09742944278790579


 80%|█████████████████████████████████████████████████████████████████▌                | 8/10 [23:00<05:49, 174.68s/it]

valid_F1: 0.8264758497316635

train_loss : 0.01556654535973653
train_F1: 0.9470664445678326
valid_loss : 0.1059589880857278


 90%|█████████████████████████████████████████████████████████████████████████▊        | 9/10 [25:42<02:50, 170.88s/it]

valid_F1: 0.8212560386473431

train_loss : 0.016887620055967344
train_F1: 0.9442337926569869
valid_loss : 0.09689653079424586


100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [28:12<00:00, 169.20s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

valid_F1: 0.8316467341306348
FOLD 7
--------------------------------

train_loss : 0.19014601494969013
train_F1: 0.6509873970427706
valid_loss : 0.08714150148746255


 10%|████████▏                                                                         | 1/10 [02:53<26:02, 173.61s/it]

valid_F1: 0.8018867924528302

train_loss : 0.057663331309918464
train_F1: 0.8417961348995832
valid_loss : 0.08668227665513367


 20%|████████████████▍                                                                 | 2/10 [05:50<23:24, 175.60s/it]

valid_F1: 0.801923076923077

train_loss : 0.03701969315617308
train_F1: 0.8896195396899954
valid_loss : 0.08171719969431689


 30%|████████████████████████▌                                                         | 3/10 [08:43<20:19, 174.25s/it]

valid_F1: 0.832309043020193

train_loss : 0.02959639472409491
train_F1: 0.9070332959221847
valid_loss : 0.08905331800123314


 40%|████████████████████████████████▊                                                 | 4/10 [11:35<17:20, 173.47s/it]

valid_F1: 0.8214936247723132

train_loss : 0.022279940301464997
train_F1: 0.9282127580435187
valid_loss : 0.08954748681855279


 50%|█████████████████████████████████████████                                         | 5/10 [14:30<14:29, 173.98s/it]

valid_F1: 0.8125530110262934

train_loss : 0.019220135439967312
train_F1: 0.9354599406528189
valid_loss : 0.09693359202620658


 60%|█████████████████████████████████████████████████▏                                | 6/10 [17:02<11:05, 166.41s/it]

valid_F1: 0.8158590308370044

train_loss : 0.019732193634609964
train_F1: 0.931657355679702
valid_loss : 0.09818050716459364


 70%|█████████████████████████████████████████████████████████▍                        | 7/10 [19:39<08:10, 163.36s/it]

valid_F1: 0.8176895306859207

train_loss : 0.016213556441700337
train_F1: 0.943809789170614
valid_loss : 0.09475404631902838


 80%|█████████████████████████████████████████████████████████████████▌                | 8/10 [22:37<05:36, 168.13s/it]

valid_F1: 0.8189116859946476

train_loss : 0.01678608022836148
train_F1: 0.946106671622375
valid_loss : 0.1309688924025599


 90%|█████████████████████████████████████████████████████████████████████████▊        | 9/10 [25:30<02:49, 169.52s/it]

valid_F1: 0.8172446110590441

train_loss : 0.016051368521306234
train_F1: 0.946997122435719
valid_loss : 0.10216900504303056


100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [28:24<00:00, 170.44s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

valid_F1: 0.8031634446397188
FOLD 8
--------------------------------

train_loss : 0.19330833989198887
train_F1: 0.6659238249594813
valid_loss : 0.08127553589255004


 10%|████████▏                                                                         | 1/10 [02:52<25:49, 172.14s/it]

valid_F1: 0.813126709206928

train_loss : 0.05888244354229972
train_F1: 0.8423861254049934
valid_loss : 0.07580200781102304


 20%|████████████████▍                                                                 | 2/10 [05:46<23:06, 173.35s/it]

valid_F1: 0.8259041211101767

train_loss : 0.039260305829399975
train_F1: 0.8785714285714284
valid_loss : 0.07166295225260319


 30%|████████████████████████▌                                                         | 3/10 [08:40<20:15, 173.70s/it]

valid_F1: 0.8375768217734856

train_loss : 0.03067899863723168
train_F1: 0.8978368761119955
valid_loss : 0.09613041751473755


 40%|████████████████████████████████▊                                                 | 4/10 [11:30<17:13, 172.28s/it]

valid_F1: 0.8392226148409895

train_loss : 0.02519153745320461
train_F1: 0.9182495344506517
valid_loss : 0.08113345866660018


 50%|█████████████████████████████████████████                                         | 5/10 [14:14<14:05, 169.13s/it]

valid_F1: 0.8255093002657219

train_loss : 0.02132105545312274
train_F1: 0.9304623415361669
valid_loss : 0.0986273560987471


 60%|█████████████████████████████████████████████████▏                                | 6/10 [16:55<11:06, 166.54s/it]

valid_F1: 0.80067283431455

train_loss : 0.01830558909308367
train_F1: 0.9392080312325711
valid_loss : 0.10278309781297848


 70%|█████████████████████████████████████████████████████████▍                        | 7/10 [19:57<08:34, 171.54s/it]

valid_F1: 0.821397756686799

train_loss : 0.01843155517069697
train_F1: 0.9392193308550186
valid_loss : 0.09700570658284735


 80%|█████████████████████████████████████████████████████████████████▌                | 8/10 [22:52<05:45, 172.55s/it]

valid_F1: 0.8010204081632654

train_loss : 0.017374254806834653
train_F1: 0.9450385938807774
valid_loss : 0.10402436438318971


 90%|█████████████████████████████████████████████████████████████████████████▊        | 9/10 [25:50<02:54, 174.43s/it]

valid_F1: 0.8351254480286737

train_loss : 0.014816653351080924
train_F1: 0.9492281941603125
valid_loss : 0.09377933399866686


100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [28:43<00:00, 172.37s/it]
  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

valid_F1: 0.8177083333333335
FOLD 9
--------------------------------

train_loss : 0.19163294981371626
train_F1: 0.6510396233817183
valid_loss : 0.08658657660828782


 10%|████████▏                                                                         | 1/10 [02:54<26:13, 174.84s/it]

valid_F1: 0.808206958073149

train_loss : 0.05759814428131933
train_F1: 0.8474320241691843
valid_loss : 0.09335462761099463


 20%|████████████████▍                                                                 | 2/10 [05:43<22:51, 171.47s/it]

valid_F1: 0.825

train_loss : 0.03934825434877339
train_F1: 0.8819678264122708
valid_loss : 0.08260658938534461


 30%|████████████████████████▌                                                         | 3/10 [08:11<18:42, 160.41s/it]

valid_F1: 0.8330341113105924

train_loss : 0.029875752633767006
train_F1: 0.9097358349668627
valid_loss : 0.09245482024240803


 40%|████████████████████████████████▊                                                 | 4/10 [10:42<15:41, 156.88s/it]

valid_F1: 0.8165057067603161

train_loss : 0.024582098632217098
train_F1: 0.9203094417000652
valid_loss : 0.09778362760840395


 50%|█████████████████████████████████████████                                         | 5/10 [13:40<13:41, 164.35s/it]

valid_F1: 0.8007213706041479

train_loss : 0.022793635829866255
train_F1: 0.928053541550474
valid_loss : 0.13084741369760655


 60%|█████████████████████████████████████████████████▏                                | 6/10 [16:36<11:13, 168.29s/it]

valid_F1: 0.8081556997219648

train_loss : 0.019287491744949016
train_F1: 0.9360751837722155
valid_loss : 0.11012472686442462


 70%|█████████████████████████████████████████████████████████▍                        | 7/10 [19:30<08:30, 170.10s/it]

valid_F1: 0.8170940170940171

train_loss : 0.01795271768815486
train_F1: 0.9416465341014681
valid_loss : 0.10321018825490753


 80%|█████████████████████████████████████████████████████████████████▌                | 8/10 [22:22<05:41, 170.72s/it]

valid_F1: 0.8191304347826087

train_loss : 0.017518246154021807
train_F1: 0.9452449567723342
valid_loss : 0.10500974839178288


 90%|█████████████████████████████████████████████████████████████████████████▊        | 9/10 [25:15<02:51, 171.57s/it]

valid_F1: 0.7899305555555555

train_loss : 0.01638613670063289
train_F1: 0.9427934621099554
valid_loss : 0.10262904069446899


100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [28:10<00:00, 169.08s/it]

valid_F1: 0.8296943231441049





In [22]:
All_score = len(All_Fold_score[0]) * [0]
for score in All_Fold_score:
    All_score = np.sum([All_score,score], axis = 0)
All_score = np.round((All_score / 10) ,2) 
All_score

array([0.81, 0.82, 0.83, 0.82, 0.82, 0.82, 0.82, 0.81, 0.81, 0.82])

## 全部的 trainingset 訓練 model 後, 透過 validset 評估 model

In [19]:
bs = 64
num_epoch = 20

tr_dataset = NERDataset(X_train,y_train,word2idx['<pad>'],tag2idx['<pad>'])
va_dataset = NERDataset(X_test,y_test,word2idx['<pad>'],tag2idx['<pad>'])
train_dataloader = DataLoader(tr_dataset, batch_size=bs,
                            collate_fn=tr_dataset.collate_fn)
valid_dataloader = DataLoader(va_dataset, batch_size=bs,
                            collate_fn=va_dataset.collate_fn)  
model = EmbeddedRnn(len(word2idx), 256, len(tags),2,word2idx['<pad>'],tag2idx['<pad>'])
optimizer = optim.AdamW(model.parameters(), lr=5e-3)
model = model.to(device)
all_loader = {"train" : train_dataloader,
              "valid" : valid_dataloader}

In [20]:
F_score = []
for epoch in tqdm(range(num_epoch)):
    all_loss = {
    'train': [],
    'valid': [],
    }
    print('')
    for loader in all_loader:
        predictions , true_labels  = [],[]
        for x, y in all_loader[loader]:
            optimizer.zero_grad()
            x = x.to(device)
            y = y.to(device)
            hidden = model.initHidden(x.size(0))
            hidden[0] = hidden[0].to(device)
            hidden[1] = hidden[1].to(device)
            output, loss = model(x, hidden,y)
            if loader == 'train':
                loss.backward()
                optimizer.step()
            all_loss[loader].append(loss.cpu().item()) 
            predictions.extend([[idx2tag[j] for j in i] for i in output])
            for i in y.detach().cpu().numpy():
                _ = []
                for j in i:
                    if j != tag_pad_idx:
                        _.append(idx2tag[j])
                true_labels.append(_)
        print(f'{loader}_loss : {np.mean(np.array(all_loss[loader]))/64}')
        f_ = f1_score(true_labels,predictions,scheme = IOB2)
        print(f'{loader}_F1: {f_}')
        if loader == 'valid':
            F_score.append(f_)

  0%|                                                                                           | 0/20 [00:00<?, ?it/s]


train_loss : 0.276776714670418




train_F1: 0.5617333691467454
valid_loss : 0.5620519271832984


  5%|████                                                                              | 1/20 [02:56<56:00, 176.85s/it]

valid_F1: 0.015984499878905305

train_loss : 0.07233179081049812
train_F1: 0.8044286826399101
valid_loss : 0.38738970813535206


 10%|████████▏                                                                         | 2/20 [05:53<52:56, 176.48s/it]

valid_F1: 0.021440711204078963

train_loss : 0.046236297204035146
train_F1: 0.8571428571428572
valid_loss : 0.3310073914393883


 15%|████████████▎                                                                     | 3/20 [08:47<49:47, 175.71s/it]

valid_F1: 0.02430408365126466

train_loss : 0.035733658602537466
train_F1: 0.8837683870419182
valid_loss : 0.19924585208696807


 20%|████████████████▍                                                                 | 4/20 [11:42<46:46, 175.43s/it]

valid_F1: 0.041379310344827586

train_loss : 0.028873547185466537
train_F1: 0.9073933888468332
valid_loss : 0.1421058865251557


 25%|████████████████████▌                                                             | 5/20 [14:36<43:43, 174.91s/it]

valid_F1: 0.07675145024542615

train_loss : 0.025714961169233098
train_F1: 0.9145703190950532
valid_loss : 0.34587208956680043


 30%|████████████████████████▌                                                         | 6/20 [17:30<40:43, 174.53s/it]

valid_F1: 0.026912609615966133

train_loss : 0.024102576892723736
train_F1: 0.9188688772933851
valid_loss : 0.17734397363367696


 35%|████████████████████████████▋                                                     | 7/20 [20:25<37:49, 174.58s/it]

valid_F1: 0.05956552207428171

train_loss : 0.023020241731478843
train_F1: 0.9265622368199428
valid_loss : 0.3556627907959007


 40%|████████████████████████████████▊                                                 | 8/20 [23:20<34:55, 174.62s/it]

valid_F1: 0.029890401859847225

train_loss : 0.020057085074866803
train_F1: 0.9341287178625441
valid_loss : 0.07187469097966916


 45%|████████████████████████████████████▉                                             | 9/20 [26:14<32:00, 174.58s/it]

valid_F1: 0.1339637509850276

train_loss : 0.016698668313002707
train_F1: 0.9426195076871378
valid_loss : 0.13960473202856494


 50%|████████████████████████████████████████▌                                        | 10/20 [29:10<29:09, 174.97s/it]

valid_F1: 0.0737564322469983

train_loss : 0.020590253865194724
train_F1: 0.935174271762923
valid_loss : 0.06973302438417638


 55%|████████████████████████████████████████████▌                                    | 11/20 [32:05<26:14, 174.94s/it]

valid_F1: 0.14664310954063603

train_loss : 0.01946617009762039
train_F1: 0.9344963515893653
valid_loss : 0.10527594125948396


 60%|████████████████████████████████████████████████▌                                | 12/20 [34:59<23:18, 174.79s/it]

valid_F1: 0.10268948655256725

train_loss : 0.016788929116278595
train_F1: 0.9453749684370003
valid_loss : 0.28041423059649917


 65%|████████████████████████████████████████████████████▋                            | 13/20 [37:52<20:20, 174.33s/it]

valid_F1: 0.045275590551181105

train_loss : 0.01785921106203335
train_F1: 0.9378835000420274
valid_loss : 0.12568476038591173


 70%|████████████████████████████████████████████████████████▋                        | 14/20 [40:45<17:23, 173.93s/it]

valid_F1: 0.09747899159663867

train_loss : 0.017480470860090115
train_F1: 0.9426291474170517
valid_loss : 0.07383361906000732


 75%|████████████████████████████████████████████████████████████▊                    | 15/20 [43:39<14:28, 173.72s/it]

valid_F1: 0.1601525262154433

train_loss : 0.017619252223957178
train_F1: 0.9403873564182107
valid_loss : 0.031651734168197485


 80%|████████████████████████████████████████████████████████████████▊                | 16/20 [46:31<11:33, 173.42s/it]

valid_F1: 0.2851782363977486

train_loss : 0.01732724927222965
train_F1: 0.9449109842122942
valid_loss : 0.2683214161595042


 85%|████████████████████████████████████████████████████████████████████▊            | 17/20 [49:27<08:41, 173.99s/it]

valid_F1: 0.04538421866941722

train_loss : 0.016588048272788262
train_F1: 0.9451490970180596
valid_loss : 0.3565415931403037


 90%|████████████████████████████████████████████████████████████████████████▉        | 18/20 [52:24<05:49, 174.93s/it]

valid_F1: 0.03600395647873393

train_loss : 0.01693701799662422
train_F1: 0.9462112947889569
valid_loss : 0.20874208507546843


 95%|████████████████████████████████████████████████████████████████████████████▉    | 19/20 [55:22<02:55, 175.97s/it]

valid_F1: 0.0646345941975762

train_loss : 0.013794554162666424
train_F1: 0.9525088102030541
valid_loss : 0.2639475311090132


100%|█████████████████████████████████████████████████████████████████████████████████| 20/20 [58:18<00:00, 174.92s/it]

valid_F1: 0.0518783542039356





In [21]:
np.round(np.array(F_score),2)

array([0.02, 0.02, 0.02, 0.04, 0.08, 0.03, 0.06, 0.03, 0.13, 0.07, 0.15,
       0.1 , 0.05, 0.1 , 0.16, 0.29, 0.05, 0.04, 0.06, 0.05])