In [2]:
import argparse, datetime, numpy, os, shutil, sys, time
from cmath import inf
import socket
from pathlib import Path
import torch
# self-defined modules
import decode.evaluation
import decode.neuralfitter
import decode.neuralfitter.coord_transform
import decode.neuralfitter.utils
import decode.simulation
import decode.utils
# from decode.neuralfitter.train.random_simulation import setup_random_simulation
from decode.neuralfitter.utils import log_train_val_progress
from decode.utils.checkpoint import CheckPoint

In [3]:
def parse_args():
    parser = argparse.ArgumentParser(description='Training Args')

    parser.add_argument('-i', '--device', default=None, 
                        help='Specify the device string (cpu, cuda, cuda:0) and overwrite param.',
                        type=str)

    parser.add_argument('-p', '--param_file', default=None,
                        help='Specify your parameter file (.yml or .json).', type=str)

    parser.add_argument('-w', '--num_worker_override',default=None,
                        help='Override the number of workers for the dataloaders.',
                        type=int)

    parser.add_argument('-n', '--no_log', default=False, action='store_true',
                        help='Set no log if you do not want to log the current run.')

    parser.add_argument('-c', '--log_comment', default=None,
                        help='Add a log_comment to the run.')

    parser.add_argument('-d', '--data_path_override', default=None,
                        help='Specify your path to data', type=str)

    parser.add_argument('-is', '--img_size_override', default=None,
                        help='Override img size', type=int)

    # args = parser.parse_args()
    args, _ = parser.parse_known_args()
    return args

In [4]:
args = parse_args()
args.device = 'cuda'
args.param_file ='/home/lingjia/Documents/rpsf/NN/param.yaml'
args.data_path_override = '/media/hdd_4T/lingjia/rPSF/20220716_decode_variant_v2/data_train/30k_pt50L5'
args.img_size_override = 96

