In [1]:
import argparse
from cmath import inf
import datetime
import os
import shutil
import socket
import sys
from pathlib import Path
import numpy
import time
import torch

import decode.evaluation
import decode.neuralfitter
import decode.neuralfitter.coord_transform
import decode.neuralfitter.utils
import decode.simulation
import decode.utils
# from decode.neuralfitter.train.random_simulation import setup_random_simulation
from decode.neuralfitter.utils import log_train_val_progress
from decode.utils.checkpoint import CheckPoint
torch.set_printoptions(precision=4,sci_mode=False)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def parse_args():
    parser = argparse.ArgumentParser(description='Training Args')

    parser.add_argument('-i', '--device', default=None, 
                        help='Specify the device string (cpu, cuda, cuda:0) and overwrite param.',
                        type=str)

    parser.add_argument('-p', '--param_file', default=None,
                        help='Specify your parameter file (.yml or .json).', type=str)

    parser.add_argument('-w', '--num_worker_override',default=None,
                        help='Override the number of workers for the dataloaders.',
                        type=int)

    parser.add_argument('-n', '--no_log', default=False, action='store_true',
                        help='Set no log if you do not want to log the current run.')

    parser.add_argument('-c', '--log_comment', default=None,
                        help='Add a log_comment to the run.')

    parser.add_argument('-d', '--data_path_override', default=None,
                        help='Specify your path to data', type=str)

    parser.add_argument('-is', '--img_size_override', default=None,
                        help='Override img size', type=int)

    # args = parser.parse_args()
    args, _ = parser.parse_known_args()
    return args

In [3]:
args = parse_args()
args = parse_args()
args.device = 'cpu'
args.param_file='/home/lingjia/Documents/rPSF/NN/param.yaml'
# args.data_path_override='/media/hdd/rPSF/data/plain/train/0620_uniformFlux'
args.img_size_override=96

param_file = Path(args.param_file)
param = decode.utils.param_io.ParamHandling().load_params(param_file)

# add meta information - Meta=namespace(version='0.10.0'),
param.Meta.version = decode.utils.bookkeeping.decode_state()

"""Experiment ID"""
if param.InOut.checkpoint_init is None:
    experiment_id = datetime.datetime.now().strftime(
        "%Y-%m-%d-%H-%M-%S") + '_' + socket.gethostname()
    from_ckpt = False
    if args.log_comment:
        experiment_id = experiment_id + '_' + args.log_comment
else:
    from_ckpt = True
    experiment_id = Path(param.InOut.checkpoint_init).parent.name

"""Set up unique folder for experiment"""
if not from_ckpt:
    experiment_path = Path(param.InOut.experiment_out) / Path(experiment_id)
else:
    experiment_path = Path(param.InOut.checkpoint_init).parent

if not experiment_path.parent.exists():
    experiment_path.parent.mkdir()

if not from_ckpt:
    experiment_path.mkdir(exist_ok=False)

model_out = experiment_path / Path('model.pt')
ckpt_path = experiment_path / Path('ckpt.pt')

# Modify parameters
if args.num_worker_override is not None:
    param.Hardware.num_worker_train = args.num_worker_override

"""Hardware / Server stuff."""
if args.device is not None:
    device = args.device
    # param.Hardware.device_simulation = device_overwrite  # lazy assumption
else:
    device = param.Hardware.device

if args.data_path_override is not None:
    param.InOut.data_path = args.data_path_override

if args.img_size_override is not None:
    param.Simulation.img_size = [args.img_size_override,args.img_size_override]
    param.Simulation.psf_extent = [[-0.5, args.img_size_override-0.5],
                                    [-0.5, args.img_size_override-0.5], None]

    param.TestSet.frame_extent = param.Simulation.psf_extent
    param.TestSet.img_size = param.Simulation.img_size

# Backup the parameter file under the network output path with the experiments ID
param_backup_in = experiment_path / Path('param_run_in').with_suffix(param_file.suffix)
# shutil.copy(param_file, param_backup_in)

