In [None]:
from keras import backend as K

def f1(y_true, y_pred):
    def recall(y_true, y_pred):
        """Recall metric.

        Only computes a batch-wise average of recall.

        Computes the recall, a metric for multi-label classification of
        how many relevant items are selected.
        """
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
        recall = true_positives / (possible_positives + K.epsilon())
        return recall

    def precision(y_true, y_pred):
        """Precision metric.

        Only computes a batch-wise average of precision.

        Computes the precision, a metric for multi-label classification of
        how many selected items are relevant.
        """
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
        precision = true_positives / (predicted_positives + K.epsilon())
        return precision
    precision = precision(y_true, y_pred)
    recall = recall(y_true, y_pred)
    return 2*((precision*recall)/(precision+recall+K.epsilon()))


In [None]:
!sudo apt-get install g++ openjdk-7-jdk # Install Java 1.7+
!sudo apt-get install python-dev; pip install konlpy    
!sudo apt-get install curl
!bash <(curl -s https://raw.githubusercontent.com/konlpy/konlpy/master/scripts/mecab.sh)

Reading package lists... Done
Building dependency tree       
Reading state information... Done
Package openjdk-7-jdk is not available, but is referred to by another package.
This may mean that the package is missing, has been obsoleted, or
is only available from another source

E: Package 'openjdk-7-jdk' has no installation candidate
Reading package lists... Done
Building dependency tree       
Reading state information... Done
python-dev is already the newest version (2.7.15~rc1-1).
0 upgraded, 0 newly installed, 0 to remove and 91 not upgraded.
Reading package lists... Done
Building dependency tree       
Reading state information... Done
curl is already the newest version (7.58.0-2ubuntu3.14).
0 upgraded, 0 newly installed, 0 to remove and 91 not upgraded.
mecab-ko is already installed
mecab-ko-dic is already installed
mecab-python is already installed
Done.


In [None]:
!pip install mxnet
!pip install gluonnlp pandas tqdm
!pip install sentencepiece
!pip install transformers==3 # 최신 버전으로 설치하면 "Input: must be Tensor, not str" 라는 에러 발생
!pip install torch
!pip install konlpy

!pip install git+https://git@github.com/SKTBrain/KoBERT.git@master


import re
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm, tqdm_notebook

from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model

from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup

torch.manual_seed(123)
np.random.seed(123)

##GPU 사용 시
device = torch.device("cuda:0")

bertmodel, vocab = get_pytorch_kobert_model()

from google.colab import drive
drive.mount('/content/drive')\

Collecting git+https://****@github.com/SKTBrain/KoBERT.git@master
  Cloning https://****@github.com/SKTBrain/KoBERT.git (to revision master) to /tmp/pip-req-build-qv4oq1sr
  Running command git clone -q 'https://****@github.com/SKTBrain/KoBERT.git' /tmp/pip-req-build-qv4oq1sr
using cached model
using cached model
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Setting parameters
max_len = 64 # 해당 길이를 초과하는 단어에 대해선 bert가 학습하지 않음
batch_size = 128
warmup_ratio = 0.1
num_epochs = 30
max_grad_norm = 1
log_interval = 100
learning_rate = 3e-5

In [None]:
# 학습용 데이터셋 불러오기
import pandas as pd
from konlpy.tag import Mecab
train = pd.read_csv('/content/drive/My Drive/mulcamnlp2021/datain/train.csv')
test = pd.read_csv('/content/drive/My Drive/mulcamnlp2021/datain/test.csv')

# 데이터 전처리
train=train[['과제명', '요약문_연구내용','label']]
test=test[['과제명', '요약문_연구내용']]
train['요약문_연구내용'].fillna('NAN', inplace=True)
test['요약문_연구내용'].fillna('NAN', inplace=True)

# 대표 질병 5개 추출
train['data']=train['과제명']+train['요약문_연구내용']
test['data']=test['과제명']+test['요약문_연구내용']
def clean_text(sents):
    sents_clean = []
    for sent in sents:
        sents_clean.append(re.sub("[^가-힣ㄱ-ㅎㅏ-ㅣ]", " ", sent))
    return sents_clean


def tokenize(docs):
  mecab = Mecab()
  res=[]
  for doc in docs:
    tokenlist = mecab.pos(doc)
    temp = []
    for w in tokenlist:
      if w[1] in ['NNG', 'NNP', 'NNB', 'NR']:
        temp.append(w[0])
    res.append(' '.join(temp))
  return res

train['data'] = tokenize(clean_text(train['data']))
test['data'] = tokenize(clean_text(test['data']))

# Train / Test set 분리
from sklearn.model_selection import train_test_split
trains, tests = train_test_split(train, test_size=0.2, random_state=123)
print("train shape is:", len(trains))
print("test shape is:", len(tests))


# 기본 Bert tokenizer 사용
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair) 

        self.sentences = [transform([i]) for i in dataset['data']]
        self.labels = [np.int32(i) for i in dataset['label']]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))
        
data_train = BERTDataset(trains, 3, 2, tok, max_len, True, False)
data_test = BERTDataset(tests, 3, 2, tok, max_len, True, False)

# pytorch용 DataLoader 사용
train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, num_workers=2)
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=2)

