In [None]:
DEBUG =False #True # #  False #

import os
# if not DEBUG:
#     os.environ["WANDB_SILENT"] = "true"

In [None]:
from IPython.display import display
import warnings
warnings.filterwarnings('ignore')
from pyleaves.utils import set_tf_config
set_tf_config(num_gpus=1)

import wandb
from wandb.keras import WandbCallback
# wandb.login()

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Conv2D, BatchNormalization, MaxPool2D, ReLU, ELU, LeakyReLU, Flatten, Dense, Add, AveragePooling2D, GlobalAveragePooling2D
from tensorflow.keras.layers.experimental.preprocessing import StringLookup, CategoryEncoding
import tensorflow_datasets as tfds
import numpy as np
##
# Always randomly set the seed + log value for future replication for experiments by default,
# but if DEBUG is set to True, then seed all random generators with a hard coded value
##
import random
nb_seed = random.randint(0,1e5)
if DEBUG:
    nb_seed = 374  
np.random.seed(nb_seed)
tf.random.set_seed(nb_seed)

import os
import pprint
pp = pprint.PrettyPrinter(indent=4)

import matplotlib.pyplot as plt
plt.style.use('science')
from typing import List, Tuple, Union, Dict, NamedTuple
from omegaconf import OmegaConf

# from tfrecord_utils.img_utils import resize_repeat
# from boltons.funcutils import partial
# import logging
# logger = logging.getLogger('')

LOG_DIR = '/media/data/jacob/GitHub/experiment_results/evolution_logs'
os.makedirs(LOG_DIR, exist_ok=True)
from paleoai_data.utils.logging_utils import get_logger
logger = get_logger(logdir=LOG_DIR, filename='generation_evolution_logs.log', append=True)

VERBOSE = True
import pandas as pd
import json
import jsonpickle
from box import Box
from bunch import Bunch
import copy
import gc

from genetic_algorithm.datasets.plant_village import ClassLabelEncoder, load_and_preprocess_data
from genetic_algorithm import stateful

from genetic_algorithm.chromosome import ChromosomeSampler
from genetic_algorithm.organism.organism import Organism
from genetic_algorithm.generation.generation import Generation
from genetic_algorithm.plotting import log_high_loss_examples

In [None]:
exp_config = OmegaConf.create({'seed':756, #237,
                               'batch_size':8,#16,
                               'input_shape':(224,224,3),
                               'output_size':38,
                               'epochs_per_organism':3,
                               'results_dir':'/media/data_cifs_lrs/projects/prj_fossils/users/jacob/experiments/Nov2020-Jan2021',
                               'experiment_uid':str(np.random.randint(0,1e10))
                              })
exp_config.model_dir = os.path.join(exp_config.results_dir,exp_config.experiment_uid)

data_config = OmegaConf.create({'load':{},'preprocess':{}})

data_config['load'] = {'dataset_name':'plant_village',
                       'split':['train[0%:60%]','train[60%:70%]','train[70%:100%]'],
                       'data_dir':'/media/data/jacob/tensorflow_datasets'}

data_config['preprocess'] = {'batch_size':exp_config.batch_size,
                             'target_size':exp_config.input_shape[:2]}

generation_config = OmegaConf.create({
                                      'population_size':5,
                                      'num_generations_per_phase':3,
                                      'fitSurvivalRate': 0.5,
                                      'unfitSurvivalProb':0.2,
                                      'mutationRate':0.1,
                                      'num_phases':5
                                    })
organism_config = OmegaConf.create({'input_shape':exp_config.input_shape,
                                    'output_size':exp_config.output_size,
                                    'epochs_per_organism':exp_config.epochs_per_organism,
                                    'model_dir':exp_config.model_dir,
                                    'experiment_uid':exp_config.experiment_uid})

if DEBUG:
    exp_config = OmegaConf.create({'seed':6227,
                                   'batch_size':16,
                                   'input_shape':(64,64,3),
                                   'output_size':38,
                                   'epochs_per_organism':1,
                                   'results_dir':'/media/data_cifs_lrs/projects/prj_fossils/users/jacob/experiments/Nov2020-Jan2021/debugging_trials',
                                   'experiment_uid':str(np.random.randint(0,1e10))
                                  })
    exp_config.model_dir = os.path.join(exp_config.results_dir,exp_config.experiment_uid)

    data_config = OmegaConf.create({'load':{},'preprocess':{}})
    data_config['load'] = {'dataset_name':'plant_village',
                           'split':['train[0%:60%]','train[60%:70%]','train[70%:100%]'],
                           'data_dir':'/media/data/jacob/tensorflow_datasets'}

    data_config['preprocess'] = {'batch_size':exp_config.batch_size,
                                 'target_size':exp_config.input_shape[:2]}

    generation_config = OmegaConf.create({
                                          'population_size':3,
                                          'num_generations_per_phase':2,
                                          'fitSurvivalRate': 0.5,
                                          'unfitSurvivalProb':0.2,
                                          'mutationRate':0.1,
                                          'num_phases':3
                                        })
    organism_config = OmegaConf.create({'input_shape':exp_config.input_shape,
                                        'output_size':exp_config.output_size,
                                        'epochs_per_organism':exp_config.epochs_per_organism,
                                        'model_dir':exp_config.model_dir,
                                        'experiment_uid':exp_config.experiment_uid})

config = OmegaConf.create({
                            'experiment':exp_config,
                            'data':data_config,
                            'generation':generation_config,
                            'organism':organism_config
})
print(config.pretty())

