<a href="https://colab.research.google.com/github/ark1st/2020_AI/blob/master/model_imdb_BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/50/0c/7d5950fcd80b029be0a8891727ba21e0cd27692c407c51261c3c921f6da3/transformers-4.1.1-py3-none-any.whl (1.5MB)
[K     |████████████████████████████████| 1.5MB 15.1MB/s 
[?25hCollecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 56.6MB/s 
Collecting tokenizers==0.9.4
[?25l  Downloading https://files.pythonhosted.org/packages/0f/1c/e789a8b12e28be5bc1ce2156cf87cb522b379be9cadc7ad8091a4cc107c4/tokenizers-0.9.4-cp36-cp36m-manylinux2010_x86_64.whl (2.9MB)
[K     |████████████████████████████████| 2.9MB 54.6MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.43-cp36-none-any.whl size=893261 sha256=e206b3291688

In [2]:
import re
import sys
import random

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchtext import data
from torchtext import datasets

from transformers import BertTokenizer, BertModel

In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

print(len(tokenizer.vocab))

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…


30522


In [5]:
max_input_length = tokenizer.max_model_input_sizes['bert-base-uncased']
print(max_input_length)

512


In [6]:
def new_tokenizer(sentence):
  tokens = tokenizer.tokenize(sentence)
  tokens = tokens[:max_input_length - 2]
  return tokens

In [7]:
def PreProcessingText(input_sentence):
    input_sentence = input_sentence.lower() # 소문자화
    input_sentence = re.sub('<[^>]*>', repl= ' ', string = input_sentence) # "<br />" 처리
    input_sentence = re.sub('[!"$%&\()*+,-./:;<=>?@[\\]^_`{|}~]', repl= ' ', string = input_sentence) # 특수문자 처리 ("'" 제외)
    input_sentence = re.sub('\s+', repl= ' ', string = input_sentence) # 연속된 띄어쓰기 처리
    if input_sentence:
        return input_sentence

def PreProc(list_sentence):
    return [tokenizer.convert_tokens_to_ids(PreProcessingText(x)) for x in list_sentence]

In [8]:
TEXT = data.Field(batch_first = True,
                  use_vocab = False,
                  tokenize = new_tokenizer,
                  preprocessing = PreProc,
                  init_token = tokenizer.cls_token_id,
                  eos_token = tokenizer.sep_token_id,
                  pad_token = tokenizer.pad_token_id,
                  unk_token = tokenizer.unk_token_id)

LABEL = data.LabelField(dtype = torch.float)

In [10]:
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)

In [11]:
LABEL.build_vocab(train_data)

In [12]:
train_data, valid_data = train_data.split(random_state = random.seed(0), split_ratio=0.8)

In [13]:
# Data Length
print(f'Train Data Length : {len(train_data.examples)}')
print(f'Test Data Length : {len(test_data.examples)}')

Train Data Length : 20000
Test Data Length : 25000


In [14]:

# Data Fields
train_data.fields

{'label': <torchtext.data.field.LabelField at 0x7f6b036bd6a0>,
 'text': <torchtext.data.field.Field at 0x7f6b036bd9b0>}

In [15]:

# Data Sample
print('---- Data Sample ----')
print('Input : ')
print(tokenizer.convert_ids_to_tokens(vars(train_data.examples[2])['text']))

---- Data Sample ----
Input : 
['i', 'guess', 'there', 'are', 'some', 'out', 'there', 'that', 'remember', 'nicole', 'egg', '##ert', 'from', 'her', 'little', 'girl', 'days', 'on', 'such', 'tv', 'shows', 'as', 't', '[UNK]', 'j', '[UNK]', 'hooker', '[UNK]', 'charles', 'in', 'charge', '[UNK]', 'and', 'who', "'", 's', 'the', 'boss', '[UNK]', 'you', 'per', '##vert', '##s', '[UNK]', 'you', '[UNK]', 'maybe', 'you', 'remember', 'her', 'from', 'bay', '##watch', 'when', 'she', 'grew', 'up', 'and', 'got', 'breast', 'implant', '##s', '[UNK]', 'no', 'matter', '[UNK]', 'you', 'will', 'certainly', 'forget', 'her', 'in', 'this', 'supposed', 'comedy', 'about', 'man', '[UNK]', 'eating', 'aliens', '[UNK]', '[UNK]', 'br', '[UNK]', '[UNK]', '[UNK]', 'br', '[UNK]', '[UNK]', 'there', 'are', 'so', 'many', 'things', 'that', 'do', 'not', 'make', 'sense', 'and', 'are', 'never', 'explained', '[UNK]', 'how', 'did', 'she', 'recognize', 'the', 'alien', '[UNK]', 'why', 'was', 'the', 'alien', 'hot', 'for', 'pa', '##pr'

In [16]:
# Label Info
print(f'Label Size : {len(LABEL.vocab)}')

print('Lable Examples : ')
for idx, (k, v) in enumerate(LABEL.vocab.stoi.items()):
    print('\t', k, v)

Label Size : 2
Lable Examples : 
	 neg 0
	 pos 1


In [17]:
model_config = {}

In [18]:
model_config['batch_size'] = 8

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size=model_config['batch_size'],
    device=device)

In [19]:
# Check batch data
sample_for_check = next(iter(train_iterator))
print(sample_for_check)
print(sample_for_check.text)
print(sample_for_check.label)


[torchtext.data.batch.Batch of size 8]
	[.text]:[torch.cuda.LongTensor of size 8x512 (GPU 0)]
	[.label]:[torch.cuda.FloatTensor of size 8 (GPU 0)]
tensor([[  101,  2023, 11322,  ...,     0,     0,     0],
        [  101,  2023,  2001,  ...,  2256,  9479,   102],
        [  101, 16637,   100,  ...,     0,     0,     0],
        ...,
        [  101, 17453, 18856,  ...,     0,     0,     0],
        [  101,  6373,  3632,  ...,     0,     0,     0],
        [  101,  2023,  3185,  ...,     0,     0,     0]], device='cuda:0')
tensor([1., 1., 1., 0., 0., 0., 0., 1.], device='cuda:0')


In [20]:
bert = BertModel.from_pretrained('bert-base-uncased')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…




In [21]:
model_config['emb_dim'] = bert.config.to_dict()['hidden_size']

In [22]:
print(model_config['emb_dim'])

768


In [23]:
class SentenceClassification(nn.Module):
    def __init__(self, **model_config):
        super(SentenceClassification, self).__init__()
        self.bert = bert
        self.fc = nn.Linear(model_config['emb_dim'],
                            model_config['output_dim'])
        
    def forward(self, x):
        pooled_cls_output = self.bert(x)[1]
        return self.fc(pooled_cls_output)

In [24]:
def train(model, iterator, optimizer, loss_fn, idx_epoch, **model_params):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.train() 
    batch_size = model_params['batch_size']

    for idx, batch in enumerate(iterator):
        
        # Initializing
        optimizer.zero_grad()
        
        # Forward 
        predictions = model(batch.text).squeeze()
        loss = loss_fn(predictions, batch.label)

        acc = binary_accuracy(predictions, batch.label)
        
        sys.stdout.write(
                    "\r" + f"[Train] Epoch : {idx_epoch:^3}"\
                    f"[{(idx + 1) * batch_size} / {len(iterator) * batch_size} ({100. * (idx + 1) / len(iterator) :.4}%)]"\
                    f"  Loss: {loss.item():.4}"\
                    f"  Acc : {acc.item():.4}"\
                    )

        # Backward 
        loss.backward()
        optimizer.step()
        
        # Update Epoch Performance
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    return epoch_loss/len(iterator) , epoch_acc/len(iterator)

In [31]:
def evaluate(model, iterator, loss_fn, idx_epoch, **model_params):
    
    epoch_loss = 0
    epoch_acc = 0
    
    batch_size = model_params['batch_size']
    
    # evaluation mode
    model.eval()
    with torch.no_grad():
        for idx, batch in enumerate(iterator):
            predictions = model(batch.text).squeeze()
            loss = loss_fn(predictions, batch.label)
            acc = binary_accuracy(predictions, batch.label)

            epoch_loss += loss.item()
            epoch_acc += acc.item()

            sys.stdout.write(
                    "\r" + f"[Eval] Epoch : {idx_epoch:^3}"\
                    f"[{(idx + 1) * batch_size} / {len(iterator) * batch_size} ({100. * (idx + 1) / len(iterator) :.4}%)]"\
                    f"  Loss: {loss.item():.4}"\
                    f"  Acc : {acc.item():.4}"\
                    )
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [26]:
model_config.update(dict(output_dim = 1))

In [27]:

def binary_accuracy(preds, y):
    # rounded_preds = torch.argmax(preds, axis=1) 
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y).float()
    acc = correct.sum()/len(correct)
    return acc


model = SentenceClassification(**model_config)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)
loss_fn = nn.BCEWithLogitsLoss().to(device)
model = model.to(device)

In [28]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(model)

109483009

In [29]:

N_EPOCH = 4

best_valid_loss = float('inf')
model_name = "BERT"

print('---------------------------------')
print(f'Model name : {model_name}')
print('---------------------------------')

for epoch in range(N_EPOCH):
    train_loss, train_acc = train(model, train_iterator, optimizer, loss_fn, epoch, **model_config)
    print('')
    print(f'\t Epoch : {epoch} | Train Loss : {train_loss:.4} | Train Acc : {train_acc:.4}')
    valid_loss, valid_acc = evaluate(model, valid_iterator, loss_fn, epoch, **model_config)
    print('')
    print(f'\t Epoch : {epoch} | Valid Loss : {valid_loss:.4} | Valid Acc : {valid_acc:.4}')
    # print('')
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), f'./{model_name}.pt')
        print(f'\t Model is saved at {epoch}-epoch')

---------------------------------
Model name : BERT
---------------------------------
	 Epoch : 0 | Train Loss : 0.2972 | Train Acc : 0.8676
	 Epoch : 0 | Valid Loss : 0.2129 | Valid Acc : 0.9176
	 Model is saved at 0-epoch
	 Epoch : 1 | Train Loss : 0.1522 | Train Acc : 0.9457
	 Epoch : 1 | Valid Loss : 0.2117 | Valid Acc : 0.9144
	 Model is saved at 1-epoch
	 Epoch : 2 | Train Loss : 0.0923 | Train Acc : 0.9696
	 Epoch : 2 | Valid Loss : 0.2333 | Valid Acc : 0.9214
	 Epoch : 3 | Train Loss : 0.05436 | Train Acc : 0.9819
	 Epoch : 3 | Valid Loss : 0.2699 | Valid Acc : 0.9192


In [30]:
# Test set
# model.load_state_dict(torch.load(f'./{model_name}.pt'))
epoch = 0
test_loss, test_acc = evaluate(model, test_iterator, loss_fn, epoch, **model_config)
print('')
print(f'Test Loss : {test_loss:.4} | Test Acc : {test_acc:.4}')

Test Loss : 0.2656 | Test Acc : 0.9204
