# RE Approach 1: BiLSTM+Multi-head attention+ Dynamic gate 

# NOTE:  To load the best trained model(saved in the same folder) and test it, please run the LAST 2 cells !

In [1]:
import json
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, AdamW
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import numpy as np

#If run with Google Colab, uncomment the following
#from google.colab import drive
#drive.mount('/content/drive')

In [2]:
#-----------------------------------------------------
# 1: Data Processing - Extracting relationships
#-----------------------------------------------------
def extract_relations(json_line):
    relations = []
    sent_text = json_line['sentText']
    for relation in json_line['relationMentions']:
        em1_text = relation['em1Text']
        em2_text = relation['em2Text']
        label = relation['label']
        relations.append((sent_text, em1_text, em2_text, label))
    return relations

In [3]:
#-----------------------------------------------------
# 2: Read the dataset and process the data
#-----------------------------------------------------

#If you need to load the dataset, replace the local file path with the following
train_input_file = './NYT11/trainNTY11.json'
valid_input_file = './NYT11/validNTY11.json'
test_input_file = './NYT11/testNTY11.json'

#If run with Google Colab, uncomment the following:
#train_input_file = '/content/drive/MyDrive/trainNTY11.json'
#valid_input_file = '/content/drive/MyDrive/validNTY11.json'
#test_input_file = '/content/drive/MyDrive/testNTY11.json'


train_data = []
val_data = []
test_data = []

with open(train_input_file, 'r') as file:
    for line in file:
        json_line = json.loads(line)
        relations = extract_relations(json_line)
        train_data.extend(relations)
train_df = pd.DataFrame(train_data, columns=['sentence', 'entity1', 'entity2', 'label'])


with open(valid_input_file, 'r') as file:
    for line in file:
        json_line = json.loads(line)
        relations = extract_relations(json_line)
        val_data.extend(relations)
valid_df = pd.DataFrame(val_data, columns=['sentence', 'entity1', 'entity2', 'label'])

with open(test_input_file, 'r') as file:
    for line in file:
        json_line = json.loads(line)
        relations = extract_relations(json_line)
        test_data.extend(relations)
test_df = pd.DataFrame(test_data, columns=['sentence', 'entity1', 'entity2', 'label'])


# Change the markup to a uniform format (remove the first character and replace/with _)
#train_df['label'] = train_df['label'].str[1:]
#valid_df['label'] = valid_df['label'].str[1:]
#test_df['label'] = test_df['label'].str[1:]
train_df['label'] = train_df['label'].str.replace('/', '_')
valid_df['label'] = valid_df['label'].str.replace('/', '_')
test_df['label'] = test_df['label'].str.replace('/', '_')

# Change all data types to strings
train_df = train_df.astype(str)
valid_df = valid_df.astype(str)
test_df = test_df.astype(str)

In [4]:
#-----------------------------------------------------
# 3: Data preprocessing functions
#-----------------------------------------------------

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
label_encoder = {label: i for i, label in enumerate(train_df['label'].unique())}
max_length = 96 
print(label_encoder)
def preprocess_data(row):
    sentence = row['sentence']
    entity1 = row['entity1']
    entity2 = row['entity2']

    # Tokenize sentence，get tokenized 后的 input_ids
    encoded_sentence = tokenizer.encode_plus(
        sentence,
        add_special_tokens=True,
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors='pt'
    )

    input_ids = encoded_sentence['input_ids'].squeeze()
    attention_mask = encoded_sentence['attention_mask'].squeeze()

    # Get the tokenized entity location
    tokenized_text = tokenizer.tokenize(sentence)
    entity1_tokens = tokenizer.tokenize(entity1)
    entity2_tokens = tokenizer.tokenize(entity2)

    entity1_pos = [i for i, token in enumerate(tokenized_text) if token in entity1_tokens]
    entity2_pos = [i for i, token in enumerate(tokenized_text) if token in entity2_tokens]

    # Prevent boundaries
    entity1_pos = min(entity1_pos[0] if entity1_pos else 0, max_length - 1)
    entity2_pos = min(entity2_pos[0] if entity2_pos else 0, max_length - 1)

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'entity1_pos': torch.tensor(entity1_pos),
        'entity2_pos': torch.tensor(entity2_pos),
        'label': torch.tensor(label_encoder[row['label']])
    }