### Load and preprocess data, and acquire the class_encoder
#### Finally, alter config values based on whether running in DEBUG mode or not

In [None]:
data, class_encoder = load_and_preprocess_data(config['data'])
if DEBUG:
    config.organism.steps_per_epoch = 150
    config.organism.validation_steps = 150
else:
    config.organism.steps_per_epoch = len(data['train'])
    config.organism.validation_steps = len(data['val'])

# Organism
An organism contains the following:

1. phase - This denotes which phase does the organism belong to
2. chromosome - A dictionary of genes (hyperparameters)
3. model - The `tf.keras` model corresponding to the chromosome
4. best_organism - The best organism in the previous **phase**

# Generation
This is a class that hold generations of models.

1. fitSurvivalRate - The amount of fit individuals we want in the next generation.
2. unfitSurvivalProb - The probability of sending unfit individuals
3. mutationRate - The mutation rate to change genes in an individual.
4. phase - The phase that the generation belongs to.
5. population_size - The amount of individuals that the generation consists of.
6. best_organism - The best organism (individual) is the last phase

# MAIN TRAINING LOOP

In [None]:
os.environ['WANDB_NOTEBOOK_NAME'] = 'generation_main.ipynb'
%reload_ext memory_profiler
%cd /media/data/jacob/GitHub/genetic_algorithm
%pwd

In [None]:
%%memit
from main import main
best_organism, last_generation = main(data=data, config=config, best_organism = None, class_encoder=class_encoder, verbose=True, debug=DEBUG)

In [None]:
last_generation.population

from genetic_algorithm.plotting import *
import wandb
import os
os.environ['WANDB_MODE'] = 'dryrun'

run = wandb.init(entity='jrose')

# log_multiclass_metrics(best_organism.test_data, 
#                        best_organism.model,
#                        data_split_name='test', 
#                        class_encoder=class_encoder,
#                        log_predictions=True,
#                        max_rows=1000,
#                        run=run,
#                        commit=True)

dataset = data['test']
max_rows=100
model = best_organism.model
x, y_true = get_1_epoch_from_tf_data(dataset, max_rows=max_rows)
y_true = y_true.numpy()

y_prob, y_pred, losses = get_predictions(x, y_true, model)
y_true = np.argmax(y_true, axis=1)
print(y_true.shape, y_pred.shape)

In [None]:
y_true = np.random.randint(0, 38, size=1000)
y_pred = np.random.randint(0, 38, size=1000)
labels = np.arange(10)
target_names = list("ABCDEFGHI")

In [None]:
from sklearn.metrics import classification_report
import seaborn as sns


def plot_classification_report(y_true, y_pred, target_names=None):
    report = classification_report(y_true, y_pred, target_names=None, output_dict=True)

    per_class_metrics = pd.DataFrame(report).T.iloc[:-3,:-1]
    class_support = pd.DataFrame(report).T.iloc[:-3,-1:]
    mean_metrics = pd.DataFrame(report).T.iloc[-3:,:-1]


    fig, ax = plt.subplots(1,3, figsize=(15,20))
    # sns.heatmap(pd.DataFrame(report).iloc[:-1,:].T, annot=True)
    # sns.heatmap(pd.DataFrame(report).T, annot=True, ax=ax[0])
#     cmap = sns.diverging_palette(230, 20, as_cmap=True)
    palette = sns.crayon_palette(list(sns.crayons.keys()))

    sns.set_palette(palette)

    sns.heatmap(mean_metrics, annot=True, ax=ax[0])#, cmap="Dark2")
    sns.heatmap(per_class_metrics, annot=True, ax=ax[1])#,cmap=cmap)#"YlOrBr_r")
    sns.heatmap(class_support, annot=True, ax=ax[2])

plot_classification_report(y_true, y_pred, target_names=None)

In [None]:
sns.pairplot(df, hue="sex", height=2.5)

In [None]:
pd.DataFrame(report).T

In [None]:
sample_report = pd.DataFrame(report).T.iloc[:,:-1]

In [None]:
sample_report

In [None]:
fig = plt.figure(figsize=(18,9))
sample_report.plot()#kind='bar', width=2.5)

In [None]:
sns.barplot(data=sample_report.iloc[:-3,:].T, width)

In [None]:
df = pd.DataFrame({'y_pred':y_pred,
             'y_true':y_true}).value_counts()

sns.barplot(data=df)

In [None]:
sns.countplot(x='recall', data=sample_report)

In [None]:
sns.set_theme(style="white")

df = sns.load_dataset("penguins")

g = sns.JointGrid(data=df, x="body_mass_g", y="bill_depth_mm", space=0)



g.plot_joint(sns.heatmap,data=df,x="body_mass_g", y="bill_depth_mm",
             fill=True, clip=((2200, 6800), (10, 25)),
             thresh=0, levels=100, cmap="rocket")

# g.plot_joint(sns.kdeplot,
#              fill=True, clip=((2200, 6800), (10, 25)),
#              thresh=0, levels=100, cmap="rocket")



In [None]:
import numpy as np
import seaborn as sns
sns.set_theme(style="ticks")

# rs = np.random.RandomState(11)
# x = rs.gamma(2, size=1000)
# y = -.5 * x + rs.normal(size=1000)

x = y_pred
y = y_true


sns.jointplot(x=x, y=y, kind="hex", color="#4CB391")