train shape is: 139443
test shape is: 34861
using cached model


In [None]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes = 46, # softmax 사용 <- binary일 경우는 2
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)


In [None]:
model = BERTClassifier(bertmodel, dr_rate=0.5).to(device)

# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

# 옵티마이저 선언
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = nn.CrossEntropyLoss() # softmax용 Loss Function 정하기 <- binary classification도 해당 loss function 사용 가능

t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)


# 학습 평가 지표인 accuracy 계산 -> 얼마나 타겟값을 많이 맞추었는가
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc


In [None]:
# del train, test, data_train, data_test, trains, tests
del train, data_train, data_test, trains, tests

In [None]:
import torch, gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
highest_acc = 0
patience = 0

# 모델 학습 시작
for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # gradient clipping
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))

    #테스트 문장 예측
    
    model.eval() # 평가 모드로 변경
    
    for test_batch_id, (test_token_ids, test_valid_length, test_segment_ids, test_label) in enumerate(tqdm_notebook(test_dataloader)):
        test_token_ids = test_token_ids.long().to(device)
        test_segment_ids = test_segment_ids.long().to(device)
        test_valid_length= test_valid_length
        test_label = test_label.long().to(device)
        test_out = model(token_ids, valid_length, segment_ids)
        test_loss = loss_fn(out, label)
        test_acc += calc_accuracy(out, label)
    print("epoch {} test acc {}".format(e+1, test_acc / (test_batch_id+1)))

    if test_acc - highest_acc >= 0.01:
        torch.save({
            'epoch': e,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': test_loss,
            }, '/content/drive/My Drive/mulcamnlp2021/torchckpt/model.pt')
        patience = 0
    else:
        print("test acc did not improved. best:{} current:{}".format(highest_acc, test_acc))
        patience += 1
        if patience > 5:
            break
    print('current patience: {}'.format(patience))
    print("************************************************************************************")


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  if __name__ == '__main__':


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 1 batch id 1 loss 3.8252406120300293 train acc 0.0234375
epoch 1 batch id 101 loss 3.2361252307891846 train acc 0.2130259900990099
epoch 1 batch id 201 loss 2.231053352355957 train acc 0.4992615049751244
epoch 1 batch id 301 loss 1.4669082164764404 train acc 0.6032495847176079
epoch 1 batch id 401 loss 1.3117352724075317 train acc 0.6576917082294265
epoch 1 batch id 501 loss 1.3120595216751099 train acc 0.6892776946107785
epoch 1 batch id 601 loss 0.9133090376853943 train acc 0.7112884775374376
epoch 1 batch id 701 loss 1.0392524003982544 train acc 0.7265513552068473
epoch 1 batch id 801 loss 0.9304847717285156 train acc 0.7379545099875156
epoch 1 batch id 901 loss 0.923143744468689 train acc 0.7466443534961155
epoch 1 batch id 1001 loss 0.6584336161613464 train acc 0.7539413711288712

epoch 1 train acc 0.7596899453588775


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 1 test acc 0.8823529411764658
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 2 batch id 1 loss 1.1076102256774902 train acc 0.734375
epoch 2 batch id 101 loss 0.6918722987174988 train acc 0.8230971534653465
epoch 2 batch id 201 loss 0.7229592204093933 train acc 0.8268034825870647
epoch 2 batch id 301 loss 0.551899254322052 train acc 0.8268791528239202
epoch 2 batch id 401 loss 0.7899765372276306 train acc 0.8296056733167082
epoch 2 batch id 501 loss 0.7412642240524292 train acc 0.8320234530938124
epoch 2 batch id 601 loss 0.4149633049964905 train acc 0.8351835482529119
epoch 2 batch id 701 loss 0.6622830033302307 train acc 0.8370296897289586
epoch 2 batch id 801 loss 0.6959438920021057 train acc 0.8388830368289638
epoch 2 batch id 901 loss 0.6231672167778015 train acc 0.8403596698113207
epoch 2 batch id 1001 loss 0.49918535351753235 train acc 0.8419549200799201

epoch 2 train acc 0.8431655030131319


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 2 test acc 0.901960784313728
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 3 batch id 1 loss 0.6913585066795349 train acc 0.8125
epoch 3 batch id 101 loss 0.41447609663009644 train acc 0.8589882425742574
epoch 3 batch id 201 loss 0.5009818077087402 train acc 0.8600357587064676
epoch 3 batch id 301 loss 0.4276493489742279 train acc 0.8595047757475083
epoch 3 batch id 401 loss 0.5935074090957642 train acc 0.8612063591022444
epoch 3 batch id 501 loss 0.551464319229126 train acc 0.8624001996007984
epoch 3 batch id 601 loss 0.3070865273475647 train acc 0.8648086522462562
epoch 3 batch id 701 loss 0.492123007774353 train acc 0.8653151747503567
epoch 3 batch id 801 loss 0.4707980155944824 train acc 0.8666413077403246
epoch 3 batch id 901 loss 0.5225281119346619 train acc 0.8672221836847946
epoch 3 batch id 1001 loss 0.38055354356765747 train acc 0.8680460164835165

