In [2]:
import time
%time

CPU times: user 5 µs, sys: 1 µs, total: 6 µs
Wall time: 10.7 µs


In [3]:
import os.path
import time
import pickle
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from kobert import get_tokenizer
from kobert import get_pytorch_kobert_model
from transformers import AdamW, BertModel
from transformers.optimization import get_cosine_schedule_with_warmup
import json
import pandas as pd

In [4]:
# device = torch.device("cpu")
device = torch.device("cuda:0")
model_path = 'ksic_model'


In [6]:
try:
    # bertmodel = BertModel.from_pretrained(model_path, return_dict=False)
    bertmodel = torch.load(os.path.join(model_path, 'KSIC_KoBERT.pt'))  # 전체 모델을 통째로 불러옴, 클래스 선언 필수
    bertmodel.load_state_dict(
        torch.load(os.path.join(model_path, 'KSIC_model_state_dict.pt')))  # state_dict를 불러 온 후, 모델에 저장
    _, vocab = get_pytorch_kobert_model(cachedir=".cache")
    print('Using saved model')
except Exception as e:
    print(e)
    bertmodel, vocab = get_pytorch_kobert_model(cachedir=".cache")

using cached model. /home/hdh/PycharmProjects/KoBERT-master/.cache/kobert_v1.zip
using cached model. /home/hdh/PycharmProjects/KoBERT-master/.cache/kobert_news_wiki_ko_cased-1087f8699e.spiece


In [7]:
print(vocab)

Vocab(size=8002, unk="[UNK]", reserved="['[CLS]', '[SEP]', '[MASK]', '[PAD]']")


