In [None]:
from google.colab import drive
drive.mount('/content/drive')
import sys
sys.path.insert(1, "drive/MyDrive/workflow/")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install pymatgen transformers nltk ipywidgets seqeval[gpu]
!jupyter nbextension enable --py widgetsnbextension

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Enabling notebook extension jupyter-js-widgets/extension...
Paths used for configuration of notebook: 
    	/root/.jupyter/nbconfig/notebook.json
Paths used for configuration of notebook: 
    	
      - Validating: [32mOK[0m
Paths used for configuration of notebook: 
    	/root/.jupyter/nbconfig/notebook.json


In [None]:
import numpy as np
import pandas as pd
from ipywidgets import widgets
from IPython.display import display
import re
from pymatgen.core import Composition
from torch.utils.data import DataLoader
from torch import cuda
from transformers import BertTokenizerFast
from seqeval.metrics import classification_report
import os
import json

import torch

import psie

import nltk
nltk.download("punkt", quiet=True)

True

In [None]:
device = "cuda" if cuda.is_available() else "cpu"
print(device)

cuda


In [None]:
radio_buttons = widgets.RadioButtons(
    options=["Band Gap", "Curie Temperature"], value="Band Gap", description=''
)
print("Extraction Target: ")
display(radio_buttons)

Extraction Target: 


RadioButtons(options=('Band Gap', 'Curie Temperature'), value='Band Gap')

In [None]:
if radio_buttons.value == "Curie Temperature":
  extr_target = "Tc"
elif radio_buttons.value == "Band Gap":
  extr_target = "Gap"

MAX_LEN = 256
MAIN_DIR = os.path.join("drive", "MyDrive", "workflow")
MODEL_DIR =  os.path.join("models", extr_target, "relation")                    # Fine-tuned Relation cls model
CORPUS = "multiple_mentions_test_extraction.json"
DATABASE_OUT = "relations_extraction"                                           # Name of the output file

Adding the tokens for the relation extraction step to the BERT models vocabulary so that this tags are not splitted into different subwords.

In [None]:
tokenizer = BertTokenizerFast.from_pretrained(os.path.join(MAIN_DIR, MODEL_DIR))
new_tokens = ["[E1]", "[/E1]", "[E2]", "[/E2]"]
tokenizer.add_tokens(list(new_tokens))

4

In [None]:
with open(os.path.join(MAIN_DIR, "extraction", extr_target, CORPUS), "r") as f:
  data = json.load(f)

data = psie.fromNer(data)

In [None]:
ner_dataset = {'sentence': [], 'isrelated': [], 'source': []}

for i in range(len(data['sentence'])):
    if ('[E1]' in data['sentence'][i]) and ('[E2]' in data['sentence'][i]):
        ner_dataset['sentence'].append(str(data['sentence'][i]))
        ner_dataset['isrelated'].append(None)
        ner_dataset['source'].append(data['source'][i])

In [None]:
ner = psie.RelationDataset(
    ner_dataset, tokenizer, max_len=MAX_LEN
)

ner_params = {"batch_size": 8, "shuffle": False, "num_workers": 0}

ner_loader = DataLoader(ner, **ner_params)

model = psie.BertForRelations(pretrained=os.path.join(MAIN_DIR, MODEL_DIR), dropout=0.2, use_cls_embedding=True)
model.bert.resize_token_embeddings(len(tokenizer))
model.to(device)

BertForRelations(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(31094, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_

# Load model

In [None]:
model.load_state_dict(torch.load(os.path.join(MAIN_DIR, MODEL_DIR+ ".pt"), map_location=torch.device(device)))

<All keys matched successfully>

### Predictions on the BERT NER output

In [None]:
pred = model.predict(ner_loader, device)

predictions = []

for i in range(len(pred)):
  predictions.append(np.argmax(pred[i].cpu().numpy()))

In [None]:
database = {"compound": [], extr_target: [], "sentence": [], "source": []}

for i in range(len(predictions)):
  if predictions[i] == 1:
    comp = re.findall(re.escape("[E1]")+".*"+re.escape("[/E1]"), ner_dataset['sentence'][i])
    temp = re.findall(re.escape("[E2]")+".*"+re.escape("[/E2]"), ner_dataset['sentence'][i])

    if (len(comp)>0) and (len(temp)>0):
      comp = comp[0].replace("[E1]", "").replace("[/E1]", "").replace(" ", "")
      database["compound"].append(comp)
      temp = temp[0].replace("[E2]", "").replace("[/E2]", "").replace(" ", "")
      database[extr_target].append(temp)
      database["sentence"].append(ner_dataset['sentence'][i])
      database["source"].append(ner_dataset['source'][i])

The chemical entity is converted to a Composition object from pymatgen and its reduced formula is taken

In [None]:
database = pd.DataFrame(database)

valid_i = []

for i, comp in enumerate(database['compound']):
  try:
    Composition(comp).get_reduced_formula_and_factor()[0]
    valid_i.append(i)
  except:
    print(comp, '\t', database['sentence'][i], '\n\n')                          # The entries that raise an exception are printed for debugging purpose

In [None]:
print("Database entries:", len(valid_i), "/", len(database['sentence']))

Database entries: 0 / 0


In [None]:
database.iloc[valid_i].to_csv(os.path.join(MAIN_DIR, "extraction", extr_target, DATABASE_OUT+".csv"))

In [None]:
database.iloc[valid_i].head()

Unnamed: 0,compound,Gap,sentence,source
