In [1]:
from pathlib import Path
from datetime import datetime
import json
import simplejson

EXPERIMENT_FOLDER = "./experiments/channel_permutation_2"
Path(EXPERIMENT_FOLDER).mkdir(exist_ok=True)

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

In [3]:
import numpy as np
import pandas as pd
import tensorflow as tf

from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns
sns.reset_defaults()
sns.set()

print('Physical Devices:')
for dev in tf.config.list_physical_devices():
    print(dev)

Physical Devices:
PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')


In [4]:
from zscomm.agent import Agent
from zscomm.comm_channel import CommChannel
from zscomm.synth_teacher import SyntheticTeacher
from zscomm.data import *
from zscomm.play_game import *
from zscomm.loss import *
from zscomm.experiment import Experiment
from zscomm.meta_experiment import *
from zscomm.plot_game import plot_game
from zscomm.analysis import *

## Load Data:

In [5]:
NUM_CLASSES = 3
BATCH_SIZE = 32
CHANNEL_SIZE = 5

TRAIN_DATA, TEST_DATA = get_simple_card_data(num_classes=NUM_CLASSES)

In [6]:
def generate_train_batch():
    return generate_batch(TRAIN_DATA,
                          batch_size=BATCH_SIZE, 
                          num_classes=NUM_CLASSES)


def generate_test_batch():
    return generate_batch(TEST_DATA,
                          batch_size=BATCH_SIZE, 
                          num_classes=NUM_CLASSES)

# Run Experiments

In [7]:
# def create_message_permutation_experiment(channel_size=5, **exp_kwargs):
    
#     agent = Agent(channel_size, NUM_CLASSES)

#     play_params =  {
#         'channel_size': channel_size,
#         'p_mutate': 0,
#         'message_permutation': True,
#     }
    
#     return Experiment(
#         generate_train_batch, generate_test_batch,
#         play_params=play_params, 
#         student=agent,
#         teacher=agent,
#         loss_fn=complete_loss_fn,
#         **exp_kwargs
#     )


def create_message_permutation_separate_experiment(channel_size=5, epochs=150, **exp_kwargs):
    
    agent = Agent(channel_size, NUM_CLASSES)

    play_params =  {
        'channel_size': channel_size,
        'p_mutate': 0,
        'message_permutation': True,
        'channel_temp': 1.0,
    }
    
    beta = 10
    def teacher_loss_fn(o, t):
        return beta * student_pred_matches_test_class(o, t)
    
    return Experiment(
        generate_train_batch, generate_test_batch,
        play_params=play_params, 
        student=agent,
        teacher=agent,
        student_loss_fn=student_pred_matches_test_class,
        teacher_loss_fn=teacher_loss_fn,
        max_epochs=epochs,
        lr=1e-2,
        step_print_freq=10,
        **exp_kwargs
    )


def create_channel_permutation_experiment(channel_size=5, epochs=200, **exp_kwargs):
    
    agent = Agent(channel_size, NUM_CLASSES)

    start_temp = 10
    end_temp = 0.1
    temp_anneal_end_epoch = 200
    a = -np.log(end_temp / start_temp) / temp_anneal_end_epoch
    
    def play_params(epoch):
        if epoch < temp_anneal_end_epoch:
            channel_temp = float(start_temp * np.exp(-a*epoch))
        else:
            channel_temp = end_temp
        
        return {
            'channel_size': channel_size,
            'p_mutate': 0,
            'message_permutation': True,
            'channel_temp': channel_temp,
        }
    
    return Experiment(
        generate_train_batch, generate_test_batch,
        play_params=play_params, 
        student=agent,
        teacher=agent,
        loss_fn=student_pred_matches_test_class,
        max_epochs=epochs,
        lr=1e-2,
        step_print_freq=10,
        **exp_kwargs
    )

In [8]:
permutation_experiment = MetaExperiment(
    create_experiment_fn=create_channel_permutation_experiment,
    num_experiments=6,
    export_location=EXPERIMENT_FOLDER,
)