epoch 3 train acc 0.8689038889638425


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 3 test acc 0.9411764705882324
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 4 batch id 1 loss 0.5628318786621094 train acc 0.875
epoch 4 batch id 101 loss 0.3890545666217804 train acc 0.8790996287128713
epoch 4 batch id 201 loss 0.36980944871902466 train acc 0.8794698383084577
epoch 4 batch id 301 loss 0.35116395354270935 train acc 0.879438330564784
epoch 4 batch id 401 loss 0.4419063925743103 train acc 0.8821306109725686
epoch 4 batch id 501 loss 0.4153319001197815 train acc 0.8831711576846307
epoch 4 batch id 601 loss 0.1903926283121109 train acc 0.8855423252911814
epoch 4 batch id 701 loss 0.36025822162628174 train acc 0.8861336483594865
epoch 4 batch id 801 loss 0.3023512065410614 train acc 0.8881183676654182
epoch 4 batch id 901 loss 0.3590075671672821 train acc 0.8891769561598224
epoch 4 batch id 1001 loss 0.33748650550842285 train acc 0.8903986638361638

epoch 4 train acc 0.8913882611530851


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 4 test acc 0.9803921568627413
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 5 batch id 1 loss 0.4330187439918518 train acc 0.8984375
epoch 5 batch id 101 loss 0.30805471539497375 train acc 0.900990099009901
epoch 5 batch id 201 loss 0.3644924461841583 train acc 0.9020522388059702
epoch 5 batch id 301 loss 0.27585655450820923 train acc 0.9009811046511628
epoch 5 batch id 401 loss 0.3458540439605713 train acc 0.9030938279301746
epoch 5 batch id 501 loss 0.28241533041000366 train acc 0.9038173652694611
epoch 5 batch id 601 loss 0.1832352876663208 train acc 0.9063279950083195
epoch 5 batch id 701 loss 0.33181482553482056 train acc 0.9067960948644793
epoch 5 batch id 801 loss 0.247635617852211 train acc 0.9084737827715356
epoch 5 batch id 901 loss 0.2963389456272125 train acc 0.9093975443951166
epoch 5 batch id 1001 loss 0.26346778869628906 train acc 0.9104567307692307

epoch 5 train acc 0.9112492129879475


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 5 test acc 0.9803921568627413
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 6 batch id 1 loss 0.41288504004478455 train acc 0.8828125
epoch 6 batch id 101 loss 0.3323139548301697 train acc 0.9172339108910891
epoch 6 batch id 201 loss 0.24398934841156006 train acc 0.9198927238805971
epoch 6 batch id 301 loss 0.23682457208633423 train acc 0.9187603820598007
epoch 6 batch id 401 loss 0.2799283266067505 train acc 0.920160536159601
epoch 6 batch id 501 loss 0.2648429274559021 train acc 0.9205963073852296
epoch 6 batch id 601 loss 0.14379076659679413 train acc 0.9222259775374376
epoch 6 batch id 701 loss 0.2587035596370697 train acc 0.9231455064194009
epoch 6 batch id 801 loss 0.23675182461738586 train acc 0.9246352215980025
epoch 6 batch id 901 loss 0.29859432578086853 train acc 0.9250485571587126
epoch 6 batch id 1001 loss 0.24822843074798584 train acc 0.9259178321678322

epoch 6 train acc 0.926627006880734


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 6 test acc 1.0
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 7 batch id 1 loss 0.2728045880794525 train acc 0.9296875
epoch 7 batch id 101 loss 0.23481757938861847 train acc 0.9326268564356436
epoch 7 batch id 201 loss 0.24331910908222198 train acc 0.9347403606965174
epoch 7 batch id 301 loss 0.23004840314388275 train acc 0.9341777408637874
epoch 7 batch id 401 loss 0.22035937011241913 train acc 0.9353569201995012
epoch 7 batch id 501 loss 0.14115944504737854 train acc 0.936501996007984
epoch 7 batch id 601 loss 0.1497342586517334 train acc 0.937551996672213
epoch 7 batch id 701 loss 0.19484181702136993 train acc 0.9382912803138374
epoch 7 batch id 801 loss 0.1018046960234642 train acc 0.9394604400749064
epoch 7 batch id 901 loss 0.19328951835632324 train acc 0.939927857935627
epoch 7 batch id 1001 loss 0.19698257744312286 train acc 0.9405126123876124

epoch 7 train acc 0.9407861969329017


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 7 test acc 0.9803921568627413
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 8 batch id 1 loss 0.22160230576992035 train acc 0.9453125
epoch 8 batch id 101 loss 0.22842423617839813 train acc 0.942605198019802
epoch 8 batch id 201 loss 0.29408761858940125 train acc 0.9439521144278606
epoch 8 batch id 301 loss 0.11963357031345367 train acc 0.9441185631229236
epoch 8 batch id 401 loss 0.1669701337814331 train acc 0.9453125
epoch 8 batch id 501 loss 0.12781548500061035 train acc 0.9467939121756487
epoch 8 batch id 601 loss 0.14181025326251984 train acc 0.9478343386023295
epoch 8 batch id 701 loss 0.15483000874519348 train acc 0.948644793152639
epoch 8 batch id 801 loss 0.1545329988002777 train acc 0.9493699282147315
epoch 8 batch id 901 loss 0.1563025563955307 train acc 0.9497693534961155
epoch 8 batch id 1001 loss 0.16402161121368408 train acc 0.9504089660339661

