In [1]:
!pip install mxnet-cu101
!pip install gluonnlp pandas tqdm
!pip install sentencepiece==0.1.85
!pip install transformers==2.1.1
!pip install torch==1.3.1



In [2]:
!pip install git+https://git@github.com/SKTBrain/KoBERT.git@master

Collecting git+https://****@github.com/SKTBrain/KoBERT.git@master
  Cloning https://****@github.com/SKTBrain/KoBERT.git (to revision master) to /tmp/pip-req-build-3j67522r
  Running command git clone -q 'https://****@github.com/SKTBrain/KoBERT.git' /tmp/pip-req-build-3j67522r
Building wheels for collected packages: kobert
  Building wheel for kobert (setup.py) ... [?25l[?25hdone
  Created wheel for kobert: filename=kobert-0.1.1-cp36-none-any.whl size=12820 sha256=f6ee4400c65e465cc7ee83b6b02b096151f36a58d74106ac51660817f92eb046
  Stored in directory: /tmp/pip-ephem-wheel-cache-o0cnxmem/wheels/a2/b0/41/435ee4e918f91918be41529283c5ff86cd010f02e7525aecf3
Successfully built kobert


In [3]:
from google.colab import drive
drive.mount('/gdrive')

Drive already mounted at /gdrive; to attempt to forcibly remount, call drive.mount("/gdrive", force_remount=True).


In [4]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
import time

from tqdm import tqdm, tqdm_notebook

In [5]:
from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model

In [6]:
from transformers import AdamW
from transformers.optimization import WarmupLinearSchedule

In [7]:
##GPU 사용 시
device = torch.device("cuda:0")

In [8]:
bertmodel, vocab = get_pytorch_kobert_model()

using cached model
using cached model


In [9]:
dataset_train = nlp.data.TSVDataset("/gdrive/My Drive/Colab Notebooks/integrated_data/rand_train_data.txt", field_indices=[1,2,3], num_discard_samples=1)
dataset_test = nlp.data.TSVDataset('/gdrive/My Drive/Colab Notebooks/integrated_data/rand_test_data.txt', field_indices=[1,2,3], num_discard_samples=1)

In [10]:
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

using cached model


In [11]:
class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, rating_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)

        self.sentences = [transform([i[sent_idx]]) for i in dataset]
        self.ratings = [np.int32(i[rating_idx]) for i in dataset]
        self.labels = [np.int32(i[label_idx]) for i in dataset]
       

    def __getitem__(self, i):
        return (self.sentences[i] + (self.ratings[i], ) + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))


In [12]:
## Setting parameters
max_len = 64
batch_size = 64
warmup_ratio = 0.1
num_epochs = 30
max_grad_norm = 1
log_interval = 200
learning_rate =  5e-5

In [13]:
data_train = BERTDataset(dataset_train, 0, 1, 2, tok, max_len, True, False)
data_test = BERTDataset(dataset_test, 0, 1, 2, tok, max_len, True, False)

In [14]:
train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, num_workers=5, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=5)

In [15]:
origindataset_test = nlp.data.TSVDataset('/gdrive/My Drive/Colab Notebooks/origin_data/edit_ratings_test.txt', field_indices=[1,2,3], num_discard_samples=1)
origin_test = BERTDataset(origindataset_test, 0, 1, 2, tok, max_len, True, False) ## 전처리 잘못해서 두개 위치가 다름..

In [16]:
origin_test_dataloader = torch.utils.data.DataLoader(origin_test, batch_size=batch_size, num_workers=5)

In [17]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes=2,
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        np_classifier = self.classifier(out)
        return out, np_classifier

In [18]:
class new_bert(nn.Module):
  def __init__(self, model, hidden_size = 768):
    super(new_bert, self).__init__()
    self.main_model = model
    self.fc1 = nn.Linear(hidden_size, 512)
    self.fc2 = nn.Linear(512, 11)

  def forward(self, token_ids, valid_length, segment_ids):
    x, np_cls = self.main_model(token_ids, valid_length, segment_ids)
    x = self.fc1(x)
    x = F.gelu(x)
    rating_pred = self.fc2(x)
    return rating_pred, np_cls

In [19]:
old_model = BERTClassifier(bertmodel,  dr_rate=0.5).to(device)
continue_train = True
start_point = 8
if continue_train:
  old_model.load_state_dict(torch.load('/gdrive/My Drive/nlp/model_'+str(start_point)))
  print("load model")

load model


In [20]:

model = new_bert(old_model).to(device)
for param in model.main_model.parameters():
  param.requires_grad = False
  print(param)

