In [1]:
import argparse
import copy
import datetime
import os
import shutil
import socket
import sys
from pathlib import Path
import torch

import decode.evaluation
import decode.neuralfitter
import decode.neuralfitter.coord_transform
import decode.neuralfitter.utils
import decode.simulation
import decode.utils
from decode.neuralfitter.train.random_simulation import setup_random_simulation
from decode.neuralfitter.utils import log_train_val_progress
from decode.utils.checkpoint import CheckPoint

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = 'cpu'

In [1]:
# device_overwrite = 'cuda'
# debug = False
# num_worker_override = None
# no_log = False
# log_folder = '/home/lingjia/Documents/rPSF/log'
# log_comment = None

In [2]:
"""Load Parameters and back them up to the network output directory"""
param_file = '/home/lingjia/Documents/rPSF/NN/param_run_in.yaml'
# param_file = '/home/lingjia/Documents/rPSF/NN/param.yaml'
param_file = Path(param_file)
param = decode.utils.param_io.ParamHandling().load_params(param_file)

In [20]:
# print(type(param.Simulation.img_size))

parser = argparse.ArgumentParser(description='Training Args')

parser.add_argument('-i', '--device', default=None, 
                    help='Specify the device string (cpu, cuda, cuda:0) and overwrite param.',
                    type=str)

parser.add_argument('-p', '--param_file', default=None,
                    help='Specify your parameter file (.yml or .json).', type=str)

# parser.add_argument('-d', '--debug', default=False, action='store_true',
#                     help='Debug the specified parameter file. Will reduce ds size for example.')

parser.add_argument('-w', '--num_worker_override',default=None,
                    help='Override the number of workers for the dataloaders.',
                    type=int)

parser.add_argument('-n', '--no_log', default=False, action='store_true',
                    help='Set no log if you do not want to log the current run.')

# parser.add_argument('-l', '--log_folder', default='runs',
#                     help='Specify the (parent) folder you want to log to. If rel-path, relative to DECODE root.')

parser.add_argument('-c', '--log_comment', default=None,
                    help='Add a log_comment to the run.')

parser.add_argument('-d', '--data_path_override', default=None,
                    help='Specify your path to data', type=str)

parser.add_argument('-is', '--img_size_override', default=None,
                    help='Override img size', type=list)

# args = parser.parse_args()
args, unknown = parser.parse_known_args()

In [25]:
type(param.Simulation.psf_extent[0])
param.Simulation.psf_extent

[[-0.5, 39.5], [-0.5, 39.5], None]

In [26]:
xextent = [-0.5,39.5]
yextent = [-0.5,39.5]
img_size = [40,40]

bin_x = torch.linspace(*xextent, steps=img_size[0] + 1)
bin_y = torch.linspace(*yextent, steps=img_size[1] + 1)
bin_ctr_x = (bin_x + (bin_x[1] - bin_x[0]) / 2)[:-1]
bin_ctr_y = (bin_y + (bin_y[1] - bin_y[0]) / 2)[:-1]

In [28]:
print(bin_x)
print(bin_ctr_x)

tensor([-0.5000,  0.5000,  1.5000,  2.5000,  3.5000,  4.5000,  5.5000,  6.5000,
         7.5000,  8.5000,  9.5000, 10.5000, 11.5000, 12.5000, 13.5000, 14.5000,
        15.5000, 16.5000, 17.5000, 18.5000, 19.5000, 20.5000, 21.5000, 22.5000,
        23.5000, 24.5000, 25.5000, 26.5000, 27.5000, 28.5000, 29.5000, 30.5000,
        31.5000, 32.5000, 33.5000, 34.5000, 35.5000, 36.5000, 37.5000, 38.5000,
        39.5000])
tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
        14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27.,
        28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39.])


In [10]:
# auto-set some parameters (will be stored in the backup copy)
# param = decode.utils.param_io.autoset_scaling(param)

# add meta information
param.Meta.version = decode.utils.bookkeeping.decode_state()

AttributeError: 'RecursiveNamespace' object has no attribute 'Meta'

In [30]:
"""Experiment ID"""
if not debug:
    if param.InOut.checkpoint_init is None:
        experiment_id = datetime.datetime.now().strftime(
            "%Y-%m-%d_%H-%M-%S") + '_' + socket.gethostname()
        from_ckpt = False
        if log_comment:
            experiment_id = experiment_id + '_' + log_comment
    else:
        from_ckpt = True
        experiment_id = Path(param.InOut.checkpoint_init).parent.name
