# Introduction
**This notebook is considered to fine-tune the pre-trained SciBERT model to solve the NLI problem equivalent to ORKG's templates recommendation service.**

The notebook requires the data to be uploaded manually to google drive (please search for "TODO" below). Please check [this repository](https://gitlab.com/TIBHannover/orkg/nlp/experiments/orkg-templates-recommendation/) for further information about generating the data files.

For any other questions, please contact omar.araboghli@outlook.com or jennifer.dsouza@tib.eu

# Requirements

In [None]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

!pip uninstall -y torch
!pip install torch==1.8.2+cpu torchvision==0.9.2+cpu -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl
!pip install transformers

In [None]:
import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# TODO: adapt the file paths.
!cp '/content/drive/MyDrive/master_thesis/templates_recommendation/04_4/training_set.json' './training_set.json'
!cp '/content/drive/MyDrive/master_thesis/templates_recommendation/04_4/test_set.json' './test_set.json'
!cp '/content/drive/MyDrive/master_thesis/templates_recommendation/04_4/validation_set.json' './validation_set.json'

# Constants / Hyperparameters

In [4]:
from datetime import datetime

MODEL_FILE_NAME = 'orkgnlp-templates-recommendation-scibert.pt'
N_EPOCHS = 10
WARMUP_PERCENT = 0.2
BATCH_SIZE = 16
LEARNING_RATE = 2e-5
EPS = 1e-6
MAX_SEQUENCE_LENGTH = 512
TRAINING_SET_PATH = './training_set.json'
TEST_SET_PATH = './test_set.json'
VALIDATION_SET_PATH = './validation_set.json'
CLASSES = {
    'entailment': 0,
    'contradiction': 1,
    'neutral': 2
}

# TODO: replace v1 with your directory name for the output of tensorboard events
RUNS_LOG_DIR = 'runs/v1/'
# TODO: please use your own google bucket name where the model should be uploaded
BUCKET = ''

# Classes

In [5]:
import json

from torch.utils.data import Dataset

class BERTDataset(Dataset):
    def __init__(self, path, get_attention_mask, get_token_type, tokenize, tokens_to_ids, target_transform):

        self.data = BERTDataset.read_json(path)
        self.get_attention_mask = get_attention_mask
        self.get_token_type = get_token_type
        self.tokenize = tokenize
        self.tokens_to_ids = tokens_to_ids
        self.target_transform = target_transform

    def __len__(self):
        return len(self.data['entailments'] + self.data['contradictions'] + self.data['neutrals'])

    def __getitem__(self, idx):
        instances = self.data['entailments'] + self.data['contradictions'] + self.data['neutrals']

        # append the [PAD] character to the sequence, so that all sequences in 
        # a batch have the same length (MAX_SEQUENCE_LENGTH)
        item = self.tokenize(instances[idx]['sequence'])
        item += ['[PAD]'] * (MAX_SEQUENCE_LENGTH - len(item))
        
        attention_mask = self.get_attention_mask(item)
        token_type = self.get_token_type(item)
        label = self.target_transform(instances[idx]['target'])

        return self.tokens_to_ids(item)[:MAX_SEQUENCE_LENGTH], attention_mask[:MAX_SEQUENCE_LENGTH], token_type[:MAX_SEQUENCE_LENGTH], label

    @staticmethod
    def read_json(input_path):
        with open(input_path, encoding='utf-8') as f:
            json_data = json.load(f)

        return json_data


In [6]:
import torch.nn as nn

class BERTNLIModel(nn.Module):
    def __init__(self, bert_model, output_dim):
        super().__init__()

        self.bert = bert_model
        embedding_dim = bert_model.config.to_dict()['hidden_size']
        self.out = nn.Linear(embedding_dim, output_dim)

    def forward(self, sequence, attn_mask, token_type):
        embedded = self.bert(input_ids=sequence, attention_mask=attn_mask, token_type_ids=token_type)[1]
        output = self.out(embedded)
        return output

In [7]:
from transformers import BertModel, BertTokenizer

class SciBERT:
    DEFAULT_PATH = 'allenai/scibert_scivocab_uncased'

    @staticmethod
    def model(path=DEFAULT_PATH):
        return BertModel.from_pretrained(path)

    @staticmethod
    def tokenizer(path=DEFAULT_PATH):
        return BertTokenizer.from_pretrained(path)

# Utils

In [8]:
def get_attention_mask(tokens):
    return [1] * len(tokens)


def get_token_type(tokens):
    sep_index = tokens.index('[SEP]') + 1
    return [0] * sep_index + [1] * (len(tokens) - sep_index)


def target_transform(label_text):
  return CLASSES[label_text]


def collate_fn(instances):
  """
    produces DataLoader's batches with different shapes
  """
  batch = []

  for i in range(len(instances[0])):
      batch.append([instance[i] for instance in instances])

  return batch


def categorical_accuracy(predictions, y):
    max_predictions = predictions.argmax(dim=1, keepdim=True)
    correct = (max_predictions.squeeze(1) == y).float()

    return correct.sum() / len(y)

# Data Preparation

In [None]:
from torch.utils.data import DataLoader

# Tokenize
bert_tokenizer = SciBERT.tokenizer()

# Data
training_data = BERTDataset(TRAINING_SET_PATH, get_attention_mask, get_token_type, bert_tokenizer.tokenize, bert_tokenizer.convert_tokens_to_ids, target_transform)
validation_data = BERTDataset(VALIDATION_SET_PATH, get_attention_mask, get_token_type, bert_tokenizer.tokenize, bert_tokenizer.convert_tokens_to_ids, target_transform)

# Data loaders
training_data_loader = DataLoader(training_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
validation_data_loader = DataLoader(validation_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

# Learning Loop

In [None]:
import torch

def train(model, data_loader, optimizer, criterion, scheduler):
  epoch_loss = 0
  epoch_acc = 0
  model.train()

  for i, batch in enumerate(data_loader):
    optimizer.zero_grad() # clear gradients first

    sequence = torch.tensor(batch[0]).to(device)
    attn_mask = torch.tensor(batch[1]).to(device)
    token_type = torch.tensor(batch[2]).to(device)
    label = torch.tensor(batch[3]).to(device)

    predictions = model(sequence, attn_mask, token_type)
    loss = criterion(predictions, label)
    acc = categorical_accuracy(predictions, label)

    loss.backward()

    xm.optimizer_step(optimizer)
    xm.mark_step()
    scheduler.step()

    epoch_loss += loss.item()
    epoch_acc += acc.item()

  return epoch_loss / len(data_loader), epoch_acc / len(data_loader)

In [None]:
def evaluate(model, data_loader, criterion, device):
  epoch_loss = 0
  epoch_acc = 0
  model.eval()

  with torch.no_grad():

    for i, batch in enumerate(data_loader):
      sequence = torch.tensor(batch[0]).to(device)
      attn_mask = torch.tensor(batch[1]).to(device)
      token_type = torch.tensor(batch[2]).to(device)
      label = torch.tensor(batch[3]).to(device)

      predictions = model(sequence, attn_mask, token_type)
      loss = criterion(predictions, label)
      acc = categorical_accuracy(predictions, label)

      epoch_loss += loss.item()
      epoch_acc += acc.item()

  return epoch_loss / len(data_loader), epoch_acc / len(data_loader)

In [None]:
import math
from torch.utils.tensorboard import SummaryWriter
from transformers import AdamW, get_constant_schedule_with_warmup

writer = SummaryWriter(log_dir=RUNS_LOG_DIR)
total_steps = math.ceil(N_EPOCHS * len(training_data) * 1. / BATCH_SIZE)
warmup_steps = int(total_steps * WARMUP_PERCENT)

bert_nli_model = BERTNLIModel(SciBERT.model(), output_dim=len(CLASSES)).to(device)
optimizer = AdamW(bert_nli_model.parameters(), lr=LEARNING_RATE, eps=EPS, correct_bias=False)
scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps)
criterion = nn.CrossEntropyLoss().to(device)

best_valid_loss = float('inf')

In [None]:
for epoch in range(N_EPOCHS):

  train_loss, train_acc = train(bert_nli_model, training_data_loader, optimizer, criterion, scheduler)
  valid_loss, valid_acc = evaluate(bert_nli_model, validation_data_loader, criterion)

  writer.add_scalar("Loss/train", train_loss, epoch)
  writer.add_scalar("Accuracy/train", train_acc, epoch)
  writer.add_scalar("Loss/validation", valid_loss, epoch)
  writer.add_scalar("Accuracy/validation", valid_acc, epoch)

  print(f'Epoch: {epoch + 1:02}')
  print(f'--> Train Loss: {train_loss:.3f} | Train Acc: {train_acc * 100:.2f}%')
  print(f'--> Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc * 100:.2f}%')

  if valid_loss < best_valid_loss:
    best_valid_loss = valid_loss
    xm.save(bert_nli_model.state_dict(), MODEL_FILE_NAME)
    writer.flush()

Epoch: 01
--> Train Loss: 0.600 | Train Acc: 75.88%
--> Val. Loss: 0.166 | Val. Acc: 95.41%
Epoch: 02
--> Train Loss: 0.160 | Train Acc: 94.27%
--> Val. Loss: 0.152 | Val. Acc: 96.43%
Epoch: 03
--> Train Loss: 0.100 | Train Acc: 96.71%
--> Val. Loss: 0.149 | Val. Acc: 95.41%
Epoch: 04
--> Train Loss: 0.072 | Train Acc: 97.47%
--> Val. Loss: 0.099 | Val. Acc: 96.88%
Epoch: 05
--> Train Loss: 0.035 | Train Acc: 98.97%
--> Val. Loss: 0.099 | Val. Acc: 97.77%
Epoch: 06
--> Train Loss: 0.027 | Train Acc: 99.17%
--> Val. Loss: 0.088 | Val. Acc: 97.32%
Epoch: 07
--> Train Loss: 0.020 | Train Acc: 99.28%
--> Val. Loss: 0.092 | Val. Acc: 96.75%
Epoch: 08
--> Train Loss: 0.008 | Train Acc: 99.85%
--> Val. Loss: 0.119 | Val. Acc: 97.32%
Epoch: 09
--> Train Loss: 0.002 | Train Acc: 100.00%
--> Val. Loss: 0.099 | Val. Acc: 97.77%
Epoch: 10
--> Train Loss: 0.006 | Train Acc: 99.79%
--> Val. Loss: 0.091 | Val. Acc: 98.21%


In [None]:
bert_nli_model.load_state_dict(torch.load(MODEL_FILE_NAME))

test_data = BERTDataset(TEST_SET_PATH, get_attention_mask, get_token_type, bert_tokenizer.tokenize, bert_tokenizer.convert_tokens_to_ids, target_transform)
test_data_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
test_loss, test_acc = evaluate(bert_nli_model, test_data_loader, criterion, device)

writer.add_scalar("Loss/test", test_loss, 0)
writer.add_scalar("Accuracy/test", test_acc, 0)
writer.flush()
writer.close()
print(f'Test Loss: {test_loss:.3f} |  Test Acc: {test_acc * 100:.2f}%')

Test Loss: 0.059 |  Test Acc: 97.28%


In [17]:
def predict_similar_template(model, tokenizer, template_string, paper_string, device):
    model.eval()

    # premise = templates_string, hypothesis = paper_string
    sequence = '[CLS] {} [SEP] {} [SEP]'.format(template_string, paper_string)
    sequence_tokens = tokenizer.tokenize(sequence)

    attention_mask = get_attention_mask(sequence_tokens)
    token_type = get_token_type(sequence_tokens)
    sequence_tokens = tokenizer.convert_tokens_to_ids(sequence_tokens)

    attention_mask = torch.tensor(attention_mask).unsqueeze(0).to(device)
    token_type = torch.tensor(token_type).unsqueeze(0).to(device)
    sequence_tokens = torch.tensor(sequence_tokens).unsqueeze(0).to(device)

    prediction = model(sequence_tokens, attention_mask, token_type)
    print(prediction)
    return prediction

premise = 'ontology learning from data sources has dataset application domain input format output format accuracy f1 score precision knowledge source has data source learning purpose terms learning axiom learning properties learning properties hierarchy learning rule learning class hierarchy learning learning method learning tool evaluation metrics related work validation tool training corpus testing corpus class learning implemented technologies relationships learning instance learning taxonomy learning validation comment assessment of acquired knowledge recall f measure'
hypothesis = 'multimedia ontology learning for automatic annotation and video browsing in this work, we offer an approach to combine standard multimedia analysis techniques with knowledge drawn from conceptual metadata provided by domain experts of a specialized scholarly domain, to learn a domain specific multimedia ontology from a set of annotated examples a standard bayesian network learning algorithm that learns structure and parameters of a bayesian network is extended to include media observables in the learning an expert group provides domain knowledge to construct a basic ontology of the domain as well as to annotate a set of training videos these annotations help derive the associations between high level semantic concepts of the domain and low level mpeg 7 based features representing audio visual content of the videos we construct a more robust and refined version of this ontology by learning from this set of conceptually annotated videos to encode this knowledge, we use mowl, a multimedia extension of web ontology language (owl) which is capable of describing domain concepts in terms of their media properties and of capturing the inherent uncertainties involved we use the ontology specified knowledge for recognizing concepts relevant to a video to annotate fresh addition to the video database with relevant concepts in the ontology these conceptual annotations are used to create hyperlinks in the video collection, to provide an effective video browsing interface to the user'

predict_similar_template(bert_nli_model, bert_tokenizer, premise, hypothesis, device).argmax(dim=-1).item()

# Store files
## Tensorboard events

In [None]:
!zip -r runs.zip ./runs

from google.colab import files
files.download("./runs.zip")

## Upload to Google Cloud Storage

In [None]:
from google.colab import auth
auth.authenticate_user()

# Upload model to bucket
!gsutil cp {MODEL_FILE_NAME} gs://{BUCKET}

# Converting to PretrainedModel

With the subsequent code we convert a BERTNLIModel consisting of transformers.BertModel and torch.nn.Linear to a transformers.BertForSequenceClassification(PretrainedModel) that enables saving and loading it without the need of the class definition of BERTNLIModel, thus increases the interoperability.

In [None]:
from google.colab import auth
auth.authenticate_user()

# get model from bucket
!gsutil cp -r gs://{BUCKET}/{MODEL_FILE_NAME}  {MODEL_FILE_NAME} 

In [None]:
bert_nli_model = BERTNLIModel(SciBERT.model(), output_dim=len(CLASSES)).to(torch.device('cpu'))
bert_nli_model.load_state_dict(torch.load(MODEL_FILE_NAME, map_location=torch.device('cpu')))

In [None]:
from transformers import BertForSequenceClassification

bert_classification = BertForSequenceClassification.from_pretrained(
    'allenai/scibert_scivocab_uncased',
     num_labels=len(CLASSES),
     label2id=CLASSES,
     id2label={value:key for key, value in CLASSES.items()},
     problem_type='single_label_classification'
     )

state_dict = {}
for key, value in bert_nli_model.state_dict().items():
  if key == 'out.weight':
    state_dict['classifier.weight'] = value
  elif key == 'out.bias':
    state_dict['classifier.bias'] = value
  else:
    state_dict[key] = value

bert_classification.load_state_dict(state_dict)

In [None]:
# verify conversion results
premise = 'ontology learning from data sources has dataset application domain input format output format accuracy f1 score precision knowledge source has data source learning purpose terms learning axiom learning properties learning properties hierarchy learning rule learning class hierarchy learning learning method learning tool evaluation metrics related work validation tool training corpus testing corpus class learning implemented technologies relationships learning instance learning taxonomy learning validation comment assessment of acquired knowledge recall f measure'
hypothesis = 'multimedia ontology learning for automatic annotation and video browsing in this work, we offer an approach to combine standard multimedia analysis techniques with knowledge drawn from conceptual metadata provided by domain experts of a specialized scholarly domain, to learn a domain specific multimedia ontology from a set of annotated examples a standard bayesian network learning algorithm that learns structure and parameters of a bayesian network is extended to include media observables in the learning an expert group provides domain knowledge to construct a basic ontology of the domain as well as to annotate a set of training videos these annotations help derive the associations between high level semantic concepts of the domain and low level mpeg 7 based features representing audio visual content of the videos we construct a more robust and refined version of this ontology by learning from this set of conceptually annotated videos to encode this knowledge, we use mowl, a multimedia extension of web ontology language (owl) which is capable of describing domain concepts in terms of their media properties and of capturing the inherent uncertainties involved we use the ontology specified knowledge for recognizing concepts relevant to a video to annotate fresh addition to the video database with relevant concepts in the ontology these conceptual annotations are used to create hyperlinks in the video collection, to provide an effective video browsing interface to the user'
bert_tokenizer = SciBERT.tokenizer()

scibert_model_prediction = predict_similar_template(bert_nli_model, bert_tokenizer, premise, hypothesis, device=torch.device('cpu')).argmax(dim=-1).item()
scibert_classification_model_prediction = predict_similar_template(bert_classification, bert_tokenizer, premise, hypothesis, device=torch.device('cpu')).logits.argmax(dim=-1).item()
assert scibert_model_prediction == scibert_classification_model_prediction

In [None]:
import os

pretrained_dir = os.path.splitext(MODEL_FILE_NAME)[0]
bert_classification.save_pretrained(pretrained_dir)

!zip -r {pretrained_dir}.zip ./{pretrained_dir}

from google.colab import files
files.download('{}.zip'.format(pretrained_dir))