# Heterogeneous Information Network - Link Prediction

![image.png](./media/linkprediction31.png)

In [18]:
import random
import math
import time

from typing import Generator, List, Dict

from itertools import islice, chain
from collections import defaultdict

import numpy as np
import tabulate

## Graph Classes

In [2]:
class Walk:
    """
    A class representing a single walk in the graph.
    """
    DECIMALS: int = 4

    def __init__(self) -> None:
        """
        Constructing a Walk.
        """
        self.path: List[Node] = []

    @property
    def score(self) -> float:
        """
        Function that calculates the average score for a given path.
        """
        return sum(
            node.score
            for node in self.path
        ) / len(self.path)

    def __str__(self) -> str:
        """
        String serializer.
        """
        assert self.DECIMALS > 0
        return ' - '.join([
            f'Link [{round(self.score, self.DECIMALS)}]',
            *[
                str(node)
                for node in self.path
            ]
        ])

In [3]:
class Node:
    """
    An entity in RCP.
    """
    PERSON: str = 'PERSON'
    COMPANY: str = 'COMPANY'
    INVESTMENT: str = 'INVESTMENT'
    RELATIONSHIP: str = 'RELATIONSHIP'
    AVERAGE_RELATIONSHIPS: int = 3000
    AVERAGE_INVESTMENTS: int = 300
    RQ_WEIGHT: float = 5.0
    PS_WEIGHT: float = 2.0
    JOB_WEIGHT: float = 8.0
    COMPANY_NORMALIZATION_FACTOR: float = 20.0
    INVESTMENT_NORMALIZATION_FACTOR: float = 10.0
    DECIMALS: int = 4

    def __init__(self) -> None:
        """
        Entity constructor.
        """
        self.UUID: int = random.randint(1, 10000000)
        self.type: str = self.PERSON
        self.priority_score: float = random.randint(0, 3)
        self.relationship_quality: float = random.randint(0, 3)
        self.size: float = random.randint(5, 500000)
        self.popularity: float = random.uniform(0, 1)
        self.out_paths: dict = {}
        self._irq: float = None
        self._score: float = None
        self.job: str = random.choice([
            'employee',
            'consultant',
            'advisor',
            'board_member',
            'partner',
            'investor',
        ])

    def __str__(self) -> str:
        """
        Function that generates a pretty string for a given entity.
        """
        assert self.DECIMALS > 0
        return ' '.join([
            f'{self.id}',
            f'[{round(self.score, self.DECIMALS)}]',
            f'[{round(self.irq, self.DECIMALS)}]',
        ])

    def walks(self) -> Generator:
        """
        Flattens a tree containing different paths.
        It returns one list containing one row per combination.
        """
        if self.out_paths:
            for subnode in self.out_paths.values():
                for subwalk in subnode.walks():
                    walk: Walk = Walk()
                    walk.path = [self, *subwalk.path]
                    yield walk
        else:
            walk: Walk = Walk()
            walk.path = [self, ]
            yield walk

    @property
    def id(self) -> str:
        """
        Assigns a unique hash to any given entity.
        """
        return f'{self.type}#{self.UUID}'

    def is_person(self) -> bool:
        """
        Returns True if the entity is a Person.
        """
        return self.type == self.PERSON

    def is_company(self) -> bool:
        """
        Returns True if the entity is a Company.
        """
        return self.type == self.COMPANY

    def is_investment(self) -> bool:
        """
        Returns True if the entity is an Investment.
        """
        return self.type == self.INVESTMENT

    def is_relationship(self) -> bool:
        """
        Returns True if the entity is a Relationship.
        """
        return self.type == self.RELATIONSHIP

    def trim(self, value: float) -> float:
        """
        Function that corrects scores and ensures that
        they are always contained between 0 and 1.
        """
        return min(max(value, 0), 1)

    @property
    def irq(self) -> float:
        """
        IRQ score getter.
        """
        if self._irq is None:
            self._irq = self.score
        return self._irq

    @irq.setter
    def irq(self, value: float) -> None:
        """
        IRQ score setter.
        """
        self._irq = value

    @property
    def score(self) -> float:
        """
        Function that assings a single score to a given entity.
        - Companies are scored based on their size.
        - Investments are scored based on their size.
        - Persons are scored based on the priority signals.
        - Relationships are scored based on the priority signals and job type.
        """
        if self._score is None:
            assert self.RQ_WEIGHT > 0
            assert self.PS_WEIGHT > 0
            assert self.JOB_WEIGHT > 0
            assert self.COMPANY_NORMALIZATION_FACTOR > 0
            assert self.INVESTMENT_NORMALIZATION_FACTOR > 0
            if self.is_person():
                assert self.priority_score <= 3
                assert self.priority_score >= 0
                assert self.relationship_quality <= 3
                assert self.relationship_quality >= 0
                self._score = sum([
                    self.RQ_WEIGHT * self.relationship_quality / 3,
                    self.PS_WEIGHT * self.priority_score / 3,
                ]) / sum([
                    self.RQ_WEIGHT,
                    self.PS_WEIGHT,
                ])
            elif self.is_company():
                assert self.size > 0
                self._score = self.trim(
                    1 - math.log(self.size) / self.COMPANY_NORMALIZATION_FACTOR
                )
            elif self.is_relationship():
                assert self.popularity >= 0
                assert self.popularity <= 1
                assert self.priority_score <= 3
                assert self.priority_score >= 0
                assert self.relationship_quality <= 3
                assert self.relationship_quality >= 0
                self._score = sum([
                    self.RQ_WEIGHT * self.relationship_quality / 3,
                    self.PS_WEIGHT * self.priority_score / 3,
                    self.JOB_WEIGHT * self.popularity,
                ]) / sum([
                    self.RQ_WEIGHT,
                    self.PS_WEIGHT,
                    self.JOB_WEIGHT,
                ])
            elif self.is_investment():
                assert self.size > 0
                self._score = self.trim(
                    math.log(self.size) / self.INVESTMENT_NORMALIZATION_FACTOR
                )
        return self._score

    def get_all_relationships(self) -> Generator:
        """
        Generates a random list of relationships for any given type of entity.
        1. For Persons, it generates Investments and Relationships.
        2. For Companies, it generates Investments and Relationships.
        3. For Relationships, it generates Persons and Companies.
        4. For Investments, it generates Persons and Companies.
        """
        assert self.AVERAGE_RELATIONSHIPS > 0
        assert self.AVERAGE_INVESTMENTS > 0
        if self.is_person():
            return chain(
                (
                    self.get_random_investment()
                    for _ in range(random.randint(0, self.AVERAGE_INVESTMENTS))
                ),
                (
                    self.get_random_relationship()
                    for _ in range(random.randint(0, self.AVERAGE_RELATIONSHIPS))
                )
            )
        elif self.is_company():
            return chain(
                (
                    self.get_random_investment()
                    for _ in range(random.randint(0, self.AVERAGE_INVESTMENTS * 10))
                ),
                (
                    self.get_random_relationship()
                    for _ in range(random.randint(0, self.AVERAGE_RELATIONSHIPS * 10))
                )
            )
        elif self.is_relationship():
            return (
                random.choice([
                    self.get_random_person(),
                    self.get_random_company()
                ])
                for _ in range(random.randint(0, self.AVERAGE_RELATIONSHIPS))
            )
        elif self.is_investment():
            return (
                random.choice([
                    self.get_random_person(),
                    self.get_random_company()
                ])
                for _ in range(random.randint(0, self.AVERAGE_INVESTMENTS))
            )
        else:
            raise NotImplementedError()

    @classmethod
    def get_random_person(cls) -> 'Node':
        """
        Generates a random Person.
        """
        node: cls = cls()
        node.type = cls.PERSON
        return node

    @classmethod
    def get_random_company(cls) -> 'node':
        """
        Generates a random Company.
        """
        node: cls = cls()
        node.type = cls.COMPANY
        return node

    @classmethod
    def get_random_relationship(cls) -> 'Node':
        """
        Generates a random Relationship.
        """
        node: cls = cls()
        node.type = cls.RELATIONSHIP
        return node

    @classmethod
    def get_random_investment(cls) -> 'Node':
        """
        Generates a random Investment.
        """
        node: cls = cls()
        node.type = cls.INVESTMENT
        return node

    @classmethod
    def get_contacts_list(cls) -> Generator:
        """
        Generates a random list of contacts.
        """
        assert cls.CONTACTS_LIST_SIZE > 0
        for _ in range(cls.CONTACTS_LIST_SIZE):
            yield cls.get_random_person()