In [5]:
def setup_trainer(logger, model_out, ckpt_path, device, param):
    """Set model, optimiser, loss and schedulers"""
    models_available = {
        'SigmaMUNet': decode.neuralfitter.models.SigmaMUNet_variant,
        'DoubleMUnet': decode.neuralfitter.models.model_param.DoubleMUnet,
        'SimpleSMLMNet': decode.neuralfitter.models.model_param.SimpleSMLMNet,
    }

    model = models_available[param.HyperParameter.architecture]
    model = model.parse(param)

    model_ls = decode.utils.model_io.LoadSaveModel(model, output_file=model_out)

    model = model_ls.load_init()
    model = model.to(torch.device(device))

    # Small collection of optimisers
    optimizer_available = {
        'Adam': torch.optim.Adam,
        'AdamW': torch.optim.AdamW,
        'SGD': torch.optim.SGD
    }

    optimizer = optimizer_available[param.HyperParameter.optimizer]
    optimizer = optimizer(model.parameters(), **param.HyperParameter.opt_param)

    """Loss function."""
    criterion = decode.neuralfitter.loss.GaussianMMLoss(
        xextent=param.Simulation.psf_extent[0],
        yextent=param.Simulation.psf_extent[1],
        img_shape=param.Simulation.img_size,
        device=device,
        chweight_stat=param.HyperParameter.chweight_stat)

    """Learning Rate and Simulation Scheduling"""
    lr_scheduler_available = {
        'ReduceLROnPlateau': torch.optim.lr_scheduler.ReduceLROnPlateau,
        'StepLR': torch.optim.lr_scheduler.StepLR
    }
    lr_scheduler = lr_scheduler_available[param.HyperParameter.learning_rate_scheduler]
    lr_scheduler = lr_scheduler(optimizer, **param.HyperParameter.learning_rate_scheduler_param)

    """Checkpointing"""
    checkpoint = CheckPoint(path=ckpt_path)

    """Setup gradient modification"""
    grad_mod = param.HyperParameter.grad_mod

    """Log the model (Graph) """
    try:
        dummy = torch.rand((2, param.HyperParameter.channels_in,
                            *param.Simulation.img_size), requires_grad=False).to(torch.device(device))
        logger.add_graph(model, dummy)

    except:
        print("Did not log graph.")
        # raise RuntimeError("Your dummy input is wrong. Please update it.")

    """Setup Target generator consisting possibly multiple steps in a transformation sequence."""
    tar_proc = decode.neuralfitter.utils.processing.TransformSequence(
        [
            # param_tar --> phot/max, z/z_max, bg/bg_max
            decode.neuralfitter.scale_transform.ParameterListRescale(
                phot_max=param.Scaling.phot_max,
                z_max=param.Scaling.z_max,
                bg_max=param.Scaling.bg_max)
        ])

    # Split train & val set
    train_IDs = numpy.arange(1,101,1).tolist()
    val_IDs = numpy.arange(101,111,1).tolist()

    train_ds = decode.neuralfitter.dataset.rPSFDataset(root_dir=param.InOut.data_path,
                                                       list_IDs=train_IDs, label_path=None, 
                                                       n_max=param.HyperParameter.max_number_targets,
                                                       tar_proc=tar_proc,
                                                       img_shape=param.Simulation.img_size)

    test_ds = decode.neuralfitter.dataset.rPSFDataset(root_dir=param.InOut.data_path,
                                                       list_IDs=val_IDs, label_path=None, 
                                                       n_max=param.HyperParameter.max_number_targets,
                                                       tar_proc=tar_proc,
                                                       img_shape=param.Simulation.img_size)

    # print(test_ds.label_gen())

    """Set up post processor"""
    if param.PostProcessing is None:
        post_processor = decode.neuralfitter.post_processing.NoPostProcessing(xy_unit='px',
                                                                              px_size=param.Camera.px_size)

    elif param.PostProcessing == 'LookUp':
        post_processor = decode.neuralfitter.utils.processing.TransformSequence([

            decode.neuralfitter.scale_transform.InverseParamListRescale(
                phot_max=param.Scaling.phot_max,
                z_max=param.Scaling.z_max,
                bg_max=param.Scaling.bg_max),

            decode.neuralfitter.coord_transform.Offset2Coordinate.parse(param),

            decode.neuralfitter.post_processing.LookUpPostProcessing(
                raw_th=param.PostProcessingParam.raw_th,
                pphotxyzbg_mapping=[0, 1, 2, 3, 4, -1],
                xy_unit='px',
                px_size=param.Camera.px_size)
        ])

    elif param.PostProcessing in ('SpatialIntegration', 'NMS'):  # NMS as legacy support
        post_processor = decode.neuralfitter.utils.processing.TransformSequence([
            # out_tar --> out_tar: photo*photo_max, z*z_max, bg*bg_max
            decode.neuralfitter.scale_transform.InverseParamListRescale(
                phot_max=param.Scaling.phot_max,
                z_max=param.Scaling.z_max,
                bg_max=param.Scaling.bg_max),
            # offset --> coordinate e.g., 0.2 --> 10.2 
            decode.neuralfitter.coord_transform.Offset2Coordinate.parse(param),

            decode.neuralfitter.post_processing.SpatialIntegration(
                raw_th=param.PostProcessingParam.raw_th, # 0.5
                xy_unit='px')
        ])

    else:
        raise NotImplementedError

    """Evaluation Specification"""
    matcher = decode.evaluation.match_emittersets.GreedyHungarianMatching.parse(param)
    # matcher = None

    return train_ds, test_ds, model, model_ls, optimizer, criterion, lr_scheduler, grad_mod, post_processor, matcher, checkpoint


def setup_dataloader(param, train_ds, test_ds=None):
    """Set up dataloader"""

    train_dl = torch.utils.data.DataLoader(
        dataset=train_ds,
        batch_size=param.HyperParameter.batch_size,
        drop_last=True,
        shuffle=True,
        num_workers=param.Hardware.num_worker_train,
        pin_memory=True,
        collate_fn=decode.neuralfitter.utils.dataloader_customs.smlm_collate)

    if test_ds is not None:

        test_dl = torch.utils.data.DataLoader(
            dataset=test_ds,
            batch_size=param.HyperParameter.batch_size,
            drop_last=False,
            shuffle=False,
            num_workers=param.Hardware.num_worker_train,
            pin_memory=False,
            collate_fn=decode.neuralfitter.utils.dataloader_customs.smlm_collate)
    else:

        test_dl = None

    return train_dl, test_dl


In [6]:
"""TRIAN AND TEST 1 EPOCH - BEOFRE TUNE POST-PROCESS PART"""
param_file = Path(args.param_file)
param = decode.utils.param_io.ParamHandling().load_params(param_file)

# add meta information - Meta=namespace(version='0.10.0'),
param.Meta.version = decode.utils.bookkeeping.decode_state()

"""Experiment ID"""
if param.InOut.checkpoint_init is None:
    experiment_id = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    from_ckpt = False
    if args.log_comment:
        experiment_id = experiment_id + '_' + args.log_comment