train_data = train_df.apply(preprocess_data, axis=1).tolist()
valid_data = valid_df.apply(preprocess_data, axis=1).tolist()
test_data = test_df.apply(preprocess_data, axis=1).tolist()

{'None': 0, '_location_location_contains': 1, '_location_administrative_division_country': 2, '_location_country_administrative_divisions': 3, '_location_country_capital': 4, '_people_person_children': 5, '_people_person_place_lived': 6, '_people_person_nationality': 7, '_business_company_place_founded': 8, '_location_neighborhood_neighborhood_of': 9, '_people_person_place_of_birth': 10, '_sports_sports_team_location': 11, '_sports_sports_team_location_teams': 12, '_people_deceased_person_place_of_death': 13, '_business_company_founders': 14, '_business_person_company': 15, '_business_company_major_shareholders': 16, '_business_company_shareholder_major_shareholder_of': 17, '_people_ethnicity_people': 18, '_people_person_ethnicity': 19, '_business_company_advisors': 20, '_people_person_religion': 21, '_people_ethnicity_geographic_distribution': 22, '_people_person_profession': 23, '_business_company_industry': 24}


In [5]:
#-----------------------------------------------------
# 4. Custom dataset class function definition
#-----------------------------------------------------
class RelationshipDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Create DataLoader
train_dataset = RelationshipDataset(train_data)
valid_dataset = RelationshipDataset(valid_data)
test_dataset = RelationshipDataset(test_data)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=16, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [6]:
#*****************************************************
#-----------------------------------------------------
# 5:【Optimization part】Gating unit design function

# Note: 
# This DynamicGate module dynamically combines two features (original and transformed) by using a gating mechanism. It uses a linear layer followed by a Sigmoid activation to calculate the gating coefficients, which determine how much of each feature contributes to the final output.
#      Gate Layer: Generates coefficients to control the fusion of features.
#      Forward Pass: Concatenates the features, calculates the gate, and performs weighted fusion.
# The module allows the model to adaptively adjust the influence of each feature based on the input.
#-----------------------------------------------------
#*****************************************************