else:
    experiment_id = 'debug'
    from_ckpt = False
print(experiment_id)

2022-04-02_12-01-47_user-WS-C621E-SAGE-Series


In [31]:
"""Set up unique folder for experiment"""
if not from_ckpt:
    experiment_path = Path(param.InOut.experiment_out) / Path(experiment_id)
else:
    experiment_path = Path(param.InOut.checkpoint_init).parent
print(param.InOut.experiment_out)
print(experiment_path)

/media/hdd/rPSF_results/decode_impl
/media/hdd/rPSF_results/decode_impl/2022-04-02_12-01-47_user-WS-C621E-SAGE-Series


In [32]:
model_out = experiment_path / Path('model.pt')
ckpt_path = experiment_path / Path('ckpt.pt')
print(model_out)
print(ckpt_path)

/media/hdd/rPSF_results/decode_impl/2022-04-02_12-01-47_user-WS-C621E-SAGE-Series/model.pt
/media/hdd/rPSF_results/decode_impl/2022-04-02_12-01-47_user-WS-C621E-SAGE-Series/ckpt.pt


In [33]:
_, device_ix = decode.utils.hardware._specific_device_by_str(device)
print(device_ix)

None


In [34]:
log_folder = log_folder + '/' + experiment_id
print(log_folder)

logger = decode.neuralfitter.utils.logger.MultiLogger(
    [decode.neuralfitter.utils.logger.SummaryWriter(log_dir=log_folder,
                                                    filter_keys=["dx_red_mu", "dx_red_sig",
                                                                    "dy_red_mu",
                                                                    "dy_red_sig", "dz_red_mu",
                                                                    "dz_red_sig",
                                                                    "dphot_red_mu",
                                                                    "dphot_red_sig"]),
        decode.neuralfitter.utils.logger.DictLogger()])

/home/lingjia/Documents/rPSF/log/2022-04-02_12-01-47_user-WS-C621E-SAGE-Series


In [35]:
sim_train, sim_test = setup_random_simulation(param)
print(sim_train)

<decode.simulation.simulator.Simulation object at 0x7fb33842b0d0>


In [36]:
# visualize cell
# print(param)
# print(param.InOut.checkpoint_init)
# print(experiment_id)
# print(param.Hardware.device)
# print(param.Hardware.torch_multiprocessing_sharing_strategy)
# print(param.Hardware.torch_threads)
# print(param.HyperParameter.architecture)
# print(device)
# print(param.HyperParameter.emitter_label_photon_min)
print(param.HyperParameter.optimizer)

AdamW


In [37]:
tar_frame_ix_train = (0, 0)
tar_frame_ix_test = (0, param.TestSet.test_size)

In [38]:
# Turn to decode.neuralfitter.target_generator.ParameterListTarget 
n_max=param.HyperParameter.max_number_targets
xextent=param.Simulation.psf_extent[0]
yextent=param.Simulation.psf_extent[1]
ix_low=tar_frame_ix_train[0]
ix_high=tar_frame_ix_train[1]
squeeze_batch_dim=True

print(n_max,xextent,yextent,ix_low,ix_high)

250 [-0.5, 39.5] [-0.5, 39.5] 0 0


In [39]:
phot_max = param.Scaling.phot_max
z_max = param.Scaling.z_max
bg_max = param.Scaling.bg_max
print(phot_max, z_max, bg_max)

31000.0 960.0 240.0


In [40]:
simulator_train = sim_train
simulator_test = sim_test

In [41]:
models_available = {
    'SigmaMUNet': decode.neuralfitter.models.SigmaMUNet,
    'DoubleMUnet': decode.neuralfitter.models.model_param.DoubleMUnet,
    'SimpleSMLMNet': decode.neuralfitter.models.model_param.SimpleSMLMNet,
}

model = models_available[param.HyperParameter.architecture]
model = model.parse(param)

model_ls = decode.utils.model_io.LoadSaveModel(model, output_file=model_out)

model = model_ls.load_init()
model = model.to(torch.device(device))

# Small collection of optimisers
optimizer_available = {
    'Adam': torch.optim.Adam,
    'AdamW': torch.optim.AdamW
}

optimizer = optimizer_available[param.HyperParameter.optimizer]
optimizer = optimizer(model.parameters(), **param.HyperParameter.opt_param)

"""Loss function."""
criterion = decode.neuralfitter.loss.GaussianMMLoss(
    xextent=param.Simulation.psf_extent[0],
    yextent=param.Simulation.psf_extent[1],
    img_shape=param.Simulation.img_size,
    device=device,
    chweight_stat=param.HyperParameter.chweight_stat)

