<a href="https://colab.research.google.com/github/NakamuraSTS/PyTorch_advanced/blob/main/BERT_sentiment_estimation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
%cd "drive/MyDrive/pytorch_advanced/8_nlp_sentiment_bert"

/content/drive/MyDrive/pytorch_advanced/8_nlp_sentiment_bert


In [4]:
!pip install attrdict

Collecting attrdict
  Downloading https://files.pythonhosted.org/packages/ef/97/28fe7e68bc7adfce67d4339756e85e9fcf3c6fd7f0c0781695352b70472c/attrdict-2.0.1-py2.py3-none-any.whl
Installing collected packages: attrdict
Successfully installed attrdict-2.0.1


In [5]:
import json

config_file = "./weights/bert_config.json"

json_file = open(config_file, 'r')
config = json.load(json_file)

config

{'attention_probs_dropout_prob': 0.1,
 'hidden_act': 'gelu',
 'hidden_dropout_prob': 0.1,
 'hidden_size': 768,
 'initializer_range': 0.02,
 'intermediate_size': 3072,
 'max_position_embeddings': 512,
 'num_attention_heads': 12,
 'num_hidden_layers': 12,
 'type_vocab_size': 2,
 'vocab_size': 30522}

In [6]:
from attrdict import AttrDict

config = AttrDict(config)
config.hidden_size

768

In [7]:
import torch.nn as nn
import torch

class BertLayerNorm(nn.Module):
  def __init__(self, hidden_size, eps=1e-12):
    super(BertLayerNorm, self).__init__()

    self.gamma = nn.Parameter(torch.ones(hidden_size))
    self.beta = nn.Parameter(torch.zeros(hidden_size))
    self.variance_epsilon = eps

  def forward(self, x):
    u = x.mean(-1, keepdim=True)
    s = (x - u).pow(2).mean(-1, keepdim=True)
    x = (x - u) / torch.sqrt(s + self.variance_epsilon)
    return self.gamma * x + self.beta

In [8]:
class BertEmbeddings(nn.Module):
  def __init__(self, config):
    super(BertEmbeddings, self).__init__()

    self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
    self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
    self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

    self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)

    self.dropout = nn.Dropout(config.hidden_dropout_prob)

  def forward(self, input_ids, token_type_ids=None):
    words_embedding = self.word_embeddings(input_ids)

    if token_type_ids is None:
      token_type_ids = torch.zeros_like(input_ids)
    token_type_embeddings = self.token_type_embeddings(token_type_ids)

    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    position_embeddings = self.position_embeddings(position_ids)

    embeddings = words_embedding + position_embeddings + token_type_embeddings

    embeddings = self.LayerNorm(embeddings)
    embeddings = self.dropout(embeddings)

    return embeddings

In [9]:
import math

class BertLayer(nn.Module):
  def __init__(self, config):
    super(BertLayer, self).__init__()

    self.attention = BertAttention(config)

    self.intermediate = BertIntermediate(config)

    self.output = BertOutput(config)

  def forward(self, hidden_states, attention_mask, attention_show_flg=False):
    '''
    hidden_states: output tensor of Embedder [batch_size, seq_len, hidden_size]
    attention_mask: the same mask of transformer
    attention_show_flg: a flag of returning weights of self-attention
    '''

    if attention_show_flg == True:
      attention_output, attention_probs = self.attention(hidden_states, attention_mask, attention_show_flg)
      intermediate_output = self.intermediate(attention_output)
      layer_output = self.output(intermediate_output, attention_output)
      return layer_output, attention_probs

    elif attention_show_flg == False:
      attention_output = self.attention(hidden_states, attention_mask, attention_show_flg)
      intermediate_output = self.intermediate(attention_output)
      layer_output = self.output(intermediate_output, attention_output)
      return layer_output

class BertAttention(nn.Module):
  def __init__(self, config):
    super(BertAttention, self).__init__()
    self.selfattn = BertSelfAttention(config)
    self.output = BertSelfOutput(config)

  def forward(self, input_tensor, attention_mask, attention_show_flg):
    if attention_show_flg ==True:
      self_output, attention_probs = self.selfattn(input_tensor, attention_mask, attention_show_flg)
      attention_output = self.output(self_output, input_tensor)
      return attention_output, attention_probs

    elif attention_show_flg == False:
      self_output = self.selfattn(input_tensor, attention_mask, attention_show_flg)
      attention_output = self.output(self_output, input_tensor)
      return attention_output

