#PREPROCESSING

In [None]:
!pip install rdflib

In [None]:
import pandas as pd
from rdflib import Graph, Literal, Namespace, URIRef
from urllib.parse import quote


EG = Namespace("http://example.com/")

def create_eg_uri(name) -> URIRef:
    """Convert a value to a valid example.com URI by converting it to a string,
    URL-encoding special characters, and replacing spaces with underscores."""
    name_str = str(name)
    quoted = quote(name_str.replace(" ", "_"))
    return EG[quoted]

def translate_df_to_rdf(data, mapping):
    graph = Graph()
    num_rows = len(data.index)
    for i in range(num_rows):
        name = data.iloc[i][mapping["uri"]]
        row_uri = create_eg_uri(name)

        outcome_value = data.iloc[i]['label_utilize']
        outcome_predicate = create_eg_uri("positive" if outcome_value == 1 else "negative")
        graph.add((row_uri, outcome_predicate, Literal(outcome_value)))

        for column_name, predicate_uri in mapping.items():
            if column_name == "uri":
                continue
            predicate = create_eg_uri(predicate_uri)
            value = data.iloc[i][column_name]

            if pd.notna(value) and isinstance(value, str):
                value = value.strip()
            graph.add((row_uri, predicate, Literal(value)))
    return graph


mapping = {
    "uri": "label_utilize",
    "obj0 on table_north": "obj0_on_table_north",
    "obj0 on table_south": "obj0_on_table_south",
    "obj0 on table_east": "obj0_on_table_east",
    "obj0 on table_west": "obj0_on_table_west",
    "obj1 on table_north": "obj1_on_table_north",
    "obj1 on table_south": "obj1_on_table_south",
    "obj1 on table_east": "obj1_on_table_east",
    "obj1 on table_west": "obj1_on_table_west",
    "obj2 on table_north": "obj2_on_table_north",
    "obj2 on table_south": "obj2_on_table_south",
    "obj2 on table_east": "obj2_on_table_east",
    "obj2 on table_west": "obj2_on_table_west",
    "obj3 on table_north": "obj3_on_table_north",
    "obj3 on table_south": "obj3_on_table_south",
    "obj3 on table_east": "obj3_on_table_east",
    "obj3 on table_west": "obj3_on_table_west",
    "obj4 on table_north": "obj4_on_table_north",
    "obj4 on table_south": "obj4_on_table_south",
    "obj4 on table_east": "obj4_on_table_east",
    "obj4 on table_west": "obj4_on_table_west",
    "tray on table_north": "tray_on_table_north",
    "tray on table_south": "tray_on_table_south",
    "tray on table_east": "tray_on_table_east",
    "tray on table_west": "tray_on_table_west",
    "jug on table_north": "jug_on_table_north",
    "jug on table_south": "jug_on_table_south",
    "jug on table_east": "jug_on_table_east",
    "jug on table_west": "jug_on_table_west",
    "is_helper_exist": "is_helper_exist",
    "goal at table_north": "goal_at_table_north",
    "goal at table_south": "goal_at_table_south",
    "goal at table_east": "goal_at_table_east",
    "goal at table_west": "goal_at_table_west",
    "goal at table_watersource": "goal_at_table_watersource",
    "goal at table_obstacle": "goal_at_table_obstacle",
    "goal at table_handover": "goal_at_table_handover",
    "action_pick": "action_pick",
    "action_place": "action_place",
    "action_carry": "action_carry",
    "action_fill": "action_fill",
    "action_pour": "action_pour",
    "action_handover": "action_handover",
    "label_utilize": "label_utilize"
}

rdf_graph = translate_df_to_rdf(data_df, mapping)

def extract_data_from_graph(graph):
    data = []
    for s, p, o in graph:
        predicate = str(p).split('/')[-1]
        subject = str(s).split('/')[-1]
        object = str(o)
        data.append((subject, predicate, object))
    return data

extracted_data = extract_data_from_graph(rdf_graph)



In [None]:
clean_output_file_path = '/content/CARL/datasets/tic/facts.txt'
with open(clean_output_file_path, 'w') as file:
    for entry in extracted_data:
        formatted_entry = " ".join(entry).replace("(", "").replace(")", "").replace("'", "").replace(",", "")
        parts = formatted_entry.split()
        formatted_output = f"{parts[0]}\t{parts[1]}\t{parts[2]}\n"
        file.write(formatted_output)

clean_output_file_path


In [None]:

entities_file_path = '/content/CARL/datasets/tic/entities.txt'
relations_file_path = '/content/CARL/datasets/tic/relations.txt'
entities = set()
relations = set()

for s, p, o in extracted_data:
    entities.add(s)
    entities.add(o)
    relations.add(p)

def clean_uri(uri):
    return str(uri).split('/')[-1]

with open(entities_file_path, 'w') as file:
    for entity in sorted(entities, key=lambda x: str(x)):
        file.write(f"{clean_uri(entity)}\n")

with open(relations_file_path, 'w') as file:
    for relation in sorted(relations, key=lambda x: str(x)):
        file.write(f"{clean_uri(relation)}\n")

entities_file_path, relations_file_path


In [None]:
import random

file_path = '/content/facts.txt'


with open(file_path, 'r') as file:
    lines = file.readlines()


train_split = 0.5
valid_split = 0.25
test_split = 0.25

random.shuffle(lines)
train_end = int(len(lines) * train_split)
train_data = lines[:train_end]


random.shuffle(lines)
valid_end = int(len(lines) * valid_split)
valid_data = lines[:valid_end]


random.shuffle(lines)
test_data = lines[:int(len(lines) * test_split)]


with open('/content/CARL/datasets/tic/train.txt', 'w') as f:
    f.writelines(train_data)

with open('/content/CARL/datasets/tic/valid.txt', 'w') as f:
    f.writelines(valid_data)

with open('/content/CARL/datasets/tic/test.txt', 'w') as f:
    f.writelines(test_data)

print("Data split and saved to train.txt, valid.txt, and test.txt")


In [None]:
original_path = '/content/facts.txt'
inverted_path = '/content/CARL/datasets/tic/facts.txt.inv'


with open(original_path, 'r') as file:
    lines = file.readlines()

inverted_lines = []
for line in lines:
    parts = line.strip().split()
    if len(parts) == 3:
        inverted_line = f"{parts[2]}\tinv_{parts[1]}\t{parts[0]}\n"
        inverted_lines.append(inverted_line)
    else:
        inverted_lines.append(line)

with open(inverted_path, 'w') as file:
    file.writelines(inverted_lines)

print("Inverted file saved as facts.txt.inv")


#CARL MODEL

In [None]:
!git clone https://github.com/burning5112/CARL.git
%cd /content/CARL

In [None]:
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from random import sample
import random
from torch.nn.utils import clip_grad_norm_
import time


def construct_fact_dict(fact_rdf):
    fact_dict = {}
    for rdf in fact_rdf:
        fact = parse_rdf(rdf)
        h, r, t = fact
        if r not in fact_dict:
            fact_dict[r] = []
        fact_dict[r].append(rdf)
    return fact_dict


def parse_rdf(rdf):
    print(rdf)
    rdf_tail, rdf_rel, rdf_head = rdf
    return rdf_head, rdf_rel, rdf_tail