In [8]:
class Graph:
    """
    Class representing the social graph.
    """
    MINI_BATCH_SIZE: int = 300
    TOP_PERCENTILE: int = 0.2
    LEARNING_RATE: float = 0.4
    BRANCHING_FACTOR: int = 1000
    BRANCHING_DECAY: float = 10.0
    NOVELTY_RATE: float = 0.01
    DROP_RATE: float = 0.01
    MIN_SCORE_ACCEPTED: float = 0.95
    MAX_SCORE_REJECTED: float = 0.4
    SURVIVAL_RATE: float = 0.5

    def __init__(self) -> None:
        """
        Graph constructor.
        """
        self.max_depth: int = 5
        self.verbose: bool = False
        self.status: defaultdict = defaultdict(int)

    def _get_mini_batch(cls, nodes: Generator) -> Generator:
        """
        Function responsible for generating a mini batch
        of the top
        """
        assert cls.MINI_BATCH_SIZE > 0
        assert cls.TOP_PERCENTILE > 0
        assert cls.TOP_PERCENTILE < 100
        while True:
            mini_batch: list = list(islice(nodes, cls.MINI_BATCH_SIZE))
            if not mini_batch:
                break
            percentile: float = np.percentile([
                next_node.score
                for next_node in mini_batch
            ], (100 - cls.TOP_PERCENTILE))
            yield [
                next_node
                for next_node in mini_batch
                if next_node.score >= percentile
            ]

    def _get_next_nodes(cls, node: Node, depth: int = 1) -> Generator:
        """
        Function that finds a list of potentital nodes to walk next.
        To reduce the combinational problem, it prunes the results using heuristics.
        """
        assert cls.BRANCHING_FACTOR > 1
        assert cls.BRANCHING_DECAY > 1
        max_nodes: int = math.ceil(
            cls.BRANCHING_FACTOR / (depth * cls.BRANCHING_DECAY)
        )
        total: int = 0
        nodes: Generator = node.get_all_relationships()
        for mini_batch in cls._get_mini_batch(nodes=nodes):
            for next_node in mini_batch:
                next_node.irq = sum([
                    cls.LEARNING_RATE * next_node.irq,
                    (1 - cls.LEARNING_RATE) * node.irq,
                ])
                if not cls._is_pruned(node=next_node):
                    total += 1
                    yield next_node
                if total > max_nodes:
                    return

    def _is_pruned(cls, node: Node) -> bool:
        """
        Function that defines whether or not to select a node
        as a relevant or irrelevant candidate to walk next.
        1. Randomly accepting records to provide novelty
           in the results. That way, new paths can be formed.
        2. Calculating the score for each candidate.
           If the score is too low, then it is pruned.
        3. Randomly dropping some records to reduce the
           complexity of the recursive algorithm.
        """
        assert cls.NOVELTY_RATE > 0
        assert cls.NOVELTY_RATE < 1
        assert cls.DROP_RATE > 0
        assert cls.DROP_RATE < 1
        assert cls.MAX_SCORE_REJECTED > 0
        assert cls.MAX_SCORE_REJECTED < 1
        assert cls.MIN_SCORE_ACCEPTED > 0
        assert cls.MIN_SCORE_ACCEPTED < 1
        assert cls.MIN_SCORE_ACCEPTED > cls.MAX_SCORE_REJECTED
        probability: float = random.uniform(0, 1)
        if probability < cls.NOVELTY_RATE:
            return False
        if node.irq <= cls.MAX_SCORE_REJECTED:
            return True
        if node.irq >= cls.MIN_SCORE_ACCEPTED:
            return False
        probability: float = random.uniform(0, 1)
        if probability > node.irq:
            return True
        probability: float = random.uniform(0, 1)
        if probability < cls.DROP_RATE:
            return True
        return False

    def find_links(self, nodes: Generator = None, depth: int = 1) -> Dict[str, Node]:
        """
        Obtaines all the paths associated with a list of nodes.
        From time to time, it prints the status to stdout.
        """
        assert depth > 0
        assert self.max_depth >= depth
        links: Dict[str, Node] = {}
        for node in nodes:
            self.status[depth] += 1
            probability: float = random.uniform(0, 1)
            if self.verbose and probability > 0.99:
                print('Status:', dict(self.status))
            next_nodes: dict = self._find_links(
                node=node,
                depth=depth + 1,
            )
            if node.is_person() or node.is_company() or next_nodes:
                node.out_paths = next_nodes
                links[node.id] = node
        return links

    def _find_links(self, depth: int, node: dict) -> dict:
        """
        Obtains all the potential paths assocaited with a specific nodes.
        Randomly, it stops to prune the worst branches.
        """
        assert depth > 0
        assert self.TOP_PERCENTILE > 0
        assert self.TOP_PERCENTILE < 100
        assert self.SURVIVAL_RATE > 0
        assert self.SURVIVAL_RATE < 1
        assert self.MINI_BATCH_SIZE > 0
        links: dict = {}
        if depth <= self.max_depth:
            next_nodes: tuple = self._get_next_nodes(node=node, depth=depth)
            while True:
                batch: list = list(islice(next_nodes, self.MINI_BATCH_SIZE))
                if not batch:
                    return links
                links.update(self.find_links(
                    nodes=batch,
                    depth=depth,
                ))
                probability: float = random.uniform(0, 1)
                if links and probability > self.SURVIVAL_RATE:
                    percentile: float = np.percentile([
                        node.irq
                        for node in links.values()
                    ], (100 - self.TOP_PERCENTILE))
                    links: Dict[str, Node] = {
                        node_id: node
                        for node_id, node in links.items()
                        if node.irq >= percentile
                    }