class BertSelfAttention(nn.Module):
  def __init__(self, config):
    super(BertSelfAttention, self).__init__()

    self.num_attention_heads = config.num_attention_heads

    self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
    self.all_head_size = self.num_attention_heads * self.attention_head_size

    self.query = nn.Linear(config.hidden_size, self.all_head_size)
    self.key = nn.Linear(config.hidden_size, self.all_head_size)
    self.value = nn.Linear(config.hidden_size, self.all_head_size)

    self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

  def transpose_for_scores(self, x):
    new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
    x = x.view(*new_x_shape)
    return x.permute(0, 2, 1, 3)

  def forward(self, hidden_states, attention_mask, attention_show_flg=False):
    mixed_query_layer = self.query(hidden_states)
    mixed_key_layer = self.key(hidden_states)
    mixed_value_layer = self.value(hidden_states)
    
    query_layer = self.transpose_for_scores(mixed_query_layer)
    key_layer = self.transpose_for_scores(mixed_key_layer)
    value_layer = self.transpose_for_scores(mixed_value_layer)

    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
    attention_scores = attention_scores / math.sqrt(self.attention_head_size)

    attention_scores = attention_scores + attention_mask

    attention_probs = nn.Softmax(dim=-1)(attention_scores)

    attention_probs = self.dropout(attention_probs)

    context_layer = torch.matmul(attention_probs, value_layer)

    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
    new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
    context_layer = context_layer.view(*new_context_layer_shape)

    if attention_show_flg ==True:
      return context_layer, attention_probs
    elif attention_show_flg == False:
      return context_layer

class BertSelfOutput(nn.Module):

  def __init__(self, config):
    super(BertSelfOutput, self).__init__()

    self.dense = nn.Linear(config.hidden_size, config.hidden_size)
    self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
    self.dropout = nn.Dropout(config.hidden_dropout_prob)

  def forward(self, hidden_states, input_tensor):
    hidden_states = self.dense(hidden_states)
    hidden_states = self.dropout(hidden_states)
    hidden_states = self.LayerNorm(hidden_states + input_tensor)
    return hidden_states

def gelu(x):
  return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

class BertIntermediate(nn.Module):
  def __init__(self, config):
    super(BertIntermediate, self).__init__()

    self.dense = nn.Linear(config.hidden_size, config.intermediate_size)

    self.intermediate_act_fn = gelu

  def forward(self, hidden_states):
    hidden_states = self.dense(hidden_states)
    hidden_states = self.intermediate_act_fn(hidden_states)
    return hidden_states

class BertOutput(nn.Module):
  def __init__(self, config):
    super(BertOutput, self).__init__()

    self.dense = nn.Linear(config.intermediate_size, config.hidden_size)

    self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)

    self.dropout = nn.Dropout(config.hidden_dropout_prob)

  def forward(self, hidden_states, input_tensor):
    hidden_states = self.dense(hidden_states)
    hidden_states = self.dropout(hidden_states)
    hidden_states = self.LayerNorm(hidden_states + input_tensor)
    return hidden_states

In [10]:
class BertEncoder(nn.Module):
  def __init__(self, config):
    super(BertEncoder, self).__init__()

    self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])

  def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, attention_show_flg=False):
    all_encoder_layers = []

    for layer_module in self.layer:
      if attention_show_flg == True:
        hidden_states, attention_probs = layer_module(hidden_states, attention_mask, attention_show_flg)
      elif attention_show_flg ==False:
        hidden_states = layer_module(hidden_states, attention_mask, attention_show_flg)

      if output_all_encoded_layers:
        all_encoder_layers.append(hidden_states)

    if not output_all_encoded_layers:
      all_encoder_layers.append(hidden_states)

    if attention_show_flg == True:
      return all_encoder_layers, attention_probs
    elif attention_show_flg == False:
      return all_encoder_layers

In [11]:
class BertPooler(nn.Module):
  def __init__(self, config):
    super(BertPooler, self).__init__()

    self.dense = nn.Linear(config.hidden_size, config.hidden_size)
    self.activation = nn.Tanh()

  def forward(self, hidden_states):
    first_token_tensor = hidden_states[:, 0]
    pooled_output = self.dense(first_token_tensor)
    pooled_output = self.activation(pooled_output)
    return pooled_output

In [12]:
input_ids = torch.LongTensor([[31, 51, 12, 23, 99], [15, 5, 1, 0, 0]])
print("tensor size of input word id: ", input_ids.shape)

attention_mask = torch.LongTensor([[1, 1, 1, 1, 1], [1, 1, 1, 0, 0]])
print("tensor size of input mask: ", attention_mask.shape)

