In [1]:
import argparse
import copy
import datetime
import os
import shutil
import socket
import sys
from pathlib import Path

import torch

import decode.evaluation
import decode.neuralfitter
import decode.neuralfitter.coord_transform
import decode.neuralfitter.utils
import decode.simulation
import decode.utils
from decode.neuralfitter.train.random_simulation import setup_random_simulation
from decode.neuralfitter.utils import log_train_val_progress
from decode.utils.checkpoint import CheckPoint

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def parse_args():
    parser = argparse.ArgumentParser(description='Training Args')

    parser.add_argument('-i', '--device', default=None,
                        help='Specify the device string (cpu, cuda, cuda:0) and overwrite param.',
                        type=str, required=False)

    parser.add_argument('-p', '--param_file',
                        help='Specify your parameter file (.yml or .json).',
                        required=True)

    parser.add_argument('-d', '--debug', default=False, action='store_true',
                        help='Debug the specified parameter file. Will reduce ds size for example.')

    parser.add_argument('-w', '--num_worker_override',
                        help='Override the number of workers for the dataloaders.',
                        type=int)

    parser.add_argument('-n', '--no_log', default=False, action='store_true',
                        help='Set no log if you do not want to log the current run.')

    parser.add_argument('-l', '--log_folder', default='runs',
                        help='Specify the (parent) folder you want to log to. If rel-path, relative to DECODE root.')

    parser.add_argument('-c', '--log_comment', default=None,
                        help='Add a log_comment to the run.')

    # args = parser.parse_args()
    args, unknown = parser.parse_known_args()
    return args

In [3]:
param_file = '/home/lingjia/Documents/rPSF/NN/param_run.yaml'
device_overwrite = 'cuda'
debug = False
num_worker_override = None
no_log = False
log_folder = '/home/lingjia/Documents/rPSF/log'
log_comment = None

In [4]:
"""Load Parameters and back them up to the network output directory"""
param_file = Path(param_file)
param = decode.utils.param_io.ParamHandling().load_params(param_file)

# auto-set some parameters (will be stored in the backup copy)
param = decode.utils.param_io.autoset_scaling(param)

# add meta information
param.Meta.version = decode.utils.bookkeeping.decode_state()

"""Experiment ID"""
if not debug:
    if param.InOut.checkpoint_init is None:
        experiment_id = datetime.datetime.now().strftime(
            "%Y-%m-%d_%H-%M-%S") + '_' + socket.gethostname()
        from_ckpt = False
        if log_comment:
            experiment_id = experiment_id + '_' + log_comment
    else:
        from_ckpt = True
        experiment_id = Path(param.InOut.checkpoint_init).parent.name
else:
    experiment_id = 'debug'
    from_ckpt = False

"""Set up unique folder for experiment"""
if not from_ckpt:
    experiment_path = Path(param.InOut.experiment_out) / Path(experiment_id)
else:
    experiment_path = Path(param.InOut.checkpoint_init).parent

if not experiment_path.parent.exists():
    experiment_path.parent.mkdir()

if not from_ckpt:
    if debug:
        experiment_path.mkdir(exist_ok=True)
    else:
        experiment_path.mkdir(exist_ok=False)

model_out = experiment_path / Path('model.pt')
ckpt_path = experiment_path / Path('ckpt.pt')

# Backup the parameter file under the network output path with the experiments ID
param_backup_in = experiment_path / Path('param_run_in').with_suffix(param_file.suffix)
shutil.copy(param_file, param_backup_in)

param_backup = experiment_path / Path('param_run').with_suffix(param_file.suffix)
decode.utils.param_io.ParamHandling().write_params(param_backup, param)

if debug:
    decode.utils.param_io.ParamHandling.convert_param_debug(param)

if num_worker_override is not None:
    param.Hardware.num_worker_train = num_worker_override

"""Hardware / Server stuff."""
if device_overwrite is not None:
    device = device_overwrite
    param.Hardware.device_simulation = device_overwrite  # lazy assumption
