In [None]:
from keras import backend as K

def f1(y_true, y_pred):
    def recall(y_true, y_pred):
        """Recall metric.

        Only computes a batch-wise average of recall.

        Computes the recall, a metric for multi-label classification of
        how many relevant items are selected.
        """
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
        recall = true_positives / (possible_positives + K.epsilon())
        return recall

    def precision(y_true, y_pred):
        """Precision metric.

        Only computes a batch-wise average of precision.

        Computes the precision, a metric for multi-label classification of
        how many selected items are relevant.
        """
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
        precision = true_positives / (predicted_positives + K.epsilon())
        return precision
    precision = precision(y_true, y_pred)
    recall = recall(y_true, y_pred)
    return 2*((precision*recall)/(precision+recall+K.epsilon()))


In [1]:
!pip install mxnet
!pip install gluonnlp pandas tqdm
!pip install sentencepiece
!pip install transformers==3 # 최신 버전으로 설치하면 "Input: must be Tensor, not str" 라는 에러 발생
!pip install torch

!pip install git+https://git@github.com/SKTBrain/KoBERT.git@master


import re
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm, tqdm_notebook

from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model

from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup

torch.manual_seed(123)
np.random.seed(123)

##GPU 사용 시
device = torch.device("cuda:0")

bertmodel, vocab = get_pytorch_kobert_model()

from google.colab import drive
drive.mount('/content/drive')\

Collecting mxnet
[?25l  Downloading https://files.pythonhosted.org/packages/30/07/66174e78c12a3048db9039aaa09553e35035ef3a008ba3e0ed8d2aa3c47b/mxnet-1.8.0.post0-py2.py3-none-manylinux2014_x86_64.whl (46.9MB)
[K     |████████████████████████████████| 46.9MB 114kB/s 
[?25hCollecting graphviz<0.9.0,>=0.8.1
  Downloading https://files.pythonhosted.org/packages/53/39/4ab213673844e0c004bed8a0781a0721a3f6bb23eb8854ee75c236428892/graphviz-0.8.4-py2.py3-none-any.whl
Installing collected packages: graphviz, mxnet
  Found existing installation: graphviz 0.10.1
    Uninstalling graphviz-0.10.1:
      Successfully uninstalled graphviz-0.10.1
Successfully installed graphviz-0.8.4 mxnet-1.8.0.post0
Collecting gluonnlp
[?25l  Downloading https://files.pythonhosted.org/packages/9c/81/a238e47ccba0d7a61dcef4e0b4a7fd4473cb86bed3d84dd4fe28d45a0905/gluonnlp-0.10.0.tar.gz (344kB)
[K     |████████████████████████████████| 348kB 4.3MB/s 
Building wheels for collected packages: gluonnlp
  Building wheel fo

In [2]:
# 학습용 데이터셋 불러오기
import pandas as pd
train = pd.read_csv('/content/drive/MyDrive/Dacon/train.csv')
test = pd.read_csv('/content/drive/MyDrive/Dacon/test.csv')

# 데이터 전처리
train=train[['과제명', '요약문_연구내용','label']]
test=test[['과제명', '요약문_연구내용']]
train['요약문_연구내용'].fillna('NAN', inplace=True)
test['요약문_연구내용'].fillna('NAN', inplace=True)

# 대표 질병 5개 추출
train['data']=train['과제명']+train['요약문_연구내용']
test['data']=test['과제명']+test['요약문_연구내용']

def clean_text(sent):
    sent_clean=re.sub("[^가-힣ㄱ-하-ㅣ]", " ", sent)
    return sent_clean

# Train / Test set 분리
from sklearn.model_selection import train_test_split
trains, tests = train_test_split(train, test_size=0.2, random_state=123)
print("train shape is:", len(trains))
print("test shape is:", len(tests))


# 기본 Bert tokenizer 사용
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair) 

        self.sentences = [transform([clean_text(i)]) for i in dataset['data']]
        self.labels = [np.int32(i) for i in dataset['label']]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))
        
# Setting parameters
max_len = 64 # 해당 길이를 초과하는 단어에 대해선 bert가 학습하지 않음
batch_size = 64
warmup_ratio = 0.1
num_epochs = 30
max_grad_norm = 1
log_interval = 200
learning_rate = 3e-5


data_train = BERTDataset(trains, 3, 2, tok, max_len, True, False)
data_test = BERTDataset(tests, 3, 2, tok, max_len, True, False)

# pytorch용 DataLoader 사용
train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, num_workers=2)
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=2)

train shape is: 139443
test shape is: 34861
using cached model


In [3]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes = 46, # softmax 사용 <- binary일 경우는 2
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)


In [None]:
model = BERTClassifier(bertmodel, dr_rate=0.5).to(device)

# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

# 옵티마이저 선언
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = nn.CrossEntropyLoss() # softmax용 Loss Function 정하기 <- binary classification도 해당 loss function 사용 가능

t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)


# 학습 평가 지표인 accuracy 계산 -> 얼마나 타겟값을 많이 맞추었는가
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc
  

In [5]:
highest_acc = 0
patience = 0

