In [55]:
import string
import numpy as np
import matplotlib.pyplot as plt
import os
from prettytable import PrettyTable
from IPython.display import Latex
from keras.datasets import cifar10
import tensorflow as tf

In [56]:
seed = 99
np.random.seed(seed)
tf.set_random_seed(seed)
clean_data_size = 200

# Load the CIFAR10 data.
(x_train_, y_train_), (x_test, y_test) = cifar10.load_data()

# Seperate validation set
y_validation = y_train_[0:10000]
y_train = y_train_[10000:50000]

# Generate clean dataset
clean_index = []
for label in range(10):
    positive_index = list(np.where(y_train == label)[0])
    clean_index = np.append(clean_index, np.random.choice(positive_index, clean_data_size, replace=False)).astype(int)

y_train = np.delete(y_train, clean_index, axis=0)

In [120]:
for positive_threshold in [0.85, '0.90', 0.95]:
    for add_criterion in [90, 95]:
        for learning_rate in [0.0001, 0.0003]:
            dirs = 'saved_precision/positive_threshold_'+str(positive_threshold)+'_add_criterion_'+str(add_criterion)+\
                    '_learning_rate_'+str(np.round(learning_rate,4))
            if not os.path.exists(dirs):
                continue
        #     file = open(dirs + 'parameter.txt')
        #     lines = file.readlines()
        #     for line in lines:
        #         print(line.split()[0])
            print('positive threshold: ', positive_threshold)
            print('add criterion: ', add_criterion)
            print('learning rate: ', learning_rate)
            validation_precision_matrix = [['/' for _ in range(10)] for _ in range(10)]
            validation_number_matrix = [['/' for _ in range(10)] for _ in range(10)]
            training_precision_matrix = [['/' for _ in range(10)] for _ in range(10)]
            training_number_matrix = [['/' for _ in range(10)] for _ in range(10)]

            for label in range(10):
                try:
                    validation_file = open(dirs + '/validation_label%d.txt'%label)
                    training_file = open(dirs + '/training_label%d.txt'%label)

                    lines = training_file.readlines()
                    training_precision = '['
                    for line in lines:
                        training_precision += line.replace('\n','').replace(' ', ',')
                    training_precision += ']'
                    training_precision = eval(training_precision.replace('][','], [').replace(',,,,', ',').replace(',,,', ',').replace(',,', ',').replace('[,', '['))

                    true_positive_index = list(np.where(y_train == label)[0])
                    for k in range(len(training_precision)):
                        index = training_precision[k]
                        TP = len(list(set(index) & set(true_positive_index)))
                        if len(index) > 0:
                            training_precision_matrix[k][label] = np.round(TP/len(index), 3)
                        else:
                            training_precision_matrix[k][label] = 0
                        training_number_matrix[k][label] = len(index)

                    lines = validation_file.readlines()
                    validation_precision = '['
                    for line in lines:
                        validation_precision += line.replace('\n','').replace(' ', ',')
                    validation_precision += ']'
                    validation_precision = eval(validation_precision.replace('][','], [').replace(',,,', ',').replace(',,', ',').replace('[,', '['))

                    true_positive_index = list(np.where(y_validation == label)[0])
                    for k in range(len(validation_precision)):
                        index = validation_precision[k]
                        TP = len(list(set(index) & set(true_positive_index)))
                        if len(index) > 0:
                            validation_precision_matrix[k][label] = np.round(TP/len(index), 3)
                        else:
                            validation_precision_matrix[k][label] = 0
                        validation_number_matrix[k][label] = len(index)
                except:
                    pass



            print('\nResult on validation set')        

            table = PrettyTable(['iteration', 'airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck'])
            for k in range(6):
                table.add_row([k]+[str(validation_precision_matrix[k][i])+'('+str(validation_number_matrix[k][i])+')' for i in range(10)])
            print(table)

            print('\nResult on training set')        

            table = PrettyTable(['iteration', 'airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck'])
            for k in range(6):
                table.add_row([k]+[str(training_precision_matrix[k][i])+'('+str(training_number_matrix[k][i])+')' for i in range(10)])
            print(table, '\n')


positive threshold:  0.85
add criterion:  95
learning rate:  0.0001

Result on validation set
+-----------+------------+------------+------+------+------+------+------+-------+------+-------+
| iteration |  airplane  | automobile | bird | cat  | deer | dog  | frog | horse | ship | truck |
+-----------+------------+------------+------+------+------+------+------+-------+------+-------+
|     0     | 0.766(274) |    /(/)    | /(/) | /(/) | /(/) | /(/) | /(/) |  /(/) | /(/) |  /(/) |
|     1     | 0.76(354)  |    /(/)    | /(/) | /(/) | /(/) | /(/) | /(/) |  /(/) | /(/) |  /(/) |
|     2     | 0.748(377) |    /(/)    | /(/) | /(/) | /(/) | /(/) | /(/) |  /(/) | /(/) |  /(/) |
|     3     | 0.76(371)  |    /(/)    | /(/) | /(/) | /(/) | /(/) | /(/) |  /(/) | /(/) |  /(/) |
|     4     | 0.764(386) |    /(/)    | /(/) | /(/) | /(/) | /(/) | /(/) |  /(/) | /(/) |  /(/) |
|     5     |    /(/)    |    /(/)    | /(/) | /(/) | /(/) | /(/) | /(/) |  /(/) | /(/) |  /(/) |
+-----------+-----------


Result on validation set
+-----------+------------+------------+------+------+------+------+------+-------+------+-------+
| iteration |  airplane  | automobile | bird | cat  | deer | dog  | frog | horse | ship | truck |
+-----------+------------+------------+------+------+------+------+------+-------+------+-------+
|     0     | 0.752(303) |    /(/)    | /(/) | /(/) | /(/) | /(/) | /(/) |  /(/) | /(/) |  /(/) |
|     1     | 0.762(387) |    /(/)    | /(/) | /(/) | /(/) | /(/) | /(/) |  /(/) | /(/) |  /(/) |
|     2     | 0.751(422) |    /(/)    | /(/) | /(/) | /(/) | /(/) | /(/) |  /(/) | /(/) |  /(/) |
|     3     | 0.738(446) |    /(/)    | /(/) | /(/) | /(/) | /(/) | /(/) |  /(/) | /(/) |  /(/) |
|     4     | 0.739(449) |    /(/)    | /(/) | /(/) | /(/) | /(/) | /(/) |  /(/) | /(/) |  /(/) |
|     5     |    /(/)    |    /(/)    | /(/) | /(/) | /(/) | /(/) | /(/) |  /(/) | /(/) |  /(/) |
+-----------+------------+------------+------+------+------+------+------+-------+------+---