# KoBERT finetuning

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install ipywidgets  # for vscode
!pip install git+https://git@github.com/SKTBrain/KoBERT.git@master

Collecting git+https://****@github.com/SKTBrain/KoBERT.git@master
  Cloning https://****@github.com/SKTBrain/KoBERT.git (to revision master) to /tmp/pip-req-build-p0o2s6si
  Running command git clone -q 'https://****@github.com/SKTBrain/KoBERT.git' /tmp/pip-req-build-p0o2s6si


In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm.notebook import tqdm

In [None]:
from kobert import get_tokenizer
from kobert import get_pytorch_kobert_model
from sklearn.model_selection import train_test_split
                                                         


In [None]:
from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup

In [None]:
## CPU
# device = torch.device("cpu")

## GPU
device = torch.device("cuda:0")

In [None]:
bertmodel, vocab = get_pytorch_kobert_model(cachedir=".cache")

using cached model. /content/.cache/kobert_v1.zip
using cached model. /content/.cache/kobert_news_wiki_ko_cased-1087f8699e.spiece


In [None]:
!wget -O .cache/ratings_train.txt http://skt-lsl-nlp-model.s3.amazonaws.com/KoBERT/datasets/nsmc/ratings_train.txt
!wget -O .cache/ratings_test.txt http://skt-lsl-nlp-model.s3.amazonaws.com/KoBERT/datasets/nsmc/ratings_test.txt

--2022-03-30 13:21:54--  http://skt-lsl-nlp-model.s3.amazonaws.com/KoBERT/datasets/nsmc/ratings_train.txt
Resolving skt-lsl-nlp-model.s3.amazonaws.com (skt-lsl-nlp-model.s3.amazonaws.com)... 52.219.60.54
Connecting to skt-lsl-nlp-model.s3.amazonaws.com (skt-lsl-nlp-model.s3.amazonaws.com)|52.219.60.54|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 14628807 (14M) [text/plain]
Saving to: ‘.cache/ratings_train.txt’


2022-03-30 13:21:56 (7.62 MB/s) - ‘.cache/ratings_train.txt’ saved [14628807/14628807]

--2022-03-30 13:21:56--  http://skt-lsl-nlp-model.s3.amazonaws.com/KoBERT/datasets/nsmc/ratings_test.txt
Resolving skt-lsl-nlp-model.s3.amazonaws.com (skt-lsl-nlp-model.s3.amazonaws.com)... 52.219.60.54
Connecting to skt-lsl-nlp-model.s3.amazonaws.com (skt-lsl-nlp-model.s3.amazonaws.com)|52.219.60.54|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4893335 (4.7M) [text/plain]
Saving to: ‘.cache/ratings_test.txt’


2022-03-30 13:21:58 (3.61 

In [None]:
dataset = nlp.data.TSVDataset("/content/drive/MyDrive/캡디/한국어_단발성_대화_데이터셋.xlsx - Sheet1.tsv", field_indices=[0,1], num_discard_samples=1)
# dataset_test = nlp.data.TSVDataset(".cache/ratings_test.txt", field_indices=[1,2], num_discard_samples=1)

In [None]:
tmp_dataset_train = np.array(dataset)
# print(dataset_train[0])
print(tmp_dataset_train.shape)
dataset_train = []
dataset_test = []
p = 0.9
# print(len(dataset[:,1]))
train_idx = []
test_idx = []

# for i in range(7):
#   tmp_idx = np.where(tmp_dataset_train[:,1] == str(i))[0]
#   tmp_train_idx = int(len(tmp_idx)*p)

#   # np.concatenate((dataset_train, tmp_dataset_train[tmp_idx[:tmp_train_idx]]), axis=0)
#   # np.concatenate((dataset_test, tmp_dataset_train[tmp_idx[tmp_train_idx:]]), axis=0)
#   dataset_train += tmp_dataset_train[tmp_idx[:tmp_train_idx]].tolist()
#   dataset_test += tmp_dataset_train[tmp_idx[tmp_train_idx:]].tolist()

dataset_train, dataset_test = train_test_split(dataset, test_size=0.1, stratify=tmp_dataset_train[:,1])
# dataset_train, dataset_test = train_test_split(dataset, test_size=0.2, stratify=tmp_dataset_train[:,1])

print(len(dataset_train), len(dataset_test))

(38594, 2)
30875 7719


In [None]:
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

using cached model. /content/.cache/kobert_news_wiki_ko_cased-1087f8699e.spiece


In [None]:
class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)

        self.sentences = [transform([i[sent_idx]]) for i in dataset]
        self.labels = [np.int32(i[label_idx]) for i in dataset]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))


