In [55]:
import string
import numpy as np
import matplotlib.pyplot as plt
import os
from prettytable import PrettyTable
from IPython.display import Latex
from keras.datasets import cifar10
import tensorflow as tf

In [56]:
seed = 99
np.random.seed(seed)
tf.set_random_seed(seed)
clean_data_size = 200

# Load the CIFAR10 data.
(x_train_, y_train_), (x_test, y_test) = cifar10.load_data()

# Seperate validation set
y_validation = y_train_[0:10000]
y_train = y_train_[10000:50000]

# Generate clean dataset
clean_index = []
for label in range(10):
    positive_index = list(np.where(y_train == label)[0])
    clean_index = np.append(clean_index, np.random.choice(positive_index, clean_data_size, replace=False)).astype(int)

y_train = np.delete(y_train, clean_index, axis=0)

In [183]:
for positive_threshold in ['0.90', 0.95]:
    for add_criterion in [90, 95]:
        for learning_rate in ['0.00003', 0.0001]:
            dirs = 'saved_precision/positive_threshold_'+str(positive_threshold)+'_add_criterion_'+str(add_criterion)+\
                    '_learning_rate_'+str(learning_rate)
            if not os.path.exists(dirs):
                continue
            print('positive threshold: ', positive_threshold)
            print('add criterion: ', add_criterion)
            print('learning rate: ', learning_rate)
            validation_precision_matrix = [['/' for _ in range(10)] for _ in range(10)]
            validation_number_matrix = [['/' for _ in range(10)] for _ in range(10)]
            training_precision_matrix = [['/' for _ in range(10)] for _ in range(10)]
            training_number_matrix = [['/' for _ in range(10)] for _ in range(10)]
            for label in range(10):
                try:
                    validation_file = open(dirs + '/validation_label%d.txt'%label)
                    training_file = open(dirs + '/training_label%d.txt'%label)
                    lines = training_file.readlines()
                    training_precision = '['
                    for line in lines:
                        training_precision += line.replace('\n','').replace(' ', ',')
                    training_precision += ']'
                    training_precision = eval(training_precision.replace('][','], [').replace(',,,,', ',').replace(',,,', ',').replace(',,', ',').replace('[,', '['))
                    true_positive_index = list(np.where(y_train == label)[0])
                    for k in range(len(training_precision)):
                        index = training_precision[k]
                        TP = len(list(set(index) & set(true_positive_index)))
                        if len(index) > 0:
                            training_precision_matrix[k][label] = np.round(TP/len(index), 3)
                        else:
                            training_precision_matrix[k][label] = 0
                        training_number_matrix[k][label] = len(index)
                    lines = validation_file.readlines()
                    validation_precision = '['
                    for line in lines:
                        validation_precision += line.replace('\n','').replace(' ', ',')
                    validation_precision += ']'
                    validation_precision = eval(validation_precision.replace('][','], [').replace(',,,', ',').replace(',,', ',').replace('[,', '['))
                    true_positive_index = list(np.where(y_validation == label)[0])
                    for k in range(len(validation_precision)):
                        index = validation_precision[k]
                        TP = len(list(set(index) & set(true_positive_index)))
                        if len(index) > 0:
                            validation_precision_matrix[k][label] = np.round(TP/len(index), 3)
                        else:
                            validation_precision_matrix[k][label] = 0
                        validation_number_matrix[k][label] = len(index)
                except:
                    pass
            print('\nResult on validation set')        
            table = PrettyTable(['k', 'airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck'])
            for k in range(10):
                table.add_row([k]+[str(validation_precision_matrix[k][i])+'('+str(validation_number_matrix[k][i])+')' for i in range(10)])
            print(table)
            print('Result on training set')        
            table = PrettyTable(['k', 'airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck'])
            for k in range(10):
                table.add_row([k]+[str(training_precision_matrix[k][i])+'('+str(training_number_matrix[k][i])+')' for i in range(10)])
            print(table, '\n')


positive threshold:  0.90
add criterion:  90
learning rate:  0.0001

Result on validation set
+---+------------+------------+------------+------+------------+------+------+-------+------+-------+
| k |  airplane  | automobile |    bird    | cat  |    deer    | dog  | frog | horse | ship | truck |
+---+------------+------------+------------+------+------------+------+------+-------+------+-------+
| 0 | 0.73(381)  |    /(/)    |  0.9(30)   | /(/) | 0.784(97)  | /(/) | /(/) |  /(/) | /(/) |  /(/) |
| 1 | 0.729(451) |    /(/)    | 0.948(58)  | /(/) | 0.748(135) | /(/) | /(/) |  /(/) | /(/) |  /(/) |
| 2 | 0.711(505) |    /(/)    | 0.919(86)  | /(/) | 0.733(150) | /(/) | /(/) |  /(/) | /(/) |  /(/) |
| 3 | 0.71(507)  |    /(/)    | 0.892(102) | /(/) | 0.733(165) | /(/) | /(/) |  /(/) | /(/) |  /(/) |
| 4 | 0.709(519) |    /(/)    | 0.89(109)  | /(/) | 0.718(177) | /(/) | /(/) |  /(/) | /(/) |  /(/) |
| 5 |    /(/)    |    /(/)    |    /(/)    | /(/) | 0.709(172) | /(/) | /(/) |  /(/) | /(/


Result on validation set
+---+------------+------------+------+-----------+-----------+------------+------------+------------+------------+------------+
| k |  airplane  | automobile | bird |    cat    |    deer   |    dog     |    frog    |   horse    |    ship    |   truck    |
+---+------------+------------+------+-----------+-----------+------------+------------+------------+------------+------------+
| 0 | 0.887(53)  | 0.887(204) | /(/) | 0.583(12) |  1.0(13)  | 0.778(27)  | 0.921(379) | 0.94(216)  | 0.941(102) | 0.885(252) |
| 1 | 0.886(132) | 0.899(268) | /(/) | 0.771(35) | 0.852(27) | 0.828(64)  | 0.857(615) | 0.926(323) | 0.95(159)  | 0.892(397) |
| 2 | 0.864(184) | 0.914(325) | /(/) | 0.696(46) | 0.875(32) | 0.798(104) | 0.83(675)  | 0.901(395) | 0.954(195) | 0.88(475)  |
| 3 | 0.853(217) | 0.916(347) | /(/) | 0.615(65) | 0.884(43) | 0.786(126) | 0.815(697) | 0.885(436) | 0.946(221) | 0.864(515) |
| 4 | 0.843(223) | 0.915(364) | /(/) | 0.614(70) | 0.843(51) | 0.772(127) | 0.