In [1]:
import pickle

from train_model import train_step, test_step
from utils.load_data import get_data
from utils.make_dict import train_bow, get_bow

In [2]:
args ={'dataset': 'cifar10',
       'dataroot': './data',
       'model': 'custom_SVM',
       'kernel': 'gaussian',
       'validation': 0.1,
       'C': 5.0,
       'sigma': 1.0,
       'batch': 1000,
       'dict_size': 100,
       'train': True,
       'load_cluster': False,
       'cuda': True,
       'depth': 50,
       'forest': 100,
       }

In [3]:
hyper_C = [0.1, 0.5, 1.0, 5.0, 10.0]
hyper_sigma = [1e-2, 0.1, 1.0, 5.0] 

In [4]:
trainX, trainy = get_data(dataset=args['dataset'], train=True, dataroot=args['dataroot'])

if args['dataset'] == 'cifar10':
    trainX = trainX.reshape((-1, 32, 32, 3), order='F')

if args['load_cluster']:
    with open("./cluster.dump", "rb") as f:
        cluster = pickle.load(f)
else:
    cluster = train_bow(trainX, num_dict=args['dict_size'], num_select=10000)
    with open("./cluster.dump", "wb") as f:
        pickle.dump(cluster, f)

trainFeature = get_bow(trainX, cluster, num_dict=args['dict_size'])

In [5]:
best_C = None
best_sigma = None
best_valid = 0.0

for C in hyper_C:
    for sigma in hyper_sigma:
        # Test hyperparameter
        args['C'] = C
        args['sigma'] = sigma

        # Get result
        _, train_acc_list, valid_acc_list  = \
            train_step(args, trainFeature, trainy)

        # Evaluation parameter
        tra = sum(train_acc_list) / len(train_acc_list)
        val = sum(valid_acc_list) / len(valid_acc_list)

        if val > best_valid:
            best_valid = val
            best_C = C
            best_sigma = sigma

        # Print result
        print("C: %f Sigma: %f Train accuracy: %f Valid accuracy: %f"%(C, sigma, tra, val))

print("Best C: %f Best sigma: %f"%(best_C, best_sigma))

100%|██████████| 10/10 [00:02<00:00,  4.04it/s]


C: 0.100000 Sigma: 0.100000 Train accuracy: 100.000000 Valid accuracy: 58.956000


100%|██████████| 10/10 [00:01<00:00,  6.48it/s]


C: 0.100000 Sigma: 1.000000 Train accuracy: 100.000000 Valid accuracy: 34.932000


100%|██████████| 10/10 [00:01<00:00,  6.41it/s]


C: 0.100000 Sigma: 1.500000 Train accuracy: 100.000000 Valid accuracy: 50.892000


100%|██████████| 10/10 [00:02<00:00,  3.85it/s]


C: 0.100000 Sigma: 5.000000 Train accuracy: 100.000000 Valid accuracy: 75.052000


100%|██████████| 10/10 [00:01<00:00,  6.45it/s]


C: 0.800000 Sigma: 0.100000 Train accuracy: 100.000000 Valid accuracy: 58.768000


100%|██████████| 10/10 [00:01<00:00,  6.42it/s]


C: 0.800000 Sigma: 1.000000 Train accuracy: 100.000000 Valid accuracy: 50.792000


100%|██████████| 10/10 [00:01<00:00,  6.41it/s]


C: 0.800000 Sigma: 1.500000 Train accuracy: 100.000000 Valid accuracy: 43.098000


100%|██████████| 10/10 [00:02<00:00,  4.04it/s]


C: 0.800000 Sigma: 5.000000 Train accuracy: 100.000000 Valid accuracy: 35.132000


100%|██████████| 10/10 [00:01<00:00,  6.45it/s]


C: 0.900000 Sigma: 0.100000 Train accuracy: 100.000000 Valid accuracy: 58.988000


100%|██████████| 10/10 [00:01<00:00,  6.53it/s]


C: 0.900000 Sigma: 1.000000 Train accuracy: 100.000000 Valid accuracy: 51.072000


100%|██████████| 10/10 [00:01<00:00,  6.41it/s]


C: 0.900000 Sigma: 1.500000 Train accuracy: 100.000000 Valid accuracy: 58.760000


