# KoBERT finetuning

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install mxnet
!pip install gluonnlp pandas tqdm
!pip install sentencepiece
!pip install torch
!pip install transformers

In [None]:
!pip install git+https://git@github.com/SKTBrain/KoBERT.git@master

In [4]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm.notebook import tqdm

In [5]:
from tqdm import tqdm_notebook

In [6]:
from kobert import get_tokenizer
from kobert import get_pytorch_kobert_model

In [22]:
from transformers import AutoTokenizer, AutoModel

In [7]:
from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup

In [8]:
## CPU
# device = torch.device("cpu")

## GPU
device = torch.device("cuda:0")

In [9]:
bertmodel, vocab = get_pytorch_kobert_model(cachedir=".cache")
# bertmodel = AutoModel.from_pretrained("beomi/KcELECTRA-base")

/content/.cache/kobert_v1.zip[██████████████████████████████████████████████████]
/content/.cache/kobert_news_wiki_ko_cased-1087f8699e.spiece[██████████████████████████████████████████████████]


In [10]:
import pandas as pd
train_df = pd.read_csv("/content/drive/MyDrive/CUAI_Winter/train_.tsv", delimiter='\t')
valid_df = pd.read_csv("/content/drive/MyDrive/CUAI_Winter/valid_.tsv", delimiter='\t')

In [11]:
def tokens_to_sentence(tokens):
  return tokens.replace("'", '')[1:-1].split(', ')

In [12]:
dataset_train = nlp.data.TSVDataset("/content/drive/MyDrive/CUAI_Winter/train_.tsv", field_indices=[1,2], num_discard_samples=1)
dataset_valid = nlp.data.TSVDataset("/content/drive/MyDrive/CUAI_Winter/valid_.tsv", field_indices=[1,2], num_discard_samples=1)

In [13]:
print(dataset_train[0])

['아내 출산 되 신 나', '0']


In [None]:
tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
# tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)
# tok = get_tokenizer() # vocab path

In [16]:
class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len,
                 pad, pair):
        self.transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, vocab=vocab, max_seq_length=max_len, pad=pad, pair=pair)
        # self.tokenizer = bert_tokenizer
        self.sentences = [self.transform([i[sent_idx]]) for i in dataset]
        self.labels = [np.int32(i[label_idx]) for i in dataset]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))

    def clean(self, x):
        emojis = ''.join(emoji.UNICODE_EMOJI.keys())
        pattern = re.compile(f'[^ .,?!/@$%~％·∼()\x00-\x7Fㄱ-힣{emojis}]+')
        url_pattern = re.compile(
            r'https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)')
        x = pattern.sub(' ', x)
        x = url_pattern.sub('', x)
        x = x.strip()
        x = repeat_normalize(x, num_repeats=2)
        return x
        
    def encode(self, x, **kwargs):
        return self.tokenizer.encode(
            str(x),
            padding='max_length',
            max_length=max_len,
            truncation=True,
            **kwargs,
        )


In [17]:
## Setting parameters
max_len = 230
batch_size = 10
warmup_ratio = 0.1
num_epochs = 10
max_grad_norm = 1
log_interval = 200
learning_rate =  5e-5

In [24]:
data_train = BERTDataset(dataset_train, 0, 1, tokenizer, vocab, max_len, True, False)
data_valid = BERTDataset(dataset_valid, 0, 1, tokenizer, vocab, max_len, True, False)

In [None]:
train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, num_workers=5)
test_dataloader = torch.utils.data.DataLoader(data_valid, batch_size=batch_size, num_workers=5)

In [26]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes=6,
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)

In [27]:
model = BERTClassifier(bertmodel,  dr_rate=0.5).to(device)

In [28]:
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

In [29]:
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

In [30]:
t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

In [31]:
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)

In [32]:
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc

In [33]:
torch.cuda.empty_cache()

In [34]:
  import gc
  gc.collect()

33578

In [None]:
for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))
    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        test_acc += calc_accuracy(out, label)
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))

  cpuset_checked))


  0%|          | 0/4088 [00:00<?, ?it/s]

