In [None]:
from pathlib import Path
from datetime import datetime
import json
import simplejson

EXPERIMENT_FOLDER = "./experiments/channel_permutation"
Path(EXPERIMENT_FOLDER).mkdir(exist_ok=True)

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf

from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns
sns.reset_defaults()
sns.set()

print('Physical Devices:')
for dev in tf.config.list_physical_devices():
    print(dev)

In [None]:
from zscomm.agent import Agent
from zscomm.comm_channel import CommChannel
from zscomm.synth_teacher import SyntheticTeacher
from zscomm.data import *
from zscomm.play_game import *
from zscomm.loss import *
from zscomm.experiment import Experiment
from zscomm.meta_experiment import *
from zscomm.plot_game import plot_game
from zscomm.analysis import *

## Load Data:

In [None]:
NUM_CLASSES = 3
BATCH_SIZE = 32
CHANNEL_SIZE = 5
USE_MNIST = False

if USE_MNIST:
    TRAIN_DATA, TEST_DATA = get_mnist_data(num_classes=NUM_CLASSES)
else:
    TRAIN_DATA, TEST_DATA = get_simple_card_data(num_classes=NUM_CLASSES)

In [None]:
def generate_train_batch():
    return generate_batch(TRAIN_DATA,
                          batch_size=BATCH_SIZE, 
                          num_classes=NUM_CLASSES)


def generate_test_batch():
    return generate_batch(TEST_DATA,
                          batch_size=BATCH_SIZE, 
                          num_classes=NUM_CLASSES)

# Run Experiments

In [None]:
def create_message_permutation_experiment(channel_size=5, **exp_kwargs):
    
    agent = Agent(channel_size, NUM_CLASSES)

    play_params =  {
        'channel_size': channel_size,
        'p_mutate': 0,
        'message_permutation': True,
    }
    
    return Experiment(
        generate_train_batch, generate_test_batch,
        play_params=play_params, 
        student=agent,
        teacher=agent,
        loss_fn=complete_loss_fn,
        **exp_kwargs
    )


def create_message_permutation_is_all_you_need_experiment(channel_size=5, epochs=300, **exp_kwargs):
    
    agent = Agent(channel_size, NUM_CLASSES)

    play_params =  {
        'channel_size': channel_size,
        'p_mutate': 0,
        'message_permutation': True,
    }
    
    return Experiment(
        generate_train_batch, generate_test_batch,
        play_params=play_params, 
        student=agent,
        teacher=agent,
        loss_fn=student_pred_matches_test_class,
        max_epochs=epochs,
        lr=1e-2,
        step_print_freq=25,
        **exp_kwargs
    )

In [None]:
permutation_experiment = MetaExperiment(
    create_experiment_fn=create_message_permutation_is_all_you_need_experiment,
    num_experiments=6,
    export_location=EXPERIMENT_FOLDER,
)

In [None]:
permutation_experiment.run()

In [None]:
for item in permutation_experiment.experiments:
    if item['status'] == 'Complete':
        total_time = sum([
            x['seconds_taken']
            for x in item['experiment'].training_history
        ])
        print(total_time)

In [None]:
permutation_experiment.experiments

## Analyse Results

In [None]:
def load_channel_permutation_experiment(path):
    
    config = json.load((path / 'config.json').open(mode='r'))
    results = json.load((path / 'results.json').open(mode='r'))
    history = json.load((path / 'training_history.json').open(mode='r'))
    
    agent = Agent(config['play_params']['channel_size'], NUM_CLASSES)
    agent.load_weights(str(path / 'agent_weights'))
    
    config['loss_fn'] = student_pred_matches_test_class
    
    kwargs = {
        k: v for k, v in config.items()
        if k not in ['epochs_optimised', 'optimiser_config']
    }
    experiment = Experiment(
        generate_train_batch, generate_test_batch,
        student=agent,
        teacher=agent,
        **kwargs
    )
    experiment.epoch = config['epochs_optimised']
    experiment.training_history = history
    experiment.results = results
    
    return experiment

In [None]:
experiments = []
for path in Path(EXPERIMENT_FOLDER).glob('*'):
    if not path.is_file():
        exp = load_channel_permutation_experiment(path)
        experiments.append(exp)

In [None]:
def did_converge(experiment):
    return experiment.results['mean_ground_truth_f1'] > 0.9

In [None]:
df = pd.DataFrame([
    {
        'Epoch': epoch,
        'Loss': train_item['loss'],
        'Experiment': f"Run {index}",
        'Category': 'Converged' if did_converge(experiment) else 'Not Converged'
    }
    for index, experiment in enumerate(experiments)
    for epoch, train_item in enumerate(experiment.training_history)
])

plt.figure(figsize=(8, 5))
sns.lineplot(x='Epoch', y='Loss', hue='Category', data=df);

In [None]:
df = pd.DataFrame([
    {
        'Epoch': epoch,
        'Performance': train_item['test_metrics']['mean_ground_truth_f1'],
        'Protocol Diversity': train_item['test_metrics']['mean_protocol_diversity'],
        'Experiment': f"Run {index}",
        'Category': 'Converged' if did_converge(experiment) else 'Not Converged'
    }
    for index, experiment in enumerate(experiments)
    for epoch, train_item in enumerate(experiment.training_history)
    if 'test_metrics' in train_item
])

fig, axs = plt.subplots(1, 2, figsize=(15, 5))

axs[0].set_title('Test Performance')
axs[1].set_title('Protocol Diversity')

sns.lineplot(x='Epoch', y='Performance', hue='Category', data=df, ax=axs[0]);
sns.lineplot(x='Epoch', y='Protocol Diversity', hue='Category', data=df, ax=axs[1]);
for ax in axs:
    ax.set_ylim([0, 1.05])

In [None]:
games_played, _ = experiments[4].run_tests()

In [None]:
for i in range(5):
    inputs, targets, outputs = games_played[i]
    plot_game(inputs, outputs, targets, select_batch=0)

In [None]:
mean_class_message_map = create_mean_class_message_map(games_played)
sns.heatmap(mean_class_message_map, vmin=0, vmax=1);
plt.ylabel('Class')
plt.xlabel('Symbol')
plt.title('Communication Protocol')
plt.show()

In [None]:
conf_matrix = compute_confusion_matrix(games_played)
sns.heatmap(conf_matrix, annot=True, vmin=0, vmax=1)
plt.title('Ground Truth Confusion Matrix')
plt.ylabel('Predicted Class')
plt.xlabel('Actual Class');