In [1]:
import torch,torchvision,os,pyhessian,time
import torchvision.transforms as transforms
import numpy as np
import pyhessian
from utils.util import get_loader,evaluate
from utils.layer import qConv2d,qLinear
from utils.train import QAVAT_train
import os

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_resnet20", pretrained=True).cuda()

Using cache found in /home/zihao/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


In [None]:
from copy import deepcopy
int4model = deepcopy(model)
int4model.load_state_dict(torch.load('int4model.ckpt')())

bfp12model = deepcopy(model)
bfp12model.load_state_dict(torch.load('bfp12model.ckpt')())

In [2]:
train,test = get_loader('cifar10'.upper(),batch_size=128,test_batch_size=512)
train.num_workers = 4
test.num_workers = 4

Files already downloaded and verified
Files already downloaded and verified


In [None]:
evaluate(test,bfp12model)

In [None]:
int4params,_ = pyhessian.get_params_grad(int4model)
bfp12params,_ = pyhessian.get_params_grad(bfp12model)
fp32params,_ = pyhessian.get_params_grad(model)
int4_error = [(x-y) for x,y in zip(int4params,fp32params)]
bfp12_error = [(x-y) for x,y in zip(bfp12params,fp32params)]

In [None]:
int4_l2 = pyhessian.group_product(int4_error,int4_error)**0.5
bfp12_l2 = pyhessian.group_product(bfp12_error,bfp12_error)**0.5

In [None]:
int4_error_var = [x.std() for x in int4_error]
bfp12_error_var = [x.std() for x in bfp12_error]

In [None]:
for x,y in zip(int4_error_var,bfp12_error_var):
    print(x/int4_l2,y/bfp12_l2)
    print()

In [3]:
def replaceModuleByName(modelName,moduleName,newModuleName):
    '''
        replace module with name modelName.moduleName with newModule
    '''
    tokens = moduleName.split('.')
    eval_str = modelName
    for token in tokens:
        try:
            eval_str += f'[{int(token)}]'
        except:
            eval_str += f'.{token}'
            
    exec(eval_str+f'={newModuleName}')
    
for name,module in model.named_modules():
    if isinstance(module,torch.nn.Conv2d):
        #print(name,' is a conv2d')
        newLayer = qConv2d(0,0,0,init_from=module).cuda()
        replaceModuleByName('model',name,'newLayer')
    elif isinstance(module,torch.nn.Linear):
        #print(name,' is a linear')
        newLayer = qLinear(0,0,init_from=module).cuda()
        replaceModuleByName('model',name,'newLayer')

In [None]:
evaluate(test,model,noise_std=0.2,repeat=100,debug=True)

In [6]:
C = {}
C['epochs'] = 30
C['optimizer'] ='SGD'
C['lr'] = 1e-4
C['decay_ep'] = 10
C['decay_ratio'] = 0.1
C['device'] = 'cuda'
C['valPerEp'] = 1
C['noise_std'] = 0.2
C['valSample'] = 10
C['trial_name'] = 'qavat_train'

In [None]:
model.train()
QAVAT_train(model,train,test,config=C,imgSize=32,imgFlat=False,
                lossfunc=torch.nn.CrossEntropyLoss(),printPerEpoch=1)

