In [1]:
!pip install mxnet-cu101
!pip install gluonnlp pandas tqdm
!pip install sentencepiece==0.1.85
!pip install transformers==2.1.1
!pip install torch==1.3.1
!pip install git+https://git@github.com/SKTBrain/KoBERT.git@master

Collecting mxnet-cu101
[?25l  Downloading https://files.pythonhosted.org/packages/40/26/9655677b901537f367c3c473376e4106abc72e01a8fc25b1cb6ed9c37e8c/mxnet_cu101-1.7.0-py2.py3-none-manylinux2014_x86_64.whl (846.0MB)
[K     |███████████████████████████████▌| 834.1MB 1.2MB/s eta 0:00:10tcmalloc: large alloc 1147494400 bytes == 0x393e6000 @  0x7f1150846615 0x591e47 0x4cc179 0x4cc2db 0x50a1cc 0x50beb4 0x507be4 0x509900 0x50a2fd 0x50beb4 0x507be4 0x509900 0x50a2fd 0x50cc96 0x58e683 0x50c127 0x58e683 0x50c127 0x58e683 0x50c127 0x58e683 0x50c127 0x5095c8 0x50a2fd 0x50beb4 0x507be4 0x509900 0x50a2fd 0x50beb4 0x5095c8 0x50a2fd
[K     |████████████████████████████████| 846.0MB 20kB/s 
Collecting graphviz<0.9.0,>=0.8.1
  Downloading https://files.pythonhosted.org/packages/53/39/4ab213673844e0c004bed8a0781a0721a3f6bb23eb8854ee75c236428892/graphviz-0.8.4-py2.py3-none-any.whl
Installing collected packages: graphviz, mxnet-cu101
  Found existing installation: graphviz 0.10.1
    Uninstalling graphv

In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm, tqdm_notebook
from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model
from transformers import AdamW
from transformers.optimization import WarmupLinearSchedule
import json
import transformers
from google.colab import drive
drive.mount('/content/drive/')

device = torch.device("cuda:0")
bertmodel, vocab = get_pytorch_kobert_model()
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).
using cached model
using cached model
using cached model


In [2]:
with open('/content/drive/MyDrive/dacon_ExtSum/train.jsonl') as f :
  train_data = [json.loads(i.strip()) for i in f.readlines()]

MAX_LEN, count_SENTENTS = (512, 30)
df = []
labels = []
for d in train_data :
  token_ids = [2]
  segment_ids = [0]
  attention_mask = [1]
  label = d['extractive']
  y = [0] * count_SENTENTS
  data = d['article_original']
  segment = 0
  for i in data :
    if len(token_ids) > 511 :
      token_ids = token_ids[:511]+[3]
      segment_ids = segment_ids[:511]+[segment]
      break
    token = tok(i)+['[SEP]']
    ids = tok.convert_tokens_to_ids(token)
    token_ids += ids
    if segment == 0 :
      segment_ids += [0]*len(ids)
      segment = 1
    else :
      segment_ids += [1]*len(ids)
      segment = 0
  for i in label :
    if int(i) >= 30 : continue;
    y[int(i)] = 1
  token_ids = token_ids + [1] * (MAX_LEN - len(token_ids))
  attention_mask = [1]* len(segment_ids) + [0] * (MAX_LEN - len(segment_ids))
  segment_ids = segment_ids + [segment] * (MAX_LEN - len(segment_ids))
  df.append((torch.tensor(token_ids[:512]).long(),torch.tensor(attention_mask[:512]).long(),torch.tensor(segment_ids[:512]).long(),torch.tensor(y).float()))

In [8]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert.to(device)
        self.dr_rate = dr_rate
        self.classifier = nn.Linear(768,30,bias=False)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def forward(self, token_ids, valid_length, segment_ids):
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids, attention_mask = valid_length)
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)

In [9]:
warmup_ratio = 0.1
num_epochs = 30
max_grad_norm = 1
log_interval = 200
learning_rate =  1e-5
train_set = torch.utils.data.DataLoader(df,batch_size=10,shuffle=True)
model = BERTClassifier(bertmodel,dr_rate=0.5).to(device)

# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = nn.BCEWithLogitsLoss()
t_total = len(train_set)
warmup_step = int(t_total * warmup_ratio)
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=warmup_step, t_total=t_total)