epoch 8 train acc 0.9505805619266054


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 8 test acc 1.0
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 9 batch id 1 loss 0.19032856822013855 train acc 0.9453125
epoch 9 batch id 101 loss 0.1593199521303177 train acc 0.9532023514851485
epoch 9 batch id 201 loss 0.2169194221496582 train acc 0.9525031094527363
epoch 9 batch id 301 loss 0.09625088423490524 train acc 0.9526318521594684
epoch 9 batch id 401 loss 0.18069803714752197 train acc 0.9536120635910225
epoch 9 batch id 501 loss 0.10296524316072464 train acc 0.954684381237525
epoch 9 batch id 601 loss 0.11400092393159866 train acc 0.9552308652246256
epoch 9 batch id 701 loss 0.12573549151420593 train acc 0.9557440263908702
epoch 9 batch id 801 loss 0.14158710837364197 train acc 0.9560997971285893
epoch 9 batch id 901 loss 0.16422706842422485 train acc 0.9565153301886793
epoch 9 batch id 1001 loss 0.12386158853769302 train acc 0.957199050949051

epoch 9 train acc 0.9575328296456197


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 9 test acc 0.9607843137254916
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 10 batch id 1 loss 0.2890418469905853 train acc 0.9375
epoch 10 batch id 101 loss 0.08264759927988052 train acc 0.9565284653465347
epoch 10 batch id 201 loss 0.2171182930469513 train acc 0.9581389925373134
epoch 10 batch id 301 loss 0.12688849866390228 train acc 0.9581343438538206
epoch 10 batch id 401 loss 0.18362605571746826 train acc 0.9592035536159601
epoch 10 batch id 501 loss 0.035124316811561584 train acc 0.9603293413173652
epoch 10 batch id 601 loss 0.06274165958166122 train acc 0.9615094633943427
epoch 10 batch id 701 loss 0.08466865122318268 train acc 0.9623083095577746
epoch 10 batch id 801 loss 0.04591037705540657 train acc 0.9629662141073658
epoch 10 batch id 901 loss 0.11679039150476456 train acc 0.9634520671476138
epoch 10 batch id 1001 loss 0.05123162642121315 train acc 0.9640827922077922

epoch 10 train acc 0.9643205275229357


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 10 test acc 1.0
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 11 batch id 1 loss 0.20685380697250366 train acc 0.9609375
epoch 11 batch id 101 loss 0.11325320601463318 train acc 0.966506806930693
epoch 11 batch id 201 loss 0.14149875938892365 train acc 0.9658737562189055
epoch 11 batch id 301 loss 0.09347961097955704 train acc 0.9652200996677741
epoch 11 batch id 401 loss 0.11422090977430344 train acc 0.9661588216957606
epoch 11 batch id 501 loss 0.10834876447916031 train acc 0.9668007734530938
epoch 11 batch id 601 loss 0.049419429153203964 train acc 0.9678530574043261
epoch 11 batch id 701 loss 0.05684167146682739 train acc 0.9684713801711841
epoch 11 batch id 801 loss 0.04147379472851753 train acc 0.9691498907615481
epoch 11 batch id 901 loss 0.14633747935295105 train acc 0.9693656354051055
epoch 11 batch id 1001 loss 0.04250151664018631 train acc 0.9695460789210789

epoch 11 train acc 0.9696315940366973


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 11 test acc 1.0
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 12 batch id 1 loss 0.10968412458896637 train acc 0.9765625
epoch 12 batch id 101 loss 0.08583611249923706 train acc 0.9725402227722773
epoch 12 batch id 201 loss 0.1531163454055786 train acc 0.9701492537313433
epoch 12 batch id 301 loss 0.10093823075294495 train acc 0.9695546096345515
epoch 12 batch id 401 loss 0.09928421676158905 train acc 0.9708541147132169
epoch 12 batch id 501 loss 0.02033929154276848 train acc 0.9717596057884231
epoch 12 batch id 601 loss 0.046069782227277756 train acc 0.9721037853577371
epoch 12 batch id 701 loss 0.051867444068193436 train acc 0.9722271754636234
epoch 12 batch id 801 loss 0.05013185739517212 train acc 0.9727001404494382
epoch 12 batch id 901 loss 0.08112433552742004 train acc 0.9729900804661488
epoch 12 batch id 1001 loss 0.09225459396839142 train acc 0.9734328171828172

epoch 12 train acc 0.9731436353211009


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 12 test acc 1.0
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 13 batch id 1 loss 0.13306458294391632 train acc 0.9765625
epoch 13 batch id 101 loss 0.061965472996234894 train acc 0.9754022277227723
epoch 13 batch id 201 loss 0.12119060754776001 train acc 0.9748911691542289
epoch 13 batch id 301 loss 0.04535337910056114 train acc 0.9746418189368771
epoch 13 batch id 401 loss 0.11497481167316437 train acc 0.97504286159601
epoch 13 batch id 501 loss 0.04911861941218376 train acc 0.975813997005988
epoch 13 batch id 601 loss 0.05211295560002327 train acc 0.9762375207986689
epoch 13 batch id 701 loss 0.0385054312646389 train acc 0.9765959343794579
epoch 13 batch id 801 loss 0.04171347618103027 train acc 0.9769818976279651
epoch 13 batch id 901 loss 0.047420110553503036 train acc 0.9770914261931187
epoch 13 batch id 1001 loss 0.03847985342144966 train acc 0.977179070929071