In [None]:
## Setting parameters
max_len = 64
batch_size = 64
warmup_ratio = 0.1
num_epochs = 30
max_grad_norm = 1
log_interval = 200
learning_rate =  5e-5

In [None]:
data_train = BERTDataset(dataset_train, 0, 1, tok, max_len, True, False)
data_test = BERTDataset(dataset_test, 0, 1, tok, max_len, True, False)

In [None]:
train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, num_workers=5)
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=5)

  cpuset_checked))


In [None]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes=7,
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        # nn.init.xavier_uniform_(self.classifier.weight, 0.0)
        if dr_rate is not None:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        out = pooler
        return self.classifier(out)

In [None]:
model = BERTClassifier(bertmodel).to(device)

In [None]:
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.bert.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.bert.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
    {'params': [p for n, p in model.classifier.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.classifier.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

# optimizer_grouped_parameters = [
#     {'params': [p for n, p in model.classifier.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01, 'lr' : 0.001},
#     {'params': [p for n, p in model.classifier.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr' : 0.001}
# ]

In [None]:
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()



In [None]:
t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

In [None]:
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)

In [None]:
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc

In [None]:
#official code drop_rate : 0.1
for i in range(12):
  model.bert.encoder.layer[i].attention.self.dropout = nn.Dropout(0.5)
  model.bert.encoder.layer[i].attention.output.dropout = nn.Dropout(0.5)
  model.bert.encoder.layer[i].output.dropout = nn.Dropout(0.5)


In [None]:
for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))
    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        test_acc += calc_accuracy(out, label)
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))

  cpuset_checked))


  0%|          | 0/483 [00:00<?, ?it/s]

epoch 1 batch id 1 loss 1.9362025260925293 train acc 0.1875
epoch 1 batch id 201 loss 1.9759420156478882 train acc 0.14769900497512436
epoch 1 batch id 401 loss 1.9329965114593506 train acc 0.14974283042394015
epoch 1 train acc 0.15108599800628786


  0%|          | 0/121 [00:00<?, ?it/s]

epoch 1 test acc 0.16942810976901884


  0%|          | 0/483 [00:00<?, ?it/s]

epoch 2 batch id 1 loss 1.9432365894317627 train acc 0.203125
epoch 2 batch id 201 loss 1.944503903388977 train acc 0.1630907960199005
epoch 2 batch id 401 loss 1.7453645467758179 train acc 0.1909678927680798
epoch 2 train acc 0.20739398819108962


  0%|          | 0/121 [00:00<?, ?it/s]

epoch 2 test acc 0.32411858974358976


  0%|          | 0/483 [00:00<?, ?it/s]

epoch 3 batch id 1 loss 1.730364203453064 train acc 0.359375
epoch 3 batch id 201 loss 1.5838650465011597 train acc 0.33488805970149255
epoch 3 batch id 401 loss 1.4629899263381958 train acc 0.36038809226932667
epoch 3 train acc 0.3655885764128518


  0%|          | 0/121 [00:00<?, ?it/s]

epoch 3 test acc 0.39939274740411107


  0%|          | 0/483 [00:00<?, ?it/s]

epoch 4 batch id 1 loss 1.3684591054916382 train acc 0.515625
epoch 4 batch id 201 loss 1.501288890838623 train acc 0.41068097014925375
epoch 4 batch id 401 loss 1.370251178741455 train acc 0.4246415211970075
epoch 4 train acc 0.4254119220151829


  0%|          | 0/121 [00:00<?, ?it/s]

epoch 4 test acc 0.422848458359822


  0%|          | 0/483 [00:00<?, ?it/s]

epoch 5 batch id 1 loss 1.2360605001449585 train acc 0.546875
epoch 5 batch id 201 loss 1.200718641281128 train acc 0.4425528606965174
epoch 5 batch id 401 loss 1.2643078565597534 train acc 0.4558525561097257
epoch 5 train acc 0.4580769304501189


  0%|          | 0/121 [00:00<?, ?it/s]

epoch 5 test acc 0.4553070036024582


  0%|          | 0/483 [00:00<?, ?it/s]