In [None]:
for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    loss_val = []
    acc = []
    train_set = torch.utils.data.DataLoader(df,batch_size=10,shuffle=True)
    model.train()
    for batch_id, (token_ids, attention_mask,token_type, label) in enumerate(tqdm_notebook(train_set)):
        optimizer.zero_grad()
        token_ids = token_ids.to(device)
        token_type = token_type.to(device)
        attention_mask = attention_mask.to(device)
        label = label.float().to(device)
        out = model(token_ids, attention_mask, token_type)
        loss = loss_fn(out, label)
        loss.backward()
        loss_val.append(loss.item())
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
#        acc = binary_acc(out,label)
#        acc_val.append(acc)
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        #train_acc += binary_acc(out,label)
        #train_acc += calc_accuracy(out, label)
        if batch_id % 200 == 0:
            print("epoch {} batch id {} loss {}".format(e+1, batch_id+1, loss.data.cpu().numpy()))
        if batch_id == 2000 :
            torch.save({
                        'epoch': e,
                        'train_no': batch_id,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss':loss
                      }, '/content/drive/MyDrive/dacon_ExtSum/checkpoint_ext_bias_false.tar')
    print("epoch {} loss {}".format(e+1, sum(loss_val)/len(loss_val)))
    torch.save({
            'epoch': e,
            'train_no': batch_id,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss':loss
          }, '/content/drive/MyDrive/dacon_ExtSum/checkpoint_ext_bias_false.tar')

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, max=4281.0), HTML(value='')))

epoch 1 batch id 1 loss 0.6605191230773926
epoch 1 batch id 201 loss 0.4917415380477905
epoch 1 batch id 401 loss 0.35226669907569885
epoch 1 batch id 601 loss 0.2796275019645691
epoch 1 batch id 801 loss 0.2633858323097229
epoch 1 batch id 1001 loss 0.21214692294597626
epoch 1 batch id 1201 loss 0.266189306974411
epoch 1 batch id 1401 loss 0.19854900240898132
epoch 1 batch id 1601 loss 0.2375704050064087
epoch 1 batch id 1801 loss 0.17870382964611053
epoch 1 batch id 2001 loss 0.1893962323665619
epoch 1 batch id 2201 loss 0.17029829323291779
epoch 1 batch id 2401 loss 0.19141048192977905
epoch 1 batch id 2601 loss 0.18971236050128937
epoch 1 batch id 2801 loss 0.20089414715766907
epoch 1 batch id 3001 loss 0.19681507349014282
epoch 1 batch id 3201 loss 0.19948090612888336
epoch 1 batch id 3401 loss 0.21351014077663422
epoch 1 batch id 3601 loss 0.1671810895204544
epoch 1 batch id 3801 loss 0.16547101736068726
epoch 1 batch id 4001 loss 0.20304271578788757
epoch 1 batch id 4201 loss 0.

HBox(children=(FloatProgress(value=0.0, max=4281.0), HTML(value='')))

epoch 2 batch id 1 loss 0.19091126322746277
epoch 2 batch id 201 loss 0.19822032749652863
epoch 2 batch id 401 loss 0.21144381165504456
epoch 2 batch id 601 loss 0.19720223546028137
epoch 2 batch id 801 loss 0.2585524618625641
epoch 2 batch id 1001 loss 0.1989128738641739
epoch 2 batch id 1201 loss 0.1998414695262909
epoch 2 batch id 1401 loss 0.1819145828485489
epoch 2 batch id 1601 loss 0.18370023369789124
epoch 2 batch id 1801 loss 0.17408941686153412
epoch 2 batch id 2001 loss 0.18562805652618408
epoch 2 batch id 2201 loss 0.16968035697937012
epoch 2 batch id 2401 loss 0.17559108138084412
epoch 2 batch id 2601 loss 0.1698753535747528
epoch 2 batch id 2801 loss 0.17111174762248993
epoch 2 batch id 3001 loss 0.21456167101860046
epoch 2 batch id 3201 loss 0.1600462943315506
epoch 2 batch id 3401 loss 0.2347215861082077
epoch 2 batch id 3601 loss 0.20106202363967896
epoch 2 batch id 3801 loss 0.22516591846942902
epoch 2 batch id 4001 loss 0.20522639155387878
epoch 2 batch id 4201 loss 