else:
    from_ckpt = True
    experiment_id = Path(param.InOut.checkpoint_init).parent.name

"""Set up unique folder for experiment"""
if not from_ckpt:
    experiment_path = Path(param.InOut.experiment_out) / Path(experiment_id)
else:
    experiment_path = Path(param.InOut.checkpoint_init).parent

if not experiment_path.parent.exists():
    experiment_path.parent.mkdir()

if not from_ckpt:
    experiment_path.mkdir(exist_ok=False)

model_out = experiment_path / Path('model.pt')
ckpt_path = experiment_path / Path('ckpt.pt')

# Modify parameters
if args.num_worker_override is not None:
    param.Hardware.num_worker_train = args.num_worker_override

"""Hardware / Server stuff."""
if args.device is not None:
    device = args.device
    # param.Hardware.device_simulation = device_overwrite  # lazy assumption
else:
    device = param.Hardware.device

if args.data_path_override is not None:
    param.InOut.data_path = args.data_path_override

if args.img_size_override is not None:
    param.Simulation.img_size = [args.img_size_override,args.img_size_override]
    param.Simulation.psf_extent = [[-0.5, args.img_size_override-0.5],
                                    [-0.5, args.img_size_override-0.5], None]

    param.TestSet.frame_extent = param.Simulation.psf_extent
    param.TestSet.img_size = param.Simulation.img_size

# Backup the parameter file under the network output path with the experiments ID
param_backup_in = experiment_path / Path('param_run_in').with_suffix(param_file.suffix)
shutil.copy(param_file, param_backup_in)

param_backup = experiment_path / Path('param_run').with_suffix(param_file.suffix)
decode.utils.param_io.ParamHandling().write_params(param_backup, param)

if sys.platform in ('linux', 'darwin'):
    os.nice(param.Hardware.unix_niceness)
elif param.Hardware.unix_niceness is not None:
    print(f"Cannot set niceness on platform {sys.platform}. You probably do not need to worry.")

torch.set_num_threads(param.Hardware.torch_threads)

"""Setup Log System"""
if args.no_log:
    logger = decode.neuralfitter.utils.logger.NoLog()

else:
    log_folder = experiment_path

    logger = decode.neuralfitter.utils.logger.MultiLogger(
        [decode.neuralfitter.utils.logger.SummaryWriter(log_dir=log_folder,
                                                        filter_keys=["dx_red_mu", "dx_red_sig",
                                                                        "dy_red_mu",
                                                                        "dy_red_sig", "dz_red_mu",
                                                                        "dz_red_sig",
                                                                        "dphot_red_mu",
                                                                        "dphot_red_sig"]),
            decode.neuralfitter.utils.logger.DictLogger()])

# sim_train, sim_test = setup_random_simulation(param)
ds_train, ds_test, model, model_ls, optimizer, criterion, lr_scheduler, grad_mod, post_processor, matcher, ckpt = setup_trainer(logger, model_out, ckpt_path, device, param)

dl_train, dl_test = setup_dataloader(param, ds_train, ds_test)

if from_ckpt:
    ckpt = decode.utils.checkpoint.CheckPoint.load(param.InOut.checkpoint_init)
    model.load_state_dict(ckpt.model_state)
    optimizer.load_state_dict(ckpt.optimizer_state)
    lr_scheduler.load_state_dict(ckpt.lr_sched_state)
    first_epoch = ckpt.step + 1
    model = model.train()
    print(f'Resuming training from checkpoint ' + experiment_id)
else:
    first_epoch = 0

best_val_loss = inf
i = first_epoch
# for i in range(first_epoch, param.HyperParameter.epochs):
logger.add_scalar('learning/learning_rate', optimizer.param_groups[0]['lr'], i)
print(f'Epoch{i}')
# if i >= 1:
_ = decode.neuralfitter.train_val_impl.train(
    model=model,
    optimizer=optimizer,
    loss=criterion,
    dataloader=dl_train,
    grad_rescale=param.HyperParameter.moeller_gradient_rescale,
    grad_mod=grad_mod,
    epoch=i,
    device=torch.device(device),
    logger=logger
)

# val_loss=avg of loss for all batches
# test_out=list of network_output: ["loss", "x", "y_out", "y_tar", "weight", "em_tar"]
val_loss, test_out = decode.neuralfitter.train_val_impl.test(
    model=model,
    loss=criterion,
    dataloader=dl_test,
    epoch=i,
    device=torch.device(device))
# print(val_loss)

Model instantiated.
Model initialised as specified in the constructor.
Epoch0