# 모델 학습 시작
for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # gradient clipping
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))

    #테스트 문장 예측
    
    model.eval() # 평가 모드로 변경
    
    for test_batch_id, (test_token_ids, test_valid_length, test_segment_ids, test_label) in enumerate(tqdm_notebook(test_dataloader)):
        test_token_ids = test_token_ids.long().to(device)
        test_segment_ids = test_segment_ids.long().to(device)z
        test_valid_length= test_valid_length
        test_label = test_label.long().to(device)
        test_out = model(token_ids, valid_length, segment_ids)
        test_loss = loss_fn(out, label)
        test_acc += calc_accuracy(out, label)
    print("epoch {} test acc {}".format(e+1, test_acc / (test_batch_id+1)))

    if test_acc > highest_acc:
        torch.save({
            'epoch': e,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': test_loss,
            }, '/content/drive/MyDrive/Dacon/torchckpt/model.pt')
        patience = 0
    else:
        print("test acc did not improved. best:{} current:{}".format(highest_acc, test_acc))
        patience += 1
        if patience > 5:
            break
    print('current patience: {}'.format(patience))
    print("************************************************************************************")


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  if __name__ == '__main__':


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 1 batch id 1 loss 3.746753692626953 train acc 0.046875
epoch 1 batch id 201 loss 2.833305597305298 train acc 0.35152363184079605
epoch 1 batch id 401 loss 1.6726614236831665 train acc 0.5844373441396509
epoch 1 batch id 601 loss 0.941444456577301 train acc 0.6603057404326124
epoch 1 batch id 801 loss 1.0692002773284912 train acc 0.7006476279650437
epoch 1 batch id 1001 loss 1.3497064113616943 train acc 0.7236357392607392
epoch 1 batch id 1201 loss 0.9132053852081299 train acc 0.7399432764363031
epoch 1 batch id 1401 loss 1.078210711479187 train acc 0.7513048715203426
epoch 1 batch id 1601 loss 0.8389707207679749 train acc 0.7601401467832605
epoch 1 batch id 1801 loss 0.8923542499542236 train acc 0.7669610632981677
epoch 1 batch id 2001 loss 0.8088604807853699 train acc 0.7727932908545727

epoch 1 train acc 0.7774260375329572


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 1 test acc 0.9019607843137151


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 2 batch id 1 loss 0.7161873579025269 train acc 0.8125
epoch 2 batch id 201 loss 0.6478219032287598 train acc 0.8365205223880597
epoch 2 batch id 401 loss 0.8538629412651062 train acc 0.8398924563591023
epoch 2 batch id 601 loss 0.5060431957244873 train acc 0.8388103161397671
epoch 2 batch id 801 loss 0.5608747005462646 train acc 0.841916354556804
epoch 2 batch id 1001 loss 0.7570269107818604 train acc 0.8438904845154845
epoch 2 batch id 1201 loss 0.5110994577407837 train acc 0.8461438384679434
epoch 2 batch id 1401 loss 0.6903754472732544 train acc 0.8468504639543183
epoch 2 batch id 1601 loss 0.576520562171936 train acc 0.8486492816989382
epoch 2 batch id 1801 loss 0.5613363981246948 train acc 0.8500225569128262
epoch 2 batch id 2001 loss 0.45418351888656616 train acc 0.8513633808095952

epoch 2 train acc 0.8523565518451529


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 2 test acc 0.9411764705882323


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 3 batch id 1 loss 0.4548601806163788 train acc 0.875
epoch 3 batch id 201 loss 0.4855831265449524 train acc 0.8680037313432836
epoch 3 batch id 401 loss 0.5832599997520447 train acc 0.869856608478803
epoch 3 batch id 601 loss 0.3788071274757385 train acc 0.8694103577371048
epoch 3 batch id 801 loss 0.39459431171417236 train acc 0.8717423533083646
epoch 3 batch id 1001 loss 0.5814250707626343 train acc 0.8717376373626373
epoch 3 batch id 1201 loss 0.3396674394607544 train acc 0.8734257910074937
epoch 3 batch id 1401 loss 0.46280455589294434 train acc 0.8741858493932905
epoch 3 batch id 1601 loss 0.4102008044719696 train acc 0.8756636477201749
epoch 3 batch id 1801 loss 0.30783751606941223 train acc 0.8766050111049417
epoch 3 batch id 2001 loss 0.40754953026771545 train acc 0.877467516241879

epoch 3 train acc 0.8780744112697856


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 3 test acc 0.980392156862752


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 4 batch id 1 loss 0.3000028133392334 train acc 0.9375
epoch 4 batch id 201 loss 0.37646347284317017 train acc 0.8922574626865671
epoch 4 batch id 401 loss 0.44257235527038574 train acc 0.89214463840399
epoch 4 batch id 601 loss 0.31388139724731445 train acc 0.8914049500831946
epoch 4 batch id 801 loss 0.23440377414226532 train acc 0.8926927278401997
epoch 4 batch id 1001 loss 0.3362177312374115 train acc 0.8930756743256744
epoch 4 batch id 1201 loss 0.2426133155822754 train acc 0.8947621773522065
epoch 4 batch id 1401 loss 0.3846658766269684 train acc 0.8951641684511064
epoch 4 batch id 1601 loss 0.27625352144241333 train acc 0.8969394128669581
epoch 4 batch id 1801 loss 0.37761151790618896 train acc 0.8979473209328152
epoch 4 batch id 2001 loss 0.2523356080055237 train acc 0.8990036231884058

