In [55]:
import string
import numpy as np
import matplotlib.pyplot as plt
import os
from prettytable import PrettyTable
from IPython.display import Latex
from keras.datasets import cifar10
import tensorflow as tf

In [56]:
seed = 99
np.random.seed(seed)
tf.set_random_seed(seed)
clean_data_size = 200

# Load the CIFAR10 data.
(x_train_, y_train_), (x_test, y_test) = cifar10.load_data()

# Seperate validation set
y_validation = y_train_[0:10000]
y_train = y_train_[10000:50000]

# Generate clean dataset
clean_index = []
for label in range(10):
    positive_index = list(np.where(y_train == label)[0])
    clean_index = np.append(clean_index, np.random.choice(positive_index, clean_data_size, replace=False)).astype(int)

y_train = np.delete(y_train, clean_index, axis=0)

In [210]:
print('Validation accuracy and (test accuracy) of pre-activation ResNet20 on CIFAR-10.')
lr_list = [0.01, 0.003, 0.001, 0.0003, 0.0001]
noise_list = [0, 0.3, 0.5, 0.7, 0.8, 0.9]
table = PrettyTable([' ', 'noise=0%', '30%', '50%', '70%', '80%', '90%'])
for lr_i in range(len(lr_list)):
    test_acc = ['/' for _ in range(len(noise_list))]
    validation_acc = ['/' for _ in range(len(noise_list))]
    for noise_j in range(len(noise_list)):
        dirs = 'saved_precision/benchmark_ResNet20_pre-activation_lr_%.4f_noise_%.1f'%(lr_list[lr_i], noise_list[noise_j])
        
        if not os.path.exists(dirs + '/accuracy_file0.txt'):
            continue
        else:
            file = open(dirs + '/accuracy_file0.txt')
            lines = file.readlines()
            test_acc[noise_j] = lines[5]
            validation_acc[noise_j] = np.max(eval(lines[3]))
    table.add_row(['lr='+str(lr_list[lr_i])] + [str(validation_acc[i])+'('+str(test_acc[i])+')' for i in range(len(noise_list))])
print(table)