token_type_ids = torch.LongTensor([[0, 0, 1, 1, 1], [0, 1, 1, 1, 1]])
print("tensor size of input sentence id: ", token_type_ids.shape)

embeddings = BertEmbeddings(config)
encoder = BertEncoder(config)
pooler = BertPooler(config)

extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=torch.float32)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
print("tensor size of extended mask: ", extended_attention_mask.shape)

out1 = embeddings(input_ids, token_type_ids)
print("tensor size of BertEmbedding output: ", out1.shape)

out2 = encoder(out1, extended_attention_mask)
print("tensor size of BertEncoder output: ", out2[0].shape)

out3 = pooler(out2[-1])
print("tensor size of BertPooler output: ", out3.shape)

tensor size of input word id:  torch.Size([2, 5])
tensor size of input mask:  torch.Size([2, 5])
tensor size of input sentence id:  torch.Size([2, 5])
tensor size of extended mask:  torch.Size([2, 1, 1, 5])
tensor size of BertEmbedding output:  torch.Size([2, 5, 768])
tensor size of BertEncoder output:  torch.Size([2, 5, 768])
tensor size of BertPooler output:  torch.Size([2, 768])


In [13]:
import torch

class BertModel(nn.Module):
  def __init__(self, config):
    super(BertModel, self).__init__()

    self.embeddings = BertEmbeddings(config)
    self.encoder = BertEncoder(config)
    self.pooler = BertPooler(config)

  def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, attention_show_flg=False):
    if attention_mask is None:
      attention_mask = torch.ones_like(input_ids)
    if token_type_ids is None:
      token_type_ids = torch.zeros_like(input_ids)

    extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

    extended_attention_mask = extended_attention_mask.to(dtype=torch.float32)
    extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

    embedding_output = self.embeddings(input_ids, token_type_ids)

    if attention_show_flg == True:
      encoded_layers, attention_probs = self.encoder(embedding_output, extended_attention_mask, output_all_encoded_layers, attention_show_flg)
    elif attention_show_flg == False:
      encoded_layers = self.encoder(embedding_output, extended_attention_mask, output_all_encoded_layers, attention_show_flg)

    pooled_output = self.pooler(encoded_layers[-1])

    if not output_all_encoded_layers:
      encoded_layers = encoded_layers[-1]

    if attention_show_flg == True:
      return encoded_layers, pooled_output, attention_probs
    elif attention_show_flg == False:
      return encoded_layers, pooled_output

In [56]:
net_bert = BertModel(config)

encoded_layers, pooled_output, attention_probs = net(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, attention_show_flg=True)

print("tensor size of encoded_layers: ", encoded_layers.shape)
print("tensor size of pooled_output: ", pooled_output.shape)
print("tensor size of attention_probs: ", attention_probs.shape)

tensor size of encoded_layers:  torch.Size([2, 5, 768])
tensor size of pooled_output:  torch.Size([2, 768])
tensor size of attention_probs:  torch.Size([2, 12, 5, 5])


In [57]:
weights_path = "./weights/pytorch_model.bin"
loaded_state_dict = torch.load(weights_path)

for s in loaded_state_dict.keys():
  print(s)

bert.embeddings.word_embeddings.weight
bert.embeddings.position_embeddings.weight
bert.embeddings.token_type_embeddings.weight
bert.embeddings.LayerNorm.gamma
bert.embeddings.LayerNorm.beta
bert.encoder.layer.0.attention.self.query.weight
bert.encoder.layer.0.attention.self.query.bias
bert.encoder.layer.0.attention.self.key.weight
bert.encoder.layer.0.attention.self.key.bias
bert.encoder.layer.0.attention.self.value.weight
bert.encoder.layer.0.attention.self.value.bias
bert.encoder.layer.0.attention.output.dense.weight
bert.encoder.layer.0.attention.output.dense.bias
bert.encoder.layer.0.attention.output.LayerNorm.gamma
bert.encoder.layer.0.attention.output.LayerNorm.beta
bert.encoder.layer.0.intermediate.dense.weight
bert.encoder.layer.0.intermediate.dense.bias
bert.encoder.layer.0.output.dense.weight
bert.encoder.layer.0.output.dense.bias
bert.encoder.layer.0.output.LayerNorm.gamma
bert.encoder.layer.0.output.LayerNorm.beta
bert.encoder.layer.1.attention.self.query.weight
bert.encode

In [58]:
net_bert = BertModel(config)
net.eval()

param_names = []

for name, param in net.named_parameters():
  print(name)
  param_names.append(name)

