In [1]:
from pathlib import Path
from datetime import datetime
import json

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

In [3]:
import numpy as np
import pandas as pd
import tensorflow as tf

from IPython.display import clear_output
import matplotlib.pyplot as plt
from matplotlib import gridspec
%matplotlib inline

import seaborn as sns
sns.reset_defaults()
sns.set()

print('Physical Devices:')
for dev in tf.config.list_physical_devices():
    print(dev)

Physical Devices:
PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')


In [4]:
from zscomm.agent import Agent
from zscomm.comm_channel import CommChannel
from zscomm.synth_teacher import SyntheticTeacher
from zscomm.data import *
from zscomm.play_game import *
from zscomm.loss import *
from zscomm.experiment import Experiment
from zscomm.meta_experiment import *
from zscomm.vary_play_param_experiment import *
from zscomm.plot_game import plot_game
from zscomm.analysis import *

## Load Data:

In [5]:
NUM_CLASSES = 3
BATCH_SIZE = 32
CHANNEL_SIZE = 5
USE_MNIST = False

if USE_MNIST:
    TRAIN_DATA, TEST_DATA = get_mnist_data(num_classes=NUM_CLASSES)
else:
    TRAIN_DATA, TEST_DATA = get_simple_card_data(num_classes=NUM_CLASSES)

In [6]:
def generate_train_batch():
    return generate_batch(TRAIN_DATA,
                          batch_size=BATCH_SIZE, 
                          num_classes=NUM_CLASSES)


def generate_test_batch():
    return generate_batch(TEST_DATA,
                          batch_size=BATCH_SIZE, 
                          num_classes=NUM_CLASSES)

# Run Experiments

In [7]:
def create_self_play_experiment(p_mutate=0.4, channel_size=5, **exp_kwargs):
    
    agent = Agent(channel_size, NUM_CLASSES)

    play_params =  {
        'channel_size': channel_size, 
        'p_mutate': p_mutate,
    }
    
    return Experiment(
        generate_train_batch, generate_test_batch,
        play_params=play_params, 
        student=agent,
        teacher=agent,
        loss_fn=complete_loss_fn,
        **exp_kwargs
    )

def create_unkind_experiment(p_mutate=0.4, channel_size=5, **exp_kwargs):
    
    agent = Agent(channel_size, NUM_CLASSES)

    play_params =  {
        'channel_size': channel_size,
        'p_mutate': p_mutate, 
        'kind_mutations': False,
    }
    
    return Experiment(
        generate_train_batch, generate_test_batch,
        play_params=play_params, 
        student=agent,
        teacher=agent,
        loss_fn=complete_loss_fn,
        **exp_kwargs
    )

## Investigating the effect of message mutation:

In [8]:
BASE_FOLDER = './experiments/kind_chan_size'
Path(BASE_FOLDER).mkdir(exist_ok=True)
# EXPERIMENT_FOLDER = f'{BASE_FOLDER}/{datetime.now().strftime("%d-%m_%H-%M")}'

# EXPERIMENT_FOLDER = "./experiments/message_mutation_3"
# EXPERIMENT_FOLDER = "./experiments/message_mutation_1/24-09_09-40"
# EXPERIMENT_FOLDER = "./experiments/message_mutation_1/03-10_02-03"

# EXPERIMENT_FOLDER = "./experiments/message_mutation_1/05-10_03-00"
EXPERIMENT_FOLDER = './experiments/kindness_chan_size/08-10_01-16'
Path(EXPERIMENT_FOLDER).mkdir(exist_ok=True)
EXPERIMENT_FOLDER

'./experiments/kindness_chan_size/08-10_01-16'

In [9]:
kind_experiment = VaryPlayParamExperiment(save_location='./experiments/kind_chan_size',
                                          param_vals=[10, 15, 20, 30],
                                          param_name='channel_size',
                                          max_epochs=250,
                                          num_experiments_per_val=3,
                                          create_experiment_fn=create_self_play_experiment)

In [29]:

curr_exp

<zscomm.meta_experiment.MetaExperiment at 0x1963a0d6ba8>

In [32]:
def measure_zero_shot_coordination(experiment_1,
                                   experiment_2,
                                   num_tests=5, 
                                   **zs_play_kwargs):
    
    results = []
    games_played = test_game(experiment_1.teacher, 
                             experiment_2.student,
                             experiment_1.generate_test_batch,
                             num_tests=num_tests,
                             **zs_play_kwargs)
    test_metrics = experiment_1.extract_test_metrics(games_played)
    results.append(test_metrics)

    games_played = test_game(experiment_2.teacher, 
                             experiment_1.student,
                             experiment_1.generate_test_batch,
                             num_tests=num_tests,
                             **zs_play_kwargs)
    test_metrics = experiment_1.extract_test_metrics(games_played)
    results.append(test_metrics)
    
    return results