param_backup = experiment_path / Path('param_run').with_suffix(param_file.suffix)
decode.utils.param_io.ParamHandling().write_params(param_backup, param)

if sys.platform in ('linux', 'darwin'):
    os.nice(param.Hardware.unix_niceness)
elif param.Hardware.unix_niceness is not None:
    print(f"Cannot set niceness on platform {sys.platform}. You probably do not need to worry.")

torch.set_num_threads(param.Hardware.torch_threads)

"""Setup Log System"""
if args.no_log:
    logger = decode.neuralfitter.utils.logger.NoLog()

else:
    log_folder = experiment_path

    logger = decode.neuralfitter.utils.logger.MultiLogger(
        [decode.neuralfitter.utils.logger.SummaryWriter(log_dir=log_folder,
                                                        filter_keys=["dx_red_mu", "dx_red_sig",
                                                                        "dy_red_mu",
                                                                        "dy_red_sig", "dz_red_mu",
                                                                        "dz_red_sig",
                                                                        "dphot_red_mu",
                                                                        "dphot_red_sig"]),
            decode.neuralfitter.utils.logger.DictLogger()])


In [4]:
"""Set model, optimiser, loss and schedulers"""
models_available = {
    'SigmaMUNet': decode.neuralfitter.models.SigmaMUNet,
    'DoubleMUnet': decode.neuralfitter.models.model_param.DoubleMUnet,
    'SimpleSMLMNet': decode.neuralfitter.models.model_param.SimpleSMLMNet,
}

model = models_available[param.HyperParameter.architecture]
model = model.parse(param)

model_ls = decode.utils.model_io.LoadSaveModel(model, output_file=model_out)

model = model_ls.load_init()
model = model.to(torch.device(device))

# Small collection of optimisers
optimizer_available = {
    'Adam': torch.optim.Adam,
    'AdamW': torch.optim.AdamW
}

optimizer = optimizer_available[param.HyperParameter.optimizer]
optimizer = optimizer(model.parameters(), **param.HyperParameter.opt_param)

"""Loss function."""
criterion = decode.neuralfitter.loss.GaussianMMLoss(
    xextent=param.Simulation.psf_extent[0],
    yextent=param.Simulation.psf_extent[1],
    img_shape=param.Simulation.img_size,
    device=device,
    chweight_stat=param.HyperParameter.chweight_stat)

"""Learning Rate and Simulation Scheduling"""
lr_scheduler_available = {
    'ReduceLROnPlateau': torch.optim.lr_scheduler.ReduceLROnPlateau,
    'StepLR': torch.optim.lr_scheduler.StepLR
}
lr_scheduler = lr_scheduler_available[param.HyperParameter.learning_rate_scheduler]
lr_scheduler = lr_scheduler(optimizer, **param.HyperParameter.learning_rate_scheduler_param)

"""Checkpointing"""
checkpoint = CheckPoint(path=ckpt_path)

"""Setup gradient modification"""
grad_mod = param.HyperParameter.grad_mod

"""Log the model (Graph) """
try:
    dummy = torch.rand((2, param.HyperParameter.channels_in,
                        *param.Simulation.img_size), requires_grad=False).to(torch.device(device))
    logger.add_graph(model, dummy)

except:
    print("Did not log graph.")
    # raise RuntimeError("Your dummy input is wrong. Please update it.")

"""Setup Target generator consisting possibly multiple steps in a transformation sequence."""
tar_proc = decode.neuralfitter.utils.processing.TransformSequence(
    [
        # param_tar --> phot/max, z/z_max, bg/bg_max
        decode.neuralfitter.scale_transform.ParameterListRescale(
            phot_max=param.Scaling.phot_max,
            z_max=param.Scaling.z_max,
            bg_max=param.Scaling.bg_max)
    ])

# print(test_ds.label_gen())

