# Finetuning scibert-nli for assertions classification

Refs
* https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/nli/training_nli_v2.py
* https://github.com/gsarti/covid-papers-browser/blob/master/scripts/finetune_nli.py
* https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/nli/training_nli.py

In [14]:
!nvidia-smi

Thu Jan 20 20:26:41 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.73.01    Driver Version: 460.73.01    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   47C    P8    10W /  70W |      3MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla T4            Off  | 00000000:00:05.0 Off |                    0 |
| N/A   36C    P8     9W /  70W |      3MiB / 15109MiB |      0%      Default |
|       

In [81]:
path = %pwd
while "src" in path:
    %cd ..
    path = %pwd
    
from torch.utils.data import DataLoader
import math
from sentence_transformers import models, losses, datasets, util
from sentence_transformers import SentencesDataset, LoggingHandler, SentenceTransformer
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, BinaryClassificationEvaluator
from sentence_transformers.readers import NLIDataReader
import logging
from datetime import datetime
import argparse
import os


## Import data

In [16]:
train_data_path = "data/train"
val_data_path = "data/val"
ast_folder_name = "ast"
concept_folder_name = "concept"
rel_folder_name = "rel"
txt_folder_name = "txt"
nli_data_path = "data/nli"

model_name = "gsarti/scibert-nli"
model_save_path = "models/scibert-nli-finetuned-nli"

# task = "ner" # Should be one of "ner", "pos" or "chunk"
# model_checkpoint = "allenai/scibert_scivocab_uncased"
batch_size = 32
assert os.path.exists(nli_data_path), "Please check prepare_nli_dataset.ipynb before"

In [18]:
# Use BERT for mapping tokens to embeddings
word_embedding_model = models.Transformer(model_name)

# Apply mean pooling to get one fixed sized sentence vector
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                               pooling_mode_mean_tokens=True,
                               pooling_mode_cls_token=False,
                               pooling_mode_max_tokens=False)

model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

In [71]:
nli_reader = NLIDataReader(nli_data_path)
num_labels = nli_reader.get_num_labels()

In [72]:
# Convert the dataset to a DataLoader ready for training

# Our training loss

train_data = SentencesDataset(nli_reader.get_examples('train.gz'), model=model)
train_dataloader = DataLoader(train_data, shuffle=True, batch_size=batch_size)
# train_loss = losses.MultipleNegativesRankingLoss(model)
train_loss = losses.SoftmaxLoss(model=model, sentence_embedding_dimension=model.get_sentence_embedding_dimension(), num_labels=num_labels)

val_data = SentencesDataset(nli_reader.get_examples('val.gz'), model=model)
val_dataloader = DataLoader(val_data, shuffle=False, batch_size=batch_size)
val_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(val_data, batch_size=batch_size)
val_evaluator

<sentence_transformers.evaluation.EmbeddingSimilarityEvaluator.EmbeddingSimilarityEvaluator at 0x7f5292c06390>

In [73]:
# Configure the training
num_epochs = 5

warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) #10% of train data for warm-up
print("Warmup-steps: {}".format(warmup_steps))

Warmup-steps: 154


In [75]:

model_save_path = model_save_path+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
model_save_path


'models/scibert-nli-finetuned-nli-2022-01-20_21-38-12'

In [77]:
# Train the model
model.fit(train_objectives=[(train_dataloader, train_loss)],
          evaluator=val_evaluator,
          epochs=num_epochs,
          evaluation_steps=int(len(train_dataloader)*0.1),
          warmup_steps=warmup_steps,
          output_path=model_save_path,
          use_amp=True          #Set to True, if your GPU supports FP16 operations
          )

Epoch:   0%|          | 0/5 [00:00<?, ?it/s]

Iteration:   0%|          | 0/307 [00:00<?, ?it/s]

Iteration:   0%|          | 0/307 [00:00<?, ?it/s]

Iteration:   0%|          | 0/307 [00:00<?, ?it/s]

Iteration:   0%|          | 0/307 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [80]:
# Load the stored model and evaluate its performance on STS benchmark dataset
model = SentenceTransformer(model_save_path)
# test_data = SentencesDataset(examples=sts_reader.get_examples("sts-test.csv"), model=model)
# test_dataloader = DataLoader(test_data, shuffle=False, batch_size=batch_size)
# evaluator = EmbeddingSimilarityEvaluator(test_dataloader)

model.evaluate(val_evaluator)

0.04244127578751249

In [40]:
# Train the model
model.fit(train_objectives=[(train_dataloader, train_loss)],
          evaluator=val_evaluator,
          epochs=num_epochs,
          evaluation_steps=int(len(train_dataloader)*0.1),
          warmup_steps=warmup_steps,
          output_path=model_save_path,
          use_amp=True          #Set to True, if your GPU supports FP16 operations
          )

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

Iteration:   0%|          | 0/307 [00:00<?, ?it/s]

Iteration:   0%|          | 0/307 [00:00<?, ?it/s]

In [116]:
def get_label(sent, problem):
    # anchor embeddings
    label2hyp = {
        "absent": "Patient currently doesn't have {problem}",
        "possible": "Patient may have {problem}",
        "conditional": "Patient has {problem} only under certain conditions",
        "hypothetical": "Patient may develop {problem}",
        "associated_with_someone_else": "{problem} is associated with someone else who is not the patient",
    }
    anchors = {label: model.encode([text.format(problem=problem)]) for label, text in label2hyp.items()}

    embedding = model.encode([sent])
    scores = [util.cos_sim(embedding, anchor) for anchor in anchors.values()]
    return max(zip(scores, anchors.keys()))[1], scores

get_label("Patient may not develop diabetes under conditions", "diabetes")

('hypothetical',
 [tensor([[0.9923]]),
  tensor([[0.9944]]),
  tensor([[0.9941]]),
  tensor([[0.9955]]),
  tensor([[0.9903]])])