In [55]:
import string
import numpy as np
import matplotlib.pyplot as plt
import os
from prettytable import PrettyTable
from IPython.display import Latex
from keras.datasets import cifar10
import tensorflow as tf

In [56]:
seed = 99
np.random.seed(seed)
tf.set_random_seed(seed)
clean_data_size = 200

# Load the CIFAR10 data.
(x_train_, y_train_), (x_test, y_test) = cifar10.load_data()

# Seperate validation set
y_validation = y_train_[0:10000]
y_train = y_train_[10000:50000]

# Generate clean dataset
clean_index = []
for label in range(10):
    positive_index = list(np.where(y_train == label)[0])
    clean_index = np.append(clean_index, np.random.choice(positive_index, clean_data_size, replace=False)).astype(int)

y_train = np.delete(y_train, clean_index, axis=0)

In [161]:
for positive_threshold in ['0.90', 0.95]:
    for add_criterion in [90, 95]:
        for learning_rate in ['0.00003', 0.0001, 0.0003]:
            dirs = 'saved_precision/positive_threshold_'+str(positive_threshold)+'_add_criterion_'+str(add_criterion)+\
                    '_learning_rate_'+str(learning_rate)
            if not os.path.exists(dirs):
                continue
        #     file = open(dirs + 'parameter.txt')
        #     lines = file.readlines()
        #     for line in lines:  
        #         print(line.split()[0])
            print('positive threshold: ', positive_threshold)
            print('add criterion: ', add_criterion)
            print('learning rate: ', learning_rate)
            validation_precision_matrix = [['/' for _ in range(10)] for _ in range(10)]
            validation_number_matrix = [['/' for _ in range(10)] for _ in range(10)]
            training_precision_matrix = [['/' for _ in range(10)] for _ in range(10)]
            training_number_matrix = [['/' for _ in range(10)] for _ in range(10)]

            for label in range(10):
                try:
                    validation_file = open(dirs + '/validation_label%d.txt'%label)
                    training_file = open(dirs + '/training_label%d.txt'%label)

                    lines = training_file.readlines()
                    training_precision = '['
                    for line in lines:
                        training_precision += line.replace('\n','').replace(' ', ',')
                    training_precision += ']'
                    training_precision = eval(training_precision.replace('][','], [').replace(',,,,', ',').replace(',,,', ',').replace(',,', ',').replace('[,', '['))

                    true_positive_index = list(np.where(y_train == label)[0])
                    for k in range(len(training_precision)):
                        index = training_precision[k]
                        TP = len(list(set(index) & set(true_positive_index)))
                        if len(index) > 0:
                            training_precision_matrix[k][label] = np.round(TP/len(index), 3)
                        else:
                            training_precision_matrix[k][label] = 0
                        training_number_matrix[k][label] = len(index)

                    lines = validation_file.readlines()
                    validation_precision = '['
                    for line in lines:
                        validation_precision += line.replace('\n','').replace(' ', ',')
                    validation_precision += ']'
                    validation_precision = eval(validation_precision.replace('][','], [').replace(',,,', ',').replace(',,', ',').replace('[,', '['))

                    true_positive_index = list(np.where(y_validation == label)[0])
                    for k in range(len(validation_precision)):
                        index = validation_precision[k]
                        TP = len(list(set(index) & set(true_positive_index)))
                        if len(index) > 0:
                            validation_precision_matrix[k][label] = np.round(TP/len(index), 3)
                        else:
                            validation_precision_matrix[k][label] = 0
                        validation_number_matrix[k][label] = len(index)
                except:
                    pass



            print('\nResult on validation set')        

            table = PrettyTable(['iteration', 'airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck'])
            for k in range(10):
                table.add_row([k]+[str(validation_precision_matrix[k][i])+'('+str(validation_number_matrix[k][i])+')' for i in range(10)])
            print(table)

            print('Result on training set')        

            table = PrettyTable(['iteration', 'airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck'])
            for k in range(10):
                table.add_row([k]+[str(training_precision_matrix[k][i])+'('+str(training_number_matrix[k][i])+')' for i in range(10)])
            print(table, '\n')


