In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/ee-pretrained/event-classififcation-pretrained.pt
/kaggle/input/ee-pretrained/trigger-detection-pretrained.pt
/kaggle/input/ee-pretrain-biobert/trigger-detection-biobert.pt
/kaggle/input/ee-pretrain-biobert/event-cls-biobert.pt


In [2]:
from collections import OrderedDict

import torch
from torch import nn
from transformers import BertTokenizerFast, BertForTokenClassification, BertTokenizer, BertModel

In [3]:
def load_trigger_model(model_path, device='cuda' if torch.cuda.is_available() else 'cpu'):
    """
    Tải tokenizer và mô hình từ đường dẫn đã lưu.
    """
    tokenizer = BertTokenizerFast.from_pretrained('dmis-lab/biobert-base-cased-v1.2')

    model = BertForTokenClassification.from_pretrained('dmis-lab/biobert-base-cased-v1.2', num_labels=2)

    # Tải state_dict từ tệp .pt
    state_dict = torch.load(model_path, map_location=device)
    model.load_state_dict(state_dict)

    model.to(device)
    model.eval()
    return tokenizer, model, device

In [4]:
def trigger_predict(sentence, tokenizer, model, device, max_len=128):
    """
    Dự đoán nhãn cho một câu đầu vào.
    """
    words = sentence.split()
    encoding = tokenizer(
        words,
        is_split_into_words=True,
        return_offsets_mapping=False,
        padding='max_length',
        truncation=True,
        max_length=max_len,
        return_tensors='pt'
    )
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)

    logits = outputs.logits
    preds = torch.argmax(logits, dim=-1).squeeze().cpu().numpy()
    word_ids = encoding.word_ids(batch_index=0)

    labels = []
    previous_word_idx = None
    for idx, word_idx in enumerate(word_ids):
        if word_idx is None:
            continue
        elif word_idx != previous_word_idx:
            labels.append(preds[idx])
            previous_word_idx = word_idx

    return list(zip(words, labels))

In [5]:
label_to_index_model2 = {
    "Negative_regulation": 0,
    "Gene_expression": 1,
    "Regulation": 2,
    "Transcription": 3,
    "Positive_regulation": 4,
    "Binding": 5,
    "Localization": 6,
    "Phosphorylation": 7,
    "Protein_catabolism": 8,
}

In [6]:
class BertClassifier(nn.Module):
    """
    BERT-based classifier for event classification.
    """

    def __init__(self, num_classes):
        super(BertClassifier, self).__init__()
        self.bert = BertModel.from_pretrained("dmis-lab/biobert-base-cased-v1.2")
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        dropout_output = self.dropout(pooled_output)
        return self.classifier(dropout_output)

In [7]:
def load_event_model(model_path, device='cuda' if torch.cuda.is_available() else 'cpu'):
    """
    Load the Event Classification tokenizer and model from the saved path.
    """
    # Check if the model file exists
    # if not os.path.isfile(model_path):
    #     raise FileNotFoundError(f"Event Classification model file '{model_path}' does not exist.")

    # Load the tokenizer for Event Classification
    tokenizer = BertTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.2')

    # Initialize the Event Classification model with the appropriate number of classes
    num_classes = len(label_to_index_model2)
    model = BertClassifier(num_classes=num_classes)

    # Load the state_dict from the .pt file
    state_dict = torch.load(model_path, map_location=device)

    # If the state_dict has 'module.' prefix (from DataParallel), remove it
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if k.startswith('module.'):
            k = k[len('module.'):]
        new_state_dict[k] = v

    # Load the state_dict into the model
    model.load_state_dict(new_state_dict)
    model.to(device)
    model.eval()

    return tokenizer, model, device

In [8]:
def prepare_input(words, tokenizer, max_len=32, is_split_into_words=True):
    """
    Prepare the input for the models by tokenizing the words.
    :param words: List of words to tokenize.
    :param tokenizer: Tokenizer to use.
    :param max_len: Maximum length of the token sequence.
    :param is_split_into_words: Whether the input is already split into words.
    :return: input_ids and attention_mask tensors.
    """
    encoding = tokenizer(
        words,
        is_split_into_words=is_split_into_words,
        return_offsets_mapping=False,
        padding='max_length',
        truncation=True,
        max_length=max_len,
        return_tensors='pt'
    )
    input_ids = encoding['input_ids']
    attention_mask = encoding['attention_mask']
    return input_ids, attention_mask


index_to_label_model2 = {v: k for k, v in label_to_index_model2.items()}

In [9]:
def predict_event(word, tokenizer, model, device, max_len=32):
    """
    Predict the event type for a single trigger word using the Event Classification model.
    :param word: The trigger word to classify.
    :param tokenizer: Tokenizer for Event Classification model.
    :param model: Event Classification model.
    :param device: Device to run the model on.
    :param max_len: Maximum length for tokenization.
    :return: Predicted event type label.
    """
    model.eval()
    # Tokenize the single word
    input_ids, attention_mask = prepare_input([word], tokenizer, max_len)
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs
        preds = torch.argmax(logits, dim=1).cpu().numpy()

    # Map the prediction to the corresponding label
    predicted_label = index_to_label_model2.get(preds[0], "Unknown")
    return predicted_label

In [20]:
def main():
    trigger_model_path = '/kaggle/input/ee-pretrain-biobert/trigger-detection-biobert.pt'
    event_model_path = '/kaggle/input/ee-pretrain-biobert/event-cls-biobert.pt'
    trigger_tokenizer, trigger_model, device = load_trigger_model(trigger_model_path)

    sentence = " Four of the 6 mAECA activated EC , manifested by increased IL-6 and vWF secretion ."
    predictions = trigger_predict(sentence, trigger_tokenizer, trigger_model, device)
    print("Event triggers:")
    trigger_words = []
    for word, label in predictions:
        if label == 1:
            print(f"{word}: {'Trigger'}")
            trigger_words.append(word)
    # Load the Event Classification model
    try:
        event_tokenizer, event_model, device_event = load_event_model(event_model_path)
        print("Successfully loaded Event Classification tokenizer and model.")
    except FileNotFoundError as e:
        print(e)
        return

    # If there are trigger words, perform Event Classification
    if trigger_words:
        print("\nEvent Classification Results:")
        for trigger_word in trigger_words:
            event_label = predict_event(trigger_word, event_tokenizer, event_model, device_event)
            print(f"{trigger_word}: {event_label}")
    else:
        print("\nNo Trigger Words detected in the sentence.")

In [21]:
if __name__ == "__main__":
    main()

Some weights of BertForTokenClassification were not initialized from the model checkpoint at dmis-lab/biobert-base-cased-v1.2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  state_dict = torch.load(model_path, map_location=device)


Event triggers:
increased: Trigger
secretion: Trigger


  state_dict = torch.load(model_path, map_location=device)


Successfully loaded Event Classification tokenizer and model.

Event Classification Results:
increased: Positive_regulation
secretion: Localization