class Dictionary(object):
    def __init__(self):
        self.rel2idx_ = {}
        self.idx2rel_ = {}
        self.idx = 0

    def add_relation(self, rel):
        if rel not in self.rel2idx_.keys():
            self.rel2idx_[rel] = self.idx
            self.idx2rel_[self.idx] = rel
            self.idx += 1

    @property
    def rel2idx(self):
        return self.rel2idx_

    @property
    def idx2rel(self):
        return self.idx2rel_

    def __len__(self):
        return len(self.idx2rel_)


def load_entities(path):
    idx2ent, ent2idx = {}, {}
    with open(path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        for idx, line in enumerate(lines):
            """
            这一行使用for循环遍历lines列表中的每一项，并使用enumerate函数给每一项添加一个索引。索引赋值给变量idx，字符串赋值给变量line。
            """
            e = line.strip()
            ent2idx[e] = idx
            idx2ent[idx] = e
    return idx2ent, ent2idx


class Dataset(object):
    def __init__(self, data_root, sparsity=1, inv=False):

        entity_path = data_root + 'entities.txt'
        self.idx2ent_, self.ent2idx_ = load_entities(entity_path)

        relation_path = data_root + 'relations.txt'
        self.rdict = Dictionary()
        self.load_relation_dict(relation_path)

        self.head_rdict = Dictionary()
        self.head_rdict = copy.deepcopy(self.rdict)

        fact_path = data_root + 'facts.txt'
        train_path = data_root + 'train.txt'
        valid_path = data_root + 'valid.txt'
        test_path = data_root + 'test.txt'

        if inv:
            fact_path += '.inv'

        self.rdf_data_ = self.load_data_(fact_path, train_path, valid_path, test_path, sparsity)
        self.fact_rdf_, self.train_rdf_, self.valid_rdf_, self.test_rdf_ = self.rdf_data_

        if inv:

            rel_list = list(self.rdict.rel2idx_.keys())
            for rel in rel_list:
                inv_rel = "inv_" + rel
                self.rdict.add_relation(inv_rel)
                self.head_rdict.add_relation(inv_rel)

        self.head_rdict.add_relation("None")

    def load_rdfs(self, path):
        rdf_list = []
        with open(path, 'r', encoding='utf-8') as f:
            lines = f.readlines()
            for line in lines:
                tuples = line.strip().split('\t')
                rdf_list.append(tuples)
        return rdf_list

    def load_data_(self, fact_path, train_path, valid_path, test_path, sparsity):
        fact = self.load_rdfs(fact_path)
        fact = sample(fact, int(len(fact) * sparsity))

        train = self.load_rdfs(train_path)
        valid = self.load_rdfs(valid_path)
        test = self.load_rdfs(test_path)
        return fact, train, valid, test

    def load_relation_dict(self, relation_path):
        """
        Read relation.txt to relation dictionary
        """
        with open(relation_path, encoding='utf-8') as f:
            rel_list = f.readlines()
            for r in rel_list:
                relation = r.strip()
                self.rdict.add_relation(relation)

    def get_relation_dict(self):
        return self.rdict

    def get_head_relation_dict(self):
        return self.head_rdict

    @property
    def idx2ent(self):
        return self.idx2ent_

    @property
    def ent2idx(self):
        return self.ent2idx_

    @property
    def fact_rdf(self):
        return self.fact_rdf_

    @property
    def train_rdf(self):
        return self.train_rdf_

    @property
    def valid_rdf(self):
        return self.valid_rdf_

    @property
    def test_rdf(self):
        return self.test_rdf_


def sample_anchor_rdf(rdf_data, num):
    if num < len(rdf_data):
        return sample(rdf_data, num)


    else:
        return rdf_data


def construct_descendant(rdf_data):
    entity2desced = {}
    for rdf_ in rdf_data:
        print(rdf_)
        h_, r_, t_ = parse_rdf(rdf_)
        if h_ not in entity2desced.keys():
            entity2desced[h_] = [(r_, t_)]
        else:
            entity2desced[h_].append((r_, t_))
    return entity2desced


def connected(entity2desced, head, tail):
    if head in entity2desced:
        decedents = entity2desced[head]
        for d in decedents:
            d_relation_, d_tail_ = d
            if d_tail_ == tail:
                return d_relation_
        return False
    else:
        return False


def construct_rule_seq(rdf_data, anchor_rdf, entity2desced, sample_max_path_len=2, PRINT=False):
    len2seq = {}
    anchor_h, anchor_r, anchor_t = parse_rdf(anchor_rdf)

    stack = [(anchor_h, anchor_r, anchor_t)]

    stack_print = ['{}-{}-{}'.format(anchor_h, anchor_r, anchor_t)]

    pre_path = anchor_h
    rule_seq, expended_node = [], []
    record = []
    while len(stack) > 0:
        cur_h, cur_r, cur_t = stack.pop(-1)
        cur_print = stack_print.pop(-1)
        deced_list = []

        if cur_t in entity2desced:
            deced_list = entity2desced[cur_t]

        if len(cur_r.split('|')) < sample_max_path_len and len(deced_list) > 0 and cur_t not in expended_node:
            for r_, t_ in deced_list:
                if t_ != cur_h and t_ != anchor_h:
                    stack.append((cur_t, cur_r + '|' + r_, t_))
                    stack_print.append(cur_print + '-{}-{}'.format(r_, t_))
        expended_node.append(cur_t)

        rule_head_rel = connected(entity2desced, anchor_h, cur_t)
        if rule_head_rel and cur_t != anchor_t:
            rule = cur_r + '-' + rule_head_rel
            rule_seq.append(rule)
            if (cur_h, r_, t_) not in record:
                record.append((cur_h, r_, t_))
            if PRINT:
                print('rule body:\n{}'.format(cur_print))
                print('rule head:\n{}-{}-{}'.format(anchor_h, rule_head_rel, cur_t))
                print('rule:\n{}\n'.format(rule))
        elif rule_head_rel == False and random.random() > 0.9:
            rule = cur_r + '-' + "None"
            rule_seq.append(rule)
            if (cur_h, r_, t_) not in record:
                record.append((cur_h, r_, t_))
            if PRINT:
                print('rule body:\n{}'.format(cur_print))
                print('rule head:\n{}-{}-{}'.format(anchor_h, rule_head_rel, cur_t))
                print('rule:\n{}\n'.format(rule))
    return rule_seq, record


def body2idx(body_list, head_rdict):
    res = []
    for body in body_list:
        body_path = body.split('|')

        indexs = []
        for rel in body_path:
            indexs.append(head_rdict.rel2idx[rel])
        res.append(indexs)
    return res


def inv_rel_idx(head_rdict):
    inv_rel_idx = []
    for i_ in range(len(head_rdict.idx2rel)):
        r_ = head_rdict.idx2rel[i_]
        if "inv_" in r_:
            inv_rel_idx.append(i_)
    return inv_rel_idx


def idx2body(index, head_rdict):
    body = "|".join([head_rdict.idx2rel[idx] for idx in index])
    return body


def rule2idx(rule, head_rdict):
    body, head = rule.split('-')
    body_path = body.split('|')

    indexs = []
    for rel in body_path + [-1, head]:
        indexs.append(head_rdict.rel2idx[rel] if rel != -1 else -1)
    return indexs


def idx2rule(index, head_rdict):
    body_idx = index[0:-2]
    body = "|".join([head_rdict.idx2rel[b] for b in body_idx])
    rule = body + "-" + head_rdict.idx2rel[index[-1]]
    return rule


def enumerate_body(relation_num, body_len, rdict):
    import itertools

    all_body_idx = list(list(x) for x in itertools.product(range(relation_num), repeat=body_len))
    """
    product(A, B, C) = ((a, b, c) for a in A for b in B for c in C)
    product(A, repeat=2) = product(A, A) = ((a1, a2) for a1 in A for a2 in A)
    """
    idx2rel = rdict.idx2rel
    all_body = []
    for b_idx_ in all_body_idx:
        b_ = [idx2rel[x] for x in b_idx_]
        all_body.append(b_)
    return all_body_idx, all_body


In [None]:
def print_msg(msg):

    msg = "## {} ##".format(msg)
    length = len(msg)
    msg = "\n{}\n".format(msg)
    print(length*"#" + msg + length * "#")
import torch
def cross_entropy(y,y_pre,weight):

    res = weight * (y*torch.log(y_pre))
    loss=-torch.sum(res)
    return loss/y_pre.shape[0]

    model_file = 'epochs:{}_alpha:{}_anchor:{}_sample_max_path_len:{}_embedding_size{}_learning_rate{}_batch_size{}_get_top_k{}'.format(
        args.epochs,
        args.alpha,
        args.anchor,
        args.sample_max_path_len,
        args.embedding_size,
        args.learning_rate,
        args.batch_size,
        args.get_top_k
    )
    os.makedirs("../rules/{}/{}/{}".format(args.model,args.datasets,model_file), exist_ok=True)


In [None]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
from torch.distributions import Categorical
import math
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt


class PositionalEmbedding(nn.Module):

    def __init__(self, d_model, max_len=3):
        super(PositionalEmbedding, self).__init__()
        pe = torch.zeros(max_len, d_model).float()
        pe.require_grad = False
        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return self.pe[:, :x.size(1)]


class ProbAttention(nn.Module):
    def __init__(self, mask_flag=False, factor=5, scale=None, attention_dropout=0.1, output_attention=True):
        super(ProbAttention, self).__init__()
        self.factor = factor
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)

    def _prob_QK(self, Q, K, sample_k, n_top):
        B, H, L_K, E = K.shape
        _, _, L_Q, _ = Q.shape
        K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
        index_sample = torch.randint(L_K, (L_Q, sample_k))
        K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
        Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2)
        M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
        M_top = M.topk(n_top, sorted=False)[1]
        Q_reduce = Q[torch.arange(B)[:, None, None],
                   torch.arange(H)[None, :, None],
                   M_top, :]
        Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1))
        return Q_K, M_top

    def _get_initial_context(self, V, L_Q):
        B, H, L_V, D = V.shape
        if not self.mask_flag:
            V_sum = V.mean(dim=-2)
            contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
        else:
            assert (L_Q == L_V)
            contex = V.cumsum(dim=-2)
        return contex

    def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
        B, H, L_V, D = V.shape
        if self.mask_flag:
            attn_mask = ProbMask(B, H, L_Q, index, scores,
                                 device=V.device)
            scores.masked_fill_(attn_mask.mask,
                                -np.inf)
        attn = torch.softmax(scores,
                             dim=-1)
        context_in[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = torch.matmul(attn,
                                                                                                            V).type_as(
            context_in)
        if self.output_attention:
            attns = (torch.ones([B,
                                 H,
                                 L_V,
                                 L_V]) / L_V).type_as(attn).to(attn.device)
            attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn
            return (context_in, attns)
        else:
            return (context_in, None)

    def forward(self, queries, keys, values, attn_mask=None):
        B, L_Q, H, D = queries.shape
        _, L_K, _, _ = keys.shape
        queries = queries.transpose(2, 1)
        keys = keys.transpose(2, 1)
        values = values.transpose(2, 1)
        U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item()
        u = self.factor * np.ceil(np.log(L_Q)).astype('int').item()
        U_part = U_part if U_part < L_K else L_K
        u = u if u < L_Q else L_Q
        scores_top, index = self._prob_QK(queries,
                                          keys,
                                          sample_k=U_part,
                                          n_top=u)
        scale = self.scale or 1. / math.sqrt(D)
        if scale is not None:
            scores_top = scores_top * scale
        context = self._get_initial_context(values, L_Q)
        context, attn = self._update_context(context,
                                             values,
                                             scores_top,
                                             index,
                                             L_Q,
                                             attn_mask)
        return context.transpose(2, 1).contiguous(), attn


class Encoder(nn.Module):
    def __init__(self, relation_num, emb_size, device):

        super(Encoder, self).__init__()
        self.emb = nn.Embedding(relation_num + 1, emb_size, padding_idx=relation_num)
        self.hidden_size = emb_size
        self.relation_num = relation_num
        self.emb_size = emb_size
        self.device = device
        self.dropout_rate = 0.1
        self.num_heads = 8
        self.num_layers = 1
        self.slider_encoder = nn.LSTM(input_size=emb_size,
                                      hidden_size=self.hidden_size,
                                      num_layers=self.num_layers,
                                      batch_first=True,
                                      dropout=0
                                      )
        self.fc = nn.Linear(self.hidden_size, 1)
        self.sigmoid = nn.Sigmoid()
        self.w_1 = nn.Linear(emb_size, emb_size)
        self.w_2 = nn.Linear(emb_size, emb_size)
        self.relu = nn.ReLU()
        self.weight_w_1 = nn.Linear(self.num_heads * (relation_num + 1), self.num_heads * (relation_num + 1))
        self.weight_w_2 = nn.Linear(self.num_heads * (relation_num + 1), relation_num + 1)
        self.layernorm1 = nn.LayerNorm(relation_num + 1)
        self.weight_a_1 = nn.Linear(self.num_heads * (relation_num + 1), self.num_heads * (relation_num + 1))
        self.weight_a_2 = nn.Linear(self.num_heads * (relation_num + 1), relation_num + 1)
        self.layernorm2 = nn.LayerNorm(relation_num + 1)
        self.slider1 = nn.Linear(2 * emb_size, 2 * emb_size)
        self.slider2 = nn.Linear(2 * emb_size, emb_size)
        self.slider_3_1 = nn.Linear(2 * emb_size, 2 * emb_size)
        self.slider_3_2 = nn.Linear(2 * emb_size, emb_size)
        assert self.emb_size % self.num_heads == 0
        self.head_emb_size = self.emb_size // self.num_heads
        self.fc_q = nn.Linear(emb_size, emb_size)
        self.fc_k = nn.Linear(emb_size, emb_size)
        self.fc_v = nn.Linear(emb_size, emb_size)
        self.out = nn.Linear(emb_size, emb_size)
        self.self_fc_q = nn.Linear(emb_size, emb_size)
        self.self_fc_k = nn.Linear(emb_size, emb_size)
        self.self_fc_v = nn.Linear(emb_size, emb_size)
        self.self_out = nn.Linear(emb_size, emb_size)
        self.dropout = nn.Dropout(self.dropout_rate)
        self.layernorm = nn.LayerNorm(emb_size)
        self.position_embedding = PositionalEmbedding(d_model=emb_size, max_len=3)
        self.probAttention = ProbAttention()

    def forward(self, inputs):
        inputs = self.emb(inputs)
        a = self.position_embedding(inputs)
        inputs = inputs + a
        batch_size, seq_len, emb_size = inputs.shape
        idx_ = torch.LongTensor(range(self.relation_num)).repeat(batch_size, 1).to(self.device)
        relation_emb_ori = self.emb(idx_)
        relation_emb = relation_emb_ori.reshape(batch_size, self.relation_num, self.num_heads, self.head_emb_size)
        relation_emb, relation_emb_weigth = self.probAttention(queries=relation_emb, keys=relation_emb,
                                                               values=relation_emb)
        relation_emb = relation_emb.reshape(batch_size, self.relation_num, -1)
        relation_emb = self.layernorm(relation_emb_ori + 0.1 * relation_emb)
        L = [inputs]
        loss_list = []
        idx = 0
        while idx < seq_len - 1:
            output, loss = self.reduce_rel_pairs(L[-1], relation_emb)
            L.append(output)
            loss_list.append(loss)
            idx += 1
        selected_rel_pair_after_attention_value, attn_weights, scores = self.multiHeadAttention(L[-1], relation_emb)
        loss = Categorical(probs=attn_weights).entropy()
        loss_list.append(loss)
        loss_tensor = torch.cat(loss_list, dim=-1)
        return self.predict_head(scores), loss_tensor, relation_emb_weigth

    def reduce_rel_pairs(self, inputs, relation_emb_with_attention):

        batch_size, seq_len, emb_size = inputs.shape
        if seq_len > 2:
            rel_pairs = []
            idx = 0
            while idx < seq_len - 1:
                rel_pairs_emb = inputs[:, idx:idx + 2, :]
                rel_pairs_emb = rel_pairs_emb.reshape(batch_size, -1)
                rel_pairs_emb = self.dropout(self.relu(self.slider1(rel_pairs_emb)))
                rel_pairs_emb = self.dropout(self.slider2(rel_pairs_emb))
                h = self.layernorm(rel_pairs_emb)
                rel_pairs.append(h)
                idx += 1
            rel_pairs = torch.stack(rel_pairs, dim=1)
            choice_rel_pairs = self.dropout(self.fc(rel_pairs)).squeeze(-1)
            choice_rel_pairs = self.sigmoid(choice_rel_pairs)
            selected_rel_pair_idx = torch.argmax(choice_rel_pairs,
                                                 dim=-1)
            full_batch = torch.arange(batch_size).to(self.device)
            selected_rel_pair = rel_pairs[full_batch, selected_rel_pair_idx, :]
            selected_rel_pair = selected_rel_pair.unsqueeze(1)
            selected_rel_pair_after_attention_value, attn_weights, scores = self.multiHeadAttention(selected_rel_pair,
                                                                                                    relation_emb_with_attention)
            selected_rel_pair_after_attention_value = self.feedForward(selected_rel_pair_after_attention_value)
            loss = Categorical(probs=attn_weights).entropy()
            selected_rel_pair_after_attention_value = selected_rel_pair_after_attention_value.squeeze(1)
            output = inputs.detach().clone()
            zero = torch.zeros(emb_size).to(self.device)
            output[full_batch, selected_rel_pair_idx, :] = selected_rel_pair_after_attention_value
            output[full_batch, selected_rel_pair_idx + 1, :] = zero
            output = output[~torch.all(output == 0, dim=-1)]
            output = output.reshape(batch_size, -1, emb_size)

        else:

            inter = inputs.reshape(batch_size, -1)
            inter = self.dropout(self.relu(self.slider1(inter)))
            output = self.dropout(self.slider2(inter))
            output = self.layernorm(output).unsqueeze(1)
            loss = torch.zeros((batch_size, 1)).to(self.device)
        return output, loss

    def multiHeadAttention(self, inputs, relation_emb, mask=None):
        batch_size, seq_len, emb_size = inputs.shape
        query = self.dropout(self.fc_q(inputs)).view(batch_size, seq_len, self.num_heads, self.head_emb_size).transpose(
            1, 2)
        key = self.dropout(self.fc_k(torch.cat((relation_emb, inputs), dim=1))).view(batch_size, -1, self.num_heads,
                                                                                     self.head_emb_size).transpose(1, 2)
        value = self.dropout(self.fc_v(torch.cat((relation_emb, inputs), dim=1))).view(batch_size, -1, self.num_heads,
                                                                                       self.head_emb_size).transpose(1,
                                                                                                                     2)

        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_emb_size)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn_weights = torch.softmax(scores, dim=-1)

        attn_output = torch.matmul(attn_weights, value)

        output = self.out(attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1))

        scores = torch.mean(scores, dim=1)
        attn_weights = torch.mean(attn_weights, dim=1)

        return self.layernorm(output + inputs), attn_weights, scores

    def feedForward(self, inputs):
        inter = self.dropout(self.relu(self.w_1(inputs)))
        output = self.dropout(self.w_2(inter))
        return self.layernorm(output + inputs)

    def transformer_attention(self, inputs, relation_emb):
        batch_size, seq_len, emb_size = inputs.shape
        query = self.dropout(self.fc_q(inputs))
        key = self.dropout(self.fc_k(torch.cat((relation_emb, inputs), dim=1)))
        value = self.dropout(self.fc_v(torch.cat((relation_emb, inputs), dim=1)))
        scores_ori = torch.matmul(query, key.transpose(-2, -1)) \
                     / math.sqrt(self.emb_size)
        mask1 = torch.zeros((batch_size, seq_len, self.relation_num), dtype=torch.bool).to(self.device)
        I = torch.eye(seq_len).to(self.device)
        I = I.reshape((1, seq_len, seq_len))
        I = I.repeat(batch_size, 1, 1)
        mask2 = ~I.to(torch.bool)
        mask = torch.cat((mask1, mask2), dim=-1)
        scores = scores_ori
        scores[mask] = float('-inf')
        attn_weights = torch.softmax(scores, dim=-1)
        output = attn_weights @ value
        return self.layernorm(output), attn_weights, scores_ori

    def self_attention(self, relation_emb, mask=None):
        batch_size, seq_len, emb_size = relation_emb.shape

        query = self.dropout(self.self_fc_q(relation_emb)).view(batch_size, seq_len, self.num_heads,
                                                                self.head_emb_size).transpose(1, 2)
        key = self.dropout(self.self_fc_k(relation_emb)).view(batch_size, seq_len, self.num_heads,
                                                              self.head_emb_size).transpose(1, 2)
        value = self.dropout(self.self_fc_v(relation_emb)).view(batch_size, seq_len, self.num_heads,
                                                                self.head_emb_size).transpose(1, 2)

        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_emb_size)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn_weights = torch.softmax(scores, dim=-1)

        attn_output = torch.matmul(attn_weights, value)

        output = self.self_out(attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1))

        scores = torch.mean(scores, dim=1)
        attn_weights = torch.mean(attn_weights, dim=1)

        return self.layernorm(output + relation_emb), attn_weights, scores

    def predict_head(self, prob):
        return prob.squeeze(1)

    def get_relation_emb(self, rel):
        return self.emb(rel)

    def weightedAverage(self, scores, inputs):
        batch_size, seq_len, _ = scores.shape
        mask1 = torch.zeros((batch_size, seq_len, self.relation_num), dtype=torch.bool).to(self.device)
        I = torch.eye(seq_len).to(self.device)
        I = I.reshape((1, seq_len, seq_len))
        I = I.repeat(batch_size, 1, 1)
        mask2 = ~I.to(torch.bool)
        mask = torch.cat((mask1, mask2), dim=-1)
        scores[mask] = float('-inf')
        prob = self.dropout(torch.softmax(scores, dim=-1))
        idx_ = torch.LongTensor(range(self.relation_num)).repeat(batch_size, 1).to(self.device)
        relation_emb = self.emb(idx_)
        all_emb = torch.cat((relation_emb, inputs), dim=1)
        out = prob @ all_emb
        return self.layernorm(out)


