# Imports

In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
from typing import List

import numpy as np
from IPython.core.display import display
from pandas import DataFrame
from sklearn.dummy import DummyClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

from dao.ower.ower_dir import OwerDir, Sample

# Config

In [None]:
class_count = 4
sent_count = 3
class_labels = ['married', 'male', 'American', 'actor']

# Load train/valid datasets

In [None]:
ower_dir_path = Path('data/ower/ower-v3-fb-irt-3')
ower_dir = OwerDir('OWER Dataset Directory', ower_dir_path, class_count, sent_count)
ower_dir.check()

train_set: List[Sample]
valid_set: List[Sample]

train_set, valid_set, _, vocab = ower_dir.read_datasets()

# Calc class frequencies

In [None]:
_, train_classes_stack, _ = zip(*train_set)
_, valid_classes_stack, _ = zip(*valid_set)

train_classes_stack = np.array(train_classes_stack)
valid_classes_stack = np.array(valid_classes_stack)

print('train_class_freqs =', train_classes_stack.mean(axis=0))
print('valid_class_freqs =', valid_classes_stack.mean(axis=0))

# Dummy Classifiers

In [None]:
for strategy in ('uniform', 'stratified', 'most_frequent'):
    print(strategy)
    
    print('train_class_freqs =', train_classes_stack.mean(axis=0))
    print('valid_class_freqs =', valid_classes_stack.mean(axis=0))

    for class_stack in (train_classes_stack, valid_classes_stack):

        mean_metrics = []
        for i, gt in enumerate(class_stack.T):

            classifier = DummyClassifier(strategy=strategy)
            classifier.fit(gt, gt)

            metrics_list = []
            for _ in range(100):
                pred = classifier.predict(gt)

                acc = accuracy_score(gt, pred)
                prec = precision_score(gt, pred, zero_division=0)
                recall = recall_score(gt, pred)
                f1 = f1_score(gt, pred)

                metrics_list.append((acc, prec, recall, f1))

            # df_cols = ['Accuracy', 'Precision', 'Recall', 'F1']
            # df = DataFrame(data=metrics_list, columns=df_cols)
            # df.plot()

            mean_metrics.append(np.mean(metrics_list, axis=0))

        df_cols = ['Accuracy', 'Precision', 'Recall', 'F1']
        df_rows = class_labels
        df_data = ((f'{100*acc:.1f}%', f'{100*prec:.1f}%', f'{100*rec:.1f}%', f'{100*f1:.1f}%')
                   for acc, prec, rec, f1 in mean_metrics)

        df = DataFrame(data=df_data, index=df_rows, columns=df_cols)
        display(df)