embeddings.word_embeddings.weight
embeddings.position_embeddings.weight
embeddings.token_type_embeddings.weight
embeddings.LayerNorm.gamma
embeddings.LayerNorm.beta
encoder.layer.0.attention.selfattn.query.weight
encoder.layer.0.attention.selfattn.query.bias
encoder.layer.0.attention.selfattn.key.weight
encoder.layer.0.attention.selfattn.key.bias
encoder.layer.0.attention.selfattn.value.weight
encoder.layer.0.attention.selfattn.value.bias
encoder.layer.0.attention.output.dense.weight
encoder.layer.0.attention.output.dense.bias
encoder.layer.0.attention.output.LayerNorm.gamma
encoder.layer.0.attention.output.LayerNorm.beta
encoder.layer.0.intermediate.dense.weight
encoder.layer.0.intermediate.dense.bias
encoder.layer.0.output.dense.weight
encoder.layer.0.output.dense.bias
encoder.layer.0.output.LayerNorm.gamma
encoder.layer.0.output.LayerNorm.beta
encoder.layer.1.attention.selfattn.query.weight
encoder.layer.1.attention.selfattn.query.bias
encoder.layer.1.attention.selfattn.key.weight
e

In [59]:
new_state_dict = net_bert.state_dict().copy()

for index, (key_name, value) in enumerate(loaded_state_dict.items()):
  name = param_names[index]
  new_state_dict[name] = value
  print(str(key_name)+" => "+ str(name))

  if index+1 >= len(param_names):
    break

net_bert.load_state_dict(new_state_dict)

bert.embeddings.word_embeddings.weight => embeddings.word_embeddings.weight
bert.embeddings.position_embeddings.weight => embeddings.position_embeddings.weight
bert.embeddings.token_type_embeddings.weight => embeddings.token_type_embeddings.weight
bert.embeddings.LayerNorm.gamma => embeddings.LayerNorm.gamma
bert.embeddings.LayerNorm.beta => embeddings.LayerNorm.beta
bert.encoder.layer.0.attention.self.query.weight => encoder.layer.0.attention.selfattn.query.weight
bert.encoder.layer.0.attention.self.query.bias => encoder.layer.0.attention.selfattn.query.bias
bert.encoder.layer.0.attention.self.key.weight => encoder.layer.0.attention.selfattn.key.weight
bert.encoder.layer.0.attention.self.key.bias => encoder.layer.0.attention.selfattn.key.bias
bert.encoder.layer.0.attention.self.value.weight => encoder.layer.0.attention.selfattn.value.weight
bert.encoder.layer.0.attention.self.value.bias => encoder.layer.0.attention.selfattn.value.bias
bert.encoder.layer.0.attention.output.dense.weight

<All keys matched successfully>

In [18]:
import collections

def load_vocab(vocab_file):
  vocab = collections.OrderedDict()

  ids_to_tokens = collections.OrderedDict()
  index = 0

  with open(vocab_file, "r", encoding="utf-8") as reader:
    while True:
      token = reader.readline()
      if not token:
        break
      token = token.strip()

      vocab[token] = index
      ids_to_tokens[index] = token
      index += 1

  return vocab, ids_to_tokens

vocab_file = "./vocab/bert-base-uncased-vocab.txt"
vocab, ids_to_tokens = load_vocab(vocab_file)

In [19]:
vocab

OrderedDict([('[PAD]', 0),
             ('[unused0]', 1),
             ('[unused1]', 2),
             ('[unused2]', 3),
             ('[unused3]', 4),
             ('[unused4]', 5),
             ('[unused5]', 6),
             ('[unused6]', 7),
             ('[unused7]', 8),
             ('[unused8]', 9),
             ('[unused9]', 10),
             ('[unused10]', 11),
             ('[unused11]', 12),
             ('[unused12]', 13),
             ('[unused13]', 14),
             ('[unused14]', 15),
             ('[unused15]', 16),
             ('[unused16]', 17),
             ('[unused17]', 18),
             ('[unused18]', 19),
             ('[unused19]', 20),
             ('[unused20]', 21),
             ('[unused21]', 22),
             ('[unused22]', 23),
             ('[unused23]', 24),
             ('[unused24]', 25),
             ('[unused25]', 26),
             ('[unused26]', 27),
             ('[unused27]', 28),
             ('[unused28]', 29),
             ('[unused29]', 30),
  

In [31]:
from utils.tokenizer import BasicTokenizer, WordpieceTokenizer

class BertTokenizer(object):
  def __init__(self, vocab_file, do_lower_case=True):
    self.vocab, self.ids_to_tokens = load_vocab(vocab_file)

    never_split = ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")

    self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, never_split=never_split)
    self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)

  def tokenize(self, text):
    split_tokens = []
    for token in self.basic_tokenizer.tokenize(text):
      for sub_token in self.wordpiece_tokenizer.tokenize(token):
        split_tokens.append(sub_token)
    return split_tokens

  def convert_tokens_to_ids(self, tokens):
    ids = []
    for token in tokens:
      ids.append(self.vocab[token])
    return ids

  def convert_ids_to_tokens(self, ids):
    tokens = []
    for i in ids:
      tokens.append(self.ids_to_tokens[i])
    return tokens