epoch 4 train acc 0.8998035222579165


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 4 test acc 1.0


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 5 batch id 1 loss 0.32684168219566345 train acc 0.875
epoch 5 batch id 201 loss 0.2846899926662445 train acc 0.9123134328358209
epoch 5 batch id 401 loss 0.3000118136405945 train acc 0.9115102867830424
epoch 5 batch id 601 loss 0.390426903963089 train acc 0.9114236688851913
epoch 5 batch id 801 loss 0.1587885320186615 train acc 0.9134480337078652
epoch 5 batch id 1001 loss 0.24528901278972626 train acc 0.9148195554445554
epoch 5 batch id 1201 loss 0.19768504798412323 train acc 0.9161896336386345
epoch 5 batch id 1401 loss 0.26349183917045593 train acc 0.9168228051391863
epoch 5 batch id 1601 loss 0.17049042880535126 train acc 0.9186543566520925
epoch 5 batch id 1801 loss 0.25545406341552734 train acc 0.9195932815102721
epoch 5 batch id 2001 loss 0.4104589819908142 train acc 0.920250812093953

epoch 5 train acc 0.9209194663409191


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 5 test acc 0.980392156862752


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 6 batch id 1 loss 0.133785679936409 train acc 0.9375
epoch 6 batch id 201 loss 0.3497188091278076 train acc 0.927782960199005
epoch 6 batch id 401 loss 0.22515493631362915 train acc 0.9296680174563591
epoch 6 batch id 601 loss 0.2501625120639801 train acc 0.9292325291181365
epoch 6 batch id 801 loss 0.18635913729667664 train acc 0.9302629525593009
epoch 6 batch id 1001 loss 0.18683794140815735 train acc 0.9309752747252747
epoch 6 batch id 1201 loss 0.14751401543617249 train acc 0.9325171731890092
epoch 6 batch id 1401 loss 0.2336578518152237 train acc 0.9330389007851535
epoch 6 batch id 1601 loss 0.1755433976650238 train acc 0.9346794971892567
epoch 6 batch id 1801 loss 0.240845188498497 train acc 0.9350447667962243
epoch 6 batch id 2001 loss 0.09947893768548965 train acc 0.9356025112443778

epoch 6 train acc 0.9359062722151734


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 6 test acc 0.980392156862752


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 7 batch id 1 loss 0.14195466041564941 train acc 0.96875
epoch 7 batch id 201 loss 0.4454287588596344 train acc 0.9420087064676617
epoch 7 batch id 401 loss 0.2792515456676483 train acc 0.9431499376558603
epoch 7 batch id 601 loss 0.15286588668823242 train acc 0.9430636439267887
epoch 7 batch id 801 loss 0.071864552795887 train acc 0.9452832397003745
epoch 7 batch id 1001 loss 0.22844162583351135 train acc 0.9454920079920079
epoch 7 batch id 1201 loss 0.033172883093357086 train acc 0.946698064113239
epoch 7 batch id 1401 loss 0.16712787747383118 train acc 0.9472698072805139
epoch 7 batch id 1601 loss 0.08882782608270645 train acc 0.9481183635227982
epoch 7 batch id 1801 loss 0.2688029408454895 train acc 0.9482405607995558
epoch 7 batch id 2001 loss 0.12338897585868835 train acc 0.9486975262368815

epoch 7 train acc 0.9487275204942004


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 7 test acc 0.980392156862752


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 8 batch id 1 loss 0.1089828833937645 train acc 0.984375
epoch 8 batch id 201 loss 0.22256456315517426 train acc 0.9501710199004975
epoch 8 batch id 401 loss 0.3196576237678528 train acc 0.9517612219451371
epoch 8 batch id 601 loss 0.09642317146062851 train acc 0.9521630615640599
epoch 8 batch id 801 loss 0.07157982885837555 train acc 0.9527933832709113
epoch 8 batch id 1001 loss 0.1652129739522934 train acc 0.9536401098901099
epoch 8 batch id 1201 loss 0.030277661979198456 train acc 0.9546601790174855
epoch 8 batch id 1401 loss 0.16873051226139069 train acc 0.9552997858672377
epoch 8 batch id 1601 loss 0.04226575419306755 train acc 0.9562285290443473
epoch 8 batch id 1801 loss 0.2830953299999237 train acc 0.9565258883953359
epoch 8 batch id 2001 loss 0.280060738325119 train acc 0.9567169540229885

epoch 8 train acc 0.956950509093036


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 8 test acc 0.9607843137254817


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 9 batch id 1 loss 0.059671029448509216 train acc 0.984375
epoch 9 batch id 201 loss 0.21403521299362183 train acc 0.9582555970149254
epoch 9 batch id 401 loss 0.16770794987678528 train acc 0.9587749376558603
epoch 9 batch id 601 loss 0.05295838043093681 train acc 0.9585066555740432
epoch 9 batch id 801 loss 0.0895855650305748 train acc 0.959640293383271
epoch 9 batch id 1001 loss 0.08348409086465836 train acc 0.9605550699300699
epoch 9 batch id 1201 loss 0.0797133818268776 train acc 0.9614253746877602
epoch 9 batch id 1401 loss 0.18394438922405243 train acc 0.9619468236973591
epoch 9 batch id 1601 loss 0.09826437383890152 train acc 0.9626990943160525
epoch 9 batch id 1801 loss 0.1611204296350479 train acc 0.9627810938367574
epoch 9 batch id 2001 loss 0.1691845804452896 train acc 0.963034107946027