class DynamicGate(torch.nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
       # Gated network design (using linear layers to generate conditioning coefficients)
        self.gate_layer = torch.nn.Sequential(
            torch.nn.Linear(hidden_size * 4, hidden_size * 2),  # Enter concatenated double dimensions
            torch.nn.Sigmoid()
        )

    def forward(self, original, transformed):
        # Concatenate both features as gating input
        combined = torch.cat([original, transformed], dim=-1)  # [batch, seq_len, hidden*4]

        # Generate dynamic gating coefficients
        gate = self.gate_layer(combined)  # [batch, seq_len, hidden*2]

        # Gated fusion (element-wise adjustment)）
        return gate * original + (1 - gate) * transformed


In [7]:
#*****************************************************
#-----------------------------------------------------
# 6: 【Optimization part】Overall model design function: BiLSTMModel+ Multi-head Attention + Dynamic gate

# Note: 
# The BiLSTM (Bidirectional LSTM) layer processes input data in both forward and backward directions, allowing 
# the model to capture information from the past and future of a sequence, which is helpful for tasks like sequence classification.

# The Multi-head Attention layer focuses on different parts of the input sequence, learning which parts are most 
# important for the task. It assigns different attention weights to different parts of the sequence, helping the model focus on relevant information.
#-----------------------------------------------------
#*****************************************************

import torch

class BiLSTMModelWithAttention(torch.nn.Module):
    def __init__(self, vocab_size, hidden_size, num_layers, num_heads, num_classes):
        super(BiLSTMModelWithAttention, self).__init__()
        # 1. embedding layer
        self.embedding = torch.nn.Embedding(vocab_size, hidden_size)

        # 2. BiLSTM layer
        self.lstm = torch.nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True)

        # 3. Multi-head Attention layer
        self.multihead_attn = torch.nn.MultiheadAttention(
            embed_dim=hidden_size * 2,
            num_heads=num_heads,
            batch_first=True)

        # 4. Dynamic gate layer
        self.dynamic_gate = DynamicGate(hidden_size)

       
        # Modify MLP input size dynamically based on hidden_size
        mlp_input_dim = hidden_size * 10  # Adjust based on actual concatenation
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(mlp_input_dim, 512),
            torch.nn.BatchNorm1d(512),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5),

            torch.nn.Linear(512, 256),
            torch.nn.BatchNorm1d(256),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.3),

            torch.nn.Linear(256, num_classes)
        )

    def forward(self, input_ids, attention_mask, entity1_pos, entity2_pos):
        # Get embedding representation [batch_size, seq_len, hidden_size]
        embeds = self.embedding(input_ids)

        # BiLSTM output [batch_size, seq_len, hidden_size*2]
        lstm_out, _ = self.lstm(embeds)

        # Handle attention mask
        key_padding_mask = (attention_mask == 0)  # [batch_size, seq_len]

        # Multi-head Attention Computation [batch_size, seq_len, hidden_size*2]
        attn_out, attn_weights = self.multihead_attn(
            query=lstm_out,
            key=lstm_out,
            value=lstm_out,
            key_padding_mask=key_padding_mask,
            average_attn_weights=False)

        # Compute dynamic gate weights
        combined = torch.cat([lstm_out, attn_out], dim=-1)  # [batch, seq_len, hidden*4]
        gate_weights = self.dynamic_gate.gate_layer(combined)

        # Apply dynamic gating
        refined_features = self.dynamic_gate(lstm_out, attn_out)

        # Extract entity features
        batch_size = input_ids.size(0)

        # Ensure the index does not exceed the maximum range
        # Ensure that the index does not exceed the maximum range
        entity1_pos = torch.min(entity1_pos, torch.tensor(max_length-1))
        entity2_pos = torch.min(entity2_pos, torch.tensor(max_length-1))

        entity1_hidden = refined_features[torch.arange(batch_size), entity1_pos]  # [batch, hidden*2]
        entity2_hidden = refined_features[torch.arange(batch_size), entity2_pos]  # [batch, hidden*2]

        # Compute additional features
        feature_diff = torch.abs(entity1_hidden - entity2_hidden)
        feature_mul = entity1_hidden * entity2_hidden

        # Concatenate all features [batch, hidden_size*5]
        combined = torch.cat([entity1_hidden, entity2_hidden, feature_diff, feature_mul, refined_features.mean(dim=1)], dim=1)

        # Classification through MLP
        logits = self.mlp(combined)

        return logits, attn_weights, gate_weights

In [8]:
#-----------------------------------------------------
# 7. Set up the training loop
#-----------------------------------------------------

num_labels = len(train_df['label'].unique())

input_size = len(tokenizer.vocab)   # Vocabulary size
embedding_size = 256  # Dimension of word embedding
hidden_size = 256     # LSTM hidden layer dimension (bidirectional total dimension)
num_layers = 2
num_classes = num_labels
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_epochs = 5

# Initializing the model
bi_lstm_model = BiLSTMModelWithAttention(
    vocab_size=input_size,
    hidden_size=hidden_size,
    num_layers=num_layers,
    num_heads=16,
    num_classes=num_classes
).to(device)

optimizer = torch.optim.AdamW(bi_lstm_model.parameters(), lr=5e-5)

In [9]:
#-----------------------------------------------------
# 8. Train and evaluate function definitions
#-----------------------------------------------------

def evaluate(model, dataloader):
    model.eval()
    total_preds = []
    total_labels = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            entity1_pos = batch['entity1_pos'].to(device)
            entity2_pos = batch['entity2_pos'].to(device)
            labels = batch['label'].to(device)

            logits, attn_weights, gate_weights = model(input_ids, attention_mask, entity1_pos, entity2_pos)

            # Use torch.max only for logits
            _, preds = torch.max(logits, dim=1)

            total_preds.extend(preds.cpu().tolist())
            total_labels.extend(labels.cpu().tolist())

    return accuracy_score(total_labels, total_preds)

