In [1]:
pip install transformers[torch]

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.2 -> 24.3.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
from transformers import DistilBertTokenizer, AutoModelForSequenceClassification
from itertools import combinations

  from .autonotebook import tqdm as notebook_tqdm


## Relation Extraction Task

In [7]:
def generate_new_data_relation_feature(text, entities):
    processed_entities = [entity for entity in entities if entity != 'O']
    processed_entities = [entity[2:] if entity.startswith(("B-", "I-")) else entity for entity in processed_entities]
    
    entity_pairs = list(combinations(entities, 2))

    features = []
    for entity_1, entity_2 in entity_pairs:
        # Generate input text
        input_text = f"{text} [SEP] {entity_1} [SEP] {entity_2}"
        
        features.append([input_text, entity_1, entity_2])

    return features

In [8]:
relation_extraction_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
relation_extraction_model = AutoModelForSequenceClassification.from_pretrained("GoDillonAudris/distilbert-relation-extraction")

label_list = ['attributed-to',
             'authored-by', 
             'beacons-to',
             'communicates-with',
             'compromises',
             'consists-of', 
             'controls', 
             'delivers', 
             'downloads', 
             'drops',
             'duplicate-of', 
             'exfiltrates-to', 
             'exploits', 
             'has', 
             'hosts', 
             'impersonates',
             'indicates', 
             'located-at', 
             'no_relation', 
             'originates-from', 
             'owns',
             'related-to', 
             'targets', 
             'uses']

In [9]:
def predict_relations(text, entities):
    features = generate_new_data_relation_feature(text, entities)

    relations = []
    for feature in features:
        inputs = relation_extraction_tokenizer(feature[0], return_tensors="pt")
        outputs = relation_extraction_model(**inputs)

        predictions = outputs.logits
        predicted_label = predictions.argmax(dim=1).item()

        decoded_label = label_list[predicted_label]

        if decoded_label != 'no_relation':
            relations.append(feature[1] + " " + decoded_label + " " + feature[2])

    return relations

## Pipeline

In [None]:
# Assume there is text, and entities list
text = ""
entities = []

# Relation extraction task
relations = predict_relations(text, entities)