In [None]:
import torch,torchvision,os,pyhessian,time
import torchvision.transforms as transforms
import numpy as np
import pyhessian
from utils.util import get_loader,evaluate
from utils.layer import qConv2d,qLinear
from utils.train import QAVAT_train
import os

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_resnet20", pretrained=True).cuda()

In [None]:
from copy import deepcopy
int4model = deepcopy(model)
int4model.load_state_dict(torch.load('int4model.ckpt')())

bfp12model = deepcopy(model)
bfp12model.load_state_dict(torch.load('bfp12model.ckpt')())

In [None]:
train,test = get_loader('cifar10'.upper(),batch_size=128,test_batch_size=512)
train.num_workers = 4
test.num_workers = 4

In [None]:
evaluate(test,bfp12model)

In [None]:
int4params,_ = pyhessian.get_params_grad(int4model)
bfp12params,_ = pyhessian.get_params_grad(bfp12model)
fp32params,_ = pyhessian.get_params_grad(model)
int4_error = [(x-y) for x,y in zip(int4params,fp32params)]
bfp12_error = [(x-y) for x,y in zip(bfp12params,fp32params)]

In [None]:
int4_l2 = pyhessian.group_product(int4_error,int4_error)**0.5
bfp12_l2 = pyhessian.group_product(bfp12_error,bfp12_error)**0.5

In [None]:
int4_error_var = [x.std() for x in int4_error]
bfp12_error_var = [x.std() for x in bfp12_error]

In [None]:
for x,y in zip(int4_error_var,bfp12_error_var):
    print(x/int4_l2,y/bfp12_l2)
    print()

In [None]:
def replaceModuleByName(modelName,moduleName,newModuleName):
    '''
        replace module with name modelName.moduleName with newModule
    '''
    tokens = moduleName.split('.')
    eval_str = modelName
    for token in tokens:
        try:
            eval_str += f'[{int(token)}]'
        except:
            eval_str += f'.{token}'
            
    exec(eval_str+f'={newModuleName}')
    
for name,module in model.named_modules():
    if isinstance(module,torch.nn.Conv2d):
        #print(name,' is a conv2d')
        newLayer = qConv2d(0,0,0,init_from=module).cuda()
        replaceModuleByName('model',name,'newLayer')
    elif isinstance(module,torch.nn.Linear):
        #print(name,' is a linear')
        newLayer = qLinear(0,0,init_from=module).cuda()
        replaceModuleByName('model',name,'newLayer')

In [None]:
evaluate(test,model,noise_std=0.2,repeat=100,debug=True)

In [None]:
C = {}
C['epochs'] = 30
C['optimizer'] ='SGD'
C['lr'] = 1e-3
C['decay_ep'] = 10
C['decay_ratio'] = 0.1
C['device'] = 'cuda'
C['valPerEp'] = 1
C['noise_std'] = 0.2
C['valSample'] = 10
C['trial_name'] = 'qavat_train'

In [None]:
model.train()
QAVAT_train(model,train,test,config=C,imgSize=32,imgFlat=False,
                lossfunc=torch.nn.CrossEntropyLoss(),printPerEpoch=1)