In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf

from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns
sns.set()

print('Physical Devices:')
for dev in tf.config.list_physical_devices():
    print(dev)

Physical Devices:
PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')
PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')


In [2]:
from zscomm.agent import Agent
from zscomm.comm_channel import CommChannel
from zscomm.synth_teacher import SyntheticTeacher
from zscomm.data import *
from zscomm.play_game import *
from zscomm.loss import *
from zscomm.experiment import Experiment
from zscomm.plot_game import plot_game
from zscomm.analysis import *

## Load Data:

In [3]:
NUM_CLASSES = 3
BATCH_SIZE = 32
CHANNEL_SIZE = 5
USE_MNIST = False

if USE_MNIST:
    TRAIN_DATA, TEST_DATA = get_mnist_data(num_classes=NUM_CLASSES)
else:
    TRAIN_DATA, TEST_DATA = get_simple_card_data(num_classes=NUM_CLASSES)

In [4]:
def generate_train_batch():
    return generate_batch(TRAIN_DATA,
                          batch_size=BATCH_SIZE, 
                          num_classes=NUM_CLASSES)


def generate_test_batch():
    return generate_batch(TEST_DATA,
                          batch_size=BATCH_SIZE, 
                          num_classes=NUM_CLASSES)

# Run Experiments

In [5]:
def create_other_play_experiment(p_mutate=0.3):
    
    student = Agent(CHANNEL_SIZE, NUM_CLASSES)
    teacher = Agent(CHANNEL_SIZE, NUM_CLASSES)

    play_params =  {'p_mutate': p_mutate}
    
    return Experiment(
        generate_train_batch, generate_test_batch,
        play_params=play_params, 
        student=student,
        teacher=teacher,
        loss_fn=complete_loss_fn
    )

def create_other_play_separate_optimise_experiment(p_mutate=0.3):
    
    student = Agent(CHANNEL_SIZE, NUM_CLASSES)
    teacher = Agent(CHANNEL_SIZE, NUM_CLASSES)

    play_params =  {'p_mutate': p_mutate}
    
    return Experiment(
        generate_train_batch, generate_test_batch,
        play_params=play_params, 
        student=student,
        teacher=teacher,
        student_loss_fn=student_pred_matches_implied_class,
        teacher_loss_fn=student_pred_matches_test_class
    )

def create_self_play_experiment(p_mutate=0.3, **exp_kwargs):
    
    agent = Agent(CHANNEL_SIZE, NUM_CLASSES)

    play_params =  {'p_mutate': p_mutate}
    
    return Experiment(
        generate_train_batch, generate_test_batch,
        play_params=play_params, 
        student=agent,
        teacher=agent,
        loss_fn=complete_loss_fn,
        **exp_kwargs
    )

In [16]:
class MetaExperiment:
    
    def __init__(
        self, 
        p_mutate,  
        num_experiments=3,
        create_experiment_fn=create_self_play_experiment,
        **make_experiment_kwargs
    ):
        self.p_mutate = p_mutate
        self.num_experiments = num_experiments
        self.experiments = [
            {
                'experiment': create_experiment_fn(
                    p_mutate, **make_experiment_kwargs
                ),
                'status': 'Not Run',
                'results': None,
                'index': i,
            }
            for i in range(num_experiments)
        ]
        for item in self.experiments:
            item['experiment'].print_my_history = item['experiment'].print_history
            item['experiment'].print_history = self.print_history
        
    def print_history(self):
        num_complete = len([
            item for item in self.experiments
            if item['status'] == 'Complete'
        ])
        for item in self.experiments:
            if item['status'] == 'Complete':
                print(f"Results of experiment {item['index']}:")
                item['experiment'].print_test_metrics(item['results'])
        for item in self.experiments:
            if item['status'] == 'In Progress':
                print(f"Running experiment {item['index']}", 
                      f'({num_complete}/{self.num_experiments} complete):')
                item['experiment'].print_my_history()
                break

    def is_finished(self):
        return all([
            item['status'] == 'Complete' 
            for item in self.experiments
        ])
    
    def get_experiment_to_run(self):
        for item in self.experiments:
            if item['status'] == 'In Progress':
                return item
        not_run = [
            item
            for item in self.experiments
            if item['status'] == 'Not Run'
        ]
        if len(not_run) == 0:
            return None
        return not_run[0]
    
    def get_experiment_results(self, experiment):
        test_metrics_items = [
            item['test_metrics']
            for item in experiment.training_history
            if 'test_metrics' in item
        ]
        return test_metrics_items[-1]
    
    def run(self):
        try:
            while not self.is_finished():
                experiment_item = self.get_experiment_to_run()
                index = experiment_item['index']
                experiment = experiment_item['experiment']
                
                self.experiments[index]['status'] = 'In Progress'
                
                experiment.train(catch_interrupt=False)
                
                self.experiments[index]['results'] = \
                    self.get_experiment_results(experiment)
                self.experiments[index]['status'] = 'Complete'
                
        except KeyboardInterrupt:
            pass
        
        clear_output()
        self.print_history()
        print('Run Stopped.')

