In [40]:
##############################################################
# IMPORTS
##############################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import re
import random

from sklearn.metrics import roc_auc_score

# For SPARQL queries, if you want to query the endpoint
# pip install SPARQLWrapper
from SPARQLWrapper import SPARQLWrapper, JSON

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)


##############################################################
# A) EMBEDDING MODEL (ConvE) + TRAINING CODE
##############################################################

def parse_reified_triples(ttl_path):
    """
    Parse (s, p, o, label) from the training file which includes hasTruthValue (0 or 1).
    """
    fact_dict = {}
    with open(ttl_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith('#'):
                continue
            # match lines like:
            # <factIRI> <predicate> <object> .
            # or
            # <factIRI> <http://swc2017.aksw.org/hasTruthValue> "0.0"^^...
            m = re.match(r'^<([^>]+)>\s+<([^>]+)>\s+(.*)\s+\.$', line)
            if not m:
                continue
            factIRI = m.group(1)
            predIRI = m.group(2)
            objStr  = m.group(3)

            if factIRI not in fact_dict:
                fact_dict[factIRI] = {
                    'subject': None,
                    'predicate': None,
                    'object': None,
                    'label': None
                }

            if predIRI == 'http://swc2017.aksw.org/hasTruthValue':
                val_match = re.match(r'^"([^"]+)"\^\^<([^>]+)>$', objStr)
                if val_match:
                    val = float(val_match.group(1))
                    fact_dict[factIRI]['label'] = val
            elif predIRI == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#subject':
                s_match = re.match(r'^<([^>]+)>$', objStr)
                if s_match:
                    fact_dict[factIRI]['subject'] = s_match.group(1)
            elif predIRI == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#predicate':
                p_match = re.match(r'^<([^>]+)>$', objStr)
                if p_match:
                    fact_dict[factIRI]['predicate'] = p_match.group(1)
            elif predIRI == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#object':
                o_match = re.match(r'^<([^>]+)>$', objStr)
                if o_match:
                    fact_dict[factIRI]['object'] = o_match.group(1)
    
    results = []
    for fid, vals in fact_dict.items():
        s = vals['subject']
        p = vals['predicate']
        o = vals['object']
        lbl= vals['label']
        if s and p and o and (lbl is not None):
            results.append((s, p, o, lbl))
    return results


def build_index(triples_4tuple):
    """
    Return entity2id, relation2id, plus reverse maps.
    """
    entities = set()
    relations = set()
    for s, p, o, lbl in triples_4tuple:
        entities.add(s)
        entities.add(o)
        relations.add(p)
    entities = sorted(list(entities))
    relations = sorted(list(relations))
    entity2id = {e: i for i, e in enumerate(entities)}
    relation2id = {r: i for i, r in enumerate(relations)}
    return entity2id, relation2id


def convert_to_idx(triples_4tuple, entity2id, relation2id):
    data_idx = []
    for s, p, o, lbl in triples_4tuple:
        s_idx = entity2id[s]
        p_idx = relation2id[p]
        o_idx = entity2id[o]
        data_idx.append((s_idx, p_idx, o_idx, lbl))
    return data_idx


class FactDataset(Dataset):
    def __init__(self, data_idx, num_entities):
        self.data = data_idx
        self.num_entities = num_entities

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]  # (s_idx, p_idx, o_idx, label)
    
    def collate_fn(self, batch):
        """
        batch is list of (s, p, o, label)
        We'll add 1 negative triple per positive or negative (whichever).
        """
        s = torch.LongTensor([x[0] for x in batch])
        p = torch.LongTensor([x[1] for x in batch])
        o = torch.LongTensor([x[2] for x in batch])
        lbl = torch.FloatTensor([x[3] for x in batch])  # 0 or 1

        # We'll generate a random corruption:
        batch_size = len(batch)
        rand_ents = torch.randint(0, self.num_entities, (batch_size,))
        mask_corrupt_head = (torch.rand(batch_size) > 0.5)
        s_neg = s.clone()
        o_neg = o.clone()
        s_neg[mask_corrupt_head] = rand_ents[mask_corrupt_head]
        o_neg[~mask_corrupt_head] = rand_ents[~mask_corrupt_head]

        # label for corrupted = 0
        lbl_neg = torch.zeros(batch_size)

        s_final = torch.cat([s, s_neg], dim=0)
        p_final = torch.cat([p, p], dim=0)
        o_final = torch.cat([o, o_neg], dim=0)
        lbl_final= torch.cat([lbl, lbl_neg], dim=0)

        return s_final.to(device), p_final.to(device), o_final.to(device), lbl_final.to(device)


