In [1]:
import torch
from transformers import BertConfig, BertTokenizerFast, BertForSequenceClassification, AdamW

from data import SentiDataset
from learner import SentimentLearner

torch.manual_seed(41)

config = BertConfig(num_labels=1)

### Hyperparameters

In [2]:
BATCH_SIZE = 64
OPTIM_CLS = AdamW

### Load data

In [3]:
FILENAME = 'senti.{}.tsv'

tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

train_set = SentiDataset(FILENAME.format('train'), tokenizer=tokenizer)
valid_set = SentiDataset(FILENAME.format('dev'), tokenizer=tokenizer)
test_set = SentiDataset(FILENAME.format('test'), tokenizer=tokenizer)

# Fine tune BERT

### Load model and learner

In [4]:
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', config=config)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [5]:
learner = SentimentLearner(
    model=model,
    train_set=train_set,
    valid_set=valid_set,
    test_set=test_set,
    batch_size=BATCH_SIZE,
    optim_cls=OPTIM_CLS,
    lr=1e-6
)

### Train model

In [6]:
bert_filename = 'bert.pt'
learner.train(epochs=10, filename=bert_filename)

Epoch 01: 100%|██████████| 1053/1053 [06:39<00:00,  2.64it/s, Loss=0.498, Acc=0.762]


	Train Loss: 0.428	Train Acc: 80.39%
	Valid Loss: 0.281	Valid Acc: 89.56%
	Model parameters saved to bert.pt


Epoch 02: 100%|██████████| 1053/1053 [06:43<00:00,  2.61it/s, Loss=0.237, Acc=0.905]


	Train Loss: 0.263	Train Acc: 89.76%
	Valid Loss: 0.250	Valid Acc: 90.25%
	Model parameters saved to bert.pt


Epoch 03: 100%|██████████| 1053/1053 [06:38<00:00,  2.64it/s, Loss=0.224, Acc=0.952]


	Train Loss: 0.226	Train Acc: 91.20%
	Valid Loss: 0.239	Valid Acc: 91.06%
	Model parameters saved to bert.pt


Epoch 04: 100%|██████████| 1053/1053 [07:50<00:00,  2.24it/s, Loss=0.0502, Acc=1]    


	Train Loss: 0.201	Train Acc: 92.23%
	Valid Loss: 0.231	Valid Acc: 91.40%
	Model parameters saved to bert.pt


Epoch 05: 100%|██████████| 1053/1053 [07:22<00:00,  2.38it/s, Loss=0.43, Acc=0.81]   


	Train Loss: 0.183	Train Acc: 93.06%
	Valid Loss: 0.232	Valid Acc: 91.06%


Epoch 06: 100%|██████████| 1053/1053 [08:11<00:00,  2.14it/s, Loss=0.0756, Acc=1]    


	Train Loss: 0.168	Train Acc: 93.74%
	Valid Loss: 0.241	Valid Acc: 91.17%


Epoch 07: 100%|██████████| 1053/1053 [09:03<00:00,  1.94it/s, Loss=0.0895, Acc=0.952]


	Train Loss: 0.154	Train Acc: 94.37%
	Valid Loss: 0.245	Valid Acc: 91.51%


Epoch 08: 100%|██████████| 1053/1053 [08:47<00:00,  2.00it/s, Loss=0.0409, Acc=1]    


	Train Loss: 0.144	Train Acc: 94.79%
	Valid Loss: 0.249	Valid Acc: 91.17%


Epoch 09: 100%|██████████| 1053/1053 [08:32<00:00,  2.05it/s, Loss=0.139, Acc=0.952] 


	Train Loss: 0.133	Train Acc: 95.22%
	Valid Loss: 0.262	Valid Acc: 91.17%


Epoch 10: 100%|██████████| 1053/1053 [08:23<00:00,  2.09it/s, Loss=0.0351, Acc=1]    


	Train Loss: 0.125	Train Acc: 95.46%
	Valid Loss: 0.266	Valid Acc: 90.83%


### Load best model to evaluate

In [7]:
learner.load_model_params(bert_filename)

In [8]:
%%time
learner.print_test_results()

	 Test Loss: 0.208	 Test Acc: 91.60%
CPU times: user 2.86 s, sys: 1.88 s, total: 4.74 s
Wall time: 4.75 s
