In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import os
import matplotlib.cm as cm
import torch
import torch.nn.functional as F
from core.bnet import BNet
from core.layers.bconv2d import compute_bit_cost, bit2min, bit2max
from model import get_model

label_fontsize = 12

def scale_w(w, scale, bit=8):
    thd_neg = bit2min(bit)
    thd_pos = bit2max(bit)
    w = torch.clamp(w.div(scale), thd_neg-0.5,  thd_pos+0.5)
    return w

def get_layer_ind(K=3):
    unsorted_max_indices = np.argpartition(-net.dense_syn_cnt, K)[:K]
    y = net.dense_syn_cnt[unsorted_max_indices]
    indices = np.argsort(-y)
    max_k_indices = unsorted_max_indices[indices]

    return max_k_indices

print('Start')

dssw = 1024
act_bit = 32 # 8, 32
wgt_bit = 8 # 8, 32
lamda = 1.0 # 0.0, 1.0
epoch = 0
dataset_name = 'cifar10' # 'cifar10' , 'cifar100'
optim_loss_type = 'act_tgt' # 'act_tgt' 'act_naive'
model_name = 'resnet18' # 

exp_name = '{}/{}/{}_wd1_bs512_epoch200_lr0.05_0.01_lamda_{:0.1f}_bit8_{}'.format(dataset_name, model_name, optim_loss_type, lamda, act_bit)

log_tpath  = os.path.join('log', exp_name,'{:08d}','score','log') 

bit_cost, _  = compute_bit_cost(wgt_bit)

cost_color = []

for i in range(len(bit_cost)):
    if bit_cost[i]==0:
        cost_color.append((0,0,0))
    else:
        cost_color.append(cm.jet((bit_cost[i].item())/7.0))

score_tpath   = os.path.join('log', exp_name, '{:08d}').format(dssw)
model_tpath   = os.path.join('log', exp_name, '{:08d}'.format(dssw), 'model', 'final.pt').format(dssw)

print('path OK ' + score_tpath)

cfg = CN((yaml.safe_load(open('config/cifar10.yaml', 'r'))))
model = get_model(cfg)
cfg.dataset.input_shape = [1, 3, 32, 32]
net = BNet(model, cfg)

print('model loadted ' + exp_name)

if epoch>0:
    net.load_state_dict(torch.load(model_tpath, map_location='cpu'),  strict=False)


os.makedirs(os.path.join(score_tpath.format(dssw),  'png') ,exist_ok=True)
os.makedirs(os.path.join(score_tpath.format(dssw),  'pdf') ,exist_ok=True)

max_k_indices = get_layer_ind(3)
print(net.dense_syn_cnt/1e6)


for idx, module in enumerate(net.DSS_cand):
    plt.figure(21, figsize=(6,3),dpi=200)
    w = module.fweight
        
    if idx in max_k_indices:
        print(idx)
        print(net.dense_syn_cnt[idx]/1e6)
    else:
        continue

    scale = module.scale
    w = scale_w(w, scale, wgt_bit)
    hist = torch.histogram(w.cpu(), bins=2**wgt_bit+1, range=(bit2min(wgt_bit)-0.5, bit2max(wgt_bit)+0.5), density=False)
    bit_cnt = hist[0].mul(bit_cost.cpu().squeeze()).sum()/torch.numel(w)
    fig = plt.bar(hist[1].ceil()[:257].detach().numpy(), hist[0].detach().numpy(), color=cost_color, linewidth=1)
    
    for c in range(8):
        if c==0:
            col = (0,0,0)
        else:
            col = cm.jet(c/7.0)
    # plt.text(-125+ (c)*16, 50, str(c))
        plt.text(10+ (c)*15, 1e5, str(c), bbox=dict(edgecolor=col, facecolor='none'))
        
    
    plt.ylim([0, 1e6])
    plt.yscale('log')
    plt.xlabel("Weight", fontsize=label_fontsize)
    plt.ylabel("Frequency (log)", fontsize=label_fontsize)
    plt.xlim([-128.5, 128.5])
    plt.yticks([1e1, 1e2, 1e3,1e4,1e5, 1e6])
    plt.savefig((os.path.join(score_tpath, 'png', 'layer{:03d}_{:03d}_.png'.format(idx, epoch))), bbox_inches='tight')
    plt.savefig((os.path.join(score_tpath, 'pdf', 'layer{:03d}_{:03d}_.pdf'.format(idx, epoch))), bbox_inches='tight')
    print(os.path.join(score_tpath, 'png', 'layer{:03d}_{:03d}_.png'.format(idx, epoch)))
    plt.show()