In [5]:
import json
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

## Data Loading and Vocab Mappings

In [6]:
def load_hyperfacts(json_path):
    """
    Reads hyperfacts from a JSON file, each entry like:
      {
        "drug1": "DrugA",
        "relation": "interactWith",
        "drug2": "DrugC",
        "attributes": {
            "adverseEvent": "ConditionX",
            "PRR": 400.0
        }
      }
    Returns a list of facts:
      [
        (h_str, r_str, t_str, [(k1_str, v1_str), (k2_str, v2_str), ...]),
        ...
      ]
    """
    with open(json_path, "r") as f:
        data = json.load(f)

    facts = []
    for entry in data:
        h = entry["drug1"]
        r = entry["relation"]
        t = entry["drug2"]
        # Convert the "attributes" dict into a list of (k, v) pairs
        kv_pairs = []
        if "attributes" in entry:
            for k, v in entry["attributes"].items():
                # v could be a string or numeric; if numeric, cast to string
                # to keep everything consistent in entity embeddings
                if not isinstance(v, str):
                    v = str(v)
                kv_pairs.append((k, v))
        facts.append((h, r, t, kv_pairs))
    return facts

def load_conditions_from_json(conditions_path):
    """
    Loads an array of condition strings from a JSON file.
    E.g. ["Dissociative disorder", "Incision site haemorrhage", ...]
    """
    with open(conditions_path, "r") as f:
        return json.load(f)

def build_vocab_and_mappings(facts):
    """
    Assigns unique IDs to each entity and relation found in the hyperfacts.
    Returns:
      entity2id (dict): maps entity string -> integer ID
      relation2id (dict): maps relation string -> integer ID
    """
    entity2id = {}
    relation2id = {}
    next_eid = 0
    next_rid = 0

    for (h, r, t, kv_pairs) in facts:
        # Entities
        if h not in entity2id:
            entity2id[h] = next_eid
            next_eid += 1
        if t not in entity2id:
            entity2id[t] = next_eid
            next_eid += 1

        # Relations
        if r not in relation2id:
            relation2id[r] = next_rid
            next_rid += 1

        # For the key-value pairs, we treat "k" like a relation; "v" like an entity.
        for (k, v) in kv_pairs:
            if k not in relation2id:
                relation2id[k] = next_rid
                next_rid += 1
            if v not in entity2id:
                entity2id[v] = next_eid
                next_eid += 1

    return entity2id, relation2id



##  HINGE Model Definition