0 Train Time:0.89 Loss_mean_ep:1.97e+07: 100%|████████████████████████| 2/2 [00:00<00:00,  2.18it/s]
0 Test Time:0.069 Loss_mean_ep:1.39e+05: 100%|████| 1/1 [00:00<00:00,  9.17it/s]


In [None]:
# PRINT LOG
# log_train_val_progress.post_process_log_test(loss_cmp=test_out.loss,
#                                             loss_scalar=val_loss,
#                                             x=test_out.x, y_out=test_out.y_out,
#                                             y_tar=test_out.y_tar,
#                                             weight=test_out.weight,
#                                             em_tar=ds_test.emitter(),
#                                             px_border=-0.5, px_size=1.,
#                                             post_processor=post_processor,
#                                             matcher=matcher, logger=logger,
#                                             step=i)


In [64]:
"""POST_PROCESSOR"""
post_processor = decode.neuralfitter.utils.processing.TransformSequence([
    # out_tar --> out_tar: photo*photo_max, z*z_max, bg*bg_max
    decode.neuralfitter.scale_transform.InverseParamListRescale(
        phot_max=param.Scaling.phot_max,
        z_max=param.Scaling.z_max,
        bg_max=param.Scaling.bg_max),
    # offset --> coordinate e.g., 0.2 --> 10.2
    decode.neuralfitter.coord_transform.Offset2Coordinate.parse(param),

    decode.neuralfitter.post_processing.SpatialIntegration(
        raw_th=param.PostProcessingParam.raw_th, # 0.5
        xy_unit='px')])

In [19]:
"""INVERSEPARAMLISTRESCALE CHECK"""
# change corresponding index
# num channels = 3
# p=0,1,2 phot=3,4,5 x=6,7,8 y=9,10,11 z=12,13,14 phot_sig=15,16,17 x_sig=18,19,20, y_sig=21,22,23, z_sig=24,25,26, bg=27
# num channels = 2
# p=0,1 phot=2,3 x=4,5 y=6,7 z=8,9 phot_sig=10,11 x_sig=12,13, y_sig=14,15, z_sig=16,17, bg=18

'INVERSEPARAMLISTRESCALE CHECK'

In [7]:
"""SIMULATED INPUT X AND GROUND-TRUTH TARGET"""
# INPUT
x = torch.randn(1,1,4,4).to('cuda')
# TARGET
param_tar = torch.zeros(1,60,4)
param_tar_v = torch.randn(2,4)
param_tar[0,:2,:] = param_tar_v

mask_tar = torch.zeros(1,60)
mask_tar[0,:2] = 1

bg = torch.ones((1,4,4))*5

target = (param_tar.to('cuda'), mask_tar.to('cuda'), bg.to('cuda'))
# WEIGHT
weight = None

In [17]:
"""OUTPUT OF SIMULATED INPUT X AND OFFSET2COORDINATE CHECK"""
output = model(x)
print(f'y_out: {output.shape}')

from decode.neuralfitter.target_generator import UnifiedEmbeddingTarget
off_psf = UnifiedEmbeddingTarget(xextent=param.TestSet.frame_extent[0],
                                    yextent=param.TestSet.frame_extent[1],
                                    img_shape=(4,4), roi_size=1)

print(off_psf._bin_ctr_x)

xv, yv = torch.meshgrid([off_psf._bin_ctr_x, off_psf._bin_ctr_y])
_x_mesh = xv.unsqueeze(0)
_y_mesh = yv.unsqueeze(0)

x_offset = output[:, 4:6]
print(f'x_offset:\n{x_offset}')
print(f'x_mesh:\n{_x_mesh[:,None].repeat(1,2, 1, 1)}')


y_out: torch.Size([1, 19, 4, 4])
tensor([11.5000, 35.5000, 59.5000, 83.5000])
x_offset:
tensor([[[[-0.5418,  0.8828, -0.8397, -0.3643],
          [ 0.7593, -0.8372,  0.5576, -0.8078],
          [ 0.3342, -0.0592, -0.9278, -0.5546],
          [ 0.4561,  0.5927,  0.2704,  0.7213]],

         [[-0.8458,  0.0293,  0.9037, -0.2904],
          [-0.7090,  0.9806, -0.0610, -0.5777],
          [-0.4695,  0.1654, -0.3924,  0.8914],
          [ 0.7908, -0.2937, -0.7683, -0.8147]]]], device='cuda:0',
       grad_fn=<SliceBackward>)