[1;30;43m스트리밍 출력 내용이 길어서 마지막 5000줄이 삭제되었습니다.[0m
        -8.3703e-02,  2.0324e-02,  4.4611e-03,  6.0379e-02,  7.5472e-02,
         3.3576e-02, -5.0742e-02,  6.0371e-02, -7.4994e-04,  3.5519e-02,
         7.3245e-02,  2.2619e-02, -1.5524e-02,  5.1840e-03,  3.0499e-02,
        -4.1665e-02, -1.5196e-02,  7.0477e-02, -9.8715e-02,  1.1301e-01,
         2.9148e-02, -1.8262e-03,  5.3836e-02,  8.1751e-03, -3.7369e-02,
        -4.2939e-02,  7.3634e-02, -1.9948e-02,  1.7513e-02, -2.7256e-03,
         2.6411e-02, -2.5554e-02, -1.2968e-02, -3.8595e-02,  6.5910e-02,
         3.7191e-02, -1.4987e-02, -3.0821e-02,  6.2824e-02, -2.5902e-02,
        -3.5882e-02,  2.9081e-02,  7.8296e-02, -6.2917e-02, -6.0682e-03,
        -2.1960e-02, -4.0536e-02,  8.4482e-02, -1.0291e-01, -8.7210e-03,
        -2.6455e-02,  5.2199e-03, -4.6123e-02, -3.5413e-02, -3.6678e-04,
        -1.5286e-02,  4.5764e-03,  9.9232e-03, -8.5327e-03,  3.4890e-02,
         2.7781e-02, -4.0610e-05,  2.0502e-02, -3.0779e-03,  4.1487e-02,
 

In [13]:
"""
model = BERTClassifier(bertmodel,  dr_rate=0.5).to(device)
continue_train = True
start_point = 8
if continue_train:
  model.load_state_dict(torch.load('/gdrive/My Drive/nlp/model_'+str(start_point)))
  print("load model")
"""

load model


In [21]:
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

In [22]:
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

In [23]:
t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

In [24]:
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=warmup_step, t_total=t_total)

In [25]:
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc

In [None]:
for e in range(start_point,num_epochs):
    train_acc_1 = 0.0
    train_acc_2 = 0.0
    test_acc_1 = 0.0
    test_acc_2 = 0.0
  
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, rating, label) in enumerate(tqdm_notebook(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        rating = rating.long().to(device)
        label = label.long().to(device)
       
        rating_pred, np_cls = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(rating_pred, rating)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc_1 += calc_accuracy(rating_pred, rating)
        train_acc_2 += calc_accuracy(np_cls, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train_1 acc {} train_2 acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc_1 / (batch_id+1),  train_acc_2 / (batch_id+1)))
    print("epoch {} train_1 acc {} train_2 acc {}".format(e+1, train_acc_1 / (batch_id+1), train_acc_2 / (batch_id+1)))
    model.eval()
    torch.save(model.state_dict(), '/gdrive/My Drive/Colab Notebooks/freeze_model_save/model_'+str(e+1))
    for batch_id, (token_ids, valid_length, segment_ids, rating, label) in enumerate(tqdm_notebook(test_dataloader)):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        rating = rating.long().to(device)
        label = label.long().to(device)
        rating_pred, np_cls = model(token_ids, valid_length, segment_ids)
        test_acc_1 += calc_accuracy(rating_pred, rating)
        test_acc_2 += calc_accuracy(np_cls, label)
    print("epoch {} test_1 acc {} test_2 acc {}".format(e+1, test_acc_1 / (batch_id+1), test_acc_2 / (batch_id+1)))
  
    

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, max=2188.0), HTML(value='')))

epoch 9 batch id 1 loss 2.4307363033294678 train_1 acc 0.0625 train_2 acc 0.046875
epoch 9 batch id 201 loss 2.368537664413452 train_1 acc 0.0896299751243781 train_2 acc 0.049440298507462684
epoch 9 batch id 401 loss 2.2726821899414062 train_1 acc 0.1241817331670823 train_2 acc 0.05026496259351621
epoch 9 batch id 601 loss 2.141249418258667 train_1 acc 0.14205490848585692 train_2 acc 0.050774750415973374
epoch 9 batch id 801 loss 2.0875203609466553 train_1 acc 0.15113920099875156 train_2 acc 0.05140059300873907
epoch 9 batch id 1001 loss 2.124476909637451 train_1 acc 0.1580919080919081 train_2 acc 0.051027097902097904
epoch 9 batch id 1201 loss 2.174044132232666 train_1 acc 0.16292412572855952 train_2 acc 0.050934117402164865
epoch 9 batch id 1401 loss 2.092926025390625 train_1 acc 0.166700124910778 train_2 acc 0.050901142041399
epoch 9 batch id 1601 loss 2.0795934200286865 train_1 acc 0.16894714241099312 train_2 acc 0.051022798251093064
epoch 9 batch id 1801 loss 2.041609287261963 tra

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


epoch 9 test_1 acc 0.20432308306709265 test_2 acc 0.037689696485623


