#  Import packages

In [None]:
import numpy as np
import h5py
import glob
import re
import tensorflow as tf

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from scipy.stats import pearsonr

import sys
sys.path.append('./..')
from src.training_utils import data_load, extract_floats, split_dataset, predict_multi_by_name, plot_violin_and_statistics,cross_mean_err_calculator

from tensorflow import keras
from keras import backend as K
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Conv2D,Flatten,Dropout,MaxPooling2D,BatchNormalization,AveragePooling2D,LeakyReLU,GlobalAveragePooling2D,ReLU

from cmcrameri import cm
import seaborn as sns
import pandas as pd

np.set_printoptions(precision=3, suppress=True)

# Set seed (optional)

In [None]:
fixed_seed = 216 #choose seed (comment out if not needed)

if 'fixed_seed' in locals():
    keras.utils.set_random_seed(fixed_seed)
    print("Running program with fixed seed:",fixed_seed)
else:
    print("Running program with random seed.")


# Setup GPU

First, follow instructions [here](https://gist.github.com/zrruziev/b93e1292bf2ee39284f834ec7397ee9f), or alternatively run:
```bash
for a in /sys/bus/pci/devices/*; do echo 0 | sudo tee -a $a/numa_node; done
```
We do this as a workaround for [this error](https://github.com/tensorflow/tensorflow/issues/42738):

In [None]:
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
for device in gpu_devices:
    tf.config.experimental.set_memory_growth(device, True)
print(tf.config.list_physical_devices('GPU'), tf.test.gpu_device_name())
print("TF Version:",tf.__version__)

# Define Functions

In [None]:
def violin_plotter (v,y_val,adjustment,legloc="upper left"):
    bins = np.logspace(-6,-1,10, base=2)*0.85

    #v = prediction2.T[0]

    colors = cm.batlowS(np.digitize(v, bins))
    colors_actual = cm.batlowS(np.digitize(np.unique(y_val),bins))

    fig, (ax1,ax2) = plt.subplots(nrows=2,ncols=1,figsize=(9,6),dpi=600)

    df = pd.DataFrame()
    df.insert(0, "predicted", v - y_val)
    df.insert(1, "actual", y_val)

    sns.violinplot(
        ax=ax1,
        data=df,
        x="actual",
        y="predicted",
        color="w",
        alpha=0.7,
        density_norm="width",
        linewidth=1,
        inner="box",
        inner_kws={"box_width": 4, "color": "0.2"},
        )

    ax1.set_xlabel("Actual turning rate")
    ax1.set_ylabel(r"Prediction Difference $P_{pred}-P_{true}$")

    std = []
    means = []
    overlap = []
    std_div = []
    accuracy = 5e-3
    print ("Prediction means and standard deviations.")
    for val in np.unique(y_val):
        v_mapped = v[np.where(y_val == val)]
        stdev = np.std(v_mapped)
        std.append(stdev)
        mean = np.mean(v_mapped)
        overlap.append((val + accuracy >= np.min(v_mapped)) & (val - accuracy <= np.max(v_mapped)))
        within_std = abs(val-mean)/stdev
        print (f"Actual value {val}: Average = {mean:.5f} +- {stdev:.5f}; Expected value within {within_std:.3f} stdevs of mean")
        std_div.append(within_std)

    print(f"With accuracy {accuracy}, overlap ratio:", np.sum(overlap)/len(overlap))
    print("(Min, Max, Avg) STD:", np.min(std), np.max(std), np.mean(std))
    print("Pearson's correlation coeff: ", pearsonr(y_val, v).statistic)



    for val in np.unique(y_val):
        v_mapped = v[np.where(y_val == val)]
        means.append(np.mean(v_mapped))

    ax2.errorbar(np.sort(np.unique(y_val)),np.abs(means-np.sort(np.unique(y_val))),yerr=(std),ecolor='black',elinewidth=0.5,capsize=3,color='purple',label=r'$|\langle P_{pred} \rangle -P_{true}|$')
    ax2.plot(np.sort(np.unique(y_val)),np.zeros(np.unique(y_val).shape[0]),color='red',label='True value line',linestyle='dotted',alpha=0.5)


    ax2.legend(loc=legloc)

    counter = 0
    for i in np.sort(np.unique(y_val)):
        ax2.text(i,adjustment,f"${std_div[counter]:.3f} \sigma$",ha="center")
        counter = counter + 1

    ax2.set_xscale("log")
    ax2.get_xaxis().set_major_formatter(ticker.ScalarFormatter())
    ax2.set_xticks(np.unique(y_val))

    ax2.set_xlabel("Actual turning rate")
    ax2.set_ylabel("Absolute mean prediction difference")

    fig.tight_layout()

# Import and prepare data

set model1 to have orientation, model2 to be monochrome, model3 to be scrambled

In [None]:
#all alphas: [0.016,0.023,0.034,0.050,0.073,0.107,0.157,0.231,0.340,0.500]
#NOTE: we're using interpolation alphas below, so the loaded alphas should be the midpoints of the above!
#all densities: [0.05,0.10,0.15,0.20,0.25,0.30,0.35,0.40,0.45,0.50,0.55,0.60,0.65,0.70,0.75,0.80,0.85,0.90,0.95]
#MIDWAY INTERPOLATION alphas=[0.019,0.028,0.042,0.061,0.090,0.132,0.194,0.286,0.420]
#MORE ARBITRARY INTERPOLATION alphas = [0.017,0.019,0.020,0.022,0.024,0.027,0.029,0.032,0.035,0.038,0.042,0.045,0.054,0.059,0.065,0.071,0.077,0.085,0.092,0.101,0.110,0.121,0.132,0.144,0.172,0.188,0.206,0.225,0.246,0.268,0.293,0.321,0.350,0.383,0.419,0.457]
x1,y1,shape1 = data_load(alphas = [0.017,0.019,0.020,0.022,0.024,0.027,0.029,0.032,0.035,0.038,0.042,0.045,0.054,0.059,0.065,0.071,0.077,0.085,0.092,0.101,0.110,0.121,0.132,0.144,0.172,0.188,0.206,0.225,0.246,0.268,0.293,0.321,0.350,0.383,0.419,0.457], densities=[0.25],orientation=True,scrambled=False,path=f"../data/no-rolling/interpolation-set/")
x2,y2,shape2 = data_load(alphas = [0.017,0.019,0.020,0.022,0.024,0.027,0.029,0.032,0.035,0.038,0.042,0.045,0.054,0.059,0.065,0.071,0.077,0.085,0.092,0.101,0.110,0.121,0.132,0.144,0.172,0.188,0.206,0.225,0.246,0.268,0.293,0.321,0.350,0.383,0.419,0.457], densities=[0.25],orientation=False,scrambled=False,path=f"../data/no-rolling/interpolation-set/")
x3,y3,shape3 = data_load(alphas = [0.017,0.019,0.020,0.022,0.024,0.027,0.029,0.032,0.035,0.038,0.042,0.045,0.054,0.059,0.065,0.071,0.077,0.085,0.092,0.101,0.110,0.121,0.132,0.144,0.172,0.188,0.206,0.225,0.246,0.268,0.293,0.321,0.350,0.383,0.419,0.457], densities=[0.25],orientation=False,scrambled=True,path=f"../data/no-rolling/interpolation-set/")

We have N * number of unique alpha snapshots total, we split them into training set and a validation set with the ratio 80/20:

In [None]:
print("Orientation model:")
x_train1, y_train1, x_val1, y_val1 = split_dataset(x1,y1,last=int(len(x1)*1)) #len(x)*1 means no training, only validation!
x_train2, y_train2, x_val2, y_val2 = split_dataset(x2,y2,last=int(len(x1)*1)) #len(x)*1 means no training, only validation!
x_train3, y_train3, x_val3, y_val3 = split_dataset(x3,y3,last=int(len(x1)*1)) #len(x)*1 means no training, only validation!



In [None]:
fig, (ax1,ax2,ax3) = plt.subplots(nrows=1,ncols=3)
ax1.matshow(x_val1[500],cmap=plt.get_cmap(name="gnuplot",lut=5))
ax2.matshow(x_val2[500],cmap=plt.get_cmap(name="gnuplot",lut=5))
ax3.matshow(x_val3[500],cmap=plt.get_cmap(name="gnuplot",lut=5))

# Predict multiple models

In [None]:
models_one = ['orientation0216','orientation0226','orientation0236','orientation0246','orientation0256','orientation0266','orientation0276','orientation0286','orientation0296','orientation0306'] #array of model names
models_two = ['monochrome0216','monochrome0226','monochrome0236','monochrome0246','monochrome0256','monochrome0266','monochrome0276','monochrome0286','monochrome0296','monochrome0306'] # array of model names
models_three = ['scrambled0216','scrambled0226','scrambled0236','scrambled0246','scrambled0256','scrambled0266','scrambled0276','scrambled0286','scrambled0296','scrambled0306'] # array of model names

one_pred_of_one, one_actuals_of_one = predict_multi_by_name(models_one,x_val1,y_val1)
one_pred_of_two, one_actuals_of_two = predict_multi_by_name(models_one,x_val2,y_val2)
one_pred_of_three, one_actuals_of_three = predict_multi_by_name(models_one,x_val3,y_val3)

two_pred_of_two, two_actuals_of_two = predict_multi_by_name(models_two,x_val2,y_val2)
two_pred_of_one, two_actuals_of_one = predict_multi_by_name(models_two,x_val1,y_val1)
two_pred_of_three, two_actuals_of_three = predict_multi_by_name(models_two,x_val3,y_val3)

three_pred_of_three, three_actuals_of_three = predict_multi_by_name(models_three,x_val3,y_val3)
three_pred_of_one, three_actuals_of_one = predict_multi_by_name(models_three,x_val1,y_val1)
three_pred_of_two, three_actuals_of_two = predict_multi_by_name(models_three,x_val2,y_val2)

# Plot one results

## Predict one on one

In [None]:
violin_plotter(one_pred_of_one,one_actuals_of_one,-0.001)

## Predict one on two

In [None]:
violin_plotter(one_pred_of_two,one_actuals_of_two,0.04,'lower left')

## Predict one on three

In [None]:
violin_plotter(one_pred_of_three,one_actuals_of_three,0.04,'lower left')

# Plot monochrome results

## Predict two on two

In [None]:
violin_plotter(two_pred_of_two,two_actuals_of_two,-0.002)

## Predict two on one

In [None]:
violin_plotter(two_pred_of_one,two_actuals_of_one,0.07,'lower left')

## Predict two on three

In [None]:
violin_plotter(two_pred_of_three,two_actuals_of_three,0.07,'lower left')

# Plot Scrambled Results

## Predict three on three

In [None]:
violin_plotter(three_pred_of_three,three_actuals_of_three,0.07,'lower left')

## Predict three on one

In [None]:
violin_plotter(three_pred_of_one,three_actuals_of_one,0.07,'lower left')

## Predict three on two

In [None]:
violin_plotter(three_pred_of_two,three_actuals_of_two,0.07,'lower left')

## Combined plots

In [None]:
#predictions on own kind
#means1,std1,means2,std2,means3,std3=cross_mean_err_calculator(one_pred_of_one,one_actuals_of_one,two_pred_of_two,two_actuals_of_two,three_pred_of_three,three_actuals_of_three,three_cases=True)

fig, (ax1,ax2) = plt.subplots(nrows=1,ncols=2,figsize=(16,6),dpi=600)

#NOTE: I've commented out the previous method of plotting due to its inelegance, but I have kept it for posterity as a way of plotting error bands without seaborn.

#ONE ON ONE
#ax1.errorbar(np.sort(np.unique(orientation_actuals_of_orientation)),np.abs(means1-np.sort(np.unique(orientation_actuals_of_orientation))),yerr=(std1),ecolor='blue',elinewidth=0.5,capsize=3,color='blue',label="Trained Orientation Predicting Orientation")
#ax1.plot(np.sort(np.unique(one_actuals_of_one)),np.abs(means1-np.sort(np.unique(one_actuals_of_one))),'b-',label="Trained Orientation Predicting Orientation")
#ax1.fill_between(np.sort(np.unique(one_actuals_of_one)),np.abs(means1-np.sort(np.unique(one_actuals_of_one)))-std1,np.abs(means1-np.sort(np.unique(one_actuals_of_one)))+std1,color='b',alpha=0.25)

#TWO ON TWO
#ax1.errorbar(np.sort(np.unique(monochrome_actuals_of_monochrome)),np.abs(means2-np.sort(np.unique(monochrome_actuals_of_monochrome))),yerr=(std2),ecolor='red',elinewidth=0.5,capsize=3,color='red',label="Trained Monochrome Predicting Monochrome")
#ax1.plot(np.sort(np.unique(two_actuals_of_two)),np.abs(means2-np.sort(np.unique(two_actuals_of_two))),'r-',label="Trained Monochrome Predicting Monochrome")
#ax1.fill_between(np.sort(np.unique(two_actuals_of_two)),np.abs(means2-np.sort(np.unique(two_actuals_of_two)))-std2,np.abs(means2-np.sort(np.unique(two_actuals_of_two)))+std2,color='r',alpha=0.25)

#THREE ON THREE
#ax1.plot(np.sort(np.unique(three_actuals_of_three)),np.abs(means3-np.sort(np.unique(three_actuals_of_three))),'g-',label="Trained Scrambled Predicting Scrambled")
#ax1.fill_between(np.sort(np.unique(three_actuals_of_three)),np.abs(means3-np.sort(np.unique(three_actuals_of_three)))-std3,np.abs(means3-np.sort(np.unique(three_actuals_of_three)))+std3,color='g',alpha=0.25)

#ZERO LINE, LEGEND, AX1 PLOT CONFIG
#ax1.plot(np.sort(np.unique(one_actuals_of_one)),np.zeros(np.unique(one_actuals_of_one).shape[0]),color='black',label='True value line',linestyle='dotted',alpha=0.5)
#ax1.legend(loc='upper left')
#ax1.set_xscale("log")
#ax1.get_xaxis().set_major_formatter(ticker.ScalarFormatter())
#ax1.set_xticks(np.unique(y_val1))
#ax1.set_xlabel("Actual turning rate")
#ax1.set_ylabel("Absolute mean prediction difference")

df_one = pd.DataFrame()
df_two = pd.DataFrame()
df_three = pd.DataFrame()
df_one.insert(0,"predicted",np.abs(one_pred_of_one-one_actuals_of_one))
df_one.insert(1,"actuals",one_actuals_of_one)
df_two.insert(0,"predicted",np.abs(two_pred_of_two-two_actuals_of_two))
df_two.insert(1,"actuals",np.abs(two_actuals_of_two))
df_three.insert(0,"predicted",np.abs(three_pred_of_three-three_actuals_of_three))
df_three.insert(1,"actuals",np.abs(three_actuals_of_three))
df_one['Data Type']='Orientation'
df_two['Data Type']='Monochrome'
df_three['Data Type']='Scrambled'
cdf = pd.concat([df_one,df_two,df_three])
#print(cdf.head())

sns.lineplot(ax=ax1,
             x="actuals",
             y="predicted",
             hue="Data Type",
             data=cdf,
             errorbar="sd",
             palette={"Orientation": "blue", "Monochrome": "red", "Scrambled": "green"})

sns.boxplot(ax=ax2,
            data=cdf,
            x="actuals",
            y="predicted",
            hue="Data Type",
            fill=False,
            gap=.4,
            whis=(0,100),
            width=.5,
            palette={"Orientation": "blue", "Monochrome": "red", "Scrambled": "green"})

ax1.set_xlabel("Actual turning rate")
ax2.set_xlabel("Actual turning rate")
ax1.set_ylabel("Absolute mean prediction difference")
ax2.set_ylabel("Absolute mean prediction difference")
ax2.set_title("Interquartile Range Comparison of Self-Interpolation")
ax1.set_title("Standard Deviation Comparison of Self-Interpolation")

ax1.set_xscale("log")
ax1.get_xaxis().set_major_formatter(ticker.ScalarFormatter())
ax1.set_xticks(np.unique(y_val1))
ax1.set_ylim([0,0.12])
ax2.set_ylim([0,0.25])

#handles,labels=ax2.get_legend_handles_labels()
#ax2.legend(handles=handles[1:],labels=labels[1:]) #this should fix hue title appearing in rightmost figure legend
ax1.legend(loc="upper left")


#predictions on other kind
#means1,std1,means2,std2,means3,std3=cross_mean_err_calculator(one_pred_of_two,one_actuals_of_two,two_pred_of_one,two_actuals_of_one,three_cases=False)

#ONE ON TWO
#ax2.errorbar(np.sort(np.unique(orientation_actuals_of_monochrome)),np.abs(means1-np.sort(np.unique(orientation_actuals_of_monochrome))),yerr=(std1),ecolor='blue',elinewidth=0.5,capsize=3,color='blue',label="Trained Orientation Predicting Monochrome")
#ax2.plot(np.sort(np.unique(one_actuals_of_two)),np.abs(means1-np.sort(np.unique(one_actuals_of_two))),'b-',label="Trained Orientation Predicting Monochrome")
#ax2.fill_between(np.sort(np.unique(one_actuals_of_two)),np.abs(means1-np.sort(np.unique(one_actuals_of_two)))-std1,np.abs(means1-np.sort(np.unique(one_actuals_of_two)))+std1,color='b',alpha=0.25)

#TWO ON ONE
#ax2.errorbar(np.sort(np.unique(monochrome_actuals_of_orientation)),np.abs(means2-np.sort(np.unique(monochrome_actuals_of_orientation))),yerr=(std2),ecolor='red',elinewidth=0.5,capsize=3,color='red',label="Trained Monochrome Predicting Orientation")
#ax2.plot(np.sort(np.unique(two_actuals_of_one)),np.abs(means2-np.sort(np.unique(two_actuals_of_one))),'r-',label="Trained Monochrome Predicting Orientation")
#ax2.fill_between(np.sort(np.unique(two_actuals_of_one)),np.abs(means2-np.sort(np.unique(two_actuals_of_one)))-std2,np.abs(means2-np.sort(np.unique(two_actuals_of_one)))+std2,color='r',alpha=0.25)

#ZERO LINE, LEGEND, AX2 PLOT CONFIG
#ax2.plot(np.sort(np.unique(one_actuals_of_two)),np.zeros(np.unique(one_actuals_of_two).shape[0]),color='black',label='True value line',linestyle='dotted',alpha=0.5)
#ax2.legend(loc='upper left')
#ax2.set_xscale("log")
#ax2.get_xaxis().set_major_formatter(ticker.ScalarFormatter())
#ax2.set_xticks(np.unique(y_val1))
#ax2.set_xlabel("Actual turning rate")
#ax2.set_ylabel("Absolute mean prediction difference")

fig.tight_layout()

In [None]:
df12=pd.DataFrame()
df21=pd.DataFrame()
df13=pd.DataFrame()
df31=pd.DataFrame()
df23=pd.DataFrame()
df32=pd.DataFrame()
df12.insert(0,"predicted",np.abs(one_pred_of_two-one_actuals_of_two))
df12.insert(1,"actuals",one_actuals_of_two)
df21.insert(0,"predicted",np.abs(two_pred_of_one-two_actuals_of_one))
df21.insert(1,"actuals",two_actuals_of_one)
df13.insert(0,"predicted",np.abs(one_pred_of_three-one_actuals_of_three))
df13.insert(1,"actuals",one_actuals_of_three)
df31.insert(0,"predicted",np.abs(three_pred_of_one-three_actuals_of_one))
df31.insert(1,"actuals",three_actuals_of_one)
df23.insert(0,"predicted",np.abs(two_pred_of_three-two_actuals_of_three))
df23.insert(1,"actuals",two_actuals_of_three)
df32.insert(0,"predicted",np.abs(three_pred_of_two-three_actuals_of_two))
df32.insert(1,"actuals",three_actuals_of_two)

fig,ax = plt.subplots(nrows=3,ncols=2,figsize=(12,16),dpi=600)

#PREDICTING ONE

df21['Data Type']='Monochrome'
df31['Data Type']='Scrambled'
cdf = pd.concat([df_one,df21,df31])

sns.lineplot(ax=ax[0][0],
             x="actuals",
             y="predicted",
             hue="Data Type",
             data=cdf,
             errorbar="sd",
             palette={"Orientation": "blue", "Monochrome": "red", "Scrambled": "green"})

sns.boxplot(ax=ax[0][1],
            data=cdf,
            x="actuals",
            y="predicted",
            hue="Data Type",
            fill=False,
            gap=.4,
            whis=(0,100),
            width=.5,
            palette={"Orientation": "blue", "Monochrome": "red", "Scrambled": "green"})



#PREDICTING TWO

df12['Data Type']='Orientation'
df32['Data Type']='Scrambled'
cdf = pd.concat([df12,df_two,df32])

sns.lineplot(ax=ax[1][0],
             x="actuals",
             y="predicted",
             hue="Data Type",
             data=cdf,
             errorbar="sd",
             palette={"Orientation": "blue", "Monochrome": "red", "Scrambled": "green"})

ax[0][0].legend(loc='upper left')
ax[0][1].legend(loc='upper right')

sns.boxplot(ax=ax[1][1],
            data=cdf,
            x="actuals",
            y="predicted",
            hue="Data Type",
            fill=False,
            gap=.4,
            whis=(0,100),
            width=.5,
            palette={"Orientation": "blue", "Monochrome": "red", "Scrambled": "green"})

ax[1][0].legend(loc='upper left')
ax[1][1].legend(loc='upper left')

#PREDICTING THREE

df13['Data Type']='Orientation'
df23['Data Type']='Monochrome'
cdf = pd.concat([df13,df23,df_three])

sns.lineplot(ax=ax[2][0],
             x="actuals",
             y="predicted",
             hue="Data Type",
             data=cdf,
             errorbar="sd",
             palette={"Orientation": "blue", "Monochrome": "red", "Scrambled": "green"})

sns.boxplot(ax=ax[2][1],
            data=cdf,
            x="actuals",
            y="predicted",
            hue="Data Type",
            fill=False,
            gap=.4,
            whis=(0,100),
            width=.5,
            palette={"Orientation": "blue", "Monochrome": "red", "Scrambled": "green"})

ax[2][0].legend(loc='upper left')

for i,examiner in enumerate(["Standard Deviation","Interquartile Range"]):
    for j,examined in enumerate(["Orientation","Monochrome","Scrambled"]):
        ax[j][i].set_xlabel("Actual turning rate")
        ax[j][i].set_ylabel("Absolute mean prediction difference")
        ax[j][i].set_title(f"{examiner} Comparison Interpolating on {examined}")
        if i == 0:
            ax[j][i].set_xscale("log")
            ax[j][i].get_xaxis().set_major_formatter(ticker.ScalarFormatter())
            ax[j][i].set_xticks(np.unique(y_val1))



fig.tight_layout()