Validation accuracy and (test accuracy) of pre-activation ResNet20 on CIFAR-10.
+-----------+----------------+----------------+----------------+----------------+----------------+----------------+
|           |    noise=0%    |      30%       |      50%       |      70%       |      80%       |      90%       |
+-----------+----------------+----------------+----------------+----------------+----------------+----------------+
|  lr=0.01  |      /(/)      |      /(/)      |      /(/)      |      /(/)      |      /(/)      |      /(/)      |
|  lr=0.003 | 0.8996(0.8945) | 0.8065(0.7975) | 0.6993(0.6915) | 0.5035(0.4894) |      /(/)      |      /(/)      |
|  lr=0.001 | 0.9225(0.9136) | 0.7951(0.7884) | 0.7186(0.7172) | 0.5927(0.5917) | 0.4351(0.434)  | 0.2122(0.2127) |
| lr=0.0003 | 0.9208(0.912)  | 0.7302(0.7252) | 0.6921(0.6835) | 0.5333(0.5222) | 0.3895(0.3955) | 0.2157(0.2099) |
| lr=0.0001 |      /(/)      |      /(/)      |      /(/)      | 0.4931(0.4849) |      /(/)      | 0.1842(0.

In [227]:
for positive_threshold in [0.95]:
    for add_criterion in [90, 95]:
        for learning_rate in [0.0001]:
            dirs = 'saved_precision/positive_threshold_'+str(positive_threshold)+'_add_criterion_'+str(add_criterion)+\
                    '_learning_rate_'+str(learning_rate)
            if not os.path.exists(dirs):
                continue
            print('positive threshold: ', positive_threshold)
            print('add criterion: ', add_criterion)
            print('learning rate: ', learning_rate)
            validation_precision_matrix = [['/' for _ in range(10)] for _ in range(10)]
            validation_number_matrix = [['/' for _ in range(10)] for _ in range(10)]
            training_precision_matrix = [['/' for _ in range(10)] for _ in range(10)]
            training_number_matrix = [['/' for _ in range(10)] for _ in range(10)]
            for label in range(10):
                try:
                    validation_file = open(dirs + '/validation_label%d.txt'%label)
                    training_file = open(dirs + '/training_label%d.txt'%label)
                    lines = training_file.readlines()
                    training_precision = '['
                    for line in lines:
                        training_precision += line.replace('\n','').replace(' ', ',')
                    training_precision += ']'
                    training_precision = eval(training_precision.replace('][','], [').replace(',,,,', ',').replace(',,,', ',').replace(',,', ',').replace('[,', '['))
                    true_positive_index = list(np.where(y_train == label)[0])
                    for k in range(len(training_precision)):
                        index = training_precision[k]
                        TP = len(list(set(index) & set(true_positive_index)))
                        if len(index) > 0:
                            training_precision_matrix[k][label] = np.round(TP/len(index), 3)
                        else:
                            training_precision_matrix[k][label] = 0
                        training_number_matrix[k][label] = len(index)
                    lines = validation_file.readlines()
                    validation_precision = '['
                    for line in lines:
                        validation_precision += line.replace('\n','').replace(' ', ',')
                    validation_precision += ']'
                    validation_precision = eval(validation_precision.replace('][','], [').replace(',,,', ',').replace(',,', ',').replace('[,', '['))
                    true_positive_index = list(np.where(y_validation == label)[0])
                    for k in range(len(validation_precision)):
                        index = validation_precision[k]
                        TP = len(list(set(index) & set(true_positive_index)))
                        if len(index) > 0:
                            validation_precision_matrix[k][label] = np.round(TP/len(index), 3)
                        else:
                            validation_precision_matrix[k][label] = 0
                        validation_number_matrix[k][label] = len(index)
                except:
                    pass
            print('\nResult on validation set')        
            table = PrettyTable(['k', 'airplane','auto','bird','cat','deer','dog','frog','horse','ship','truck'])
            for k in range(10):
                table.add_row([k]+[str(validation_precision_matrix[k][i])+'('+str(validation_number_matrix[k][i])+')' for i in range(10)])
#             print(table)
            print('Result on training set')        
            table = PrettyTable(['k', 'airplane','auto','bird','cat','deer','dog','frog','horse','ship','truck'])
            for k in range(10):
                table.add_row([k]+[str(training_precision_matrix[k][i])+'('+str(training_number_matrix[k][i])+')' for i in range(10)])
            print(table, '\n')


positive threshold:  0.95
add criterion:  90
learning rate:  0.0001

Result on validation set
Result on training set
+---+-------------+-------------+------------+------------+------------+------+------+-------+------+-------+
| k |   airplane  |     auto    |    bird    |    cat     |    deer    | dog  | frog | horse | ship | truck |
+---+-------------+-------------+------------+------------+------------+------+------+-------+------+-------+
| 0 | 0.807(1193) | 0.867(1331) | 0.95(100)  | 0.638(127) | 0.798(173) | /(/) | /(/) |  /(/) | /(/) |  /(/) |
| 1 | 0.788(1544) | 0.861(1825) | 0.93(158)  | 0.598(249) | 0.801(292) | /(/) | /(/) |  /(/) | /(/) |  /(/) |
| 2 | 0.778(1665) | 0.865(2000) | 0.93(213)  | 0.552(366) | 0.784(366) | /(/) | /(/) |  /(/) | /(/) |  /(/) |
| 3 | 0.772(1734) | 0.882(2000) | 0.922(270) | 0.559(399) | 0.769(407) | /(/) | /(/) |  /(/) | /(/) |  /(/) |
| 4 | 0.773(1763) |  0.89(2000) | 0.919(320) | 0.55(438)  | 0.77(435)  | /(/) | /(/) |  /(/) | /(/) |  /(/) |
| 5

In [220]:
def assemble_additional_data(positive_threshold, add_criterion, learning_rate, label, index, file):
    file = open(file, 'a+')
    dirs = 'saved_precision/positive_threshold_'+str(positive_threshold)+'_add_criterion_'+str(add_criterion)+\
                    '_learning_rate_'+str(learning_rate)
    
    training_file = open(dirs + '/training_label%d.txt'%label)
    lines = training_file.readlines()
    training_precision = '['
    for line in lines:
        training_precision += line.replace('\n','').replace(' ', ',')
    training_precision += ']'
    training_precision = eval(training_precision.replace('][','], [').replace(',,,,', ',').replace(',,,', ',').replace(',,', ',').replace('[,', '['))
    file.write(str(training_precision[index]) + '\n')
    file.close()

In [222]:
assemble_additional_data(0.95, 95, 0.0001, 0, 9, 'additional_data_index.txt')
assemble_additional_data(0.95, 95, 0.0001, 1, 9, 'additional_data_index.txt')
assemble_additional_data(0.95, 90, 0.0001, 2, 9, 'additional_data_index.txt')
assemble_additional_data(0.95, 90, 0.0001, 3, 9, 'additional_data_index.txt')
assemble_additional_data(0.95, 90, 0.0001, 4, 9, 'additional_data_index.txt')
assemble_additional_data(0.95, 95, 0.0001, 5, 9, 'additional_data_index.txt')
assemble_additional_data(0.95, 95, 0.0001, 6, 1, 'additional_data_index.txt')
assemble_additional_data(0.95, 95, 0.0001, 7, 7, 'additional_data_index.txt')
assemble_additional_data(0.95, 95, 0.0001, 8, 5, 'additional_data_index.txt')
assemble_additional_data(0.95, 95, 0.0001, 9, 3, 'additional_data_index.txt')

In [228]:
file = open('additional_data_index.txt')
lines = file.readlines()
precision = '['
for line in lines:
    precision += line.replace('\n','').replace(' ', ',')
precision += ']'
precision = eval(precision.replace('][','], [').replace(',,,,', ',').replace(',,,', ',').replace(',,', ',').replace('[,', '['))

table = PrettyTable(['label', 'precision', 'number of additional data'])
label_list = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

for label in range(10):     
    true_positive_index = list(np.where(y_train == label)[0])
    index = precision[label]
    TP = len(list(set(index) & set(true_positive_index)))
    table.add_row([label_list[label], np.round(TP/len(index), 3), len(index)])

print(table)

+------------+-----------+---------------------------+
|   label    | precision | number of additional data |
+------------+-----------+---------------------------+
|  airplane  |    0.88   |            1115           |
| automobile |    0.93   |            1675           |
|    bird    |   0.915   |            448            |
|    cat     |   0.548   |            529            |
|    deer    |   0.759   |            523            |
|    dog     |   0.839   |            597            |
|    frog    |   0.865   |            2000           |
|   horse    |   0.847   |            1898           |
|    ship    |    0.95   |            952            |
|   truck    |   0.854   |            1916           |
+------------+-----------+---------------------------+