def measure_zero_shot_coordination_internal():
    results = []

    curr_exp, *_ = [item['experiment'] for item in kind_experiment.experiments
                    if item['status'] == 'In Progress']

    for item_1, item_2 in combinations(item['experiment'].experiments, 2):
        e1 = item_1['experiment']
        e2 = item_2['experiment']

        vanilla_params_test_metrics = measure_zero_shot_coordination(
            e1, e2, **{
                **e1.get_play_params(), 
                'p_mutate': 0, 'message_permutation': False
            }
        )
        training_params_test_metrics = measure_zero_shot_coordination(
            e1, e2, **e1.get_play_params()
        )

        results.append({
            'vanilla_params_test_metrics': vanilla_params_test_metrics,
            'training_params_test_metrics': training_params_test_metrics,
        })

    return results

for item in kind_experiment.experiments:
    item['experiment'].measure_zero_shot_coordination = measure_zero_shot_coordination_internal

In [33]:
kind_experiment.run()

Running vary_pm_experiment...
meta_experiment_channel_size=10 results:  [0.23125, 0.3125, 0.38125, 0.35625, 0.36250000000000004, 0.2875]
Running experiment 2 (1/4 complete):
Run Stopped.


In [None]:
unkind_experiment = VaryPlayParamExperiment(save_location='./experiments/unkind_chan_size',
                                            param_vals=[10, 15, 20, 30],
                                            param_name='channel_size',
                                            max_epochs=250,
                                            num_experiments_per_val=3,
                                            create_experiment_fn=create_unkind_experiment)

In [4]:
grand_total = 0
for item_1 in vary_pm_experiment.experiments:
    for item_2 in item_1['experiment'].experiments:
        total_time = sum([
            x['seconds_taken']
            for x in item_2['experiment'].training_history
        ])
        print(int(total_time / 3600), 'hours,', int(total_time / 60) % 60, 'mins and', 
              int(total_time) % 60, 'seconds taken for experiment', 
              f"pm={item_1['p_mutate']}_{item_2['index']}")
        grand_total += total_time
    print()
    
print(int(grand_total / 3600), 'hours,', int(grand_total / 60) % 60, 'mins and', 
      int(grand_total) % 60, 'seconds taken for whole experiment')

NameError: name 'vary_pm_experiment' is not defined

In [42]:
zs_coord_df = pd.DataFrame([
    {
        'Mutation Probability': item['p_mutate'],
        'Zero-Shot Coordination Score': score,
    }
    for item in vary_pm_experiment.experiments
    if item['status'] == 'Complete'
    for score in item['results']['zs_coord_f1_scores']
])
zs_coord_df.head()

TypeError: list indices must be integers or slices, not str

In [None]:
self_play_df = pd.DataFrame([
    {
        'Mutation Probability': item_1['p_mutate'],
        'Self-play Performance': item_2['results']['mean_ground_truth_f1'],
    }
    for item_1 in vary_pm_experiment.experiments
    if item_1['status'] == 'Complete'
    for item_2 in item_1['experiment'].experiments
    if item_2['status'] == 'Complete'
])

In [None]:
sns.reset_defaults()
sns.set()

In [None]:
plt.figure(figsize=(8, 4))
ax = sns.lineplot(x=[-2, 2], y=[1/NUM_CLASSES, 1/NUM_CLASSES], color=(0.1, 0.1, 0.1, 0.5), label='Baseline')
ax.lines[0].set_linestyle("--")

sns.lineplot(x='Mutation Probability', y='Zero-Shot Coordination Score', data=zs_coord_df, label='Zero-shot Performance')
sns.scatterplot(x='Mutation Probability', y='Zero-Shot Coordination Score', data=zs_coord_df, marker='x')
sns.lineplot(x='Mutation Probability', y='Self-play Performance', data=self_play_df, label='Self-play Performance')
sns.scatterplot(x='Mutation Probability', y='Self-play Performance', data=self_play_df, marker='x')
sns.scatterplot(x='x', y='y', data=pd.DataFrame([{'x': 100, 'y': 100}]), color=(0.1, 0.1, 0.1, 0.5), marker='x', label='Raw Data')
plt.ylim([0, 1.05])
plt.xlim([-.05, 1.05])
plt.title('The Effect of Mutations on Zero-Shot Coordination')
plt.ylabel('Performance')
plt.xlabel('Mutation Probability')
# plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.legend(loc=4)
plt.show()