In [32]:
text_1 = "[CLS] I accessed the bank account. [SEP]"
text_2 = "[CLS] He transferred the deposit money into the bank account. [SEP]"
text_3 = "[CLS] We play soccer at the bank of the river. [SEP]"

tokenizer = BertTokenizer(vocab_file="./vocab/bert-base-uncased-vocab.txt", do_lower_case=True)

tokenized_text_1 = tokenizer.tokenize(text_1)
tokenized_text_2 = tokenizer.tokenize(text_2)
tokenized_text_3 = tokenizer.tokenize(text_3)

print(tokenized_text_1)

['[CLS]', 'i', 'accessed', 'the', 'bank', 'account', '.', '[SEP]']


In [33]:
import numpy as np

indexed_tokens_1 = tokenizer.convert_tokens_to_ids(tokenized_text_1)
indexed_tokens_2 = tokenizer.convert_tokens_to_ids(tokenized_text_2)
indexed_tokens_3 = tokenizer.convert_tokens_to_ids(tokenized_text_3)

bank_posi_1 = np.where(np.array(tokenized_text_1) == "bank")[0][0]
bank_posi_2 = np.where(np.array(tokenized_text_2) == "bank")[0][0]
bank_posi_3 = np.where(np.array(tokenized_text_3) == "bank")[0][0]

print(np.where(np.array(tokenized_text_1) == "bank"))

(array([4]),)


In [34]:
print(indexed_tokens_1)

[101, 1045, 11570, 1996, 2924, 4070, 1012, 102]


In [35]:
tokens_tensor_1 = torch.tensor([indexed_tokens_1])
tokens_tensor_2 = torch.tensor([indexed_tokens_2])
tokens_tensor_3 = torch.tensor([indexed_tokens_3])

bank_word_id = tokenizer.convert_tokens_to_ids(["bank"])[0]

print(tokens_tensor_1)
print(bank_word_id)

tensor([[  101,  1045, 11570,  1996,  2924,  4070,  1012,   102]])
2924


In [37]:
with torch.no_grad():
  encoded_layers_1, _ = net(tokens_tensor_1, output_all_encoded_layers=True)
  encoded_layers_2, _ = net(tokens_tensor_2, output_all_encoded_layers=True)
  encoded_layers_3, _ = net(tokens_tensor_3, output_all_encoded_layers=True)

In [38]:
bank_vector_0 = net.embeddings.word_embeddings.weight[bank_word_id]

bank_vector_1_1 = encoded_layers_1[0][0, bank_posi_1]
bank_vector_1_12 = encoded_layers_1[11][0, bank_posi_1]

bank_vector_2_1 = encoded_layers_2[0][0, bank_posi_2]
bank_vector_2_12 = encoded_layers_2[11][0, bank_posi_2]

bank_vector_3_1 = encoded_layers_3[0][0, bank_posi_3]
bank_vector_3_12 = encoded_layers_3[11][0, bank_posi_3]

In [39]:
import torch.nn.functional as F

print("similarity of first vector and first sentence 1 : ", F.cosine_similarity(bank_vector_0, bank_vector_1_1, dim=0))
print("similarity of first vector and 12th sentence 1 : ", F.cosine_similarity(bank_vector_0, bank_vector_1_12, dim=0))

print("similarity of 1st sentence 1 and 1st sentence 2 : ", F.cosine_similarity(bank_vector_1_1, bank_vector_2_1, dim=0))
print("similarity of 1st sentence 1 and 1st sentence 3 : ", F.cosine_similarity(bank_vector_1_1, bank_vector_3_1, dim=0))

print("similarity of 12th sentence 1 and 12th sentence 2 : ", F.cosine_similarity(bank_vector_1_12, bank_vector_2_12, dim=0))
print("similarity of 12th sentence 1 and 12th sentence 3 : ", F.cosine_similarity(bank_vector_1_12, bank_vector_3_12, dim=0))

