In [1]:
import os
import sys
sys.path.append(os.getcwd())
sys.path.append(os.path.join(os.getcwd(), 'utils'))
from configuration import *
import torch
import pprint
import pNN
from utils import *

args = parser.parse_args([])
args.e_train = 0.05
args.N_train = 10
args.e_fault = 1
args.N_fault = 20
args.SEED = 0
args.report_freq = 1
args = FormulateArgs(args)

In [2]:
print(f'Training network on device: {args.DEVICE}.')
MakeFolder(args)

train_loader, datainfo = GetDataLoader(args, 'train')
valid_loader, datainfo = GetDataLoader(args, 'valid')
test_loader, datainfo = GetDataLoader(args, 'test')
pprint.pprint(datainfo)

SetSeed(args.SEED)

setup = f"data_{args.DATASET:02d}_{datainfo['dataname']}_seed_{args.SEED:02d}_epsilon_{args.e_train}_faults_{args.e_fault:1d}.model"
print(f'Training setup: {setup}.')

msglogger = GetMessageLogger(args, setup)
msglogger.info(f'Training network on device: {args.DEVICE}.')
msglogger.info(f'Training setup: {setup}.')
msglogger.info(datainfo)

topology = [datainfo['N_feature']] + args.hidden + [datainfo['N_class']]
msglogger.info(f'Topology of the network: {topology}.')

pnn = pNN.pNN(topology, args).to(args.DEVICE)


Training network on device: cpu.
{'N_class': 2,
 'N_feature': 6,
 'N_test': 25,
 'N_train': 70,
 'N_valid': 23,
 'dataname': 'acuteinflammation'}
Training setup: data_00_acuteinflammation_seed_00_epsilon:0.05.model.


In [3]:
lossfunction = pNN.pNNLoss(args).to(args.DEVICE)
optimizer = torch.optim.Adam(pnn.GetParam(), lr=args.LR)

if args.PROGRESSIVE:
    pnn, best = train_pnn_progressive(pnn, train_loader, valid_loader, lossfunction, optimizer, args, msglogger, UUID=setup)
else:
    pnn, best = train_pnn(pnn, train_loader, valid_loader, lossfunction, optimizer, args, msglogger, UUID=setup)

if best:
    if not os.path.exists(f'{args.savepath}/'):
        os.makedirs(f'{args.savepath}/')
    torch.save(pnn, f'{args.savepath}/{setup}')
    msglogger.info('Training if finished.')
else:
    msglogger.warning('Time out, further training is necessary.')

| Epoch:      0 | Train loss: 7.4705e+00 | Valid loss: 1.8939e+01 | Train acc: 0.4983 ± 0.0296 | Valid acc: 0.5339 ± 0.1350 | patience:   0 | lr: 0.1 | Epoch time: 3.8 |
| Epoch:      1 | Train loss: 6.9889e+00 | Valid loss: 6.6272e+00 | Train acc: 0.5060 ± 0.0565 | Valid acc: 0.6239 ± 0.1255 | patience:   0 | lr: 0.1 | Epoch time: 3.8 |
| Epoch:      2 | Train loss: 6.7421e+00 | Valid loss: 6.1706e+00 | Train acc: 0.5757 ± 0.0912 | Valid acc: 0.7061 ± 0.1396 | patience:   0 | lr: 0.1 | Epoch time: 3.8 |


KeyboardInterrupt: 

In [None]:
pnn.model[0].INV.Mask = pnn.model[0].FaultMaskNEG

In [None]:
pnn.model[0].INV.Mask.shape