else:
    device = param.Hardware.device

if torch.cuda.is_available():
    _, device_ix = decode.utils.hardware._specific_device_by_str(device)
    if device_ix is not None:
        # do this instead of set env variable, because torch is inevitably already imported
        torch.cuda.set_device(device)
elif not torch.cuda.is_available():
    device = 'cpu'

if param.Hardware.torch_multiprocessing_sharing_strategy is not None:
    torch.multiprocessing.set_sharing_strategy(
        param.Hardware.torch_multiprocessing_sharing_strategy)

if sys.platform in ('linux', 'darwin'):
    os.nice(param.Hardware.unix_niceness)
elif param.Hardware.unix_niceness is not None:
    print(f"Cannot set niceness on platform {sys.platform}. You probably do not need to worry.")

torch.set_num_threads(param.Hardware.torch_threads)

"""Setup Log System"""
if no_log:
    logger = decode.neuralfitter.utils.logger.NoLog()

else:
    log_folder = log_folder + '/' + experiment_id

    logger = decode.neuralfitter.utils.logger.MultiLogger(
        [decode.neuralfitter.utils.logger.SummaryWriter(log_dir=log_folder,
                                                        filter_keys=["dx_red_mu", "dx_red_sig",
                                                                        "dy_red_mu",
                                                                        "dy_red_sig", "dz_red_mu",
                                                                        "dz_red_sig",
                                                                        "dphot_red_mu",
                                                                        "dphot_red_sig"]),
            decode.neuralfitter.utils.logger.DictLogger()])

In [5]:

def setup_trainer(simulator_train, simulator_test, logger, model_out, ckpt_path, device, param):
    """Set model, optimiser, loss and schedulers"""
    models_available = {
        'SigmaMUNet': decode.neuralfitter.models.SigmaMUNet,
        'DoubleMUnet': decode.neuralfitter.models.model_param.DoubleMUnet,
        'SimpleSMLMNet': decode.neuralfitter.models.model_param.SimpleSMLMNet,
    }

    model = models_available[param.HyperParameter.architecture]
    model = model.parse(param)

    model_ls = decode.utils.model_io.LoadSaveModel(model,
                                                   output_file=model_out)

    model = model_ls.load_init()
    model = model.to(torch.device(device))

    # Small collection of optimisers
    optimizer_available = {
        'Adam': torch.optim.Adam,
        'AdamW': torch.optim.AdamW
    }

    optimizer = optimizer_available[param.HyperParameter.optimizer]
    optimizer = optimizer(model.parameters(), **param.HyperParameter.opt_param)

    """Loss function."""
    criterion = decode.neuralfitter.loss.GaussianMMLoss(
        xextent=param.Simulation.psf_extent[0],
        yextent=param.Simulation.psf_extent[1],
        img_shape=param.Simulation.img_size,
        device=device,
        chweight_stat=param.HyperParameter.chweight_stat)

    """Learning Rate and Simulation Scheduling"""
    lr_scheduler_available = {
        'ReduceLROnPlateau': torch.optim.lr_scheduler.ReduceLROnPlateau,
        'StepLR': torch.optim.lr_scheduler.StepLR
    }
    lr_scheduler = lr_scheduler_available[param.HyperParameter.learning_rate_scheduler]
    lr_scheduler = lr_scheduler(optimizer, **param.HyperParameter.learning_rate_scheduler_param)

    """Checkpointing"""
    checkpoint = CheckPoint(path=ckpt_path)

    """Setup gradient modification"""
    grad_mod = param.HyperParameter.grad_mod

    """Log the model"""
    try:
        dummy = torch.rand((2, param.HyperParameter.channels_in,
                            *param.Simulation.img_size), requires_grad=False).to(
            torch.device(device))
        logger.add_graph(model, dummy)

    except:
        print("Did not log graph.")
        # raise RuntimeError("Your dummy input is wrong. Please update it.")

    """Transform input data, compute weight mask and target data"""
    frame_proc = decode.neuralfitter.scale_transform.AmplitudeRescale.parse(param)
    bg_frame_proc = None

    if param.HyperParameter.emitter_label_photon_min is not None:
        em_filter = decode.neuralfitter.em_filter.PhotonFilter(
            param.HyperParameter.emitter_label_photon_min)
    else:
        em_filter = decode.neuralfitter.em_filter.NoEmitterFilter()

    tar_frame_ix_train = (0, 0)
    tar_frame_ix_test = (0, param.TestSet.test_size)

    """Setup Target generator consisting possibly multiple steps in a transformation sequence."""
    tar_gen = decode.neuralfitter.utils.processing.TransformSequence(
        [
            decode.neuralfitter.target_generator.ParameterListTarget(
                n_max=param.HyperParameter.max_number_targets,
                xextent=param.Simulation.psf_extent[0],
                yextent=param.Simulation.psf_extent[1],
                ix_low=tar_frame_ix_train[0],
                ix_high=tar_frame_ix_train[1],
                squeeze_batch_dim=True),

            decode.neuralfitter.target_generator.DisableAttributes.parse(param),

            decode.neuralfitter.scale_transform.ParameterListRescale(
                phot_max=param.Scaling.phot_max,
                z_max=param.Scaling.z_max,
                bg_max=param.Scaling.bg_max)
        ])

    # setup target for test set in similar fashion, however test-set is static.
    tar_gen_test = copy.deepcopy(tar_gen)
    tar_gen_test.com[0].ix_low = tar_frame_ix_test[0]
    tar_gen_test.com[0].ix_high = tar_frame_ix_test[1]
    tar_gen_test.com[0].squeeze_batch_dim = False
    tar_gen_test.com[0].sanity_check()

    if param.Simulation.mode == 'acquisition':
        train_ds = decode.neuralfitter.dataset.SMLMLiveDataset(
            simulator=simulator_train,
            em_proc=em_filter,
            frame_proc=frame_proc,
            bg_frame_proc=bg_frame_proc,
            tar_gen=tar_gen, weight_gen=None,
            frame_window=param.HyperParameter.channels_in,
            pad=None, return_em=False)

        train_ds.sample(True)

    elif param.Simulation.mode == 'samples':
        train_ds = decode.neuralfitter.dataset.SMLMLiveSampleDataset(
            simulator=simulator_train,
            em_proc=em_filter,
            frame_proc=frame_proc,
            bg_frame_proc=bg_frame_proc,
            tar_gen=tar_gen,
            weight_gen=None,
            frame_window=param.HyperParameter.channels_in,
            return_em=False,
            ds_len=param.HyperParameter.pseudo_ds_size)

    test_ds = decode.neuralfitter.dataset.SMLMAPrioriDataset(
        simulator=simulator_test,
        em_proc=em_filter,
        frame_proc=frame_proc,
        bg_frame_proc=bg_frame_proc,
        tar_gen=tar_gen_test, weight_gen=None,
        frame_window=param.HyperParameter.channels_in,
        pad=None, return_em=False)

    test_ds.sample(True)

    """Set up post processor"""
    if param.PostProcessing is None:
        post_processor = decode.neuralfitter.post_processing.NoPostProcessing(xy_unit='px',
                                                                              px_size=param.Camera.px_size)

    elif param.PostProcessing == 'LookUp':
        post_processor = decode.neuralfitter.utils.processing.TransformSequence([

            decode.neuralfitter.scale_transform.InverseParamListRescale(
                phot_max=param.Scaling.phot_max,
                z_max=param.Scaling.z_max,
                bg_max=param.Scaling.bg_max),

            decode.neuralfitter.coord_transform.Offset2Coordinate.parse(param),

            decode.neuralfitter.post_processing.LookUpPostProcessing(
                raw_th=param.PostProcessingParam.raw_th,
                pphotxyzbg_mapping=[0, 1, 2, 3, 4, -1],
                xy_unit='px',
                px_size=param.Camera.px_size)
        ])

    elif param.PostProcessing in ('SpatialIntegration', 'NMS'):  # NMS as legacy support
        post_processor = decode.neuralfitter.utils.processing.TransformSequence([

            decode.neuralfitter.scale_transform.InverseParamListRescale(
                phot_max=param.Scaling.phot_max,
                z_max=param.Scaling.z_max,
                bg_max=param.Scaling.bg_max),

            decode.neuralfitter.coord_transform.Offset2Coordinate.parse(param),

            decode.neuralfitter.post_processing.SpatialIntegration(
                raw_th=param.PostProcessingParam.raw_th,
                xy_unit='px',
                px_size=param.Camera.px_size)
        ])

    else:
        raise NotImplementedError

    """Evaluation Specification"""
    matcher = decode.evaluation.match_emittersets.GreedyHungarianMatching.parse(param)

    return train_ds, test_ds, model, model_ls, optimizer, criterion, lr_scheduler, grad_mod, post_processor, matcher, checkpoint