epoch 13 train acc 0.9770175616118006


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 13 test acc 0.9803921568627413
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 14 batch id 1 loss 0.07400265336036682 train acc 0.9765625
epoch 14 batch id 101 loss 0.07379217445850372 train acc 0.9786509900990099
epoch 14 batch id 201 loss 0.08643775433301926 train acc 0.9791277985074627
epoch 14 batch id 301 loss 0.023778917267918587 train acc 0.9785610465116279
epoch 14 batch id 401 loss 0.1545373499393463 train acc 0.9788224750623441
epoch 14 batch id 501 loss 0.019256791099905968 train acc 0.978932759481038
epoch 14 batch id 601 loss 0.061866339296102524 train acc 0.9794093178036606
epoch 14 batch id 701 loss 0.04622013866901398 train acc 0.9799505171184023
epoch 14 batch id 801 loss 0.030335593968629837 train acc 0.980210284019975
epoch 14 batch id 901 loss 0.028136244043707848 train acc 0.9804123890122086
epoch 14 batch id 1001 loss 0.02263488806784153 train acc 0.9803321678321678

epoch 14 train acc 0.9802967316513761


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 14 test acc 1.0
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 15 batch id 1 loss 0.052279822528362274 train acc 0.9765625
epoch 15 batch id 101 loss 0.07412195950746536 train acc 0.981822400990099
epoch 15 batch id 201 loss 0.09227012097835541 train acc 0.9803327114427861
epoch 15 batch id 301 loss 0.050083234906196594 train acc 0.9800404900332226
epoch 15 batch id 401 loss 0.08389188349246979 train acc 0.9803615960099751
epoch 15 batch id 501 loss 0.038320936262607574 train acc 0.9807728293413174
epoch 15 batch id 601 loss 0.019829602912068367 train acc 0.9810602121464226
epoch 15 batch id 701 loss 0.009906499646604061 train acc 0.9812990370898717
epoch 15 batch id 801 loss 0.06170134246349335 train acc 0.981585518102372
epoch 15 batch id 901 loss 0.04555971175432205 train acc 0.9815916342952276
epoch 15 batch id 1001 loss 0.01490563154220581 train acc 0.9816277472527473

epoch 15 train acc 0.9818520642201835


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 15 test acc 1.0
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 16 batch id 1 loss 0.022448178380727768 train acc 0.9921875
epoch 16 batch id 101 loss 0.06032249704003334 train acc 0.9833694306930693
epoch 16 batch id 201 loss 0.07105602324008942 train acc 0.9831312189054726
epoch 16 batch id 301 loss 0.01607424020767212 train acc 0.9829993770764119
epoch 16 batch id 401 loss 0.12359761446714401 train acc 0.9835567331670823
epoch 16 batch id 501 loss 0.03126830607652664 train acc 0.9837980289421158
epoch 16 batch id 601 loss 0.07136562466621399 train acc 0.983946027454243
epoch 16 batch id 701 loss 0.05750839039683342 train acc 0.9838623395149786
epoch 16 batch id 801 loss 0.008374162018299103 train acc 0.9842286985018727
epoch 16 batch id 901 loss 0.04413134604692459 train acc 0.9843316453940066
epoch 16 batch id 1001 loss 0.06040889769792557 train acc 0.9843515859140859

epoch 16 train acc 0.9845255160550459


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 16 test acc 1.0
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 17 batch id 1 loss 0.02936646156013012 train acc 0.9921875
epoch 17 batch id 101 loss 0.012535830959677696 train acc 0.9859993811881188
epoch 17 batch id 201 loss 0.04476921632885933 train acc 0.9856576492537313
epoch 17 batch id 301 loss 0.04895586892962456 train acc 0.9852055647840532
epoch 17 batch id 401 loss 0.06712943315505981 train acc 0.9850568890274314
epoch 17 batch id 501 loss 0.05931737646460533 train acc 0.984936377245509
epoch 17 batch id 601 loss 0.005976209416985512 train acc 0.9850769550748752
epoch 17 batch id 701 loss 0.04410187155008316 train acc 0.9849545292439372
epoch 17 batch id 801 loss 0.061684317886829376 train acc 0.9854576310861424
epoch 17 batch id 901 loss 0.060382694005966187 train acc 0.9855108906770256
epoch 17 batch id 1001 loss 0.005005488637834787 train acc 0.98565497002997