class ConvE(nn.Module):
    def __init__(self, num_entities, num_relations, embedding_dim=100, embed_shape=(10,10), num_filters=32, kernel_size=3):
        super(ConvE, self).__init__()
        self.num_entities = num_entities
        self.num_relations= num_relations
        self.embedding_dim= embedding_dim
        assert embed_shape[0]*embed_shape[1] == embedding_dim
        self.embed_shape = embed_shape

        # Embeddings
        self.emb_ent = nn.Embedding(num_entities, embedding_dim)
        self.emb_rel = nn.Embedding(num_relations, embedding_dim)
        bound = 1.0 / embedding_dim**0.5
        nn.init.uniform_(self.emb_ent.weight, a=-bound, b=bound)
        nn.init.uniform_(self.emb_rel.weight, a=-bound, b=bound)

        # 2D conv
        self.conv = nn.Conv2d(in_channels=1, out_channels=num_filters, kernel_size=kernel_size, stride=1, padding=1)
        out_h = embed_shape[0]*2  # since we stack subject & relation vertically
        out_w = embed_shape[1]
        self.flat_size = num_filters*out_h*out_w
        self.fc = nn.Linear(self.flat_size, embedding_dim)

    def forward(self, s_idx, p_idx, o_idx):
        """
        Return logits, shape (B,)
        """
        # embeddings
        s_e = self.emb_ent(s_idx)  # (B, dim)
        p_e = self.emb_rel(p_idx)  # (B, dim)
        o_e = self.emb_ent(o_idx)  # (B, dim)

        B = s_e.size(0)
        h, w = self.embed_shape

        s_2d = s_e.view(B, 1, h, w)
        p_2d = p_e.view(B, 1, h, w)
        # stack along height => shape (B, 1, 2*h, w)
        stacked = torch.cat([s_2d, p_2d], dim=2)

        c = self.conv(stacked)   # (B, num_filters, 2h, w)
        c = F.relu(c)
        c = c.view(B, -1)
        c = self.fc(c)
        c = F.relu(c)

        # dot with o_e
        logits = torch.sum(c * o_e, dim=1)
        return logits
    
    def score_triple(self, s_idx, p_idx, o_idx):
        """
        Return a logit for (s_idx, p_idx, o_idx).
        Larger => more likely positive.
        We'll apply sigmoid outside if needed.
        """
        return self.forward(s_idx, p_idx, o_idx)


##############################################################
# B) LOADING REFERENCE KG & CLASS HIERARCHY
##############################################################

def load_reference_kg(ref_kg_path):
    """
    Example: parse domain/range statements and any other facts you want from the reference KG dump.
    Returns dictionaries:
      domain_of: relationIRI -> set of possible domain classes
      range_of:  relationIRI -> set of possible range classes
      type_of:   entityIRI -> set of classes (rdfs:subClassOf or rdf:type)
    You can expand as needed.
    """
    domain_of = {}
    range_of  = {}
    type_of   = {}
    
    with open(ref_kg_path, 'r', encoding='utf-8') as f:
        for line in f:
            line=line.strip()
            if not line or line.startswith('#'):
                continue
            # e.g. <relIRI> <rdfs:domain> <classIRI> .
            m = re.match(r'^<([^>]+)>\s+<([^>]+)>\s+<([^>]+)>\s+\.$', line)
            if not m:
                continue
            subj = m.group(1)   # e.g. relation IRI
            pred = m.group(2)
            obj  = m.group(3)

            if pred == 'http://www.w3.org/2000/01/rdf-schema#domain':
                # domain of "subj" is "obj"
                domain_of.setdefault(subj, set()).add(obj)
            elif pred == 'http://www.w3.org/2000/01/rdf-schema#range':
                # range of "subj" is "obj"
                range_of.setdefault(subj, set()).add(obj)
            elif pred == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type':
                # type_of[entity].add(class)
                type_of.setdefault(subj, set()).add(obj)
        print(len(domain_of),"!!! len domain_of")
        print(len(type_of), "!!! len type_of")
    return domain_of, range_of, type_of