def setup_dataloader(param, train_ds, test_ds=None):
    """Set's up dataloader"""

    train_dl = torch.utils.data.DataLoader(
        dataset=train_ds,
        batch_size=param.HyperParameter.batch_size,
        drop_last=True,
        shuffle=True,
        num_workers=param.Hardware.num_worker_train,
        pin_memory=True,
        collate_fn=decode.neuralfitter.utils.dataloader_customs.smlm_collate)

    if test_ds is not None:

        test_dl = torch.utils.data.DataLoader(
            dataset=test_ds,
            batch_size=param.HyperParameter.batch_size,
            drop_last=False,
            shuffle=False,
            num_workers=param.Hardware.num_worker_train,
            pin_memory=False,
            collate_fn=decode.neuralfitter.utils.dataloader_customs.smlm_collate)
    else:

        test_dl = None

    return train_dl, test_dl



In [6]:
sim_train, sim_test = setup_random_simulation(param)
ds_train, ds_test, model, model_ls, optimizer, criterion, lr_scheduler, grad_mod, post_processor, matcher, ckpt = \
    setup_trainer(sim_train, sim_test, logger, model_out, ckpt_path, device, param)
dl_train, dl_test = setup_dataloader(param, ds_train, ds_test)

Model instantiated.
Model initialised as specified in the constructor.
Sampled dataset in 0.33s. 99786 emitters on 10001 frames.
Sampled dataset in 0.02s. 5035 emitters on 513 frames.


In [7]:
print(torch.__version__)
print(torch.cuda.is_available())

1.7.1
True


In [8]:
converges = False
n = 0
n_max = param.HyperParameter.auto_restart_param.num_restarts

In [9]:
i = 2

In [10]:
_ = decode.neuralfitter.train_val_impl.train(
    model=model,
    optimizer=optimizer,
    loss=criterion,
    dataloader=dl_train,
    grad_rescale=param.HyperParameter.moeller_gradient_rescale,
    grad_mod=grad_mod,
    epoch=i,
    device=torch.device(device),
    logger=logger
)

E: 2 - t: 0.17 - t_dat: 0.00082 - L: 95.5: 100%|██████████| 156/156 [00:27<00:00,  5.68it/s]    


In [11]:
from tqdm import tqdm
dataloader = dl_test
tqdm_enum = tqdm(dataloader, total=len(dataloader), smoothing=0.)  # progress bar enumeration
ttt = iter(tqdm_enum)
(x, y_tar, weight) = next(ttt)
print(x.size())
print(y_tar[0].size())
print(y_tar[1].size())
print(y_tar[2].size())
# print(y_tar[0])

  0%|          | 0/8 [00:00<?, ?it/s]

torch.Size([64, 3, 40, 40])
torch.Size([64, 250, 4])
torch.Size([64, 250])
torch.Size([64, 40, 40])


In [12]:
val_loss, test_out = decode.neuralfitter.train_val_impl.test(
    model=model,
    loss=criterion,
    dataloader=dl_test,
    epoch=i,
    device=torch.device(device))