epoch 9 train acc 0.9633432767324461


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 9 test acc 1.0


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 10 batch id 1 loss 0.01841643452644348 train acc 1.0
epoch 10 batch id 201 loss 0.14860746264457703 train acc 0.9650963930348259
epoch 10 batch id 401 loss 0.11354953050613403 train acc 0.9648534912718204
epoch 10 batch id 601 loss 0.008930115960538387 train acc 0.9650062396006656
epoch 10 batch id 801 loss 0.038416724652051926 train acc 0.966019038701623
epoch 10 batch id 1001 loss 0.1344355195760727 train acc 0.9663617632367633
epoch 10 batch id 1201 loss 0.05118847265839577 train acc 0.9668895711906744
epoch 10 batch id 1401 loss 0.16118215024471283 train acc 0.9673559064953604
epoch 10 batch id 1601 loss 0.04571232199668884 train acc 0.9680180356027482
epoch 10 batch id 1801 loss 0.16006360948085785 train acc 0.9682641588006663
epoch 10 batch id 2001 loss 0.238925039768219 train acc 0.9684298475762119

epoch 10 train acc 0.9685779027076641


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 10 test acc 1.0


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 11 batch id 1 loss 0.04738565534353256 train acc 0.984375
epoch 11 batch id 201 loss 0.1630384773015976 train acc 0.9692941542288557
epoch 11 batch id 401 loss 0.055518344044685364 train acc 0.9696851620947631
epoch 11 batch id 601 loss 0.0849698930978775 train acc 0.9697899334442596
epoch 11 batch id 801 loss 0.10697045922279358 train acc 0.9701544943820225
epoch 11 batch id 1001 loss 0.04338927939534187 train acc 0.9711694555444556
epoch 11 batch id 1201 loss 0.0988185703754425 train acc 0.9716772481265612
epoch 11 batch id 1401 loss 0.06960476189851761 train acc 0.9718058529621699
epoch 11 batch id 1601 loss 0.02074417471885681 train acc 0.9724098219862586
epoch 11 batch id 1801 loss 0.05097923055291176 train acc 0.9724024847307051
epoch 11 batch id 2001 loss 0.042564477771520615 train acc 0.9724278485757122

epoch 11 train acc 0.9726006769160165


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 11 test acc 1.0


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 12 batch id 1 loss 0.02083381451666355 train acc 1.0
epoch 12 batch id 201 loss 0.1473212093114853 train acc 0.9741138059701493
epoch 12 batch id 401 loss 0.2566220164299011 train acc 0.9742440773067331
epoch 12 batch id 601 loss 0.020185356959700584 train acc 0.9740536605657238
epoch 12 batch id 801 loss 0.09670394659042358 train acc 0.9746605805243446
epoch 12 batch id 1001 loss 0.027599522843956947 train acc 0.9748220529470529
epoch 12 batch id 1201 loss 0.016140524297952652 train acc 0.9750728559533722
epoch 12 batch id 1401 loss 0.07908423990011215 train acc 0.9754193433261956
epoch 12 batch id 1601 loss 0.010044471360743046 train acc 0.9759525296689568
epoch 12 batch id 1801 loss 0.12490025907754898 train acc 0.9761764297612437
epoch 12 batch id 2001 loss 0.05116100609302521 train acc 0.9761759745127436

epoch 12 train acc 0.9763509637448371


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 12 test acc 1.0


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 13 batch id 1 loss 0.005278869066387415 train acc 1.0
epoch 13 batch id 201 loss 0.17424818873405457 train acc 0.978467039800995
epoch 13 batch id 401 loss 0.05331182852387428 train acc 0.9779847256857855
epoch 13 batch id 601 loss 0.10386092215776443 train acc 0.9772514559068219
epoch 13 batch id 801 loss 0.08574889600276947 train acc 0.9774500624219725
epoch 13 batch id 1001 loss 0.01668880321085453 train acc 0.978380994005994
epoch 13 batch id 1201 loss 0.0072428323328495026 train acc 0.9782863238967527
epoch 13 batch id 1401 loss 0.019147703424096107 train acc 0.9785755710206995
epoch 13 batch id 1601 loss 0.017036980018019676 train acc 0.978968222985634
epoch 13 batch id 1801 loss 0.145708367228508 train acc 0.9790394225430317
epoch 13 batch id 2001 loss 0.10944550484418869 train acc 0.9791432408795602

epoch 13 train acc 0.979355495640202


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 13 test acc 1.0


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 14 batch id 1 loss 0.020343966782093048 train acc 1.0
epoch 14 batch id 201 loss 0.13569189608097076 train acc 0.9787002487562189
epoch 14 batch id 401 loss 0.07481526583433151 train acc 0.979231608478803
epoch 14 batch id 601 loss 0.05781889706850052 train acc 0.9786553660565723
epoch 14 batch id 801 loss 0.0049534570425748825 train acc 0.9795763108614233
epoch 14 batch id 1001 loss 0.17463354766368866 train acc 0.9801916833166833
epoch 14 batch id 1201 loss 0.05513133108615875 train acc 0.9805110324729392
epoch 14 batch id 1401 loss 0.10299401730298996 train acc 0.9806722876516774
epoch 14 batch id 1601 loss 0.005223841406404972 train acc 0.9810567613991256
epoch 14 batch id 1801 loss 0.046413034200668335 train acc 0.9812690866185453
epoch 14 batch id 2001 loss 0.0866517722606659 train acc 0.9814702023988006

