In [1]:
import argparse, datetime, numpy, os, shutil, sys, time
from cmath import inf
import socket
from pathlib import Path
import torch
# self-defined modules
import decode.evaluation
import decode.neuralfitter
import decode.neuralfitter.coord_transform
import decode.neuralfitter.utils
import decode.simulation
import decode.utils
# from decode.neuralfitter.train.random_simulation import setup_random_simulation
from decode.neuralfitter.utils import log_train_val_progress
from decode.utils.checkpoint import CheckPoint


  from .autonotebook import tqdm as notebook_tqdm


In [7]:
def parse_args():
    parser = argparse.ArgumentParser(description='Training Args')

    parser.add_argument('-i', '--device', default=None, 
                        help='Specify the device string (cpu, cuda, cuda:0) and overwrite param.',
                        type=str)

    parser.add_argument('-p', '--param_file', default=None,
                        help='Specify your parameter file (.yml or .json).', type=str)

    parser.add_argument('-w', '--num_worker_override',default=None,
                        help='Override the number of workers for the dataloaders.',
                        type=int)

    parser.add_argument('-n', '--no_log', default=False, action='store_true',
                        help='Set no log if you do not want to log the current run.')

    parser.add_argument('-c', '--log_comment', default=None,
                        help='Add a log_comment to the run.')

    parser.add_argument('-d', '--data_path_override', default=None,
                        help='Specify your path to data', type=str)

    parser.add_argument('-is', '--img_size_override', default=None,
                        help='Override img size', type=int)

    # args = parser.parse_args()
    args, _ = parser.parse_known_args()
    return args

In [8]:
args = parse_args()
args.device = 'cuda'
args.param_file ='/home/lingjia/Documents/rpsf/NN/param.yaml'
args.data_path_override = '/media/hdd_4T/lingjia/rPSF/20220716_decode_variant_v2/data_train/30k_pt50L5'
args.img_size_override = 96