In [17]:
meta_experiment = MetaExperiment(0.3)

In [18]:
meta_experiment.run()

Results of experiment 0:
Test Loss: 4.935, Ground Truth F1-Score: 0.356, Student Error: 1.416, Teacher Error: 1.326, Protocol Diversity: 0.53,
Results of experiment 1:
Test Loss: 4.688, Ground Truth F1-Score: 0.339, Student Error: 1.213, Teacher Error: 1.367, Protocol Diversity: 0.546,
Results of experiment 2:
Test Loss: 4.643, Ground Truth F1-Score: 0.311, Student Error: 1.239, Teacher Error: 1.22, Protocol Diversity: 0.526,
Run Stopped.


## Render Model Graph in Tensorboard

In [None]:
# # Set up logging.
# from datetime import datetime
# import tensorboard
# %load_ext tensorboard

# stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
# logdir = f'logs\\{stamp}'
# writer = tf.summary.create_file_writer(logdir)

# # Bracket the function call with
# # tf.summary.trace_on() and tf.summary.trace_export().
# tf.summary.trace_off()
# tf.summary.trace_on(graph=True, profiler=True)
# # Call only one tf.function when tracing.

# @tf.function
# def graph_training_step():
#     return only_teacher_training_step(agent_1)

# graph_training_step()

# with writer.as_default():
#     tf.summary.trace_export(
#         name="teacher_only_training_step",
#         step=0,
#         profiler_outdir=logdir)
    
# tf.summary.trace_off()

In [None]:
# %tensorboard --logdir logs

## Analyse Trained Models

In [None]:
games_played, test_metrics = experiment.run_tests()
test_metrics

In [None]:
games_played, _ = experiment.run_tests()
conf_matrix = compute_confusion_matrix(games_played)
sns.heatmap(conf_matrix, annot=True, vmin=0, vmax=1)
plt.title('Ground Truth Confusion Matrix')
plt.ylabel('Predicted Class')
plt.xlabel('Actual Class');

In [None]:
games_played, _ = experiment2.run_tests()
conf_matrix = compute_confusion_matrix(games_played)
sns.heatmap(conf_matrix, annot=True, vmin=0, vmax=1)
plt.title('Ground Truth Confusion Matrix')
plt.ylabel('Predicted Class')
plt.xlabel('Actual Class');

The rows correspond to the true labels and the columns to the predicted labels. Each column is divided by its sum in order to show the percentage of the time the model predicts the given class.

In [None]:
games_played, _ = experiment.run_tests()
mean_class_message_map = create_mean_class_message_map(games_played)
sns.heatmap(mean_class_message_map, vmin=0, vmax=1);
plt.ylabel('Class')
plt.xlabel('Symbol')
plt.title('Communication Protocol')
plt.show()

In [None]:
games_played, _ = experiment2.run_tests()
mean_class_message_map = create_mean_class_message_map(games_played)
sns.heatmap(mean_class_message_map, vmin=0, vmax=1);
plt.ylabel('Class')
plt.xlabel('Symbol')
plt.title('Communication Protocol')
plt.show()

## Zero-shot Coordination

In [None]:
inputs, targets = generate_test_batch()
outputs = play_game(
    inputs, 
    experiment.teacher, 
    experiment2.student, 
    training=False, 
    p_mutate=0,
)

teacher_error, protocol_diversity = experiment.get_teacher_test_metrics([(inputs, targets, outputs)])
f1, student_error = experiment.get_student_test_metrics([(inputs, targets, outputs)])
loss = experiment.get_test_loss([(inputs, targets, outputs)])
print(f'Teacher Error: {teacher_error}, Protocol Diversity {protocol_diversity}, Student Error: {student_error}, F1 Score {f1}, Loss: {loss}')

In [None]:
plot_game(inputs, outputs, targets, select_batch=1)

In [None]:
inputs, targets = generate_test_batch()
outputs = play_game(
    inputs, 
    experiment2.teacher, 
    experiment.student, 
    training=False, 
    p_mutate=0,
)

teacher_error, protocol_diversity = experiment.get_teacher_test_metrics([(inputs, targets, outputs)])
f1, student_error = experiment.get_student_test_metrics([(inputs, targets, outputs)])
loss = experiment.get_test_loss([(inputs, targets, outputs)])
print(f'Teacher Error: {teacher_error}, Protocol Diversity {protocol_diversity}, Student Error: {student_error}, F1 Score {f1}, Loss: {loss}')

In [None]:
plot_game(inputs, outputs, targets)