In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

In [9]:
import sys

In [2]:
import numpy as np
import pandas as pd
import tensorflow as tf

from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns
sns.set()

print('Physical Devices:')
for dev in tf.config.list_physical_devices():
    print(dev)

Physical Devices:
PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')
PhysicalDevice(name='/physical_device:XLA_CPU:0', device_type='XLA_CPU')


In [3]:
from zscomm.agent import Agent
from zscomm.comm_channel import CommChannel
from zscomm.synth_teacher import SyntheticTeacher
from zscomm.data import *
from zscomm.play_game import *
from zscomm.loss import *
from zscomm.experiment import Experiment
from zscomm.meta_experiment import *
from zscomm.plot_game import plot_game
from zscomm.analysis import *

## Load Data:

In [4]:
NUM_CLASSES = 3
CHANNEL_SIZE = 5
BATCH_SIZE = 32
USE_MNIST = False

if USE_MNIST:
    TRAIN_DATA, TEST_DATA = get_mnist_data(num_classes=NUM_CLASSES)
else:
    TRAIN_DATA, TEST_DATA = get_simple_card_data(num_classes=NUM_CLASSES)

In [5]:
def generate_train_batch():
    return generate_batch(TRAIN_DATA,
                          batch_size=BATCH_SIZE, 
                          num_classes=NUM_CLASSES)


def generate_test_batch():
    return generate_batch(TEST_DATA,
                          batch_size=BATCH_SIZE, 
                          num_classes=NUM_CLASSES)

# Run Experiments

In [6]:
def create_temporal_fixing_experiment(epochs=15, **exp_kwargs):
    
    agent = Agent(CHANNEL_SIZE, NUM_CLASSES, first_activation='relu')

    play_params =  {
        'channel_size': CHANNEL_SIZE,
        'p_mutate': 0.0
    }
    
    return Experiment(
        generate_train_batch, generate_test_batch,
        play_params=play_params, 
        student=agent,
        teacher=agent,
        loss_fn=complete_loss_fn,
        max_epochs=epochs,
        **exp_kwargs
    )

def create_observation_fixing_experiment(epochs=15, **exp_kwargs):
    
    agent = Agent(CHANNEL_SIZE, NUM_CLASSES, first_activation=None)

    play_params = {
        'channel_size': CHANNEL_SIZE,
        'p_mutate': 0.0
    }
    
    return Experiment(
        generate_train_batch, generate_test_batch,
        play_params=play_params, 
        student=agent,
        teacher=agent,
        loss_fn=complete_loss_fn,
        max_epochs=epochs,
        **exp_kwargs
    )

In [7]:
obs_fixing_experiment = create_observation_fixing_experiment()

In [8]:
obs_fixing_experiment.run()

Running experiment...
Run config:
 {'name': 'experiment', 'max_epochs': 15, 'steps_per_epoch': 50, 'epochs_optimised': 11, 'play_params': {'channel_size': 5, 'p_mutate': 0.0}, 'test_freq': 5, 'test_steps': 25, 'optimiser_config': {'name': 'RMSprop', 'learning_rate': 0.009999999776482582, 'decay': 0.0, 'rho': 0.8999999761581421, 'momentum': 0.0, 'epsilon': 1e-07, 'centered': False}, 'optimise_agents_separately': False, 'loss_fn': 'complete_loss_fn'}
Epoch 0, Time Taken (mm:ss): 0:8, Mean Loss: 3.885
Test Loss: 4.654, Ground Truth F1-Score: 0.323, Student Error: 1.107, Teacher Error: 0.547, Protocol Diversity: 0.333, Protocol Entropy: 1.6,
Epoch 1, Time Taken (mm:ss): 0:7, Mean Loss: 3.828
Epoch 2, Time Taken (mm:ss): 0:7, Mean Loss: 3.79
Epoch 3, Time Taken (mm:ss): 0:7, Mean Loss: 2.053
Epoch 4, Time Taken (mm:ss): 0:6, Mean Loss: 1.003
Epoch 5, Time Taken (mm:ss): 0:7, Mean Loss: 1.005
Test Loss: 1.0, Ground Truth F1-Score: 1.0, Student Error: 0.0, Teacher Error: 0.0, Protocol Diversi

In [9]:
obs_fixing_experiment.results

In [10]:
round(obs_fixing_experiment.results['teacher_responsiveness'], 4), 
round(obs_fixing_experiment.results['student_responsiveness'], 4)

TypeError: 'NoneType' object is not subscriptable

In [None]:
temp_fixing_experiment = create_temporal_fixing_experiment()

In [None]:
temp_fixing_experiment.run()

In [None]:
temp_fixing_experiment.results

In [None]:
round(temp_fixing_experiment.results['teacher_responsiveness'], 4), round(temp_fixing_experiment.results['student_responsiveness'], 4)

