In [1]:
from net.hcq import HCQ
import numpy as np
from net.data import get_dataloader
import torch
import time
import json
from collections import OrderedDict
import os

## Acc Loss: 1.0

In [None]:
acc_loss = 1.0
linkage_name='single'
file_path = 'model/1.1 Pipeline Experiments/acc loss {} linkage {}'.format(str(acc_loss), linkage_name)
if not os.path.exists(file_path):
    os.makedirs(file_path)
train_loader, valid_loader, test_loader = get_dataloader(type='cifar10', data_root='../data', batch_size=64)
net = HCQ(10)
net.load_from_pretrained_model('model/resnet18_cifar10_baseline.pth')
valid_acc = net.compute_acc(valid_loader)
test_acc = net.compute_acc(test_loader)
print('Orignal valid accuracy:{:.2f}% | Original Test accuracy:{:.2f}% | Minimal valid accuracy:{:.2f}%'.format(valid_acc, test_acc, valid_acc - acc_loss))

'''
Use designed linkage function to quantize all layers of alexnet:
'''
cluster_num_dict = OrderedDict()
acc_dict = OrderedDict()
time_dict = OrderedDict()
for lower_bound in range(5, 2, -1):
    print('\n########################## Lower bound {} ##########################\n'.format(lower_bound))
    for layer_name in net.layers:
        start_time = time.time()
        print('---------------------- Quantize Layer {} ------------------------'.format(layer_name))
        code_book, weights = net.hcq_initialization(layer_name, 'single')
        total_cluster_num = []
        total_acc = []
        for iteration in range(10):
            max_idx = np.max(code_book)
            code_book, centroids, clusters_num_list, acc_list = net.quantize_layer_under_acc_loss(
                layer_name=layer_name,
                code_book=code_book,
                linkage_name=linkage_name,
                clusters_num_lower_bound=2**lower_bound,
                baseline_acc=valid_acc - acc_loss,
                valid_loader=valid_loader)
            if max_idx == np.max(code_book):
                break
            net.fine_tune(layer_name, code_book, 1, 1e-4, train_loader, valid_loader, test_loader, 
                          show_interval = 50, sample_rate=0.2)
            total_cluster_num = np.concatenate((total_cluster_num, clusters_num_list))
            total_acc = np.concatenate((total_acc, acc_list))
        cluster_num_dict[layer_name]= list(total_cluster_num)
        acc_dict[layer_name] = list(total_acc)
        time_dict[layer_name] = time.time() - start_time
        net.save_model(os.path.join(file_path, 'model.pth'))

    '''
    Record Data
    '''
    with open(os.path.join(file_path, 'acc_dict.json'), 'w') as f:
        f.write(json.dumps(acc_dict))
    with open(os.path.join(file_path, 'clusters_nums_dict.json'), 'w') as f:
        f.write(json.dumps(cluster_num_dict))
    with open(os.path.join(file_path, 'time_dict.json'), 'w') as f:
        f.write(json.dumps(time_dict))

Orignal valid accuracy:100.00% | Original Test accuracy:92.60% | Minimal valid accuracy:99.00%

########################## Lower bound 5 ##########################

---------------------- Quantize Layer block 1 Conv 1 1 ------------------------
Clusters:240 | Validation Accuracy:99.95% | Accuracy change: 0.00% | Time: 0.38 mins
Clusters:220 | Validation Accuracy:99.95% | Accuracy change: 0.00% | Time: 0.77 mins
Clusters:200 | Validation Accuracy:99.95% | Accuracy change: 0.00% | Time: 1.15 mins
Clusters:190 | Validation Accuracy:99.95% | Accuracy change: 0.00% | Time: 1.52 mins
Clusters:180 | Validation Accuracy:99.95% | Accuracy change: 0.00% | Time: 1.90 mins
Clusters:170 | Validation Accuracy:99.95% | Accuracy change: 0.00% | Time: 2.28 mins
Clusters:160 | Validation Accuracy:99.95% | Accuracy change: 0.00% | Time: 2.67 mins
Clusters:150 | Validation Accuracy:99.95% | Accuracy change: 0.00% | Time: 3.06 mins
Clusters:140 | Validation Accuracy:99.95% | Accuracy change: 0.00% | Time: 

Clusters: 95 | Validation Accuracy:99.40% | Accuracy change: -0.05% | Time: 4.94 mins
Clusters: 90 | Validation Accuracy:99.40% | Accuracy change: -0.05% | Time: 5.32 mins
Clusters: 85 | Validation Accuracy:99.40% | Accuracy change: -0.05% | Time: 5.69 mins
Clusters: 80 | Validation Accuracy:99.35% | Accuracy change: -0.10% | Time: 6.07 mins
Clusters: 75 | Validation Accuracy:99.35% | Accuracy change: -0.10% | Time: 6.45 mins
Clusters: 70 | Validation Accuracy:99.35% | Accuracy change: -0.10% | Time: 6.83 mins
Clusters: 65 | Validation Accuracy:99.35% | Accuracy change: -0.10% | Time: 7.20 mins
Clusters: 62 | Validation Accuracy:99.35% | Accuracy change: -0.10% | Time: 7.58 mins
Clusters: 60 | Validation Accuracy:99.25% | Accuracy change: -0.20% | Time: 7.96 mins
Clusters: 58 | Validation Accuracy:99.30% | Accuracy change: -0.15% | Time: 8.34 mins
Clusters: 56 | Validation Accuracy:99.30% | Accuracy change: -0.15% | Time: 8.71 mins
Clusters: 54 | Validation Accuracy:99.25% | Accuracy c

Clusters: 56 | Validation Accuracy:99.30% | Accuracy change: -0.05% | Time: 8.64 mins
Clusters: 54 | Validation Accuracy:99.25% | Accuracy change: -0.10% | Time: 9.01 mins
Clusters: 52 | Validation Accuracy:99.30% | Accuracy change: -0.05% | Time: 9.39 mins
Clusters: 50 | Validation Accuracy:99.35% | Accuracy change: 0.00% | Time: 9.76 mins
Clusters: 48 | Validation Accuracy:99.30% | Accuracy change: -0.05% | Time: 10.13 mins
Clusters: 46 | Validation Accuracy:99.30% | Accuracy change: -0.05% | Time: 10.49 mins
Clusters: 44 | Validation Accuracy:99.30% | Accuracy change: -0.05% | Time: 10.85 mins
Clusters: 42 | Validation Accuracy:99.30% | Accuracy change: -0.05% | Time: 11.22 mins
Clusters: 40 | Validation Accuracy:99.30% | Accuracy change: -0.05% | Time: 11.59 mins
Clusters: 38 | Validation Accuracy:99.20% | Accuracy change: -0.15% | Time: 11.96 mins
Clusters: 36 | Validation Accuracy:99.20% | Accuracy change: -0.15% | Time: 12.33 mins
End quantization
Epoch: 001/001 | Batch 050/750 