"""Set up post processor"""
if param.PostProcessing is None:
    post_processor = decode.neuralfitter.post_processing.NoPostProcessing(xy_unit='px',
                                                                            px_size=param.Camera.px_size)

elif param.PostProcessing == 'LookUp':
    post_processor = decode.neuralfitter.utils.processing.TransformSequence([

        decode.neuralfitter.scale_transform.InverseParamListRescale(
            phot_max=param.Scaling.phot_max,
            z_max=param.Scaling.z_max,
            bg_max=param.Scaling.bg_max),

        decode.neuralfitter.coord_transform.Offset2Coordinate.parse(param),

        decode.neuralfitter.post_processing.LookUpPostProcessing(
            raw_th=param.PostProcessingParam.raw_th,
            pphotxyzbg_mapping=[0, 1, 2, 3, 4, -1],
            xy_unit='px',
            px_size=param.Camera.px_size)
    ])

elif param.PostProcessing in ('SpatialIntegration', 'NMS'):  # NMS as legacy support
    post_processor = decode.neuralfitter.utils.processing.TransformSequence([
        # out_tar --> out_tar: photo*photo_max, z*z_max, bg*bg_max
        decode.neuralfitter.scale_transform.InverseParamListRescale(
            phot_max=param.Scaling.phot_max,
            z_max=param.Scaling.z_max,
            bg_max=param.Scaling.bg_max),
        # offset --> coordinate e.g., 0.2 --> 10.2 
        decode.neuralfitter.coord_transform.Offset2Coordinate.parse(param),

        decode.neuralfitter.post_processing.SpatialIntegration(
            raw_th=param.PostProcessingParam.raw_th, # 0.5
            xy_unit='px')
    ])



Model instantiated.
Model initialised as specified in the constructor.


In [7]:
x = torch.randn(1,1,96,96)
y_out = model(x)
print(f'y_out: {y_out.shape}')

# define output, target, weight
output = y_out

param_tar = torch.zeros(1,100,4)
param_tar_v = torch.randn(15,4)
param_tar[0,:15,:] = param_tar_v

mask_tar = torch.zeros(1,100,1)
mask_tar[0,:15,:] = 1

bg = torch.ones((96,96))*5
target = (param_tar, mask_tar, bg)

weight = None

_forward_checks(output, target, weight)
tar_param, tar_mask, tar_bg = target
p, pxyz_mu, pxyz_sig, bg = _format_model_output(output)
_bg_loss = torch.nn.MSELoss(reduction='none')
bg_loss = _bg_loss(bg, tar_bg).sum(-1).sum(-1)
# gmm_loss = _compute_gmm_loss(p, pxyz_mu, pxyz_sig, tar_param, tar_mask)

y_out: torch.Size([1, 10, 96, 96])


  return F.mse_loss(input, target, reduction=self.reduction)


In [10]:
from abc import ABC, abstractmethod  # abstract class
from typing import Union, Tuple

import torch
from deprecated import deprecated
from torch import distributions

# from . import MixtureSameFamily as mixture
# from ..simulation import psf_kernel
import decode.generic.utils

p, pxyz_mu, pxyz_sig, pxyz_tar, mask = p, pxyz_mu, pxyz_sig, tar_param, tar_mask
# def _compute_gmm_loss(self, p, pxyz_mu, pxyz_sig, pxyz_tar, mask) -> torch.Tensor:
"""
Computes the Gaussian Mixture Loss.

Args:
    p: the model's detection prediction (sigmoid already applied) size N x H x W
    pxyz_mu: prediction of parameters (phot, xyz) size N x C=4 x H x W
    pxyz_sig: presdiction of uncertainties / sigma values (phot, xyz) size N x C=4 x H x W
    pxyz_tar: ground truth values (phot, xyz) size N x M x 4 (M being max number of tars)
    mask: activation mask of ground truth values (phot, xyz) size N x M

Returns:
    torch.Tensor (size N x 1)
"""