## Tests

### Test 1: Finding links of degree 5 for a contact list of 100 people with an aggresive pruning strategy

In [21]:
Node.CONTACTS_LIST_SIZE: int = 100
Node.RQ_WEIGHT: float = 5.0
Node.PS_WEIGHT: float = 2.0
Node.JOB_WEIGHT: float = 8.0
Node.COMPANY_NORMALIZATION_FACTOR: float = 20.0
Node.INVESTMENT_NORMALIZATION_FACTOR: float = 11.0
Node.AVERAGE_RELATIONSHIPS: int = 100
Node.AVERAGE_INVESTMENTS: int = 50

Graph.MINI_BATCH_SIZE: int = 300
Graph.TOP_PERCENTILE: int = 0.01
Graph.LEARNING_RATE: float = 0.1
Graph.BRANCHING_FACTOR: int = 100000
Graph.BRANCHING_DECAY: int = 4
Graph.NOVELTY_RATE: float = 0.00001
Graph.DROP_RATE: float = 0.00001
Graph.MIN_SCORE_ACCEPTED: float = 0.995
Graph.MAX_SCORE_REJECTED: float = 0.60
Graph.SURVIVAL_RATE: float = 0.4

graph: Graph = Graph()
graph.max_depth = 5
graph.verbose = False