In [10]:
#-----------------------------------------------------
# 9. Training and evaluation process
#-----------------------------------------------------

best_valid_accuracy = 0.0
best_model_path = 'best_BiLSTMmodelmlp.pth'  # save best model path

for epoch in range(num_epochs):
    bi_lstm_model.train()
    total_loss = 0
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for batch_idx, batch in enumerate(progress_bar):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        entity1_pos = batch['entity1_pos'].to(device)
        entity2_pos = batch['entity2_pos'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()

        # Forward pass including position parameters
        #Note:attn_weights, gate_weights are not used, just reserved for printing the progress of the process if is needed.
        logits, attn_weights, gate_weights = bi_lstm_model(input_ids, attention_mask, entity1_pos, entity2_pos)

        # Calculate the cross-entropy loss
        loss = torch.nn.functional.cross_entropy(logits, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress_bar.set_postfix({'loss': total_loss / (batch_idx + 1)})

    # Calculate accuracy for Train and Valid set
    train_accuracy = evaluate(bi_lstm_model, train_dataloader)
    valid_accuracy = evaluate(bi_lstm_model, valid_dataloader)
    print(f"Epoch {epoch+1}/{num_epochs} completed. Train Accuracy: {train_accuracy:.4f}. validation Accuracy: {valid_accuracy:.4f}")
    
    # Save the model if validation accuracy is better than previous
    if valid_accuracy 
    > best_valid_accuracy:
        best_valid_accuracy = valid_accuracy
        torch.save(bi_lstm_model.state_dict(), best_model_path)
        print(f"Best model saved with accuracy: {best_valid_accuracy:.4f}")

Epoch 1/5:   0%|          | 0/20991 [00:00<?, ?it/s]

Epoch 1/5 completed. Train Accuracy: 0.8672. validation Accuracy: 0.8516
Best model saved with accuracy: 0.8516


Epoch 2/5:   0%|          | 0/20991 [00:00<?, ?it/s]

Epoch 2/5 completed. Train Accuracy: 0.8963. validation Accuracy: 0.8715
Best model saved with accuracy: 0.8715


Epoch 3/5:   0%|          | 0/20991 [00:00<?, ?it/s]

Epoch 3/5 completed. Train Accuracy: 0.9159. validation Accuracy: 0.8817
Best model saved with accuracy: 0.8817


Epoch 4/5:   0%|          | 0/20991 [00:00<?, ?it/s]

Epoch 4/5 completed. Train Accuracy: 0.9254. validation Accuracy: 0.8836
Best model saved with accuracy: 0.8836


Epoch 5/5:   0%|          | 0/20991 [00:00<?, ?it/s]

Epoch 5/5 completed. Train Accuracy: 0.9336. validation Accuracy: 0.8842
Best model saved with accuracy: 0.8842


In [11]:
#-----------------------------------------------------
# 10. Assessment module function
#-----------------------------------------------------

relation_cls_label_map = {
    0: 'None',
    1: '_location_location_contains',
    2: '_location_administrative_division_country',
    3: '_location_country_administrative_divisions',
    4: '_location_country_capital',
    5: '_people_person_children',
    6: '_people_person_place_lived',
    7: '_people_person_nationality',
    8: '_business_company_place_founded',
    9: '_location_neighborhood_neighborhood_of',
    10: '_people_person_place_of_birth',
    11: '_sports_sports_team_location',
    12: '_sports_sports_team_location_teams',
    13: '_people_deceased_person_place_of_death',
    14: '_business_company_founders',
    15: '_business_person_company',
    16: '_business_company_major_shareholders',
    17: '_business_company_shareholder_major_shareholder_of',
    18: '_people_ethnicity_people',
    19: '_people_person_ethnicity',
    20: '_business_company_advisors',
    21: '_people_person_religion',
    22: '_people_ethnicity_geographic_distribution',
    23: '_people_person_profession',
    24: '_business_company_industry'
}


# Categories to ignore (usually the "unrelated" category)
ignore_rel_list = ['None']
def get_threshold(data, preds):
    max_f1 = -1.0
    best_th = -1.0
    cur_th = 0.0

    while cur_th < 1.0:
        pred_pos, gt_pos, correct_pos, total_correct, total_samples= get_F1(data, preds, threshold=cur_th)
        p = float(correct_pos) / (pred_pos + 1e-8)
        r = float(correct_pos) / (gt_pos + 1e-8)
        cur_f1 = (2 * p * r) / (p + r + 1e-8)

        if cur_f1 > max_f1:
            max_f1 = cur_f1
            best_th = cur_th
        cur_th += 0.01  # The best threshold was searched with a step size of 0.01

    return best_th

def get_F1(data, preds_probs, threshold=0.0):
    gt_pos = 0
    pred_pos = 0
    correct_pos = 0
    total_correct = 0  
    total_samples = len(data) 

    for i in range(len(preds_probs)):
        true_label_idx = data[i]['label'].item()
        org_rel_name = relation_cls_label_map[true_label_idx]
        pred_val = np.argmax(preds_probs[i])
        pred_rel_name = relation_cls_label_map[pred_val]
        max_prob = np.max(preds_probs[i])

        # Adjusted prediction: If the prediction is non-None and the probability > threshold is retained, otherwise it is treated as None
        if pred_rel_name not in ignore_rel_list and max_prob > threshold:
            adjusted_pred = pred_val
        else:
            adjusted_pred = 0  # the index


        if adjusted_pred == true_label_idx:
            total_correct += 1

        if org_rel_name not in ignore_rel_list:
            gt_pos += 1
        if (pred_rel_name not in ignore_rel_list) and (max_prob > threshold):
            pred_pos += 1
        if (org_rel_name == pred_rel_name) and (org_rel_name not in ignore_rel_list) and (max_prob > threshold):
            correct_pos += 1

    return pred_pos, gt_pos, correct_pos, total_correct, total_samples

def evaluate_metricsNew(data, preds_probs, threshold=0.0):
    pred_pos, gt_pos, correct_pos, total_correct, total_samples = get_F1(data, preds_probs, threshold)

    precision = correct_pos / (pred_pos + 1e-8)
    recall = correct_pos / (gt_pos + 1e-8)
    f1 = (2 * precision * recall) / (precision + recall + 1e-8)
    accuracy = total_correct / total_samples

    # Outputting each metric
    print(f"Threshold = {threshold:.2f}")
    print(f"Accuracy = {accuracy:.4f}")
    print(f"Precision = {precision:.4f}")
    print(f"Recall = {recall:.4f}")
    print(f"F1 = {f1:.4f}")
    return precision, recall, f1, accuracy

def evaluate_metric_modelNew(model, dataloader, dataset, threshold=0.0):
    model.eval()
    all_pred_probs = []
    all_valid_probs = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            entity1_pos = batch['entity1_pos'].to(device)
            entity2_pos = batch['entity2_pos'].to(device)

            outputs = model(input_ids, attention_mask, entity1_pos, entity2_pos)
            logits = outputs[0] if isinstance(outputs, tuple) else outputs
            probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy()
            all_pred_probs.extend(probs)
    with torch.no_grad():
        for batch in valid_dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            entity1_pos = batch['entity1_pos'].to(device)
            entity2_pos = batch['entity2_pos'].to(device)

            outputs = model(input_ids, attention_mask, entity1_pos, entity2_pos)
            logits = outputs[0] if isinstance(outputs, tuple) else outputs
            probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy()
            all_valid_probs.extend(probs)

    if len(all_pred_probs) != len(dataset):
        raise ValueError(f"The number of predictions ({len(all_pred_probs)}) does not match the number of dataset samples({len(dataset)})")
    threshold = get_threshold(valid_dataset, all_valid_probs)
    #print(threshold)

    # Call the evaluation function and return the accuracy
    precision, recall, f1, accuracy = evaluate_metricsNew(dataset, all_pred_probs, threshold)
    return accuracy, precision, recall, f1

In [12]:

#-----------------------------------------------------
# 11. Load the best trained model and output the score of test set
#-----------------------------------------------------

bi_lstm_model.load_state_dict(torch.load(best_model_path))
bi_lstm_model.eval()

accuracy, precision, recall, f1 = evaluate_metric_modelNew(bi_lstm_model, test_dataloader, test_dataset, threshold=0.01)

  bi_lstm_model.load_state_dict(torch.load(best_model_path))


Threshold = 0.42
Accuracy = 0.5662
Precision = 0.4171
Recall = 0.6385
F1 = 0.5046


# --------------------------------------END----------------------------------------

# 

# To load the best trained model(saved in the same folder) and test it, please run the following 2 cells !  You can input your test cases in the second cell below.

In [1]:
#-----------------------------------------------------
# 12: For teachers to test our model
#-----------------------------------------------------

import sys
import subprocess

def install(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])

# List of packages to install
required_packages = ["pandas", "torch", "transformers", "tqdm", "scikit-learn", "numpy", "gdown"]

# Install the missing libraries
for package in required_packages:
    try:
        __import__(package)
        print(f"{package} Installed")
    except ImportError:
        print(f"Installing {package} ...")
        install(package)

#-----------------------------------------------------
# Load the model and related variables
#-----------------------------------------------------
import torch
import os
import gdown
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        
# Google Drive File link
file_id = "1w7zZ1a3-nOzq_Vb3YvJyoJAVp1e_H5NQ"
best_model_path = "./best_BiLSTMmodel.pth"  # Local storage path

# If the model file does not exist, it is downloaded
if not os.path.exists(best_model_path):
    gdown.download(f"https://drive.google.com/uc?id={file_id}", best_model_path, quiet=False)
    print("Model downloaded successfully!")


#Related parameter introduction
num_labels=25
input_size = len(tokenizer.vocab)   # Vocabulary size
embedding_size = 256  # Dimension of word embedding
hidden_size = 256     # LSTM hidden layer dimension (bidirectional total dimension)
num_layers = 2
max_length=96
num_classes = num_labels
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
label_encoder = {'None': 0, 'location_location_contains': 1, 'location_administrative_division_country': 2, 'location_country_administrative_divisions': 3, 'location_country_capital': 4, 'people_person_children': 5, 'people_person_place_lived': 6, 'people_person_nationality': 7, 'business_company_place_founded': 8, 'location_neighborhood_neighborhood_of': 9, 'people_person_place_of_birth': 10, 'sports_sports_team_location': 11, 'sports_sports_team_location_teams': 12, 'people_deceased_person_place_of_death': 13, 'business_company_founders': 14, 'business_person_company': 15, 'business_company_major_shareholders': 16, 'business_company_shareholder_major_shareholder_of': 17, 'people_ethnicity_people': 18, 'people_person_ethnicity': 19, 'business_company_advisors': 20, 'people_person_religion': 21, 'people_ethnicity_geographic_distribution': 22, 'people_person_profession': 23, 'business_company_industry': 24}

#Related Functions introduction
class DynamicGate(torch.nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
       # Gated network design (using linear layers to generate conditioning coefficients)
        self.gate_layer = torch.nn.Sequential(
            torch.nn.Linear(hidden_size * 4, hidden_size * 2),  # Enter concatenated double dimensions
            torch.nn.Sigmoid()
        )

    def forward(self, original, transformed):
        # Concatenate both features as gating input
        combined = torch.cat([original, transformed], dim=-1)  # [batch, seq_len, hidden*4]

        # Generate dynamic gating coefficients
        gate = self.gate_layer(combined)  # [batch, seq_len, hidden*2]

        # Gated fusion (element-wise adjustment)）
        return gate * original + (1 - gate) * transformed

class BiLSTMModelWithAttention(torch.nn.Module):
    def __init__(self, vocab_size, hidden_size, num_layers, num_heads, num_classes):
        super(BiLSTMModelWithAttention, self).__init__()
        # 1. embedding layer
        self.embedding = torch.nn.Embedding(vocab_size, hidden_size)

        # 2. BiLSTM layer
        self.lstm = torch.nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True)

        # 3. Multi-head Attention layer
        self.multihead_attn = torch.nn.MultiheadAttention(
            embed_dim=hidden_size * 2,
            num_heads=num_heads,
            batch_first=True)

        # 4. Dynamic gate layer
        self.dynamic_gate = DynamicGate(hidden_size)

       
        # Modify MLP input size dynamically based on hidden_size
        mlp_input_dim = hidden_size * 10  # Adjust based on actual concatenation
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(mlp_input_dim, 512),
            torch.nn.BatchNorm1d(512),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5),

            torch.nn.Linear(512, 256),
            torch.nn.BatchNorm1d(256),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.3),

            torch.nn.Linear(256, num_classes)
        )

    def forward(self, input_ids, attention_mask, entity1_pos, entity2_pos):
        # Get embedding representation [batch_size, seq_len, hidden_size]
        embeds = self.embedding(input_ids)

        # BiLSTM output [batch_size, seq_len, hidden_size*2]
        lstm_out, _ = self.lstm(embeds)

        # Handle attention mask
        key_padding_mask = (attention_mask == 0)  # [batch_size, seq_len]

        # Multi-head Attention Computation [batch_size, seq_len, hidden_size*2]
        attn_out, attn_weights = self.multihead_attn(
            query=lstm_out,
            key=lstm_out,
            value=lstm_out,
            key_padding_mask=key_padding_mask,
            average_attn_weights=False)

        # Compute dynamic gate weights
        combined = torch.cat([lstm_out, attn_out], dim=-1)  # [batch, seq_len, hidden*4]
        gate_weights = self.dynamic_gate.gate_layer(combined)

        # Apply dynamic gating
        refined_features = self.dynamic_gate(lstm_out, attn_out)

        # Extract entity features
        batch_size = input_ids.size(0)

        # Ensure the index does not exceed the maximum range
        # Ensure that the index does not exceed the maximum range
        entity1_pos = torch.min(entity1_pos, torch.tensor(max_length-1))
        entity2_pos = torch.min(entity2_pos, torch.tensor(max_length-1))

        entity1_hidden = refined_features[torch.arange(batch_size), entity1_pos]  # [batch, hidden*2]
        entity2_hidden = refined_features[torch.arange(batch_size), entity2_pos]  # [batch, hidden*2]

        # Compute additional features
        feature_diff = torch.abs(entity1_hidden - entity2_hidden)
        feature_mul = entity1_hidden * entity2_hidden

        # Concatenate all features [batch, hidden_size*5]
        combined = torch.cat([entity1_hidden, entity2_hidden, feature_diff, feature_mul, refined_features.mean(dim=1)], dim=1)

        # Classification through MLP
        logits = self.mlp(combined)

        return logits, attn_weights, gate_weights
    