In [7]:
class HINGEModel(nn.Module):
    """
    Minimal PyTorch implementation of the HINGE architecture:
      - triple-wise pipeline: (h, r, t)
      - quintuple-wise pipeline: (h, r, t, k, v)
      - merges features via elementwise MIN
      - final linear layer for scoring
    """
    def __init__(self, num_entities, num_relations, embedding_dim=100, num_filters=400):
        super(HINGEModel, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_filters = num_filters

        # Lookup tables
        self.entity_emb = nn.Embedding(num_entities, embedding_dim)
        self.relation_emb = nn.Embedding(num_relations, embedding_dim)

        # CNN for triple (3 x embedding_dim)
        # Filter size = (3,3) => a "height" of 3 to convolve across (h, r, t), "width" of 3
        self.triple_conv = nn.Conv2d(
            in_channels=1,
            out_channels=num_filters,
            kernel_size=(3, 3),
            stride=(1, 1),
            padding=(0, 0)
        )

        # CNN for quintuple (5 x embedding_dim)
        # Filter size = (5,3)
        self.quintuple_conv = nn.Conv2d(
            in_channels=1,
            out_channels=num_filters,
            kernel_size=(5, 3),
            stride=(1, 1),
            padding=(0, 0)
        )

        # After convolution + ReLU => shape is [batch_size, num_filters, 1, (embedding_dim - 2)]
        # Flatten to [batch_size, num_filters*(embedding_dim-2)] => final linear
        feature_dim = num_filters * (embedding_dim - 2)
        self.proj = nn.Linear(feature_dim, 1)

        # Weight init
        nn.init.xavier_uniform_(self.entity_emb.weight)
        nn.init.xavier_uniform_(self.relation_emb.weight)
        nn.init.xavier_uniform_(self.triple_conv.weight)
        nn.init.xavier_uniform_(self.quintuple_conv.weight)
        nn.init.xavier_uniform_(self.proj.weight)

    def triple_forward(self, h, r, t):
        """
        Forward pass for triple (h, r, t).
        Input shapes: h, r, t => [batch_size] of IDs
        Returns: a feature tensor => [batch_size, num_filters*(embedding_dim-2)]
        """
        h_emb = self.entity_emb(h)   # => [batch_size, embedding_dim]
        r_emb = self.relation_emb(r)
        t_emb = self.entity_emb(t)

        # Stack in a "height" dimension => [batch_size, 3, embedding_dim]
        x = torch.stack([h_emb, r_emb, t_emb], dim=1)  # shape [B, 3, E]
        # Unsqueeze for conv2d => [B, 1, 3, E]
        x = x.unsqueeze(1)

        x = self.triple_conv(x)  # => [B, num_filters, 3-3+1=1, E-3+1=E-2]
        x = F.relu(x)
        # Squeeze the height dimension => [B, num_filters, embedding_dim-2]
        x = x.squeeze(2)  # remove dimension=2 if it is size 1
        # Flatten => [B, num_filters*(embedding_dim-2)]
        x = x.view(x.size(0), -1)
        return x

    def quintuple_forward(self, h, r, t, k, v):
        """
        Forward pass for quintuple (h, r, t, k, v).
        """
        h_emb = self.entity_emb(h)
        r_emb = self.relation_emb(r)
        t_emb = self.entity_emb(t)
        k_emb = self.relation_emb(k)
        v_emb = self.entity_emb(v)

        # shape [B, 5, E]
        x = torch.stack([h_emb, r_emb, t_emb, k_emb, v_emb], dim=1)
        # => [B, 1, 5, E]
        x = x.unsqueeze(1)

        x = self.quintuple_conv(x)  # => [B, num_filters, 1, E-2]
        x = F.relu(x)
        x = x.squeeze(2)           # => [B, num_filters, E-2]
        x = x.view(x.size(0), -1)  # => [B, num_filters*(E-2)]
        return x

    def forward(self, h, r, t, kv_pairs):
        """
        Forward pass for a single hyper-relational fact with N key-value pairs.
        h, r, t => [batch_size] of size 1 if we handle 1 fact at a time
        kv_pairs => list of (k_id, v_id), each shape [batch_size] (or scalar).

        Steps:
          1) get triple_feat
          2) get quintuple_feat for each (k, v)
          3) combine via elementwise min along feature dimension
          4) project => score
        """
        # triple-wise features: shape [B, F]
        triple_feat = self.triple_forward(h, r, t)

        if len(kv_pairs) == 0:
            # If no attributes, the final feature is just triple_feat
            merged_feat = triple_feat
        else:
            # For each (k, v), get quintuple features => shape [B, F]
            # We'll store them in a list => then stack => shape [N, B, F]
            quint_feat_list = []
            for (k_id, v_id) in kv_pairs:
                qf = self.quintuple_forward(h, r, t, k_id, v_id)  # => [B, F]
                quint_feat_list.append(qf)
            # Stack => [N, B, F]
            qfeats = torch.stack(quint_feat_list, dim=0)

            # We want to do elementwise min along dimension=0 *and* incorporate triple_feat
            # But we need triple_feat repeated N times (or we can do pairwise min per step).
            # A simpler approach is:
            # 1) Expand triple_feat => shape [N, B, F]
            triple_feat_expanded = triple_feat.unsqueeze(0).expand(qfeats.size(0), -1, -1)
            # 2) min => shape [N, B, F]
            combined = torch.min(triple_feat_expanded, qfeats)
            # 3) Now we min across N => shape [B, F]
            merged_feat, _ = torch.min(combined, dim=0)

        # final projection => [B, 1]
        score = self.proj(merged_feat)
        return score



## Example usage (training skeleton + top-K condition inference)

In [8]:
def main():
    # ------------------------------------------------
    # A) Load data
    # ------------------------------------------------
    hyperfacts_path = "hyperfacts.json"   # your hyperfacts
    conditions_path = "conditions.json"   # your array of condition strings
    facts = load_hyperfacts(hyperfacts_path)
    all_conditions = load_conditions_from_json(conditions_path)

    # ------------------------------------------------
    # B) Build vocab
    # ------------------------------------------------
    entity2id, relation2id = build_vocab_and_mappings(facts)
    print(f"Num entities: {len(entity2id)}, Num relations: {len(relation2id)}")

    # ------------------------------------------------
    # C) Instantiate model
    # ------------------------------------------------
    model = HINGEModel(num_entities=len(entity2id),
                       num_relations=len(relation2id),
                       embedding_dim=100,
                       num_filters=400)

    # Typically define an optimizer
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # ------------------------------------------------
    # D) Training Skeleton (not a full example)
    # ------------------------------------------------
    model.train()

    # For demonstration, we do one epoch with a simple negative sampling approach
    # This is not a full robust training loop, just a placeholder:
    for epoch in range(1):
        total_loss = 0.0
        for (h_str, r_str, t_str, kv_strs) in facts:
            # Convert strings to IDs
            h_id = torch.tensor([entity2id[h_str]], dtype=torch.long)
            r_id = torch.tensor([relation2id[r_str]], dtype=torch.long)
            t_id = torch.tensor([entity2id[t_str]], dtype=torch.long)

            kv_pairs_ids = []
            for (k_str, v_str) in kv_strs:
                k_id = torch.tensor([relation2id[k_str]], dtype=torch.long)
                v_id = torch.tensor([entity2id[v_str]], dtype=torch.long)
                kv_pairs_ids.append((k_id, v_id))

            # 1) Positive score
            pos_score = model(h_id, r_id, t_id, kv_pairs_ids)

            # 2) Sample a negative fact by corrupting one entity or relation
            # (Simplified; you might randomly choose which to corrupt)
            # e.g., randomly pick an entity from entity2id
            import random
            do_corrupt_entity = random.random() < 0.5
            if do_corrupt_entity:
                # Corrupt the tail entity
                neg_t_str = random.choice(list(entity2id.keys()))
                neg_t_id = torch.tensor([entity2id[neg_t_str]], dtype=torch.long)
                neg_score = model(h_id, r_id, neg_t_id, kv_pairs_ids)
            else:
                # Corrupt the relation
                neg_r_str = random.choice(list(relation2id.keys()))
                neg_r_id = torch.tensor([relation2id[neg_r_str]], dtype=torch.long)
                neg_score = model(h_id, neg_r_id, t_id, kv_pairs_ids)

            # 3) Softplus loss => sum( log(1 + exp(-pos_score)) + log(1 + exp(neg_score)) )
            # or margin-based, etc.
            loss = F.softplus(-pos_score) + F.softplus(neg_score)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch}, total_loss = {total_loss:.4f}")

    # ------------------------------------------------
    # E) Inference: Predict top-K conditions
    # ------------------------------------------------
    model.eval()

    # Example query: we want conditions for (DrugA, interactWith, DrugC)
    query_drugA = "Temazepam"
    query_drugC = "sildenafil"
    query_relation = "interactWith"

    if (query_drugA not in entity2id) or (query_drugC not in entity2id) or (query_relation not in relation2id):
        print("One of these is not in the vocab. Please fix or handle OOV.")
        return

    h_id = torch.tensor([entity2id[query_drugA]], dtype=torch.long)
    r_id = torch.tensor([relation2id[query_relation]], dtype=torch.long)
    t_id = torch.tensor([entity2id[query_drugC]], dtype=torch.long)

    # We'll treat each condition as a potential (k, v) pair => (adverseEvent, conditionX)
    attribute_key = "adverseEvent"
    if attribute_key not in relation2id:
        print(f"Relation '{attribute_key}' not in vocab. Please fix.")
        return
    key_id = torch.tensor([relation2id[attribute_key]], dtype=torch.long)

    candidate_scores = []
    with torch.no_grad():
        for cond_str in all_conditions:
            # If this condition is not in the entity vocab, skip or use an OOV approach
            if cond_str not in entity2id:
                continue
            cond_id = torch.tensor([entity2id[cond_str]], dtype=torch.long)

            # Evaluate the hyperfact: (DrugA, interactWith, DrugC) with attribute (adverseEvent -> cond)
            # i.e. one kv_pair => (key_id, cond_id)
            score = model(h_id, r_id, t_id, [(key_id, cond_id)])
            candidate_scores.append((cond_str, score.item()))

    # Sort by descending score
    candidate_scores.sort(key=lambda x: x[1], reverse=True)

    # Show top 3
    K = 5
    top_k = candidate_scores[:K]

    print(f"\nPredicted hyper-relations for ({query_drugA}, {query_drugC}):")
    for cond_str, sc in top_k:
        print(f"  - {query_drugA} + {query_drugC} -> {cond_str} (Scoring: {sc:.4f})")


