In [6]:
from signalling import *
import seaborn as sns 
import imageio
from matplotlib import pyplot as plt
from matplotlib import animation

In [7]:
sender, receiver = Sender(10, 10), Receiver(10, 10)
world = World(10)
past_rewards = 0
matrices = []
for epoch in range(3000):
    world_state = world.emit_state()
    message = sender.send_message(world_state)
    action = receiver.act(message)
    reward = world.evaluate_action(action)
    receiver.learn_from_feedback(reward)
    sender.learn_from_feedback(reward)
    past_rewards += reward
    if epoch % 25 == 0:
        plt.tight_layout(pad=0)
        plot = sns.heatmap(
            np.exp(receiver.action_weights)/np.exp(receiver.action_weights).sum(axis=0), 
            square=True, cbar=False, annot=True, fmt='.1f'
        ).get_figure()
        plt.xlabel('messages')
        plt.ylabel('actions')
        plt.title(f'Receiver\'s weights, rollout {epoch}')
        plt.savefig(f"receiver_{epoch}.png")
        plt.clf()
        
        plot = sns.heatmap(
            np.exp(sender.message_weights)/np.exp(sender.message_weights).sum(axis=0), 
            square=True, cbar=False,annot=True, fmt='.1f'
        ).get_figure()
        plt.xlabel('world states')
        plt.ylabel('messages')
        plt.title(f'Sender\'s weights, rollout {epoch}')
        plt.savefig(f"sender_{epoch}.png")
        plt.clf()
           
    if epoch % 100 == 0:
        print(f'Epoch {epoch}, last 100 epochs reward: {past_rewards/100}')
        print(world_state, message, action, reward)
        past_rewards = 0

print("Observation to message mapping:")
print(sender.message_weights.argmax(1))
print("Message to action mapping:")
print(receiver.action_weights.argmax(1))

Epoch 0, last 100 epochs reward: -0.01
4 8 9 -1
Epoch 100, last 100 epochs reward: -0.8
1 2 9 -1
Epoch 200, last 100 epochs reward: -0.9
9 1 4 -1
Epoch 300, last 100 epochs reward: -0.76
9 1 7 -1
Epoch 400, last 100 epochs reward: -0.68
6 5 3 -1
Epoch 500, last 100 epochs reward: -0.64
4 9 1 -1
Epoch 600, last 100 epochs reward: -0.5
8 1 3 -1
Epoch 700, last 100 epochs reward: -0.3
4 4 4 1
Epoch 800, last 100 epochs reward: -0.28
7 0 5 -1
Epoch 900, last 100 epochs reward: -0.04
1 9 1 1
Epoch 1000, last 100 epochs reward: -0.02
4 4 4 1
Epoch 1100, last 100 epochs reward: 0.06
6 5 7 -1
Epoch 1200, last 100 epochs reward: -0.04
9 4 4 -1
Epoch 1300, last 100 epochs reward: 0.02
7 6 7 1
Epoch 1400, last 100 epochs reward: 0.02
4 4 4 1
Epoch 1500, last 100 epochs reward: -0.04
7 6 7 1
Epoch 1600, last 100 epochs reward: 0.02
6 3 1 -1
Epoch 1700, last 100 epochs reward: 0.02
2 2 8 -1
Epoch 1800, last 100 epochs reward: 0.04
4 4 4 1
Epoch 1900, last 100 epochs reward: 0.22
9 9 1 -1
Epoch 2000

<Figure size 432x288 with 0 Axes>

In [8]:
def make_gif(filename_base):
    images = []
    for filename in [f'{filename_base}_{i}.png' for i in range(3000) if i % 25 == 0]:
        images.append(imageio.imread(filename))
    imageio.mimsave(f'{filename_base}.gif', images)
    
make_gif('sender')
make_gif('receiver')