## Imports

In [None]:
import pyro
import torch
import pandas as pd
from tqdm.auto import tqdm
from typing import Callable
import matplotlib.pyplot as plt
from sklearn.naive_bayes import GaussianNB, MultinomialNB, BernoulliNB, CategoricalNB
from sklearn import metrics as sklearn_metrics
from sklearn.preprocessing import label_binarize

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

pyro.enable_validation(True)
pd.options.display.float_format = '{:.3f}'.format

from preprocess import preprocess

## Utils

In [None]:
TRAIN: pd.DataFrame = None
VAL: pd.DataFrame = None
TEST: pd.DataFrame = None

def show_metrics(predict: Callable[[pd.DataFrame], pd.DataFrame], show_only_test: bool = False):
    for split, name in zip([TRAIN, VAL, TEST], ['train', 'val', 'test']):
        y_true = split['y']
        y_pred = predict(split.drop('y', axis=1))
        y_true_binarized = label_binarize(y_true, classes=range(6))
        y_pred_binarized = label_binarize(y_pred, classes=range(6))
        res = {
            'accuracy': sklearn_metrics.accuracy_score(y_true, y_pred),
            'precision': sklearn_metrics.precision_score(y_true, y_pred, average='weighted'),
            'recall': sklearn_metrics.recall_score(y_true, y_pred, average='weighted'),
            'f1': sklearn_metrics.f1_score(y_true, y_pred, average='weighted'),
            'auc': sklearn_metrics.roc_auc_score(y_true_binarized, y_pred_binarized, average='weighted', multi_class='ovr'),
        }
        res = pd.DataFrame([res])
        # if not show_only_test: display(res)
        # elif name == 'test': display(res)
        if name == 'test': return res

## Implementation sklearn

In [None]:
metrics = None
for plabel, pkwargs in [
    ('one hot', dict(categorical_to_numerical_scale=False, categorical_to_one_hot=True, continous_to_discrete=False)),
    ('ordinal', dict(categorical_to_numerical_scale=True, categorical_to_one_hot=False, continous_to_discrete=False)),
    ('ordinal 20 bins', dict(categorical_to_numerical_scale=True, categorical_to_one_hot=False, continous_to_discrete=True)),
]:
    preprocess(**pkwargs)
    TRAIN = pd.read_csv('data/preprocessed_train.csv')
    VAL = pd.read_csv('data/preprocessed_val.csv')
    TEST = pd.read_csv('data/preprocessed_test.csv')
    for mlabel, model in [
        ('GaussianNB', GaussianNB(priors=None)), # https://scikit-learn.org/stable/modules/generated/sklearn.naive_bayes.GaussianNB.html#sklearn.naive_bayes.GaussianNB
        ('MultinomialNB', MultinomialNB(alpha=1.0, force_alpha=True, fit_prior=True, class_prior=None)), # https://scikit-learn.org/stable/modules/generated/sklearn.naive_bayes.MultinomialNB.html#sklearn.naive_bayes.MultinomialNB
        ('BernoulliNB', BernoulliNB(alpha=1.0, force_alpha=True, binarize=0.0, fit_prior=True, class_prior=None)), # https://scikit-learn.org/stable/modules/generated/sklearn.naive_bayes.BernoulliNB.html#sklearn.naive_bayes.BernoulliNB
        ('CategoricalNB', CategoricalNB(alpha=1.0, force_alpha=True, fit_prior=True, class_prior=None, min_categories=None)), # https://scikit-learn.org/stable/modules/generated/sklearn.naive_bayes.CategoricalNB.html#sklearn.naive_bayes.CategoricalNB
    ]:
        model.fit(X=TRAIN.drop('y', axis=1), y=TRAIN['y'])
        res = show_metrics(lambda x: model.predict(x), show_only_test=True)
        res['preprocessing'] = plabel
        res['model'] = mlabel
        if metrics is None: metrics = res
        else: metrics = pd.concat([metrics, res], ignore_index=True)
metrics = metrics[['preprocessing','model','accuracy','precision','recall','f1','auc']]
display(metrics)

## Load data