In [None]:
from audioop import reverse
from wsgiref import headers
from xml.dom.minidom import Element

import copy
import re
import torch
from torch.utils.data import DataLoader
import torch.multiprocessing as mp
import numpy as np
from scipy import sparse
from collections import defaultdict
import argparse
import gc
import os
import sys
import io



head2Mean_rank_mrr = defaultdict(list)
head2Mean_rank_hit_1 = defaultdict(list)
head2Mean_rank_hit_10 = defaultdict(list)
head2Top_rank_mrr = defaultdict(list)
head2Top_rank_hit_1 = defaultdict(list)
head2Top_rank_hit_10 = defaultdict(list)
Mean_rank = defaultdict(list)
Top_rank = defaultdict(list)


class RuleDataset(Dataset):
    def __init__(self, r2mat, rules, e_num, idx2rel, args):
        self.e_num = e_num
        self.r2mat = r2mat
        self.rules = rules
        self.idx2rel = idx2rel
        self.len = len(self.rules)
        self.args = args

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        rel = self.idx2rel[idx]
        _rules = self.rules[rel]

        path_count = sparse.dok_matrix((self.e_num, self.e_num))
        for rule in _rules:
            head, body, conf_1, conf_2 = rule
            if conf_1 >= self.args.threshold:
                body_adj = sparse.eye(self.e_num)
                for b_rel in body:
                    body_adj = body_adj * self.r2mat[b_rel]

                body_adj = body_adj * conf_1
                path_count += body_adj
                del body_adj

        return rel, path_count

    @staticmethod
    def collate_fn(data):
        head = [_[0] for _ in data]
        path_count1 = [_[1] for _ in data]
        return head, path_count1