similarity of first vector and first sentence 1 :  tensor(0.6814, grad_fn=<DivBackward0>)
similarity of first vector and 12th sentence 1 :  tensor(0.2276, grad_fn=<DivBackward0>)
similarity of 1st sentence 1 and 1st sentence 2 :  tensor(0.8968)
similarity of 1st sentence 1 and 1st sentence 3 :  tensor(0.7584)
similarity of 12th sentence 1 and 12th sentence 2 :  tensor(0.8796)
similarity of 12th sentence 1 and 12th sentence 3 :  tensor(0.4814)


In [41]:
import re
import string

def preprocessing_text(text):
  text = re.sub('<br />', '', text)

  for p in string.punctuation:
    if (p == ".") or (p == ","):
      continue
    else:
      text = text.replace(p, " ")

  text = text.replace(".", " . ")
  text = text.replace(",", " , ")
  return text

tokenizer_bert = BertTokenizer(vocab_file="./vocab/bert-base-uncased-vocab.txt", do_lower_case=True)

def tokenizer_with_preprocessing(text, tokenizer=tokenizer_bert.tokenize):
  text = preprocessing_text(text)
  ret = tokenizer(text)
  return ret

In [44]:
import torchtext

max_length = 256

TEXT = torchtext.data.Field(sequential=True, tokenize=tokenizer_with_preprocessing, use_vocab=True,
                            lower=True, include_lengths=True, batch_first=True, 
                            fix_length=max_length, init_token="[CLS]",
                            eos_token="[SEP]", pad_token="[PAD]", unk_token="[UNK]")
LABEL = torchtext.data.Field(sequential=False, use_vocab=False)

In [46]:
import random

train_val_ds, test_ds = torchtext.data.TabularDataset.splits(
    path="./data/", train="IMDb_train.tsv", test="IMDb_test.tsv", format="tsv",
    fields=[('Text', TEXT), ('Label', LABEL)]
)

train_ds, val_ds = train_val_ds.split(split_ratio=0.8, random_state=random.seed(1234))

In [48]:
vocab_bert, ids_to_tokens_bert = load_vocab(vocab_file="./vocab/bert-base-uncased-vocab.txt")

TEXT.build_vocab(train_ds, min_freq=1)
TEXT.vocab.stoi = vocab_bert

In [49]:
batch_size = 32

train_dl = torchtext.data.Iterator(train_ds, batch_size=batch_size, train=True)
val_dl = torchtext.data.Iterator(val_ds, batch_size=batch_size, train=False, sort=False)
test_dl = torchtext.data.Iterator(test_ds, batch_size=batch_size, train=False, sort=False)

dataloaders_dict = {"train": train_dl, "val": val_dl}

In [50]:
batch = next(iter(val_dl))
print(batch.Text)
print(batch.Label)

(tensor([[  101,  2087,  1997,  ...,  2023,  3185,   102],
        [  101, 10166,  2061,  ...,     0,     0,     0],
        [  101,  6823, 13038,  ...,     0,     0,     0],
        ...,
        [  101,  2758,  5557,  ...,  2412,  2113,   102],
        [  101,  2023,  2143,  ...,     0,     0,     0],
        [  101,  1045,  2001,  ...,     0,     0,     0]]), tensor([256, 226, 199, 256, 209, 233, 144, 256, 256, 221, 225, 256, 128, 184,
        127, 256, 256, 256,  74, 256, 215, 176,  71, 256, 188, 256, 132, 256,
        256, 256, 135, 232]))
tensor([1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0,
        0, 1, 1, 1, 0, 1, 1, 1])


In [51]:
text_minibatch_1 = (batch.Text[0][1]).numpy()

text = tokenizer_bert.convert_ids_to_tokens(text_minibatch_1)

print(text)