In [None]:
g.plot_joint(sns.heatmap,
             fill=True, clip=((2200, 6800), (10, 25)),
             thresh=0, levels=100, cmap="rocket")

g.plot_marginals(sns.histplot, color="#03051A", alpha=1, bins=25)

In [None]:
sns.set_theme(style="whitegrid")

# Initialize the matplotlib figure
f, ax = plt.subplots(figsize=(6, 15))

# Load the example car crash dataset
crashes = sns.load_dataset("car_crashes").sort_values("total", ascending=False)

# Plot the total crashes
sns.set_color_codes("pastel")
sns.barplot(x="total", y="abbrev", data=crashes,
            label="Total", color="b")

# Plot the crashes where alcohol was involved
sns.set_color_codes("muted")
sns.barplot(x="alcohol", y="abbrev", data=crashes,
            label="Alcohol-involved", color="b")

# Add a legend and informative axis label
ax.legend(ncol=2, loc="lower right", frameon=True)
ax.set(xlim=(0, 24), ylabel="",
       xlabel="Automobile collisions per billion miles")
sns.despine(left=True, bottom=True)

In [None]:
sns.set_theme()

# Load the brain networks example dataset
df = sns.load_dataset("brain_networks", header=[0, 1, 2], index_col=0)

# Select a subset of the networks
used_networks = [1, 5, 6, 7, 8, 12, 13, 17]
used_columns = (df.columns.get_level_values("network")
                          .astype(int)
                          .isin(used_networks))
df = df.loc[:, used_columns]

# Create a categorical palette to identify the networks
network_pal = sns.husl_palette(8, s=.45)
network_lut = dict(zip(map(str, used_networks), network_pal))

# Convert the palette to vectors that will be drawn on the side of the matrix
networks = df.columns.get_level_values("network")
network_colors = pd.Series(networks, index=df.columns).map(network_lut)

# Draw the full plot
g = sns.clustermap(df.corr(), center=0, cmap="vlag",
                   row_colors=network_colors, col_colors=network_colors,
                   dendrogram_ratio=(.1, .2),
                   cbar_pos=(.02, .32, .03, .2),
                   linewidths=.75, figsize=(12, 13))

g.ax_row_dendrogram.remove()

In [None]:
report.keys()

In [None]:
cm = ConfusionMatrix(actual_vector=y_true, predict_vector=y_pred)


fig, ax = plot_confusion_matrix(cm,normalize=True, title='Confusion matrix', annot=False, cmap="YlGnBu")

In [None]:
dir()

In [None]:
# from genetic_algorithm.plotting import *

# model = best_organism.model
# test_dataset = data['test']
# max_rows=1000
# k = 32
# x, y_true = get_1_epoch_from_tf_data(test_dataset, max_rows=max_rows)
# y_true = y_true.numpy()
# y_prob, y_pred, losses = get_predictions(x, y_true, model)
# highest_k_losses, hardest_k_examples, hardest_k_true_labels, hardest_k_predictions = get_hardest_k_examples(x, y_true, y_pred, losses, model, k=k)

# %debug

# %debug

# dir()

In [None]:
# from genetic_algorithm.chromosome import sampler
# sampler(0)

# phases = []
# sampler=ChromosomeSampler()
# phases.append(sampler(phase=0))
# phases.append(sampler(phase=1))

# child_chromosome = sampler(phase=0)#.generate_chromosome(phase=0)
# child = Organism(chromosome=child_chromosome,
#                  data=data,
#                  config=config['data'],
#                  phase=0,
#                  generation_number=0,
#                  organism_id=0,
#                  best_organism=None)

# (child._chromosome._state)

# (child._chromosome.get_state())

# # %%writefile main.py

# # def main()
# # print(config)
# # best_organism = None
# # for phase in range(config.generation.num_phases):
# #     print("PHASE {}".format(phase))
# #     generation = Generation(data=data,
# #                             generation_config=config['generation'],
# #                             organism_config=config['organism'],
# #                             phase=phase,
# #                             previous_best_organism=best_organism,
# #                             verbose=VERBOSE,
# #                             DEBUG=DEBUG)
    
# #     best_organism = generation.run_phase()

# (child_chromosome.get_chromosome(True))

============================================
## TODO
========



* Implement saving
* Log per-class plots for TP, TN, FP, FN, Recall, Precision
* Fix test data loading for get_hardest_k_samples

* Create dummy test model
* Feed a ResNet backbone in as the previous best model
* Follow that by feeding an arbitrary loaded model

* Refactor according to keras idiomatic programmer style

* Generate a DAG containing every model architecture, drawing connections between nodes indicating a type of step (e.g. mutation or crossover) to see the evolution of structure as a tree
    * Consider passing run id down from each surviving organism to unify runs linking eachunique model
    * Provide each organism a self._parent attribute pointing to identifying info about the parent.


### Experiment Idea (12/1/20)
* Optimize each generation by maximizing validation recall, rather than validation accuracy

