# Object Oriented Probabilistic Programming in Pyro

### Introduction

#### Why do we need Object Oriented Probabilistic Models ?

The causal scene generation project involves generating 2D images from natural language where the causal model describes the relationship between the entities in the caption. Flickr8k dataset contains captions and the corresponding images. Using these captions, along with the causal model, we can probabilistically reason the following:

**Caption: A girl is going to a wooden building.**



![girl-walking](./pictures/stock1.jpg)



**Question: Had a boy been walking, would he still be going towards a wooden building?**

![boy-walking](./pictures/stock2.jpg)

In order to achieve this, we first need to model our real world concepts as entities so that we can translate them into object-oriented programming language. For example, in our cases we need to have classes for A girl and A boy (The Entities),  A building (The environment), and class for action (walking, running etc) for all the entities (Action).

Object-oriented programming (OOP) has been around since the 70s[[1](https://en.wikipedia.org/wiki/Object-oriented_programming)].However, OOP is not prevalent in the probabilistic programming (PP) community. At the time of this writing, there is no OO probabilistic programming support in Python. In this tutorial, we attempt to model our domain entities as Python objects and then perform causal inference on them using Pyro.


In [1]:
import abc
import random
import itertools
import sqlite3
import numpy as np
import matplotlib.pyplot as plt

import torch
import pyro
import pyro.distributions as dist

pyro.set_rng_seed(101)
sqlite_file = './data/SocialNetwork'
conn = sqlite3.connect(sqlite_file)

### Social Media Class Entity Diagram
![Social Media Relationship](./pictures/social_media_diagram.jpg)

### Interpretation

1. The classes are represented in the `rectangle` boxes.
2. Attributes represented within in the `oval` are Random Variables. For Person, the interest could be Politics or Sports in our case.
3. Thick arrows show probabilistic dependencies. Post.Topic depends on Person.Interest and so on.
4. Dashed line shows asymmetric dependency .Eg, Comment.match (Boolean Random Variable) and Connection.Type (Family/Friend/Acquintance) which explains the asymmetric dependency between Comment.Post.poster and Comment.Commenter.


### Model Schema in Database

![Class-Diagram](./pictures/class_diagram.png)

#### NOTE: Although, it is represented as ids, we need the Primary keys and Foreign keys to have the value as 

#### `TABLE_NAME+FK/PK_ID`. Hence, postId in comment table would  have the value post1 or post2 etc.

### Implementation

The notebook provides an example of object oriented probabilistic programming in pyro using an example of a social media which has classes defined for people, posts, connections and comments. Furthermore, it also shows a way to read in data from a knowledge base, here we are using sqlite. 

The key ideas introduced in this approach are:
- Using unique ID's read from the knowledge base to create persistent trace objects accross model runs. 
- Creating observe and infer abstarct methods for each class to provide an interface to set value for a instance of an object and to infer a particular value for an instance
- The idea of a Universe class (here it is SocialMedia) which contains the CPT's and provides and an interface between conditions on individual class instances and pyro's condition and intervention functions. 

In [2]:
class SocialMedia(metaclass=abc.ABCMeta):
    """
    The SocialMedia class is not one of the entities we model in our domain. It serves as a convenient 
    meta class to capture the probabilities for the random variables in the inheriting class, abstract methods 
    to be implemented by the inheriting class and utility methods shared by the inheriting class. All the domain 
    entities inherit the SocialMedia class.
    """
    
    # Values the random variables can take
    values = {
        'interest': ['sports', 'politics'], 
        'topic': ['sports', 'politics'],
        'connection': ['family', 'close friend', 'acquintance'],
        'match': ['no', 'yes'],
    }
    
    # Conditional probability tables
    cpts = {
        'interest': torch.tensor([.5, .5]), 
        'topic': torch.tensor([[.9, .1], [.1, .9]]),
        'connection': torch.tensor([1/3, 1/3, 1/3]),
        'match': torch.tensor([
            [
                [[0, 1.], [0, 1.], [0, 1.]], 
                [[.2, .8], [.5, .5], [.9, .1]],
            ],
            [
                [[.2, .8], [.5, .5], [.9, .1]],
                [[0, 1.], [0, 1.], [0, 1.]],
            ],
            
        ]),
    }
    
    # Dictionary containing the connections sampled in the current model run
    existing_connections = dict()
    
    @abc.abstractmethod
    def observe(self):
        """
        Abstract method to make sure subclasses implement the observe function which allows the user 
        to set evidence pertaining to a class instance. 
        """
        pass
    
    @abc.abstractmethod
    def infer(self):
        """
        Abstract method to make sure subclasses implement the infer function which allows the user 
        infer a particular value for a class instance.
        """
        pass
    
    @staticmethod
    def evidence(conditions):
        """
        Forms a dictionary to be passed to pyro from a list of conditions formed by calling
        the observe function of each class object with the given observation. 
        
        :param list(dict) conditions: List of dictionaries, each dictionary having key as the 
            trace object related to the class object and its observed value.
        
        :return: Dictionary where the keys are trace objects and values are the observed values. 
        :rtype: dict(str, torch.tesnor)
        """
        
        cond_dict = {}
        for c in conditions:
            cond_dict[list(c.keys())[0]] = list(c.values())[0]
        
        return cond_dict
    
    @staticmethod
    def condition(model, evidence, infer, val, num_samples = 1000):
        """
        Uses pyro condition function with importance sampling to get the conditional probability 
        of a particular value for the random variable under inference. 
        
        :param func model: Probabilistic model defined with pyro sample methods.
        :param dict(str, torch.tensor) evidence: Dictionary with trace objects and their observed values.
        :param str infer: Trace object which needs to be inferred.
        :param int val: Value of the trace object for which the probabilities are required.
        :param int num_samples: Number of samples to run the inference alogrithm.
        
        :return: Probability of trace object being the value provided.
        :rtype: int
        """
        
        conditioned_model = pyro.condition(model, data = evidence)
        posterior = pyro.infer.Importance(conditioned_model, num_samples=num_samples).run()
        marginal = pyro.infer.EmpiricalMarginal(posterior, infer)
        samples = np.array([marginal().item() for _ in range(num_samples)])
        
        return sum([1 for s in samples if s.item() == val])/num_samples
    
    @staticmethod
    def intervention(model, evidence, infer, val, num_samples = 1000):
        """
        Uses pyro condition function with importance sampling to get the intervention probability 
        of a particular value for the random variable under inference.
        
        :param func model: Probabilistic model defined with pyro sample methods.
        :param dict(str, torch.tensor) evidence: Dictionary with trace objects and their observed values.
        :param str infer: Trace object which needs to be inferred.
        :param int val: Value of the trace object for which the probabilities are required.
        :param int num_samples: Number of samples to run the inference alogrithm.
        
        :return: Probability of trace object being the value provided.
        :rtype: int
        """
        
        intervention_model = pyro.do(model, data = evidence)
        posterior = pyro.infer.Importance(intervention_model, num_samples=num_samples).run()
        marginal = pyro.infer.EmpiricalMarginal(posterior, infer)
        samples = np.array([marginal().item() for _ in range(num_samples)])
        
        return sum([1 for s in samples if s.item() == val])/num_samples

In [3]:
class Person(SocialMedia):
    """
    Defines a person on the social media. 
    
    Attributes:
    - interest: Sampled from uniform distribution over [sports, politics]
    """
    
    def __init__(self, iid):
        """
        Creates the trace object for the person and also samples the person's interest
        using pyro.
        
        :param int iid: The unique ID obtained from the knowledge base.
        """
        self.iid = iid
        
        self.trace_var = 'Interest%d' % (self.iid)
        self.interest = pyro.sample(self.trace_var, dist.Categorical(self.cpts['interest']))
    
    def observe(self, topic):
        """
        Overwrites the :func: `~SocialMedia.observe~` abstract method to set the trace object
        of the person's interest to a particular topic.
        
        :param str topic: The observed topic for the trace object.
        
        :return: Dictionary containing the trace variable as key and topic as key.
        :rtype: dict(str, torch.tensor)
        """
        topic_idx = self.values['interest'].index(topic)
        return {self.trace_var: torch.tensor(topic_idx)}
    
    def infer(self):
        """
        Overwrites the :func: `~SocialMedia.infer~` abstract method to return the trace object
        of the person's interest.
        
        :return: The trace object related to the person's interest. 
        :rtype: str
        """
        return self.trace_var

In [4]:
class Post(SocialMedia):
    """
    Defines a post on the social media. 
    
    Attributes:
    - poster: The person who posted the post
    - topic: Sampled using pyro based on cpt defined in :class: `~SocialMedia~` based on 
             poster's interest. 
    """
    
    def __init__(self, poster, iid):
        """
        Creates the trace object for the post and sample's the post's topic 
        based on the poster's interest. 
        
        :param Person poster: Person who posted the post. 
        :param int iid: The unique ID obtained from the knowledge base.
        """
        self.iid = iid
        self.poster = poster
        
        self.trace_var = 'Topic%d' % (self.iid)
        self.topic = pyro.sample(self.trace_var, dist.Categorical(self.cpts['topic'][self.poster.interest]))
    
    def observe(self, topic):
        """
        Overwrites the :func: `~SocialMedia.observe~` abstract method to set the trace object
        of the post's topic to a particular topic.
        
        :param str topic: The observed topic for the trace object.
        
        :return: Dictionary containing the trace variable as key and topic as key.
        :rtype: dict(str, torch.tensor)
        """
        topic_idx = self.values['interest'].index(topic)
        return {self.trace_var: torch.tensor(topic_idx)}
    
    def infer(self):
        """
        Overwrites the :func: `~SocialMedia.infer~` abstract method to return the trace object
        of the post's topic.
        
        :return: The trace object related to the post's topic. 
        :rtype: str
        """
        return self.trace_var

In [5]:
class Connection(SocialMedia):
    """
    Defines the connection between two people on the social network.
    
    Attributes:
    - Person1: First person in the connection
    - Person2: Second person in the connection
    - connection: Connection type between the two person which is sampled from a uniform distribution
                containing ['family', 'close friend', 'acquintance'] using pyro
    """
    
    def __init__(self, person1, person2):
        """
        Creates trace object for connection between two people on the social network. Samples the 
        connection type from a uniform distribution using pyro. 
        
        - The connections in this implementation are asymmetric, i.e Connection(person1, person2) is not the 
        same as Connection(person2, person1)
        - A connection is only sampled once during a single model run. This is done because say we have two posts by person1
        and person2 comments on the both posts, we dont want to sample Connection(person1, person2) twice which might 
        lead to two different connection types between person1 and person2, which doesn't make sense in the real world. 
        After the connections are sampled for the first time, they are stored in a dictionary and are retrieved from it if 
        they are needed thereafter. 
        
        :param Person person1: First person involved in the connection
        :param Person person2: Second person involved in the connection
        """
        self.person1 = person1
        self.person2 = person2 
        
        self.trace_var = 'Connection%d%d' % (self.person1.iid, self.person2.iid)
        
        if (self.person1.iid, self.person2.iid) not in self.existing_connections.keys():
            self.connection = pyro.sample(self.trace_var, dist.Categorical(self.cpts['connection']))
            self.existing_connections[(person1.iid, person2.iid)] = dict()
            self.existing_connections[(person1.iid, person2.iid)]['connection'] = self.connection
            self.existing_connections[(person1.iid, person2.iid)]['trace_var'] = self.trace_var
        else:
            self.connection = self.existing_connections[(person1.iid, person2.iid)]['connection']
        
    def observe(self, connection):
        """
        Overwrites the :func: `~SocialMedia.observe~` abstract method to set the trace object
        of the connection to a particular connection type
        
        :param str connection: Type of connection
        
        :return: Dictionary containing the trace variable as key and match as key.
        :rtype: dict(str, torch.tensor)
        """
        connection_idx = self.values['connection'].index(connection)
        return {self.existing_connections[(self.person1.iid, self.person2.iid)]['trace_var']: \
                torch.tensor(connection_idx)}
    
    def infer(self):
        """
        Overwrites the :func: `~SocialMedia.infer~` abstract method to return the trace object
        of the connection type
        
        :return: The trace object related to the connection type 
        :rtype: str
        """
        return self.existing_connections[(self.person1.iid, self.person2.iid)]['trace_var']

In [6]:
class Comment(SocialMedia):
    """
    Defines a comment on a post on the social media. 
    
    Attributes:
    - post: The post to which the comment is related
    - commenter: The person who commented on the post
    - connection: The connection type between the poster and commenter
    - match: Wether the commenter or not. Sampled using cpt present in the :class: `~SocialMedia~`
            class using the commenter's interest, post's topic and the connection between the commenter 
            and poster
    """
    
    def __init__(self, post, commenter, iid):
        """
        Creates the trace object for comment and samples wether the commenter 
        will comment on the post or not.
        
        Eventhough it may seem that when we create an instance of the Comment class 
        it should mean that the comment exists, but thats not the case. The `match` attribute 
        of the class defines wether the comment exists or not. We create these objects so we 
        can run analysis on the probabilties of someone commenting/not commenting by changing 
        the value of `match` to either 1 or 0.
        
        :param Post post: The post to which the comment is related.
        :param Person commenter: The person who commented on the post.
        :param int iid: The unique ID obtained from the knowledge base.
        """
        self.iid = iid
        self.post = post
        self.commenter = commenter
        self.trace_var = 'Comment%d' % (self.iid)
        
        self.connection = Connection(self.post.poster, self.commenter).connection
        self.match = pyro.sample(self.trace_var, \
                                 dist.Categorical(self.cpts['match'][self.commenter.interest]\
                                                  [self.post.topic][self.connection]))
    
    def observe(self, match):
        """
        Overwrites the :func: `~SocialMedia.observe~` abstract method to set the trace object
        of the comment to wether it is a match or not.
        
        :param int match: 1 if there is a match, 0 if not.
        
        :return: Dictionary containing the trace variable as key and match as key.
        :rtype: dict(str, torch.tensor)
        """
        return {self.trace_var: torch.tensor(match)}
    
    def infer(self):
        """
        Overwrites the :func: `~SocialMedia.infer~` abstract method to return the trace object
        of the comment's match.
        
        :return: The trace object related to the comment's match. 
        :rtype: str
        """
        return self.trace_var

In [7]:
def model():
    """
    Defines the pyro model. It creates objects for each of the classes in the social media based on the 
    knowledge base. 
    
    - The `SocialMedia.existing_connections` are cleared each run as they are used to store connections for a 
    single run and need to be cleared so they can be samples again for the next run. 
    - Connection are created between all pairs of people on the social media. This is only sampled once during 
    the model run and is thereafter read from the dictionary `SocialMedia.existing_connections`. 
    """
    SocialMedia.existing_connections = dict()
    
    persons = conn.execute("SELECT * from Persons")
    for person in persons:
        globals()[person[1]] = Person(int(person[0]))
    
    persons = conn.execute("SELECT * from Persons")
    for pp in list(itertools.permutations([p[1] for p in persons])):
        Connection(globals()[pp[0]], globals()[pp[1]])

    posts = conn.execute("SELECT * from Posts")
    for post in posts:
        globals()[post[1]] = Post(globals()[post[2]], int(post[0]))

    comments = conn.execute('SELECT * from Comments')
    for comment in comments:
        globals()[comment[1]] = Comment(globals()[comment[2]], globals()[comment[3]], int(comment[0]))

In [8]:
model()

### Sample Data in Database

![sample-data](./pictures/table-values.jpg)

### DAG Representation

![dag-representation](./pictures/dag.png)

## Conditioning

In [9]:
evidence = SocialMedia.evidence([Post1.observe('politics'), \
                                 Post2.observe('sports'), \
                                 Post3.observe('politics'), 
                                Comment1.observe(1), 
                                Comment2.observe(1), 
                                Comment3.observe(1), 
                                Comment4.observe(1)])

In [10]:
print("Probability Amy's interest is politics = %.3f" % (SocialMedia.condition(model, evidence, Amy.infer(), 1)))
print("Probability Brians's interest is politics = %.3f" % (SocialMedia.condition(model, evidence, Brian.infer(), 1)))
print("Probability Cheryl's interest is politics = %.3f" % (SocialMedia.condition(model, evidence, Cheryl.infer(), 1)))
print("Probability Brian is Amy's family = %.3f" % (SocialMedia.condition(model, evidence, Connection(Amy, Brian).infer(), 0)))
print("Probability Cheryl is Amy's family = %.3f" % (SocialMedia.condition(model, evidence, Connection(Amy, Cheryl).infer(), 0)))
print("Probability Cheryl is Brian's family = %.3f" % (SocialMedia.condition(model, evidence, Connection(Brian, Cheryl).infer(), 0)))

Probability Amy's interest is politics = 0.977
Probability Brians's interest is politics = 0.175
Probability Cheryl's interest is politics = 0.788
Probability Brian is Amy's family = 0.604
Probability Cheryl is Amy's family = 0.389
Probability Cheryl is Brian's family = 0.295


#### Inference

- Amy is most likely to be interested in politics because she’s posted twice on politics. 
- Cheryl is also likely to be interested in politics, even though she hasn’t posted, because she’s commented on two of Amy’s posts on politics.
- Brian, on the other hand, posted on sports, so he’s probably not interested in politics. 
- On the other hand, Brian commented on one of Amy’s posts on politics, so he’s likely to be Amy’s family, because he’s probably not interested in politics. 
- Finally, because Cheryl never commented on Brian’s posts, the probability that Cheryl is Brian’s family is the same as it was originally, 1/3.

## Intervention

In [11]:
evidence2 = SocialMedia.evidence([Post1.observe('politics'), \
                                 Post2.observe('sports'), \
                                 Post3.observe('politics')])

In [12]:
print("Conditional probability Brian will comment on Amy's first post = %.3f" %(SocialMedia.condition(model, evidence2, Comment1.infer(), 1)))
print("Intervention probability Brian will comment on Amy's first post = %.3f" %(SocialMedia.intervention(model, evidence2, Comment1.infer(), 1)))

Conditional probability Brian will comment on Amy's first post = 0.498
Intervention probability Brian will comment on Amy's first post = 0.703


#### Inference

- When we condition, it is highly likely Amy is interested in "Politics" and Brian is interested in "Sports", therefore the probability of Brian commenting would rely on their connection type
$$P(\text{Brian commenting on Post1 | evidence2}) \approx (0.33 \times 0.8) + (0.33 \times 0.5) + (0.34 \times 0.1) = 0.463$$

- However when we perform the do operation we remove the effect of the person's interest on the topic of the post. Therefore we have no information about Amy or Brian's interest, the probability of Brian commenting will then be
$$P(\text{Brian commenting on Post1 | do(evidence2)}) \approx (0.5 \times 1) + 0.5 \times (0.33 \times 0.8 + 0.33 \times 0.5 + 0.34 \times 0.1) = 0.731$$