epoch 17 train acc 0.9857904973916172


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 17 test acc 0.9803921568627413
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 18 batch id 1 loss 0.057391103357076645 train acc 0.984375
epoch 18 batch id 101 loss 0.00719530088827014 train acc 0.9856126237623762
epoch 18 batch id 201 loss 0.04779154807329178 train acc 0.9857742537313433
epoch 18 batch id 301 loss 0.01665523275732994 train acc 0.9856468023255814
epoch 18 batch id 401 loss 0.06805119663476944 train acc 0.98589463840399
epoch 18 batch id 501 loss 0.010677381418645382 train acc 0.9866049151696606
epoch 18 batch id 601 loss 0.005881864577531815 train acc 0.986922836938436
epoch 18 batch id 701 loss 0.045324310660362244 train acc 0.9871277639087018
epoch 18 batch id 801 loss 0.031801044940948486 train acc 0.9873595505617978
epoch 18 batch id 901 loss 0.020688442513346672 train acc 0.9874011514983352
epoch 18 batch id 1001 loss 0.04279530793428421 train acc 0.9875124875124875

epoch 18 train acc 0.9875286697247706


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 18 test acc 1.0
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 19 batch id 1 loss 0.009583309292793274 train acc 1.0
epoch 19 batch id 101 loss 0.007001370657235384 train acc 0.9892481435643564
epoch 19 batch id 201 loss 0.04173023998737335 train acc 0.9895055970149254
epoch 19 batch id 301 loss 0.009943192824721336 train acc 0.9889171511627907
epoch 19 batch id 401 loss 0.03186549246311188 train acc 0.988953397755611
epoch 19 batch id 501 loss 0.011699432507157326 train acc 0.9889283932135728
epoch 19 batch id 601 loss 0.005728963762521744 train acc 0.9891326955074875
epoch 19 batch id 701 loss 0.005789606366306543 train acc 0.989256419400856
epoch 19 batch id 801 loss 0.02363506518304348 train acc 0.9894662921348315
epoch 19 batch id 901 loss 0.015153857879340649 train acc 0.9894648307436182
epoch 19 batch id 1001 loss 0.0033383190166205168 train acc 0.9895573176823177

epoch 19 train acc 0.9896502293577981


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 19 test acc 1.0
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 20 batch id 1 loss 0.010927174240350723 train acc 0.9921875
epoch 20 batch id 101 loss 0.02897021733224392 train acc 0.9897122524752475
epoch 20 batch id 201 loss 0.023424094542860985 train acc 0.9894278606965174
epoch 20 batch id 301 loss 0.027768488973379135 train acc 0.9892545681063123
epoch 20 batch id 401 loss 0.03541212156414986 train acc 0.989518391521197
epoch 20 batch id 501 loss 0.0026739768218249083 train acc 0.9896768962075848
epoch 20 batch id 601 loss 0.03836830332875252 train acc 0.989951643094842
epoch 20 batch id 701 loss 0.006728284526616335 train acc 0.9898582382310984
epoch 20 batch id 801 loss 0.01186633575707674 train acc 0.9900319912609239
epoch 20 batch id 901 loss 0.004408322740346193 train acc 0.9900457824639289
epoch 20 batch id 1001 loss 0.03912509232759476 train acc 0.9900646228771228

epoch 20 train acc 0.9900587729357798


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 20 test acc 1.0
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 21 batch id 1 loss 0.05350055545568466 train acc 0.984375
epoch 21 batch id 101 loss 0.008391736075282097 train acc 0.9905631188118812
epoch 21 batch id 201 loss 0.051236964762210846 train acc 0.9906327736318408
epoch 21 batch id 301 loss 0.07630635797977448 train acc 0.9899034468438538
epoch 21 batch id 401 loss 0.04930884763598442 train acc 0.9903561408977556
epoch 21 batch id 501 loss 0.0024028290063142776 train acc 0.9905189620758483
epoch 21 batch id 601 loss 0.049699023365974426 train acc 0.9905496048252912
epoch 21 batch id 701 loss 0.061080366373062134 train acc 0.9904712018544936
epoch 21 batch id 801 loss 0.03421863541007042 train acc 0.9907830056179775
epoch 21 batch id 901 loss 0.02195429429411888 train acc 0.9908261653718091
epoch 21 batch id 1001 loss 0.008246098645031452 train acc 0.9907670454545454

epoch 21 train acc 0.9908400229357798


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 21 test acc 1.0
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 22 batch id 1 loss 0.11161112785339355 train acc 0.96875
epoch 22 batch id 101 loss 0.014892689883708954 train acc 0.9914139851485149
epoch 22 batch id 201 loss 0.0270235538482666 train acc 0.9912935323383084
epoch 22 batch id 301 loss 0.011452407576143742 train acc 0.9913309800664452
epoch 22 batch id 401 loss 0.04321344569325447 train acc 0.9915056109725686
epoch 22 batch id 501 loss 0.0027701924555003643 train acc 0.9917664670658682
epoch 22 batch id 601 loss 0.007477934006601572 train acc 0.9916675332778702
epoch 22 batch id 701 loss 0.033737048506736755 train acc 0.9917082738944365
epoch 22 batch id 801 loss 0.05203976854681969 train acc 0.9919534176029963
epoch 22 batch id 901 loss 0.004089305177330971 train acc 0.991953385127636
epoch 22 batch id 1001 loss 0.00581613602116704 train acc 0.9918518981018981