# Load saved optimal model
def load_best_model(model, model_path):
    #model.load_state_dict(torch.load(model_path))
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) 
    model.eval()
    print(f"BiLSTM Model loaded successfully!")
    return model

#Function definition for relation prediction on an input sentence and two entities using the loaded model
def predict_relationship_with_saved_model(sentence, entity1, entity2, model, tokenizer, label_encoder, max_length=256):
    tokens = tokenizer.tokenize(sentence)
    entity1_pos = [i for i, token in enumerate(tokens) if token in tokenizer.tokenize(entity1)]
    entity2_pos = [i for i, token in enumerate(tokens) if token in tokenizer.tokenize(entity2)]

    # tokenize the input sentence
    encoded_sentence = tokenizer.encode_plus(
        sentence,
        add_special_tokens=True,
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors='pt')

    input_ids = encoded_sentence['input_ids'].to(device)
    attention_mask = encoded_sentence['attention_mask'].to(device)
    entity1_pos = torch.tensor(entity1_pos[0] if entity1_pos else 0).unsqueeze(0).to(device)
    entity2_pos = torch.tensor(entity2_pos[0] if entity2_pos else 0).unsqueeze(0).to(device)

    # Make predictions using the loaded model
    with torch.no_grad():
        outputs = model(input_ids, attention_mask, entity1_pos, entity2_pos)
        logits = outputs[0]
        _, preds = torch.max(logits, dim=1)

    # The labels are obtained by reverse mapping
    reverse_label_encoder = {v: k for k, v in label_encoder.items()}
    predicted_label = reverse_label_encoder[preds.item()]

    return predicted_label

