In [338]:
import numpy as np
from IPython.display import display, HTML

Slides by C. Potts with the description of the RSA: https://web.stanford.edu/class/linguist130a/screencasts/130a-screencast-rsa.pdf

In [339]:
def display_table(table, row_labels, col_labels):
    '''
    Function that displays a table in HTML.
    '''
    html = "<table>"
    html += "<tr><th></th><th>" + "</th><th>".join(col_labels) + "</th></tr>"
    for i, row in enumerate(table):
        html += "<tr><th>" + row_labels[i] + "</th><td>" + "</td><td>".join(map(str, row)) + "</td></tr>"
    html += "</table>"
    display(HTML(html))

def safelog(vals):
    with np.errstate(divide='ignore'):
        return np.log(vals)

def normalize(A):
    '''
    Noramlizes a matrix so that the sum of each row is 1.
    '''
    # return (A.T / A.sum(axis=1)).T
    return A / A.sum(axis=1, keepdims=True)

In [340]:
### RSA ###

def l_lit(A):
    '''
    Literal listener: normalizes the defined lexicon
    '''
    return normalize(A * prior)

def s_prag(A):
    '''
    Literal speaker: subtract cost from the defined lexicon
    '''
    return normalize(np.exp(alpha * (safelog(A.T) + C)))

def l_prag(A):
    '''
    Pragmatic listener: normalizes the defined lexicon receiver from the pragmatic speaker
    '''
    return normalize(A.T * prior)

In [341]:
meanings = np.array([[0, 1], [1, 1]]) # the lexicon
C = np.array([0, 0]) # costs of messages 
prior = np.array([0.5, 0.5]) # prior over meanings
alpha = 1. # pragmatic reasoning strength
n_r = meanings.shape[0] # number of meanings
n_s = meanings.shape[1] # number of messages

The initial conditions are the following: 

| | r_1 | r_2 |
|---|---|---|
|s_1 (hat) | 0| 1| 
|s_2 (glasses) | 1| 1| 

'Hat' corresponds to the referent 2, while if someone says 'glasses' it corresponds to both referents 1 and 2.


In [342]:
# display meanings, but also add costs as a column and prior as a row
print('### Lexicon initial ###')
display_table(np.hstack([meanings, C[:, None]]), [f's_{i + 1}' for i in range(n_s)], [f'r_{i + 1}' for i in range(n_s)] + ["C"])
print('### Priors ###')
display_table(prior[None, :], ["p(s)"], [f's_{i + 1}' for i in range(n_s)])

### Lexicon initial ###


Unnamed: 0,r_1,r_2,C
s_1,0,1,0
s_2,1,1,0


### Priors ###


Unnamed: 0,s_1,s_2
p(s),0.3,0.7


In [343]:
print('### Litteral listener ###')
display_table(np.hstack([l_lit(meanings), C[:, None]]), [f's_{i + 1}' for i in range(n_s)], [f'r_{i + 1}' for i in range(n_s)] + ["C"])

### Litteral listener ###


Unnamed: 0,r_1,r_2,C
s_1,0.0,1.0,0.0
s_2,0.3,0.7,0.0


In [344]:
print('### Pragmatic speaker ###')
display_table(np.hstack([np.round(s_prag(l_lit(meanings)), 2)]), [f'r_{i + 1}' for i in range(n_s)] + ["C"], [f's_{i + 1}' for i in range(n_s)])
print('### NB: the matrix is transposed, since the speakers probability is p(s | r) ###')

### Pragmatic speaker ###


Unnamed: 0,s_1,s_2
r_1,0.0,1.0
r_2,0.59,0.41


### NB: the matrix is transposed, since the speakers probability is p(s | r) ###


In [345]:
meanings_upd = l_prag(s_prag(l_lit(meanings))) # RSA reasoning

print('### Pragmatic listener ###')
display_table(np.hstack([np.round(meanings_upd, 2), C[:, None]]), [f's_{i + 1}' for i in range(n_s)], [f'r_{i + 1}' for i in range(n_s)] + ["C"])
print('### Prior ###')
display_table(prior[None, :], ["p(s)"], [f's_{i + 1}' for i in range(n_s)])

### Pragmatic listener ###


Unnamed: 0,r_1,r_2,C
s_1,0.0,1.0,0.0
s_2,0.51,0.49,0.0


### Prior ###


Unnamed: 0,s_1,s_2
p(s),0.3,0.7