In [9]:
def setup_trainer(logger, model_out, ckpt_path, device, param):
    """Set model, optimiser, loss and schedulers"""
    models_available = {
        'SigmaMUNet': decode.neuralfitter.models.SigmaMUNet_variant,
        'DoubleMUnet': decode.neuralfitter.models.model_param.DoubleMUnet,
        'SimpleSMLMNet': decode.neuralfitter.models.model_param.SimpleSMLMNet,
    }

    model = models_available[param.HyperParameter.architecture]
    model = model.parse(param)

    model_ls = decode.utils.model_io.LoadSaveModel(model, output_file=model_out)

    model = model_ls.load_init()
    model = model.to(torch.device(device))

    # Small collection of optimisers
    optimizer_available = {
        'Adam': torch.optim.Adam,
        'AdamW': torch.optim.AdamW,
        'SGD': torch.optim.SGD
    }

    optimizer = optimizer_available[param.HyperParameter.optimizer]
    optimizer = optimizer(model.parameters(), **param.HyperParameter.opt_param)

    """Loss function."""
    criterion = decode.neuralfitter.loss.GaussianMMLoss(
        xextent=param.Simulation.psf_extent[0],
        yextent=param.Simulation.psf_extent[1],
        img_shape=param.Simulation.img_size,
        device=device,
        chweight_stat=param.HyperParameter.chweight_stat)

    """Learning Rate and Simulation Scheduling"""
    lr_scheduler_available = {
        'ReduceLROnPlateau': torch.optim.lr_scheduler.ReduceLROnPlateau,
        'StepLR': torch.optim.lr_scheduler.StepLR
    }
    lr_scheduler = lr_scheduler_available[param.HyperParameter.learning_rate_scheduler]
    lr_scheduler = lr_scheduler(optimizer, **param.HyperParameter.learning_rate_scheduler_param)

    """Checkpointing"""
    checkpoint = CheckPoint(path=ckpt_path)

    """Setup gradient modification"""
    grad_mod = param.HyperParameter.grad_mod

    """Log the model (Graph) """
    try:
        dummy = torch.rand((2, param.HyperParameter.channels_in,
                            *param.Simulation.img_size), requires_grad=False).to(torch.device(device))
        logger.add_graph(model, dummy)

    except:
        print("Did not log graph.")
        # raise RuntimeError("Your dummy input is wrong. Please update it.")

    """Setup Target generator consisting possibly multiple steps in a transformation sequence."""
    tar_proc = decode.neuralfitter.utils.processing.TransformSequence(
        [
            # param_tar --> phot/max, z/z_max, bg/bg_max
            decode.neuralfitter.scale_transform.ParameterListRescale(
                phot_max=param.Scaling.phot_max,
                z_max=param.Scaling.z_max,
                bg_max=param.Scaling.bg_max)
        ])

    # Split train & val set
    train_IDs = numpy.arange(1,101,1).tolist()
    val_IDs = numpy.arange(101,111,1).tolist()

    train_ds = decode.neuralfitter.dataset.rPSFDataset(root_dir=param.InOut.data_path,
                                                       list_IDs=train_IDs, label_path=None, 
                                                       n_max=param.HyperParameter.max_number_targets,
                                                       tar_proc=tar_proc,
                                                       img_shape=param.Simulation.img_size)

    test_ds = decode.neuralfitter.dataset.rPSFDataset(root_dir=param.InOut.data_path,
                                                       list_IDs=val_IDs, label_path=None, 
                                                       n_max=param.HyperParameter.max_number_targets,
                                                       tar_proc=tar_proc,
                                                       img_shape=param.Simulation.img_size)

    # print(test_ds.label_gen())

    """Set up post processor"""
    if param.PostProcessing is None:
        post_processor = decode.neuralfitter.post_processing.NoPostProcessing(xy_unit='px',
                                                                              px_size=param.Camera.px_size)

    elif param.PostProcessing == 'LookUp':
        post_processor = decode.neuralfitter.utils.processing.TransformSequence([

            decode.neuralfitter.scale_transform.InverseParamListRescale(
                phot_max=param.Scaling.phot_max,
                z_max=param.Scaling.z_max,
                bg_max=param.Scaling.bg_max),

            decode.neuralfitter.coord_transform.Offset2Coordinate.parse(param),

            decode.neuralfitter.post_processing.LookUpPostProcessing(
                raw_th=param.PostProcessingParam.raw_th,
                pphotxyzbg_mapping=[0, 1, 2, 3, 4, -1],
                xy_unit='px',
                px_size=param.Camera.px_size)
        ])

    elif param.PostProcessing in ('SpatialIntegration', 'NMS'):  # NMS as legacy support
        post_processor = decode.neuralfitter.utils.processing.TransformSequence([
            # out_tar --> out_tar: photo*photo_max, z*z_max, bg*bg_max
            decode.neuralfitter.scale_transform.InverseParamListRescale(
                phot_max=param.Scaling.phot_max,
                z_max=param.Scaling.z_max,
                bg_max=param.Scaling.bg_max),
            # offset --> coordinate e.g., 0.2 --> 10.2 
            decode.neuralfitter.coord_transform.Offset2Coordinate.parse(param),

            decode.neuralfitter.post_processing.SpatialIntegration(
                raw_th=param.PostProcessingParam.raw_th, # 0.5
                xy_unit='px')
        ])

    else:
        raise NotImplementedError

    """Evaluation Specification"""
    matcher = decode.evaluation.match_emittersets.GreedyHungarianMatching.parse(param)
    # matcher = None

    return train_ds, test_ds, model, model_ls, optimizer, criterion, lr_scheduler, grad_mod, post_processor, matcher, checkpoint