In [9]:
games_played, _ = permutation_experiment.experiments[0]['experiment'].run_tests()

In [10]:
# teacher = Agent(CHANNEL_SIZE, NUM_CLASSES)
# student = Agent(CHANNEL_SIZE, NUM_CLASSES)

# exp = permutation_experiment.experiments[0]['experiment']

# inputs, targets = generate_train_batch()

# with tf.GradientTape(persistent=True) as tape:
#     outputs = play_game(inputs, teacher, student, 
#                         training=True, 
#                         **exp.get_play_params())

#     loss = student_pred_matches_test_class(outputs, targets)
    
# teacher_grads = tape.gradient(loss, teacher.trainable_variables)
# student_grads = tape.gradient(loss, student.trainable_variables)

# for v, g in zip(teacher.trainable_variables, teacher_grads):
#     print(f'{v.name} teacher grad norm: {tf.reduce_sum(g**2)**0.5}')

# print()

# for v, g in zip(student.trainable_variables, student_grads):
#     print(f'{v.name} student grad norm: {tf.reduce_sum(g**2)**0.5}')
    
# # agent/dense_200/kernel:0 teacher grad norm: 0.004447852727025747
# # agent/dense_200/bias:0 teacher grad norm: 0.004131803754717112
# # agent/lstm_100/kernel:0 teacher grad norm: 0.00819376204162836
# # agent/lstm_100/recurrent_kernel:0 teacher grad norm: 0.001391303027048707
# # agent/lstm_100/bias:0 teacher grad norm: 0.008151191286742687
# # agent/dense_201/kernel:0 teacher grad norm: 0.004122724756598473
# # agent/dense_201/bias:0 teacher grad norm: 0.01262927521020174

In [11]:
# exp = permutation_experiment.experiments[0]['experiment']
# inputs, targets = generate_train_batch()
# outputs = play_game(inputs, exp.teacher, exp.student, 
#                     training=True, 
#                     **exp.get_play_params())

In [12]:
# plot_game(inputs, outputs, targets, select_batch=2)

In [13]:
permutation_experiment.experiments[0]['experiment'].student.summary()

Model: "agent"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
concatenate (Concatenate)    multiple                  0         
_________________________________________________________________
dropout (Dropout)            multiple                  0         
_________________________________________________________________
lambda (Lambda)              multiple                  0         
_________________________________________________________________
dense (Dense)                multiple                  1664      
_________________________________________________________________
lstm (LSTM)                  multiple                  49408     
_________________________________________________________________
dense_1 (Dense)              multiple                  520       
_________________________________________________________________
lambda_1 (Lambda)            multiple                  0     

In [None]:
permutation_experiment.run()

Running meta_experiment...
Running experiment 0 (0/6 complete):
Running experiment...
Run config:
 {'name': 'experiment', 'max_epochs': 200, 'steps_per_epoch': 50, 'epochs_optimised': 93, 'play_params': {'channel_size': 5, 'p_mutate': 0, 'message_permutation': True, 'channel_temp': 1.1748975549395295}, 'test_freq': 5, 'test_steps': 25, 'optimiser_config': {'name': 'RMSprop', 'learning_rate': 0.009999999776482582, 'decay': 0.0, 'rho': 0.8999999761581421, 'momentum': 0.0, 'epsilon': 1e-07, 'centered': False}, 'optimise_agents_separately': False, 'loss_fn': 'student_pred_matches_test_class'}
Epoch 0, Time Taken (mm:ss): 0:7, Mean Loss: 1.1
Test Loss: 1.1, Ground Truth F1-Score: 0.344, Student Error: 1.099, Teacher Error: 1.631, Protocol Diversity: 0.358,
Epoch 1, Time Taken (mm:ss): 0:8, Mean Loss: 1.103
Epoch 2, Time Taken (mm:ss): 0:10, Mean Loss: 1.101
Epoch 3, Time Taken (mm:ss): 0:7, Mean Loss: 1.101
Epoch 4, Time Taken (mm:ss): 0:6, Mean Loss: 1.101
Epoch 5, Time Taken (mm:ss): 0:6,