HBox(children=(FloatProgress(value=0.0, max=2188.0), HTML(value='')))

epoch 10 batch id 1 loss 1.8995561599731445 train_1 acc 0.140625 train_2 acc 0.0
epoch 10 batch id 201 loss 2.078573703765869 train_1 acc 0.19340796019900497 train_2 acc 0.051694651741293535
epoch 10 batch id 401 loss 1.9548087120056152 train_1 acc 0.19073410224438903 train_2 acc 0.05084943890274314
epoch 10 batch id 601 loss 2.0180912017822266 train_1 acc 0.19028182196339435 train_2 acc 0.050930740432612316
epoch 10 batch id 801 loss 1.9560022354125977 train_1 acc 0.1918305243445693 train_2 acc 0.05102996254681648
epoch 10 batch id 1001 loss 2.001473903656006 train_1 acc 0.19185501998001997 train_2 acc 0.05098026973026973
epoch 10 batch id 1201 loss 2.2401626110076904 train_1 acc 0.19178028726061616 train_2 acc 0.050400707743547046
epoch 10 batch id 1401 loss 2.132573366165161 train_1 acc 0.1909238936473947 train_2 acc 0.05032119914346895
epoch 10 batch id 1601 loss 2.0186238288879395 train_1 acc 0.1914428482198626 train_2 acc 0.05053482198625859
epoch 10 batch id 1801 loss 1.86495864

HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


epoch 10 test_1 acc 0.21590455271565495 test_2 acc 0.037689696485623


HBox(children=(FloatProgress(value=0.0, max=2188.0), HTML(value='')))

epoch 11 batch id 1 loss 2.05789852142334 train_1 acc 0.171875 train_2 acc 0.046875
epoch 11 batch id 201 loss 1.951984167098999 train_1 acc 0.1993936567164179 train_2 acc 0.048118781094527364
epoch 11 batch id 401 loss 2.0691845417022705 train_1 acc 0.19392923940149626 train_2 acc 0.05022599750623442
epoch 11 batch id 601 loss 2.0036840438842773 train_1 acc 0.19561148086522462 train_2 acc 0.05051476705490848
epoch 11 batch id 801 loss 1.885964035987854 train_1 acc 0.19699984394506867 train_2 acc 0.05093242821473159
epoch 11 batch id 1001 loss 1.957834243774414 train_1 acc 0.1962880869130869 train_2 acc 0.05085539460539461
epoch 11 batch id 1201 loss 1.8702197074890137 train_1 acc 0.19663301415487094 train_2 acc 0.05045274771024146
epoch 11 batch id 1401 loss 2.0255424976348877 train_1 acc 0.19644450392576732 train_2 acc 0.05056655960028551
epoch 11 batch id 1601 loss 1.9987037181854248 train_1 acc 0.1961957370393504 train_2 acc 0.05069097439100562
epoch 11 batch id 1801 loss 2.0277535

HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


epoch 11 test_1 acc 0.21146166134185304 test_2 acc 0.037689696485623


HBox(children=(FloatProgress(value=0.0, max=2188.0), HTML(value='')))

epoch 12 batch id 1 loss 2.0424182415008545 train_1 acc 0.1875 train_2 acc 0.0625
epoch 12 batch id 201 loss 2.0451724529266357 train_1 acc 0.19185323383084577 train_2 acc 0.052394278606965175
epoch 12 batch id 401 loss 1.9572877883911133 train_1 acc 0.1952150872817955 train_2 acc 0.050537718204488775
epoch 12 batch id 601 loss 2.1207218170166016 train_1 acc 0.19678140599001664 train_2 acc 0.050124792013311145
epoch 12 batch id 801 loss 1.961578130722046 train_1 acc 0.19561485642946316 train_2 acc 0.05062031835205993
epoch 12 batch id 1001 loss 1.9531315565109253 train_1 acc 0.19539835164835165 train_2 acc 0.0501998001998002
epoch 12 batch id 1201 loss 1.853119134902954 train_1 acc 0.19563124479600333 train_2 acc 0.05016652789342215
epoch 12 batch id 1401 loss 2.0628366470336914 train_1 acc 0.1969798358315489 train_2 acc 0.05024312990720914
epoch 12 batch id 1601 loss 2.064058780670166 train_1 acc 0.19645924422236102 train_2 acc 0.05048602435977514
epoch 12 batch id 1801 loss 1.9604300

HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


epoch 12 test_1 acc 0.21665335463258786 test_2 acc 0.037689696485623


HBox(children=(FloatProgress(value=0.0, max=2188.0), HTML(value='')))