In [None]:
games_played, _ = temp_fixing_experiment.run_tests()
mean_class_message_map = create_mean_class_message_map(games_played)
sns.heatmap(mean_class_message_map, vmin=0, vmax=1);
plt.ylabel('Class')
plt.xlabel('Symbol')
plt.title('Communication Protocol')
plt.show()

games_played, _ = temp_fixing_experiment.run_tests()
mean_class_message_map = create_mean_index_message_map(games_played)
sns.heatmap(mean_class_message_map, vmin=0, vmax=1);
plt.ylabel('Class')
plt.xlabel('Symbol')
plt.title('Communication Protocol')
plt.show()

In [None]:
games_played, _ = obs_fixing_experiment.run_tests()
mean_class_message_map = create_mean_class_message_map(games_played)
sns.heatmap(mean_class_message_map, vmin=0, vmax=1);
plt.ylabel('Class')
plt.xlabel('Symbol')
plt.title('Communication Protocol')
plt.show()

games_played, _ = obs_fixing_experiment.run_tests()
mean_class_message_map = create_mean_index_message_map(games_played)
sns.heatmap(mean_class_message_map, vmin=0, vmax=1);
plt.ylabel('Class')
plt.xlabel('Symbol')
plt.title('Communication Protocol')
plt.show()

In [None]:
def test_at_different_pms(experiment):
    results = []
    for i in range(6):
        override_play_params = {
            'p_mutate': i / 5.
        }
        _, test_metrics = experiment.run_tests(override_play_params) 
        results.append(test_metrics)
    return results

In [None]:
pms = [i / 5. for i in range(6)]

In [None]:
lstm_layer, *_ = [layer for layer in temp_fixing_experiment.student.layers
                  if isinstance(layer, tf.keras.layers.LSTM)]
for w in lstm_layer.weights:
    x = np.abs(w.numpy())
    print(w.name, 'mean:', x.mean(), '+-', x.std(), 'max:', x.max())

In [None]:
lstm_layer, *_ = [layer for layer in obs_fixing_experiment.student.layers
                  if isinstance(layer, tf.keras.layers.LSTM)]
for w in lstm_layer.weights:
    x = np.abs(w.numpy())
    print(w.name, 'mean:', x.mean(), '+-', x.std(), 'max:', x.max())

In [None]:
temp_pm_tests = test_at_different_pms(temp_fixing_experiment)

In [None]:
temp_pm_tests

In [None]:
obs_pm_tests = test_at_different_pms(obs_fixing_experiment)

In [None]:
obs_pm_tests

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(15, 3))

metrics = {
    'ground_truth_acc': 'Performance', 
    'mean_student_error': 'Student Error', 
    'mean_teacher_error': 'Teacher Error'
}

for i, metric in enumerate(metrics):

    temp_vals = [
        metrics[metric] for metrics in temp_pm_tests
    ]
    sns.lineplot(x=pms, y=temp_vals, label='Temporally-fixed', ax=axs[i])

    obs_vals = [
        metrics[metric] for metrics in obs_pm_tests
    ]
    sns.lineplot(x=pms, y=obs_vals, label='Observation-fixed', ax=axs[i])

    axs[i].set_xlim([-.05, 1.05])
    axs[i].set_title(f'The Effect of Mutations on {metrics[metric]}')
    axs[i].set_xlabel('Mutation Probability')
    axs[i].set_ylabel(metrics[metric])
    
    if i != 1:
        axs[i].get_legend().remove()

plt.tight_layout()
plt.show()

In [None]:
games_played, _ = temp_fixing_experiment.run_tests({'p_mutate': 0.8}) 

In [None]:
for i in range(5):
    inputs, targets, outputs = games_played[i]
    plot_game(inputs, outputs, targets, select_batch=0)

In [None]:
tf_meta_experiment = MetaExperiment(
    create_experiment_fn=create_temporal_fixing_experiment,
    num_experiments=4,
    epochs=15,
    export_location='./experiments/temporally_fixed',
)

In [None]:
tf_meta_experiment.run()

In [None]:
of_meta_experiment = MetaExperiment(
    create_experiment_fn=create_observation_fixing_experiment,
    num_experiments=4,
    epochs=15,
    export_location='./experiments/observation_fixed',
)

In [None]:
of_meta_experiment.run()

In [None]:
zs_results = [
    metrics['mean_ground_truth_f1']
    for stranger_pairings in of_meta_experiment.results
    for metrics in stranger_pairings['vanilla_params_test_metrics']
]

print('Final mean zero-shot test performance for OF-agents:', 
      round(float(np.mean(zs_results)), 4), '+-', 
      round(float(np.std(zs_results)), 4))

In [None]:
zs_results = [
    metrics['mean_ground_truth_f1']
    for stranger_pairings in tf_meta_experiment.results
    for metrics in stranger_pairings['vanilla_params_test_metrics']
]

print('Final mean zero-shot test performance for TF-agents:', 
      round(float(np.mean(zs_results)), 4), '+-', 
      round(float(np.std(zs_results)), 4))