def sum_from_n_to_len(n_non_zero, filtered_pred):
    total = 0

    for i in range(n_non_zero + 1, len(filtered_pred) + 1):
        total += i

    return total


def sortSparseMatrix(m, r, rev=True, only_indices=False):
    d = m.getrow(r)
    s = zip(d.indices, d.data)
    sorted_s = sorted(s, key=lambda v: v[1], reverse=rev)
    if only_indices:
        res = [element[0] for element in sorted_s]
    else:
        res = sorted_s
    return res


def remove_var(r):
    r = re.sub(r"\(\D?, \D?\)", "", r)
    return r


def parse_rule(r):
    r = remove_var(r)
    head, body = r.split(" <-- ")
    body = body.split(", ")
    return head, body


def load_rules(rule_path, all_rules, all_heads):
    with open(rule_path, 'r') as f:
        rules = f.readlines()
        for i_, rule in enumerate(rules):
            conf, r = rule.strip('\n').split('\t')
            conf_1, conf_2 = float(conf[0:5]), float(conf[-6:-1])
            head, body = parse_rule(r)

            if head not in all_rules:
                all_rules[head] = []
            all_rules[head].append((head, body, conf_1, conf_2))

            if head not in all_heads:
                all_heads.append(head)


def construct_rmat(idx2rel, idx2ent, ent2idx, fact_rdf):
    e_num = len(idx2ent)
    r2mat = {}
    for idx, rel in idx2rel.items():
        mat = sparse.dok_matrix((e_num, e_num))
        r2mat[rel] = mat

    for rdf in fact_rdf:
        fact = parse_rdf(rdf)
        h, r, t = fact
        h_idx, t_idx = ent2idx[h], ent2idx[t]
        r2mat[r][h_idx, t_idx] = 1
    return r2mat