x_mesh:
tensor([[[[11.5000, 11.5000, 11.5000, 11.5000],
          [35.5000, 35.5000, 35.5000, 35.5000],
          [59.5000, 59.5000, 59.5000, 59.5000],
          [83.5000, 83.5000, 83.5000, 83.5000]],

         [[11.5000, 11.5000, 11.5000, 11.5000],
          [35.5000, 35.5000, 35.5000, 35.5000],
          [59.5000, 59.5000, 59.5000, 59.5000],
          [83.5000, 83.5000, 83.5000, 83.5000]]]])


In [20]:
def _subpx_to_absolute(self, x_offset, y_offset):
    """
    Convert subpixel pointers to absolute coordinates. Actual implementation

    Args:
        x_offset: N x H x W
        y_offset: N x H x W

    Returns:
    """
    batch_size = x_offset.size(0)
    x_coord = _x_mesh[:,None].repeat(batch_size, 1, 1).to(x_offset.device) + x_offset
    y_coord = _y_mesh.repeat(batch_size, 1, 1).to(y_offset.device) + y_offset
    return x_coord, y_coord

def Offset2Coordinate(x: torch.Tensor) -> torch.Tensor:
    """
    Forward frames through post-processor.

    Args:
        x (torch.Tensor): features to be converted. Expecting x/y coordinates in channel index 2, 3.
            expected shape :math:`(N, C, H, W)`
    """

    if x.dim() != 4:
        raise ValueError("Wrong dimensionality. Needs to be N x C x H x W.")

    """Convert the channel values to coordinates"""
    x_coord, y_coord = _subpx_to_absolute(x[:, 6:9], x[:, 9:12])

    output_converted = x.clone()
    output_converted[:, 6:9] = x_coord
    output_converted[:, 9:12] = y_coord

    return output_converted

torch.set_printoptions(precision=4,sci_mode=False)
# print(output)
# print(Offset2Coordinate(output))

In [28]:
"""SIMULATED OUTPUT X0"""
torch.manual_seed(0)
# x0 = torch.randn(2,28,2,2)
x0 = torch.randn(3,19,4,4)

In [29]:
"""SPATIALINTEGRATION CHECK - P1 IN SpatialIntegration FORWARD"""
from typing import Union, Callable
def _nms(p: torch.Tensor, p_aggregation, raw_th, split_th) -> torch.Tensor:
    """
    Non-Maximum Suppresion
    Args:
        p:
    """

    with torch.no_grad():
        p_copy = p.clone()

        """Probability values > 0.3 are regarded as possible locations"""
        p_clip = torch.where(p > raw_th, p, torch.zeros_like(p))[:, None]
        print(f'p_clip:{p_clip.shape}')

        """localize maximum values within a 3x3 patch"""
        pool = torch.nn.functional.max_pool3d(p_clip, kernel_size=(1,3,3), stride=1, padding=(0,1,1))
        max_mask1 = torch.eq(p[:, None], pool).float()
        print(f'max_mask1:{max_mask1.shape}')

        """Add probability values from the 4 adjacent pixels"""
        diag = 0.  # 1/np.sqrt(2)
        filt = torch.tensor([[diag, 1., diag], [1, 1, 1], [diag, 1, diag]]).unsqueeze(0).unsqueeze(0).to(p.device)
        conv = [torch.nn.functional.conv2d(p[:, None, idx], filt, padding=1) for idx in range(p.shape[1])]
        conv = torch.cat(conv,dim=1)[:,None]
        p_ps1 = max_mask1 * conv
        print(f'p_ps1:{p_ps1.shape}')

        """
        In order do be able to identify two fluorophores in adjacent pixels we look for
        probablity values > 0.6 that are not part of the first mask
        """
        p_copy *= (1 - max_mask1[:, 0])
        # p_clip = torch.where(p_copy > split_th, p_copy, torch.zeros_like(p_copy))[:, None]
        max_mask2 = torch.where(p_copy > split_th, torch.ones_like(p_copy), torch.zeros_like(p_copy))[:, None]
        p_ps2 = max_mask2 * conv

        """This is our final clustered probablity which we then threshold (normally > 0.7)
        to get our final discrete locations"""
        p_ps = p_aggregation(p_ps1, p_ps2)
        assert p_ps.size(1) == 1

        return p_ps.squeeze(1)