HBox(children=(FloatProgress(value=0.0, max=4281.0), HTML(value='')))

epoch 3 batch id 1 loss 0.1841743290424347
epoch 3 batch id 201 loss 0.22919289767742157
epoch 3 batch id 401 loss 0.19713033735752106
epoch 3 batch id 601 loss 0.1590227484703064
epoch 3 batch id 801 loss 0.19404251873493195
epoch 3 batch id 1001 loss 0.22196508944034576
epoch 3 batch id 1201 loss 0.17820079624652863
epoch 3 batch id 1401 loss 0.19880415499210358
epoch 3 batch id 1601 loss 0.1705121546983719
epoch 3 batch id 1801 loss 0.2114916294813156
epoch 3 batch id 2001 loss 0.21144478023052216
epoch 3 batch id 2201 loss 0.2500704228878021
epoch 3 batch id 2401 loss 0.18015627562999725
epoch 3 batch id 2601 loss 0.18371610343456268
epoch 3 batch id 2801 loss 0.16462862491607666
epoch 3 batch id 3001 loss 0.18435944616794586
epoch 3 batch id 3201 loss 0.20418712496757507
epoch 3 batch id 3401 loss 0.21711531281471252
epoch 3 batch id 3601 loss 0.20650339126586914
epoch 3 batch id 3801 loss 0.1998416632413864
epoch 3 batch id 4001 loss 0.22669784724712372
epoch 3 batch id 4201 loss

HBox(children=(FloatProgress(value=0.0, max=4281.0), HTML(value='')))

epoch 4 batch id 1 loss 0.2193673700094223
epoch 4 batch id 201 loss 0.17208434641361237
epoch 4 batch id 401 loss 0.20738805830478668
epoch 4 batch id 601 loss 0.17072026431560516
epoch 4 batch id 801 loss 0.18017412722110748
epoch 4 batch id 1001 loss 0.18065112829208374
epoch 4 batch id 1201 loss 0.2501077950000763
epoch 4 batch id 1401 loss 0.17448346316814423
epoch 4 batch id 1601 loss 0.20036758482456207
epoch 4 batch id 1801 loss 0.1704731285572052
epoch 4 batch id 2001 loss 0.1760600060224533
epoch 4 batch id 2201 loss 0.17327556014060974


In [3]:
with open('/content/drive/MyDrive/dacon_ExtSum/extractive_test_v2.jsonl') as f :
  test_data = [json.loads(i.strip()) for i in f.readlines()]
test = []
MAX_LEN = 512
for d in test_data :
  token_ids = [2]
  segment_ids = [0]
  attention_mask = [1]
  data = d['article_original']
  id = d['id']
  segment = 0
  for i in data :
    if len(token_ids) > 511 :
      token_ids = token_ids[:511]+[3]
      segment_ids = segment_ids[:511]+[segment]
      break
    token = tok(i)+['[SEP]']
    ids = tok.convert_tokens_to_ids(token)
    token_ids += ids
    if segment == 0 :
      segment_ids += [0]*len(ids)
      segment = 1
    else :
      segment_ids += [1]*len(ids)
      segment = 0
  token_ids = token_ids + [1] * (MAX_LEN - len(token_ids))
  attention_mask = [1]* len(segment_ids) + [0] * (MAX_LEN - len(segment_ids))
  segment_ids = segment_ids + [segment] * (MAX_LEN - len(segment_ids))
  test.append((torch.tensor(token_ids[:512]).long(),torch.tensor(attention_mask[:512]).long(),torch.tensor(segment_ids[:512]).long(),id))

In [4]:
model = BERTClassifier(bertmodel,dr_rate=0.1).to(device)
model.load_state_dict(torch.load('/content/drive/MyDrive/dacon_ExtSum/checkpoint_ext2.tar',map_location=torch.device('cpu'))['model_state_dict'])
model.eval()

BERTClassifier(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(8002, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True

In [5]:
sub = {}
for token_ids, attention_mask, token_type, id in test :
  token_ids = token_ids.unsqueeze(0).to(device)
  attention_mask = attention_mask.unsqueeze(0).to(device)
  token_type = token_type.unsqueeze(0).to(device)
  output = model(token_ids, attention_mask, token_type)
  sub[int(id)] = output

RuntimeError: ignored