epoch 1 batch id 1 loss 1.9448007345199585 train acc 0.2
epoch 1 batch id 201 loss 1.742891550064087 train acc 0.1696517412935325
epoch 1 batch id 401 loss 1.8331588506698608 train acc 0.1733167082294267
epoch 1 batch id 601 loss 1.8093633651733398 train acc 0.1703826955074874
epoch 1 batch id 801 loss 1.8607618808746338 train acc 0.1676654182272155
epoch 1 batch id 1001 loss 1.8215272426605225 train acc 0.16683316683316574
epoch 1 batch id 1201 loss 1.9112136363983154 train acc 0.16552872606161376
epoch 1 batch id 1401 loss 1.7885513305664062 train acc 0.16695217701641502
epoch 1 batch id 1601 loss 1.7555360794067383 train acc 0.16589631480324624
epoch 1 batch id 1801 loss 1.7868907451629639 train acc 0.16801776790671735
epoch 1 batch id 2001 loss 1.8051459789276123 train acc 0.16776611694152885
epoch 1 batch id 2201 loss 1.699618935585022 train acc 0.16901408450704225
epoch 1 batch id 2401 loss 1.769626259803772 train acc 0.16859641815910076
epoch 1 batch id 2601 loss 1.7126367092132

  0%|          | 0/513 [00:00<?, ?it/s]

epoch 1 test acc 0.17680311890838205


  0%|          | 0/4088 [00:00<?, ?it/s]

epoch 2 batch id 1 loss 1.796655297279358 train acc 0.1
epoch 2 batch id 201 loss 1.7140686511993408 train acc 0.16716417910447778
epoch 2 batch id 401 loss 1.9030015468597412 train acc 0.1718204488778059
epoch 2 batch id 601 loss 1.8173259496688843 train acc 0.17271214642262875
epoch 2 batch id 801 loss 1.848510980606079 train acc 0.1739076154806487
epoch 2 batch id 1001 loss 1.7584949731826782 train acc 0.17132867132867016
epoch 2 batch id 1201 loss 1.8775545358657837 train acc 0.17277268942547727
epoch 2 batch id 1401 loss 1.831079125404358 train acc 0.1713062098501053
epoch 2 batch id 1601 loss 1.7891435623168945 train acc 0.17089319175515144
epoch 2 batch id 1801 loss 1.8648998737335205 train acc 0.17062742920599583
epoch 2 batch id 2001 loss 1.7909908294677734 train acc 0.17006496751624145
epoch 2 batch id 2201 loss 1.72427237033844 train acc 0.17024079963652894
epoch 2 batch id 2401 loss 1.7890485525131226 train acc 0.17113702623906746
epoch 2 batch id 2601 loss 1.74153232574462

  0%|          | 0/513 [00:00<?, ?it/s]

epoch 2 test acc 0.15906432748538019


  0%|          | 0/4088 [00:00<?, ?it/s]

epoch 3 batch id 1 loss 1.8471606969833374 train acc 0.1
epoch 3 batch id 201 loss 1.7058327198028564 train acc 0.17462686567164193
epoch 3 batch id 401 loss 1.8722221851348877 train acc 0.1740648379052371
epoch 3 batch id 601 loss 1.8368362188339233 train acc 0.17437603993344392
epoch 3 batch id 801 loss 1.7926361560821533 train acc 0.17540574282147242
epoch 3 batch id 1001 loss 1.7393949031829834 train acc 0.17382617382617255
epoch 3 batch id 1201 loss 1.8895946741104126 train acc 0.17227310574521076
epoch 3 batch id 1401 loss 1.7948577404022217 train acc 0.17180585296216808
epoch 3 batch id 1601 loss 1.7793381214141846 train acc 0.17076826983135396
epoch 3 batch id 1801 loss 1.79409921169281 train acc 0.17184897279289205
epoch 3 batch id 2001 loss 1.8356395959854126 train acc 0.17296351824087913
epoch 3 batch id 2201 loss 1.7777283191680908 train acc 0.17305770104497958
epoch 3 batch id 2401 loss 1.7701385021209717 train acc 0.17251145356101663
epoch 3 batch id 2601 loss 1.738565444

  0%|          | 0/513 [00:00<?, ?it/s]

epoch 3 test acc 0.15906432748538019


  0%|          | 0/4088 [00:00<?, ?it/s]

epoch 4 batch id 1 loss 1.7779678106307983 train acc 0.2
epoch 4 batch id 201 loss 1.7595056295394897 train acc 0.1601990049751245
epoch 4 batch id 401 loss 1.817226767539978 train acc 0.16957605985037436
epoch 4 batch id 601 loss 1.8099000453948975 train acc 0.16955074875207973
epoch 4 batch id 801 loss 1.7906297445297241 train acc 0.17066167290886336
epoch 4 batch id 1001 loss 1.7581231594085693 train acc 0.17032967032966934
epoch 4 batch id 1201 loss 1.8091624975204468 train acc 0.1701082431307232
epoch 4 batch id 1401 loss 1.7598094940185547 train acc 0.1703783012134175
epoch 4 batch id 1601 loss 1.8112945556640625 train acc 0.17051842598375896
epoch 4 batch id 1801 loss 1.7458031177520752 train acc 0.17223764575235928
epoch 4 batch id 2001 loss 1.8034626245498657 train acc 0.17256371814092947
epoch 4 batch id 2201 loss 1.722765564918518 train acc 0.1732394366197185
epoch 4 batch id 2401 loss 1.741219162940979 train acc 0.17138692211578577
epoch 4 batch id 2601 loss 1.7895643711090

  0%|          | 0/513 [00:00<?, ?it/s]