"""Learning Rate and Simulation Scheduling"""
lr_scheduler_available = {
    'ReduceLROnPlateau': torch.optim.lr_scheduler.ReduceLROnPlateau,
    'StepLR': torch.optim.lr_scheduler.StepLR
}
lr_scheduler = lr_scheduler_available[param.HyperParameter.learning_rate_scheduler]
lr_scheduler = lr_scheduler(optimizer, **param.HyperParameter.learning_rate_scheduler_param)

"""Checkpointing"""
checkpoint = CheckPoint(path=ckpt_path)

"""Setup gradient modification"""
grad_mod = param.HyperParameter.grad_mod

"""Log the model (Graph) """
try:
    dummy = torch.rand((2, param.HyperParameter.channels_in,
                        *param.Simulation.img_size), requires_grad=False).to(
        torch.device(device))
    logger.add_graph(model, dummy)

except:
    print("Did not log graph.")
    # raise RuntimeError("Your dummy input is wrong. Please update it.")

"""Transform input data, compute weight mask and target data"""
# frame_proc: x --> (x-offset)/scale
frame_proc = decode.neuralfitter.scale_transform.AmplitudeRescale.parse(param)
bg_frame_proc = None

if param.HyperParameter.emitter_label_photon_min is not None:
    # select emitters with photon > emitter_label_photon_min
    em_filter = decode.neuralfitter.em_filter.PhotonFilter(
        param.HyperParameter.emitter_label_photon_min)
else:
    em_filter = decode.neuralfitter.em_filter.NoEmitterFilter()

tar_frame_ix_train = (0, 0)
tar_frame_ix_test = (0, param.TestSet.test_size)

"""Setup Target generator consisting possibly multiple steps in a transformation sequence."""
tar_gen = decode.neuralfitter.utils.processing.TransformSequence(
    [
        # emitter_set --> Tuple (param_tar, mask_tar, bg)
        # param_tar: N*4, [photon, xyz], mask_tar: N*4, 0/1, indicate function 
        decode.neuralfitter.target_generator.ParameterListTarget(
            n_max=param.HyperParameter.max_number_targets,
            xextent=param.Simulation.psf_extent[0],
            yextent=param.Simulation.psf_extent[1],
            ix_low=tar_frame_ix_train[0],
            ix_high=tar_frame_ix_train[1],
            squeeze_batch_dim=True),

        decode.neuralfitter.target_generator.DisableAttributes.parse(param),

        # param_tar --> phot/max, z/z_max, bg/bg_max
        decode.neuralfitter.scale_transform.ParameterListRescale(
            phot_max=param.Scaling.phot_max,
            z_max=param.Scaling.z_max,
            bg_max=param.Scaling.bg_max)
    ])

# setup target for test set in similar fashion, however test-set is static.
tar_gen_test = copy.deepcopy(tar_gen)
tar_gen_test.com[0].ix_low = tar_frame_ix_test[0]
tar_gen_test.com[0].ix_high = tar_frame_ix_test[1]
tar_gen_test.com[0].squeeze_batch_dim = False
tar_gen_test.com[0].sanity_check()

Model instantiated.
Model initialised as specified in the constructor.


In [42]:
def setup_dataloader(param, train_ds, test_ds=None):
    """Set up dataloader"""

    train_dl = torch.utils.data.DataLoader(
        dataset=train_ds,
        batch_size=param.HyperParameter.batch_size,
        drop_last=True,
        shuffle=True,
        num_workers=param.Hardware.num_worker_train,
        pin_memory=True,
        collate_fn=decode.neuralfitter.utils.dataloader_customs.smlm_collate)

    if test_ds is not None:

        test_dl = torch.utils.data.DataLoader(
            dataset=test_ds,
            batch_size=param.HyperParameter.batch_size,
            drop_last=False,
            shuffle=False,
            num_workers=param.Hardware.num_worker_train,
            pin_memory=False,
            collate_fn=decode.neuralfitter.utils.dataloader_customs.smlm_collate)
    else:

        test_dl = None

    return train_dl, test_dl

In [43]:
train_ds = decode.neuralfitter.dataset.SMLMLiveDataset(
    simulator=simulator_train,
    em_proc=em_filter,
    frame_proc=frame_proc,
    bg_frame_proc=bg_frame_proc,
    tar_gen=tar_gen, weight_gen=None,
    frame_window=param.HyperParameter.channels_in,
    pad=None, return_em=False)