def setup_dataloader(param, train_ds, test_ds=None):
    """Set up dataloader"""

    train_dl = torch.utils.data.DataLoader(
        dataset=train_ds,
        batch_size=param.HyperParameter.batch_size,
        drop_last=True,
        shuffle=True,
        num_workers=param.Hardware.num_worker_train,
        pin_memory=True,
        collate_fn=decode.neuralfitter.utils.dataloader_customs.smlm_collate)

    if test_ds is not None:

        test_dl = torch.utils.data.DataLoader(
            dataset=test_ds,
            batch_size=param.HyperParameter.batch_size,
            drop_last=False,
            shuffle=False,
            num_workers=param.Hardware.num_worker_train,
            pin_memory=False,
            collate_fn=decode.neuralfitter.utils.dataloader_customs.smlm_collate)
    else:

        test_dl = None

    return train_dl, test_dl


In [10]:
"""BEOFRE TRAIN MODEL - BEOFRE TUNE LOSS"""
param_file = Path(args.param_file)
param = decode.utils.param_io.ParamHandling().load_params(param_file)

# add meta information - Meta=namespace(version='0.10.0'),
param.Meta.version = decode.utils.bookkeeping.decode_state()

"""Experiment ID"""
if param.InOut.checkpoint_init is None:
    experiment_id = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    from_ckpt = False
    if args.log_comment:
        experiment_id = experiment_id + '_' + args.log_comment
else:
    from_ckpt = True
    experiment_id = Path(param.InOut.checkpoint_init).parent.name

"""Set up unique folder for experiment"""
if not from_ckpt:
    experiment_path = Path(param.InOut.experiment_out) / Path(experiment_id)
else:
    experiment_path = Path(param.InOut.checkpoint_init).parent

if not experiment_path.parent.exists():
    experiment_path.parent.mkdir()

if not from_ckpt:
    experiment_path.mkdir(exist_ok=False)

model_out = experiment_path / Path('model.pt')
ckpt_path = experiment_path / Path('ckpt.pt')

# Modify parameters
if args.num_worker_override is not None:
    param.Hardware.num_worker_train = args.num_worker_override

"""Hardware / Server stuff."""
if args.device is not None:
    device = args.device
    # param.Hardware.device_simulation = device_overwrite  # lazy assumption
else:
    device = param.Hardware.device

if args.data_path_override is not None:
    param.InOut.data_path = args.data_path_override

if args.img_size_override is not None:
    param.Simulation.img_size = [args.img_size_override,args.img_size_override]
    param.Simulation.psf_extent = [[-0.5, args.img_size_override-0.5],
                                    [-0.5, args.img_size_override-0.5], None]

    param.TestSet.frame_extent = param.Simulation.psf_extent
    param.TestSet.img_size = param.Simulation.img_size

# Backup the parameter file under the network output path with the experiments ID
param_backup_in = experiment_path / Path('param_run_in').with_suffix(param_file.suffix)
shutil.copy(param_file, param_backup_in)

param_backup = experiment_path / Path('param_run').with_suffix(param_file.suffix)
decode.utils.param_io.ParamHandling().write_params(param_backup, param)

if sys.platform in ('linux', 'darwin'):
    os.nice(param.Hardware.unix_niceness)
elif param.Hardware.unix_niceness is not None:
    print(f"Cannot set niceness on platform {sys.platform}. You probably do not need to worry.")

torch.set_num_threads(param.Hardware.torch_threads)

"""Setup Log System"""
if args.no_log:
    logger = decode.neuralfitter.utils.logger.NoLog()

else:
    log_folder = experiment_path

    logger = decode.neuralfitter.utils.logger.MultiLogger(
        [decode.neuralfitter.utils.logger.SummaryWriter(log_dir=log_folder,
                                                        filter_keys=["dx_red_mu", "dx_red_sig",
                                                                        "dy_red_mu",
                                                                        "dy_red_sig", "dz_red_mu",
                                                                        "dz_red_sig",
                                                                        "dphot_red_mu",
                                                                        "dphot_red_sig"]),
            decode.neuralfitter.utils.logger.DictLogger()])

# sim_train, sim_test = setup_random_simulation(param)
ds_train, ds_test, model, model_ls, optimizer, criterion, lr_scheduler, grad_mod, post_processor, matcher, ckpt = setup_trainer(logger, model_out, ckpt_path, device, param)