In [8]:
def load_train_dataset():
    try:
        train_ds = nlp.data.TSVDataset('.cache/train_ds.tsv', encoding='utf-8',
                                       field_indices=[0, 1], num_discard_samples=1)
        val_ds = nlp.data.TSVDataset('.cache/val_ds.tsv', encoding='utf-8',
                                     field_indices=[0, 1], num_discard_samples=1)
        test_ds = nlp.data.TSVDataset('.cache/test_ds.tsv', encoding='utf-8',
                                      field_indices=[0, 1], num_discard_samples=1)
        print(time.strftime('%l:%M%p %Z on %b %d, %Y'))  # ' 1:36PM EDT on Oct 18, 2010'
        print('Loading saved text-label pair dataset completed')
        with open('.cache/label_ksic.pickle', 'rb') as f:
            ksic_index_dict = pickle.load(f)
        with open('.cache/ksic_label.pickle', 'rb') as f:
            ksic_label_dict = pickle.load(f)
    except:
        def load_datasets():
            try:
                with open('.cache/label_ksic.pickle', 'rb') as f:
                    ksic_index_dict = pickle.load(f)
                with open('.cache/ksic_label.pickle', 'rb') as f:
                    ksic_label_dict = pickle.load(f)
                train_input = pd.read_csv('.cache/train_input.csv', encoding='utf-8', low_memory=False)
                val_input = pd.read_csv('.cache/val_input.csv', encoding='utf-8', low_memory=False)
                test_input = pd.read_csv('.cache/test_input.csv', encoding='utf-8', low_memory=False)
                print(time.strftime('%l:%M%p %Z on %b %d, %Y'))  # ' 1:36PM EDT on Oct 18, 2010'
                print('Loading saved train_input completed')
            except:
                fname = ".cache/ksic00.json"
                f_list = [".cache/ksic00.json", ".cache/ksic01.json", ".cache/ksic02.json"]

                with open(fname, encoding='utf-8') as f:
                    for line in tqdm(f):
                        try:
                            # print('line: ', line)
                            temp = json.loads(line)
                            print(temp['an'])
                        #             text = re.sub('[-=.#/?:$}(){,]', ' ', patent['title'] + patent['ab'] + patent['cl'])
                        #             token = text.split()
                        except:
                            pass

                def read_column_names(fn):
                    with open(fn, encoding='utf-8') as json_f:
                        json_line = json_f.readline()
                    temp = json.loads(json_line)
                    return temp.keys()

                col_name = read_column_names(f_list[0])

                temp = []
                error = []
                for fn in f_list[1:]:
                    with open(fn, encoding='utf-8') as f:
                        for i, line in enumerate(f):
                            try:
                                temp.append(json.loads(line.replace('\\\\"', '\\"')))
                            except Exception as e:
                                error.append([e, line])
                raw_df = pd.DataFrame(data=temp, columns=col_name)

                class_count = raw_df['ksic'].value_counts()
                class_count2 = class_count[class_count >= 500]

                raw_df2 = raw_df.loc[raw_df['ksic'].isin(class_count2.keys())].copy()

                ksic_label = raw_df2['ksic'].unique()
                ksic_index_dict = {i: label for i, label in enumerate(ksic_label)}
                ksic_label_dict = {ksic_index_dict[key]: key for key in ksic_index_dict.keys()}

                # ksic_index_dict
                with open('.cache/label_ksic.pickle', 'wb') as f:
                    pickle.dump(ksic_index_dict, f)
                with open('.cache/ksic_label.pickle', 'wb') as f:
                    pickle.dump(ksic_label_dict, f)

                raw_df2['label'] = raw_df2['ksic'].map(ksic_label_dict)
                train_input, test_input = train_test_split(raw_df2, random_state=15, test_size=0.2,
                                                           stratify=raw_df2['ksic'],
                                                           shuffle=True)
                train_input, val_input = train_test_split(train_input, random_state=15, test_size=0.15,
                                                          stratify=train_input['ksic'], shuffle=True)

                train_input.to_csv('.cache/train_input.csv', encoding='utf-8', mode='w', index=False)
                val_input.to_csv('.cache/val_input.csv', encoding='utf-8', mode='w', index=False)
                test_input.to_csv('.cache/test_input.csv', encoding='utf-8', mode='w', index=False)
                print(time.strftime('%l:%M%p %Z on %b %d, %Y'))  # ' 1:36PM EDT on Oct 18, 2010'
                print('Loading json files and saving "train_input.csv" completed')
            return ksic_index_dict, ksic_label_dict, train_input, val_input, test_input

        ksic_index_dict, ksic_label_dict, train_input, val_input, test_input = load_datasets()

        def make_input_text(df):
            input_tl = df[['title', 'label']].copy()
            input_tl.rename(columns={'title': 'text'}, inplace=True)
            input_ab = df[['ab', 'label']].copy()
            input_ab.rename(columns={'ab': 'text'}, inplace=True)
            input_cl = df[['cl', 'label']].copy()
            input_cl.rename(columns={'cl': 'text'}, inplace=True)
            input_text = pd.concat([input_tl, input_ab, input_cl]).copy()
            input_text['text_len'] = input_text['text'].str.len()
            input_text2 = input_text.loc[
                input_text['text_len'] > 3, ['text', 'label']].copy()  # 60813 rows × 3 columns 제거
            return input_text2

        train_ds = make_input_text(train_input)
        val_ds = make_input_text(val_input)
        test_ds = make_input_text(test_input)

        train_ds.to_csv('.cache/train_ds.tsv', encoding='utf-8', mode='w', index=False, sep='\t')
        val_ds.to_csv('.cache/val_ds.tsv', encoding='utf-8', mode='w', index=False, sep='\t')
        test_ds.to_csv('.cache/test_ds.tsv', encoding='utf-8', mode='w', index=False, sep='\t')
        print(time.strftime('%l:%M%p %Z on %b %d, %Y'))  # ' 1:36PM EDT on Oct 18, 2010'
        print('Saving text-label pair dataset completed')
    return train_ds, val_ds, test_ds


