## Imports

In [1]:
import numpy as np
from tqdm.notebook import trange

from khan_helpers import Experiment
from khan_helpers.constants import FIG_DIR, N_PARTICIPANTS

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

Experiment & Participant classes, helper functions, and variables used across multiple notebooks can be found in `/mnt/code/khan_helpers/khan_helpers`, or on GitHub, [here](https://github.com/contextlab/efficient-learning-khan/tree/master/code/khan_helpers).<br />You can also view source code directly from the notebook with:<br /><pre>    from khan_helpers.functions import show_source<br />    show_source(foo)</pre>

In [2]:
exp = Experiment()

## Set plotting params

In [3]:
plt.rcParams['pdf.fonttype'] = 42
sns.set_context('paper')
cmap = 'bone'

## Uniformly shift data for plotting over heatmap

In [4]:
resolution = 100
embeddings = np.vstack((exp.forces_embedding, 
                        exp.bos_embedding, 
                        exp.question_embeddings))
x_min, y_min = embeddings.min(axis=0) // 1 - 3
x_max, y_max = embeddings.max(axis=0) // 1 + 3
x_step = (x_max - x_min) / resolution
y_step = (y_max - y_min) / resolution
xs = np.arange(x_min, x_max, x_step)
ys = np.arange(y_min, y_max, y_step)
vertices = np.array([(x_coord, y_coord) for y_coord in ys for x_coord in xs])

all_coords = [exp.forces_embedding, exp.bos_embedding, exp.question_embeddings, vertices]
split_inds = np.cumsum([arr.shape[0] for arr in all_coords])[:-1]
shifted = np.vstack(all_coords)
shifted -= shifted.min(axis=0)
shifted /= (shifted.max(axis=0) / resolution)

forces, bos, questions = np.vsplit(shifted, split_inds)[:-1]

## Plot individual knowledge & learning maps

*NOTE: PLOT TAKES ~20 MINUTES TO GENERATE*

In [10]:
# NOTE: plot takes ~20 minutes to generate
n_rows = N_PARTICIPANTS * 4
n_cols = 6

fig = plt.figure(figsize=(14, 8 * N_PARTICIPANTS))
# set up subplot layout
axarr = []
for sub_n in range(N_PARTICIPANTS):
    row_ix = sub_n * 4
    # row of 3 knowledge maps
    kmap1 = plt.subplot2grid((n_rows, n_cols), (row_ix, 0), colspan=2, rowspan=2)
    kmap2 = plt.subplot2grid((n_rows, n_cols), (row_ix, 2), colspan=2, rowspan=2)
    kmap3 = plt.subplot2grid((n_rows, n_cols), (row_ix, 4), colspan=2, rowspan=2)
    # row of 2 learning maps
    lmap1 = plt.subplot2grid((n_rows, n_cols), (row_ix + 2, 1), colspan=2, rowspan=2)
    lmap2 = plt.subplot2grid((n_rows, n_cols), (row_ix + 2, 3), colspan=2, rowspan=2)
    axarr.append([kmap1, kmap2, kmap3, lmap1, lmap2])
    
# call tight_layout before adding colorbar axes so bounding boxes are set properly
plt.tight_layout()

# hack to add colorbar axes to each row without cutting into rightmost heatmap
cbar_axarr = []
for sub_axes in axarr:
    kmap3_bbox = sub_axes[2].get_position()
    lmap2_bbox = sub_axes[4].get_position()
    kmap_cax = fig.add_axes([
        kmap3_bbox.xmax * 1.01, 
        kmap3_bbox.y0, 
        kmap3_bbox.width * 0.05, 
        kmap3_bbox.height
    ])
    lmap_cax = fig.add_axes([
        lmap2_bbox.xmax * 1.01, 
        lmap2_bbox.y0, 
        lmap2_bbox.width * 0.05, 
        lmap2_bbox.height
    ])
    cbar_axarr.append([kmap_cax, lmap_cax])

# loop over participants/axes/colorbar axes
for sub_n in trange(N_PARTICIPANTS):
    sub = exp.participants[sub_n]
    sub_axes = axarr[sub_n]
    sub_kmap_axes, sub_lmap_axes = sub_axes[:3], sub_axes[3:]
    kmap_cbar_ax, lmap_cbar_ax = cbar_axarr[sub_n]
    kmaps = [sub.get_kmap(f'forces_bos_qset{qset}') for qset in range(3)]
    lmaps = [kmaps[qset + 1] - kmaps[qset] for qset in range(2)]
    
    # plot knowledge map
    for i, kmap in enumerate(kmaps):
        ax = sub_kmap_axes[i]
        # turn colorbar on for last plot in row
        cbar = True if i == 2 else False
        cbar_ax = kmap_cbar_ax if cbar else None
        # create background heatmap
        sns.heatmap(kmap, 
                    vmin=0, 
                    vmax=1, 
                    xticklabels=[], 
                    yticklabels=[], 
                    cmap=cmap, 
                    cbar=cbar, 
                    ax=ax, 
                    cbar_ax=cbar_ax)
        # rasterize
        ax.collections[0].remove()
        ax.imshow(kmap, 
                  vmin=0, 
                  vmax=1, 
                  aspect='auto', 
                  cmap=cmap)
        
        # plot questions from current question set
        qset_data = sub.get_data(qset=i)
        g_qIDs, f_qIDs, b_qIDs = qset_data.groupby('lecture')['qID'].groups.values()
        
        ax.scatter(questions[f_qIDs - 1, 0], questions[f_qIDs - 1, 1], 
                   c='r', marker='o', s=50, alpha=.7)
        ax.scatter(questions[b_qIDs - 1, 0], questions[b_qIDs - 1, 1], 
                   c='g', marker='o', s=50, alpha=.7)
        ax.scatter(questions[g_qIDs - 1, 0], questions[g_qIDs - 1, 1], 
                   c='b', marker='o', s=50, alpha=.7)
        
        # overlay previously viewed lectures
        if i > 0:
            ax.plot(forces[:, 0], forces[:, 1], 'r--', alpha=.7, linewidth=2)
        if i == 2:
            ax.plot(bos[:, 0], bos[:, 1],'g--', alpha=.7, linewidth=2)
        
        ax.set_title(f'P{sub_n + 1} knowledge: question block {i + 1}', 
                     fontsize='large')
        
        # undo automatic y-axis inversion from sns.heatmap
        ax.invert_yaxis()
        
    # plot learning maps
    for i, lmap in enumerate(lmaps):
        ax = sub_lmap_axes[i]
        cbar = True if i == 1 else False
        cbar_ax = lmap_cbar_ax if cbar else None
        sns.heatmap(lmap, vmin=-1, vmax=1, xticklabels=[], 
                    yticklabels=[], cmap=cmap, cbar=cbar, ax=ax, cbar_ax=cbar_ax)
        ax.collections[0].remove()
        ax.imshow(lmap, vmin=-1, vmax=1, aspect='auto', cmap=cmap)
        
        # plot questions from "before" and "after" question sets
        qset_data = sub.get_data(qset=i)
        qset_data = sub.get_data(qset=range(i, i + 2))
        g_qIDs, f_qIDs, b_qIDs = qset_data.groupby('lecture')['qID'].groups.values()
        ax.scatter(questions[f_qIDs - 1, 0], 
                   questions[f_qIDs - 1, 1], 
                   c='r', 
                   marker='o', 
                   s=50, 
                   alpha=.7)
        ax.scatter(questions[b_qIDs - 1, 0], 
                   questions[b_qIDs - 1, 1], 
                   c='g', 
                   marker='o', 
                   s=50, 
                   alpha=.7)
        ax.scatter(questions[g_qIDs - 1, 0], 
                   questions[g_qIDs - 1, 1], 
                   c='b', 
                   marker='o', 
                   s=50, 
                   alpha=.7)
        
        # overlay lecture viewed viewed between the two question sets
        if i == 0:
            ax.plot(forces[:, 0], forces[:, 1], 'r--', alpha=.7, linewidth=2)
        else:
            ax.plot(bos[:, 0], bos[:, 1],'g--', alpha=.7, linewidth=2)
        
        ax.set_title(fr'P{sub_n + 1} learning: question block {i + 1} $\rightarrow$ {i + 2}', 
                     fontsize='large')
        ax.invert_yaxis()
    
# NOTE: DPI must be <= (2 ** 16) / 400 for matplotlib to save file
plt.savefig(FIG_DIR.joinpath('individual_maps_equal_scale.pdf'), bbox_inches='tight', dpi=150)
# figure is too large to display in notebook without crashing Jupyter kernel
plt.close()

HBox(children=(IntProgress(value=0, max=50), HTML(value='')))