dl_train, dl_test = setup_dataloader(param, ds_train, ds_test)

if from_ckpt:
    ckpt = decode.utils.checkpoint.CheckPoint.load(param.InOut.checkpoint_init)
    model.load_state_dict(ckpt.model_state)
    optimizer.load_state_dict(ckpt.optimizer_state)
    lr_scheduler.load_state_dict(ckpt.lr_sched_state)
    first_epoch = ckpt.step + 1
    model = model.train()
    print(f'Resuming training from checkpoint ' + experiment_id)
else:
    first_epoch = 0

Model instantiated.
Model initialised as specified in the constructor.


In [11]:
"""SIMULATED INPUT X AND GROUND-TRUTH TARGET"""
# INPUT
x = torch.randn(1,1,4,4).to('cuda')
# TARGET
param_tar = torch.zeros(1,60,4)
param_tar_v = torch.randn(2,4)
param_tar[0,:2,:] = param_tar_v

mask_tar = torch.zeros(1,60)
mask_tar[0,:2] = 1

bg = torch.ones((1,4,4))*5

target = (param_tar.to('cuda'), mask_tar.to('cuda'), bg.to('cuda'))
# WEIGHT
weight = None

In [19]:
"""OUTPUT OF SIMULATED INPUT X AND BASIC PROCESS AND BG LOSS"""
output = model(x)
print(f'y_out: {output.shape}')

# SHAPE CHECK OF OUTPUT AND TARGET
def _forward_checks(output: torch.Tensor, target: tuple, weight: None):
    if weight is not None:
        raise NotImplementedError(f"Weight must be None for this loss implementation.")

    if output.dim() != 4:
        raise ValueError(f"Output must have 4 dimensions (N,C,H,W).")

    if output.size(1) != 19:
        raise ValueError(f"Wrong number of channels.")

    if len(target) != 3:
        raise ValueError(f"Wrong length of target.")

_forward_checks(output, target, weight)

# FORMAT MODEL OUTPUT AND GROUND-TRUTH
def _format_model_output(output: torch.Tensor) -> tuple:
    """
    Transforms solely channel based model output into more meaningful variables.
    Args:
        output: model output
    Returns:
        tuple containing
            p: N x 3 x H x W
            pxyz_mu: N x 12 x H x W = 3phot, 3x, 3y, 3z
            pxyz_sig: N x 12 x H x W = 3phot, 3x, 3y, 3z
            bg: N x H x W
    """
    p = output[:, 0:2] # 0,1
    pxyz_mu = output[:, 2:10] # 2,3, 4,5, 6,7, 8,9
    pxyz_sig = output[:, 10:-1] # 10,11, 12,13, 14,15, 16,17
    bg = output[:, -1] # 27
    return p, pxyz_mu, pxyz_sig, bg

tar_param, tar_mask, tar_bg = target
p, pxyz_mu, pxyz_sig, bg = _format_model_output(output)
# print(bg.shape,tar_bg.shape)

# BG LOSS
_bg_loss = torch.nn.MSELoss(reduction='none')
bg_loss = _bg_loss(bg, tar_bg).sum(-1).sum(-1)
print(f'bg loss: {bg_loss.item():.3f}')
# gmm_loss = _compute_gmm_loss(p, pxyz_mu, pxyz_sig, tar_param, tar_mask)
# torch.set_printoptions(precision=4,sci_mode=False)
# print(output)

y_out: torch.Size([1, 19, 4, 4])
bg loss: 277.627


In [23]:
"""GMM LOSS - PROBABILITY LOSS"""
# LOAD MODULES
from abc import ABC, abstractmethod  # abstract class
from typing import Union, Tuple
import torch
from deprecated import deprecated
from torch import distributions
import decode.generic.utils

# gmm_loss = _compute_gmm_loss(p, pxyz_mu, pxyz_sig, tar_param, tar_mask)
p, pxyz_mu, pxyz_sig, pxyz_tar, mask = p, pxyz_mu, pxyz_sig, tar_param, tar_mask