def get_gt(dataset):
    idx2ent, ent2idx = dataset.idx2ent, dataset.ent2idx
    fact_rdf, train_rdf, valid_rdf, test_rdf = dataset.fact_rdf, dataset.train_rdf, dataset.valid_rdf, dataset.test_rdf
    gt = defaultdict(list)
    all_rdf = fact_rdf + train_rdf + valid_rdf + test_rdf
    for rdf in all_rdf:
        h, r, t = parse_rdf(rdf)
        gt[(h, r)].append(ent2idx[t])
    return gt


def kg_completion(rules, dataset, args):
    fact_rdf, train_rdf, valid_rdf, test_rdf = dataset.fact_rdf, dataset.train_rdf, dataset.valid_rdf, dataset.test_rdf

    gt = get_gt(dataset)

    rdict = dataset.get_relation_dict()
    head_rdict = dataset.get_head_relation_dict()
    rel2idx, idx2rel = rdict.rel2idx, rdict.idx2rel

    idx2ent, ent2idx = dataset.idx2ent, dataset.ent2idx
    e_num = len(idx2ent)

    r2mat = construct_rmat(idx2rel, idx2ent, ent2idx, fact_rdf + train_rdf + valid_rdf)

    body2mat = {}

    rule_dataset = RuleDataset(r2mat, rules, e_num, idx2rel, args)
    rule_loader = DataLoader(
        rule_dataset,
        batch_size=args.data_batch_size,
        shuffle=False,
        num_workers=max(1, args.cpu_num // 2),

        collate_fn=RuleDataset.collate_fn,
    )

    for epoch, sample in enumerate(rule_loader):
        heads, score_counts = sample
        for idx in range(len(heads)):
            head = heads[idx]
            score_count = score_counts[idx]
            body2mat[head] = score_count

    for key, value in body2mat.items():
        array_matrix = value.todense()
        body2mat[key] = array_matrix

    for q_i, query_rdf in enumerate(test_rdf):
        query = parse_rdf(query_rdf)
        q_h, q_r, q_t = query
        if q_r not in body2mat:
            continue
        print("{}\t{}\t{}".format(q_h, q_r, q_t))

        pred = np.squeeze(np.array(body2mat[q_r][ent2idx[q_h]]))
        pred_ranks = np.argsort(pred)[::-1]

        truth = gt[(q_h, q_r)]
        truth = [t for t in truth if t != ent2idx[q_t]]

        L = []
        H = []

        top_rank = 0
        mean_rank = 0

        for i in range(len(pred_ranks)):
            idx = pred_ranks[i]
            if idx not in truth and pred[idx] > pred[ent2idx[q_t]]:
                L.append(idx)
            if idx not in truth and pred[idx] >= pred[ent2idx[q_t]]:
                H.append(idx)

        for i in range(len(L) + 1, len(H) + 1):
            mean_rank += i / (len(H) - len(L))
        Mean_rank['mrr'].append(1.0 / mean_rank)
        Mean_rank['hits_1'].append(1 if mean_rank <= 1 else 0)
        Mean_rank['hits_10'].append(1 if mean_rank <= 10 else 0)
        head2Mean_rank_mrr[q_r].append(1.0 / mean_rank)
        head2Mean_rank_hit_1[q_r].append(1 if mean_rank <= 1 else 0)
        head2Mean_rank_hit_10[q_r].append(1 if mean_rank <= 10 else 0)
        print("Number{}_use_mean_rank:{}".format(q_i, mean_rank))

        top_rank = len(L) + 1
        Top_rank['mrr'].append(1.0 / top_rank)
        Top_rank['hits_1'].append(1 if top_rank <= 1 else 0)
        Top_rank['hits_10'].append(1 if top_rank <= 10 else 0)
        head2Top_rank_mrr[q_r].append(1.0 / top_rank)
        head2Top_rank_hit_1[q_r].append(1 if top_rank <= 1 else 0)
        head2Top_rank_hit_10[q_r].append(1 if top_rank <= 10 else 0)
        print("Number{}_use_top_rank:{}".format(q_i, top_rank))

    print("{:<16}{:<20} Hits@1:{:<20} Hits@10:{:<20}\n{:>16}{:<20} Hits@1:{:<20} Hits@10:{:<20}\n".format(
        "expectation MRR:", np.mean(Mean_rank['mrr']), np.mean(Mean_rank['hits_1']),
        np.mean(Mean_rank['hits_10']), "TOP MRR:", np.mean(Top_rank['mrr']), np.mean(Top_rank['hits_1']),
        np.mean(Top_rank['hits_10'])))

    model_file = get_model_file(args)
    os.makedirs("../evaluate/{}/{}".format(args.model, args.datasets), exist_ok=True)
    with open("../evaluate/{}/{}/{}_{}-{}_[{}].txt".format(args.model, args.datasets, model_file, args.rule_len_low,
                                                           args.rule_len_high, args.threshold), 'w') as f:
        f.write("{:<16}{:<20} Hits@1:{:<20} Hits@10:{:<20}\n{:>16}{:<20} Hits@1:{:<20} Hits@10:{:<20}\n".format(
            "expectation MRR:", np.mean(Mean_rank['mrr']), np.mean(Mean_rank['hits_1']),
            np.mean(Mean_rank['hits_10']), "TOP MRR:", np.mean(Top_rank['mrr']), np.mean(Top_rank['hits_1']),
            np.mean(Top_rank['hits_10'])))

        f.write('\n{:<40}{:<20}{:<20}\n'.format("head", "expectation MRR", "TOP MRR"))
        for (head, mrr1), (head, mrr2) in zip(head2Mean_rank_mrr.items(), head2Top_rank_mrr.items()):
            f.write('{:<40}{:<20}{:<20}\n'.format(head, np.mean(mrr1), np.mean(mrr2)))


def feq(relation, fact_rdf):
    count = 0
    for rdf in fact_rdf:
        h, r, t = parse_rdf(rdf)
        if r == relation:
            count += 1
    return count


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from random import sample
import random
from torch.nn.utils import clip_grad_norm_
import time
import pickle
import argparse
import numpy as np
import os


import debugpy
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print("Current device:", device) # Print the correct device being used



rule_conf = {}
candidate_rule = {}


def sample_training_data(sample_max_path_len, anchor_num, fact_rdf, entity2desced, head_rdict):
    print("Sampling training data...")
    anchors_rdf = []
    per_anchor_num = anchor_num // ((head_rdict.__len__() - 1) // 2)
    print("Number of head relation:{}".format((head_rdict.__len__() - 1) // 2))
    print("Number of per_anchor_num: {}".format(per_anchor_num))

    fact_dict = construct_fact_dict(fact_rdf)
    for head in head_rdict.rel2idx:

        if head != "None" and "inv_" not in head:
            sampled_rdf = sample_anchor_rdf(fact_dict[head], per_anchor_num)
            anchors_rdf.extend(sampled_rdf)

    print("Total_anchor_num", len(anchors_rdf))
    train_rule, train_rule_dict = [], {}
    len2train_rule_idx = {}

    sample_number = 0

    for anchor_rdf in anchors_rdf:
        rule_seq, record = construct_rule_seq(fact_rdf, anchor_rdf, entity2desced, sample_max_path_len,
                                              PRINT=False)

        sample_number += len(record)
        if len(rule_seq) > 0:
            train_rule += rule_seq
            for rule in rule_seq:
                idx = torch.LongTensor(rule2idx(rule, head_rdict))
                h = head_rdict.idx2rel[idx[-1].item()]

                if h not in train_rule_dict:
                    train_rule_dict[h] = []
                train_rule_dict[h].append(idx)
                body_len = len(idx) - 2
                if body_len in len2train_rule_idx.keys():
                    len2train_rule_idx[body_len] += [idx]
                else:
                    len2train_rule_idx[body_len] = [idx]

    print("# head:{}".format(len(train_rule_dict)))
    for h in train_rule_dict:
        print("head {}:{}".format(h, len(train_rule_dict[h])))
    rule_len_range = list(len2train_rule_idx.keys())
    print("Fact set number:{} Sample number:{}".format(len(fact_rdf), sample_number))
    for rule_len in rule_len_range:
        print("sampled examples for rule of length {}: {}".format(rule_len, len(len2train_rule_idx[rule_len])))
    print("length_of_train_rule:{}".format(len(train_rule)))

    return len2train_rule_idx


def train(args, dataset):
    rdict = dataset.get_relation_dict()
    head_rdict = dataset.get_head_relation_dict()
    all_rdf = dataset.fact_rdf + dataset.train_rdf + dataset.valid_rdf
    entity2desced = construct_descendant(all_rdf)

    relation_num = rdict.__len__()

    sample_max_path_len = args.sample_max_path_len
    anchor_num = args.anchor
    len2train_rule_idx = sample_training_data(sample_max_path_len, anchor_num, all_rdf, entity2desced, head_rdict)
    print_msg("  Start training  ")
    batch_size = args.batch_size
    emb_size = args.embedding_size
    n_epoch = args.epochs
    lr = args.learning_rate
    body_len_range = list(range(args.learned_rule_len_from_x_to_X, args.learned_rule_len_from_2_to_X + 1))
    print("body_len_range", body_len_range)
    model = Encoder(relation_num, emb_size, device)
    if torch.cuda.is_available():
        model = model.cuda()
        if args.parallel:
            device_ids = [0, 1]
            model = torch.nn.DataParallel(model, device_ids=device_ids)
    loss_func_head = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    model.train()
    start = time.time()

    train_acc = {}

    for rule_len in body_len_range:
        rule_ = len2train_rule_idx[rule_len]
        print("\nrule length:{}".format(rule_len))

        train_acc[rule_len] = []
        for epoch in range(n_epoch):
            model.zero_grad()

            if len(rule_) > batch_size:
                sample_rule_ = sample(rule_, batch_size)
            else:
                sample_rule_ = rule_
            body_ = [r_[0:-2] for r_ in sample_rule_]
            head_ = [r_[-1] for r_ in sample_rule_]
            print(head_)
            inputs_h = body_
            targets_h = head_

            inputs_h = torch.stack(inputs_h, 0).to(device)
            targets_h = torch.stack(targets_h, 0).to(device)

            pred_head, _entropy_loss, relation_emb_weigth = model(inputs_h)

            loss_head = loss_func_head(pred_head, targets_h.reshape(-1))

            entropy_loss = _entropy_loss.mean()
            loss = args.alpha * loss_head + (1 - args.alpha) * entropy_loss

            if epoch % (n_epoch // 10) == 0:
                print("### epoch:{}\tloss_head:{:.3}\tentropy_loss:{:.3}\tloss:{:.3}\t".format(epoch, loss_head,
                                                                                               entropy_loss, loss))

            train_acc[rule_len].append(
                ((pred_head.argmax(dim=1) == targets_h.reshape(-1)).sum() / pred_head.shape[0]).cpu().numpy())
            print(train_acc)
            clip_grad_norm_(model.parameters(), 0.5)
            loss.backward()
            optimizer.step()

        import matplotlib.pyplot as plt

        plt.xlabel("Epochs")
        plt.ylabel("Accuracy")
        plt.title("LogicFormer Epoch vs Accurary")
        train_acc[rule_len] = [float(x) for x in train_acc[rule_len]]
        plt.plot(train_acc[rule_len])

        os.makedirs("../figures/{}/{}/{}".format(args.model, args.datasets, train_file_name), exist_ok=True)
        plt.savefig('../figures/{}/{}/{}/{}.png'.format(args.model, args.datasets, train_file_name, rule_len))

    end = time.time()
    print("Time usage: {:.2}".format(end - start))

    print("Saving model...")
    os.makedirs("../results/{}/{}".format(args.model, args.datasets, train_file_name), exist_ok=True)
    with open('../results/{}/{}/{}'.format(args.model, args.datasets, train_file_name), 'wb') as g:
        pickle.dump(model, g)


def enumerate_body(relation_num, rdict, body_len):
    import itertools
    all_body_idx = list(list(x) for x in itertools.product(range(relation_num), repeat=body_len))

    idx2rel = rdict.idx2rel
    all_body = []
    for b_idx_ in all_body_idx:
        b_ = [idx2rel[x] for x in b_idx_]
        all_body.append(b_)
    return all_body_idx, all_body
def train(args, dataset):
    rdict = dataset.get_relation_dict()
    head_rdict = dataset.get_head_relation_dict()
    all_rdf = dataset.fact_rdf + dataset.train_rdf + dataset.valid_rdf
    entity2desced = construct_descendant(all_rdf)

    relation_num = rdict.__len__()

    sample_max_path_len = args.sample_max_path_len
    anchor_num = args.anchor
    len2train_rule_idx = sample_training_data(sample_max_path_len, anchor_num, all_rdf, entity2desced, head_rdict)
    print("len2train_rule_idx:", len2train_rule_idx)  # Debugging line
    print_msg("  Start training  ")
    batch_size = args.batch_size
    emb_size = args.embedding_size
    n_epoch = args.epochs
    lr = args.learning_rate
    body_len_range = list(range(args.learned_rule_len_from_x_to_X, args.learned_rule_len_from_2_to_X + 1))
    print("body_len_range:", body_len_range)  # Debugging line
    model = Encoder(relation_num, emb_size, device)
    if torch.cuda.is_available():
        model = model.cuda()
        if args.parallel:
            device_ids = [0, 1]
            model = torch.nn.DataParallel(model, device_ids=device_ids)
    loss_func_head = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    model.train()
    start = time.time()

    train_acc = {}

    for rule_len in body_len_range:
        if rule_len not in len2train_rule_idx:
            print(f"Warning: rule length {rule_len} not found in len2train_rule_idx")  # Error handling line
            continue
        rule_ = len2train_rule_idx[rule_len]
        print("\nrule length:{}".format(rule_len))

        train_acc[rule_len] = []
        for epoch in range(n_epoch):
            model.zero_grad()

            if len(rule_) > batch_size:
                sample_rule_ = sample(rule_, batch_size)
            else:
                sample_rule_ = rule_
            body_ = [r_[0:-2] for r_ in sample_rule_]
            head_ = [r_[-1] for r_ in sample_rule_]
            print("head",head_)
            inputs_h = body_
            targets_h = head_

            inputs_h = torch.stack(inputs_h, 0).to(device)
            targets_h = torch.stack(targets_h, 0).to(device)

            pred_head, _entropy_loss, relation_emb_weigth = model(inputs_h)

            loss_head = loss_func_head(pred_head, targets_h.reshape(-1))
            print(targets_h)
            entropy_loss = _entropy_loss.mean()
            loss = args.alpha * loss_head + (1 - args.alpha) * entropy_loss

            if epoch % (n_epoch // 10) == 0:
                print("### epoch:{}\tloss_head:{:.3}\tentropy_loss:{:.3}\tloss:{:.3}\t".format(epoch, loss_head,
                                                                                               entropy_loss, loss))

            train_acc[rule_len].append(
                ((pred_head.argmax(dim=1) == targets_h.reshape(-1)).sum() / pred_head.shape[0]).cpu().numpy())

            clip_grad_norm_(model.parameters(), 0.5)
            loss.backward()
            optimizer.step()

        import matplotlib.pyplot as plt

        plt.xlabel("Epochs")
        plt.ylabel("Accuracy")
        plt.title("LogicFormer Epoch vs Accurary")
        train_acc[rule_len] = [float(x) for x in train_acc[rule_len]]
        plt.plot(train_acc[rule_len])

        os.makedirs("../figures/{}/{}/{}".format(args.model, args.datasets, train_file_name), exist_ok=True)
        plt.savefig('../figures/{}/{}/{}/{}.png'.format(args.model, args.datasets, train_file_name, rule_len))

    end = time.time()
    print("Time usage: {:.2}".format(end - start))

    print("Saving model...")
    os.makedirs("../results/{}/{}".format(args.model, args.datasets, train_file_name), exist_ok=True)
    with open('../results/{}/{}/{}'.format(args.model, args.datasets, train_file_name), 'wb') as g:
        pickle.dump(model, g)


def test(args, dataset):
    head_rdict = dataset.get_head_relation_dict()
    print(head_rdict)
    with open('../results/{}/{}/{}'.format(args.model, args.datasets, train_file_name),
              'rb') as g:
        if torch.cuda.is_available():
            model = pickle.load(g)
            model.to(device)
            print_msg(str(device))
        else:
            model = torch.load(g, map_location='cpu')
    print_msg("  Start Eval  ")
    model.eval()

    r_num = head_rdict.__len__() - 1

    batch_size = args.batch_size

    for i in range(args.learned_rule_len_from_x_to_X, args.learned_rule_len_from_2_to_X + 1):
        rule_len = i
        print("\nrule length:{}".format(rule_len))

        probs = []
        _, body = enumerate_body(r_num, head_rdict,
                                 body_len=rule_len)
        body_list = ["|".join(b) for b in body]
        candidate_rule[rule_len] = body_list
        n_epoches = math.ceil(float(len(body_list)) / batch_size)
        for epoches in range(n_epoches):
            bodies = body_list[epoches: (epoches + 1) * batch_size]
            if epoches == n_epoches - 1:
                bodies = body_list[epoches * batch_size:]
            else:
                bodies = body_list[epoches * batch_size: (epoches + 1) * batch_size]

            body_idx = body2idx(bodies, head_rdict)

            if torch.cuda.is_available():
                inputs = torch.LongTensor(np.array(body_idx)).to(device)
            else:
                inputs = torch.LongTensor(np.array(body_idx))

            with torch.no_grad():
                pred_head, _entropy_loss, relation_emb_weigth = model(inputs)

                prob_ = torch.softmax(pred_head, dim=-1)
                probs.append(prob_.detach().cpu())

        rule_conf[rule_len] = torch.cat(probs, dim=0)
        #print("rule_conf[{}].shape:{}".format(rule_len, rule_conf[rule_len].shape))
        if args.get_rule:
            print_msg("Generate Rule!")
            head_rdict = dataset.get_head_relation_dict()
            n_rel = head_rdict.__len__() - 1
            os.makedirs("../rules/{}/{}/{}".format(args.model, args.datasets, model_file), exist_ok=True)
            rule_path = '../rules/{}/{}/{}/{}.txt'.format(args.model, args.datasets, model_file, rule_len)
            #print("\nGenerate rule length:{}".format(rule_len))
            sorted_val, sorted_idx = torch.sort(rule_conf[rule_len], 0, descending=True)
            n_rules, _ = sorted_val.shape
            with open(rule_path, 'w') as g:
                for r in range(n_rel):

                    head = head_rdict.idx2rel[r]
                    idx = 0
                    while idx < args.get_top_k and idx < n_rules:
                        conf = sorted_val[idx, r]
                        body = candidate_rule[rule_len][sorted_idx[idx, r]]
                        msg = "{:.3f} ({:.3f})\t{} <-- ".format(conf, conf, head)
                        body = body.split('|')
                        msg += ", ".join(body)
                        g.write(msg + '\n')
                        idx += 1
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
def get_true_labels_for_bodies(bodies, dataset):
    true_labels = []
    for body in bodies:
        label = dataset.get_label(body)  # Use the method you defined in the Dataset class
        true_labels.append(label)
    return true_labels

def test(args, dataset):
    head_rdict = dataset.get_head_relation_dict()
    print(head_rdict)
    with open('../results/{}/{}/{}'.format(args.model, args.datasets, train_file_name), 'rb') as g:
        if torch.cuda.is_available():
            model = pickle.load(g)
            model.to(device)
            print_msg(str(device))
        else:
            model = torch.load(g, map_location='cpu')
    print_msg("  Start Eval  ")
    model.eval()

    r_num = head_rdict.__len__() - 1

    batch_size = args.batch_size
    all_labels = []
    all_predictions = []

    for i in range(args.learned_rule_len_from_x_to_X, args.learned_rule_len_from_2_to_X + 1):
        rule_len = i
        print("\nrule length:{}".format(rule_len))

        probs = []
        _, body = enumerate_body(r_num, head_rdict, body_len=rule_len)
        body_list = ["|".join(b) for b in body]
        candidate_rule[rule_len] = body_list
        n_epoches = math.ceil(float(len(body_list)) / batch_size)
        for epoches in range(n_epoches):
            if epoches == n_epoches - 1:
                bodies = body_list[epoches * batch_size:]
            else:
                bodies = body_list[epoches * batch_size: (epoches + 1) * batch_size]

            body_idx = body2idx(bodies, head_rdict)

            if torch.cuda.is_available():
                inputs = torch.LongTensor(np.array(body_idx)).to(device)
            else:
                inputs = torch.LongTensor(np.array(body_idx))

            with torch.no_grad():
                pred_head, _entropy_loss, relation_emb_weigth = model(inputs)

                prob_ = torch.softmax(pred_head, dim=-1)
                predictions = prob_.argmax(dim=-1).cpu().numpy()
                all_predictions.extend(predictions)

                # Assuming you have a way to get true labels for each body
                true_labels = get_true_labels_for_bodies(bodies, dataset)  # You need to implement this
                all_labels.extend(true_labels)

                probs.append(prob_.detach().cpu())

        rule_conf[rule_len] = torch.cat(probs, dim=0)

    # Calculating metrics
    accuracy = accuracy_score(all_labels, all_predictions)
    precision = precision_score(all_labels, all_predictions, average='macro')
    recall = recall_score(all_labels, all_predictions, average='macro')
    f1 = f1_score(all_labels, all_predictions, average='macro')

    print("Accuracy: {:.3f}, Precision: {:.3f}, Recall: {:.3f}, F1 Score: {:.3f}".format(accuracy, precision, recall, f1))


def get_model_file(args):
    model_file = '[epochs{}][alpha{}][anchor{}][max_sample{}][emb{}][lr{}][batchsize{}][topk{}]'.format(
        args.epochs,
        args.alpha,
        args.anchor,
        args.sample_max_path_len,
        args.embedding_size,
        args.learning_rate,
        args.batch_size,
        args.get_top_k
    )
    return model_file


def get_train_file_name(args):
    train_file_name = '[epochs{}][alpha{}][anchor{}][max_sample{}][emb{}][lr{}][batchsize{}]'.format(
        args.epochs,
        args.alpha,
        args.anchor,
        args.sample_max_path_len,
        args.embedding_size,
        args.learning_rate,
        args.batch_size
    )
    return train_file_name

if __name__ == '__main__':
    msg = "First Order Logic Rule Mining"
    print(msg)
    parser = argparse.ArgumentParser()
    parser.add_argument('-f', '--file', help='Dummy argument to avoid error')
    parser.add_argument("--get_rule", default="1", action="store_true", help="increase output verbosity")
    parser.add_argument("--gpu", type=int, default=0, help="increase output verbosity")
    parser.add_argument("--parallel", default="", help="increase output verbosity")
    parser.add_argument("--sparsity", type=float, default=1, help="increase output verbosity")
    parser.add_argument("--model", default="show", help="训练谁的模型")
    parser.add_argument("--datasets", default="tic", help="数据集")
    parser.add_argument("--epochs", type=int, default=1000, help="epochs")
    parser.add_argument("--alpha", type=float, default=0.8, help="increase output verbosity")
    parser.add_argument("--beta", type=float, default=3, help="control beta")
    parser.add_argument("--gamma", type=float, default=3, help="control gamma")
    #parser.add_argument("--temp", type=float, default=0.001, help="control temperature")
    parser.add_argument("--anchor", type=int, default=5000, help="increase output verbosity")
    parser.add_argument("--sample_max_path_len", type=int, default=2, help="采样路径长度")
    parser.add_argument("--embedding_size", type=int, default=512, help="embedding_size")
    parser.add_argument("--learning_rate", type=int, default=0.001, help="learning_rate")
    parser.add_argument("--batch_size", type=int, default=1000, help="increase output verbosity")
    parser.add_argument("--get_top_k", type=int, default=100, help="得到规则的get_top_k")
    parser.add_argument("--learned_rule_len_from_x_to_X", type=int, default=2, help="学习规则的长度")
    parser.add_argument("--learned_rule_len_from_2_to_X", type=int, default=2, help="学习规则的长度")
    parser.add_argument('--cpu_num', type=int, default=20)
    parser.add_argument("--data_batch_size", type=int, default=1)
    parser.add_argument("--threshold", type=float, default=0)
    parser.add_argument("--rule_len_low", type=int, default=2)
    parser.add_argument("--rule_len_high", type=int, default=2)
    args = parser.parse_args()
    model_file = get_model_file(args)
    train_file_name = get_train_file_name(args)
    if torch.cuda.is_available():
        torch.cuda.set_device(args.gpu)
    data_path = '/content/CARL/datasets/{}/'.format(args.datasets)
    dataset = Dataset(data_root=data_path, sparsity=args.sparsity,
                      inv=True)
    print("Dataset:{}".format(data_path))
    model_path = "../results/{}/{}/{}".format(args.model, args.datasets, train_file_name)
    print("results at:{}".format(model_path))
    if not os.path.isfile("../results/{}/{}/{}".format(args.model, args.datasets, train_file_name)):
        print("Train!")
        train(args, dataset)
    print_msg("Test!")
    test(args, dataset)
    all_rules = {}
    all_rule_heads = []
    for L in range(args.rule_len_low, args.rule_len_high + 1):
        file = "../rules/{}/{}/{}/{}.txt".format(args.model, args.datasets, model_file, L)
        load_rules("{}".format(file), all_rules, all_rule_heads)
    for head in all_rules:
        all_rules[head] = all_rules[head][:args.get_top_k * 5]
    fact_rdf, train_rdf, valid_rdf, test_rdf = dataset.fact_rdf, dataset.train_rdf, dataset.valid_rdf, dataset.test_rdf
    kg_completion(all_rules, dataset, args)