start: float = time.time()
links: Dict[str , Node] = graph.find_links(nodes=Node.get_contacts_list())
end: float = time.time()

In [41]:
print('Predicted Links:')
score: float = 0
walks: Generator = (
    walk
    for node in links.values()
    for walk in node.walks()
)
table: list = []
for total, walk in enumerate(walks):
    if all([
        len(table) < 10,
        any(node.is_relationship() for node in walk.path),
        any(node.is_company() for node in walk.path)
    ]):
        table.append([
            walk. score,
            *walk.path,
        ])
    total += 1
    score += walk.score
table: tabulate.JupyterHTMLStr = tabulate.tabulate(table, tablefmt='html')
table

Predicted Links:


0,1,2,3,4,5
0.830401,PERSON#7381945 [0.5714] [0.5714],INVESTMENT#5531634 [1] [0.6143],COMPANY#8232622 [0.5831] [0.6112],RELATIONSHIP#5404915 [0.9975] [0.6498],PERSON#2995914 [1.0] [0.6848]
0.803477,PERSON#7381945 [0.5714] [0.5714],INVESTMENT#5531634 [1] [0.6143],COMPANY#8232622 [0.5831] [0.6112],RELATIONSHIP#9467500 [0.9581] [0.6459],PERSON#7325944 [0.9048] [0.6718]
0.83771,PERSON#1833055 [0.9048] [0.9048],RELATIONSHIP#5085161 [0.891] [0.9034],PERSON#3326142 [1.0] [0.913],INVESTMENT#6752569 [1] [0.9217],COMPANY#5971017 [0.3928] [0.8688]
0.916806,PERSON#1833055 [0.9048] [0.9048],RELATIONSHIP#5085161 [0.891] [0.9034],PERSON#3326142 [1.0] [0.913],INVESTMENT#9076090 [1] [0.9217],COMPANY#5351098 [0.7883] [0.9084]
0.82446,PERSON#7413287 [0.6667] [0.6667],INVESTMENT#3017059 [1] [0.7],COMPANY#3740435 [0.5006] [0.6801],RELATIONSHIP#7650943 [0.9551] [0.7076],PERSON#587272 [1.0] [0.7368]
0.82446,PERSON#7413287 [0.6667] [0.6667],INVESTMENT#3017059 [1] [0.7],COMPANY#3740435 [0.5006] [0.6801],RELATIONSHIP#7650943 [0.9551] [0.7076],PERSON#6109930 [1.0] [0.7368]
0.82446,PERSON#7413287 [0.6667] [0.6667],INVESTMENT#3017059 [1] [0.7],COMPANY#3740435 [0.5006] [0.6801],RELATIONSHIP#7650943 [0.9551] [0.7076],PERSON#5666759 [1.0] [0.7368]
0.82446,PERSON#7413287 [0.6667] [0.6667],INVESTMENT#3017059 [1] [0.7],COMPANY#3740435 [0.5006] [0.6801],RELATIONSHIP#7650943 [0.9551] [0.7076],PERSON#9157930 [1.0] [0.7368]
0.826927,PERSON#1435767 [0.6667] [0.6667],INVESTMENT#4901247 [1] [0.7],COMPANY#4790623 [0.4705] [0.6771],RELATIONSHIP#5366584 [0.9974] [0.7091],PERSON#3718938 [1.0] [0.7382]
0.826927,PERSON#1435767 [0.6667] [0.6667],INVESTMENT#4901247 [1] [0.7],COMPANY#4790623 [0.4705] [0.6771],RELATIONSHIP#5366584 [0.9974] [0.7091],PERSON#5018285 [1.0] [0.7382]


