In [18]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [19]:
questions = {
    "Cause-Effect(e1,e2)": ["What is caused by <e1>?", "What causes <e2>?"],
    "Cause-Effect(e2,e1)": ["What causes <e1>?", "What is caused by <e2>?"],
    "Instrument-Agency(e1,e2)": ["What uses <e1>?", "What is the instrument of <e2>?"],
    "Instrument-Agency(e2,e1)": ["What is the instrument of <e1>?", "What uses <e2>?"],
    "Product-Producer(e1,e2)": ["What produces <e1>?", "What is the product of <e2>?"],
    "Product-Producer(e2,e1)": ["What is the product of <e1>?", "What produces <e2>?"],
    "Content-Container(e1,e2)": ["What stores <e1>?", "What is stored inside <e2>?"],
    "Content-Container(e2,e1)": ["What is stored inside <e1>?", "What stores <e2>?"],
    "Entity-Origin(e1,e2)": ["What is the origin of <e1>?", "What entity originates from <e2>?"],
    "Entity-Origin(e2,e1)": ["What entity originates from <e1>?", "What is the origin of <e2>?"],
    "Entity-Destination(e1,e2)": ["What is the destination of <e1>?", "What entity is moving toward <e2>?"],
    "Entity-Destination(e2,e1)": ["What entity is moving toward <e1>?", "What is the destination of <e2>?"],
    "Component-Whole(e1,e2)": ["What entity does <e1> operate within?", "What functional component of <e2> is mentioned?"],
    "Component-Whole(e2,e1)": ["What functional component of <e1> is mentioned?", "What entity does <e2> operate within?"],
    "Member-Collection(e1,e2)": ["Which collection/organization does <e1> belong to?", "Who are the members of <e2>?"],
    "Member-Collection(e2,e1)": ["Who are the members of <e1>?", "Which collection/organization does <e2> belong to?"],
    "Message-Topic(e1,e2)": ["What is the main topic discussed in <e1>?", "Which message contains information about <e2>?"],
    "Message-Topic(e2,e1)": ["Which message contains information about <e1>?", "What is the main topic discussed in <e2>?"],
}

In [20]:
import torch
from transformers import AutoModel, AutoTokenizer

model_name = "huawei-noah/TinyBERT_General_6L_768D"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tinybert = AutoModel.from_pretrained(model_name)

class TinyBERTForQA(torch.nn.Module):
    def __init__(self, tinybert):
        super().__init__()
        self.tinybert = tinybert
        self.qa_outputs = torch.nn.Linear(768, 2)  # Add QA head for start and end logits

    def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.tinybert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        logits = self.qa_outputs(outputs.last_hidden_state)
        start_logits, end_logits = logits.split(1, dim=-1)  # Split into start and end logits
        return start_logits.squeeze(-1), end_logits.squeeze(-1)

# Initialise new model
model = TinyBERTForQA(tinybert)

- Change the directory to load the model

In [21]:
directory = "/content/drive/My Drive/relation_extraction_model_final.pth"
model.load_state_dict(torch.load(directory))
model.eval()

  model.load_state_dict(torch.load(directory))