epoch 14 train acc 0.9814851996328591


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 14 test acc 1.0


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 15 batch id 1 loss 0.0038802947383373976 train acc 1.0
epoch 15 batch id 201 loss 0.08390944451093674 train acc 0.9835199004975125
epoch 15 batch id 401 loss 0.0489528626203537 train acc 0.9829722568578554
epoch 15 batch id 601 loss 0.019501985982060432 train acc 0.9823471297836939
epoch 15 batch id 801 loss 0.1032957062125206 train acc 0.9823267790262172
epoch 15 batch id 1001 loss 0.07048313319683075 train acc 0.982736013986014
epoch 15 batch id 1201 loss 0.0274916123598814 train acc 0.9829439009159034
epoch 15 batch id 1401 loss 0.04755758121609688 train acc 0.9832708779443254
epoch 15 batch id 1601 loss 0.056261271238327026 train acc 0.9834673641474079
epoch 15 batch id 1801 loss 0.020650461316108704 train acc 0.9836982926152138
epoch 15 batch id 2001 loss 0.011080196127295494 train acc 0.983851824087956

epoch 15 train acc 0.9838300252409362


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 15 test acc 1.0


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 16 batch id 1 loss 0.003126733936369419 train acc 1.0
epoch 16 batch id 201 loss 0.04571525380015373 train acc 0.9848414179104478
epoch 16 batch id 401 loss 0.020187964662909508 train acc 0.9849205112219451
epoch 16 batch id 601 loss 0.027365129441022873 train acc 0.9846869800332779
epoch 16 batch id 801 loss 0.01768716610968113 train acc 0.9846676029962547
epoch 16 batch id 1001 loss 0.0058767106384038925 train acc 0.9851242507492507
epoch 16 batch id 1201 loss 0.09527122229337692 train acc 0.9851165695253955
epoch 16 batch id 1401 loss 0.010934336110949516 train acc 0.9853452890792291
epoch 16 batch id 1601 loss 0.003931816201657057 train acc 0.9855754216114928
epoch 16 batch id 1801 loss 0.006949355825781822 train acc 0.9856416574125486
epoch 16 batch id 2001 loss 0.022855915129184723 train acc 0.9857727386306847

epoch 16 train acc 0.9857876319412574


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 16 test acc 1.0


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 17 batch id 1 loss 0.015223503112792969 train acc 1.0
epoch 17 batch id 201 loss 0.07734429091215134 train acc 0.9870180348258707
epoch 17 batch id 401 loss 0.003234289586544037 train acc 0.9869077306733167
epoch 17 batch id 601 loss 0.007177378050982952 train acc 0.986662853577371
epoch 17 batch id 801 loss 0.015037070959806442 train acc 0.9870864544319601
epoch 17 batch id 1001 loss 0.0064973668195307255 train acc 0.9873251748251748
epoch 17 batch id 1201 loss 0.006410971283912659 train acc 0.9872111781848459
epoch 17 batch id 1401 loss 0.007248562760651112 train acc 0.9872635617416131
epoch 17 batch id 1601 loss 0.012115886434912682 train acc 0.9873809337913804
epoch 17 batch id 1801 loss 0.009412660263478756 train acc 0.9874982648528595
epoch 17 batch id 2001 loss 0.010394535958766937 train acc 0.9875921414292853

epoch 17 train acc 0.9874655805415328


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 17 test acc 1.0


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 18 batch id 1 loss 0.0028574944008141756 train acc 1.0
epoch 18 batch id 201 loss 0.0033608663361519575 train acc 0.9885727611940298
epoch 18 batch id 401 loss 0.0031584177631884813 train acc 0.9876870324189526
epoch 18 batch id 601 loss 0.002877212595194578 train acc 0.9876507903494176
epoch 18 batch id 801 loss 0.06775888055562973 train acc 0.9876911672908864
epoch 18 batch id 1001 loss 0.007554926909506321 train acc 0.988292957042957
epoch 18 batch id 1201 loss 0.014088032767176628 train acc 0.9882389675270608
epoch 18 batch id 1401 loss 0.06271469593048096 train acc 0.9885684332619558
epoch 18 batch id 1601 loss 0.049903277307748795 train acc 0.988747267332917
epoch 18 batch id 1801 loss 0.00366342649795115 train acc 0.9887128678511938
epoch 18 batch id 2001 loss 0.01551581546664238 train acc 0.988677536231884

