In [1]:
import os
import sys
sys.path.append(os.getcwd())
sys.path.append(os.path.join(os.getcwd(), 'utils'))
from configuration import *
import torch
import pprint
from utils import *
import PrintedLearnableFilter as pNN

In [2]:
args = parser.parse_args([])
args.e_train = 0.1
args.N_train = 10
args.metric = 'temporal_acc'
args = FormulateArgs(args)

print(f'Training network on device: {args.DEVICE}.')
MakeFolder(args)

train_loader, datainfo = GetDataLoader(args, 'train')
valid_loader, datainfo = GetDataLoader(args, 'valid')
test_loader, datainfo = GetDataLoader(args, 'test')
pprint.pprint(datainfo)

Training network on device: cpu.
{'N_class': 2,
 'N_feature': 6,
 'N_test': 25,
 'N_time': 32,
 'N_train': 70,
 'N_valid': 23,
 'dataname': 'acuteinflammation'}


In [3]:
SetSeed(args.SEED)

setup = f"pLF_data_{args.DATASET:02d}_{datainfo['dataname']}_seed_{args.SEED:02d}.model"
print(f'Training setup: {setup}.')

msglogger = GetMessageLogger(args, setup)
msglogger.info(f'Training network on device: {args.DEVICE}.')
msglogger.info(f'Training setup: {setup}.')
msglogger.info(datainfo)

if os.path.isfile(f'{args.savepath}/{setup}'):
    print(f'{setup} exists, skip this training.')
    msglogger.info('Training was already finished.')
else:
    pnn = pNN.PrintedNeuralNetwork(args, datainfo['N_feature'], datainfo['N_class'], 2).to(args.DEVICE)

    msglogger.info(f'Number of parameters that are learned in this experiment: {len(pnn.GetParam())}.')

    lossfunction = pNN.LFLoss(args).to(args.DEVICE)
    optimizer = torch.optim.Adam(pnn.GetParam(), lr=args.LR)

    if args.PROGRESSIVE:
        pnn, best = train_pnn_progressive(pnn, train_loader, valid_loader, lossfunction, optimizer, args, msglogger, UUID=setup)
    else:
        pnn, best = train_pnn(pnn, train_loader, valid_loader, lossfunction, optimizer, args, msglogger, UUID=setup)

    if best:
        if not os.path.exists(f'{args.savepath}/'):
            os.makedirs(f'{args.savepath}/')
        torch.save(pnn, f'{args.savepath}/{setup}')
        msglogger.info('Training if finished.')
    else:
        msglogger.warning('Time out, further training is necessary.')

CloseLogger(msglogger)

Training setup: pLF_data_00_acuteinflammation_seed_00.model.
torch.Size([10, 70, 2, 32])
torch.Size([10, 23, 2, 32])
| Epoch:      0 | Train loss: 6.9318e-01 | Valid loss: 6.9023e-01 | Train acc: 0.5186 | Valid acc: 0.6087 | patience:   0 | lr: 0.1 | Epoch time: 57.4 |
torch.Size([10, 70, 2, 32])
torch.Size([10, 23, 2, 32])
| Epoch:      1 | Train loss: 7.6135e-01 | Valid loss: 6.7503e-01 | Train acc: 0.5286 | Valid acc: 0.6087 | patience:   0 | lr: 0.1 | Epoch time: 58.3 |
torch.Size([10, 70, 2, 32])


KeyboardInterrupt: 

In [None]:
test = PLF.PrintedNeuralNetwork(args, 3, 5, 2)

In [None]:
Batch = 100
Channel = 3
X = torch.linspace(0,1,64).repeat(Batch, Channel, 1).to(args.DEVICE)

print(X.shape)

In [None]:
y = test(X)
y.shape

In [None]:
test.UpdateVariation(20, 0.1)

In [None]:
y = test(X)
y.shape