bi_lstm_model = BiLSTMModelWithAttention(
    vocab_size=input_size,
    hidden_size=hidden_size,
    num_layers=num_layers,
    num_heads=16,
    num_classes=num_classes
).to(device)

# load best model
bi_lstm_model = load_best_model(bi_lstm_model, best_model_path)


Installing pandas ...
Collecting pandas
  Using cached pandas-2.2.3-cp310-cp310-macosx_11_0_arm64.whl.metadata (89 kB)
Collecting numpy>=1.22.4 (from pandas)
  Downloading numpy-2.2.3-cp310-cp310-macosx_14_0_arm64.whl.metadata (62 kB)
Collecting pytz>=2020.1 (from pandas)
  Using cached pytz-2025.1-py2.py3-none-any.whl.metadata (22 kB)
Collecting tzdata>=2022.7 (from pandas)
  Using cached tzdata-2025.1-py2.py3-none-any.whl.metadata (1.4 kB)
Using cached pandas-2.2.3-cp310-cp310-macosx_11_0_arm64.whl (11.3 MB)
Downloading numpy-2.2.3-cp310-cp310-macosx_14_0_arm64.whl (5.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.4/5.4 MB[0m [31m22.0 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hUsing cached pytz-2025.1-py2.py3-none-any.whl (507 kB)
Using cached tzdata-2025.1-py2.py3-none-any.whl (346 kB)
Installing collected packages: pytz, tzdata, numpy, pandas
Successfully installed numpy-2.2.3 pandas-2.2.3 pytz-2025.1 tzdata-2025.1
Installing torch ...
Collecting

Downloading...
From (original): https://drive.google.com/uc?id=1w7zZ1a3-nOzq_Vb3YvJyoJAVp1e_H5NQ
From (redirected): https://drive.google.com/uc?id=1w7zZ1a3-nOzq_Vb3YvJyoJAVp1e_H5NQ&confirm=t&uuid=4bf226e4-c3f9-465f-a440-96a7c964eea7
To: /Users/zhouqiaoqiao/Desktop/【61332】 Text Mining/CW最终提交材料/best_BiLSTMmodel.pth
100%|██████████| 53.9M/53.9M [00:02<00:00, 23.8MB/s]


Model downloaded successfully!
BiLSTM Model loaded successfully!


In [2]:
#-----------------------------------------------------
# 13: Input Module： Please input your test case here!
#-----------------------------------------------------

# The sentence and two entities (entity1 and entity2) are provided as input for relationship prediction.
# Users can modify these variables to test the model with different sentences and entities.

# In this example, 'sentence' is a string that represents the sentence where the entities appear.
# 'entity1' and 'entity2' are the two entities whose relationship you want to predict. You can replace 
# these with any sentence and entities of your choice.

# To test the module:
# 1. Modify the 'sentence' variable with your desired sentence that contains two entities.
# 2. Modify 'entity1' with the first entity (a person, organization, or any other entity) in the sentence.
# 3. Modify 'entity2' with the second entity (the relationship you want to identify between entity1 and entity2).

# After modifying these values, the model will predict the relationship between the two entities in the sentence.
# The predicted relationship will be printed out.


sentence = "otecna employed Kojo Annan , Kofi Annan 's son , as a contractor at the time it received the aid inspection contract"
entity1 = "Kojo Annan"
entity2 = "Kofi Annan"

# Relationship prediction using the loaded model
predicted_relationship = predict_relationship_with_saved_model(sentence, entity1, entity2, bi_lstm_model, tokenizer, label_encoder)
print(f"Predicted relationships：{predicted_relationship}")


Predicted relationships：people_person_children