epoch 18 train acc 0.9886200665442864


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 18 test acc 1.0


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 19 batch id 1 loss 0.0013852736447006464 train acc 1.0
epoch 19 batch id 201 loss 0.020034540444612503 train acc 0.9880286069651741
epoch 19 batch id 401 loss 0.02465527132153511 train acc 0.9881156483790524
epoch 19 batch id 601 loss 0.0023251958191394806 train acc 0.9886647254575707
epoch 19 batch id 801 loss 0.10688649117946625 train acc 0.9892322097378277
epoch 19 batch id 1001 loss 0.004281818401068449 train acc 0.9893231768231768
epoch 19 batch id 1201 loss 0.00811002403497696 train acc 0.9894098667776853
epoch 19 batch id 1401 loss 0.009018558077514172 train acc 0.9896390970735189
epoch 19 batch id 1601 loss 0.03905987739562988 train acc 0.9898500936914428
epoch 19 batch id 1801 loss 0.04634706303477287 train acc 0.9899361465852304
epoch 19 batch id 2001 loss 0.005630082916468382 train acc 0.9900049975012494

epoch 19 train acc 0.9899968448829738


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 19 test acc 1.0


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 20 batch id 1 loss 0.0020739620085805655 train acc 1.0
epoch 20 batch id 201 loss 0.004213545937091112 train acc 0.9899720149253731
epoch 20 batch id 401 loss 0.004050294402986765 train acc 0.9900639027431422
epoch 20 batch id 601 loss 0.0025070500560104847 train acc 0.9900686356073212
epoch 20 batch id 801 loss 0.025433406233787537 train acc 0.9902660736579276
epoch 20 batch id 1001 loss 0.019147612154483795 train acc 0.9906187562437563
epoch 20 batch id 1201 loss 0.007384947035461664 train acc 0.9905417360532889
epoch 20 batch id 1401 loss 0.002917373552918434 train acc 0.9907320663811563
epoch 20 batch id 1601 loss 0.03758469969034195 train acc 0.9908748438475953
epoch 20 batch id 1801 loss 0.001153155928477645 train acc 0.9909425319267073
epoch 20 batch id 2001 loss 0.008762332610785961 train acc 0.9910044977511244

epoch 20 train acc 0.9911011358421294


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 20 test acc 1.0


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 21 batch id 1 loss 0.001535283518023789 train acc 1.0
epoch 21 batch id 201 loss 0.00967868510633707 train acc 0.9915267412935324
epoch 21 batch id 401 loss 0.0015946095809340477 train acc 0.9913107855361596
epoch 21 batch id 601 loss 0.07451540976762772 train acc 0.9910825707154742
epoch 21 batch id 801 loss 0.025506949052214622 train acc 0.9911438826466916
epoch 21 batch id 1001 loss 0.030207108706235886 train acc 0.9916489760239761
epoch 21 batch id 1201 loss 0.03211531788110733 train acc 0.9914914654454621
epoch 21 batch id 1401 loss 0.0014359628548845649 train acc 0.9915350642398287
epoch 21 batch id 1601 loss 0.010416851378977299 train acc 0.9916458463460337
epoch 21 batch id 1801 loss 0.004608296323567629 train acc 0.9916539422543031
epoch 21 batch id 2001 loss 0.059804949909448624 train acc 0.9916916541729135

epoch 21 train acc 0.9917249885268472


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 21 test acc 1.0


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 22 batch id 1 loss 0.0012462002923712134 train acc 1.0
epoch 22 batch id 201 loss 0.002672388218343258 train acc 0.992148631840796
epoch 22 batch id 401 loss 0.0015729257138445973 train acc 0.9924407730673317
epoch 22 batch id 601 loss 0.03944787010550499 train acc 0.9924084858569051
epoch 22 batch id 801 loss 0.0024486558977514505 train acc 0.9924313358302123
epoch 22 batch id 1001 loss 0.0011479586828500032 train acc 0.9926323676323676
epoch 22 batch id 1201 loss 0.0016265443991869688 train acc 0.9924932348043297
epoch 22 batch id 1401 loss 0.0007927644182927907 train acc 0.992706102783726
epoch 22 batch id 1601 loss 0.0035845006350427866 train acc 0.9927486727045597
epoch 22 batch id 1801 loss 0.0009490923839621246 train acc 0.9926429761243754
epoch 22 batch id 2001 loss 0.02173277921974659 train acc 0.9926208770614693

epoch 22 train acc 0.9926069871500688


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 22 test acc 1.0


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 23 batch id 1 loss 0.0013854113640263677 train acc 1.0
epoch 23 batch id 201 loss 0.0014092159690335393 train acc 0.9927705223880597
epoch 23 batch id 401 loss 0.016380595043301582 train acc 0.9924018079800498
epoch 23 batch id 601 loss 0.003972676116973162 train acc 0.9919665141430949
epoch 23 batch id 801 loss 0.013212691061198711 train acc 0.9922167602996255
epoch 23 batch id 1001 loss 0.0016621070681139827 train acc 0.9926948051948052
epoch 23 batch id 1201 loss 0.12245899438858032 train acc 0.9926233347210658
epoch 23 batch id 1401 loss 0.0011150866048410535 train acc 0.9926949500356888
epoch 23 batch id 1601 loss 0.009471889585256577 train acc 0.9928365084322298
epoch 23 batch id 1801 loss 0.013611676171422005 train acc 0.9928598695169351
epoch 23 batch id 2001 loss 0.10975956171751022 train acc 0.9927926661669165