epoch 13 batch id 1 loss 1.9818390607833862 train_1 acc 0.15625 train_2 acc 0.046875
epoch 13 batch id 201 loss 2.000199317932129 train_1 acc 0.1943407960199005 train_2 acc 0.049906716417910446
epoch 13 batch id 401 loss 2.1150479316711426 train_1 acc 0.19299407730673318 train_2 acc 0.050187032418952615
epoch 13 batch id 601 loss 1.967086911201477 train_1 acc 0.19119176372712146 train_2 acc 0.050748752079866885
epoch 13 batch id 801 loss 2.076580762863159 train_1 acc 0.1933715667915106 train_2 acc 0.049996098626716605
epoch 13 batch id 1001 loss 1.9115948677062988 train_1 acc 0.19324425574425574 train_2 acc 0.05027784715284715
epoch 13 batch id 1201 loss 1.9649653434753418 train_1 acc 0.19427820566194837 train_2 acc 0.05043973771856786
epoch 13 batch id 1401 loss 2.0235273838043213 train_1 acc 0.19496118843683083 train_2 acc 0.05038811563169165
epoch 13 batch id 1601 loss 1.898380994796753 train_1 acc 0.19568824172392255 train_2 acc 0.05048602435977514
epoch 13 batch id 1801 loss 1.908

HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


epoch 13 test_1 acc 0.22304313099041534 test_2 acc 0.037689696485623


HBox(children=(FloatProgress(value=0.0, max=2188.0), HTML(value='')))

epoch 14 batch id 1 loss 1.9368128776550293 train_1 acc 0.15625 train_2 acc 0.046875
epoch 14 batch id 201 loss 2.0347228050231934 train_1 acc 0.1986162935323383 train_2 acc 0.05076181592039801
epoch 14 batch id 401 loss 1.9203156232833862 train_1 acc 0.2003584788029925 train_2 acc 0.052096321695760596
epoch 14 batch id 601 loss 2.060039520263672 train_1 acc 0.20055116472545756 train_2 acc 0.05228265391014975
epoch 14 batch id 801 loss 1.9169577360153198 train_1 acc 0.20049157303370788 train_2 acc 0.05247347066167291
epoch 14 batch id 1001 loss 1.9517823457717896 train_1 acc 0.1996441058941059 train_2 acc 0.05160464535464535
epoch 14 batch id 1201 loss 2.1186411380767822 train_1 acc 0.1989227726894255 train_2 acc 0.05146752706078268
epoch 14 batch id 1401 loss 2.0776357650756836 train_1 acc 0.1984185403283369 train_2 acc 0.05146993219129194
epoch 14 batch id 1601 loss 2.090657949447632 train_1 acc 0.1981671611492817 train_2 acc 0.051422938788257336
epoch 14 batch id 1801 loss 2.0229568

HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


epoch 14 test_1 acc 0.21260982428115016 test_2 acc 0.037689696485623


HBox(children=(FloatProgress(value=0.0, max=2188.0), HTML(value='')))

epoch 15 batch id 1 loss 2.0085318088531494 train_1 acc 0.140625 train_2 acc 0.0
epoch 15 batch id 201 loss 2.049593210220337 train_1 acc 0.2015702736318408 train_2 acc 0.04936256218905473
epoch 15 batch id 401 loss 2.0961596965789795 train_1 acc 0.1961892144638404 train_2 acc 0.04971945137157107
epoch 15 batch id 601 loss 2.0101687908172607 train_1 acc 0.1971453826955075 train_2 acc 0.049786813643926786
epoch 15 batch id 801 loss 2.0380964279174805 train_1 acc 0.19818976279650438 train_2 acc 0.05058130461922597
epoch 15 batch id 1001 loss 1.9569405317306519 train_1 acc 0.19675636863136864 train_2 acc 0.05091783216783217
epoch 15 batch id 1201 loss 1.8571687936782837 train_1 acc 0.19767381348875937 train_2 acc 0.05105120732722731
epoch 15 batch id 1401 loss 1.920436143875122 train_1 acc 0.19750401498929337 train_2 acc 0.0509680585296217
epoch 15 batch id 1601 loss 2.0232889652252197 train_1 acc 0.19710337289194255 train_2 acc 0.05084712679575266
epoch 15 batch id 1801 loss 1.8546286821

HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


epoch 15 test_1 acc 0.222444089456869 test_2 acc 0.037689696485623


HBox(children=(FloatProgress(value=0.0, max=2188.0), HTML(value='')))

