In [None]:
import time
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchsummary import summary
import statistics
import numpy as np

from src.resnet import ResNet
from src.train import test
from src.dataset import GoogleSpeechCommandsDataset

torch.set_flush_denormal(True)      # Important: sets small tensor values to zero
device = torch.device("cuda")

In [None]:
data_test_dir = '/home/marcel/Source/Python/DLOptimization/data/speech_commands_test_set_v0.02'
data_cache_dir = '/home/marcel/Source/Python/DLOptimization/cache/data/'
test_data = GoogleSpeechCommandsDataset(data_test_dir, data_cache_dir, encoder='mel', augment=0, train=False)
test_data.precache()
test_loader = DataLoader(test_data, batch_size=64, shuffle=True)
test_loader_cpu = DataLoader(test_data, batch_size=1, shuffle=True)
criterion = nn.CrossEntropyLoss()

### Post Training Static Quantization

In [None]:
base = 'models/resnet_ep55_acc99_sprs0.pt'
model_base = ResNet.from_state_dict(torch.load(base))
model_ptsq = ResNet.from_state_dict(torch.load(base))
model_ptsq = model_ptsq.prepare_quantization(qat=False, inplace=True)
model_ptsq = model_ptsq.quantize(inplace=True)

In [None]:
print(f'Baseline Model Size: {model_base.model_size():.3f} Mb')
print(f'Quantized Model Size: {model_ptsq.model_size():.3f} Mb')

In [None]:
model_ptsq.state_dict()

In [None]:
test_loss = 0
correct = 0
times = []
with torch.no_grad():
    for batch_idx, (data, target, target_idx, target_lbl, data_idx, pitch_shift) in enumerate(test_loader_cpu):
        st = time.perf_counter()
        output = model_ptsq(data)
        et = time.perf_counter() - st
        times.append(et)
        
        test_loss += criterion(output, target).item()
        
        # Get ACC
        pred = output.argmax(dim=1)
        correct += (pred == target_idx).sum().item()
    
        current = batch_idx * len(data)
        total = len(test_loader.dataset)
        percent = 100. * batch_idx / len(test_loader)
        print(f'Test [{current}/{total} ({percent:.0f}%)]\tCorrect: {correct}/{total}\tACC: {(correct / (current+1))*100:.2f}%\tTime: {et:.2f}s')

In [None]:
cpu_time_mean = sum(times) / 250
if len(times) > 1:
    cpu_time_std = statistics.stdev(times)
print(f'Time: {cpu_time_mean:.4f}±{cpu_time_std:.2f} s')

### Quantization Aware Model

In [None]:
path = 'models/resnet_ep51_acc100_sprs23_qat.pt'
model_base = ResNet.from_state_dict(torch.load(base))
model_pre = ResNet.from_state_dict(torch.load(path))

In [None]:
model_int8 = ResNet.from_state_dict(torch.load(path))
model_int8 = model_int8.prepare_quantization(qat=True, inplace=True)
model_int8 = model_int8.quantize(inplace=True)

In [None]:
print(f'Baseline Model Size: {model_base.model_size():.3f} Mb')
print(f'Original Model Size: {model_pre.model_size():.3f} Mb')
print(f'Quantized Model Size: {model_int8.model_size():.3f} Mb')

In [None]:
print('--- QAT ---')

st = time.perf_counter()
prediction = model_base(torch.rand(1, 1, 128, 111))
bet = time.perf_counter() - st
print(f'Baseline Time: {bet:.5f} s')

st = time.perf_counter()
prediction = model_pre(torch.rand(1, 1, 128, 111))
oet = time.perf_counter() - st
print(f'Original Time: {oet:.5f} s')

st = time.perf_counter()
prediction = model_int8(torch.rand(1, 1, 128, 111))
qet = time.perf_counter() - st
print(f'Quantized Time: {qet:.5f} s')

In [None]:
model_base.to(device)
accuracy, sparsity, significance, gpu_time, cpu_time_mean, cpu_time_std, num_params = test(model_base, test_loader, criterion, 
                                                                                           cpu_tests=32, return_time=True, log_interval=10, num_steps=10, device=device)

In [None]:
model_pre.to(device)
accuracy, sparsity, significance, gpu_time, cpu_time_mean, cpu_time_std, num_params = test(model_pre, test_loader, criterion, 
                                                                                           cpu_tests=32, return_time=True, log_interval=10, num_steps=10, device=device)

In [None]:
test_loss = 0
correct = 0
times = []
with torch.no_grad():
    for batch_idx, (data, target, target_idx, target_lbl, data_idx, pitch_shift) in enumerate(test_loader_cpu):
        st = time.perf_counter()
        output = model_int8(data)
        et = time.perf_counter() - st
        times.append(et)
        
        test_loss += criterion(output, target).item()
        
        # Get ACC
        pred = output.argmax(dim=1)
        correct += (pred == target_idx).sum().item()
    
        current = batch_idx * len(data)
        total = len(test_loader.dataset)
        percent = 100. * batch_idx / len(test_loader)
        print(f'Test [{current}/{total} ({percent:.0f}%)]\tCorrect: {correct}/{total}\tACC: {(correct / (current+1))*100:.2f}%\tTime: {et:.2f}s')

In [None]:
cpu_time_mean = sum(times) / 250
if len(times) > 1:
    cpu_time_std = statistics.stdev(times)
print(f'Time: {cpu_time_mean:.4f}±{cpu_time_std:.2f} s')

In [None]:
model_int8.state_dict()

In [None]:
summary(model_int8, input_size=(1, 128, 111))