positive threshold:  0.90
add criterion:  90
learning rate:  0.0001

Result on validation set
+-----------+------------+------------+------------+------+------------+------+------+-------+------+-------+
| iteration |  airplane  | automobile |    bird    | cat  |    deer    | dog  | frog | horse | ship | truck |
+-----------+------------+------------+------------+------+------------+------+------+-------+------+-------+
|     0     | 0.73(381)  |    /(/)    |  0.9(30)   | /(/) | 0.784(97)  | /(/) | /(/) |  /(/) | /(/) |  /(/) |
|     1     | 0.729(451) |    /(/)    | 0.948(58)  | /(/) | 0.748(135) | /(/) | /(/) |  /(/) | /(/) |  /(/) |
|     2     | 0.711(505) |    /(/)    | 0.919(86)  | /(/) | 0.733(150) | /(/) | /(/) |  /(/) | /(/) |  /(/) |
|     3     | 0.71(507)  |    /(/)    | 0.892(102) | /(/) | 0.733(165) | /(/) | /(/) |  /(/) | /(/) |  /(/) |
|     4     | 0.709(519) |    /(/)    | 0.89(109)  | /(/) | 0.718(177) | /(/) | /(/) |  /(/) | /(/) |  /(/) |
|     5     |    /(/)    |


Result on validation set
+-----------+------------+------------+-----------+-----------+-----------+------------+------+-------+------+-------+
| iteration |  airplane  | automobile |    bird   |    cat    |    deer   |    dog     | frog | horse | ship | truck |
+-----------+------------+------------+-----------+-----------+-----------+------------+------+-------+------+-------+
|     0     | 0.86(129)  | 0.87(269)  | 0.846(13) |  0.64(25) | 0.862(29) | 0.776(76)  | /(/) |  /(/) | /(/) |  /(/) |
|     1     | 0.859(213) | 0.896(376) | 0.963(27) | 0.702(57) | 0.837(43) | 0.761(134) | /(/) |  /(/) | /(/) |  /(/) |
|     2     | 0.854(274) | 0.887(432) | 0.925(40) | 0.627(75) | 0.783(60) | 0.716(162) | /(/) |  /(/) | /(/) |  /(/) |
|     3     | 0.843(286) | 0.901(474) | 0.976(42) | 0.603(78) | 0.766(64) |    /(/)    | /(/) |  /(/) | /(/) |  /(/) |
|     4     | 0.834(308) | 0.893(497) |  0.96(50) | 0.546(97) | 0.753(73) |    /(/)    | /(/) |  /(/) | /(/) |  /(/) |
|     5     | 0.825(31


Result on validation set
+-----------+------------+------------+------+-----------+-----------+------------+------+-------+------+-------+
| iteration |  airplane  | automobile | bird |    cat    |    deer   |    dog     | frog | horse | ship | truck |
+-----------+------------+------------+------+-----------+-----------+------------+------+-------+------+-------+
|     0     | 0.887(53)  | 0.887(204) | /(/) | 0.583(12) |  1.0(13)  | 0.778(27)  | /(/) |  /(/) | /(/) |  /(/) |
|     1     | 0.886(132) | 0.899(268) | /(/) | 0.771(35) | 0.852(27) | 0.828(64)  | /(/) |  /(/) | /(/) |  /(/) |
|     2     | 0.864(184) | 0.914(325) | /(/) | 0.696(46) | 0.875(32) | 0.798(104) | /(/) |  /(/) | /(/) |  /(/) |
|     3     | 0.853(217) | 0.916(347) | /(/) | 0.615(65) | 0.884(43) |    /(/)    | /(/) |  /(/) | /(/) |  /(/) |
|     4     | 0.843(223) | 0.915(364) | /(/) | 0.614(70) | 0.843(51) |    /(/)    | /(/) |  /(/) | /(/) |  /(/) |
|     5     | 0.838(240) | 0.928(405) | /(/) | 0.556(72) | 0.8