epoch 4 test acc 0.15906432748538019


  0%|          | 0/4088 [00:00<?, ?it/s]

epoch 5 batch id 1 loss 1.8102086782455444 train acc 0.1
epoch 5 batch id 201 loss 1.7649142742156982 train acc 0.16467661691542296
epoch 5 batch id 401 loss 1.783509612083435 train acc 0.1715710723192023
epoch 5 batch id 601 loss 1.8168761730194092 train acc 0.16505823627287833
epoch 5 batch id 801 loss 1.8205842971801758 train acc 0.16104868913857617
epoch 5 batch id 1001 loss 1.7882463932037354 train acc 0.16583416583416494
epoch 5 batch id 1201 loss 1.856406807899475 train acc 0.16677768526228018
epoch 5 batch id 1401 loss 1.7631800174713135 train acc 0.16638115631691486
epoch 5 batch id 1601 loss 1.8103668689727783 train acc 0.16608369768894304
epoch 5 batch id 1801 loss 1.810765027999878 train acc 0.1676290949472508
epoch 5 batch id 2001 loss 1.8084529638290405 train acc 0.167216391804098
epoch 5 batch id 2201 loss 1.7393643856048584 train acc 0.16869604725124984
epoch 5 batch id 2401 loss 1.7524511814117432 train acc 0.1690962099125372
epoch 5 batch id 2601 loss 1.76532518863677

  0%|          | 0/513 [00:00<?, ?it/s]

epoch 5 test acc 0.17017543859649129


  0%|          | 0/4088 [00:00<?, ?it/s]

epoch 6 batch id 1 loss 1.7538890838623047 train acc 0.3
epoch 6 batch id 201 loss 1.8046514987945557 train acc 0.1761194029850747
epoch 6 batch id 401 loss 1.838772177696228 train acc 0.17780548628428938
epoch 6 batch id 601 loss 1.8179384469985962 train acc 0.1758735440931778
epoch 6 batch id 801 loss 1.8026502132415771 train acc 0.17303370786516775
epoch 6 batch id 1001 loss 1.7609155178070068 train acc 0.17432567432567328
epoch 6 batch id 1201 loss 1.828481674194336 train acc 0.17202331390507772
epoch 6 batch id 1401 loss 1.7810519933700562 train acc 0.1723054960742312
epoch 6 batch id 1601 loss 1.7940394878387451 train acc 0.17320424734540812
epoch 6 batch id 1801 loss 1.8165152072906494 train acc 0.17473625763464712
epoch 6 batch id 2001 loss 1.8090165853500366 train acc 0.17406296851574238
epoch 6 batch id 2201 loss 1.7743241786956787 train acc 0.17505679236710633
epoch 6 batch id 2401 loss 1.776526689529419 train acc 0.17413577675968425
epoch 6 batch id 2601 loss 1.795495271682

  0%|          | 0/513 [00:00<?, ?it/s]

epoch 6 test acc 0.17017543859649129


  0%|          | 0/4088 [00:00<?, ?it/s]

epoch 7 batch id 1 loss 1.7783353328704834 train acc 0.1
epoch 7 batch id 201 loss 1.7982292175292969 train acc 0.164179104477612
epoch 7 batch id 401 loss 1.7747533321380615 train acc 0.16508728179551155
epoch 7 batch id 601 loss 1.804547667503357 train acc 0.16755407653910126
epoch 7 batch id 801 loss 1.7934682369232178 train acc 0.16404494382022422
epoch 7 batch id 1001 loss 1.7537370920181274 train acc 0.1659340659340649
epoch 7 batch id 1201 loss 1.8242870569229126 train acc 0.1676103247293908
epoch 7 batch id 1401 loss 1.7982454299926758 train acc 0.1678800856531033
epoch 7 batch id 1601 loss 1.8085343837738037 train acc 0.1695190505933777
epoch 7 batch id 1801 loss 1.7884864807128906 train acc 0.17157134925041548
epoch 7 batch id 2001 loss 1.8354400396347046 train acc 0.1709145427286352
epoch 7 batch id 2201 loss 1.7935457229614258 train acc 0.17242162653339377
epoch 7 batch id 2401 loss 1.7819675207138062 train acc 0.17192836318200772
epoch 7 batch id 2601 loss 1.76828312873840

  0%|          | 0/513 [00:00<?, ?it/s]

