In [1]:
import torch

from bert_data_utils import get_raw_imdb_data

In [2]:
# 데이터 셋을 불러옴
train_data, valid_data, test_data = get_raw_imdb_data()

In [3]:
from transformers import BertTokenizer

In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

In [5]:
def bert_tokenized_data(tokenizer, data, max_seq_len=128, pad_to_max_len=True):
    sentences = [' '.join(s.text) for s in data]  # I am so ... good .
    labels = [torch.tensor([1]) if l.label == 'pos' else torch.tensor([0]) for l in data]  # [1, 0, 0, ... , 1, ...]

    sentences = [tokenizer.encode_plus(s, max_length=max_seq_len, pad_to_max_length=pad_to_max_len) for s
                 in sentences]
    input_ids = [torch.tensor(s['input_ids']) for s in sentences]
    token_type_ids = [torch.tensor(s['token_type_ids']) for s in sentences]
    attn_mask = [torch.tensor(s['attention_mask']) for s in sentences]
    
    return input_ids, \
           token_type_ids, \
           attn_mask, \
           labels

In [6]:
train_input_ids, train_token_type_ids, train_attn_mask, train_labels = bert_tokenized_data(tokenizer, train_data)
valid_input_ids, valid_token_type_ids, valid_attn_mask, valid_labels = bert_tokenized_data(tokenizer, train_data)
test_input_ids, test_token_type_ids, test_attn_mask, test_labels = bert_tokenized_data(tokenizer, train_data)

In [7]:
from bert_dataset import Corpus
from torch.utils.data import Dataset, DataLoader

In [8]:
from torch.utils.data import Dataset, DataLoader


class Corpus(Dataset):
    def __init__(self, input_ids:list, token_type_ids:list, attn_masks:list, labels:list):
        super().__init__()
        self.input_ids = input_ids
        self.token_type_ids = token_type_ids
        self.attn_masks = attn_masks
        self.labels = labels

    def __getitem__(self, index: int):
        return self.input_ids[index], self.token_type_ids[index], self.attn_masks[index], self.labels[index]

    def __len__(self):
        return len(self.input_ids)

In [9]:
train = Corpus(train_input_ids, train_token_type_ids, train_attn_mask, train_labels)
valid = Corpus(valid_input_ids, valid_token_type_ids, valid_attn_mask, valid_labels)
test = Corpus(test_input_ids, test_token_type_ids, test_attn_mask, test_labels)

params = {'batch_size': 64,
          'shuffle': True,
          'num_workers': 6}

train_loader = DataLoader(train, **params)
valid_loader = DataLoader(valid, **params)
test_loader = DataLoader(test, **params)

In [10]:
from transformers import BertModel, BertPreTrainedModel, BertForSequenceClassification

In [11]:
# Bert 모델 정의
bert_config = 'bert-base-cased'
model = BertForSequenceClassification.from_pretrained(bert_config)

In [18]:
def get_device():
    return 'cuda' if torch.cuda.is_available() else 'cpu'

def get_num_corrects(logits, labels):
    return (logits.max(1)[1] == labels.max(1)[0]).sum()

def train(model, optim, iterator, device):
    epoch_loss = 0
    epoch_acc = 0

    model.train()

    for batch in iterator:
        optim.zero_grad()
        
        input_ids, token_type_ids, attn_masks, labels = batch[0].to(device), batch[1].to(device), batch[2].to(device), batch[3].to(device)
        loss, logits = model(input_ids=input_ids, attention_mask=attn_masks, token_type_ids=token_type_ids, labels=labels)
        num_corrects = get_num_corrects(logits, labels)
        acc = 100.0 * num_corrects.item() / labels.size(0)
        epoch_loss += loss.item()
        epoch_acc += acc

        loss.backward()
        optim.step()

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

def evaluate(model, otpim, iterator, device):
    epoch_loss = 0
    epoch_acc = 0

    model.eval()

    with torch.no_grad():
        for batch in iterator:
        input_ids, token_type_ids, attn_masks, labels = batch[0].to(device), batch[1].to(device), batch[2].to(device), batch[3].to(device)
            loss, logits = model(input_ids=input_ids, attention_mask=attn_masks, token_type_ids=token_type_ids, labels=labels)
            num_corrects = get_num_corrects(logits, labels)
            acc = 100.0 * num_corrects.item() / labels.size(0)
            epoch_loss += loss.item()
            epoch_acc += acc


    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [19]:
N_EPOCHS = 5

best_valid_loss = float('inf')

device = get_device()
optimizer = optim.Adam(model.parameters())
model = model.to(device)

for epoch in range(N_EPOCHS):

    train_loss, train_acc = train(model, optimizer, train_loader, device)
    valid_loss, valid_acc = evaluate(model, optimizer, valid_loader, device)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), './bert_base_cased_sentence_classification.pt')

    print(f'Train Loss: {train_loss} | Train Acc: {train_acc}%')
    print(f'Val Loss: {valid_loss} |  Val Acc: {valid_acc}%')

tensor(35) 64
0.546875
torch.Size([64, 1])
tensor(0.7217, grad_fn=<NllLossBackward>) tensor([[ 0.2341, -0.4232],
        [ 0.1322, -0.3166],
        [ 0.2847, -0.4331],
        [ 0.2817, -0.6284],
        [ 0.5515, -0.5937],
        [ 0.0871, -0.3191],
        [ 0.1095, -0.5467],
        [ 0.0249, -0.3276],
        [ 0.2986, -0.0950],
        [ 0.4858, -0.8031],
        [ 0.5034, -0.4779],
        [ 0.3899, -0.6690],
        [ 0.4067, -0.4875],
        [ 0.2457, -0.5320],
        [ 0.2018, -0.3728],
        [ 0.1991, -0.2996],
        [ 0.3407, -0.3948],
        [ 0.2467, -0.7128],
        [ 0.3178, -0.2568],
        [ 0.4158, -0.2844],
        [ 0.1323, -0.3908],
        [ 0.3363, -0.0263],
        [ 0.3544, -0.3458],
        [ 0.5303, -0.5723],
        [ 0.2716, -0.4681],
        [ 0.3004, -0.2291],
        [ 0.3256, -0.5786],
        [ 0.3188, -0.7150],
        [ 0.1128, -0.3715],
        [ 0.2284, -0.1627],
        [ 0.3739, -0.2684],
        [ 0.5286, -0.5271],
        [ 0.2384, -