In [1]:
import utils_data as ut
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertForMaskedLM, BertTokenizer
tokenizerBERT = BertTokenizer.from_pretrained('pranav-s/MaterialsBERT', model_max_length=512)
modelBERT = BertForMaskedLM.from_pretrained('pranav-s/MaterialsBERT')

In [2]:
classes = {'POLYMER': 1,
           'ORGANIC': 2,
           'MONOMER': 3,
           'PROP_NAME': 4,
           'INORGANIC': 5,
           'MATERIAL_AMOUNT': 6,
           'POLYMER_FAMILY': 7,
           'PROP_VALUE': 8,
           'O': 0}
max_length = 512
batch_size = 3
class NERBERTModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = modelBERT.base_model
        self.linear = nn.Linear(768, len(classes) + 1)
        
    def forward(self, token):
        encoder_output= self.bert(token)  # torch.LongTensor of shape (batch_size, sequence_length)
        linear_output = self.linear(encoder_output.last_hidden_state)
        class_output = F.softmax(linear_output, dim=2)
        return class_output

In [3]:
model = NERBERTModel()

In [4]:
num_data = 100
data_list = ut.read_data('train.json', max_length)[:num_data]

In [8]:
token_tensors_all_list = [ut.list2token(tokenizerBERT, d['words'], max_length) for d in data_list]
data = torch.cat(token_tensors_all_list, dim=0)
data_batches = ut.to_batches(data, batch_size)
target_tensors_all_list = [ut.cat2digit(classes, d['ner'], max_length) for d in data_list]
target = torch.stack(target_tensors_all_list, dim=0)
target_batches = ut.to_batches(target, batch_size)

In [11]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0002)

In [12]:
epochs = 2
train_losses = []
for epoch in range(epochs):
    for b, X in enumerate(data_batches):
        y_pred = model(X)
        y_pred = torch.swapaxes(y_pred, 1, 2)
        y = target_batches[b]
        
        loss = criterion(y_pred, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f'epoch: {epoch:2}  batch: {b:4}  loss: {loss.item():10.8f}')
    train_losses.append(loss)
    

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


epoch:  0  batch:    0  loss: 2.34516931
epoch:  0  batch:    1  loss: 2.07965207
epoch:  0  batch:    2  loss: 1.89227831
epoch:  0  batch:    3  loss: 1.86859322
epoch:  0  batch:    4  loss: 1.74767876
epoch:  0  batch:    5  loss: 1.58365214
epoch:  0  batch:    6  loss: 1.53962743
epoch:  0  batch:    7  loss: 1.52590048
epoch:  0  batch:    8  loss: 1.53084755
epoch:  0  batch:    9  loss: 1.52945483
epoch:  0  batch:   10  loss: 1.53164637
epoch:  0  batch:   11  loss: 1.55527401
epoch:  0  batch:   12  loss: 1.53272164
epoch:  0  batch:   13  loss: 1.52031934
epoch:  0  batch:   14  loss: 1.50896406
epoch:  0  batch:   15  loss: 1.51920211
epoch:  0  batch:   16  loss: 1.53089154
epoch:  0  batch:   17  loss: 1.52350187
epoch:  0  batch:   18  loss: 1.51018870
epoch:  0  batch:   19  loss: 1.50567687
epoch:  0  batch:   20  loss: 1.53294098
epoch:  0  batch:   21  loss: 1.52948940
epoch:  0  batch:   22  loss: 1.52433884
epoch:  0  batch:   23  loss: 1.53214121
epoch:  0  batch

In [20]:
y_pred[0, :, 0]

tensor([9.9898e-01, 1.2063e-04, 8.4833e-05, 7.8012e-05, 8.3950e-05, 1.2519e-04,
        7.3147e-05, 1.1248e-04, 8.3863e-05, 2.5312e-04],
       grad_fn=<SelectBackward0>)

In [18]:
y

tensor([[0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
         0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 6, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 6, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0

In [10]:
data.shape

torch.Size([100, 512])