epoch 22 train acc 0.9917574541284404


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 22 test acc 1.0
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 23 batch id 1 loss 0.00591482687741518 train acc 1.0
epoch 23 batch id 101 loss 0.001786773675121367 train acc 0.9917233910891089
epoch 23 batch id 201 loss 0.0368562713265419 train acc 0.9919542910447762
epoch 23 batch id 301 loss 0.01507573388516903 train acc 0.9914607558139535
epoch 23 batch id 401 loss 0.028899049386382103 train acc 0.9919537094763092
epoch 23 batch id 501 loss 0.001698757172562182 train acc 0.9922186876247505
epoch 23 batch id 601 loss 0.06802044063806534 train acc 0.992239496672213
epoch 23 batch id 701 loss 0.012097359634935856 train acc 0.9922989479315264
epoch 23 batch id 801 loss 0.020564820617437363 train acc 0.9923728152309613
epoch 23 batch id 901 loss 0.019954364746809006 train acc 0.9923349056603774
epoch 23 batch id 1001 loss 0.07216832041740417 train acc 0.9922187187812188

epoch 23 train acc 0.9921516628440367


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 23 test acc 1.0
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 24 batch id 1 loss 0.09105177223682404 train acc 0.984375
epoch 24 batch id 101 loss 0.02037147805094719 train acc 0.9913366336633663
epoch 24 batch id 201 loss 0.029290063306689262 train acc 0.9919154228855721
epoch 24 batch id 301 loss 0.0170882698148489 train acc 0.99187603820598
epoch 24 batch id 401 loss 0.04669230431318283 train acc 0.9921875
epoch 24 batch id 501 loss 0.008823946118354797 train acc 0.9924058133732535
epoch 24 batch id 601 loss 0.03560088947415352 train acc 0.9924084858569051
epoch 24 batch id 701 loss 0.013044072315096855 train acc 0.9924772646219686
epoch 24 batch id 801 loss 0.017100505530834198 train acc 0.9924410892634207
epoch 24 batch id 901 loss 0.0011715575819835067 train acc 0.9924823113207547
epoch 24 batch id 1001 loss 0.00925537571310997 train acc 0.9924762737262737

epoch 24 train acc 0.9924526949541285


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 24 test acc 1.0
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 25 batch id 1 loss 0.012756844982504845 train acc 0.9921875
epoch 25 batch id 101 loss 0.046230364590883255 train acc 0.9919554455445545
epoch 25 batch id 201 loss 0.024195943027734756 train acc 0.9925761815920398
epoch 25 batch id 301 loss 0.009878363460302353 train acc 0.9924210963455149
epoch 25 batch id 401 loss 0.053736455738544464 train acc 0.9924212905236908
epoch 25 batch id 501 loss 0.0018295206828042865 train acc 0.9928268463073853
epoch 25 batch id 601 loss 0.04005928710103035 train acc 0.9927854617304492
epoch 25 batch id 701 loss 0.04250267520546913 train acc 0.9929342011412268
epoch 25 batch id 801 loss 0.014277215115725994 train acc 0.9929872815230961
epoch 25 batch id 901 loss 0.0019085027743130922 train acc 0.9929158573806881
epoch 25 batch id 1001 loss 0.057211216539144516 train acc 0.9928743131868132

epoch 25 train acc 0.9927608944954128


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 25 test acc 1.0
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 26 batch id 1 loss 0.03599162399768829 train acc 0.984375
epoch 26 batch id 101 loss 0.013513141311705112 train acc 0.9935798267326733
epoch 26 batch id 201 loss 0.012774228118360043 train acc 0.9934312810945274
epoch 26 batch id 301 loss 0.020287102088332176 train acc 0.9928104235880398
epoch 26 batch id 401 loss 0.04798802360892296 train acc 0.9929083541147132
epoch 26 batch id 501 loss 0.0014022176619619131 train acc 0.9930763473053892
epoch 26 batch id 601 loss 0.03232946619391441 train acc 0.9931364392678869
epoch 26 batch id 701 loss 0.04264099523425102 train acc 0.9932574001426534
epoch 26 batch id 801 loss 0.01828615926206112 train acc 0.9932896379525593
epoch 26 batch id 901 loss 0.021757347509264946 train acc 0.9932453523862376
epoch 26 batch id 1001 loss 0.0098301125690341 train acc 0.9931786963036963

epoch 26 train acc 0.9930762614678899


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 26 test acc 1.0
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 27 batch id 1 loss 0.014168926514685154 train acc 0.984375
epoch 27 batch id 101 loss 0.03172614797949791 train acc 0.9923422029702971
epoch 27 batch id 201 loss 0.008075034245848656 train acc 0.9926539179104478
epoch 27 batch id 301 loss 0.018584053963422775 train acc 0.9925249169435216
epoch 27 batch id 401 loss 0.03910302370786667 train acc 0.9927330112219451
epoch 27 batch id 501 loss 0.002055867575109005 train acc 0.9930139720558883
epoch 27 batch id 601 loss 0.014721712097525597 train acc 0.9930454450915142
epoch 27 batch id 701 loss 0.00486795837059617 train acc 0.9933019793152639
epoch 27 batch id 801 loss 0.012741422280669212 train acc 0.9933676654182272
epoch 27 batch id 901 loss 0.007582757622003555 train acc 0.9933494034406215
epoch 27 batch id 1001 loss 0.03456367552280426 train acc 0.9932177197802198

