### Load Data

In [1]:
from fastNLP.io import SSTLoader

# initialize the SSTLoader
loader = SSTLoader()
# download the dataset to the default cache directory and returns the directory
data_dir = loader.download()
# load the dataset from the directory to a DataBundle   
data_bundle = loader.load(data_dir)  

In [2]:
print(data_bundle)

In total 3 datasets:
	dev has 1101 instances.
	test has 2210 instances.
	train has 8544 instances.



### Preprocessing

In [3]:
from fastNLP.io import SSTPipe

pipe = SSTPipe(subtree=False, 
               train_subtree=True, 
               lower=False, 
               granularity=5, tokenizer='spacy')

# (1) tokenization;
# (2) create vocabulary and indices from words;
data_bundle = pipe.process(data_bundle)  

print(data_bundle)

In total 3 datasets:
	dev has 1101 instances.
	test has 2210 instances.
	train has 318582 instances.
In total 2 vocabs:
	words has 20204 entries.
	target has 5 entries.



In [4]:
print(data_bundle.get_dataset('train')[:5])

+------------------------+--------+------------------------+---------+
| raw_words              | target | words                  | seq_len |
+------------------------+--------+------------------------+---------+
| The Rock is destine... | 1      | [21, 1215, 11, 5536... | 39      |
| The Rock               | 0      | [21, 1215]             | 2       |
| The                    | 0      | [21]                   | 1       |
| Rock                   | 0      | [1215]                 | 1       |
| is destined to be t... | 3      | [11, 5536, 8, 26, 2... | 37      |
+------------------------+--------+------------------------+---------+


In [5]:
vocab = data_bundle.get_vocab('words')
print(vocab)

Vocabulary(['The', 'Rock', 'is', 'destined', 'to']...)


In [6]:
index = vocab.to_index('new')
print("The index of the word 'new' is {}".format(index))
print("index:{} corresponds to the word {}".format(index, vocab.to_word(index)))

The index of the word 'new' is 133
index:133 corresponds to the word new


### Word Embedding

In [7]:
from fastNLP.embeddings import BertEmbedding

# loading BertEmbedding
bert_embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='en-base-uncased', 
                          requires_grad=True, layers='-4, -3, -2, -1', include_cls_sep=True)

loading vocabulary file /home/ubuntu/.fastNLP/embedding/bert-base-uncased/vocab.txt
Load pre-trained BERT parameters from file /home/ubuntu/.fastNLP/embedding/bert-base-uncased/pytorch_model.bin.


### Load Training/Testing/Validation Set

In [8]:
train_data = data_bundle.get_dataset('train')
test_data = data_bundle.get_dataset('test')
val_data = data_bundle.get_dataset('dev')

print("#entries in training set:{}\n#entries in testing set:{}\n#entries in validation set:{}\n"
      .format(len(train_data), len(test_data), len(val_data)))

#entries in training set:318582
#entries in testing set:2210
#entries in validation set:1101



In [9]:
# NOTE: field whose is_input is true will become batch_x for iterations in DataSetIter 
#       while field whose is_target is true will become batch_y for iterations in DataSetIter

train_data.print_field_meta()

+-------------+-----------+--------+-------+---------+
| field_names | raw_words | target | words | seq_len |
+-------------+-----------+--------+-------+---------+
|   is_input  |   False   | False  |  True |   True  |
|  is_target  |   False   |  True  | False |  False  |
| ignore_type |           | False  | False |  False  |
|  pad_value  |           |   0    |   0   |    0    |
+-------------+-----------+--------+-------+---------+


<prettytable.prettytable.PrettyTable at 0x7f74488c03d0>

### Create Model

In [10]:
from fastNLP.models import BertForSequenceClassification

# loading model
model_bertnn = BertForSequenceClassification(bert_embed, len(data_bundle.get_vocab('target')))

print(model_bertnn)

BertForSequenceClassification(
  (bert): BertEmbedding(
    (dropout_layer): Dropout(p=0, inplace=False)
    (model): _BertWordModel(
      (encoder): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(30522, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )

### Evaluation Metric

In [11]:
from fastNLP import AccuracyMetric
from fastNLP import Const

# `pred` corresponds to one key from the returned dict by the `forward` method of the model
# `target` corresponds to the field name representing the lable of the DataSet
metrics = AccuracyMetric(pred=Const.OUTPUT, target=Const.TARGET)

### Loss Function & Optimizer

In [12]:
from fastNLP import CrossEntropyLoss

loss = CrossEntropyLoss(pred=Const.OUTPUT, target=Const.TARGET)

In [13]:
from fastNLP import Adam

optimizer = Adam(model_params=model_bertnn.parameters(), lr=2e-5)


### Train the Model

In [14]:
from fastNLP import Trainer
import torch

N_EPOCHS = 5 
BATCH_SIZE = 16

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

trainer = Trainer(train_data=train_data, model=model_bertnn,
                  optimizer=optimizer,
                  loss=loss, device=device,
                  batch_size=BATCH_SIZE, dev_data=val_data,
                  metrics=metrics, n_epochs=N_EPOCHS, print_every=1,
                  save_path='./saved_models/sst5-bert.pt')

trainer.train()

input fields after batch(if batch size is 2):
	words: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 39]) 
	seq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) 
target fields after batch(if batch size is 2):
	target: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) 

training epochs started 2021-04-19-11-41-27-421164


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=99560.0), HTML(value='')), layout=Layout(…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=69.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 2.22 seconds!
Evaluation on dev at Epoch 1/5. Step:19912/99560: 
AccuracyMetric: acc=0.53406



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=69.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 2.2 seconds!
Evaluation on dev at Epoch 2/5. Step:39824/99560: 
AccuracyMetric: acc=0.545867



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=69.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 2.21 seconds!
Evaluation on dev at Epoch 3/5. Step:59736/99560: 
AccuracyMetric: acc=0.526794



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=69.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 2.2 seconds!
Evaluation on dev at Epoch 4/5. Step:79648/99560: 
AccuracyMetric: acc=0.510445



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=69.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 2.2 seconds!
Evaluation on dev at Epoch 5/5. Step:99560/99560: 
AccuracyMetric: acc=0.520436

Reloaded the best model.

In Epoch:2/Step:39824, got best dev performance:
AccuracyMetric: acc=0.545867


{'best_eval': {'AccuracyMetric': {'acc': 0.545867}},
 'best_epoch': 2,
 'best_step': 39824,
 'seconds': 9266.43}

### Test the Model

In [15]:
from fastNLP import Tester

tester = Tester(test_data, model_bertnn, metrics=AccuracyMetric())
tester.test()

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=139.0), HTML(value='')), layout=Layout(di…

Evaluate data in 4.45 seconds!
[tester] 
AccuracyMetric: acc=0.533484


{'AccuracyMetric': {'acc': 0.533484}}