def load_class_hierarchy(class_hierarchy_path):
    """
    Parse lines like:
      <classA> <rdfs:subClassOf> <classB> .
    We store subClassOf relationships in a dict for inference (transitive closure if desired).
    """
    subclass_of = {}  # e.g. subclass_of[A] = set([B, ...])
    with open(class_hierarchy_path, 'r', encoding='utf-8') as f:
        for line in f:
            line=line.strip()
            if not line or line.startswith('#'):
                continue
            m = re.match(r'^<([^>]+)>\s+<http://www.w3.org/2000/01/rdf-schema#subClassOf>\s+<([^>]+)>\s+\.$', line)
            if not m:
                continue
            child = m.group(1)
            parent= m.group(2)
            subclass_of.setdefault(child, set()).add(parent)
    return subclass_of

def expand_types_with_subclass(type_of, subclass_of, max_depth=5):
    """
    Expand each entity's type set by traversing up the subclass hierarchy.
    e.g. if entity e is typed as ChildClass, and ChildClass subClassOf ParentClass,
         we also add ParentClass to e's type set.
    Simple BFS up to a certain depth or until no more changes.
    """
    for ent, classes in list(type_of.items()):
        to_visit = list(classes)
        visited = set(classes)
        depth=0
        while to_visit and depth<max_depth:
            nxt = []
            for c in to_visit:
                # c might have parents
                if c in subclass_of:
                    for parent in subclass_of[c]:
                        if parent not in visited:
                            visited.add(parent)
                            nxt.append(parent)
            to_visit = nxt
            depth+=1
        # update
        type_of[ent] = visited
    print(depth,"!!! depth")
    return type_of


##############################################################
# C) DOMAIN-RANGE & SUBCLASS REASONING
##############################################################

def domain_range_consistency(subjectIRI, predicateIRI, objectIRI, 
                             domain_of, range_of, type_of):
    """
    Return a score in [0,1] indicating how consistent the triple is with
    domain/range constraints in the reference KG + type_of dictionary.
    Example approach:
      - If predicateIRI is in domain_of => check if subjectIRI's types intersect
        domain_of[predicateIRI] (or their superclasses).
      - If predicateIRI is in range_of => check if objectIRI's types intersect
        range_of[predicateIRI].
      - Score could be average of domain_match + range_match, or you can do your own logic.
    """
    # if either entity is missing from type_of => we can't confirm or deny => partial
    subj_types = type_of.get(subjectIRI, set())
    obj_types  = type_of.get(objectIRI, set())

    # domain match
    doms = domain_of.get(predicateIRI, set())  # possible domain classes
    if not doms:
        domain_score = 0.2  # we don't know domain => partial guess
    else:
        # if there's an intersection between subj_types and doms => good
        if subj_types.intersection(doms):
            domain_score = 1.0
        else:
            domain_score = 0.0

    # range match
    rngs = range_of.get(predicateIRI, set())
    if not rngs:
        range_score = 0.2
    else:
        if obj_types.intersection(rngs):
            range_score = 1.0
        else:
            range_score = 0.0

    # final
    # we can do the average or min or something
    final_score = (domain_score + range_score)/2.0
    return final_score


##############################################################
# D) SCORING LOGIC (COMBINING EMBEDDINGS + KNOWLEDGE-BASED CHECKS)
##############################################################

def hybrid_score(model, 
                 sIRI, pIRI, oIRI, 
                 entity2id, relation2id, 
                 domain_of, range_of, type_of,
                 alpha=0.9):
    """
    If (sIRI, pIRI, oIRI) is fully in-vocab => produce embedding score in [0,1].
    If partially OOV => combine domain-range check with fallback to .5 or partial embedding.

    alpha controls how strongly we weigh the knowledge-based check vs. embedding.
    """
    # 1) Check if in-vocab
    inVocab = (sIRI in entity2id) and (pIRI in relation2id) and (oIRI in entity2id)

    # 2) Knowledge-based domain-range check
    kg_consistency = domain_range_consistency(sIRI, pIRI, oIRI, domain_of, range_of, type_of)

    if inVocab:
        # Embedding-based logit => convert to [0,1] by sigmoid
        s_idx = torch.tensor([entity2id[sIRI]], device=device)
        p_idx = torch.tensor([relation2id[pIRI]], device=device)
        o_idx = torch.tensor([entity2id[oIRI]], device=device)
        with torch.no_grad():
            logit = model.score_triple(s_idx, p_idx, o_idx)  # shape (1,)
            emb_score = torch.sigmoid(logit).item()
        # combine with domain-range if you want
        final_score = alpha*emb_score + (1-alpha)*kg_consistency
    else:
        # at least one is OOV => we have no embedding info
        # => fallback to knowledge-based check alone or combine with 0.5
        final_score = alpha*kg_consistency + (1-alpha)*0.5

    return final_score


