In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import pandas as pd

sys.path.append('..')

import torch
from torch.utils.data import DataLoader
from sklearn import datasets
from sklearn.model_selection import train_test_split

# DEID libraries
from gojo import core
from gojo import deepl
from gojo import util
from gojo import plotting

In [None]:
# load test dataset (Wine)
wine_dt = datasets.load_wine()

# create the target variable. Classification problem 0 vs rest
# to see the target names you can use wine_dt['target_names']
y = (wine_dt['target'] == 1).astype(int)  
X = wine_dt['data']

# standarize input data
std_X = util.zscoresScaling(X)

# split Xs and Ys in training and validation
X_train, X_valid, y_train, y_valid = train_test_split(
    std_X, y, train_size=0.8, random_state=1997, shuffle=True,
    stratify=y
)
X_train.shape, X_valid.shape, '%.3f' % y_train.mean(),  '%.3f' % y_valid.mean()

In [None]:
# create the dataloaders
train_dl = DataLoader(
    deepl.loading.TorchDataset(X=X_train, y=y_train), 
    batch_size=16, shuffle=True)

valid_dl = DataLoader(
    deepl.loading.TorchDataset(X=X_valid, y=y_valid), 
    batch_size=X_valid.shape[0], shuffle=False)

In [None]:
# create a basic FFN
model = deepl.ffn.createSimpleFFNModel(
    in_feats=X_valid.shape[1],
    out_feats=1,
    layer_dims=[20],
    layer_activation=torch.nn.ELU(),
    output_activation=torch.nn.Sigmoid()
)
model

In [None]:
output = deepl.fitNeuralNetwork(
    deepl.iterSupervisedEpoch,
    model=model,
    train_dl=train_dl,
    valid_dl=valid_dl,
    n_epochs=50,
    loss_fn=torch.nn.BCELoss(),
    optimizer_class=torch.optim.Adam,
    optimizer_params={'lr': 0.001},
    device='mps',
    metrics=core.getDefaultMetrics('binary_classification', bin_threshold=0.5)
)

In [None]:
output.keys()

In [None]:
train_info = output['train']
valid_info = output['valid']


In [None]:
plotting.linePlot(
    train_info, valid_info,
    x='epoch', y='loss (mean)', err='loss (std)',
    legend_labels=['Train', 'Validation'],
    title='Model convergence',
    ls=['solid', 'dashed'],
    style='default', legend_pos='center right'
    
)


In [None]:
assert False

In [None]:
import inspect

In [None]:
inspect.isclass(valid_info)

In [None]:
import matplotlib.pyplot as plt

def linePlot(*dfs, x: str, y: str, err: str = None, err_alpha: float = 0.3, 
             labels: list = None, figsize: tuple = (6, 3.5), dpi: int = 100, 
             style: str = 'ggplot', legend_pos: str = 'upper right', 
             legend_size: int = 12, colors: list = None, grid_alpha: float = 0.5,
             xlabel_size: float or int = 13, ylabel_size: float or int = 13,
             title: str = '', title_size: int or float = 15, save: str = None,
             save_kw: dict = None, show: bool = True
            ):
    """ Description """
    # check input data types
    
    for df in dfs:
        # check dataframe input types
        # check x, y, and err variables
        pass 
    
    if labels is None:
        labels = ['(%d)' % (i+1) for i in range(len(dfs))]
    
    # check number of labels and number of dfs
    # check the lenght of the colors
    
    # plot information
    with plt.style.context(style):
        fig, ax = plt.subplots(figsize=figsize)
        fig.set_dpi(dpi)
        
        for i, (label, df) in enumerate(zip(labels, dfs)):
            color = None if colors is None else colors[i]
            
            ax.plot(
                df[x].values, df[y].values, label=label,
                color=color)
            
            if err is not None:
                ax.fill_between(
                    df[x].values, 
                    df[y].values + df[err].values,
                    df[y].values - df[err].values,
                    alpha=err_alpha)
            
        # figure layout
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.grid(alpha=grid_alpha)
        ax.legend(loc=legend_pos, prop=dict(size=legend_size))
        ax.set_xlabel(x, size=xlabel_size)
        ax.set_ylabel(y, size=ylabel_size)
        ax.set_title(title, size=title_size)
        
        # save figure if specified
        if save:
            save_kw = {} if save_kw is None else save_kw
            plt.savefig(save, **save_kw)
            
        if show:
            plt.show()
        
linePlot(train_info, valid_info, x='epoch', y='loss (mean)', labels=['Train', 'Validation'], err='loss (std)', title='FOO')

In [None]:
valid_info['epoch'].values

In [None]:
valid_info

In [None]:
train_info