In [None]:
class InvRT(torch.nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.N = args.N_train
        self.epsilon = args.e_train
        args.N_fault = args.N_fault

        # R1, R2, R3, W1, L1, W2, L2, W3, L3
        self.rt_ = torch.nn.Parameter(torch.tensor([args.NEG_R1n, args.NEG_R2n, args.NEG_R3n, args.NEG_W1n, args.NEG_L1n, args.NEG_W2n, args.NEG_L2n, args.NEG_W3n, args.NEG_L3n]).to(args.DEVICE), requires_grad=True)
        # model
        package = torch.load('./utils/neg_param.package')
        self.eta_estimator = package['eta_estimator'].to(self.DEVICE)
        self.eta_estimator.train(False)
        for name, param in self.eta_estimator.named_parameters():
            param.requires_grad = False
        self.X_max = package['X_max'].to(self.DEVICE)
        self.X_min = package['X_min'].to(self.DEVICE)
        self.Y_max = package['Y_max'].to(self.DEVICE)
        self.Y_min = package['Y_min'].to(self.DEVICE)
        # load power model
        package = torch.load('./utils/neg_power.package')
        self.power_estimator = package['power_estimator'].to(self.DEVICE)
        for name, param in self.power_estimator.named_parameters():
            param.requires_grad = False
        self.power_estimator.train(False)
        self.pow_X_max = package['X_max'].to(self.DEVICE)
        self.pow_X_min = package['X_min'].to(self.DEVICE)
        self.pow_Y_max = package['Y_max'].to(self.DEVICE)
        self.pow_Y_min = package['Y_min'].to(self.DEVICE)

        self.eta_fault = torch.tensor([[ 8.3907e-01, -1.0000e+00,  0.0000e+00,  2.0978e-17],
                                       [-6.0647e-01, -1.0000e+00,  0.0000e+00,  1.0867e-08],
                                       [-9.9992e-01, -1.0000e+00,  0.0000e+00, -4.5596e-18],
                                       [ 8.3907e-01, -1.0000e+00,  0.0000e+00, -1.8533e-17],
                                       [ 3.1485e+01,  2.8551e-03, -9.9980e-02,  6.3016e+00],
                                       [-1.0000e+00, -1.0000e+00,  0.0000e+00,  4.4070e-16],
                                       [ 1.8307e+02, -1.8394e+02, -8.3575e+01,  4.2966e-02],
                                       [-6.0647e-01, -1.0000e+00,  0.0000e+00,  1.0867e-08],
                                       [ 1.2159e-01, -7.3578e-01, -7.9441e-02,  3.1090e+00],
                                       [ 8.3907e-01, -1.0000e+00,  0.0000e+00,  2.0978e-17],
                                       [-9.9992e-01, -1.0000e+00,  0.0000e+00, -4.5596e-18],
                                       [ 7.6517e-01, -8.0291e-03,  6.3714e-01,  1.2184e+00],
                                       [ 8.3907e-01, -1.0000e+00,  0.0000e+00,  2.0978e-17],
                                       [ 8.3907e-01, -1.0000e+00,  0.0000e+00, -1.8533e-17],
                                       [-1.0000e+00, -1.0000e+00,  0.0000e+00,  4.4070e-16],
                                       [ 1.0000e+00, -1.0000e+00,  0.0000e+00,  4.4982e-16],
                                       [-4.5913e-02, -7.2633e-01,  7.3493e-02,  1.0507e+01],
                                       [-9.9992e-01, -1.0000e+00,  0.0000e+00, -4.5596e-18]]).to(self.DEVICE)

        self.Mask = None

    @property
    def DEVICE(self):
        return self.args.DEVICE
    
    @property
    def RT(self):
        # keep values in (0,1)
        rt_temp = torch.sigmoid(self.rt_)
        RTn = torch.zeros([12]).to(self.DEVICE)
        RTn[:9] = rt_temp
        # denormalization
        RT = RTn * (self.X_max - self.X_min) + self.X_min
        return RT
    
    @property
    def RT_noisy(self):
        RT_mean = self.RT.repeat(self.N, 1)
        noise = ((torch.rand(RT_mean.shape) * 2. - 1.) * self.epsilon) + 1.
        RT_variation = RT_mean * noise
        return RT_variation

    @property
    def RTn_extend(self):
        RT_extend = torch.stack([self.RT_noisy[:,0], self.RT_noisy[:,1], self.RT_noisy[:,2], self.RT_noisy[:,3],
                                 self.RT_noisy[:,4], self.RT_noisy[:,5], self.RT_noisy[:,6], self.RT_noisy[:,7],
                                 self.RT_noisy[:,8], self.RT_noisy[:,3]/self.RT_noisy[:,4], self.RT_noisy[:,5]/self.RT_noisy[:,6],
                                 self.RT_noisy[:,7]/self.RT_noisy[:,8]], dim=1)
        return (RT_extend - self.X_min) / (self.X_max - self.X_min)

    @property
    def eta(self):
        # calculate eta
        eta_n = self.eta_estimator(self.RTn_extend)
        eta = eta_n * (self.Y_max - self.Y_min) + self.Y_min
        return eta

    @property
    def power(self):
        # calculate power
        power_n = self.power_estimator(self.RTn_extend)
        power = power_n * (self.pow_Y_max - self.pow_Y_min) + self.pow_Y_min
        return power.mean()

    def output_variation(self, eta, z):
        a = torch.zeros_like(z)
        for i in range(self.N):
            a[i,:,:] = -(eta[i,0] + eta[i,1] * torch.tanh((z[i,:,:] - eta[i,2]) * eta[i,3]))
        return a
    
    def output_faults(self, z, mask):
        result = [self.output_variation(self.eta, z)]

        for fault in range(self.eta_fault.shape[0]):
            eta_temp = self.eta_fault[fault,:].repeat(self.N, 1)
            result.append(self.output_variation(eta_temp, z))

        output = torch.stack(result)
        slices = [output[int(mask[i]), :, :, i] for i in range(mask.numel())]
        return torch.stack(slices, dim=2)
    
    def forward(self, z):
        result = []
        for i in range(self.Mask.shape[0]):
            result.append(self.output_faults(z[i,:,:,:], self.Mask[i]))
        return torch.stack(result)
    
    def UpdateArgs(self, args):
        self.args = args
    

In [None]:
inv = InvRT(args)

In [None]:
for x, y in train_loader:
    break

z  = x.repeat(args.N_fault, args.N_train, 1, 1)
z = torch.randn(args.N_fault, args.N_train, 70, 8)

In [None]:
mask = pnn.model[0].FaultMaskNEG
mask.shape

In [None]:
inv.Mask = mask
inv.output_faults(z[0,:,:,:], inv.Mask[0]).shape

In [None]:
inv.forward(z).shape

In [None]:
pnn.model[0].INV.Mask = mask
pnn.model[0].INV.forward(z).shape

In [None]:
inv.Mask[0]

In [None]:
inv(torch.Size([17, 10, 70, 8]))

In [None]:
pnn.model[0].INV(torch.Size([17, 10, 70, 8]))