epoch 7 test acc 0.17680311890838205


  0%|          | 0/4088 [00:00<?, ?it/s]

epoch 8 batch id 1 loss 1.8100589513778687 train acc 0.0
epoch 8 batch id 201 loss 1.8037183284759521 train acc 0.16865671641791052
epoch 8 batch id 401 loss 1.765131950378418 train acc 0.17655860349127195
epoch 8 batch id 601 loss 1.8105148077011108 train acc 0.17753743760399313
epoch 8 batch id 801 loss 1.7928835153579712 train acc 0.1764044943820218
epoch 8 batch id 1001 loss 1.7733970880508423 train acc 0.1737262737262725
epoch 8 batch id 1201 loss 1.8024886846542358 train acc 0.17527060782680962
epoch 8 batch id 1401 loss 1.7621629238128662 train acc 0.1758029978586709
epoch 8 batch id 1601 loss 1.8301843404769897 train acc 0.1754528419737652
epoch 8 batch id 1801 loss 1.801291823387146 train acc 0.17590227651304768
epoch 8 batch id 2001 loss 1.82666015625 train acc 0.17611194402798586
epoch 8 batch id 2201 loss 1.773604393005371 train acc 0.1770104497955478
epoch 8 batch id 2401 loss 1.7784391641616821 train acc 0.17667638483965073
epoch 8 batch id 2601 loss 1.8029890060424805 tr

  0%|          | 0/513 [00:00<?, ?it/s]

epoch 8 test acc 0.17680311890838205


  0%|          | 0/4088 [00:00<?, ?it/s]

epoch 9 batch id 1 loss 1.811819314956665 train acc 0.1
epoch 9 batch id 201 loss 1.7930396795272827 train acc 0.17761194029850763
epoch 9 batch id 401 loss 1.771790862083435 train acc 0.1733167082294268
epoch 9 batch id 601 loss 1.794023871421814 train acc 0.17071547420965044
epoch 9 batch id 801 loss 1.8113987445831299 train acc 0.17103620474406941
epoch 9 batch id 1001 loss 1.7496306896209717 train acc 0.17072927072926963
epoch 9 batch id 1201 loss 1.8205373287200928 train acc 0.17252289758534417
epoch 9 batch id 1401 loss 1.8108692169189453 train acc 0.17416131334760732
epoch 9 batch id 1601 loss 1.8156620264053345 train acc 0.1757651467832594
epoch 9 batch id 1801 loss 1.7960647344589233 train acc 0.17545807884508563
epoch 9 batch id 2001 loss 1.8150157928466797 train acc 0.17536231884057962
epoch 9 batch id 2201 loss 1.7982158660888672 train acc 0.1746478873239441
epoch 9 batch id 2401 loss 1.7505033016204834 train acc 0.17496876301541092
epoch 9 batch id 2601 loss 1.789083242416

  0%|          | 0/513 [00:00<?, ?it/s]

epoch 9 test acc 0.17680311890838205


  0%|          | 0/4088 [00:00<?, ?it/s]

epoch 10 batch id 1 loss 1.7743806838989258 train acc 0.3
epoch 10 batch id 201 loss 1.7844946384429932 train acc 0.18756218905472632
epoch 10 batch id 401 loss 1.7960546016693115 train acc 0.1793017456359104
epoch 10 batch id 601 loss 1.814205527305603 train acc 0.17670549084858544
epoch 10 batch id 801 loss 1.8082119226455688 train acc 0.1766541822721591
epoch 10 batch id 1001 loss 1.7718303203582764 train acc 0.17432567432567322
epoch 10 batch id 1201 loss 1.808786153793335 train acc 0.17144046627810014
epoch 10 batch id 1401 loss 1.7787113189697266 train acc 0.17337615988579433
epoch 10 batch id 1601 loss 1.808118224143982 train acc 0.17307932542161025
epoch 10 batch id 1801 loss 1.8054920434951782 train acc 0.17518045530260884
epoch 10 batch id 2001 loss 1.8087962865829468 train acc 0.17386306846576682
epoch 10 batch id 2201 loss 1.78256094455719 train acc 0.1734211721944573
epoch 10 batch id 2401 loss 1.7861416339874268 train acc 0.1730112453144529
epoch 10 batch id 2601 loss 1.7

  0%|          | 0/513 [00:00<?, ?it/s]