In [5]:
train_ds = nlp.data.TSVDataset('.cache/train_ds.tsv', encoding='utf-8',
                               field_indices=[0, 1], num_discard_samples=1)
val_ds = nlp.data.TSVDataset('.cache/val_ds.tsv', encoding='utf-8',
                             field_indices=[0, 1], num_discard_samples=1)
test_ds = nlp.data.TSVDataset('.cache/test_ds.tsv', encoding='utf-8',
                              field_indices=[0, 1], num_discard_samples=1)
print(time.strftime('%l:%M%p %Z on %b %d, %Y'))  # ' 1:36PM EDT on Oct 18, 2010'
print('Loading saved text-label pair dataset completed')
with open('.cache/label_ksic.pickle', 'rb') as f:
    ksic_index_dict = pickle.load(f)
with open('.cache/ksic_label.pickle', 'rb') as f:
    ksic_label_dict = pickle.load(f)

 3:44PM KST on Nov 24, 2022
Loading saved text-label pair dataset completed


In [6]:
print(len(train_ds), len(train_ds), len(train_ds))

2108062


In [9]:
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

In [10]:
class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)

        self.sentences = [transform([i[sent_idx]]) for i in tqdm(dataset)]
        self.labels = [np.int32(i[label_idx]) for i in tqdm(dataset)]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))

In [11]:
max_len = 256
batch_size = 16
warmup_ratio = 0.1
num_epochs = 5
max_grad_norm = 1
log_interval = 10000
learning_rate = 5e-5

In [None]:
print(time.strftime('%l:%M%p %Z on %b %d, %Y'))  # 8:37AM KST on Jun 16, 2022
print('Starting to make BERTDataset')
try:
    # data_train_id = np.load('.cache/data_train_id.npy', mmap_mode='r')
    # data_train_id = np.load('.cache/ksic_data_train_id.npy', allow_pickle=True)
    # data_test_id = np.load('.cache/ksic_data_test_id.npy', allow_pickle=True)
    with open(".cache/ksic_data_train_id.pickle", "rb") as fr:
        data_train_id = pickle.load(fr)
    with open(".cache/ksic_data_test_id.pickle", "rb") as fr:
        data_test_id = pickle.load(fr)
    # Array can't be memory-mapped: Python objects in dtype. -> mmap_mode='r' 주석처리
    # Object arrays cannot be loaded when allow_pickle=False -> np.dave에 allow_pickle=True 문구 추가
    print(time.strftime('%l:%M%p %Z on %b %d, %Y'))  # ' 1:36PM EDT on Oct 18, 2010'
    print('Completed to load BERTDataset')
except Exception as e:
    print(e)
    train_ds, val_ds, test_ds = load_train_dataset()
    data_train_id = BERTDataset(train_ds, 0, 1, tok, max_len, True, False)
    # np.save('.cache/ksic_data_train_id.npy', data_train_id, allow_pickle=True)
    with open('.cache/ksic_data_train_id.pickle', "wb") as fw:
        pickle.dump(data_train_id, fw)
    data_test_id = BERTDataset(test_ds, 0, 1, tok, max_len, True, False)
    # np.save('.cache/ksic_data_test_id.npy', data_test_id, allow_pickle=True)
    with open('.cache/ksic_data_test_id.pickle', "wb") as fw:
        pickle.dump(data_test_id, fw)
    # data_train_id = BERTDataset(train_ds.to_numpy(), 0, 1, tok, max_len, True, False)
    # data_test_id = BERTDataset(test_ds.to_numpy(), 0, 1, tok, max_len, True, False)
    print(time.strftime('%l:%M%p %Z on %b %d, %Y'))  # ' 1:36PM EDT on Oct 18, 2010'
    print('Completed to make BERTDataset')

