### Import required libs

In [1]:
import rdflib
from transformers import AutoTokenizer, AutoModelForSequenceClassification, T5ForConditionalGeneration, T5Tokenizer, pipeline
from transformers.utils import logging
from sentence_transformers import SentenceTransformer
import torch
import pandas as pd

import os
from pprint import pprint
from IPython.display import display, Markdown

from utils import triple_sentiment_analysis, test_entailment
from utils.sparql_queries import find_all_triples_q
from anchor_points_extractor import anchor_points_extractor
from graph_explorator import graph_explorator
from g2t_generator import g2t_generator

os.environ["TOKENIZERS_PARALLELISM"] = "false"

device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
logging.set_verbosity_error()

### Set the current goal and the triples to avoid as anchor points

For example, if you want to run the approach for the subgoal &ldquo;Use flood warning system to predict flood using weather forecastings&rdquo;, which has been created using the triples `Flood warning system Predicts Flood` and `Flood warning system Analyzes Forecasting`, you should use the following parameters:

```python
goal = "Use flood warning system to predict flood using weather forecastings"
filtered_out_triples = ["Flood warning system Predicts Flood", "Flood warning system Analyzes Forecasting"]
```

In [2]:
goal = "Anticipate the impact of floods on people"
filtered_out_triples = []

### Import the example knowledge graph

In [3]:
domain_graph = rdflib.Graph()
domain_graph.parse("./flooding_graph.rdf");

query_results = domain_graph.query(find_all_triples_q)
triples = [list(map(str, [row["subject"], row["predicate"], row["object"]])) for row in query_results.bindings]

data = []
for t in triples:
    # simple triple
    triple = " ".join(t)
    # triples serialized
    triple_with_separator = [t]
    list_goal_triples = [(triple, goal, triple_with_separator)]

    for element in list_goal_triples:
        row = {'TRIPLE': element[0], 'GOAL': element[1], 'TRIPLE_SERIALIZED': element[2]}
        data.append(row)

goal_triples_df = pd.DataFrame(data)

### Import the models used

In [4]:
model_sts = SentenceTransformer('all-mpnet-base-v2')

model_nli_name = "ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli"
tokenizer_nli = AutoTokenizer.from_pretrained(model_nli_name)
model_nli = AutoModelForSequenceClassification.from_pretrained(model_nli_name).to(device)

model_g2t = T5ForConditionalGeneration.from_pretrained("Inria-CEDAR/WebNLG20T5B").to(device)
tokenizer_g2t = T5Tokenizer.from_pretrained("t5-base", model_max_length=512)

sentiment_model_path = "cardiffnlp/twitter-roberta-base-sentiment-latest"
sentiment_task = pipeline("sentiment-analysis", model=sentiment_model_path, tokenizer=sentiment_model_path, device=device)

### Extract anchor points

In [5]:
anchor_points_df = anchor_points_extractor(goal_triples_df, model_sts, filtered_out_triples).copy()
anchor_points_df

Unnamed: 0,TRIPLE,GOAL,TRIPLE_SERIALIZED,SCORE
53,Flood Causes Health and safety risks,Anticipate the impact of floods on people,"[[Flood, Causes, Health and safety risks]]",0.716851
27,Flood Causes Loss of life,Anticipate the impact of floods on people,"[[Flood, Causes, Loss of life]]",0.67423
14,Flood Causes Environmental damage,Anticipate the impact of floods on people,"[[Flood, Causes, Environmental damage]]",0.64711
19,Flood warning system Analyzes Forecasting,Anticipate the impact of floods on people,"[[Flood warning system, Analyzes, Forecasting]]",0.641735
31,Flood Causes Business disruption,Anticipate the impact of floods on people,"[[Flood, Causes, Business disruption]]",0.62965
35,Flood Causes Property damage,Anticipate the impact of floods on people,"[[Flood, Causes, Property damage]]",0.619106
52,Flood Causes Evacuation of residents,Anticipate the impact of floods on people,"[[Flood, Causes, Evacuation of residents]]",0.615268
5,Flood warning system Predicts Flood,Anticipate the impact of floods on people,"[[Flood warning system, Predicts, Flood]]",0.610021


