# Проверка работоспособности.

Я написал 10 моделей one vs all для каждого класса.

Сейчас мы попробуем из всех них составить единый файл с ответам и закинуть на kaggle.

импорт необходимых библиотек

In [1]:
import sys

import pandas as pd
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa

sys.path.append('../../../')

from core.datasets import open_f


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



In [2]:
class MulticlassClassificationMetrics:
    def __init__(self, y_true, y_pred):
        self.y_true = np.array(y_true)
        self.y_pred = np.array(y_pred)
        self.matrix_error = self.get_matrix_error()

    def get_matrix_error(self):
        self.class_types = np.unique([self.y_true, self.y_pred])
        TP, FN, FP, TN = [], [], [], []
        for class_type in self.class_types:
            TP.append(np.sum(np.logical_and(self.y_true == self.y_pred, self.y_true == class_type)))
            FN.append(np.sum(np.logical_and(self.y_true == class_type, self.y_pred != class_type)))
            FP.append(np.sum(np.logical_and(self.y_pred == class_type, self.y_true != class_type)))
            TN.append(np.sum(np.logical_and(self.y_pred != class_type, self.y_true != class_type)))
        return pd.DataFrame({
            'class_type': self.class_types,
            'TP': TP, 'TN': TN, 'FP': FP, 'FN': FN,
        }).set_index('class_type')

    def accuracy(self):
        return np.sum(self.y_true == self.y_pred) / self.y_true.shape[0]

    def precision(self, averaging='macro'):
        if averaging == 'macro':
            precisions = self.matrix_error['TP'] / (self.matrix_error['TP'] + self.matrix_error['FP'])
            return np.mean(precisions.replace(np.nan, 0))
        elif averaging == 'micro':
            mean_val = self.matrix_error.mean()
            return mean_val['TP'] / (mean_val['TP'] + mean_val['FP'])

    def recall(self, averaging='macro'):
        if averaging == 'macro':
            recalls = self.matrix_error['TP'] / (self.matrix_error['TP'] + self.matrix_error['FN'])
            return np.mean(recalls.replace(np.nan, 0))
        elif averaging == 'micro':
            mean_val = self.matrix_error.mean()
            return mean_val['TP'] / (mean_val['TP'] + mean_val['FN'])

    def f1_score(self, averaging='macro'):
        if averaging == 'macro':
            reverse_r = (self.matrix_error['TP'] + self.matrix_error['FN']) / self.matrix_error['TP']
            reverse_p = (self.matrix_error['TP'] + self.matrix_error['FP']) / self.matrix_error['TP']
            f1_scores = 2 / (reverse_r + reverse_p)
            return np.mean(f1_scores.replace(np.nan, 0))
        elif averaging == 'micro':
            p = self.precision(averaging)
            r = self.recall(averaging)
            return 2 * p * r / (p + r)

    def metrics(self, averaging='macro'):
        return {
            'accuracy': self.accuracy(),
            'precision': self.precision(averaging),
            'recall': self.recall(averaging),
            'f1_score': self.f1_score(averaging),
        }

    def __str__(self):
        headers = {
            'selector': '*',
            'props': 'background-color: darkgreen; color: white; font-size: 12pt;',
        }
        che_super_mega_puper_visualization_of_cell = { 
            'selector': 'td:hover',
            'props': 'background-color: green; color: white;',
        }
        display(
            self.matrix_error
            .style
            .set_table_styles([
                headers,
                che_super_mega_puper_visualization_of_cell,
            ])
            .set_properties(**{'background-color': 'lightgreen',
                           'color': 'black', 'font-size': '12pt'})
        )
        return ''

    def __repr__(self):
        return (
            f'MulticlassClassificationMetrics(class_types={self.class_types})'
            .replace("'", '').replace(',', ';'))

Скачиваем тестовый датасет

In [3]:
test_ds = open_f('repaired_data_train', back=3)
shuffle = np.random.permutation(test_ds['labels'].shape[0])
test_ds_x = test_ds['images'][shuffle][:10000] / 255
test_ds_y = test_ds['labels'][shuffle][:10000]

# Рассмотрим все avg acc чекпоинты 

In [4]:
nums = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine']
models = []
for i in nums:
    models.append(tf.keras.models.load_model(f'../../checkpoints/model_{i}_avg_categorical_accuracy.h5'))

