References
https://skimai.com/fine-tuning-bert-for-sentiment-analysis/
https://github.com/huggingface/transformers


In [1]:
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from copy import deepcopy
import json
import datetime

In [2]:
cache_dir = '/output/cache'
compress_data_filename = '/data/52WangRuicheng/aclImdb/aclImdb_v1.tar.gz'
embedding_filename = '/data/52WangRuicheng/glove6B/glove.6B.100d.txt'
embed_size = 100
num_class = 2
min_freq = 3        # 用到的词最少出现的次数
split_rate = 0.2

In [3]:
import os
import tarfile
from tqdm import tqdm
import pickle
import random

def read_imdb(type_:str) -> list:
    '''
    Read the imdb dataset, and return list of data like [[TEXT:str, LABEL:int],...]\n
    The unzipped file will be cached.
    '''
    if os.path.exists(os.path.join(cache_dir, f'{type_}.pkl')):
        print('Cache found.')
        with open(os.path.join(cache_dir, f'{type_}.pkl'), 'rb') as f:
            data = pickle.load(f)
        return data

    if not os.path.exists(os.path.join(cache_dir, "aclImdb")):
        print(f">>> Decompressing {compress_data_filename}...")
        with tarfile.open(compress_data_filename, 'r') as f:
            f.extractall(cache_dir)
        print(">>> Done. ")

    data = []
    print(">>> Reading cached decompressed files...")
    for label in ['pos', 'neg']:
        folder_name = os.path.join(cache_dir, 'aclImdb', type_, label)
        for file in tqdm(os.listdir(folder_name)):
            with open(os.path.join(folder_name, file), 'rb') as f:
                review = f.read().decode('utf-8').replace('\n', '').lower()
                data.append((review, 1 if label == 'pos' else 0))
    print(">>> Done")
    random.shuffle(data)  
    with open(os.path.join(cache_dir, f'{type_}.pkl'), 'wb') as f:
        pickle.dump(data, f, protocol=1)
    return data

train_raw, test_raw = read_imdb('train'), read_imdb('test')

Cache found.
Cache found.


In [4]:
from transformers import BertTokenizer
import re

# Load the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
def text_preprocessing(text):
    """
    - Remove entity mentions (eg. '@united')
    - Correct errors (eg. '&amp;' to '&')
    """
    text = re.sub(r'(@.*?)[\s]', ' ', text)
    text = re.sub(r'&amp;', '&', text)
    text = re.sub(r'\s+', ' ', text).strip()

    return text

# Create a function to tokenize a set of texts
def preprocessing(data):
    """Perform required preprocessing steps for pretrained BERT.
    @param    data (np.array): Array of texts to be processed.
    @return   input_ids (torch.Tensor): Tensor of token ids to be fed to a model.
    @return   attention_masks (torch.Tensor): Tensor of indices specifying which
                  tokens should be attended to by the model.
    """
    # Create empty lists to store outputs
    input_ids = []
    attention_masks = []
    labels = []
    
    # For every sentence...
    for sent, label in tqdm(data):
        encoded_sent = tokenizer.encode_plus(
            text=text_preprocessing(sent),  # Preprocess sentence
            add_special_tokens=True,        # Add `[CLS]` and `[SEP]`
            max_length=256,                # Max length to truncate/pad
            padding='max_length',          # Pad sentence to max length
            truncation=True,
            return_attention_mask=True      # Return attention mask
            )
        # Add the outputs to the lists
        input_ids.append(encoded_sent.get('input_ids'))
        attention_masks.append(encoded_sent.get('attention_mask'))
        labels.append(label)

    return input_ids, attention_masks, labels

In [5]:
if os.path.exists(os.path.join(cache_dir, 'train_processed.pkl')):
    print('Cache found.')
    with open(os.path.join(cache_dir, 'train_processed.pkl'), 'rb') as f:
        input_ids, attention_masks, labels = pickle.load(f)
else:
    input_ids, attention_masks, labels = preprocessing(train_raw)
    input_ids, attention_masks, labels = torch.tensor(input_ids), torch.tensor(attention_masks), torch.tensor(labels)
    with open(os.path.join(cache_dir, 'train_processed.pkl'), 'wb') as f:
        pickle.dump((input_ids, attention_masks, labels), f, protocol=1)

Cache found.


In [6]:
if os.path.exists(os.path.join(cache_dir, 'test_processed.pkl')):
    print('Cache found.')
    with open(os.path.join(cache_dir, 'test_processed.pkl'), 'rb') as f:
        test_input_ids, test_attention_masks, test_labels = pickle.load(f)
else:
    test_input_ids, test_attention_masks, test_labels = preprocessing(test_raw)
    test_input_ids, test_attention_masks, test_labels = torch.tensor(test_input_ids), torch.tensor(test_attention_masks), torch.tensor(test_labels)
    with open(os.path.join(cache_dir, 'test_processed.pkl'), 'wb') as f:
        pickle.dump((test_input_ids, test_attention_masks, test_labels), f, protocol=1)

Cache found.


In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel

# Create the BertClassfier class
class BertClassifier(nn.Module):
    """Bert Model for Classification Tasks.
    """
    def __init__(self, freeze_bert=False):
        super(BertClassifier, self).__init__()
        # Specify hidden size of BERT, hidden size of our classifier, and number of labels
        D_in, H, D_out = 768, 64, 1

        self.bert = BertModel.from_pretrained('bert-base-uncased')

        # Instantiate an one-layer feed-forward classifier
        self.classifier = nn.Sequential(
            nn.Linear(D_in, H),
            nn.ReLU(),
            nn.Linear(H, D_out)
        )

        # Freeze the BERT model
        self.freeze_bert = freeze_bert
        
    def forward(self, input_ids, attention_mask):
        # Feed input to BERT
        if self.freeze_bert:
            with torch.no_grad():
                outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        else:
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        # Extract the last hidden state of the token `[CLS]` for classification task
        last_hidden_state_cls = outputs[0][:, 0, :]

        # Feed input to classifier to compute logits
        logits = torch.sigmoid(self.classifier(last_hidden_state_cls)).squeeze()

        return logits

In [19]:
def train_epoch(model:nn.Module, optimizer:torch.optim.Optimizer, dataloader:DataLoader, loss_fn) -> None:
    model.train()
    epoch_loss = 0.
    num_batch = len(dataloader)
    for i_batcch, (seq, mask, label) in enumerate(dataloader):
        seq, mask, label = seq.cuda(), mask.cuda(), label.cuda().float()
        y = model(seq, mask)
        loss = loss_fn(y, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        print(f'[Train] Loss {epoch_loss / (i_batcch + 1):>10.5f}, [{i_batcch + 1:5d} / {num_batch:5d}]', end='\r')
    epoch_loss = epoch_loss / num_batch
    print()
    return {'train_loss':epoch_loss}


def val_epoch(model:nn.Module, dataloader:DataLoader, loss_fn):
    model.eval()
    epoch_loss = 0.
    acc = 0
    num_batch = len(dataloader)
    with torch.no_grad():
        for i_batcch, (seq, mask, label) in enumerate(dataloader):
            seq, mask, label = seq.cuda(), mask.cuda(), label.cuda().float()
            y = model(seq, mask)
            loss = loss_fn(y, label)

            epoch_loss += loss.item()
            acc += (torch.round(y) == label).float().mean().item()
            print(f'[Val  ] [{i_batcch + 1} / {num_batch}]', end='\r')
    acc = acc / num_batch
    epoch_loss = epoch_loss / num_batch
    print(f'[Val  ] Loss {epoch_loss:>10.5f}, Acc {acc:>10.5f}.')
    return {'val_acc':acc, 'val_loss':epoch_loss}

def test(model:nn.Module, dataloader:DataLoader):
    model.eval()

    acc = 0
    num_batch = len(dataloader)
    with torch.no_grad():
        for i_batcch, (seq, mask, label) in enumerate(dataloader):
            seq, mask, label = seq.cuda(), mask.cuda(), label.cuda().float()
            y = model(seq, mask)

            acc += (torch.round(y) == label).float().mean().item()
            print(f'[Test  ] [{i_batcch + 1} / {num_batch}]', end='\r')
    acc = acc / num_batch
    print(f'[Test ] Acc {acc:>10.5f}.')
    return {'test_acc':acc}

In [9]:
sp = int(input_ids.shape[0] * 0.8)
train_input_ids, train_attention_masks, train_labels = input_ids[:sp], attention_masks[:sp], labels[:sp]
val_input_ids, val_attention_masks, val_labels = input_ids[sp:], attention_masks[sp:], labels[sp:]

train_dataloader = DataLoader(list(zip(train_input_ids.unbind(0), train_attention_masks.unbind(0), train_labels.unbind(0))), batch_size=32, shuffle=True)
val_dataloader = DataLoader(list(zip(val_input_ids.unbind(0), val_attention_masks.unbind(0), val_labels.unbind(0))), batch_size=32, shuffle=True)
test_dataloader = DataLoader(list(zip(test_input_ids.unbind(0), test_attention_masks.unbind(0), test_labels.unbind(0))), batch_size=32, shuffle=True)

In [30]:
model = BertClassifier(freeze_bert=False)
model = nn.DataParallel(model).cuda()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
# Split train & validation dataset
import numpy as np

optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
bce_loss_fn = nn.BCELoss()

num_epochs = 15

best_val_loss = np.inf
best_step = 0
best_model = None

logs = {'test_acc':0., 'train_loss':[], 'val_acc':[], 'val_loss':[]}
for i in range(num_epochs):
    print(f'Epoch {i}')
    
    train_log = train_epoch(model, optimizer, train_dataloader, bce_loss_fn)
    val_log = val_epoch(model, val_dataloader, bce_loss_fn)

    for value, key in enumerate(train_log):
        logs[key].append(value)
    for value, key in enumerate(val_log):
        logs[key].append(value)
    
    val_loss = val_log['val_loss']
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_step = i
        best_model = deepcopy(model)
    elif i - best_step > 3:
        break

test_log = test(model, test_dataloader)
logs['test_acc'] = test_log['test_acc']


log_save_dir = '/output'
with open(os.path.join(log_save_dir, f'logs_{datetime.datetime.now()}.json'), 'w') as f:
    json.dump(logs, f)
torch.save(best_model, os.path.join(log_save_dir, f'model_{datetime.datetime.now()}.pth'))

Epoch 0
[Train] Loss    0.07761, [  625 /   625]
[Val  ] Loss    0.29928, Acc    0.90107.
Epoch 1
[Train] Loss    0.04217, [  248 /   625]

In [33]:
test_log = test(model, test_dataloader)

[Test ] Acc    0.91392.