In [None]:
TRAIN = pd.read_csv('data/preprocessed_train.csv')
VAL = pd.read_csv('data/preprocessed_val.csv')
TEST = pd.read_csv('data/preprocessed_test.csv')

col_cat = ['Gender','Age','family_history_with_overweight','FAVC','FCVC','NCP','CAEC','SMOKE','CH2O','SCC','FAF','TUE','CALC','MTRANS']
# col_cat = ['Gender','family_history_with_overweight','FAVC','CAEC','SCC','CALC','SMOKE','MTRANS']
# col_cat = []

for c in col_cat: TRAIN[c] = TRAIN[c].astype('category')
for c in col_cat: VAL[c] = VAL[c].astype('category')
for c in col_cat: TEST[c] = TEST[c].astype('category')

## Implementation pyro

In [None]:
class GaussianNBClassifier:
    def __init__(self, num_epochs=500, lr=1e-2, optimizer=None, loss=None):
        self._num_epochs = num_epochs
        self._lr = lr
        self._optimizer = optimizer if optimizer is not None else pyro.optim.Adam({'lr': self._lr})
        self._loss = loss if loss is not None else pyro.infer.Trace_ELBO()
        self._num_cls = None
        self._c_logits = None
        self._num_probs = None
        self._cat_probs = None
        
    def fit(self, X, y, valX=None, valy=None):
        pyro.clear_param_store()
        categorical_cols = X.select_dtypes('category').columns.values
        numerical_cols = [c for c in X.columns if c not in categorical_cols]
        print('num', numerical_cols)
        print('cat', categorical_cols)
        svi = pyro.infer.SVI(model=self._model, guide=self._guide,
            optim=self._optimizer, loss=self._loss)
        accs = []
        losses = []
        val_accs = []
        val_losses = []
        for i in tqdm(range(self._num_epochs)):
            losses.append(svi.step(X, y))
            if valX is not None: val_losses.append(self._loss.loss(self._model, self._guide, valX, valy))
            if (i+1) % 10 == 0:
                accs.append(sklearn_metrics.accuracy_score(y, self.predict(X)))
                if valX is not None: val_accs.append(sklearn_metrics.accuracy_score(valy, self.predict(valX)))
        return accs, losses, val_accs, val_losses

    def predict(self, X):
        pred = pyro.infer.Predictive(model=self._model, guide=self._guide,
            num_samples=1, return_sites=('logP(c|x)',))
        log_pcx = pred(X)['logP(c|x)'].detach().squeeze(0).squeeze(0)
        y_pred = torch.argmax(log_pcx, dim=-1)
        return y_pred
    
    def _model(self, X, y=None):    
        if y is None: # inference mode
            self._get_classes_log_probs(X)
            return
        self._num_cls = max(y) + 1
        categorical_cols = X.select_dtypes('category').columns.values
        numerical_cols = [c for c in X.columns if c not in categorical_cols]
        self._init_c_logits()
        self._init_num_params(X, numerical_cols)
        self._init_cat_params(X, categorical_cols)
        self._observe_numerical_features_given_classes(X, y)
        self._observe_categorical_features_given_classes(X, y)
        
    def _guide(self, X, y=None):
        pass  # This is meant to be an empty function
    
    def _init_c_logits(self):
        self._c_probs = pyro.param('c_probs',
            lambda: torch.ones(self._num_cls).div(self._num_cls),
            constraint=torch.distributions.constraints.simplex)
        
    def _init_num_params(self, X, numerical_cols):
        self._num_dists = {
            col: {
                'mu': pyro.param(f'{col}_mu', lambda: torch.zeros(self._num_cls)),
                'sigma': pyro.param(f'{col}_sigma', lambda: torch.ones(self._num_cls),
                    constraint=torch.distributions.constraints.positive),
            } for col in numerical_cols
        }
    
    def _init_cat_params(self, X, categorical_cols):  # Add
        self._cat_logits = {
            col: pyro.param(f'{col}_logits', lambda: torch.ones([self._num_cls, len(X[col].cat.categories)]))
            for col in categorical_cols
        }
        
    def _observe_numerical_features_given_classes(self, X, y):
        for c in range(self._num_cls):
            x_c = X[y==c]
            with pyro.plate(f'data-numerical-{c}', x_c.shape[0]):
                for nc, v in self._num_dists.items():
                    pyro.sample(f'x_{nc}|c={c}', 
                        pyro.distributions.Normal(v['mu'][c], v['sigma'][c]),
                        obs=torch.tensor(x_c[nc].values))

    def _observe_categorical_features_given_classes(self, X, y):  # Add
        for c in range(self._num_cls):
            x_c = X[y==c]
            with pyro.plate(f'data-categorical-{c}', x_c.shape[0]):
                for cc, v in self._cat_logits.items():
                    pyro.sample(f'x_{cc}|c={c}',
                        pyro.distributions.Categorical(logits=v[c]),
                        obs=torch.tensor(x_c[cc].values))

    def _get_log_likelihood(self, X):
        log_lk = []
        for c in range(self._num_cls):
            lps = []
            lps.extend([
                pyro.distributions.Normal(v['mu'][c], v['sigma'][c]).log_prob(torch.tensor(X[nc].values))
                for nc, v in self._num_dists.items()])
            lps.extend([
                pyro.distributions.Categorical(logits=v[c]).log_prob(torch.tensor(X[cc].values))
                for cc, v in self._cat_logits.items()])
            log_lk.append(torch.stack(lps).sum(dim=0))
        return torch.stack(log_lk).t()
    
    def _get_classes_log_probs(self, X):
        log_lk = self._get_log_likelihood(X)
        log_pcx = pyro.deterministic('logP(c|x)', self._c_probs.log() + log_lk)