['[CLS]', 'wow', 'so', 'much', 'fun', 'probably', 'a', 'bit', 'much', 'for', 'normal', 'american', 'kids', ',', 'and', 'really', 'it', 's', 'a', 'stretch', 'to', 'call', 'this', 'a', 'kid', 's', 'film', ',', 'this', 'movie', 'reminded', 'me', 'a', 'quite', 'a', 'bit', 'of', 'time', 'bandits', 'very', 'terry', 'gill', '##iam', 'all', 'the', 'way', 'through', '.', 'while', 'the', 'overall', 'narrative', 'is', 'pretty', 'much', 'straight', 'forward', ',', 'mi', '##ike', 'still', 'throws', 'in', 'a', 'lot', 'of', 'surreal', 'and', 'bun', '##uel', 'esq', '##ui', '##re', 'moments', '.', 'the', 'whole', 'first', 'act', 'violently', 'ju', '##xt', '##ap', '##oses', 'from', 'scene', 'to', 'scene', 'the', 'normal', 'family', 'life', 'of', 'the', 'main', 'kid', 'hero', ',', 'with', 'the', 'spirit', 'world', 'and', 'the', 'evil', 'than', 'is', 'ensuing', 'there', '##in', '.', 'and', 'while', 'the', 'ending', 'does', 'have', 'a', 'bit', 'of', 'an', 'ambiguous', 'aspect', 'that', 'are', 'common', 'of

In [52]:
class BertForIMDb(nn.Module):

  def __init__(self, net_bert):
    super(BertForIMDb, self).__init__()

    self.bert = net_bert

    self.cls = nn.Linear(in_features=768, out_features=2)

    nn.init.normal_(self.cls.weight, std=0.02)
    nn.init.normal_(self.cls.bias, 0)

  def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=False, attention_show_flg=False):

    if attention_show_flg == True:
      encoded_layers, pooled_output, attention_probs = self.bert(
          input_ids, token_type_ids, attention_mask, output_all_encoded_layers, attention_show_flg
      )
    elif attention_show_flg == False:
      encoded_layers, pooled_output = self.bert(
          input_ids, token_type_ids, attention_mask, output_all_encoded_layers, attention_show_flg
      )

    vec_0 = encoded_layers[:, 0, :]
    vec_0 = vec_0.view(-1, 768)
    out = self.cls(vec_0)

    if attention_show_flg == True:
      return out, attention_probs
    elif attention_show_flg == False:
      return out

In [60]:
net = BertForIMDb(net_bert)

net.train()

print("network setting is completed.")

network setting is completed.


In [61]:
for name, param in net.named_parameters():
  param.requires_grad = False

for name, param in net.bert.encoder.layer[-1].named_parameters():
  param.requires_grad = True

for name, param in net.cls.named_parameters():
  param.requires_grad = True

In [62]:
import torch.optim as optim

optimizer = optim.Adam([
                        {'params': net.bert.encoder.layer[-1].parameters(), 'lr': 5e-5},
                        {'params': net.cls.parameters(), 'lr': 5e-5}
], betas=(0.9, 0.999))

criterion = nn.CrossEntropyLoss()

In [66]:
import time

def train_model(net, dataloaders_dict, criterion, optimizer, num_epochs):

  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  print("Using device: ", device)
  print("---------------start--------------------")

  net.to(device)

  torch.backends.cudnn.benchmark = True

  batch_size = dataloaders_dict["train"].batch_size

  for epoch in range(num_epochs):
    for phase in ['train', 'val']:
      if phase == "train":
        net.train()
      else:
        net.eval()

      epoch_loss = 0.0
      epoch_corrects = 0
      iteration = 1

      t_epoch_start = time.time()
      t_iter_start = time.time()

      for batch in (dataloaders_dict[phase]):

        inputs = batch.Text[0].to(device)
        labels = batch.Label.to(device)

        optimizer.zero_grad()

        with torch.set_grad_enabled(phase == 'train'):
          outputs = net(inputs, token_type_ids=None, attention_mask=None, output_all_encoded_layers=False, attention_show_flg=False)

          loss = criterion(outputs, labels)

          _, preds = torch.max(outputs, 1)

          if phase == 'train':
            loss.backward()
            optimizer.step()

            if (iteration % 10 == 0):
              t_iter_finish = time.time()
              duration = t_iter_finish - t_iter_start
              acc = (torch.sum(preds == labels.data)).double()/batch_size
              print('Iteration {} || Loss: {:.4f} || 10iter: {:.4f} sec. || accuracy in this iteration: {}'.format(iteration, loss.item(), duration, acc))
              t_iter_start = time.time()

          iteration += 1

          epoch_loss += loss.item() * batch_size
          epoch_corrects += torch.sum(preds == labels.data)

      t_epoch_finish = time.time()
      epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)
      epoch_acc = epoch_corrects.double() / len(dataloaders_dict[phase].dataset)

      print('Epoch {}/{} | {:^5} | Loss: {:.4f} Acc: {:.4f}'.format(epoch+1, num_epochs, phase, epoch_loss, epoch_acc))

      t_epoch_start = time.time()

  return net

In [67]:
num_epochs = 2
net_trained = train_model(net, dataloaders_dict, criterion, optimizer, num_epochs=num_epochs)