In [38]:
print(f'Total Links: {total}')
print(f'Average Score: {round(score / total, 4)}')
print(f'Elapsed Time: {round(end - start, 4)}')

Total Links: 26829
Average Score: 0.9176
Elapsed Time: 14.2496


### Test 2: Finding the best 7-degree connections that my top 10 contacts can provide me

In [42]:
Node.CONTACTS_LIST_SIZE: int = 10
Node.RQ_WEIGHT: float = 5.0
Node.PS_WEIGHT: float = 2.0
Node.JOB_WEIGHT: float = 8.0
Node.COMPANY_NORMALIZATION_FACTOR: float = 20.0
Node.INVESTMENT_NORMALIZATION_FACTOR: float = 11.0
Node.AVERAGE_RELATIONSHIPS: int = 100
Node.AVERAGE_INVESTMENTS: int = 50

Graph.MINI_BATCH_SIZE: int = 300
Graph.TOP_PERCENTILE: int = 0.01
Graph.LEARNING_RATE: float = 0.5
Graph.BRANCHING_FACTOR: int = 100000
Graph.BRANCHING_DECAY: int = 4
Graph.NOVELTY_RATE: float = 0.001
Graph.DROP_RATE: float = 0.0001
Graph.MIN_SCORE_ACCEPTED: float = 0.995
Graph.MAX_SCORE_REJECTED: float = 0.80
Graph.SURVIVAL_RATE: float = 0.8

graph: Graph = Graph()
graph.verbose = False
graph.max_depth = 7

start: float = time.time()
links: Dict[str , Node] = graph.find_links(nodes=Node.get_contacts_list())
end: float = time.time()

In [46]:
print('Predicted Links:')
score: float = 0
walks: Generator = (
    walk
    for node in links.values()
    for walk in node.walks()
)
table: list = []
for total, walk in enumerate(walks):
    if len(walk.path) > 1 and len(table) < 10:
        table.append([
            walk. score,
            *walk.path,
        ])
    total += 1
    score += walk.score