TinyBERTForQA(
  (tinybert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwi

In [22]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

TinyBERTForQA(
  (tinybert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwi

In [23]:
import re
def inference(sentence):
    clean_sentence = re.sub(r"</?e[12]>", " ", sentence)

    # Extract e1 and e2
    e1_match = re.search(r"<e1>(.*?)</e1>", sentence)
    e2_match = re.search(r"<e2>(.*?)</e2>", sentence)
    entities = {
        'e1': e1_match.group(1) if e1_match else None,
        'e2': e2_match.group(1) if e2_match else None
    }
    answer = []
    for rel, q_list in questions.items(): # Loop for each relation
        # first_entity, second_entity = rel[-7:].replace('(', '').replace(')', '').split(',')
        switch = False
        for q in q_list: # Loop for each question
            question = q.replace('<e1>', entities['e1']).replace('<e2>', entities['e2']) # Sub the entity to the question

            if switch:
                true_answer = entities['e1']
            else:
                true_answer = entities['e2']
            switch = not switch

            # Tokenizer the input
            inputs = tokenizer(
                clean_sentence,
                question,
                return_tensors="pt",
                truncation=True,
                max_length=80,
                padding="max_length",
                return_offsets_mapping=True
            )

            # Prepare to convert the start and end token position back to word
            tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze())
            offsets = inputs["offset_mapping"].squeeze()

            with torch.no_grad():
                outputs = model(
                    input_ids=inputs["input_ids"].to(device),
                    attention_mask=inputs["attention_mask"].to(device),
                    token_type_ids=inputs["token_type_ids"].to(device)
                )

                start_logits, end_logits = outputs # start and end position matrix
                pred_start = torch.argmax(start_logits, dim=1).item() # Choose the highest value start pos
                pred_end = torch.argmax(end_logits, dim=1).item() # Choose the highest value end pos

                # pred_end is not before pred_start
                if pred_start <= pred_end:
                    start_char = offsets[pred_start][0].item()
                    end_char = offsets[pred_end][1].item()

                    predicted_word = clean_sentence[start_char:end_char]
                    # print(predicted_word)
                    answer.append((rel, question, predicted_word, true_answer))


    # print('Length of answer list: ', len(answer))
    # print(answer)
    for i in range(0, len(answer), 2):
        relation = answer[i][0]
        if answer[i][2] != '' and answer[i+1][2] != '':
            if answer[i][2] == answer[i][3] or answer[i+1][2] == answer[i+1][3]:
                print(answer[i][1:3])
                print(answer[i+1][1:3])
                return relation
    return 'Other'

- INPUT THE SENTENCE HERE

In [26]:
sentences = ["The most common <e1>audits</e1> were about <e2>waste</e2> and recycling.",
             "The <e1>computer</e1> is kept in a common <e2>area</e2> within our home.",
             "Broken <e1>bones</e1> (also called fractures) in the <e2>foot</e2> are very common.",
             "The <e1>company</e1> has mocked up a <e2>version</e2> of YouTube built around the HTML5 video tag, playing mini-movies inside a browser sans plug-ins.",
             "After handsome renovations at various locales, the <e1>company</e1> has remodeled a church into a <e2>home</e2>.",
             "This <e1>football match</e1> is <e2>interesting</e2>.",
             "The two <e1>countries</e1> are related through an unbroken common <e2>history</e2> spanning three-hundred and thirty-nine years.",
             "Basic diagrams also work well on the <e1>computer</e1> <e2>screen</e2> if they are carefully designed to match the grid of pixels on the screen.",
             "The city of Chicago lost 725000 residents between 1950 and 2000, yet 82 percent of the suburban <e1>growth</e1> was from outside the metropolitan <e2>area</e2>.",
             "While making <e1>observations</e1> the microfossil through the binocular microscope or on a computer <e2>monitor</e2>, the investigator needed to manually move the specimen.",
             "The <e1>storm</e1> was generated by an intense <e2>cold front</e2> moving across drought-affected areas in South Australia and NSW.",
             "<e1>Fainting</e1> is a common cause of <e2>unconsciousness</e2> and may occur when the casualty's heart rate is too slow to maintain sufficient blood pressure for the brain."]

In [27]:
for sentence in sentences:
    print('The relation of the sentence is: ', inference(sentence))
    print('----------------------------------------------------------')

('What is the main topic discussed in audits?', 'waste')
('Which message contains information about waste?', 'audits')
The relation of the sentence is:  Message-Topic(e1,e2)
----------------------------------------------------------
The relation of the sentence is:  Other
----------------------------------------------------------
('What entity does bones operate within?', 'foot')
('What functional component of foot is mentioned?', 'bones')
The relation of the sentence is:  Component-Whole(e1,e2)
----------------------------------------------------------
The relation of the sentence is:  Other
----------------------------------------------------------
The relation of the sentence is:  Other
----------------------------------------------------------
The relation of the sentence is:  Other
----------------------------------------------------------
('Which message contains information about countries?', 'history')
('What is the main topic discussed in history?', 'The two  countries')
The r