# FUNCTION - COMPUTE_GMM_LOSS
"""
Computes the Gaussian Mixture Loss.
Args:
    p: the model's detection prediction (sigmoid already applied) size N x H x W
    pxyz_mu: prediction of parameters (phot, xyz) size N x C=4 x H x W
    pxyz_sig: prediction of uncertainties / sigma values (phot, xyz) size N x C=4 x H x W
    pxyz_tar: ground truth values (phot, xyz) size N x M x 4 (M being max number of tars)
    mask: activation mask of ground truth values (phot, xyz) size N x M
Returns:
    torch.Tensor (size N x 1)
"""
batch_size, nc, hh, ww = pxyz_mu.size()
print(f'pxyz_mu shape: [{batch_size} {nc} {hh} {ww}]')
log_prob = 0

p_mean = p.sum(-1).sum(-1).sum(-1)
p_var = (p - p ** 2).sum(-1).sum(-1).sum(-1)  # var estimate of bernoulli
print(f'p shape: {p.shape} mean: {p_mean.item():.4f} var: {p_var.item():.4f}')
p_gauss = distributions.Normal(p_mean, torch.sqrt(p_var))

log_prob = log_prob + p_gauss.log_prob(mask.sum(-1)) * mask.sum(-1)
print(f'log_prob version 1: {log_prob.item():.3f}')


pxyz_mu shape: [1 8 4 4]
p shape: torch.Size([1, 2, 4, 4]) mean: 0.1208 var: 0.1197
log_prob version 1: -29.212


In [27]:
"""GMM LOSS - LOCALIZATION LOSS"""
# print(p)
# print(p.shape)
# print(p.sum(-1).sum(-1))

prob_normed = p / p.sum(-1).sum(-1).sum(-1).view(-1, 1, 1, 1)
print(prob_normed.shape)

"""Hacky way to get all prob indices"""
# p_inds [0]=batch_index [1]=prob_3channel_index [2]=x_index [3]=y_index
p_inds = tuple((p + 1).nonzero(as_tuple=False).transpose(1, 0))
pxyz_mu = pxyz_mu.reshape(batch_size,int(nc/2),2,hh,ww).transpose(2,1)
print(pxyz_mu.shape)
pxyz_mu = pxyz_mu[p_inds[0], p_inds[1], :, p_inds[2], p_inds[3]]
print(pxyz_mu.shape)

"""Convert px shifts to absolute coordinates"""
pxyz_mu[:, 1] += 1
pxyz_mu[:, 2] += 1

"""Flatten img dimension --> N x (HxW) x 4"""
pxyz_mu = pxyz_mu.reshape(batch_size, -1, 4)
print(pxyz_mu.shape)
pxyz_sig = pxyz_sig.reshape(batch_size,int(nc/2),2,hh,ww).transpose(2,1)
pxyz_sig = pxyz_sig[p_inds[0], p_inds[1], :, p_inds[2], p_inds[3]].reshape(batch_size, -1, 4)


"""Set up mixture family"""
mix = distributions.Categorical(prob_normed[p_inds].reshape(batch_size, -1))
comp = distributions.Independent(distributions.Normal(pxyz_mu, pxyz_sig), 1)
gmm = distributions.mixture_same_family.MixtureSameFamily(mix, comp)
# print(f'gmm:{gmm}')

"""Calc log probs if there is anything there"""
if mask.sum():
    # print(f'pxyz_tar:{pxyz_tar.shape}')
    gmm_log = gmm.log_prob(pxyz_tar.transpose(0, 1)).transpose(0, 1)
    gmm_log = (gmm_log * mask).sum(-1)
    # print(f"LogProb: {log_prob.mean()}, GMM_log: {gmm_log.mean()}")
    log_prob = log_prob + gmm_log

# log_prob = log_prob.reshape(batch_size, 1)  # need?

loss = log_prob * (-1)

torch.Size([1, 2, 4, 4])
torch.Size([1, 2, 4, 4, 4])
torch.Size([32, 4])
torch.Size([1, 32, 4])