In [None]:
for item in permutation_experiment.experiments:
    if item['status'] == 'Complete':
        total_time = sum([
            x['seconds_taken']
            for x in item['experiment'].training_history
        ])
        print(total_time)

In [None]:
permutation_experiment.experiments

## Analyse Results

In [None]:
def load_channel_permutation_experiment(path):
    
    config = json.load((path / 'config.json').open(mode='r'))
    results = json.load((path / 'results.json').open(mode='r'))
    history = json.load((path / 'training_history.json').open(mode='r'))
    
    agent = Agent(config['play_params']['channel_size'], NUM_CLASSES)
    agent.load_weights(str(path / 'agent_weights'))
    
    config['loss_fn'] = student_pred_matches_test_class
    
    kwargs = {
        k: v for k, v in config.items()
        if k not in ['epochs_optimised', 'optimiser_config']
    }
    experiment = Experiment(
        generate_train_batch, generate_test_batch,
        student=agent,
        teacher=agent,
        **kwargs
    )
    experiment.epoch = config['epochs_optimised']
    experiment.training_history = history
    experiment.results = results
    
    return experiment

In [None]:
experiments = []
for path in Path(EXPERIMENT_FOLDER).glob('*'):
    if not path.is_file():
        exp = load_channel_permutation_experiment(path)
        experiments.append(exp)
        print('Loaded experiment from:', path)

In [None]:
def did_converge_to_global_optima(experiment):
    return experiment.results['mean_ground_truth_f1'] > 0.9

def did_converge_to_local_optima(experiment):
    return 0.9 > experiment.results['mean_ground_truth_f1'] > 0.6

def get_category(experiment):
    if did_converge_to_global_optima(experiment):
        return 'Coverged to Global Optima'
    if did_converge_to_local_optima(experiment):
        return 'Coverged to Local Optima'
    return 'Did Not Converge'

In [None]:
df = pd.DataFrame([
    {
        'Epoch': epoch,
        'Loss': train_item['loss'],
        'Experiment': f"Run {index}",
        'Category': get_category(experiment)
    }
    for index, experiment in enumerate(experiments)
    for epoch, train_item in enumerate(experiment.training_history)
])
df = df[df['Category'] != 'Did Not Converge']

plt.figure(figsize=(8, 5))
sns.lineplot(x='Epoch', y='Loss', hue='Category', data=df);

In [None]:
df = pd.DataFrame([
    {
        'Epoch': epoch,
        'Performance': train_item['test_metrics']['mean_ground_truth_f1'],
        'Protocol Diversity': train_item['test_metrics']['mean_protocol_diversity'],
        'Experiment': f"Run {index}",
        'Category': get_category(experiment)
    }
    for index, experiment in enumerate(experiments)
    for epoch, train_item in enumerate(experiment.training_history)
    if 'test_metrics' in train_item
])

fig, axs = plt.subplots(1, 2, figsize=(15, 5))

axs[0].set_title('Test Performance')
axs[1].set_title('Protocol Diversity')

sns.lineplot(x='Epoch', y='Performance', hue='Category', data=df, ax=axs[0]);
sns.lineplot(x='Epoch', y='Protocol Diversity', hue='Category', data=df, ax=axs[1]);
for ax in axs:
    ax.set_ylim([0, 1.05])

In [None]:
games_played, _ = experiments[4].run_tests()

In [None]:
# for i in range(5):
#     inputs, targets, outputs = games_played[i]
#     plot_game(inputs, outputs, targets, select_batch=0)

In [None]:
mean_class_message_map = create_mean_class_message_map(games_played)
sns.heatmap(mean_class_message_map, vmin=0, vmax=1);
plt.ylabel('Class')
plt.xlabel('Symbol')
plt.title('Communication Protocol')
plt.show()

In [None]:
conf_matrix = compute_confusion_matrix(games_played)
sns.heatmap(conf_matrix, annot=True, vmin=0, vmax=1)
plt.title('Ground Truth Confusion Matrix')
plt.ylabel('Predicted Class')
plt.xlabel('Actual Class');