epoch 23 train acc 0.9927647430013767


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 23 test acc 1.0


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 24 batch id 1 loss 0.0008498056558892131 train acc 1.0
epoch 24 batch id 201 loss 0.018474554643034935 train acc 0.9926150497512438
epoch 24 batch id 401 loss 0.0013391821412369609 train acc 0.9926355985037406
epoch 24 batch id 601 loss 0.010936247184872627 train acc 0.9928244592346089
epoch 24 batch id 801 loss 0.0030405274592339993 train acc 0.9928995006242197
epoch 24 batch id 1001 loss 0.0009107948280870914 train acc 0.9932099150849151
epoch 24 batch id 1201 loss 0.018326377496123314 train acc 0.9930396544546212
epoch 24 batch id 1401 loss 0.005890007596462965 train acc 0.992996074232691
epoch 24 batch id 1601 loss 0.05669831484556198 train acc 0.9931097751405371
epoch 24 batch id 1801 loss 0.006409929133951664 train acc 0.9931982232093282
epoch 24 batch id 2001 loss 0.008485089987516403 train acc 0.9931596701649176

epoch 24 train acc 0.9932308398347865


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 24 test acc 1.0


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 25 batch id 1 loss 0.0011035238858312368 train acc 1.0
epoch 25 batch id 201 loss 0.0008124569430947304 train acc 0.9939365671641791
epoch 25 batch id 401 loss 0.0011447641300037503 train acc 0.9932980049875312
epoch 25 batch id 601 loss 0.005341087933629751 train acc 0.9931364392678869
epoch 25 batch id 801 loss 0.010990194045007229 train acc 0.9932311173533084
epoch 25 batch id 1001 loss 0.0008538949768990278 train acc 0.9934284465534465
epoch 25 batch id 1201 loss 0.0118870185688138 train acc 0.9935340341382182
epoch 25 batch id 1401 loss 0.1570751667022705 train acc 0.9934979478943612
epoch 25 batch id 1601 loss 0.02138819731771946 train acc 0.9935196752029981
epoch 25 batch id 1801 loss 0.024842269718647003 train acc 0.9935539283731261
epoch 25 batch id 2001 loss 0.008546793833374977 train acc 0.9934798225887056

epoch 25 train acc 0.9934459614502065


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 25 test acc 1.0


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 26 batch id 1 loss 0.0007226355373859406 train acc 1.0
epoch 26 batch id 201 loss 0.002518701832741499 train acc 0.9939365671641791
epoch 26 batch id 401 loss 0.0033181835897266865 train acc 0.9937266209476309
epoch 26 batch id 601 loss 0.005959045607596636 train acc 0.9936044093178037
epoch 26 batch id 801 loss 0.013550963252782822 train acc 0.9932506242197253
epoch 26 batch id 1001 loss 0.0018335239728912711 train acc 0.9935221028971029
epoch 26 batch id 1201 loss 0.006302074529230595 train acc 0.9936771440466278
epoch 26 batch id 1401 loss 0.09553298354148865 train acc 0.9937544610992148
epoch 26 batch id 1601 loss 0.015773648396134377 train acc 0.993773422860712
epoch 26 batch id 1801 loss 0.00128120556473732 train acc 0.9938055247084953
epoch 26 batch id 2001 loss 0.18598362803459167 train acc 0.9936281859070465

epoch 26 train acc 0.9936324001835705


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 26 test acc 1.0


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 27 batch id 1 loss 0.0006270776502788067 train acc 1.0
epoch 27 batch id 201 loss 0.07446005940437317 train acc 0.9938588308457711
epoch 27 batch id 401 loss 0.0006912374519743025 train acc 0.9932980049875312
epoch 27 batch id 601 loss 0.001055510831065476 train acc 0.9931364392678869
epoch 27 batch id 801 loss 0.03867262601852417 train acc 0.9930945692883895
epoch 27 batch id 1001 loss 0.0009674589964561164 train acc 0.9933347902097902
epoch 27 batch id 1201 loss 0.02673119679093361 train acc 0.9934169442131557
epoch 27 batch id 1401 loss 0.13527408242225647 train acc 0.9935537116345468
epoch 27 batch id 1601 loss 0.024507388472557068 train acc 0.9935587133041849
epoch 27 batch id 1801 loss 0.02165064588189125 train acc 0.9935626041088285
epoch 27 batch id 2001 loss 0.011639207601547241 train acc 0.9934720139930034

epoch 27 train acc 0.9934889857732905


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 27 test acc 1.0


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 28 batch id 1 loss 0.0005779717466793954 train acc 1.0
epoch 28 batch id 201 loss 0.07959127426147461 train acc 0.9935478855721394
epoch 28 batch id 401 loss 0.0007456211023963988 train acc 0.9930642144638404
epoch 28 batch id 601 loss 0.023279773071408272 train acc 0.9930844425956739
epoch 28 batch id 801 loss 0.0020544847939163446 train acc 0.9934066791510612
epoch 28 batch id 1001 loss 0.0006036842241883278 train acc 0.9936001498501499
epoch 28 batch id 1201 loss 0.011945932172238827 train acc 0.9936120940882598
epoch 28 batch id 1401 loss 0.048132095485925674 train acc 0.9937544610992148
epoch 28 batch id 1601 loss 0.01524168811738491 train acc 0.9937539038101186
epoch 28 batch id 1801 loss 0.0034359737765043974 train acc 0.9936493614658523
epoch 28 batch id 2001 loss 0.009691826067864895 train acc 0.9936672288855573