### Transform negative triples

In [6]:
anchor_points_df["SENTIMENT"] = anchor_points_df["TRIPLE_SERIALIZED"].apply(
        lambda triple: triple_sentiment_analysis(triple[0], sentiment_task, neutral_predicates=["is a type of"])[0])
anchor_points_df.rename(columns={'TRIPLE': 'PREMISE', 'GOAL': 'HYPOTHESIS'}, inplace=True)
transformed_triples_premise = []

for triple, sentiment in zip(anchor_points_df["PREMISE"], anchor_points_df["SENTIMENT"]):
    if sentiment == "negative":
        ### Transformation
        transformed_triples_premise.append("Prevent that " + triple)
    else:
        transformed_triples_premise.append(triple)
transformed_anchor_points = pd.DataFrame(transformed_triples_premise, columns=["PREMISE"])
transformed_anchor_points["HYPOTHESIS"] = anchor_points_df["HYPOTHESIS"].values
transformed_anchor_points["PREMISE_SERIALIZED"] = anchor_points_df["TRIPLE_SERIALIZED"].values

transformed_anchor_points

Unnamed: 0,PREMISE,HYPOTHESIS,PREMISE_SERIALIZED
0,Prevent that Flood Causes Health and safety risks,Anticipate the impact of floods on people,"[[Flood, Causes, Health and safety risks]]"
1,Prevent that Flood Causes Loss of life,Anticipate the impact of floods on people,"[[Flood, Causes, Loss of life]]"
2,Prevent that Flood Causes Environmental damage,Anticipate the impact of floods on people,"[[Flood, Causes, Environmental damage]]"
3,Flood warning system Analyzes Forecasting,Anticipate the impact of floods on people,"[[Flood warning system, Analyzes, Forecasting]]"
4,Prevent that Flood Causes Business disruption,Anticipate the impact of floods on people,"[[Flood, Causes, Business disruption]]"
5,Prevent that Flood Causes Property damage,Anticipate the impact of floods on people,"[[Flood, Causes, Property damage]]"
6,Prevent that Flood Causes Evacuation of residents,Anticipate the impact of floods on people,"[[Flood, Causes, Evacuation of residents]]"
7,Flood warning system Predicts Flood,Anticipate the impact of floods on people,"[[Flood warning system, Predicts, Flood]]"


### Verify entailment

In [7]:
entailment_result = test_entailment(transformed_anchor_points, tokenizer_nli, model_nli)
entailment_result

Unnamed: 0,PREMISE,HYPOTHESIS,PREMISE_SERIALIZED,ENTAILMENT,NEUTRAL,CONTRADICTION,NLI_LABEL
6,Prevent that Flood Causes Evacuation of residents,Anticipate the impact of floods on people,"[[Flood, Causes, Evacuation of residents]]",0.86817,0.11147,0.02036,ENTAILMENT
1,Prevent that Flood Causes Loss of life,Anticipate the impact of floods on people,"[[Flood, Causes, Loss of life]]",0.78269,0.17343,0.04388,ENTAILMENT
0,Prevent that Flood Causes Health and safety risks,Anticipate the impact of floods on people,"[[Flood, Causes, Health and safety risks]]",0.70614,0.21396,0.0799,ENTAILMENT
4,Prevent that Flood Causes Business disruption,Anticipate the impact of floods on people,"[[Flood, Causes, Business disruption]]",0.41095,0.35846,0.23059,ENTAILMENT
7,Flood warning system Predicts Flood,Anticipate the impact of floods on people,"[[Flood warning system, Predicts, Flood]]",0.3862,0.59991,0.01389,NEUTRAL
5,Prevent that Flood Causes Property damage,Anticipate the impact of floods on people,"[[Flood, Causes, Property damage]]",0.3506,0.34778,0.30162,ENTAILMENT
3,Flood warning system Analyzes Forecasting,Anticipate the impact of floods on people,"[[Flood warning system, Analyzes, Forecasting]]",0.21787,0.76808,0.01405,NEUTRAL
2,Prevent that Flood Causes Environmental damage,Anticipate the impact of floods on people,"[[Flood, Causes, Environmental damage]]",0.16007,0.38834,0.45159,CONTRADICTION


