# Imports

In [None]:
# standard lib imports
from functools import partial
from os.path import join as join_path, abspath
from sys import path as sys_path
from typing import Optional

# Numeric libraries
from matplotlib.pyplot import figure
import numpy as np
import torch
from torch.utils.data.dataloader import DataLoader
from torch.optim.sgd import SGD
from sklearn.datasets import make_moons

# inhouse imports
    # enabling imports of "adjacent" modules
module_path = abspath(join_path(".."))
if module_path not in sys_path:
    sys_path.append(module_path)

from lib.definitions import meta_ds, meta_sample, modeling_ds
from lib.learning_utils import collate_modeling_samples
from lib.nn_blocks import fc_classifier
from lib.nn_optimize import train, validate
from lib.learning_metrics import md_classification_accuracy


#region Visualization 
from bokeh.plotting import figure as bokeh_figure, output_notebook, show, ColumnDataSource
from bokeh.io.notebook import push_notebook 
from bokeh.layouts import column
output_notebook()
#endregion 

# Configurations

In [None]:
# Data
N_TRAIN, N_VALID = 300, 100 
assert N_TRAIN != N_VALID, "You will get identical sampling for training and validation"
NOISE = 0.15
SEED = 5

# Optimization
BATCH_SZ = 8
EPOCH_CNT = 1000

# Model
LINEAR_BLOCKS_DIMS = [16, 16]
LINEAR_BLOCKS_ACTIVATIONS = ['relu', 'relu', 'sigmoid']

# Generate Data

In [None]:
x_train, y_train = make_moons(n_samples = N_TRAIN,
                noise = NOISE,
                random_state = SEED)

x_valid, y_valid = make_moons(n_samples = N_VALID,
                noise = NOISE,
                random_state = SEED)

# Visualize Data 

In [None]:
fig = figure(figsize=(10, 10))
ax = fig.add_subplot(1, 1, 1)
train_one_indices = np.argwhere(y_train == 1)
train_zero_indices = np.argwhere(y_train == 0)
ax.plot(x_train[train_one_indices, 0], x_train[train_one_indices, 1], "k.");
ax.plot(x_train[train_zero_indices, 0], x_train[train_zero_indices, 1], "r.");
valid_one_indices = np.argwhere(y_valid == 1)
valid_zero_indices = np.argwhere(y_valid == 0)
ax.plot(x_valid[valid_one_indices, 0], x_valid[valid_one_indices, 1], "kx",\
        markersize = 10);
ax.plot(x_valid[valid_zero_indices, 0], x_valid[valid_zero_indices, 1], "rx",\
        markersize = 10);
ax.legend(["Training Data Upper", "Training Data Lower",\
            "Validation Data Upper", "Training Data Lower"]);

# Create Meta-data datasets

In [None]:
class moon_sample(meta_sample):
    def __init__(self, idx, ds_identity):
        super().__init__(f"{idx}_{ds_identity}")
        self.idx = idx
    def __repr__(self) -> str:
        return self.identity

class moon_meta_ds(meta_ds):
    def __init__(self, samples_cnt, noise, seed):
        self.ds_identity = \
        f"noise_{noise}_N_{samples_cnt}_seed_{seed}"
        self.samples_cnt = samples_cnt
    def __len__(self):
        return self.samples_cnt

    def __getitem__(self, ind: int) -> meta_sample:
        if ind >= self.samples_cnt:
            raise IndexError(
                f"The dataset only has {self.samples_cnt} samples.")
        return moon_sample(ind, self.ds_identity)

train_meta_ds = moon_meta_ds(samples_cnt = N_TRAIN,
                             noise = NOISE,
                             seed = SEED)

valid_meta_ds = moon_meta_ds(samples_cnt = N_VALID,
                            noise = NOISE,
                            seed = SEED)


# Create modeling dataset

In [None]:
def x_creator(data_x : np.ndarray, sample : moon_sample):
    return torch.tensor(data_x[sample.idx, : ]).reshape(shape = [1, 2])

def y_creator(data_y : np.ndarray, sample : moon_sample):
    return torch.tensor(data_y[sample.idx]).reshape(shape = [1, 1])

train_model_ds = modeling_ds(
    meta_ds = train_meta_ds,
    x_creator = partial(x_creator, x_train),
    y_creator = partial(y_creator, y_train))

valid_model_ds = modeling_ds(
    meta_ds = valid_meta_ds,
    x_creator = partial(x_creator, x_valid),
    y_creator = partial(y_creator, y_valid))

# Creating Batchers

In [None]:
batcher_train = DataLoader(dataset = train_model_ds,
                           batch_size = BATCH_SZ,
                           collate_fn = collate_modeling_samples)

batcher_valid = DataLoader(dataset = valid_model_ds,
                           batch_size = BATCH_SZ,
                           collate_fn = collate_modeling_samples)

# Create Model

In [None]:
def create_model():
    model = fc_classifier(input_dim = 2,
                        output_dim = 2,
                        linear_block_sizes = LINEAR_BLOCKS_DIMS,
                        linear_block_activations = LINEAR_BLOCKS_ACTIVATIONS)
    return model
model = create_model()
model = model.to(torch.float64)
print(model)

# Creating Optimizer and loss function

In [None]:
optimizer = SGD(model.parameters(), lr = 0.1)
loss = torch.nn.CrossEntropyLoss()
def metric (y_hat, y):
    return  md_classification_accuracy(y_hat, y), y.shape[0]

In [None]:
y_hat = model(next(iter(batcher_train))[0])
y = next(iter(batcher_train))[1]
metric(y_hat, y)

In [None]:
training_ds = ColumnDataSource({'epoch':[], 'train_accu':[], 'valid_accu':[]})
accu_fig = bokeh_figure(width=1200, height=400, tools='hover,box_zoom,reset')
accu_fig.line(x='epoch', y='train_accu', source=training_ds, line_color='black', legend_label='Training Accuracy')
accu_fig.line(x='epoch', y='valid_accu', source=training_ds, line_color='red', legend_label='Validation Accuracy')
accu_fig.xaxis.axis_label = 'epoch'
accu_fig.yaxis.axis_label = 'accuracy'
accu_fig.legend.location = "top_left"
accu_fig.legend.click_policy = "hide"

show(accu_fig, notebook_handle = True);

# Training Loop

In [None]:
for epoch_i in range(EPOCH_CNT):
    train_avg_accu, train_avg_loss = train(
        train_loader = batcher_train,
        model = model,
        criterion = loss,
        metric_func = metric,
        optimizer = optimizer,
        epoch = epoch_i)
    
    valid_avg_accu, valid_avg_loss = validate(batcher_valid,
            model,
            loss,
            metric,
            epoch_i,
            print_fn = print)

    new_data = {'epoch':[epoch_i],
                'train_accu':[train_avg_accu],
                'valid_accu':[valid_avg_accu]}
    training_ds.stream(new_data)
    push_notebook()