epoch 28 train acc 0.9937256195502524


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 28 test acc 1.0


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 29 batch id 1 loss 0.0007039100164547563 train acc 1.0
epoch 29 batch id 201 loss 0.01764519326388836 train acc 0.9936256218905473
epoch 29 batch id 401 loss 0.000551644538063556 train acc 0.9933759351620948
epoch 29 batch id 601 loss 0.008981674909591675 train acc 0.9930584442595674
epoch 29 batch id 801 loss 0.0018086571944877505 train acc 0.9934456928838952
epoch 29 batch id 1001 loss 0.0005105070886202157 train acc 0.9936938061938062
epoch 29 batch id 1201 loss 0.012225037440657616 train acc 0.9937291840133222
epoch 29 batch id 1401 loss 0.06413360685110092 train acc 0.9939440578158458
epoch 29 batch id 1601 loss 0.02267826348543167 train acc 0.9939393347907558
epoch 29 batch id 1801 loss 0.004432546440511942 train acc 0.9939530122154359
epoch 29 batch id 2001 loss 0.03234993293881416 train acc 0.9939483383308346

epoch 29 train acc 0.9939479118861864


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 29 test acc 1.0


HBox(children=(FloatProgress(value=0.0, max=2179.0), HTML(value='')))

epoch 30 batch id 1 loss 0.0005420205998234451 train acc 1.0
epoch 30 batch id 201 loss 0.0223513413220644 train acc 0.9941697761194029
epoch 30 batch id 401 loss 0.0005109423073008657 train acc 0.9938824812967582
epoch 30 batch id 601 loss 0.0033700866624712944 train acc 0.9936304076539102
epoch 30 batch id 801 loss 0.0024389263708144426 train acc 0.9937968164794008
epoch 30 batch id 1001 loss 0.0006548871169798076 train acc 0.9938186813186813
epoch 30 batch id 1201 loss 0.01670895516872406 train acc 0.9939243338884263
epoch 30 batch id 1401 loss 0.0039558419957757 train acc 0.9941336545324768
epoch 30 batch id 1601 loss 0.01473530288785696 train acc 0.9941638038725796
epoch 30 batch id 1801 loss 0.003709059441462159 train acc 0.9941612298722932
epoch 30 batch id 2001 loss 0.013545439578592777 train acc 0.9941669790104948

epoch 30 train acc 0.9941558627810922


HBox(children=(FloatProgress(value=0.0, max=545.0), HTML(value='')))


epoch 30 test acc 1.0


In [20]:
class BERTDataset1(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair) 

        self.sentences = [transform([clean_text(i)]) for i in dataset['data']]
        self.labels = [np.int32(0) for i in dataset['data']]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))
        #return (self.sentences[i] + (0, ))

    def __len__(self):
        return (len(self.sentences))


test_set = BERTDataset1(test, 3, 2, tok, max_len, True, False)
test_input = torch.utils.data.DataLoader(test_set, batch_size=1, num_workers=2)

In [None]:
model = BERTClassifier(bertmodel, dr_rate=0.5).to(device)

checkpoint = torch.load('/content/drive/MyDrive/Dacon/torchckpt/model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

In [22]:
# import torch, gc
# gc.collect()
# torch.cuda.empty_cache()

model.eval()


result = []

for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_input)):
  token_ids = token_ids.long().to(device)
  segment_ids = segment_ids.long().to(device)
  valid_length= valid_length
  label = label.long().to(device)
  out = model(token_ids, valid_length, segment_ids)
  for i in out:
      logits = i
      logits = logits.detach().cpu().numpy()
      final = np.argmax(logits)
  result.append(final)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  # Remove the CWD from sys.path while we load stuff.


HBox(children=(FloatProgress(value=0.0, max=43576.0), HTML(value='')))




In [23]:
print(result)

[0, 0, 0, 0, 1, 16, 0, 0, 0, 14, 0, 0, 23, 0, 0, 19, 0, 0, 0, 0, 27, 0, 0, 45, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 37, 0, 0, 0, 0, 0, 29, 0, 0, 0, 0, 0, 0, 34, 0, 0, 1, 0, 5, 14, 0, 0, 0, 0, 0, 0, 43, 19, 0, 0, 0, 0, 0, 0, 43, 0, 0, 0, 0, 0, 0, 0, 19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 25, 18, 0, 0, 0, 34, 0, 0, 0, 0, 43, 0, 0, 0, 43, 0, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0, 45, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 27, 0, 0, 0, 0, 0, 0, 0, 19, 0, 0, 0, 0, 0, 19, 0, 23, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 24, 0, 0, 0, 0, 0, 0, 0, 0, 29, 0, 8, 19, 19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 23, 0, 19, 0, 19, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 19, 0, 0, 19, 0, 0, 0, 0, 10, 0, 0, 20, 0, 0, 45, 0, 0, 0, 45, 0, 36, 0, 0, 0, 18, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32, 45, 2, 0, 0, 0, 14, 0, 0, 24, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 14, 0, 8, 0, 0, 18, 0, 0, 0, 0, 0, 0, 5, 0, 23, 0, 38, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 19, 0, 0, 0, 0,

In [26]:
submission = pd.read_csv('/content/drive/MyDrive/Dacon/sample_submission.csv')

submission['label'] = np.array(result)

submission.to_csv('/content/drive/MyDrive/Dacon/kobert_baseline.csv', index=False)