epoch 16 batch id 1 loss 2.2142436504364014 train_1 acc 0.1875 train_2 acc 0.078125
epoch 16 batch id 201 loss 1.8801426887512207 train_1 acc 0.19208644278606965 train_2 acc 0.05006218905472637
epoch 16 batch id 401 loss 2.0026979446411133 train_1 acc 0.19451371571072318 train_2 acc 0.050031172069825436
epoch 16 batch id 601 loss 2.1315133571624756 train_1 acc 0.19584546589018303 train_2 acc 0.04989080698835274
epoch 16 batch id 801 loss 2.0613510608673096 train_1 acc 0.19692181647940074 train_2 acc 0.05071785268414482
epoch 16 batch id 1001 loss 2.020188093185425 train_1 acc 0.19770854145854147 train_2 acc 0.05065247252747253
epoch 16 batch id 1201 loss 1.980172872543335 train_1 acc 0.19766080349708576 train_2 acc 0.050843047460449625
epoch 16 batch id 1401 loss 1.9008078575134277 train_1 acc 0.1971359743040685 train_2 acc 0.05064462883654532
epoch 16 batch id 1601 loss 2.0216774940490723 train_1 acc 0.19744495627732667 train_2 acc 0.05084712679575266
epoch 16 batch id 1801 loss 2.036

HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


epoch 16 test_1 acc 0.21680311501597443 test_2 acc 0.037689696485623


HBox(children=(FloatProgress(value=0.0, max=2188.0), HTML(value='')))

epoch 17 batch id 1 loss 1.9265596866607666 train_1 acc 0.15625 train_2 acc 0.03125
epoch 17 batch id 201 loss 1.9324111938476562 train_1 acc 0.1993936567164179 train_2 acc 0.05006218905472637
epoch 17 batch id 401 loss 1.9527020454406738 train_1 acc 0.19918952618453864 train_2 acc 0.050810473815461346
epoch 17 batch id 601 loss 1.931112289428711 train_1 acc 0.20055116472545756 train_2 acc 0.050982737104825294
epoch 17 batch id 801 loss 1.9063222408294678 train_1 acc 0.20051107990012484 train_2 acc 0.050951935081148564
epoch 17 batch id 1001 loss 1.8603237867355347 train_1 acc 0.19942557442557443 train_2 acc 0.050699300699300696
epoch 17 batch id 1201 loss 1.923812747001648 train_1 acc 0.1982332431307244 train_2 acc 0.050843047460449625
epoch 17 batch id 1401 loss 2.062765598297119 train_1 acc 0.19748170949321914 train_2 acc 0.051135349750178444
epoch 17 batch id 1601 loss 2.084160327911377 train_1 acc 0.1985282635852592 train_2 acc 0.051383900687070584
epoch 17 batch id 1801 loss 1.99

HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


epoch 17 test_1 acc 0.21315894568690097 test_2 acc 0.037689696485623


HBox(children=(FloatProgress(value=0.0, max=2188.0), HTML(value='')))

epoch 18 batch id 1 loss 1.931749701499939 train_1 acc 0.265625 train_2 acc 0.015625
epoch 18 batch id 201 loss 2.088292360305786 train_1 acc 0.2025031094527363 train_2 acc 0.05262748756218905
epoch 18 batch id 401 loss 2.157688856124878 train_1 acc 0.20012468827930174 train_2 acc 0.050888403990024936
epoch 18 batch id 601 loss 1.8901671171188354 train_1 acc 0.19815931780366056 train_2 acc 0.05267262895174709
epoch 18 batch id 801 loss 1.9323116540908813 train_1 acc 0.1990675717852684 train_2 acc 0.052375936329588015
epoch 18 batch id 1001 loss 1.8671799898147583 train_1 acc 0.19814560439560439 train_2 acc 0.051901223776223776
epoch 18 batch id 1201 loss 1.9406116008758545 train_1 acc 0.19690622398001664 train_2 acc 0.05194889675270608
epoch 18 batch id 1401 loss 2.011737585067749 train_1 acc 0.19626605995717344 train_2 acc 0.05182682012847966
epoch 18 batch id 1601 loss 2.0831305980682373 train_1 acc 0.19612742036227357 train_2 acc 0.051676686445971266
epoch 18 batch id 1801 loss 1.98

HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


epoch 18 test_1 acc 0.2191493610223642 test_2 acc 0.037689696485623


HBox(children=(FloatProgress(value=0.0, max=2188.0), HTML(value='')))

epoch 19 batch id 1 loss 2.1375579833984375 train_1 acc 0.15625 train_2 acc 0.078125
epoch 19 batch id 201 loss 1.9728219509124756 train_1 acc 0.1978389303482587 train_2 acc 0.05480410447761194
epoch 19 batch id 401 loss 2.129586696624756 train_1 acc 0.19930642144638405 train_2 acc 0.05314837905236908
epoch 19 batch id 601 loss 1.9868892431259155 train_1 acc 0.19917325291181365 train_2 acc 0.05173668885191348
epoch 19 batch id 801 loss 1.945776104927063 train_1 acc 0.19830680399500625 train_2 acc 0.051868757802746565
epoch 19 batch id 1001 loss 2.078188419342041 train_1 acc 0.20003434065934067 train_2 acc 0.05140172327672328
epoch 19 batch id 1201 loss 1.950484275817871 train_1 acc 0.19919598251457118 train_2 acc 0.05063488759367194
epoch 19 batch id 1401 loss 2.0640876293182373 train_1 acc 0.19779398643825838 train_2 acc 0.05094575303354747
epoch 19 batch id 1601 loss 2.0205013751983643 train_1 acc 0.19831355402873205 train_2 acc 0.05075929106808245
epoch 19 batch id 1801 loss 1.91639

HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


epoch 19 test_1 acc 0.21260982428115016 test_2 acc 0.037689696485623


HBox(children=(FloatProgress(value=0.0, max=2188.0), HTML(value='')))

epoch 20 batch id 1 loss 1.9415745735168457 train_1 acc 0.140625 train_2 acc 0.0625
epoch 20 batch id 201 loss 1.9141672849655151 train_1 acc 0.1978389303482587 train_2 acc 0.05464863184079602
epoch 20 batch id 401 loss 1.8971225023269653 train_1 acc 0.20074812967581046 train_2 acc 0.05334320448877806
epoch 20 batch id 601 loss 1.982966661453247 train_1 acc 0.20107113144758734 train_2 acc 0.05113872712146423
epoch 20 batch id 801 loss 2.075005054473877 train_1 acc 0.19998439450686642 train_2 acc 0.05169319600499376
epoch 20 batch id 1001 loss 2.1062965393066406 train_1 acc 0.1993006993006993 train_2 acc 0.05126123876123876
epoch 20 batch id 1201 loss 1.9902384281158447 train_1 acc 0.1989878226477935 train_2 acc 0.051181307243963366
epoch 20 batch id 1401 loss 2.0600829124450684 train_1 acc 0.19926614917915775 train_2 acc 0.05124687723054961
epoch 20 batch id 1601 loss 1.9662666320800781 train_1 acc 0.19934806371018113 train_2 acc 0.051481495940037474
epoch 20 batch id 1801 loss 2.07645

HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


epoch 20 test_1 acc 0.22119608626198084 test_2 acc 0.037689696485623


HBox(children=(FloatProgress(value=0.0, max=2188.0), HTML(value='')))

epoch 21 batch id 1 loss 1.979065179824829 train_1 acc 0.234375 train_2 acc 0.0625
epoch 21 batch id 201 loss 2.1155812740325928 train_1 acc 0.1948849502487562 train_2 acc 0.05472636815920398
epoch 21 batch id 401 loss 1.9013484716415405 train_1 acc 0.19981296758104738 train_2 acc 0.05303148379052369
epoch 21 batch id 601 loss 2.0590174198150635 train_1 acc 0.19938123960066556 train_2 acc 0.05204866888519135
epoch 21 batch id 801 loss 1.9498004913330078 train_1 acc 0.19949672284644196 train_2 acc 0.05153714107365793
epoch 21 batch id 1001 loss 1.8126894235610962 train_1 acc 0.200487012987013 train_2 acc 0.0505275974025974
epoch 21 batch id 1201 loss 2.0030264854431152 train_1 acc 0.1997163821815154 train_2 acc 0.05105120732722731
epoch 21 batch id 1401 loss 1.9869900941848755 train_1 acc 0.1997791755888651 train_2 acc 0.051068433261955745
epoch 21 batch id 1601 loss 1.8939770460128784 train_1 acc 0.19864537788881947 train_2 acc 0.05130582448469707
epoch 21 batch id 1801 loss 2.00750017

HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


epoch 21 test_1 acc 0.2147064696485623 test_2 acc 0.037689696485623


HBox(children=(FloatProgress(value=0.0, max=2188.0), HTML(value='')))

epoch 22 batch id 1 loss 2.019176721572876 train_1 acc 0.28125 train_2 acc 0.0625
epoch 22 batch id 201 loss 1.999448537826538 train_1 acc 0.2068563432835821 train_2 acc 0.05068407960199005
epoch 22 batch id 401 loss 2.0211520195007324 train_1 acc 0.20031951371571072 train_2 acc 0.050031172069825436
epoch 22 batch id 601 loss 1.9093433618545532 train_1 acc 0.19956322795341097 train_2 acc 0.05069675540765391
epoch 22 batch id 801 loss 1.9580702781677246 train_1 acc 0.19844335205992508 train_2 acc 0.05087390761548065
epoch 22 batch id 1001 loss 1.9642534255981445 train_1 acc 0.19819243256743257 train_2 acc 0.05118319180819181
epoch 22 batch id 1201 loss 1.868367314338684 train_1 acc 0.19876665278934222 train_2 acc 0.05116829725228976
epoch 22 batch id 1401 loss 2.0624120235443115 train_1 acc 0.19871966452533904 train_2 acc 0.05086768379728765
epoch 22 batch id 1601 loss 2.0219292640686035 train_1 acc 0.1992992660836977 train_2 acc 0.050612898188632106
epoch 22 batch id 1801 loss 2.014967

HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


epoch 22 test_1 acc 0.22109624600638977 test_2 acc 0.037689696485623


HBox(children=(FloatProgress(value=0.0, max=2188.0), HTML(value='')))

epoch 23 batch id 1 loss 1.8114372491836548 train_1 acc 0.328125 train_2 acc 0.015625
epoch 23 batch id 201 loss 2.102952480316162 train_1 acc 0.20063743781094528 train_2 acc 0.051694651741293535
epoch 23 batch id 401 loss 2.0212855339050293 train_1 acc 0.20016365336658354 train_2 acc 0.051667705735660846
epoch 23 batch id 601 loss 2.022690534591675 train_1 acc 0.1989912645590682 train_2 acc 0.05202267054908486
epoch 23 batch id 801 loss 1.9659802913665771 train_1 acc 0.20031601123595505 train_2 acc 0.05163467540574282
epoch 23 batch id 1001 loss 2.0717391967773438 train_1 acc 0.2001436063936064 train_2 acc 0.05126123876123876
epoch 23 batch id 1201 loss 2.068775177001953 train_1 acc 0.20000260199833472 train_2 acc 0.05067391756869276
epoch 23 batch id 1401 loss 2.0001375675201416 train_1 acc 0.2000468415417559 train_2 acc 0.050343504639543186
epoch 23 batch id 1601 loss 2.0185091495513916 train_1 acc 0.20005074953154278 train_2 acc 0.05046650530918176
epoch 23 batch id 1801 loss 1.880

HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


epoch 23 test_1 acc 0.22369209265175719 test_2 acc 0.037689696485623


HBox(children=(FloatProgress(value=0.0, max=2188.0), HTML(value='')))

epoch 24 batch id 1 loss 2.0924084186553955 train_1 acc 0.1875 train_2 acc 0.09375
epoch 24 batch id 201 loss 1.9658490419387817 train_1 acc 0.20607898009950248 train_2 acc 0.049828980099502485
epoch 24 batch id 401 loss 2.0155069828033447 train_1 acc 0.20094295511221946 train_2 acc 0.05069357855361596
epoch 24 batch id 601 loss 2.0730695724487305 train_1 acc 0.20107113144758734 train_2 acc 0.050748752079866885
epoch 24 batch id 801 loss 2.038428544998169 train_1 acc 0.2017595193508115 train_2 acc 0.051205524344569285
epoch 24 batch id 1001 loss 1.997470498085022 train_1 acc 0.20189185814185814 train_2 acc 0.05193244255744256
epoch 24 batch id 1201 loss 1.8628227710723877 train_1 acc 0.20160283097418819 train_2 acc 0.05154558701082431
epoch 24 batch id 1401 loss 2.0034592151641846 train_1 acc 0.20116211634546752 train_2 acc 0.05180451463240542
epoch 24 batch id 1601 loss 1.89228355884552 train_1 acc 0.2005875234228607 train_2 acc 0.051676686445971266
epoch 24 batch id 1801 loss 1.94345

HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


epoch 24 test_1 acc 0.21735223642172524 test_2 acc 0.037689696485623


HBox(children=(FloatProgress(value=0.0, max=2188.0), HTML(value='')))

epoch 25 batch id 1 loss 1.8602238893508911 train_1 acc 0.21875 train_2 acc 0.015625
epoch 25 batch id 201 loss 1.9787602424621582 train_1 acc 0.20024875621890548 train_2 acc 0.052782960199004976
epoch 25 batch id 401 loss 1.959503412246704 train_1 acc 0.20191708229426433 train_2 acc 0.05283665835411471
epoch 25 batch id 601 loss 2.0397467613220215 train_1 acc 0.19875727953410982 train_2 acc 0.05202267054908486
epoch 25 batch id 801 loss 1.9645236730575562 train_1 acc 0.19969179151061173 train_2 acc 0.051361579275905116
epoch 25 batch id 1001 loss 1.9063388109207153 train_1 acc 0.19998751248751248 train_2 acc 0.051292457542457544


In [24]:
! pip install pandas



In [26]:
import pandas as pd
start = pd.read_csv('/gdrive/My Drive/Colab Notebooks/origin_data/ratings_test.txt', sep="\t")
start['rating'] = 1

In [27]:
print(start)

            id  ... rating
0      6270596  ...      1
1      9274899  ...      1
2      8544678  ...      1
3      6825595  ...      1
4      6723715  ...      1
...        ...  ...    ...
49995  4608761  ...      1
49996  5308387  ...      1
49997  9072549  ...      1
49998  5802125  ...      1
49999  6070594  ...      1

[50000 rows x 4 columns]