epoch 6 batch id 1 loss 1.2298601865768433 train acc 0.546875
epoch 6 batch id 201 loss 1.2266720533370972 train acc 0.4807991293532338
epoch 6 batch id 401 loss 1.1745747327804565 train acc 0.48803771820448877
epoch 6 train acc 0.4887122824169926


  0%|          | 0/121 [00:00<?, ?it/s]

epoch 6 test acc 0.47024992053401143


  0%|          | 0/483 [00:00<?, ?it/s]

epoch 7 batch id 1 loss 1.171959400177002 train acc 0.5625
epoch 7 batch id 201 loss 1.275585412979126 train acc 0.49836753731343286
epoch 7 batch id 401 loss 1.2045049667358398 train acc 0.5090009351620948
epoch 7 train acc 0.5112194233571046


  0%|          | 0/121 [00:00<?, ?it/s]

epoch 7 test acc 0.4758026064844247


  0%|          | 0/483 [00:00<?, ?it/s]

epoch 8 batch id 1 loss 1.1124495267868042 train acc 0.578125
epoch 8 batch id 201 loss 1.1216423511505127 train acc 0.5181902985074627
epoch 8 batch id 401 loss 1.138594388961792 train acc 0.5300810473815462
epoch 8 train acc 0.5314824399969328


  0%|          | 0/121 [00:00<?, ?it/s]

epoch 8 test acc 0.47649462280144095


  0%|          | 0/483 [00:00<?, ?it/s]

epoch 9 batch id 1 loss 1.0623326301574707 train acc 0.625
epoch 9 batch id 201 loss 1.2095144987106323 train acc 0.5415889303482587
epoch 9 batch id 401 loss 1.0122380256652832 train acc 0.551823566084788
epoch 9 train acc 0.5533306015642971


  0%|          | 0/121 [00:00<?, ?it/s]

epoch 9 test acc 0.4727497880906972


  0%|          | 0/483 [00:00<?, ?it/s]

epoch 10 batch id 1 loss 1.1167829036712646 train acc 0.578125
epoch 10 batch id 201 loss 1.2002960443496704 train acc 0.5649875621890548
epoch 10 batch id 401 loss 1.0283056497573853 train acc 0.5704878428927681
epoch 10 train acc 0.5710140134958975


  0%|          | 0/121 [00:00<?, ?it/s]

epoch 10 test acc 0.4820473087518542


  0%|          | 0/483 [00:00<?, ?it/s]

epoch 11 batch id 1 loss 1.1038835048675537 train acc 0.625
epoch 11 batch id 201 loss 1.1081708669662476 train acc 0.5782804726368159
epoch 11 batch id 401 loss 0.9628331065177917 train acc 0.5820994389027432
epoch 11 train acc 0.5833309370447052


  0%|          | 0/121 [00:00<?, ?it/s]

epoch 11 test acc 0.4850637052341598


  0%|          | 0/483 [00:00<?, ?it/s]

epoch 12 batch id 1 loss 0.9668719172477722 train acc 0.609375
epoch 12 batch id 201 loss 1.058518886566162 train acc 0.599657960199005
epoch 12 batch id 401 loss 0.9270355701446533 train acc 0.6025561097256857
epoch 12 train acc 0.6031530365769496


  0%|          | 0/121 [00:00<?, ?it/s]

epoch 12 test acc 0.48721259800805256


  0%|          | 0/483 [00:00<?, ?it/s]

epoch 13 batch id 1 loss 1.0574263334274292 train acc 0.625
epoch 13 batch id 201 loss 1.0090155601501465 train acc 0.6100746268656716
epoch 13 batch id 401 loss 0.9097015261650085 train acc 0.6136611596009975
epoch 13 train acc 0.6146983551874856


  0%|          | 0/121 [00:00<?, ?it/s]

epoch 13 test acc 0.48135529243483793


  0%|          | 0/483 [00:00<?, ?it/s]

epoch 14 batch id 1 loss 1.0331130027770996 train acc 0.640625
epoch 14 batch id 201 loss 0.9458537697792053 train acc 0.6294309701492538
epoch 14 batch id 401 loss 0.8880731463432312 train acc 0.6310395885286783
epoch 14 train acc 0.6307678667280117


  0%|          | 0/121 [00:00<?, ?it/s]