In [5]:
ans = None
for i in models:
    pred = i.predict(test_ds_x).ravel()
    if type(ans) is type(None):
        ans = pred
    else:
        ans = np.column_stack((ans, pred))



In [6]:
full_ans = np.argmax(ans, axis=-1)

In [7]:
metrics = MulticlassClassificationMetrics(test_ds_y, full_ans)
print(metrics)
metrics.metrics()

Unnamed: 0_level_0,TP,TN,FP,FN
class_type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,638,9359,2,1
1,1889,8103,4,4
2,1524,8471,2,3
3,1121,8872,2,5
4,1015,8979,3,3
5,958,9036,4,2
6,786,9207,4,3
7,756,9241,2,1
8,663,9331,2,4
9,623,9374,2,1





{'accuracy': 0.9973,
 'precision': 0.9970656100703337,
 'recall': 0.9972163733742958,
 'f1_score': 0.9971402629008501}

# Сперва рассмотрим все acc чекпоинты 

In [8]:
nums = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine']
models = []
for i in nums:
    models.append(tf.keras.models.load_model(f'../../checkpoints/model_{i}_categorical_accuracy.h5'))

In [9]:
ans = None
for i in models:
    pred = i.predict(test_ds_x).ravel()
    if type(ans) is type(None):
        ans = pred
    else:
        ans = np.column_stack((ans, pred))



In [10]:
full_ans = np.argmax(ans, axis=-1)

In [11]:
metrics = MulticlassClassificationMetrics(test_ds_y, full_ans)
print(metrics)
metrics.metrics()

Unnamed: 0_level_0,TP,TN,FP,FN
class_type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,638,9358,3,1
1,1888,8103,4,5
2,1523,8466,7,4
3,1120,8870,4,6
4,1015,8977,5,3
5,958,9034,6,2
6,784,9206,5,5
7,754,9241,2,3
8,658,9332,1,9
9,623,9374,2,1





{'accuracy': 0.9961,
 'precision': 0.9960245649510334,
 'recall': 0.9957419379364476,
 'f1_score': 0.9958785800257404}

# Рассмотрим все avg loss чекпоинты 

In [12]:
nums = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine']
models = []
for i in nums:
    models.append(tf.keras.models.load_model(f'../../checkpoints/model_{i}_avg_loss.h5'))

In [13]:
ans = None
for i in models:
    pred = i.predict(test_ds_x).ravel()
    if type(ans) is type(None):
        ans = pred
    else:
        ans = np.column_stack((ans, pred))



In [14]:
full_ans = np.argmax(ans, axis=-1)

In [15]:
metrics = MulticlassClassificationMetrics(test_ds_y, full_ans)
print(metrics)
metrics.metrics()

Unnamed: 0_level_0,TP,TN,FP,FN
class_type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,638,9359,2,1
1,1884,8099,8,9
2,1522,8470,3,5
3,1119,8872,2,7
4,1013,8977,5,5
5,959,9034,6,1
6,787,9206,5,2
7,755,9240,3,2
8,663,9332,1,4
9,623,9374,2,1





{'accuracy': 0.9963,
 'precision': 0.9962789188768684,
 'recall': 0.9965459920308402,
 'f1_score': 0.9964102073557164}

# Рассмотрим все loss чекпоинты

In [16]:
nums = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine']
models = []
for i in nums:
    models.append(tf.keras.models.load_model(f'../../checkpoints/model_{i}_loss.h5'))

In [17]:
ans = None
for i in models:
    pred = i.predict(test_ds_x).ravel()
    if type(ans) is type(None):
        ans = pred
    else:
        ans = np.column_stack((ans, pred))



In [18]:
full_ans = np.argmax(ans, axis=-1)

In [19]:
metrics = MulticlassClassificationMetrics(test_ds_y, full_ans)
print(metrics)
metrics.metrics()

Unnamed: 0_level_0,TP,TN,FP,FN
class_type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,638,9357,4,1
1,1880,8101,6,13
2,1517,8468,5,10
3,1118,8870,4,8
4,1014,8974,8,4
5,959,9034,6,1
6,786,9203,8,3
7,755,9239,4,2
8,661,9333,0,6
9,622,9371,5,2





{'accuracy': 0.995,
 'precision': 0.9946372374872301,
 'recall': 0.9954298204828342,
 'f1_score': 0.9950272362747923}