In [None]:
zs_coord_df[zs_coord_df['Mutation Probability'] == 0.4].describe()

In [None]:
zs_coord_df[zs_coord_df['Mutation Probability'] == 1.0].describe()

In [None]:
df_train = pd.DataFrame([
    {
        'Epoch': epoch,
        'Experiment': f"$p_m={item_1['p_mutate']}$",
        'Subexperiment': f'subexperiment_{item_2["index"]}',
        'Train Loss': training_item['loss']
    }
    for item_1 in vary_pm_experiment.experiments
    for item_2 in item_1['experiment'].experiments
    for epoch, training_item in enumerate(item_2['experiment'].training_history) 
])
df_train.head()

In [None]:
plt.figure(figsize=(8, 4))
sns.lineplot(x='Epoch', y='Train Loss', hue='Experiment', 
             data=df_train);
#              data=df_train[df_train['Experiment'] == 'p_m=0.2']);
# Put the legend out of the figure
# plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.title('Training History by Mutation Probability')
plt.show()

In [None]:
df_test = pd.DataFrame([
    {
        'Epoch': epoch,
        'Experiment': f"$p_m={item_1['p_mutate']}$",
        'Subexperiment': f'subexperiment_{item_2["index"]}',
        'Performance': training_item['test_metrics']['mean_ground_truth_f1'],
        'Protocol Diversity': training_item['test_metrics']['mean_protocol_diversity'],
    }
    for item_1 in vary_pm_experiment.experiments
    for item_2 in item_1['experiment'].experiments
    for epoch, training_item in enumerate(item_2['experiment'].training_history)
    if 'test_metrics' in training_item
])

In [None]:
plt.figure(figsize=(8, 4))
sns.lineplot(x='Epoch', y='Performance', hue='Experiment', 
             data=df_test);
#              data=df_test[df_test['Experiment'] == 'p_m=0.2']);
# Put the legend out of the figure
# plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.title('Self-play Test Performance History')
plt.show()

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(15, 5))
sns.lineplot(x='Epoch', y='Performance', hue='Experiment', 
             data=df_test, ax=axs[0]);
sns.lineplot(x='Epoch', y='Protocol Diversity', hue='Experiment', 
             data=df_test, ax=axs[1]);
#              data=df_test[df_test['Experiment'] == 'p_m=0.2']);
# Put the legend out of the figure
# plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.title('Self-play Test Performance')
plt.show()

In [None]:
exp = vary_pm_experiment.experiments[2]['experiment'].experiments[1]['experiment']
games_played, test_metrics = exp.run_tests()
test_metrics

In [None]:
for i in range(5):
    inputs, targets, outputs = games_played[i]
    plot_game(inputs, outputs, targets, select_batch=2)

In [None]:
num_meta_experiments = len(vary_pm_experiment.experiments)
exps_per_meta = len(vary_pm_experiment.experiments[0]['experiment'].experiments)

def make_cm_map_for_exp(i, j):
    meta_exp = vary_pm_experiment.experiments[i]['experiment']
    sub_exp = meta_exp.experiments[j]['experiment']
    games_played, _ = sub_exp.run_tests()
    return create_mean_class_message_map(games_played)

class_message_maps = [[make_cm_map_for_exp(i, j) for j in range(exps_per_meta)] 
                      for i in range(num_meta_experiments)]

In [None]:
def plot_protocol_maps(maps, ylabel='Class', yticklabels=None):
    fig = plt.figure(figsize=(2*2*exps_per_meta, 2*num_meta_experiments)) 

    sqrs_per_plot = 5
    gs = gridspec.GridSpec(num_meta_experiments, sqrs_per_plot*exps_per_meta+1)
    
    yticklabels = yticklabels or [i+1 for i in range(NUM_CLASSES)]

    for i in range(num_meta_experiments):
        meta_exp = vary_pm_experiment.experiments[i]['experiment']
        for j in range(exps_per_meta):
            sub_exp = meta_exp.experiments[j]['experiment']
            ax = plt.subplot(gs[i, sqrs_per_plot*j:sqrs_per_plot*(j+1)])

            last_col = j == exps_per_meta - 1
            cbar_ax = plt.subplot(gs[i, -1]) if last_col else None

            sns.heatmap(maps[i][j], vmin=0, vmax=1, ax=ax, 
                        cbar=last_col, cbar_ax=cbar_ax);

            if j == 0: 
                p_mutate = sub_exp.get_play_params().get('p_mutate', 0.0) 
                ax.set_ylabel(f'$p_m = {p_mutate}$\n\n{ylabel}')
                ax.set_yticklabels(yticklabels)
            else: 
                ax.set_yticks([])

            if i == 0:
                ax.set_title(f'Experiment {j+1}')
                ax.set_xticks([])
            elif i == num_meta_experiments - 1:
                ax.set_xlabel('Symbol')
            else:
                ax.set_xticks([])

    plt.tight_layout()
    plt.show()