### Experiment 2
* Add initalizers HeNormal and HeUniform to chromosome
[HeNormal] (https://www.tensorflow.org/api_docs/python/tf/keras/initializers/HeNormal)

### Experiment 3
- Label smoothing


### Experiment 4 (Added 2 AM 12/2/20)
- Integrate multi-dataset transfer learning workflow

1. Alternate phase_0 -> dataset_0, phase_1 -> dataset_1
    Then, either
    a. phase_2 -> back to phase_0
    or
    b. phase_2 -> Interleave samples from dataset_0 and dataset_1

### Experiment 5 (Added 5:15 AM 12/2/20)

#### Signal vs Noise Scenario for metalearning/hparam search
Based on: https://wandb.ai/stacey/pytorch_intro/reports/Meaning-and-Noise-in-Hyperparameter-Search--Vmlldzo0Mzk5MQ
Perform 2 versions of an Experiment:
Question: "How do we know when the observed effect sizes from our model-tuning efforts are meaningful?"
Step 1. Perform k-identical trials with everything but the random seed fixed.

Step 2. For another k trials, fix the random seed but perform hparam search/metalearning algorithm for each trial

Hypothesis: The relative discrepancy in performance between version 1 (noise) and version 2 (signal) should reflect only the signal and no noise

Future Work: "A promising next direction would be to quantify the number of runs/samples needed to ensure statistically significant results relative to a noise baseline."

============================================

### Left off Running Notebook (7 AM 12/1/20):
- Running full plant_village 3 phase, 3 generation, 5 organism-per-population search

In [None]:
### 1. Using tfds.features.ClassLabel

# feature_labels = tfds.features.ClassLabel(names=vocab)
# data = ['Potato___healthy',
#         'Potato___Late_blight',
#         'Raspberry___healthy',
#         'Soybean___healthy',
#         'Squash___Powdery_mildew',
#         'Strawberry___healthy',
#         'Strawberry___Leaf_scorch',
#         'Tomato___Bacterial_spot',
#         'Tomato___Early_blight',
#         'Tomato___healthy']

# data += data[::-1]
# print([feature_labels.str2int(label) for label in data])
# data = train_data
# data_enc = data.map(lambda x,y: (x, feature_labels.int2str(y)))

### 2. Using StringLookup and CategoryEncoding Layers

# layer = StringLookup(vocabulary=vocab, num_oov_indices=0, mask_token=None)
# i_layer = StringLookup(vocabulary=layer.get_vocabulary(), invert=True)
# int_data = layer(data)

# print(len(layer.get_vocabulary()))
# print(len(class_encoder.class_list))
# print(set(layer.get_vocabulary())==set(class_encoder.class_list))

# i_layer = StringLookup(vocabulary=layer.get_vocabulary(), invert=True)
# int_data = layer(data)

# print(layer(data))
# print(i_layer(int_data))

In [None]:
# # from tensorflow.keras.layers.experimental.preprocessing import StringLookup, CategoryEncoding
# # data = tf.constant(["a", "b", "c", "b", "c", "a"])
# # # Use StringLookup to build an index of the feature values
# # indexer = StringLookup()
# # indexer.adapt(data)
# # # Use CategoryEncoding to encode the integer indices to a one-hot vector
# # encoder = CategoryEncoding(output_mode="binary")
# # encoder.adapt(indexer(data))
# # # Convert new test data (which includes unknown feature values)
# # test_data = tf.constant(["a", "b", "c", "d", "e", ""])
# # encoded_data = encoder(indexer(test_data))
# # print(encoded_data)

# vocab = ["a", "b", "c", "d"]
# data = tf.constant([["a", "c", "d"], ["d", "z", "b"]])
# layer = StringLookup(vocabulary=vocab)
# i_layer = StringLookup(vocabulary=layer.get_vocabulary(), invert=True)
# int_data = layer(data)

# print(layer(data))
# print(i_layer(int_data))

In [None]:
# VERBOSE = True
# import pandas as pd
# import json
# from box import Box
# from bunch import Bunch
# # from pprint import pprint as pp
# import random

# ActivationLayers = Box(ReLU=ReLU, ELU=ELU, LeakyReLU=LeakyReLU)
# PoolingLayers = Box(MaxPool2D=MaxPool2D, AveragePooling2D=AveragePooling2D)

# class Chromosome(stateful.Stateful):#BaseChromosome):#(NamedTuple):
    
#     def __init__(self,
#                  hparams: Dict=None,
#                  name=''):
#         super().__init__()
#         self.set_state(hparams)

#     def get_state(self):
#         """Returns the current state of this object.
#         This method is called during `save`.
#         """
#         return self._state
        

#     def set_state(self, state):
#         """Sets the current state of this object.
#         This method is called during `reload`.
#         # Arguments:
#           state: Dict. The state to restore for this object.
#         """
#         self._state = state
        
#     @property
#     def deserialized_state(self):
#         state = copy.deepcopy(self.get_state())
#         state['activation_type'] = ActivationLayers[state['activation_type']]
#         state['pool_type'] = PoolingLayers[state['pool_type']]

# #         state['activation_type'] = [ActivationLayers[act_layer] for act_layer in state['activation_type']]
# #         state['pool_type'] = [PoolingLayers[pool_layer] for pool_layer in state['pool_type']]
#         return state


# import copy

# class ChromosomeOptions(stateful.Stateful): #BaseOptions): #object):
#     """
#     Container class for encapsulating variable-length lists of potential gene variants (individual hyperparameters).
#     To be used as a reservoir from which to sample a complete chromosome made up of 1 variant per gene.
    
#     This should be logged for describing the scope of a given AutoML experiment's hyperparameter search space

#     Gene: The unique identifier of a particular hyperparameter that may reference any of a set of possible variant values.
#     Variant: The particular value of a gene. Used to refer to the 1 value for a single chromosome instance, or 1 value from a set of gene options.

#     Args:
#         NamedTuple ([type]): [description]
#     """

#     def __init__(self,
#                  hparam_lists,
#                  phase=0,
#                  seed=None):
        
# #         self.__chromosomes = {k:v for k,v in locals().items() if k not in ['self', 'kwargs'] and not k.startswith('__')}
# #         print(self.__chromosomes)
        
#         self.set_state(hparam_lists, phase=phase, seed=seed)

#     def get_state(self):
#         """Returns the current state of this object.
#         This method is called during `save`.
#         """
#         return self.state
    
#     def get_config(self):
#         config = copy.deepcopy(self.state)
#         return config
        

#     def set_state(self, state, phase=0, seed=None):
#         """Sets the current state of this object.
#         This method is called during `reload`.
#         # Arguments:
#           state: Dict. The state to restore for this object.
#         """
#         self.set_seed(seed)
#         self.phase = phase
#         self.state = state

#     def set_seed(self, seed=None):
#         self.seed = seed
#         self.rng = np.random.default_rng(seed)
        
#     def sample_k_variants_from_gene(self, gene: str, k: int=1):
#         '''
#         Randomly sample the list of variants corresponding to the key indicated by the first arg, 'gene'. Produce a random sequence of length k, with the default==1.
        
#         Note: If k==1: this automatically returns a single unit from the variants list, which may or may not be a scalar object (e.g. int, str, float)
#         If k > 1: then the sampled variants will always be returned in a list.
        
#         '''
#         all_variants = self.chromosomes[gene]
#         variant_idx = self.rng.integers(low=0, high=len(all_variants), size=k)
#         sampled_variants = [all_variants[idx] for idx in variant_idx.tolist()]
#         if k==1:
#             sampled_variants = sampled_variants[0]
#         return sampled_variants
    
#     def generate_chromosome(self, phase: int=None, seed=None):
#         '''
#         Primary function for utilizing a ChromosomeOptions object during experimentation.
#         Running this function will randomly generate a new Chromosome instance for which each genetic variant is randomly sampled from this object's contained data,
#         in the form of mappings between gene names as keys, and lists of variants as values.
#         '''
#         return Chromosome(hparams={gene:self.sample_k_variants_from_gene(gene) for gene in self.chromosomes.keys()})
    
#     def generate_k_chromosomes(self, k: int=1, seed=None):
#         return [self.generate_chromosome(seed=seed) for _ in range(k)]
        
#     @property
#     def chromosomes(self):
#         return self.state

    
#     @property
#     def deserialized_state(self):
#         state = copy.deepcopy(self.state)
#         state['activation_type'] = [ActivationLayers[act_layer] for act_layer in state['activation_type']]
#         state['pool_type'] = [PoolingLayers[pool_layer] for pool_layer in state['pool_type']]
#         return state
    

# class ChromosomeSampler:
    
#     def __call__(self, phase: int):
        
#         if phase==0:
#             options = ChromosomeOptions({
# #                                       'b_include_layer':[True],
#                                       'a_filter_size':[(1,1), (3,3), (5,5), (7,7), (9,9)],
#                                       'a_include_BN':[True, False],
#                                       'a_output_channels':[8, 16, 32, 64, 128, 256, 512],
#                                       'activation_type':['ReLU', 'ELU', 'LeakyReLU'],
#                                       'b_filter_size':[(1,1), (3,3), (5,5), (7,7), (9,9)],
#                                       'b_include_BN':[True, False],
#                                       'b_output_channels':[8, 16, 32, 64, 128, 256, 512],
#                                       'include_pool':[True, False],
#                                       'pool_type':['MaxPool2D', 'AveragePooling2D'],
#                                       'include_skip':[True, False]
#                                       },
#                                       phase=phase)

#         else:
#             options = ChromosomeOptions({
#                                       'b_include_layer':[True, False],
#                                       'a_filter_size':[(1,1), (3,3), (5,5), (7,7), (9,9)],
#                                       'a_include_BN':[True, False],
#                                       'a_output_channels':[8, 16, 32, 64, 128, 256, 512],
#                                       'activation_type':['ReLU', 'ELU', 'LeakyReLU'],
#                                       'b_filter_size':[(1,1), (3,3), (5,5), (7,7), (9,9)],
#                                       'b_include_BN':[True, False],
#                                       'b_output_channels':[8, 16, 32, 64, 128, 256, 512],
#                                       'include_pool':[True, False],
#                                       'pool_type':['MaxPool2D', 'AveragePooling2D'],
#                                       'include_skip':[True, False]
#                                       },
#                                       phase=phase)
#         return options.generate_chromosome(phase=phase)

In [None]:
# import gc


# class Organism:
#     def __init__(self,
#                  data: Dict[str,tf.data.Dataset],
#                  config=None,
#                  chromosome=None,
#                  phase=0,
#                  generation_number=0,
#                  organism_id=0,
#                  best_organism=None):
#         '''
        
#         Organism is an actor with a State that can take Action in the environment
        
#         config is a . accessible dict object containing model params that will stay constant during evolution
#         chromosome is a dictionary of genes
#         phase is the phase that the individual belongs to
#         best_organism is the best organism of the previous phase
        
#         TODO:
        
#         1. implement to_json and from_json methods for copies
#         2. Separate out step where organism is associated with a dataset
#         '''
#         self.data = data
#         self.train_data = data['train']
#         self.val_data = data['val']
#         self.test_data = data['test']
#         self.config = config
#         self.chromosome = chromosome
#         self.phase = phase
#         self.generation_number = generation_number
#         self.organism_id = organism_id
#         self.best_organism=best_organism

#         if phase > 0:
#             if best_organism is None:
#                 print(f'phase {phase} gen {generation} organism {organism_id}.\nNo previous best model, creating from scratch.')
#             else:
#                 self.last_model = best_organism.model
            
#         self.debug = DEBUG
    
#     @property
#     def name(self):
#         return f'phase_{self.phase}-gen_{self.generation_number}-organism_{self.organism_id}'
    
#     @property
#     def config(self):
#         return self._config
    
#     @config.setter
#     def config(self, config=None):
#         config = config or OmegaConf.create({})
#         print(config)
#         config.input_shape = config.input_shape or (224,224,3)
#         config.output_size = config.output_size or 38
#         config.epochs_per_organism = config.epochs_per_organism or 5
#         self._config = config
        
#     def get_metrics(self):
#         return [tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
#                 tf.keras.metrics.TruePositives(name='tp'),
#                 tf.keras.metrics.FalsePositives(name='fp'),
#                 tf.keras.metrics.TrueNegatives(name='tn'),
#                 tf.keras.metrics.FalseNegatives(name='fn'),
#                 tf.keras.metrics.Precision(name='precision'),
#                 tf.keras.metrics.Recall(name='recall'),
#                 tf.keras.metrics.AUC(name='auc')]
    
#     @property
#     def fitness_metric_name(self):
#         return 'accuracy'
   
#     @property
#     def chromosome(self):
#         return self._chromosome.deserialized_state
    
#     @chromosome.setter
#     def chromosome(self, chromosome):
#         self._chromosome = chromosome
        
    
#     def build_model(self):
#         '''
#         This is the function to build the keras model
#         '''
#         K.clear_session()
#         gc.collect()
#         inputs = Input(shape=self.config.input_shape)
#         if self.phase != 0:
#             # Slice the prev best model # Use the model as a layer # Attach new layer to the sliced model
#             intermediate_model = Model(inputs=self.last_model.input,
#                                        outputs=self.last_model.layers[-3].output)
#             for layer in intermediate_model.layers:
#                 # To make the iteration efficient
#                 layer.trainable = False
#             inter_inputs = intermediate_model(inputs)
#             x = Conv2D(filters=self.chromosome['a_output_channels'],
#                        padding='same',
#                        kernel_size=self.chromosome['a_filter_size'],
#                        use_bias=self.chromosome['a_include_BN'])(inter_inputs)
#             # This is to ensure that we do not randomly chose anothere activation
#             self.chromosome['activation_type'] = self.best_organism.chromosome['activation_type']
#         else:
#             # For PHASE 0 only
#             # input layer
#             x = Conv2D(filters=self.chromosome['a_output_channels'],
#                        padding='same',
#                        kernel_size=self.chromosome['a_filter_size'],
#                        use_bias=self.chromosome['a_include_BN'])(inputs)
            
#         if self.chromosome['a_include_BN']:
#             x = BatchNormalization()(x)
#         x = self.chromosome['activation_type']()(x)
#         if self.chromosome['include_pool']:
#             x = self.chromosome['pool_type'](strides=(1,1),
#                                              padding='same')(x)
#         if self.phase != 0 and self.chromosome['b_include_layer'] == False:
#             # Except for PHASE0, there is a choice for
#             # the number of layers that the model wants
#             if self.chromosome['include_skip']:
#                 y = Conv2D(filters=self.chromosome['a_output_channels'],
#                            kernel_size=(1,1),
#                            padding='same')(inter_inputs)
#                 x = Add()([y,x])
#             x = GlobalAveragePooling2D()(x)
#             x = Dense(self.config.output_size, activation='softmax')(x)
#         else:
#             # PHASE0 or no skip
#             # in the tail
#             x = Conv2D(filters=self.chromosome['b_output_channels'],
#                        padding='same',
#                        kernel_size=self.chromosome['b_filter_size'],
#                        use_bias=self.chromosome['b_include_BN'])(x)
#             if self.chromosome['b_include_BN']:
#                 x = BatchNormalization()(x)
#             x = self.chromosome['activation_type']()(x)
#             if self.chromosome['include_skip']:
#                 y = Conv2D(filters=self.chromosome['b_output_channels'],
#                            padding='same',
#                            kernel_size=(1,1))(inputs)
#                 x = Add()([y,x])
#             x = GlobalAveragePooling2D()(x)
#             x = Dense(self.config.output_size, activation='softmax')(x)
#         self.model = Model(inputs=[inputs], outputs=[x])
#         self.model.compile(optimizer='adam',
#                            loss='categorical_crossentropy',
#                            metrics=self.get_metrics())
        
#     def fitnessFunction(self,
#                         train_data,
#                         val_data,
#                         generation_number):
#         '''
#         This function is used to calculate the
#         fitness of an individual.
#         '''
#         print('FFITNESS FUNCTION FFS')
#         print('vars():', vars())
#         self.run = wandb.init(**self.get_wandb_credentials(phase=self.phase,
#                                                 generation_number=generation_number),
#                    config=self.config)
        
#         self.model.fit(train_data,
#                        steps_per_epoch=self.config.steps_per_epoch,
#                        epochs=self.config.epochs_per_organism,
#                        callbacks=[WandbCallback()],
#                        verbose=1)
#         self.results = self.model.evaluate(val_data,
#                                            steps=self.config.validation_steps,
#                                            return_dict=True,
#                                            verbose=1)
#         self.fitness = self.results[self.fitness_metric_name]
#         print(self.name)
#         print('fitness:', self.fitness)
#         print('results:\n', len(self.results))
#         print(self.results)
        
        
# #     @results.setter
# #     def results(self, metrics_values):
# #         self._results = {name:values for name, value in zip(self.model.metrics_names, metrics_values)}
        
# #     @property
# #     def results(self):
# #         return self._results
        
#     def crossover(self,
#                   partner,
#                   generation_number):
#         '''
#         This function helps in making children from two
#         parent individuals.
#         '''
#         child_chromosome = {}
#         endpoint = np.random.randint(low=0, high=len(self.chromosome))
#         for idx, key in enumerate(self.chromosome):
#             if idx <= endpoint:
#                 child_chromosome[key] = self.chromosome[key]
#             else:
#                 child_chromosome[key] = partner.chromosome[key]
#         child = Organism(chromosome=child_chromosome,
#                          data=self.data,
#                          config=self.config,
#                          phase=self.phase,
#                          generation_number=generation_number,
#                          organism_id=f'{self.organism_id}+{partner.organism_id}',
#                          best_organism=self.best_organism)
        
#         child.build_model()
#         child.fitnessFunction(self.train_data,
#                               self.val_data,
#                               generation_number=generation_number)
#         return child
    
#     def mutation(self, generation_number):
#         '''
#         One of the gene is to be mutated.
#         '''
#         index = np.random.randint(0, len(self.chromosome))
#         key = list(self.chromosome.keys())[index]
#         if  self.phase != 0:
#             self.chromosome[key] = options[key][np.random.randint(len(options[key]))]
#         else:
#             self.chromosome[key] = options_phase0[key][np.random.randint(len(options_phase0[key]))]
#         self.build_model()
#         self.fitnessFunction(self.train_data,
#                              self.val_data,
#                              generation_number=generation_number)
    
#     def show(self):
#         '''
#         Util function to show the individual's properties.
#         '''
#         pp.pprint(self.config)
#         pp.pprint(self.chromosome)
        
        
#     def get_wandb_credentials(self, phase: int=None, generation_number: int=None):
#         phase = phase or self.phase
#         generation_number = generation_number or self.generation_number
#         if self.debug:
#             return get_wandb_credentials(phase=phase,
#                                           generation_number=generation_number,
#                                           entity="jrose",
#                                           project=f"vlga-plant_village-DEBUG")           
#         return get_wandb_credentials(phase=phase,
#                                       generation_number=generation_number,
#                                       entity="jrose",
#                                       project=f"vlga-plant_village")

        
    
# def get_wandb_credentials(phase: int,
#                           generation_number: int,
#                           entity="jrose",
#                           project=f"vlga-plant_village"):
    
#     return dict(entity=entity,
#                 project=project,
#                 group='KAGp{}'.format(phase),
#                 job_type='g{}'.format(generation_number))

## Schema for defining, loading, using, and logging configuration for hparam search


### 1. Begin Hooks

    a. at_search_begin
        Store hparam search space definitions in a file called `search_space.json`
    b. at_trial_begin

    c. at_train_begin

    d. at_epoch_begin

    e. at_batch_begin

### 2. End Hooks

    a. at_batch_end
    
    b. at_epoch_end

    c. at_train_end

    d. at_trial_end

    e. at_search_end


7. 

## 1. INTERESTING REFACTOR IDEA:
    TODO: Refactor chromosome structure to standardize the configuration options for repeated model structures
    ### (3 AM 11/27/20)

    e.g. Create a separate ConvOptions(NamedTuple) class to contain all 3 options:
        filter_size
    include_BN
    output_channels

    Then in each "ChromosomeOptions" (consider making each of those a chromosome, and upgrading what's now a chromosome to a full Genome)
    store a separate ConvOptions for layer a and layer b, separately.


## 2. TODO: 
    Consider transferring mutate() method from Organism to Chromosome, while potentially keeping crossover() method as part of organism's namespace. Purpose is to encapsulate functionality as close as possible with the data/abstractions it will operate on


## 3. To Consider:
    How can I quantify the information coverage and computational complexity of a given set of chromosome options? 

        a. Start with the raw # of permutations of all chromosome options
        b. Adjust by the expected coverage for each variant. E.g. How much of the hyperparameter space are we covering in our naive uniform grid search?

In [None]:
# def softmax(x):
#     e_x = np.exp(x - np.max(x))
#     return e_x / e_x.sum()



# class Generation:
#     def __init__(self,
#                  data,
#                  generation_config,
#                  organism_config,
#                  phase,
#                  previous_best_organism,
#                  verbose: bool=False):
#         self.data = data
#         self.config = generation_config
#         self.organism_config = organism_config
#         self.population = []
#         self.generation_number = 0
#         self.phase = phase
#         # creating the first population: GENERATION_0
#         # can be thought of as the setup function
#         self.previous_best_organism = previous_best_organism or None
#         self.best = {}
#         self._initialized = False
#         self.initialize_population(verbose=verbose)
#         self.verbose = verbose
        
#     @property
#     def config(self):
#         return self._config
    
#     @config.setter
#     def config(self, config=None):
#         config = config or OmegaConf.create({})
#         config.population_size = config.population_size or 5
#         config.num_generations_per_phase = config.num_generations_per_phase or 3
#         config.fitSurvivalRate = config.fitSurvivalRate or 0.5
#         config.unfitSurvivalProb = config.unfitSurvivalProb or 0.2
#         config.mutationRate = config.mutationRate or 0.1
#         config.num_phases = config.num_phases or 5
        
#         self._config = config
#         self.__dict__.update(config)
        
        
#     def initialize_population(self, verbose=True):
#         '''
#         1. Create self.population_size individual organisms from scratch by randomly sampling an initial set of hyperparameters (a chromosome)
#         2. As each is instantiated, build its model
#         3. Assess their fitness one-by-one
#         4. Sort models by relative fitness so we have a (potentially) new Best Organism (best model)
#         4. Increment generation number to 1
#         '''

#         for idx in range(self.population_size):
#             if verbose:
#                 print('<'*10,' '*5,'>'*10)
#                 print(f'Creating, training then testing organism {idx} out of a maximum {self.population_size} from generation {self.generation_number} and phase {self.phase}')
#             org = Organism(chromosome=sampler(self.phase), #.get_state(),
#                            data=self.data,
#                            config=self.organism_config,
#                            phase=self.phase,
#                            generation_number=self.generation_number,
#                            organism_id=idx,
#                            best_organism=self.previous_best_organism)
#             org.build_model()
#             org.fitnessFunction(org.data['train'],
#                                 org.data['test'],
#                                 generation_number=self.generation_number)
#             self.population.append(org)

#         self._initialized = True
#         self.sortModel(verbose=verbose)
#         self.generation_number += 1
#         self.evaluate(run=self.population[0].run)

#     def sortModel(self, verbose: bool=True):
#         '''
#         sort the models according to the 
#         fitness in descending order.
#         '''
#         previous_best = self.best_fitness
#         fitness = [ind.fitness for ind in self.population]
#         sort_index = np.argsort(fitness)[::-1]
#         self.population = [self.population[index] for index in sort_index]

#         if self.best_organism_so_far.fitness > previous_best:
#             self.best['organism'] = self.best_organism_so_far
#             self.best['model'] = self.best_organism_so_far.model
#             self.best['fitness'] = self.best_organism_so_far.fitness
            
#             if verbose:
#                 print(f'''NEW BEST MODEL:
#                 Fitness = {self.best["fitness"]:.3f}
#                 Previous Fitness = {previous_best:.3f}
#                 Name = {self.best['organism'].name}
#                 chromosome = {self.best['organism'].chromosome}''')
        
#     @property
#     def best_organism_so_far(self):
#         if self._initialized:
#             return self.population[0]
#         else:
#             return self.previous_best_organism

#     @property
#     def best_fitness(self):
#         if self._initialized:
#             return self.population[0].fitness
#         elif self.previous_best_organism is not None:
#             return self.previous_best_organism.fitness
#         else:
#             return 0.0
        
        
#     def generate(self):
#         '''
#         Generate a new generation in the same phase
#         '''
#         number_of_fit = int(self.population_size * self.fitSurvivalRate)
#         new_pop = self.population[:number_of_fit]
#         for individual in self.population[number_of_fit:]:
#             if np.random.rand() <= self.unfitSurvivalProb:
#                 new_pop.append(individual)
#         for index, individual in enumerate(new_pop):
#             if np.random.rand() <= self.mutationRate:
#                 new_pop[index].mutation(generation_number=self.generation_number)
#         fitness = [ind.fitness for ind in new_pop]
#         children=[]
#         for idx in range(self.population_size-len(new_pop)):
#             parents = np.random.choice(new_pop, replace=False, size=(2,), p=softmax(fitness))
#             A=parents[0]
#             B=parents[1]
#             child=A.crossover(B, generation_number=self.generation_number)
#             children.append(child)
#         self.population = new_pop+children
#         self.sortModel()
#         self.generation_number+=1

#     def evaluate(self, run=None, last=False):
#         '''
#         Evaluate the generation
#         '''
#         print('EVALUATE')
#         fitness = [ind.fitness for ind in self.population]

#         BestOrganism = self.population[0]
#         if run is None:
#             run = BestOrganism.run
#         run.log({'population_size':len(fitness)}, commit=False)
#         run.log({'Best fitness': fitness[0]}, commit=False)
#         run.log({'Average fitness': sum(fitness)/len(fitness)})
        
#         self.population[0].show()
#         print('BEST ORGANISM', BestOrganism.name)
#         k=16
#         if last:
#             k=32
#         model_path = f'best-model-phase_{self.phase}.png'
#         tf.keras.utils.plot_model(BestOrganism.model, to_file=model_path)
#         run.log({"best_model": [wandb.Image(model_path, caption=f"Best Model phase_{self.phase}")]})
#         log_high_loss_examples(BestOrganism.test_data,
#                                BestOrganism.model, 
#                                k=k,
#                                run=run)
            
#         return BestOrganism

#     def run_generation(self):
#         print('RUN GENERATION')
#         self.generate()
#         last = False
#         if self.generation_number == self.num_generations_per_phase:
#             last = True
#         best_organism = self.evaluate(last=last)
#         return best_organism
        
#     def run_phase(self):#, num_generations_per_phase: int=1):
#         print('RUN PHASE')
#         while self.generation_number < self.num_generations_per_phase:
#             best_organism = self.run_generation()
#             print(f'FINISHED GENERATION {self.generation_number}')
#             print(vars())
            
#             if self.verbose:
#                 print(f'FINISHED generation {self.generation_number}. Best fitness = {best_organism.fitness}')
            
#         return self.population[0] #best_organism
        
#             return self.population[0]