In [None]:
import numpy as np
import pandas as pd
from ipywidgets import widgets
from IPython.display import display
import re
from pymatgen.core import Composition
from torch.utils.data import DataLoader
from torch import cuda
from transformers import BertTokenizerFast
from seqeval.metrics import classification_report
import os
import json

import torch

import psie

import nltk
# nltk.download("punkt", quiet=True)

In [None]:
device = "cuda" if cuda.is_available() else "cpu"
print(device)

In [None]:
radio_buttons = widgets.RadioButtons(
    options=["Solvus"], value="Solvus", description=''
)
print("Extraction Target: ")
display(radio_buttons)

In [None]:
if radio_buttons.value == "Solvus":
  extr_target = "Solvus"

MAX_LEN = 256
MAIN_DIR = os.getcwd()
MODEL_DIR =  os.path.join("models", extr_target, "relation")                    # Fine-tuned Relation cls model
CORPUS = "extraction.json"
DATABASE_OUT = "relations_extraction"                                     # Name of the output file
BERT_VERSION = r'/pretrained_models/m3rg-iitd/matscibert'

添加标识符

In [None]:
tokenizer = BertTokenizerFast.from_pretrained('./models/Tc/relation')
new_tokens = ["[E1]", "[/E1]", "[E2]", "[/E2]"]
tokenizer.add_tokens(list(new_tokens))

In [None]:
with open('./extraction.json',"r",encoding='utf-8') as f:
  data = json.load(f)

In [None]:
data = psie.fromNer(data)
data

In [None]:
len(data['sentence'])

In [None]:
ner_dataset = {'sentence': [], 'isrelated': [], 'source': []}

for i in range(len(data['sentence'])):
    if ('[E1]' in data['sentence'][i]) and ('[E2]' in data['sentence'][i]):
        ner_dataset['sentence'].append(str(data['sentence'][i]))
        ner_dataset['isrelated'].append(None)
        ner_dataset['source'].append(data['source'][i])

In [None]:
ner_dataset

In [None]:
len(ner_dataset['sentence'])

In [None]:
ner = psie.RelationDataset(
    ner_dataset, tokenizer, max_len=MAX_LEN
)

ner_params = {"batch_size": 8, "shuffle": False, "num_workers": 0}

ner_loader = DataLoader(ner, **ner_params)

model = psie.BertForRelations(pretrained='./corpus', dropout=0.2, use_cls_embedding=True)
model.bert.resize_token_embeddings(len(tokenizer))
model.to(device)

### Predictions on the BERT NER output

Load model

In [None]:

model.load_state_dict(torch.load(os.path.join('./relation.pt'), map_location=torch.device(device)),strict=False)

In [None]:
pred = model.predict(val_loader, device)
predictions = []
for i in range(len(pred)):
  predictions.append(np.argmax(pred[i].cpu().numpy()))

In [None]:
predictions

预测

In [None]:
print(len(predictions))
print(sum(predictions))

In [None]:
database = {"compound": [], extr_target: [], "sentence": [], "source": []}

for i in range(len(predictions)):
  if predictions[i] == 1:
    comp = re.findall(re.escape("[E1]")+".*"+re.escape("[/E1]"), ner_dataset['sentence'][i])
    temp = re.findall(re.escape("[E2]")+".*"+re.escape("[/E2]"), ner_dataset['sentence'][i])

    if (len(comp)>0) and (len(temp)>0):
      comp = comp[0].replace("[E1]", "").replace("[/E1]", "").replace(" ", "")
      database["compound"].append(comp)
      temp = temp[0].replace("[E2]", "").replace("[/E2]", "").replace(" ", "")
      database[extr_target].append(temp)
      database["sentence"].append(ner_dataset['sentence'][i])
      database["source"].append(ner_dataset['source'][i])

In [None]:
database = pd.DataFrame(database)
database

In [None]:
database.to_csv('./extraction_new.csv',encoding='utf-8')

### 微调

In [None]:
import random
# 打乱数据集
combined = list(zip(ner_dataset['sentence'], ner_dataset['isrelated'], ner_dataset['source']))
random.shuffle(combined)
shuffled_sentences, shuffled_isrelated, shuffled_sources = zip(*combined)
# 计算数据集长度
length = len(shuffled_sentences)
# 按比例划分数据集
train_size = int(length * 0.7)
train_dataset = {
    'sentence': shuffled_sentences[:train_size],
    'isrelated': shuffled_isrelated[:train_size],
    'source': shuffled_sources[:train_size]
}
val_dataset = {
    'sentence': shuffled_sentences[train_size:],
    'isrelated': shuffled_isrelated[train_size:],
    'source': shuffled_sources[train_size:]
}

In [None]:
tokenizer = BertTokenizerFast.from_pretrained(BERT_VERSION)
new_tokens = ["[E1]", "[/E1]", "[E2]", "[/E2]"]
tokenizer.add_tokens(list(new_tokens))
train_ner = psie.RelationDataset(train_dataset, tokenizer, max_len=MAX_LEN)
ner_params = {"batch_size": 8, "shuffle": False, "num_workers": 0}
train_loader = DataLoader(train_ner, **ner_params)

val_ner = psie.RelationDataset(val_dataset, tokenizer, max_len=MAX_LEN)
ner_params = {"batch_size": 8, "shuffle": False, "num_workers": 0}
val_loader = DataLoader(val_ner, **ner_params)

In [None]:

model = psie.BertForRelations(pretrained=BERT_VERSION, dropout=0.2, use_cls_embedding=True)
model.bert.resize_token_embeddings(len(tokenizer))
model.to(device)

In [None]:
from sklearn.utils.class_weight import compute_class_weight

class_weights = compute_class_weight('balanced', classes=[0, 1], y=train_dataset['isrelated'])
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)
class_weights

In [None]:
from transformers import AdamW, get_linear_schedule_with_warmup

# 初始化优化器和学习率调度器
num_epochs = Num_Epochs
max_norm = 0.5 # 梯度裁剪的最大范数
optimizer = AdamW(model.parameters(), lr=1e-5, weight_decay=0.01, correct_bias=False)  # 添加权重衰减
total_steps = len(train_loader) * num_epochs
num_warmup_steps = int(total_steps * 0.1)  # 10% 的 warmup 步骤
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps=num_warmup_steps, 
                                            num_training_steps=total_steps)

In [None]:
tr_Loss_list = []
val_Loss_list = []
val_Acc_list = []
val_f1_list = []
val_recall_list = []
tr_Acc_list = []
for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    epoch_loss, val_loss_avg, train_accuracy, accuracy, f1, recall = model.finetuning(train_loader, val_loader, device, max_norm, optimizer,weight=class_weights)
    tr_Loss_list.append(epoch_loss)
    tr_Acc_list.append(train_accuracy)
    val_Acc_list.append(accuracy)
    val_Loss_list.append(val_loss_avg)
    val_f1_list.append(f1)
    val_recall_list.append(recall)
    # 更新学习率
    scheduler.step()