#### Test

In [None]:
model = GaussianNBClassifier(num_epochs=1000)
model.fit(X=TRAIN.drop('y', axis=1), y=TRAIN['y'])
show_metrics(lambda x: model.predict(x))

#### learning rate experiments

In [None]:
metrics = None
for lr in [1e-1, 1e-2, 1e-3, 1e-4, 1e-5]:
    model = GaussianNBClassifier(num_epochs=1000, lr=lr)
    model.fit(X=TRAIN.drop('y', axis=1), y=TRAIN['y'])
    res = show_metrics(lambda x: model.predict(x))
    res['lr'] = lr
    if metrics is None: metrics = res
    else: metrics = pd.concat([metrics, res], ignore_index=True)

metrics = metrics[['lr','accuracy','precision','recall','f1','auc']]
display(metrics)
# display(metrics.to_latex(index=False, float_format='%.3f'))

#### weight decay experiments

In [None]:
metrics = None
for wd in [1e-3, 1e-4]:
    model = GaussianNBClassifier(num_epochs=1000, lr=1e-1, optimizer=pyro.optim.Adam({'lr': 1e-1, 'weight_decay': wd}))
    model.fit(X=TRAIN.drop('y', axis=1), y=TRAIN['y'])
    res = show_metrics(lambda x: model.predict(x))
    res['wd'] = wd
    if metrics is None: metrics = res
    else: metrics = pd.concat([metrics, res], ignore_index=True)

metrics = metrics[['wd','accuracy','precision','recall','f1','auc']]
display(metrics)
# display(metrics.to_latex(index=False, float_format='%.3f'))

#### loss and acc graph

In [None]:
model = GaussianNBClassifier(num_epochs=1000)
accs, losses, val_accs, val_losses = model.fit(X=TRAIN.drop('y', axis=1), y=TRAIN['y'], valX=VAL.drop('y', axis=1), valy=VAL['y'])

In [None]:
plt.figure(figsize=(8, 3))
plt.plot(range(len(losses)), losses, label='Train', color='b')
plt.plot(range(len(val_losses)), val_losses, label='Validation', color='r')
plt.xlabel('Epoch')
plt.ylabel('Value')
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(8, 3))
plt.plot([i*10 for i in range(len(accs))], accs, label='Train', color='b')
plt.plot([i*10 for i in range(len(val_accs))], val_accs, label='Validation', color='r')
plt.xlabel('Epoch')
plt.ylabel('Value')
plt.legend()
plt.tight_layout()
plt.show()