if __name__ == "__main__":
    main()


Num entities: 4605, Num relations: 3
Epoch 0, total_loss = 12913.0746

Predicted hyper-relations for (Temazepam, sildenafil):
  - Temazepam + sildenafil -> Tremor (Scoring: 2.2029)
  - Temazepam + sildenafil -> Drug interaction (Scoring: 2.1987)
  - Temazepam + sildenafil -> Tachycardia (Scoring: 2.1964)
  - Temazepam + sildenafil -> Abdominal pain (Scoring: 2.1925)
  - Temazepam + sildenafil -> Thrombocytopenia (Scoring: 2.1908)


In [9]:
def main():
    # ------------------------------------------------
    # A) Load data
    # ------------------------------------------------
    hyperfacts_path = "hyperfacts.json"   # your hyperfacts
    conditions_path = "conditions.json"   # your array of condition strings
    facts = load_hyperfacts(hyperfacts_path)
    all_conditions = load_conditions_from_json(conditions_path)

    # ------------------------------------------------
    # B) Build vocab
    # ------------------------------------------------
    entity2id, relation2id = build_vocab_and_mappings(facts)
    print(f"Num entities: {len(entity2id)}, Num relations: {len(relation2id)}")

    # ------------------------------------------------
    # C) Instantiate model
    # ------------------------------------------------
    model = HINGEModel(num_entities=len(entity2id),
                       num_relations=len(relation2id),
                       embedding_dim=100,
                       num_filters=400)

    # Typically define an optimizer
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # ------------------------------------------------
    # D) Training Skeleton (not a full example)
    # ------------------------------------------------
    model.train()

    # For demonstration, we do one epoch with a simple negative sampling approach
    # This is not a full robust training loop, just a placeholder:
    for epoch in range(100):
        total_loss = 0.0
        for (h_str, r_str, t_str, kv_strs) in facts:
            # Convert strings to IDs
            h_id = torch.tensor([entity2id[h_str]], dtype=torch.long)
            r_id = torch.tensor([relation2id[r_str]], dtype=torch.long)
            t_id = torch.tensor([entity2id[t_str]], dtype=torch.long)

            kv_pairs_ids = []
            for (k_str, v_str) in kv_strs:
                k_id = torch.tensor([relation2id[k_str]], dtype=torch.long)
                v_id = torch.tensor([entity2id[v_str]], dtype=torch.long)
                kv_pairs_ids.append((k_id, v_id))

            # 1) Positive score
            pos_score = model(h_id, r_id, t_id, kv_pairs_ids)

            # 2) Sample a negative fact by corrupting one entity or relation
            # (Simplified; you might randomly choose which to corrupt)
            # e.g., randomly pick an entity from entity2id
            import random
            do_corrupt_entity = random.random() < 0.5
            if do_corrupt_entity:
                # Corrupt the tail entity
                neg_t_str = random.choice(list(entity2id.keys()))
                neg_t_id = torch.tensor([entity2id[neg_t_str]], dtype=torch.long)
                neg_score = model(h_id, r_id, neg_t_id, kv_pairs_ids)
            else:
                # Corrupt the relation
                neg_r_str = random.choice(list(relation2id.keys()))
                neg_r_id = torch.tensor([relation2id[neg_r_str]], dtype=torch.long)
                neg_score = model(h_id, neg_r_id, t_id, kv_pairs_ids)

            # 3) Softplus loss => sum( log(1 + exp(-pos_score)) + log(1 + exp(neg_score)) )
            # or margin-based, etc.
            loss = F.softplus(-pos_score) + F.softplus(neg_score)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch}, total_loss = {total_loss:.4f}")

    # ------------------------------------------------
    # E) Inference: Predict top-K conditions
    # ------------------------------------------------
    model.eval()

    # Example query: we want conditions for (DrugA, interactWith, DrugC)
    query_drugA = "Temazepam"
    query_drugC = "Prednisone"
    query_relation = "interactWith"

    if (query_drugA not in entity2id) or (query_drugC not in entity2id) or (query_relation not in relation2id):
        print("One of these is not in the vocab. Please fix or handle OOV.")
        return

    h_id = torch.tensor([entity2id[query_drugA]], dtype=torch.long)
    r_id = torch.tensor([relation2id[query_relation]], dtype=torch.long)
    t_id = torch.tensor([entity2id[query_drugC]], dtype=torch.long)

    # We'll treat each condition as a potential (k, v) pair => (adverseEvent, conditionX)
    attribute_key = "adverseEvent"
    if attribute_key not in relation2id:
        print(f"Relation '{attribute_key}' not in vocab. Please fix.")
        return
    key_id = torch.tensor([relation2id[attribute_key]], dtype=torch.long)

    candidate_scores = []
    with torch.no_grad():
        for cond_str in all_conditions:
            # If this condition is not in the entity vocab, skip or use an OOV approach
            if cond_str not in entity2id:
                continue
            cond_id = torch.tensor([entity2id[cond_str]], dtype=torch.long)

            # Evaluate the hyperfact: (DrugA, interactWith, DrugC) with attribute (adverseEvent -> cond)
            # i.e. one kv_pair => (key_id, cond_id)
            score = model(h_id, r_id, t_id, [(key_id, cond_id)])
            candidate_scores.append((cond_str, score.item()))

    # Sort by descending score
    candidate_scores.sort(key=lambda x: x[1], reverse=True)

    # Show top 3
    K = 5
    top_k = candidate_scores[:K]

    print(f"\nPredicted hyper-relations for ({query_drugA}, {query_drugC}):")
    for cond_str, sc in top_k:
        print(f"  - {query_drugA} + {query_drugC} -> {cond_str} (Scoring: {sc:.4f})")


if __name__ == "__main__":
    main()


Num entities: 4605, Num relations: 3
Epoch 0, total_loss = 12780.0887

Predicted hyper-relations for (Temazepam, Prednisone):
  - Temazepam + Prednisone -> Dizziness (Scoring: 2.2133)
  - Temazepam + Prednisone -> Chest pain (Scoring: 2.2014)
  - Temazepam + Prednisone -> Pneumonia (Scoring: 2.2001)
  - Temazepam + Prednisone -> Pain (Scoring: 2.1941)
  - Temazepam + Prednisone -> Muscle spasms (Scoring: 2.1936)