100%|██████████| 10/10 [00:02<00:00,  4.07it/s]


C: 0.900000 Sigma: 5.000000 Train accuracy: 100.000000 Valid accuracy: 58.744000


100%|██████████| 10/10 [00:01<00:00,  6.46it/s]


C: 0.950000 Sigma: 0.100000 Train accuracy: 100.000000 Valid accuracy: 66.964000


100%|██████████| 10/10 [00:01<00:00,  6.49it/s]


C: 0.950000 Sigma: 1.000000 Train accuracy: 100.000000 Valid accuracy: 50.952000


100%|██████████| 10/10 [00:01<00:00,  6.46it/s]


C: 0.950000 Sigma: 1.500000 Train accuracy: 100.000000 Valid accuracy: 34.794000


100%|██████████| 10/10 [00:02<00:00,  4.06it/s]


C: 0.950000 Sigma: 5.000000 Train accuracy: 100.000000 Valid accuracy: 58.928000


100%|██████████| 10/10 [00:01<00:00,  6.38it/s]


C: 1.000000 Sigma: 0.100000 Train accuracy: 100.000000 Valid accuracy: 43.020000


100%|██████████| 10/10 [00:01<00:00,  6.40it/s]


C: 1.000000 Sigma: 1.000000 Train accuracy: 100.000000 Valid accuracy: 58.698000


100%|██████████| 10/10 [00:01<00:00,  6.40it/s]


C: 1.000000 Sigma: 1.500000 Train accuracy: 100.000000 Valid accuracy: 50.820000


100%|██████████| 10/10 [00:02<00:00,  4.49it/s]


C: 1.000000 Sigma: 5.000000 Train accuracy: 100.000000 Valid accuracy: 42.970000


100%|██████████| 10/10 [00:01<00:00,  6.49it/s]


C: 5.000000 Sigma: 0.100000 Train accuracy: 100.000000 Valid accuracy: 58.778000


100%|██████████| 10/10 [00:01<00:00,  6.50it/s]


C: 5.000000 Sigma: 1.000000 Train accuracy: 100.000000 Valid accuracy: 26.154000


100%|██████████| 10/10 [00:01<00:00,  6.46it/s]


C: 5.000000 Sigma: 1.500000 Train accuracy: 100.000000 Valid accuracy: 83.122000


100%|██████████| 10/10 [00:02<00:00,  4.62it/s]


C: 5.000000 Sigma: 5.000000 Train accuracy: 100.000000 Valid accuracy: 47.394000


100%|██████████| 10/10 [00:01<00:00,  6.48it/s]


C: 10.000000 Sigma: 0.100000 Train accuracy: 100.000000 Valid accuracy: 34.832000


100%|██████████| 10/10 [00:01<00:00,  6.42it/s]


C: 10.000000 Sigma: 1.000000 Train accuracy: 100.000000 Valid accuracy: 50.888000


100%|██████████| 10/10 [00:01<00:00,  6.44it/s]


C: 10.000000 Sigma: 1.500000 Train accuracy: 100.000000 Valid accuracy: 42.140000


100%|██████████| 10/10 [00:02<00:00,  4.68it/s]

C: 10.000000 Sigma: 5.000000 Train accuracy: 100.000000 Valid accuracy: 58.858000
Best C: 5.000000 Best sigma: 1.500000





In [6]:
args['C'] = best_C
args['sigma'] = best_sigma
args['part'] = False
models, train_acc_list, valid_acc_list = \
            train_step(args, trainFeature, trainy)

100%|██████████| 10/10 [00:01<00:00,  6.21it/s]


In [7]:
testX, testy = get_data(dataset=args['dataset'], train=False, dataroot=args['dataroot'])
if args['dataset'] == 'cifar10':
    testX = testX.reshape((-1, 32, 32, 3), order='F')
testFeature = get_bow(testX, cluster, num_dict=args['dict_size'])

In [8]:
test_acc_list = test_step(args, testFeature, testy, models)

 90%|█████████ | 9/10 [00:01<00:00,  5.60it/s]


In [9]:
print("Test average accuracy:", sum(test_acc_list) / len(test_acc_list))

Test average accuracy: 48.049
