In [1]:
import utils_data as ut
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertForMaskedLM, BertTokenizer
tokenizerBERT = BertTokenizer.from_pretrained('pranav-s/MaterialsBERT', model_max_length=512)
modelBERT = BertForMaskedLM.from_pretrained('pranav-s/MaterialsBERT')

In [33]:
classes = {'POLYMER': 1,
           'ORGANIC': 2,
           'MONOMER': 3,
           'PROP_NAME': 4,
           'INORGANIC': 5,
           'MATERIAL_AMOUNT': 6,
           'POLYMER_FAMILY': 7,
           'PROP_VALUE': 8,
           'O': 0}
max_length = 512
batch_size = 2
class NERBERTModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = modelBERT.base_model
        self.linear = nn.Linear(768, len(classes) + 1)
        
    def forward(self, token):
        encoder_output= self.bert(token)  # torch.LongTensor of shape (batch_size, sequence_length)
        print(encoder_output.last_hidden_state.shape)
        linear_output = self.linear(encoder_output.last_hidden_state)
        class_output = F.softmax(linear_output, dim=2)
        return class_output

In [34]:
model = NERBERTModel()

In [19]:
num_data = 5
data_list = ut.read_data('train.json', max_length)[:num_data]

In [25]:
token_tensors_all_list = [ut.list2token(tokenizerBERT, d['words'], max_length) for d in data_list]
X = torch.cat(token_tensors_all_list, dim=0)
print(X.shape)
target_tensors_all_list = [ut.cat2digit(classes, d['ner'], max_length) for d in data_list]
y = torch.cat(token_tensors_all_list, dim=0)
print(y.shape)

torch.Size([2, 512])
torch.Size([2, 512])


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
epochs = 1
train_losses = []
for epoch in range(epochs):
    for data in data_list[:100]:
        token = ut.list2token(tokenizerBERT, data['words'], max_length)
        prediction = model(token)
        target = torch.tensor(ut.cat2digit(classes, data['ner'], max_length))
        loss = criterion(prediction, target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    train_losses.append(loss)
    print(f'epoch: {epoch:2}  loss: {loss.item():10.8f}')