batch_size = pxyz_mu.size(0)
log_prob = 0
print(p.shape)
p_mean = p.sum(-1).sum(-1)
p_var = (p - p ** 2).sum(-1).sum(-1)  # var estimate of bernoulli
print(p_var.shape)


torch.Size([1, 96, 96])
torch.Size([1])


In [None]:

p_gauss = distributions.Normal(p_mean, torch.sqrt(p_var))

log_prob = log_prob + p_gauss.log_prob(mask.sum(-1)) * mask.sum(-1)
print(f'log_prob: {log_prob.shape}')
# print(log_prob)
prob_normed = p / p.sum(-1).sum(-1).view(-1, 1, 1)

print(p.shape)
p_inds = tuple((p + 1).nonzero(as_tuple=False).transpose(1, 0))
# print(p_inds)
print(pxyz_mu.shape)
pxyz_mu = pxyz_mu[p_inds[0], :, p_inds[1], p_inds[2]]


# p_inds = (torch.tensor([0,0]),torch.tensor([0,1]),torch.tensor([0,1]))
# xx = torch.randn((1,3,3))
# print(xx[p_inds[0], p_inds[1], p_inds[2]])

In [None]:

"""Hacky way to get all prob indices"""
p_inds = tuple((p + 1).nonzero(as_tuple=False).transpose(1, 0))
pxyz_mu = pxyz_mu[p_inds[0], :, p_inds[1], p_inds[2]]

"""Convert px shifts to absolute coordinates"""
pxyz_mu[:, 1] += self.bin_ctr_x[p_inds[1]].to(pxyz_mu.device)
pxyz_mu[:, 2] += self.bin_ctr_y[p_inds[2]].to(pxyz_mu.device)

"""Flatten img dimension --> N x (HxW) x 4"""
pxyz_mu = pxyz_mu.reshape(batch_size, -1, 4)
pxyz_sig = pxyz_sig[p_inds[0], :, p_inds[1], p_inds[2]].reshape(batch_size, -1, 4)

"""Set up mixture family"""
mix = distributions.Categorical(prob_normed[p_inds].reshape(batch_size, -1))
comp = distributions.Independent(distributions.Normal(pxyz_mu, pxyz_sig), 1)
gmm = distributions.mixture_same_family.MixtureSameFamily(mix, comp)
# print(f'gmm:{gmm}')

"""Calc log probs if there is anything there"""
if mask.sum():
    # print(f'pxyz_tar:{pxyz_tar.shape}')
    gmm_log = gmm.log_prob(pxyz_tar.transpose(0, 1)).transpose(0, 1)
    gmm_log = (gmm_log * mask).sum(-1)
    # print(f"LogProb: {log_prob.mean()}, GMM_log: {gmm_log.mean()}")
    log_prob = log_prob + gmm_log

# log_prob = log_prob.reshape(batch_size, 1)  # need?

loss = log_prob * (-1)


In [6]:
def _forward_checks(output: torch.Tensor, target: tuple, weight: None):

    if weight is not None:
        raise NotImplementedError(f"Weight must be None for this loss implementation.")

    if output.dim() != 4:
        raise ValueError(f"Output must have 4 dimensions (N,C,H,W).")

    if output.size(1) != 10:
        raise ValueError(f"Wrong number of channels.")

    if len(target) != 3:
        raise ValueError(f"Wrong length of target.")


def _format_model_output(output: torch.Tensor) -> tuple:
    """
    Transforms solely channel based model output into more meaningful variables.

    Args:
        output: model output

    Returns:
        tuple containing
            p: N x H x W
            pxyz_mu: N x 8 x H x W = phot1, phot2, phot3, x, y, z1, z2, z3
            pxyz_sig: N x 8 x H x W = phot1, phot2, phot3, x, y, z1, z2, z3
            bg: N x H x W
    """
    p = output[:, 0]
    pxyz_mu = output[:, 1:5]
    pxyz_sig = output[:, 5:-1]
    bg = output[:, -1]

    return p, pxyz_mu, pxyz_sig, bg