(Test) E: 2 - T: 0.59: 100%|██████████| 8/8 [00:00<00:00, 11.91it/s]


In [13]:
loss_cmp=test_out.loss
loss_scalar=val_loss
x=test_out.x
y_out=test_out.y_out
y_tar=test_out.y_tar
weight=test_out.weight
em_tar=ds_test.emitter
px_border=-0.5
px_size=1.
post_processor=post_processor
matcher=matcher
logger=logger
step=i

In [14]:
print(y_out.shape)
y_out[:,0,:,:] = 1

torch.Size([511, 10, 40, 40])


In [15]:
print(ds_test.emitter.frame_ix)
print(len(ds_test.emitter.frame_ix))
print(ds_test.emitter.frame_ix.max())
print(len(ds_test))

tensor([186, 510, 418,  ..., 431, 407, 424])
4823
tensor(510)
511


In [20]:
"""Post-Process"""
em_out, xyz, frame_ixx = post_processor.forward(y_out)
# print(em_out.frame_ix)
# print(em_out.xyz)
# len(em_out.get_subset_frame(3,3).xyz)
print(xyz.shape)
print(frame_ixx.shape)

torch.Size([817600])
torch.Size([817600, 3])
torch.Size([817600, 3])
torch.Size([817600])


In [18]:
print(xyz[:10,:])
print(frame_ixx[:10])

tensor([[-4.6997e-01, -9.7671e-01,  6.7097e+02],
        [ 9.4692e-01,  8.0565e-03,  9.2543e+02],
        [ 9.4171e-01,  1.0031e+00,  8.9564e+02],
        [ 9.7965e-01,  2.0030e+00,  9.1074e+02],
        [ 9.9014e-01,  3.0060e+00,  9.1005e+02],
        [ 9.8206e-01,  4.0024e+00,  8.5168e+02],
        [ 9.8583e-01,  5.0025e+00,  8.4808e+02],
        [ 9.8512e-01,  6.0028e+00,  8.5191e+02],
        [ 9.8585e-01,  7.0026e+00,  8.4587e+02],
        [ 9.8400e-01,  8.0027e+00,  8.4337e+02]])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])


In [37]:
"""Match and Evaluate"""
tp, fp, fn, tp_match = matcher.forward(em_out, em_tar)


In [1]:
0.000==0

True

In [43]:
# print(tp)
# print(fp)
# print(fn)
# print(tp_match)
print(ds_train.emitters)

AttributeError: 'SMLMLiveDataset' object has no attribute 'emitters'

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    result = evaluation.SMLMEvaluation(weighted_eval=WeightedErrors(mode='crlb', reduction='gaussian')).forward(tp, fp, fn, tp_match)

"""Log"""
# raw frames
log_frames(x=x, y_out=y_out, y_tar=y_tar, weight=weight, em_out=em_out, em_tar=em_tar, tp=tp, tp_match=tp_match,
            logger=logger, step=step)

# KPIs
log_kpi(loss_scalar=loss_scalar, loss_cmp=loss_cmp, eval_set=result._asdict(), logger=logger, step=step)

# distributions
log_dists(tp=tp, tp_match=tp_match, pred=em_out, px_border=px_border, px_size=px_size, logger=logger, step=step)

In [27]:
output = em_out
target = ds_test.emitter
print(len(output))
print(len(target))


817600
4938


In [36]:
if len(output) >= 1 and len(target) >= 1:
    frame_low = output.frame_ix.min() if output.frame_ix.min() < target.frame_ix.min() else target.frame_ix.min()
    frame_high = output.frame_ix.max() if output.frame_ix.max() > target.frame_ix.max() else target.frame_ix.max()
elif len(output) >= 1:
    frame_low = output.frame_ix.min()
    frame_high = output.frame_ix.max()
elif len(target) >= 1:
    frame_low = target.frame_ix.min()
    frame_high = target.frame_ix.max()

In [37]:
print(frame_low)
print(frame_high)

tensor(0)
tensor(510)