epoch 1 loss 0.0599571972599496 [9.9927 seconds]
Epoch 1 validation [6.3210 seconds]
mean acc 0.8609 mean loss 0.5760
epoch 2 loss 0.05410358600099297 [9.8717 seconds]
Epoch 2 validation [6.1661 seconds]
mean acc 0.8564 mean loss 0.5793
epoch 3 loss 0.054404750749792744 [10.1709 seconds]
Epoch 3 validation [6.2437 seconds]
mean acc 0.8581 mean loss 0.5595
epoch 4 loss 0.056277968673049794 [10.3481 seconds]
Epoch 4 validation [6.2464 seconds]
mean acc 0.8769 mean loss 0.4877
epoch 5 loss 0.053023095947721276 [9.8010 seconds]
Epoch 5 validation [6.2264 seconds]
mean acc 0.8605 mean loss 0.5650
epoch 6 loss 0.05261811461118634 [9.9082 seconds]
Epoch 6 validation [6.2365 seconds]
mean acc 0.8446 mean loss 0.6462
epoch 7 loss 0.051778007239160484 [9.9243 seconds]
Epoch 7 validation [6.1288 seconds]
mean acc 0.8731 mean loss 0.5043
epoch 8 loss 0.052690284736359214 [9.9709 seconds]
Epoch 8 validation [6.2473 seconds]
mean acc 0.8625 mean loss 0.5589
epoch 9 loss 0.05002013053340108 [9.8643 s

In [None]:
evaluate(test,model,noise_std=0.2,repeat=100,debug=True)

In [None]:
import tensorboardX
def QAVATPLUS_train(model,train_loader,test_loader,config,imgSize=32,imgFlat=False,
                lossfunc=torch.nn.CrossEntropyLoss(),printPerEpoch=100):

    tb = tensorboardX.SummaryWriter(comment=config['trial_name'])
    C = config
    if C['optimizer'] == 'SGD' or C['optimizer'] == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),lr=C['lr'],momentum=0.9)
    elif C['optimizer'] == 'adam' or C['optimizer'] == 'Adam' or C['optimizer'] == 'ADAM':
        optimizer = torch.optim.Adam(model.parameters(),lr=C['lr'])
    else:
        print('unrecognized optimizer defined in config')
        exit(0)
    for epoch in range(C['epochs']):
        # lr decay
        current_lr = C['lr'] * (C['decay_ratio'] ** (epoch // C['decay_ep']))
        for param_group in optimizer.param_groups:
            param_group['lr'] = current_lr

        start = time.time()
        total_loss = 0
        batch_count = 0
        # per epoch training, do
        for i, data in enumerate(train_loader, 0):
            model.train()
            x, label = data
            if imgFlat:
                x = x.view(-1,imgSize**2)
            x = x.to(C['device'])
            label = label.to(C['device'])
            optimizer.zero_grad()
            '''
            if config['noise_std'] > 0:
                generate_variation(model,noise_std=config['noise_std'])
            output = model(x)
            l = lossfunc(output,label)
            l.backward()
            '''
            
            # use under-variation 2nd-order gradient estimation
            # g + sigma^2 (Hw)
            # Hw is the hessian-parameter product
            # computed using pyhessian (reuse gradient)
            
            hcp = pyhessian.hessian(model,data=(x,label),criterion=torch.nn.CrossEntropyLoss())
    
            hw = pyhessian.hessian_vector_product(hcp.gradsH, hcp.params, hcp.params)
        
            for p,hw_ in zip(hcp.params,hw):
                p.grad.data += hw_ * config['noise_std']**2
            
            optimizer.step()
            total_loss += l.data.item()
            batch_count += 1
            
            for p in hcp.params:
                p.grad = None
        
        total_loss /= batch_count
        tb.add_scalar('epoch loss',total_loss,epoch+1)
        tb.add_scalar('epoch time',time.time()-start,epoch+1)
        tb.add_scalar('learning rate',current_lr,epoch+1)

        # console output
        if epoch % printPerEpoch == printPerEpoch-1:
            print("epoch %s loss %s [%.4f seconds]"%(epoch+1,total_loss,time.time()-start))

        if C['valPerEp'] is None:
            continue

        # validation
        if epoch % C['valPerEp'] == 0:
            val = evaluate(test_loader,model,noise_std = C['noise_std'],repeat = C['valSample'],imgSize=imgSize,imgFlat=imgFlat,device = C['device'])

            tb.add_scalar('validation/mean accuracy',val['mean_acc'],epoch+1)
            tb.add_scalar('validation/qtl accuracy',val['qtl_acc'],epoch+1)
            tb.add_scalar('validation/mean loss',val['mean_loss'],epoch+1)
            tb.add_scalar('validation/qtl loss',val['qtl_loss'],epoch+1)
            tb.add_scalar('validation/validation time',val['test time'],epoch+1)
            tb.add_histogram('validation/accuracy',val['acc_list'],epoch+1)
            tb.add_histogram('validation/loss',val['loss_list'],epoch+1)

            print("Epoch %s validation [%.4f seconds]"%(epoch+1,val['test time']))
            print("mean acc %.4f mean loss %.4f"%(val['mean_acc'],val['mean_loss']))

    tb.close()

In [None]:
model.train()
C['trial_name'] = 'qavat_plus_train'
QAVATPLUS_train(model,train,test,config=C,imgSize=32,imgFlat=False,
                lossfunc=torch.nn.CrossEntropyLoss(),printPerEpoch=1)