In [8]:
# for i in range(first_epoch, param.HyperParameter.epochs):
i=1
logger.add_scalar('learning/learning_rate', optimizer.param_groups[0]['lr'], i)
print(f'Epoch{i}')

if i >= 1:
    _ = decode.neuralfitter.train_val_impl.train(
        model=model,
        optimizer=optimizer,
        loss=criterion,
        dataloader=dl_train,
        grad_rescale=param.HyperParameter.moeller_gradient_rescale,
        grad_mod=grad_mod,
        epoch=i,
        device=torch.device(device),
        logger=logger
    )

Epoch1


1 Train Time:6.5e+01 Loss:2.41e+02: 100%|█████████████████████████| 281/281 [01:04<00:00,  4.34it/s]


In [6]:
import torch
import time
from typing import Union
from tqdm import tqdm
from collections import namedtuple
import sys

from decode.neuralfitter.utils import log_train_val_progress
from decode.evaluation.utils import MetricMeter
from torch import distributions

def ship_device(x, device: Union[str, torch.device]):
    """
    Ships the input to a pytorch compatible device (e.g. CUDA)

    Args:
        x:
        device:

    Returns:
        x

    """
    if x is None:
        return x

    elif isinstance(x, torch.Tensor):
        return x.to(device)

    elif isinstance(x, (tuple, list)):
        x = [ship_device(x_el, device) for x_el in x]  # a nice little recursion that worked at the first try
        return x

    elif device != 'cpu':
        raise NotImplementedError(f"Unsupported data type for shipping from host to CUDA device.")

In [7]:
model=model
optimizer=optimizer
loss=criterion
dataloader=dl_train
grad_rescale=param.HyperParameter.moeller_gradient_rescale
grad_mod=grad_mod
epoch=i
device=torch.device(device)
logger=logger

NameError: name 'i' is not defined

In [9]:
"""Actual Training"""
"""Some Setup things"""
model.train()
tqdm_enum = tqdm(dataloader, total=len(dataloader), smoothing=0.,ncols=100)  # progress bar enumeration
t0 = time.time()
loss_epoch = MetricMeter()
loss_gmm_epoch = MetricMeter()
loss_bg_epoch = MetricMeter()

# for batch_num, (x, y_tar, weight) in enumerate(tqdm_enum):  # model input (x), target (yt), weights (w)
    # x = frames, y_tar =  Tuple(param_tar, mask_tar, bg), weight = None
batch_num, (x, y_tar, weight) = next(enumerate(tqdm_enum))
# print(y_tar[0][0,:5,:])
# print(y_tar[1][0,:5])
# print(y_tar[1][0,-5:])
# print(y_tar[2][0,:1,:1])
# print(type(weight))

"""Monitor time to get the data"""
t_data = time.time() - t0

"""Ship the data to the correct device"""
x, y_tar, weight = ship_device([x, y_tar, weight], device)

"""Forward the data"""
y_out = model(x)
# print(y_out.shape) # [32, 10, 96, 96]

  0%|                                                                       | 0/281 [00:00<?, ?it/s]


In [10]:
loss = decode.neuralfitter.loss.GaussianMMLoss(
    xextent=param.Simulation.psf_extent[0],
    yextent=param.Simulation.psf_extent[1],
    img_shape=param.Simulation.img_size,
    device=device,
    chweight_stat=param.HyperParameter.chweight_stat)

"""Reset the optimiser, compute the loss and backprop it"""
loss_val = loss(y_out, y_tar, weight)
print(loss_val[0])

LogProb: -1177.076416015625, GMM_log: 30.081775665283203
tensor([ 1239.6512,     0.0309], device='cuda:2', grad_fn=<SelectBackward>)


In [11]:
output = y_out
target = y_tar
weight = weight

# def forward(self, output: torch.Tensor, target: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
#             weight: None) -> torch.Tensor:

tar_param, tar_mask, tar_bg = target
# p, pxyz_mu, pxyz_sig, bg = self._format_model_output(output)
p = output[:, 0]
pxyz_mu = output[:, 1:5]
pxyz_sig = output[:, 5:-1]
bg = output[:, -1]
# print(torch.max(p))
# print(pxyz_mu.shape)

"""" Background Loss """
_bg_loss = torch.nn.MSELoss(reduction='none')
bg_loss = _bg_loss(bg, tar_bg).sum(-1).sum(-1)
print(bg_loss[0]*2)
# print(torch.max(bg_loss))
# print(bg[0,0,:5])
# print(tar_bg[0,0,:5])


tensor(0.0309, device='cuda:2', grad_fn=<MulBackward0>)


In [16]:
p = p
pxyz_mu = pxyz_mu
pxyz_sig = pxyz_sig
pxyz_tar = tar_param
mask = tar_mask

batch_size = pxyz_mu.size(0) # 32
log_prob = 0

p_mean = p.sum(-1).sum(-1) # shape = [32]
# print(p_mean)

p_var = (p - p ** 2).sum(-1).sum(-1)  # var estimate of bernoulli
# print(p_var)

p_gauss = distributions.Normal(p_mean, torch.sqrt(p_var))
# print(p_gauss.log_prob(mask.sum(-1)))
# print(-((x - mu) ** 2) / (2 * sig**2) - math.log(sig) - math.log(math.sqrt(2 * math.pi)))

# print(mask.sum(-1))
# [39,  3,  1, 16,  3, 27, 12, 19,  4,  7,  5, 39, 27, 25, 28, 16, 19, 16,
#          5, 37, 15, 23, 28,  2, 19, 19, 28,  7, 26,  3, 21,  8],
#        device='cuda:2')

log_prob = log_prob + p_gauss.log_prob(mask.sum(-1)) * mask.sum(-1)
# print(log_prob)

prob_normed = p / p.sum(-1).sum(-1).view(-1, 1, 1)
# print(torch.max(prob_normed))
# 0.1479

In [18]:
img_shape=param.Simulation.img_size
xextent=param.Simulation.psf_extent[0]
yextent=param.Simulation.psf_extent[1]
_bin_x, _bin_y, bin_ctr_x, bin_ctr_y = decode.generic.utils.frame_grid(img_shape, xextent, yextent)

p = output[:, 0]
pxyz_mu = output[:, 1:5]
pxyz_sig = output[:, 5:-1]

"""Hacky way to get all prob indices"""
p_inds = tuple((p + 1).nonzero(as_tuple=False).transpose(1, 0))
pxyz_mu = pxyz_mu[p_inds[0], :, p_inds[1], p_inds[2]]

"""Convert px shifts to absolute coordinates"""
pxyz_mu[:, 1] += bin_ctr_x[p_inds[1]].to(pxyz_mu.device)
pxyz_mu[:, 2] += bin_ctr_y[p_inds[2]].to(pxyz_mu.device)

"""Flatten img dimension --> N x (HxW) x 4"""
pxyz_mu = pxyz_mu.reshape(batch_size, -1, 4)
pxyz_sig = pxyz_sig[p_inds[0], :, p_inds[1], p_inds[2]].reshape(batch_size, -1, 4)

"""Set up mixture family"""
mix = distributions.Categorical(prob_normed[p_inds].reshape(batch_size, -1))
comp = distributions.Independent(distributions.Normal(pxyz_mu, pxyz_sig), 1)
gmm = distributions.mixture_same_family.MixtureSameFamily(mix, comp)
# print(f'gmm:{gmm}')

"""Calc log probs if there is anything there"""
if mask.sum():
    # print(f'pxyz_tar:{pxyz_tar.shape}')
    gmm_log = gmm.log_prob(pxyz_tar.transpose(0, 1)).transpose(0, 1)
    gmm_log = (gmm_log * mask).sum(-1)
    # print(f"LogProb: {log_prob.mean()}, GMM_log: {gmm_log.mean()}")
    # log_prob = log_prob + gmm_log
    gmm_loss = gmm_log