epoch 27 train acc 0.9931622706422019


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 27 test acc 1.0
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 28 batch id 1 loss 0.03645220398902893 train acc 0.9921875
epoch 28 batch id 101 loss 0.012981239706277847 train acc 0.9928836633663366
epoch 28 batch id 201 loss 0.013914299197494984 train acc 0.9930814676616916
epoch 28 batch id 301 loss 0.019570374861359596 train acc 0.9930180647840532
epoch 28 batch id 401 loss 0.04563343897461891 train acc 0.9933369700748129
epoch 28 batch id 501 loss 0.0017289485549554229 train acc 0.9933570359281437
epoch 28 batch id 601 loss 0.005634106695652008 train acc 0.993409421797005
epoch 28 batch id 701 loss 0.011784291826188564 train acc 0.9936028887303852
epoch 28 batch id 801 loss 0.011079435236752033 train acc 0.9935724875156055
epoch 28 batch id 901 loss 0.004104245454072952 train acc 0.9935488346281909
epoch 28 batch id 1001 loss 0.039856452494859695 train acc 0.9935299075924076

epoch 28 train acc 0.9934561353211009


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 28 test acc 1.0
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 29 batch id 1 loss 0.025547953322529793 train acc 0.984375
epoch 29 batch id 101 loss 0.009342765435576439 train acc 0.9931157178217822
epoch 29 batch id 201 loss 0.0056269969791173935 train acc 0.9930037313432836
epoch 29 batch id 301 loss 0.01134889479726553 train acc 0.9930699750830565
epoch 29 batch id 401 loss 0.03156683221459389 train acc 0.9934149002493765
epoch 29 batch id 501 loss 0.0012113861739635468 train acc 0.9935909431137725
epoch 29 batch id 601 loss 0.006763978395611048 train acc 0.993409421797005
epoch 29 batch id 701 loss 0.004724093712866306 train acc 0.9936363231098431
epoch 29 batch id 801 loss 0.009990357793867588 train acc 0.9936992821473158
epoch 29 batch id 901 loss 0.004290229640901089 train acc 0.9936268729189789
epoch 29 batch id 1001 loss 0.023840848356485367 train acc 0.9936625874125874

epoch 29 train acc 0.9936281536697248


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 29 test acc 1.0
current patience: 0
************************************************************************************


HBox(children=(FloatProgress(value=0.0, max=1090.0), HTML(value='')))

epoch 30 batch id 1 loss 0.0341605618596077 train acc 0.984375
epoch 30 batch id 101 loss 0.009319574572145939 train acc 0.9939665841584159
epoch 30 batch id 201 loss 0.006309538148343563 train acc 0.9940531716417911
epoch 30 batch id 301 loss 0.02153102681040764 train acc 0.9938486295681063
epoch 30 batch id 401 loss 0.019820481538772583 train acc 0.994096789276808
epoch 30 batch id 501 loss 0.0011880180099979043 train acc 0.9942926646706587
epoch 30 batch id 601 loss 0.005668388679623604 train acc 0.9942413685524126
epoch 30 batch id 701 loss 0.0036351950839161873 train acc 0.9942938659058488
epoch 30 batch id 801 loss 0.00866598729044199 train acc 0.9941576935081149
epoch 30 batch id 901 loss 0.006880796980112791 train acc 0.9941904827968924
epoch 30 batch id 1001 loss 0.016397107392549515 train acc 0.9942479395604396

epoch 30 train acc 0.994129873853211


HBox(children=(FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch 30 test acc 1.0
current patience: 0
************************************************************************************


In [None]:
class BERTDataset1(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair) 

        self.sentences = [transform([i]) for i in dataset['data']]
        self.labels = [np.int32(0) for i in dataset['data']]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))
        #return (self.sentences[i] + (0, ))

    def __len__(self):
        return (len(self.sentences))


test_set = BERTDataset1(test, 3, 2, tok, max_len, True, False)
test_input = torch.utils.data.DataLoader(test_set, batch_size=1, num_workers=2)

In [None]:
model = BERTClassifier(bertmodel, dr_rate=0.5).to(device)

checkpoint = torch.load('/content/drive/My Drive/mulcamnlp2021/torchckpt/model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

In [None]:
model.eval()

result = []

for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_input)):
  token_ids = token_ids.long().to(device)
  segment_ids = segment_ids.long().to(device)
  valid_length= valid_length
  label = label.long().to(device)
  out = model(token_ids, valid_length, segment_ids)
  for i in out:
      logits = i
      logits = logits.detach().cpu().numpy()
      final = np.argmax(logits)
  result.append(final)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  """


HBox(children=(FloatProgress(value=0.0, max=43576.0), HTML(value='')))




In [None]:
submission = pd.read_csv('/content/drive/My Drive/mulcamnlp2021/datain/sample_submission.csv')

submission['label'] = np.array(result)

submission.to_csv('/content/drive/My Drive/mulcamnlp2021/datain/kobert_baseline.csv', index=False)