epoch 14 test acc 0.48829201101928377


  0%|          | 0/483 [00:00<?, ?it/s]

epoch 15 batch id 1 loss 1.020033836364746 train acc 0.6875
epoch 15 batch id 201 loss 0.929844319820404 train acc 0.6417910447761194
epoch 15 batch id 401 loss 0.8706528544425964 train acc 0.6415991271820449
epoch 15 train acc 0.6423167797714899


  0%|          | 0/121 [00:00<?, ?it/s]

epoch 15 test acc 0.49091107755880486


  0%|          | 0/483 [00:00<?, ?it/s]

epoch 16 batch id 1 loss 0.913838267326355 train acc 0.703125
epoch 16 batch id 201 loss 0.8658130764961243 train acc 0.660214552238806
epoch 16 batch id 401 loss 0.7789493203163147 train acc 0.6588216957605985
epoch 16 train acc 0.6575859309102062


  0%|          | 0/121 [00:00<?, ?it/s]

epoch 16 test acc 0.4829512343716889


  0%|          | 0/483 [00:00<?, ?it/s]

epoch 17 batch id 1 loss 0.8577885031700134 train acc 0.78125
epoch 17 batch id 201 loss 0.8776838183403015 train acc 0.6662002487562189
epoch 17 batch id 401 loss 0.786490261554718 train acc 0.6666926433915212
epoch 17 train acc 0.6667325646039414


  0%|          | 0/121 [00:00<?, ?it/s]

epoch 17 test acc 0.48884496185632553


  0%|          | 0/483 [00:00<?, ?it/s]

epoch 18 batch id 1 loss 0.8208760619163513 train acc 0.75
epoch 18 batch id 201 loss 0.833969235420227 train acc 0.6791822139303483
epoch 18 batch id 401 loss 0.7344396114349365 train acc 0.6759663341645885
epoch 18 train acc 0.6756695230427114


  0%|          | 0/121 [00:00<?, ?it/s]

epoch 18 test acc 0.4877291269336724


  0%|          | 0/483 [00:00<?, ?it/s]

epoch 19 batch id 1 loss 0.7775917649269104 train acc 0.75
epoch 19 batch id 201 loss 0.8223102688789368 train acc 0.6898320895522388
epoch 19 batch id 401 loss 0.77601158618927 train acc 0.6858634663341646
epoch 19 train acc 0.6877719787592976


  0%|          | 0/121 [00:00<?, ?it/s]

epoch 19 test acc 0.484325333757152


  0%|          | 0/483 [00:00<?, ?it/s]

epoch 20 batch id 1 loss 0.718590259552002 train acc 0.796875
epoch 20 batch id 201 loss 0.7582974433898926 train acc 0.6993159203980099
epoch 20 batch id 401 loss 0.7970662713050842 train acc 0.696656795511222
epoch 20 train acc 0.6962560386473431


  0%|          | 0/121 [00:00<?, ?it/s]

epoch 20 test acc 0.4842789785971604


  0%|          | 0/483 [00:00<?, ?it/s]

epoch 21 batch id 1 loss 0.7240841388702393 train acc 0.75
epoch 21 batch id 201 loss 0.740211009979248 train acc 0.7096548507462687
epoch 21 batch id 401 loss 0.7259489893913269 train acc 0.7063980673316709
epoch 21 train acc 0.7070405356184342


  0%|          | 0/121 [00:00<?, ?it/s]

epoch 21 test acc 0.4853583916083916


  0%|          | 0/483 [00:00<?, ?it/s]

epoch 22 batch id 1 loss 0.6937263011932373 train acc 0.78125
epoch 22 batch id 201 loss 0.6472786664962769 train acc 0.7151741293532339
epoch 22 batch id 401 loss 0.7580546140670776 train acc 0.7124376558603491
epoch 22 train acc 0.7112903535004985


  0%|          | 0/121 [00:00<?, ?it/s]

epoch 22 test acc 0.4824347054460691


  0%|          | 0/483 [00:00<?, ?it/s]

epoch 23 batch id 1 loss 0.6927011013031006 train acc 0.734375
epoch 23 batch id 201 loss 0.6673315167427063 train acc 0.716806592039801
epoch 23 batch id 401 loss 0.7236607074737549 train acc 0.7167238154613467


In [None]:
torch.save(model.state_dict(), '/content/drive/MyDrive/캡디/drop_05.pth')