Using device:  cuda:0
---------------start--------------------
Iteration 10 || Loss: 0.6572 || 10iter: 1.7839 sec. || accuracy in this iteration: 0.6875
Iteration 20 || Loss: 0.6754 || 10iter: 1.7485 sec. || accuracy in this iteration: 0.53125
Iteration 30 || Loss: 0.6522 || 10iter: 1.7493 sec. || accuracy in this iteration: 0.6875
Iteration 40 || Loss: 0.6055 || 10iter: 1.7521 sec. || accuracy in this iteration: 0.875
Iteration 50 || Loss: 0.5998 || 10iter: 1.7533 sec. || accuracy in this iteration: 0.65625
Iteration 60 || Loss: 0.5424 || 10iter: 1.7547 sec. || accuracy in this iteration: 0.78125
Iteration 70 || Loss: 0.6364 || 10iter: 1.7546 sec. || accuracy in this iteration: 0.65625
Iteration 80 || Loss: 0.4751 || 10iter: 1.7524 sec. || accuracy in this iteration: 0.78125
Iteration 90 || Loss: 0.4884 || 10iter: 1.7558 sec. || accuracy in this iteration: 0.75
Iteration 100 || Loss: 0.3799 || 10iter: 1.7572 sec. || accuracy in this iteration: 0.84375
Iteration 110 || Loss: 0.3705 || 

In [72]:
from tqdm import tqdm

save_path = './weights/bert_fine_tuning_IMDb.pth'
torch.save(net_trained.state_dict(), save_path)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

net_trained.eval()
net_trained.to(device)

epoch_corrects = 0

for batch in tqdm(test_dl):
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  inputs = batch.Text[0].to(device)
  labels = batch.Label.to(device)

  with torch.set_grad_enabled(False):

    outputs = net_trained(inputs, token_type_ids=None, attention_mask=None, output_all_encoded_layers=False, attention_show_flg=False)

    loss = criterion(outputs, labels)
    _, preds = torch.max(outputs, 1)
    epoch_corrects += torch.sum(preds == labels.data)

epoch_acc = epoch_corrects.double() / len(test_dl.dataset)
print('Test data {} accuracy : {:.4f} : '.format(len(test_dl.dataset), epoch_acc))

100%|██████████| 782/782 [02:04<00:00,  6.27it/s]

Test data 25000 accuracy : 0.9018 : 





In [74]:
batch_size = 64
test_dl = torchtext.data.Iterator(test_ds, batch_size=batch_size, train=False, sort=False)

In [75]:
batch = next(iter(test_dl))

inputs = batch.Text[0].to(device)
labels = batch.Label.to(device)

outputs, attention_probs = net_trained(inputs, token_type_ids=None,
                                       attention_mask=None, output_all_encoded_layers=False, 
                                       attention_show_flg=True)

_, preds = torch.max(outputs, 1)

In [78]:
def highlight(word, attn):
  html_color = '#%02X%02X%02X' % (255, int(255*(1- attn)), int(255*(1-attn)))
  return '<span style="background-color: {}"> {} </span>'.format(html_color, word)

def mk_html(index, batch, preds, normalized_weights, TEXT):

  sentence = batch.Text[0][index]
  label = batch.Label[index]
  pred = preds[index]

  if label == 0:
    label_str = "Negative"
  else:
    label_str = "Positive"

  if pred == 0:
    pred_str = "Negative"
  else:
    pred_str = "Positive"

  html = 'True Label: {}<br> Prediction Label: {}<br><br>'.format(label_str, pred_str)

  for i in range(12):
    attens = normalized_weights[index, i, 0, :]
    attens /= attens.max()

    html += '[Visualize Attention of BERT_' + str(i+1) + ']<br>'
    for word, attn in zip(sentence, attens):
      if tokenizer_bert.convert_ids_to_tokens([word.numpy().tolist()])[0] == "[SEP]":
        break

      html += highlight(tokenizer_bert.convert_ids_to_tokens([word.numpy().tolist()])[0], attn)
    html += "<br><br>"

  all_attens = attens*0
  for i in range(12):
    attens += normalized_weights[index, i, 0, :]
  attens /= attens.max()

  html += '[Visualize Attention of BERT_ALL]<br>'
  for word, attn in zip(sentence, attens):
    if tokenizer_bert.convert_ids_to_tokens([word.numpy().tolist()])[0] == "[SEP]":
      break

    html += highlight(tokenizer_bert.convert_ids_to_tokens([word.numpy().tolist()])[0], attn)
  html += "<br><br>"

  return html

In [79]:
from IPython.display import HTML

index = 3
html_output = mk_html(index, batch, preds, attention_probs, TEXT)
HTML(html_output)