In [30]:
start.to_csv('/gdrive/My Drive/Colab Notebooks/origin_data/edit_ratings_test.txt',index = False, sep="\t")

NameError: ignored

In [77]:
dataset_test = nlp.data.TSVDataset('/gdrive/My Drive/Colab Notebooks/origin_data/edit_ratings_test.txt', field_indices=[1,2,3], num_discard_samples=1)
data_test = BERTDataset(dataset_test, 0, 1, 2, tok, max_len, True, False)

In [36]:
torch.save(model.state_dict(), '/gdrive/My Drive/Colab Notebooks/model_save/model_30')
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=5)

In [78]:
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=5)

In [79]:
test_acc_3 = 0.0
for batch_id, (token_ids, valid_length, segment_ids, label, rating) in enumerate(tqdm_notebook(test_dataloader)):
    token_ids = token_ids.long().to(device)
    segment_ids = segment_ids.long().to(device)
    valid_length= valid_length
    rating = rating.long().to(device)
    label = label.long().to(device)
    out_1, out_2 = model(token_ids, valid_length, segment_ids)
    test_acc_3 += calc_accuracy(out_2, label)
print("epoch {} test_3 acc {}".format(1, test_acc_3 / (batch_id+1)))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch 1 test_3 acc 0.8959998401534527


In [None]:
function ClickConnect(){
    console.log("코랩 연결 끊김 방지"); 
    document.querySelector("colab-toolbar-button#connect").click() 
}
setInterval(ClickConnect, 60 * 1000)

In [41]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes=2,
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        np_classifier = self.classifier(out)
        return out, np_classifier

In [42]:
!wget https://www.dropbox.com/s/977gbwh542gdy94/ratings_test.txt?dl=1

--2020-08-30 16:35:07--  https://www.dropbox.com/s/977gbwh542gdy94/ratings_test.txt?dl=1
Resolving www.dropbox.com (www.dropbox.com)... 162.125.3.1, 2620:100:6018:1::a27d:301
Connecting to www.dropbox.com (www.dropbox.com)|162.125.3.1|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /s/dl/977gbwh542gdy94/ratings_test.txt [following]
--2020-08-30 16:35:07--  https://www.dropbox.com/s/dl/977gbwh542gdy94/ratings_test.txt
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uc58866a943685d51dc0f6bceb12.dl.dropboxusercontent.com/cd/0/get/A-cxRRdXVZNHs31EyUDfHo9PoRzJ3vooKXgI6uG6boPbU8mUzm-ejnFn_KXOnumzMsvy5wCDKsVPzvW1fpe_aXjABX8KpW3QJRz1H1YqPAXZ-xaTkOtWSuEcRRrSAtKEufc/file?dl=1# [following]
--2020-08-30 16:35:07--  https://uc58866a943685d51dc0f6bceb12.dl.dropboxusercontent.com/cd/0/get/A-cxRRdXVZNHs31EyUDfHo9PoRzJ3vooKXgI6uG6boPbU8mUzm-ejnFn_KXOnumzMsvy5wCDKsVPzvW1fpe_aXjABX8KpW3QJRz1

In [47]:
class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)

        self.sentences = [transform([i[sent_idx]]) for i in dataset]
        self.labels = [np.int32(i[label_idx]) for i in dataset]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))

In [48]:
dataset_test = nlp.data.TSVDataset("ratings_test.txt?dl=1", field_indices=[1,2], num_discard_samples=1)
data_test = BERTDataset(dataset_test, 0, 1, tok, max_len, True, False)
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=5)

In [49]:
model = BERTClassifier(bertmodel,  dr_rate=0.5).to(device)
continue_train = True
start_point = 8

if continue_train:
  model.load_state_dict(torch.load('/gdrive/My Drive/nlp/model_'+str(start_point)))
  print("load model")

load model


In [61]:
dataset_test = nlp.data.TSVDataset('/gdrive/My Drive/Colab Notebooks/origin_data/edit_ratings_test.txt', field_indices=[1,2,3], num_discard_samples=1)
data_test = BERTDataset(dataset_test, 0, 1, 2, tok, max_len, True, False)

In [63]:
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=5)

In [67]:
test_acc = 0.0 # rating, label 통합본.
for batch_id, (token_ids, valid_length, segment_ids, label, rating) in enumerate(tqdm_notebook(test_dataloader)):
    model.eval()
    token_ids = token_ids.long().to(device)
    segment_ids = segment_ids.long().to(device)
    valid_length= valid_length
    label = label.long().to(device)
    bert_weight, np_classifier = model(token_ids, valid_length, segment_ids)
    test_acc += calc_accuracy(np_classifier, label)
print("epoch {} test acc {}".format(1, test_acc / (batch_id+1)))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


epoch 1 test acc 0.037689696485623