# log_prob = log_prob.reshape(batch_size, 1)  # need?

gmm_loss = gmm_loss * (-1)
log_prob = log_prob * (-1)
print(gmm_loss[0])
print(log_prob[0])
print((log_prob[0]+gmm_loss[0])*2)

tensor(-21.7802, device='cuda:2', grad_fn=<SelectBackward>)
tensor(641.6058, device='cuda:2', grad_fn=<SelectBackward>)
tensor(1239.6512, device='cuda:2', grad_fn=<MulBackward0>)


: 

In [34]:
import math
x = 39
mu = 130.7310
sig = math.sqrt(37.3684)

p_gauss = distributions.Normal(mu,sig)
print(p_gauss.log_prob(x))
print(-((x - mu) ** 2) / (2 * sig**2) - math.log(sig) - math.log(math.sqrt(2 * math.pi)))

tensor(-115.3188)
-115.31881669101394


In [None]:
if grad_rescale:  # rescale gradients so that they are in the same order for the last layer
    weight, _, _ = model.rescale_last_layer_grad(loss_val, optimizer)
    loss_val = loss_val * weight

optimizer.zero_grad()
loss_val.mean().backward()

"""Gradient Modification"""
if grad_mod:
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.03, norm_type=2)

"""Update model parameters"""
optimizer.step()

"""Monitor overall time"""
t_batch = time.time() - t0

"""Logging"""
loss_mean, loss_cmp = loss.log(loss_val)  # compute individual loss components
loss_gmm = loss_cmp['gmm']
loss_bg = loss_cmp['bg']
del loss_val
loss_epoch.update(loss_mean)
loss_gmm_epoch.update(loss_gmm)
loss_bg_epoch.update(loss_bg)
tqdm_enum.set_description(f"{epoch} Train Time:{t_batch:.2} Loss:{loss_mean:.3}")

# t0 = time.time()

# log_train_val_progress.log_train(loss_p_batch=loss_epoch.vals, loss_mean=loss_epoch.mean, logger=logger, step=epoch)

log_train_val_progress.log_train(loss_p_batch=loss_epoch.vals, loss_mean=loss_epoch.mean, logger=logger, step=epoch,loss_gmm_mean=loss_gmm_epoch.mean,loss_bg_mean=loss_bg_epoch.mean)

return loss_epoch.mean

In [None]:

# val_loss=avg of loss for all batches
# test_out=list of network_output: ["loss", "x", "y_out", "y_tar", "weight", "em_tar"]
val_loss, test_out = decode.neuralfitter.train_val_impl.test(
    model=model,
    loss=criterion,
    dataloader=dl_test,
    epoch=i,
    device=torch.device(device))
# print(val_loss)

if best_val_loss - val_loss >1e-4:
    best_val_loss = val_loss
    # model_ls.save(model, None, epoch_idx='best')

t0 = time.time()
if i%3 == 0:
    """Post-Process and Evaluate"""
    log_train_val_progress.post_process_log_test(loss_cmp=test_out.loss,
                                                loss_scalar=val_loss,
                                                x=test_out.x, y_out=test_out.y_out,
                                                y_tar=test_out.y_tar,
                                                weight=test_out.weight,
                                                em_tar=ds_test.emitter(),
                                                px_border=-0.5, px_size=1.,
                                                post_processor=post_processor,
                                                matcher=matcher, logger=logger,
                                                step=i)
else:
    log_train_val_progress.log_kpi_simplified(loss_scalar=val_loss,
                                            loss_cmp=test_out.loss,
                                            logger=logger,
                                            step=i)

t_log = time.time() - t0
print(f'log time:{t_log}')

if i >= 1:
    if isinstance(lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
        lr_scheduler.step(val_loss)
    else:
        lr_scheduler.step()

# print("Training finished after reaching maximum number of epochs.")