train_ds.sample(True)

Sampled dataset in 21.48s. 250965 emitters on 10001 frames.


In [44]:
test_ds = decode.neuralfitter.dataset.SMLMAPrioriDataset(
    simulator=simulator_test,
    em_proc=em_filter,
    frame_proc=frame_proc,
    bg_frame_proc=bg_frame_proc,
    tar_gen=tar_gen_test, weight_gen=None,
    frame_window=param.HyperParameter.channels_in,
    pad=None, return_em=False)

test_ds.sample(True)

Sampled dataset in 1.15s. 12438 emitters on 513 frames.


In [45]:
ds_train = train_ds
ds_test = test_ds
dl_train, dl_test = setup_dataloader(param, ds_train, ds_test)

In [46]:
from tqdm import tqdm
dataloader = dl_train
tqdm_enum = tqdm(dataloader, total=len(dataloader), smoothing=0.)  # progress bar enumeration
ttt = iter(tqdm_enum)
(x, y_tar, weight) = next(ttt)

  0%|          | 0/156 [19:29<?, ?it/s]


In [47]:
print(x.size())
print(y_tar[0].size())
print(y_tar[1].size())
print(y_tar[2].size())

torch.Size([64, 1, 40, 40])
torch.Size([64, 250, 4])
torch.Size([64, 250])
torch.Size([64, 40, 40])


In [48]:
"""Loss function."""
criterion = decode.neuralfitter.loss.GaussianMMLoss(
    xextent=param.Simulation.psf_extent[0],
    yextent=param.Simulation.psf_extent[1],
    img_shape=param.Simulation.img_size,
    device=device,
    chweight_stat=param.HyperParameter.chweight_stat)

loss = criterion

"""Forward the data"""
y_out = model(x)

"""Reset the optimiser, compute the loss and backprop it"""
loss_val = loss(y_out, y_tar, weight)

In [None]:
print(loss_val.size())

In [78]:
import numpy as np
label_path = '/media/hdd/rPSF_data/rPSF/train/0620_uniformFlux/label.txt'
label_raw = np.loadtxt(label_path)
if label_raw.ndim < 2:
    label_raw = np.expand_dims(label_raw, axis=0)
labels = {}
for i in np.unique(label_raw[:,0]):
    i_bol = label_raw[:,0] == i
    labels[i] = label_raw[i_bol,:]
    if labels[i].ndim < 2:
        labels[i] = torch.tensor(labels[i]).unsqueeze(0)
    else:
        labels[i] = torch.tensor(labels[i])
    labels[i] =  labels[i][:,[4,1,2,3]]


In [79]:
print(labels[0])


tensor([[113.5610,  20.4191,  27.5939,  -4.3109],
        [105.1080, -24.3517, -25.3649,   6.2191],
        [139.0421,  -5.3202,  28.1096, -13.1525],
        [121.4389,  28.2700,   9.0004,   8.2418],
        [119.1001,  19.8701, -27.3673, -18.7267],
        [144.9537,  31.2455, -15.0621,  -8.9231],
        [100.4356,  10.5904,   3.1879, -18.1531],
        [133.7012, -31.5716,  31.1105, -16.1147],
        [133.6259,  23.7408,  31.6124,  12.9383],
        [106.9572,  29.5115, -23.2823,   7.7931],
        [120.3302,  12.1540,  32.0003,  -7.3160],
        [ 85.5178,  17.5263,  31.0874,  18.0089],
        [ 83.8938,  16.5330,  -0.9945, -18.6222],
        [117.7493,  19.5232,  28.6126,  -0.7169],
        [135.3023,  22.2580,  27.5873,  -7.5032],
        [146.3889,  20.4429,  26.6142,  -0.3294],
        [ 89.2350,  22.1361,  26.5680,  -7.3112],
        [120.3523,  19.4234,  26.3802,  -1.9900]], dtype=torch.float64)


In [None]:


for i in labels.keys():
    if len(labels[i].shape)==2:
        labels[i] = np.expand_dims(labels[i].T,0) # [1,n,3]
    elif len(labels[i].shape)==1: # labels[i] = 3*n, n is number of source points,
        labels[i] = np.expand_dims(labels[i].T,0) # [1,3]
        labels[i] = np.expand_dims(labels[i],0) # [1,1,3]


In [55]:
np.unique(label_raw[:,0])

array([1.000e+00, 2.000e+00, 3.000e+00, ..., 9.998e+03, 9.999e+03,
       1.000e+04])