In [13]:
train_dataloader = torch.utils.data.DataLoader(data_train_id, batch_size=batch_size, num_workers=8, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(data_test_id, batch_size=batch_size, num_workers=8, shuffle=True)

0it [00:00, ?it/s]

특1994-021923
특1994-016867
특1995-016613
특1995-017642
특1995-017643
특1994-012484
특1993-019813
특1994-022446
특1994-027240
특1991-002356
특1994-013563
특1993-028967
특1994-003074
특1990-020903
특1990-019424
특1994-028441
특1994-033382
특1990-015538
특1989-012382
특1994-001380
특1989-009500


In [14]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size=768,
                 num_classes=2,
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate

        self.classifier = nn.Linear(hidden_size, num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)

    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)

        _, pooler = self.bert(input_ids=token_ids, token_type_ids=segment_ids.long(),
                              attention_mask=attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        else:
            out = pooler
        return self.classifier(out)

In [15]:
model = BERTClassifier(bertmodel, num_classes=500, dr_rate=0.5).to(device)

In [16]:
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]


In [17]:
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()
t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)


567 ,  ['20411' '25121' '20491' '30110' '11111']


In [18]:
def calc_accuracy(X, y):
    max_vals, max_indices = torch.max(X, 1)
    acc = (max_indices == y).sum().data.cpu().numpy()/max_indices.size()[0]
    return acc

28119    16022
26121    15241
27112    14819
Name: ksic, dtype: int64


In [19]:
for e in range(num_epochs):
    print(time.strftime('%l:%M%p %Z on %b %d, %Y'), ': starting epoch ', e)
    train_acc = 0.0
    test_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in tqdm(enumerate(train_dataloader),
                                                                        total=len(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length = valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("\ntime: {}, epoch {} batch id {}\n loss {} train acc {}".format(time.strftime('%l:%M%p'),
                                                                                   e+1, batch_id+1,
                                                                                   loss.data.cpu().numpy(),
                                                                                   train_acc / (batch_id+1)))
            # torch.save(model, model_path)  # 전체 모델 저장
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))
    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in tqdm(enumerate(test_dataloader),
                                                                        total=len(test_dataloader)):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length = valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        test_acc += calc_accuracy(out, label)
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))
    print(time.strftime('%l:%M%p %Z on %b %d, %Y'))
    # torch.save(model, model_path)  # 전체 모델 저장
    torch.save(model, os.path.join(model_path, 'KSIC_KoBERT.pt'))  # 전체 모델 저장
    torch.save(model.state_dict(), os.path.join(model_path, 'KSIC_model_state_dict.pt'))  # 모델 객체의 state_dict 저장
    torch.save({
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict()
    }, os.path.join(model_path, 'all.tar'))
    # 여러 가지 값 저장, 학습 중 진행 상황 저장을 위해 epoch, loss 값 등 일반 scalar값 저장 가능,
    # https://velog.io/@dev-junku/KoBERT-%EB%AA%A8%EB%8D%B8%EC%97%90-%EB%8C%80%ED%95%B4

print(time.strftime('%l:%M%p %Z on %b %d, %Y'))


567 -> 500


In [None]:
https://velog.io/@dev-junku/KoBERT-%EB%AA%A8%EB%8D%B8%EC%97%90-%EB%8C%80%ED%95%B4
PATH = 'drive/MyDrive/colab/StoryFlower/bert' # google 드라이브 연동 해야함. 관련코드는 뺐음
torch.save(model, PATH + 'KoBERT_담화.pt')  # 전체 모델 저장
torch.save(model.state_dict(), PATH + 'model_state_dict.pt')  # 모델 객체의 state_dict 저장
torch.save({
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict()
}, PATH + 'all.tar')  # 여러 가지 값 저장, 학습 중 진행 상황 저장을 위해 epoch, loss 값 등 일반 scalar값 저장 가능

!pip install mxnet
!pip install gluonnlp pandas tqdm
!pip install sentencepiece
!pip install transformers==3.0.2
!pip install torch

!pip install git+https://git@github.com/SKTBrain/KoBERT.git@master

# torch
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm, tqdm_notebook