##############################################################
# E) GENERATE FINAL TTL (TEST SET) 
##############################################################

def parse_test_reified_triples(ttl_path):
    """
    Return list of (factIRI, subjectIRI, predicateIRI, objectIRI) from test file
    (No hasTruthValue line).
    """
    fact_dict={}
    with open(ttl_path, 'r', encoding='utf-8') as f:
        for line in f:
            line=line.strip()
            if not line or line.startswith('#'):
                continue
            m = re.match(r'^<([^>]+)>\s+<([^>]+)>\s+(.*)\s+\.$', line)
            if not m:
                continue
            factIRI = m.group(1)
            predIRI = m.group(2)
            objStr  = m.group(3)
            if factIRI not in fact_dict:
                fact_dict[factIRI]={'subject':None, 'predicate':None, 'object':None}
            if predIRI == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#subject':
                s_match = re.match(r'^<([^>]+)>$', objStr)
                if s_match:
                    fact_dict[factIRI]['subject']= s_match.group(1)
            elif predIRI == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#predicate':
                p_match = re.match(r'^<([^>]+)>$', objStr)
                if p_match:
                    fact_dict[factIRI]['predicate']= p_match.group(1)
            elif predIRI == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#object':
                o_match = re.match(r'^<([^>]+)>$', objStr)
                if o_match:
                    fact_dict[factIRI]['object']= o_match.group(1)
            # ignore type statement, etc.

    results=[]
    for fid,vals in fact_dict.items():
        s=vals['subject']
        p=vals['predicate']
        o=vals['object']
        if s and p and o:
            results.append((fid, s, p, o))
    return results


def generate_results_ttl(model, 
                         test_path, 
                         out_path, 
                         entity2id, 
                         relation2id, 
                         domain_of, 
                         range_of, 
                         type_of,
                         alpha=0.9):
    """
    For each test fact, produce a final veracity in [0,1], then write to out_path as TTL lines.
    """
    test_facts = parse_test_reified_triples(test_path)
    print("Number of test facts:", len(test_facts))

    model.eval()
    with open(out_path, 'w', encoding='utf-8') as f:
        for (factIRI, sIRI, pIRI, oIRI) in test_facts:
            score = hybrid_score(model, sIRI, pIRI, oIRI, 
                                 entity2id, relation2id,
                                 domain_of, range_of, type_of,
                                 alpha=alpha)
            # write line
            line = f'<{factIRI}> <http://swc2017.aksw.org/hasTruthValue> "{score}"^^<http://www.w3.org/2001/XMLSchema#double> .\n'
            f.write(line)
    print(f"Results written to {out_path}")


##############################################################
# MAIN PIPELINE EXAMPLE
##############################################################