In [None]:
plot_protocol_maps(class_message_maps)

In [None]:
def make_im_map_for_exp(i, j):
    meta_exp = vary_pm_experiment.experiments[i]['experiment']
    sub_exp = meta_exp.experiments[j]['experiment']
    games_played, _ = sub_exp.run_tests()
    return create_mean_index_message_map(games_played)

index_message_maps = [[make_im_map_for_exp(i, j) for j in range(exps_per_meta)] 
                      for i in range(num_meta_experiments)]

In [None]:
plot_protocol_maps(index_message_maps, ylabel='Time Step Index', yticklabels=[0, 1, 2])

In [None]:
n_rows = 2
fig = plt.figure(figsize=(2*2*exps_per_meta, 2*n_rows))

sqrs_per_plot = 5
gs = gridspec.GridSpec(n_rows, sqrs_per_plot*exps_per_meta+1)

yticklabels = [i+1 for i in range(NUM_CLASSES)]


maps = [class_message_maps[0]] + [index_message_maps[0]]

for i in range(2):
    for j in range(exps_per_meta):
        
        last_col = j == exps_per_meta - 1
        cbar_ax = plt.subplot(gs[i, -1]) if last_col else None
        ax = plt.subplot(gs[i, sqrs_per_plot*j:sqrs_per_plot*(j+1)])

        sns.heatmap(maps[i][j], vmin=0, vmax=1, ax=ax, 
                    cbar=last_col, cbar_ax=cbar_ax);

        if j == 0 and i == 0: 
            ax.set_ylabel('Class')
            ax.set_yticklabels([1, 2, 3])
        elif j == 0 and i == 1:
            ax.set_ylabel('Time Step Index')
            ax.set_yticklabels([0, 1, 2])
        else:
            ax.set_yticks([])

        if i == 0:
            ax.set_title(f'Experiment {j+1}')
            ax.set_xticks([])
        elif i == 1:
            ax.set_xlabel('Symbol')

plt.tight_layout()
plt.show();

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(5, 3))
cols = 11
gs = gridspec.GridSpec(1, cols)

ax = plt.subplot(gs[0, :cols//2])
cbar_ax = plt.subplot(gs[0, -1])

sns.heatmap(tf.transpose(maps[0][0]), vmin=0, vmax=1, ax=ax, 
            cbar=last_col, cbar_ax=axs[2]);

ax.set_xlabel('Class')
ax.set_xticklabels([1, 2, 3])
ax.set_ylabel('Symbol')

ax = plt.subplot(gs[0, cols//2:-1])

sns.heatmap(tf.transpose(maps[1][0]), vmin=0, vmax=1, ax=ax, 
            cbar=last_col, cbar_ax=cbar_ax);
ax.set_xlabel('Time Step Index')
ax.set_xticklabels([0, 1, 2])
ax.set_yticks([])
# ax.set_ylabel('Symbol')

# if i == 0:
#     ax.set_title(f'Experiment {j+1}')
#     ax.set_xticks([])
# elif i == 1:
# sns.lineplot(x='Epoch', y='Performance', hue='Experiment', 
#              data=df_test, ax=axs[0]);
# sns.lineplot(x='Epoch', y='Protocol Diversity', hue='Experiment', 
#              data=df_test, ax=axs[1]);
#              data=df_test[df_test['Experiment'] == 'p_m=0.2']);
# Put the legend out of the figure
# plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
# plt.title('Protocol Visu')
plt.show()

In [None]:
conf_matrix = compute_confusion_matrix(games_played)
sns.heatmap(conf_matrix, annot=True, vmin=0, vmax=1)
plt.title('Ground Truth Confusion Matrix')
plt.ylabel('Predicted Class')
plt.xlabel('Actual Class')
plt.show()