#kobert
from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model

#GPU 사용
device = torch.device("cuda:0")

#BERT 모델, Vocabulary 불러오기 필수
bertmodel, vocab = get_pytorch_kobert_model()


# KoBERT에 입력될 데이터셋 정리
class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)

        self.sentences = [transform([i[sent_idx]]) for i in dataset]
        self.labels = [np.int32(i[label_idx]) for i in dataset]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))  

# 모델 정의
class BERTClassifier(nn.Module): ## 클래스를 상속
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes=6,   ##클래스 수 조정##
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)

# Setting parameters
max_len = 64
batch_size = 32
warmup_ratio = 0.1
num_epochs = 20
max_grad_norm = 1
log_interval = 100
learning_rate =  5e-5

## 학습 모델 로드
PATH = 'drive/MyDrive/colab/StoryFlower/bert/'
model = torch.load(PATH + 'KoBERT_담화_86.pt')  # 전체 모델을 통째로 불러옴, 클래스 선언 필수
model.load_state_dict(torch.load(PATH + 'model_state_dict_86.pt'))  # state_dict를 불러 온 후, 모델에 저장

#토큰화
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

def new_softmax(a) : 
    c = np.max(a) # 최댓값
    exp_a = np.exp(a-c) # 각각의 원소에 최댓값을 뺀 값에 exp를 취한다. (이를 통해 overflow 방지)
    sum_exp_a = np.sum(exp_a)
    y = (exp_a / sum_exp_a) * 100
    return np.round(y, 3)


# 예측 모델 설정
def predict(predict_sentence):

    data = [predict_sentence, '0']
    dataset_another = [data]

    another_test = BERTDataset(dataset_another, 0, 1, tok, max_len, True, False)
    test_dataloader = torch.utils.data.DataLoader(another_test, batch_size=batch_size, num_workers=5)
    
    model.eval()

    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataloader):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)

        valid_length= valid_length
        label = label.long().to(device)

        out = model(token_ids, valid_length, segment_ids)

        test_eval=[]
        for i in out:
            logits=i
            logits = logits.detach().cpu().numpy()
            min_v = min(logits)
            total = 0
            probability = []
            logits = np.round(new_softmax(logits), 3).tolist()
            for logit in logits:
                print(logit)
                probability.append(np.round(logit, 3))

            if np.argmax(logits) == 0:  emotion = "기쁨"
            elif np.argmax(logits) == 1: emotion = "불안"
            elif np.argmax(logits) == 2: emotion = '당황'
            elif np.argmax(logits) == 3: emotion = '슬픔'
            elif np.argmax(logits) == 4: emotion = '분노'
            elif np.argmax(logits) == 5: emotion = '상처'

            probability.append(emotion)
            print(probability)
    return probability


In [None]:
test_dataloader2 = torch.utils.data.DataLoader(data_test_id, batch_size=batch_size, num_workers=5)
test_acc = 0.0
results = []
for batch_id, (token_ids, valid_length, segment_ids, label) in tqdm(enumerate(test_dataloader2), total=len(test_dataloader2)):
    print(batch_id)
    token_ids = token_ids.long().to(device)
    segment_ids = segment_ids.long().to(device)
    valid_length= valid_length
    label = label.long().to(device)
    out = model(token_ids, valid_length, segment_ids)
    max_vals, max_indices = torch.max(out, 1)
    results.extend(max_indices.tolist())
    test_acc += calc_accuracy(out, label)
    print("test acc {}".format(test_acc / (batch_id+1)))
print(results)
#     print('label: ', label)    

In [None]:
print(max_indices)
print(max_indices.tolist())


In [None]:
m = nn.Softmax(dim=1)
exp = m(out)
max_vals, max_indices = torch.max(out, 1)
print(max_indices)

