In [None]:
import os
import sys
sys.path.append(os.getcwd())
sys.path.append(os.path.join(os.getcwd(), 'utils'))
from configuration import *
import torch
import pprint
import pNN_Power_Aware as pNN
from utils import *

args = parser.parse_args([])
args.projectname = 'test34'
args = FormulateArgs(args)

args.PATIENCE = 50
args.mu = 100
args.powerestimator = 'AL'
args.powerbalance = 5e-4
    
print(f'Training network on device: {args.DEVICE}.')
MakeFolder(args)

train_loader, datainfo = GetDataLoader(args, 'train')
valid_loader, datainfo = GetDataLoader(args, 'valid')
test_loader, datainfo = GetDataLoader(args, 'test')
pprint.pprint(datainfo)

SetSeed(args.SEED)
setup = f"data_{datainfo['dataname']}_seed_{args.SEED}_Penalty_{args.powerestimator}_Factor_{args.powerbalance}"
print(f'Training setup: {setup}.')

msglogger = GetMessageLogger(args, setup)
msglogger.info(f'Training network on device: {args.DEVICE}.')
msglogger.info(f'Training setup: {setup}.')
msglogger.info(datainfo)

topology = [datainfo['N_feature']] + args.hidden + [datainfo['N_class']]

pnn = pNN.pNN(topology, args).to(args.DEVICE)

In [None]:
def PT(pnn, train_loader, valid_loader, args, msglogger, setup):
    
    # Pretraning
    lossfunction = pNN.Lossfunction(args).to(args.DEVICE)
    optimizer = torch.optim.Adam(pnn.GetParam(), lr=args.LR)

    if args.powerestimator == 'AL':
        pnn, best = al_train_pnn_progressive(pnn, train_loader, valid_loader, lossfunction, optimizer, args, msglogger, UUID=setup+'_PT')
    else:
        pnn, best = train_pnn_progressive(pnn, train_loader, valid_loader, lossfunction, optimizer, args, msglogger, UUID=setup+'_PT')

    if best:
        if not os.path.exists(f'{args.savepath}/'):
            os.makedirs(f'{args.savepath}/')
        torch.save(pnn, f'{args.savepath}/pNN_{setup}.model')
        msglogger.info('Pretraining is finished.')
    else:
        msglogger.warning('Time out, further training is necessary.')
    
    
def FT(train_loader, valid_loader, args, msglogger, setup):

    pnn = torch.load(f'{args.savepath}/pNN_{setup}.model')

    # Pruning
    msglogger.info('Pruning...')
    print('Pruning...')
    N1, N2, N3, P1, P2, P3 = pnn.pruning
    information = f'{N1} ({P1*100:.2f}%) resistors, {N2} ({P2*100:.2f}%) activations and {N3} ({P3*100:.2f}%) negation circuits are pruned.'
    msglogger.info(information)
    print(information)


    # Fine Tuning
    lossfunction = pNN.Lossfunction(args).to(args.DEVICE)
    msglogger.info('Fine tuning...')
    optimizer = torch.optim.Adam(pnn.GetParam(), lr=args.LR/10.)
    if args.powerestimator == 'AL':
        pnn, best = al_train_pnn_progressive(pnn, train_loader, valid_loader, lossfunction, optimizer, args, msglogger, UUID=setup+'_FT')
    else:
        pnn, best = train_pnn_progressive(pnn, train_loader, valid_loader, lossfunction, optimizer, args, msglogger, UUID=setup+'_FT')
    if best:
        if not os.path.exists(f'{args.savepath}/'):
            os.makedirs(f'{args.savepath}/')
        torch.save(pnn, f'{args.savepath}/pNN_{setup}_FT.model')
        msglogger.info('Fine tuning if finished.')
    else:
        msglogger.warning('Time out, further training is necessary.') 

In [None]:
PT(pnn, train_loader, valid_loader, args, msglogger, setup)  

In [None]:
FT(train_loader, valid_loader, args, msglogger, setup)

In [None]:
import torch

torch.manual_seed(0)

torch.r