### Explore graph to improve contextualization

Add neighbor triples to anchor points to further contextualize them.

In [8]:
entailed_triples_df = graph_explorator(entailment_result, goal, domain_graph, tokenizer_nli, model_nli)
entailed_triples_df

Unnamed: 0,SUBGOALS,SUBGOALS_SERIALIZED,SCORE,NLI_LABEL
0,Prevent that Flood Causes Evacuation of residents,"[[Flood, Causes, Evacuation of residents]]",0.86817,ENTAILMENT
1,Prevent that Flood Causes Loss of life,"[[Flood, Causes, Loss of life]]",0.78269,ENTAILMENT
2,Prevent that Flood Causes Health and safety risks,"[[Flood, Causes, Health and safety risks]]",0.70614,ENTAILMENT
3,Prevent that Flood Causes Business disruption,"[[Flood, Causes, Business disruption]]",0.41095,ENTAILMENT
4,Flood Causes Loss of life. Flood warning syste...,"[[Flood, Causes, Loss of life], [Flood warning...",0.82918,ENTAILMENT
5,Prevent that Flood Causes Property damage,"[[Flood, Causes, Property damage]]",0.3506,ENTAILMENT
2,Flood warning system Recommends Evacuation of ...,"[[Flood warning system, Recommends, Evacuation...",0.65599,ENTAILMENT
4,Flood Causes Evacuation of residents. Prevent ...,"[[Flood, Causes, Evacuation of residents], [Fl...",0.7474,ENTAILMENT


### Generate text from the triples identified as relevant

In [9]:
if not entailed_triples_df.empty:
    all_triples_entailed = [triple for triples in entailed_triples_df["SUBGOALS_SERIALIZED"].tolist() for triple in triples]

    if not entailed_triples_df.empty:
        all_triples_entailed.append(triples[0] for triples in entailed_triples_df["SUBGOALS_SERIALIZED"].tolist())

    unique_triples_entailed = []
    for triple in all_triples_entailed:
        if (triple not in unique_triples_entailed) and (type(triple) is list):
            unique_triples_entailed.append(triple)

    display(Markdown("#### Unique triples"))
    pprint(unique_triples_entailed)
    display(Markdown("---------------------------"))

    triples_already_processed = []

    display(Markdown("#### Generated texts"))

    for idx, row in entailed_triples_df.iterrows():
        triples_to_process = []
        for triple in row["SUBGOALS_SERIALIZED"]:
            if (triple not in triples_already_processed) and (type(triple) is list):
                triples_to_process.append(triple)
        if triples_to_process:
            prediction = g2t_generator(triples_to_process, model=model_g2t, tokenizer=tokenizer_g2t)[0]
            text_version = row["SUBGOALS"].split(". ")[-1]
            if "Prevent that" in text_version:
                prediction = "[AVOID] " + prediction
            else:
                prediction = "[ACHIEVE] " + prediction
            print(prediction)

        triples_already_processed.extend(triples_to_process)

#### Unique triples

[['Flood', 'Causes', 'Evacuation of residents'],
 ['Flood', 'Causes', 'Loss of life'],
 ['Flood', 'Causes', 'Health and safety risks'],
 ['Flood', 'Causes', 'Business disruption'],
 ['Flood', 'Causes', 'Property damage'],
 ['Flood', 'Causes', 'Environmental damage']]


---------------------------

#### Generated texts

[AVOID] evacuation of residents is one of the causes of the flood
[AVOID] floods can cause loss of life
[AVOID] floods can cause health and safety risks
[AVOID] floods can cause business disruption
[AVOID] flooding can cause property damage
[AVOID] floods can cause environmental damage