table: tabulate.JupyterHTMLStr = tabulate.tabulate(table, tablefmt='html')
table

Predicted Links:


0,1,2,3,4,5,6,7
0.959184,PERSON#2067254 [0.7143] [0.7143],INVESTMENT#5640058 [1] [0.8571],PERSON#4004944 [1.0] [0.9286],INVESTMENT#2128276 [1] [0.9643],PERSON#1661585 [1.0] [0.9821],INVESTMENT#9217356 [1] [0.9911],PERSON#8931110 [1.0] [0.9955]
0.959184,PERSON#2067254 [0.7143] [0.7143],INVESTMENT#5640058 [1] [0.8571],PERSON#4004944 [1.0] [0.9286],INVESTMENT#2128276 [1] [0.9643],PERSON#1661585 [1.0] [0.9821],INVESTMENT#9217356 [1] [0.9911],PERSON#2778544 [1.0] [0.9955]
0.959184,PERSON#2067254 [0.7143] [0.7143],INVESTMENT#5640058 [1] [0.8571],PERSON#4004944 [1.0] [0.9286],INVESTMENT#2128276 [1] [0.9643],PERSON#1661585 [1.0] [0.9821],INVESTMENT#4094396 [1] [0.9911],PERSON#9486684 [1.0] [0.9955]
0.945578,PERSON#2067254 [0.7143] [0.7143],INVESTMENT#5640058 [1] [0.8571],PERSON#4004944 [1.0] [0.9286],INVESTMENT#2128276 [1] [0.9643],PERSON#519321 [1.0] [0.9821],INVESTMENT#1973016 [1] [0.9911],PERSON#1074942 [0.9048] [0.9479]
0.959184,PERSON#2067254 [0.7143] [0.7143],INVESTMENT#5640058 [1] [0.8571],PERSON#4004944 [1.0] [0.9286],INVESTMENT#2128276 [1] [0.9643],PERSON#519321 [1.0] [0.9821],INVESTMENT#3846577 [1] [0.9911],PERSON#5432201 [1.0] [0.9955]
0.959184,PERSON#2067254 [0.7143] [0.7143],INVESTMENT#5640058 [1] [0.8571],PERSON#4004944 [1.0] [0.9286],INVESTMENT#2128276 [1] [0.9643],PERSON#519321 [1.0] [0.9821],INVESTMENT#3846577 [1] [0.9911],PERSON#7119745 [1.0] [0.9955]
0.959184,PERSON#2067254 [0.7143] [0.7143],INVESTMENT#5640058 [1] [0.8571],PERSON#4004944 [1.0] [0.9286],INVESTMENT#2128276 [1] [0.9643],PERSON#519321 [1.0] [0.9821],INVESTMENT#3846577 [1] [0.9911],PERSON#1809349 [1.0] [0.9955]
0.959184,PERSON#2067254 [0.7143] [0.7143],INVESTMENT#5640058 [1] [0.8571],PERSON#4004944 [1.0] [0.9286],INVESTMENT#2128276 [1] [0.9643],PERSON#519321 [1.0] [0.9821],INVESTMENT#385984 [1] [0.9911],PERSON#1651586 [1.0] [0.9955]
0.959184,PERSON#2067254 [0.7143] [0.7143],INVESTMENT#5640058 [1] [0.8571],PERSON#4004944 [1.0] [0.9286],INVESTMENT#2128276 [1] [0.9643],PERSON#519321 [1.0] [0.9821],INVESTMENT#385984 [1] [0.9911],PERSON#1429475 [1.0] [0.9955]
0.959184,PERSON#2067254 [0.7143] [0.7143],INVESTMENT#5640058 [1] [0.8571],PERSON#4004944 [1.0] [0.9286],INVESTMENT#2128276 [1] [0.9643],PERSON#519321 [1.0] [0.9821],INVESTMENT#385984 [1] [0.9911],PERSON#8025144 [1.0] [0.9955]


In [44]:
print(f'Total Links: {total}')
print(f'Average Score: {round(score / total, 4)}')
print(f'Elapsed Time: {round(end - start, 4)}')

Total Links: 63603
Average Score: 0.9391
Elapsed Time: 34.5552