def main():
    #######################
    # 1) Load + Train Model
    #######################
    training_ttl = "fokg-sw-train-2024.nt"  # adapt
    all_facts = parse_reified_triples(training_ttl)
    print("Train facts:", len(all_facts))

    entity2id, relation2id = build_index(all_facts)
    data_idx = convert_to_idx(all_facts, entity2id, relation2id)

    # Shuffle & split
    random.shuffle(data_idx)
    split_pt = int(0.8*len(data_idx))
    train_data = data_idx[:split_pt]
    valid_data = data_idx[split_pt:]

    train_ds = FactDataset(train_data, num_entities=len(entity2id))
    valid_ds = FactDataset(valid_data, num_entities=len(entity2id))

    train_loader= DataLoader(train_ds, batch_size=1024, shuffle=True, 
                             collate_fn=train_ds.collate_fn, drop_last=True)
    valid_loader= DataLoader(valid_ds, batch_size=1024, shuffle=False, 
                             collate_fn=valid_ds.collate_fn, drop_last=False)

    model = ConvE(num_entities=len(entity2id),
                  num_relations=len(relation2id),
                  embedding_dim=100,
                  embed_shape=(10,10),
                  num_filters=32,
                  kernel_size=3).to(device)

    optim = torch.optim.Adam(model.parameters(), lr=0.0005)
    EPOCHS=30

    def train_epoch(loader):
        model.train()
        total_loss=0
        for batch in loader:
            s, p, o, lbl = batch
            logits = model.forward(s, p, o)
            loss   = F.binary_cross_entropy_with_logits(logits, lbl)
            optim.zero_grad()
            loss.backward()
            optim.step()
            total_loss+=loss.item()
        return total_loss
    
    def evaluate_auc(loader):
        model.eval()
        all_labels=[]
        all_scores=[]
        with torch.no_grad():
            for batch in loader:
                s, p, o, lbl= batch
                batch_size = s.shape[0]//2
                # first half => real
                s_real= s[:batch_size]
                p_real= p[:batch_size]
                o_real= o[:batch_size]
                lbl_real=lbl[:batch_size]
                logits= model.forward(s_real, p_real, o_real)
                all_scores.extend(logits.cpu().numpy().tolist())
                all_labels.extend(lbl_real.cpu().numpy().tolist())
        # check if we have both 0 and 1
        if len(set(all_labels))<2:
            return 0.0
        return roc_auc_score(all_labels, all_scores)

    for epoch in range(1, EPOCHS+1):
        train_loss= train_epoch(train_loader)
        val_auc= evaluate_auc(valid_loader)
        print(f"[Epoch {epoch}] TrainLoss={train_loss:.4f}, ValAUC={val_auc:.4f}")

    #######################
    # 2) Load Reference KG + Class Hierarchy
    #######################
    ref_kg_path= "reference-kg.nt"  # adapt
    domain_of, range_of, type_of = load_reference_kg(ref_kg_path)

    class_hierarchy_path= "classHierarchy.nt" # adapt
    subclass_of= load_class_hierarchy(class_hierarchy_path)

    # Expand type_of with transitive subClassOf
    type_of= expand_types_with_subclass(type_of, subclass_of)

    #######################
    # 3) Predict on Test
    #######################
    test_ttl= "fokg-sw-test-2024.nt"
    out_ttl= "results3.ttl"

    generate_results_ttl(model, 
                         test_ttl, 
                         out_ttl, 
                         entity2id,
                         relation2id,
                         domain_of,
                         range_of,
                         type_of,
                         alpha=0.9)

if __name__=="__main__":
    main()


Using device: cpu
Train facts: 1000
[Epoch 1] TrainLoss=0.0000, ValAUC=0.4653
[Epoch 2] TrainLoss=0.0000, ValAUC=0.4653
[Epoch 3] TrainLoss=0.0000, ValAUC=0.4653
[Epoch 4] TrainLoss=0.0000, ValAUC=0.4653
[Epoch 5] TrainLoss=0.0000, ValAUC=0.4653
[Epoch 6] TrainLoss=0.0000, ValAUC=0.4653
[Epoch 7] TrainLoss=0.0000, ValAUC=0.4653
[Epoch 8] TrainLoss=0.0000, ValAUC=0.4653
[Epoch 9] TrainLoss=0.0000, ValAUC=0.4653
[Epoch 10] TrainLoss=0.0000, ValAUC=0.4653
[Epoch 11] TrainLoss=0.0000, ValAUC=0.4653
[Epoch 12] TrainLoss=0.0000, ValAUC=0.4653
[Epoch 13] TrainLoss=0.0000, ValAUC=0.4653
[Epoch 14] TrainLoss=0.0000, ValAUC=0.4653
[Epoch 15] TrainLoss=0.0000, ValAUC=0.4653
[Epoch 16] TrainLoss=0.0000, ValAUC=0.4653
[Epoch 17] TrainLoss=0.0000, ValAUC=0.4653
[Epoch 18] TrainLoss=0.0000, ValAUC=0.4653
[Epoch 19] TrainLoss=0.0000, ValAUC=0.4653
[Epoch 20] TrainLoss=0.0000, ValAUC=0.4653
[Epoch 21] TrainLoss=0.0000, ValAUC=0.4653
[Epoch 22] TrainLoss=0.0000, ValAUC=0.4653
[Epoch 23] TrainLoss=0.0000