In [None]:
# https://dacon.io/en/competitions/official/235747/codeshare/3082?page=1&dtype=recent
def plot_graphs(history, string):
    plt.plot(history.history[string])
    plt.plot(history.history['val_'+string], '')
    plt.xlabel("Epochs")
    plt.ylabel(string)
    plt.legend([string, 'val_'+string])
    plt.show()

In [None]:
# 5. 모델 학습 과정 표시하기
%matplotlib inline
import matplotlib.pyplot as plt

fig, loss_ax = plt.subplots()

acc_ax = loss_ax.twinx()

loss_ax.plot(loss.long().to(device), 'y', label='train loss')
# loss_ax.plot(hist.history['val_loss'], 'r', label='val loss')

acc_ax.plot(train_acc.long().to(device), 'b', label='train_acc')
acc_ax.plot(test_acc.long().to(device), 'g', label='test_acc')

loss_ax.set_xlabel('epoch')
loss_ax.set_ylabel('loss')
acc_ax.set_ylabel('accuray')

loss_ax.legend(loc='upper left')
acc_ax.legend(loc='lower left')

plt.show()


In [None]:
# out = model(token_ids, valid_length, segment_ids)
# print(out)

100%
2344/2344 [10:52<00:00, 3.82it/s]
epoch 1 batch id 1 loss 0.7244999408721924 train acc 0.453125
epoch 1 batch id 201 loss 0.4514557123184204 train acc 0.5766480099502488
epoch 1 batch id 401 loss 0.44362887740135193 train acc 0.6851231296758105
epoch 1 batch id 601 loss 0.5137903094291687 train acc 0.734426996672213
epoch 1 batch id 801 loss 0.43659207224845886 train acc 0.7636352996254682
epoch 1 batch id 1001 loss 0.3066664934158325 train acc 0.7802978271728271
epoch 1 batch id 1201 loss 0.3117211163043976 train acc 0.7930890924229809
epoch 1 batch id 1401 loss 0.30558276176452637 train acc 0.8026075124910778
epoch 1 batch id 1601 loss 0.3304724097251892 train acc 0.8104992973141787
epoch 1 batch id 1801 loss 0.25170841813087463 train acc 0.8164995141588006
epoch 1 batch id 2001 loss 0.2823370695114136 train acc 0.8224481509245377
epoch 1 batch id 2201 loss 0.3238280713558197 train acc 0.8274292935029532
epoch 1 train acc 0.8307713843856656
100%
782/782 [01:06<00:00, 13.16it/s]
epoch 1 test acc 0.8835118286445013
100%
2344/2344 [10:52<00:00, 3.81it/s]
epoch 2 batch id 1 loss 0.4955632984638214 train acc 0.828125
epoch 2 batch id 201 loss 0.2225867211818695 train acc 0.8816075870646766
epoch 2 batch id 401 loss 0.33707231283187866 train acc 0.8836502493765586
epoch 2 batch id 601 loss 0.39554905891418457 train acc 0.887115224625624
epoch 2 batch id 801 loss 0.30988579988479614 train acc 0.8883426966292135
epoch 2 batch id 1001 loss 0.28933045268058777 train acc 0.8911713286713286
epoch 2 batch id 1201 loss 0.24474024772644043 train acc 0.8932269983347211
epoch 2 batch id 1401 loss 0.1908964067697525 train acc 0.8957218058529621
epoch 2 batch id 1601 loss 0.23474052548408508 train acc 0.89772993441599
epoch 2 batch id 1801 loss 0.1599130779504776 train acc 0.8997171710161022
epoch 2 batch id 2001 loss 0.20312610268592834 train acc 0.9019865067466267
epoch 2 batch id 2201 loss 0.2386036366224289 train acc 0.9033251930940481
epoch 2 train acc 0.904759047923777
100%
782/782 [01:06<00:00, 11.77it/s]
epoch 2 test acc 0.8906449808184144
100%
2344/2344 [10:51<00:00, 3.81it/s]
epoch 3 batch id 1 loss 0.3390277922153473 train acc 0.875
epoch 3 batch id 201 loss 0.15983553230762482 train acc 0.9240516169154229
epoch 3 batch id 401 loss 0.14856880903244019 train acc 0.9265118453865336
epoch 3 batch id 601 loss 0.2642267644405365 train acc 0.9280366056572379
epoch 3 batch id 801 loss 0.1951659917831421 train acc 0.93020443196005
epoch 3 batch id 1001 loss 0.26453569531440735 train acc 0.9319586663336663
epoch 3 batch id 1201 loss 0.1282612681388855 train acc 0.9340523522064946
epoch 3 batch id 1401 loss 0.1837957799434662 train acc 0.9360501427551748
epoch 3 batch id 1601 loss 0.14137345552444458 train acc 0.9374414428482198
epoch 3 batch id 1801 loss 0.09849003702402115 train acc 0.9387493059411438
epoch 3 batch id 2001 loss 0.15634101629257202 train acc 0.9401080709645178
epoch 3 batch id 2201 loss 0.1982114464044571 train acc 0.9410495229441163
epoch 3 train acc 0.9420039640216155
100%
782/782 [01:06<00:00, 11.77it/s]
epoch 3 test acc 0.8969988810741688
100%
2344/2344 [10:52<00:00, 3.82it/s]
epoch 4 batch id 1 loss 0.388761043548584 train acc 0.875
epoch 4 batch id 201 loss 0.06205718219280243 train acc 0.9593439054726368
epoch 4 batch id 401 loss 0.06854245811700821 train acc 0.9596711346633416
epoch 4 batch id 601 loss 0.23485814034938812 train acc 0.9600925540765392
epoch 4 batch id 801 loss 0.1342923790216446 train acc 0.9612008426966292
epoch 4 batch id 1001 loss 0.1908232569694519 train acc 0.9621472277722277
epoch 4 batch id 1201 loss 0.11091630905866623 train acc 0.9632597835137385
epoch 4 batch id 1401 loss 0.10650145262479782 train acc 0.9642219842969307
epoch 4 batch id 1601 loss 0.08601253479719162 train acc 0.9649242660836976
epoch 4 batch id 1801 loss 0.057230446487665176 train acc 0.9656093836757357
epoch 4 batch id 2001 loss 0.07411567866802216 train acc 0.9665011244377811
epoch 4 batch id 2201 loss 0.1597125381231308 train acc 0.9671456156292594
epoch 4 train acc 0.9676701151877133
100%
782/782 [01:06<00:00, 11.77it/s]
epoch 4 test acc 0.8980778452685422
100%
2344/2344 [10:52<00:00, 3.81it/s]
epoch 5 batch id 1 loss 0.3727969229221344 train acc 0.890625
epoch 5 batch id 201 loss 0.02794063650071621 train acc 0.9752798507462687
epoch 5 batch id 401 loss 0.024620698764920235 train acc 0.9767378428927681
epoch 5 batch id 601 loss 0.15002880990505219 train acc 0.9765495008319468
epoch 5 batch id 801 loss 0.05448848009109497 train acc 0.9766112671660424
epoch 5 batch id 1001 loss 0.08006531745195389 train acc 0.9770541958041958
epoch 5 batch id 1201 loss 0.04451199620962143 train acc 0.9775317443796836
epoch 5 batch id 1401 loss 0.08561042696237564 train acc 0.9780625446109922
epoch 5 batch id 1601 loss 0.026716381311416626 train acc 0.9783533728919426
epoch 5 batch id 1801 loss 0.02442212402820587 train acc 0.9787357717934481
epoch 5 batch id 2001 loss 0.0197431817650795 train acc 0.9790339205397302
epoch 5 batch id 2201 loss 0.03802771866321564 train acc 0.9792565879145843
epoch 5 train acc 0.9794199729806597
100%
782/782 [01:06<00:00, 11.78it/s]
epoch 5 test acc 0.8974384590792839