def set_p_aggregation(p_aggr: Union[str, Callable]) -> Callable:
    """
    Sets the p_aggregation by string or callable. Return s Callable

    Args:
        p_aggr: probability aggregation
    """

    if isinstance(p_aggr, str):

        if p_aggr == 'sum':
            return torch.add
        elif p_aggr == 'max':
            return torch.max
        elif p_aggr == 'norm_sum':
            def norm_sum(*args):
                return torch.clamp(torch.add(*args), 0., 1.)

            return norm_sum
        else:
            raise ValueError

    else:
        return p_aggr


x = x0.clone()
p_aggregation = set_p_aggregation('norm_sum')
raw_th = 0.3
_split_th = 0.6
x[:, 0:2] = _nms(x[:, 0:2], p_aggregation, raw_th, _split_th)
# print(x)


p_clip:torch.Size([3, 1, 2, 4, 4])
max_mask1:torch.Size([3, 1, 2, 4, 4])
p_ps1:torch.Size([3, 1, 2, 4, 4])


In [31]:
"""SPATIALINTEGRATION CHECK - P2 IN LookUpPostProcessing INITIAL"""
pphotxyz_mapping: Union[list, tuple] = (0,1, 2,3, 4,5, 6,7, 8,9)
photxyz_sigma_mapping: Union[list, tuple, None] = (10,11, 12,13, 14,15, 16,17)
bg_mapping: Union[list, tuple] = (-1)

assert len(pphotxyz_mapping) == 10, "Wrong length of mapping."
if photxyz_sigma_mapping is not None:
    assert len(photxyz_sigma_mapping) == 8, "Wrong length of sigma mapping."


In [38]:
"""SPATIALINTEGRATION CHECK - P3 IN LookUpPostProcessing FORWARD"""
x_mapped = x[:, pphotxyz_mapping]
print(f'x_mapped: {x_mapped.shape}')
# print(x_mapped)

"""Filter"""
def _filter(detection) -> torch.BoolTensor:
    """
    Args:
        detection: any tensor that should be thresholded
    Returns:
        boolean with active px
    """

    return detection >= raw_th

active_px = _filter(x_mapped[:, 0:2])  # 0th ch. is detection channel
prob = x_mapped[:, 0:2][active_px]
print(f'prob: {prob}')


"""Look-Up in channels"""
# features, active_px = x_mapped[:, 3:], active_px
def _lookup_features(features: torch.Tensor, active_px: torch.Tensor) -> tuple:
    """
    Args:
        features: size :math:`(N, C, H, W)`
        active_px: size :math:`(N, H, W)`

    Returns:
        torch.Tensor: batch-ix, size :math: `M`
        torch.Tensor: extracted features size :math:`(C, M)`
    """
    batch_size, nc, hh, ww = features.size()
    features = features.reshape(batch_size,int(nc/2),2,hh,ww)
    print(f'features reshape: {features.shape}')
    # print(features)

    assert features.dim() == 5
    assert active_px.dim() == features.dim() - 1

    batch_ix = active_px.nonzero(as_tuple=False)[:, 0] # before [:,0] is Nx4 = [batch_index, 3channel_index, x_index, y_index]
    # print(batch_ix)
    features_active = features.permute(1, 0, 2, 3, 4)[:, active_px]

    return batch_ix, features_active

frame_ix, features = _lookup_features(x_mapped[:, 2:], active_px)
print(f'after looup features,frame_ix:{frame_ix.shape},features:{features.shape}')
print(f'features[0,:]: {features[0,:]}')
xyz = features[1:4].transpose(0, 1)
# print(xyz)


"""If sigma mapping is present, get those values as well."""
if photxyz_sigma_mapping is not None:
    sigma = x[:, photxyz_sigma_mapping]
    _, features_sigma = _lookup_features(sigma, active_px)

    xyz_sigma = features_sigma[1:4].transpose(0, 1).cpu()
    phot_sigma = features_sigma[0].cpu()
else:
    xyz_sigma = None
    phot_sigma = None

x_mapped: torch.Size([3, 10, 4, 4])
prob: tensor([0.7372, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        0.3124, 1.0000, 1.0000, 1.0000, 1.0000, 0.5836, 1.0000, 1.0000])
features reshape: torch.Size([3, 4, 2, 4, 4])
after looup features,frame_ix:torch.Size([17]),features:torch.Size([4, 17])
features[0,:]: tensor([ 0.4397,  2.3022, -1.4689,  0.1778, -0.3952, -0.4462, -1.5312, -1.2341,
         1.8197,  0.9094,  1.2464,  0.1151,  1.6193,  0.4637, -0.3380, -0.2995,
         0.8155])
features reshape: torch.Size([3, 4, 2, 4, 4])
