In [1]:
import os
import sys
sys.path.append(os.getcwd())
sys.path.append(os.path.join(os.getcwd(), 'utils'))
from configuration import *
import torch
import pprint
import pNN_Power_Aware as pNN
from utils import *

args = parser.parse_args([])
args.projectname = 'test32'
args = FormulateArgs(args)

args.PATIENCE = 50
args.mu = 10
args.powerestimator = 'AL'
args.powerbalance = 5e-4
    
print(f'Training network on device: {args.DEVICE}.')
MakeFolder(args)

train_loader, datainfo = GetDataLoader(args, 'train')
valid_loader, datainfo = GetDataLoader(args, 'valid')
test_loader, datainfo = GetDataLoader(args, 'test')
pprint.pprint(datainfo)

SetSeed(args.SEED)
setup = f"data_{datainfo['dataname']}_seed_{args.SEED}_Penalty_{args.powerestimator}_Factor_{args.powerbalance}"
print(f'Training setup: {setup}.')

msglogger = GetMessageLogger(args, setup)
msglogger.info(f'Training network on device: {args.DEVICE}.')
msglogger.info(f'Training setup: {setup}.')
msglogger.info(datainfo)

topology = [datainfo['N_feature']] + args.hidden + [datainfo['N_class']]

pnn = pNN.pNN(topology, args).to(args.DEVICE)

Training network on device: cpu.
{'N_class': 2,
 'N_feature': 6,
 'N_test': 25,
 'N_train': 70,
 'N_valid': 23,
 'dataname': 'acuteinflammation'}
Training setup: data_acuteinflammation_seed_0_Penalty_AL_Factor_0.0005.


In [2]:
def PT(pnn, train_loader, valid_loader, args, msglogger, setup):
    
        # Pretraning
        lossfunction = pNN.Lossfunction(args).to(args.DEVICE)
        optimizer = torch.optim.Adam(pnn.GetParam(), lr=args.LR)

        if args.powerestimator == 'AL':
            pnn, best = al_train_pnn_progressive(pnn, train_loader, valid_loader, lossfunction, optimizer, args, msglogger, UUID=setup+'_PT')
        else:
            pnn, best = train_pnn_progressive(pnn, train_loader, valid_loader, lossfunction, optimizer, args, msglogger, UUID=setup+'_PT')
    
        if best:
            if not os.path.exists(f'{args.savepath}/'):
                os.makedirs(f'{args.savepath}/')
            torch.save(pnn, f'{args.savepath}/pNN_{setup}.model')
            msglogger.info('Pretraining is finished.')
        else:
            msglogger.warning('Time out, further training is necessary.')
        
        return pnn
    
def FT(train_loader, valid_loader, args, msglogger, setup):

    pnn = torch.load(f'{args.savepath}/pNN_{setup}.model')

    # Pruning
    msglogger.info('Pruning...')
    print('Pruning...')
    N1, N2, N3, P1, P2, P3 = pnn.pruning
    information = f'{N1} ({P1*100:.2f}%) resistors, {N2} ({P2*100:.2f}%) activations and {N3} ({P3*100:.2f}%) negation circuits are pruned.'
    msglogger.info(information)
    print(information)


    # Fine Tuning
    lossfunction = pNN.Lossfunction(args).to(args.DEVICE)
    msglogger.info('Fine tuning...')
    optimizer = torch.optim.Adam(pnn.GetParam(), lr=args.LR/10.)
    if args.powerestimator == 'AL':
        pnn, best = al_train_pnn_progressive(pnn, train_loader, valid_loader, lossfunction, optimizer, args, msglogger, UUID=setup+'_FT')
    else:
        pnn, best = train_pnn_progressive(pnn, train_loader, valid_loader, lossfunction, optimizer, args, msglogger, UUID=setup+'_FT')
    if best:
        if not os.path.exists(f'{args.savepath}/'):
            os.makedirs(f'{args.savepath}/')
        torch.save(pnn, f'{args.savepath}/pNN_{setup}_FT.model')
        msglogger.info('Fine tuning if finished.')
    else:
        msglogger.warning('Time out, further training is necessary.') 

In [3]:
PT(pnn, train_loader, valid_loader, args, msglogger, setup)  

| Epoch:      0 | Train loss: 0.6931 | Valid loss: 0.6872 | Train acc: 0.5286 | Valid acc: 0.6087 | patience:   0 | lr: 1.000e-01 | Epoch time: 0.2 | Power: 5.11e-04 | lambda: 0.000e+00 | mu: 1.000e+01 |
| Epoch:     10 | Train loss: 0.3654 | Valid loss: 0.3195 | Train acc: 1.0000 | Valid acc: 1.0000 | patience:   0 | lr: 1.000e-01 | Epoch time: 0.2 | Power: 1.29e-03 | lambda: 0.000e+00 | mu: 1.000e+01 |
| Epoch:     20 | Train loss: 0.2036 | Valid loss: 0.1993 | Train acc: 1.0000 | Valid acc: 1.0000 | patience:   0 | lr: 1.000e-01 | Epoch time: 0.2 | Power: 1.35e-03 | lambda: 0.000e+00 | mu: 1.000e+01 |
| Epoch:     30 | Train loss: 0.1803 | Valid loss: 0.1770 | Train acc: 1.0000 | Valid acc: 1.0000 | patience:   0 | lr: 1.000e-01 | Epoch time: 0.2 | Power: 1.19e-03 | lambda: 0.000e+00 | mu: 1.000e+01 |
| Epoch:     40 | Train loss: 0.1639 | Valid loss: 0.1628 | Train acc: 1.0000 | Valid acc: 1.0000 | patience:   0 | lr: 1.000e-01 | Epoch time: 0.2 | Power: 1.09e-03 | lambda: 0.000e+0

KeyboardInterrupt: 

In [None]:
K = 1
for n in range(10):
    print(K / (n))