In [None]:
# default_exp engine.point_process

In [None]:
%load_ext autoreload
%autoreload 2

# Emitter distribution

> Definition of the class used to simulate random emitter positions and intensities

In [None]:
#export
from decode_fish.imports import *
from torch import distributions as D, Tensor
from torch.distributions import Distribution
from decode_fish.funcs.utils import *

In [None]:
#export
class PointProcessUniform(Distribution):
    """
    This class is part of the generative model and uses the probability local_rate to generate sample locations on the voxel grid.  
    For each emitter we then sample x-,y- and z-offsets uniformly in the range [-0.5,0.5] to get continuous locations.
    Intensities are sampled from a gamma distribution torch.distirubtions.gamma(int_conc, int_rate) which is shifted by int_loc.
    Together with the microscope.scale and the PSF this results in the overall brightness of an emitter.
    
    Args:
        local_rate torch.tensor . shape(BS, C, H, W, D): Local rate
        int_conc=0., int_rate=1., int_loc (float): parameters of the intensity gamma distribution
        sim_iters (int): instead of sampling once from local_rate, we sample sim_iters times from local_rate/sim_iters. 
            This results in the same average number of sampled emitters but allows us to sample multiple emitters within one voxel.

    """
    def __init__(self, local_rate: torch.tensor, int_conc=0., int_rate=1., int_loc=1., sim_iters: int = 5, channels=1, n_bits=1, sim_z=True, codebook=None, phase_fac=0.2, int_option=1):

        assert sim_iters >= 1
        self.local_rate = local_rate
        self.device = self._get_device(self.local_rate)
        self.sim_iters = sim_iters
        self.int_conc = int_conc
        self.int_rate = int_rate
        self.int_loc = int_loc
        self.channels = channels
        self.n_bits = n_bits
        self.sim_z=sim_z
        self.codebook=codebook
        self.phase_fac=phase_fac
        self.int_option = int_option
        
    def sample(self, from_code_book=False, phasing=False):

        res_ = [self._sample(self.local_rate/self.sim_iters, from_code_book, phasing) for i in range(self.sim_iters)]
        locations = torch.cat([i[0] for i in res_], dim=0)
        x_offset = torch.cat([i[1] for i in res_], dim=0)
        y_offset = torch.cat([i[2] for i in res_], dim=0)
        z_offset = torch.cat([i[3] for i in res_], dim=0)
        intensities = torch.cat([i[4] for i in res_], dim=0)
        codes = torch.cat([i[6] for i in res_], dim=0)

        return list(locations.T), x_offset, y_offset, z_offset, intensities, res_[0][5], codes

    def _sample(self, local_rate, from_code_book, phasing):

        output_shape = list(local_rate.shape)
        local_rate = torch.clamp(local_rate,0.,1.)
        locations = D.Bernoulli(local_rate).sample()
        n_emitter = int(locations.sum().item())
        x_offset = D.Uniform(low=-0.5, high=0.5).sample(sample_shape=[n_emitter]).to(self.device)
        y_offset = D.Uniform(low=-0.5, high=0.5).sample(sample_shape=[n_emitter]).to(self.device)
        z_offset = D.Uniform(low=-0.5, high=0.5).sample(sample_shape=[n_emitter]).to(self.device)
        if self.int_option == 1:
            intensities = D.Gamma(self.int_conc, self.int_rate).sample(sample_shape=[n_emitter*self.n_bits]).to(self.device) + self.int_loc 
        elif self.int_option == 2:
            intensities = D.Gamma(self.int_conc, self.int_rate).sample(sample_shape=[n_emitter]).to(self.device) + self.int_loc 
            intensities = intensities.repeat_interleave(self.n_bits, 0)
        elif self.int_option == 3:
            intensities = D.Gamma(self.int_conc, self.int_rate).sample(sample_shape=[n_emitter]).to(self.device) + self.int_loc 
            intensities = intensities.repeat_interleave(self.n_bits, 0)   
            int_noise = D.Uniform(low=.7, high=1.5).sample(sample_shape=[n_emitter*self.n_bits]).to(self.device)
            intensities *= int_noise
            
        # If 2D data z-offset is 0
        if not self.sim_z:
            z_offset *= 0
    
        locations = locations.nonzero(as_tuple=False)

        if self.channels > 1:
            code_draw = None
            if from_code_book:
                code_draw = torch.randint(0, len(self.codebook),size=[n_emitter])
                ch_draw = self.codebook[code_draw]
            else:
                ch_draw = torch.multinomial(torch.ones([n_emitter,self.channels])/self.channels, self.n_bits, replacement=False)
            locations = locations.repeat_interleave(self.n_bits, 0)
            locations[:, 1] = ch_draw.reshape(-1)
            
            # Exact positions are shared, but not intensities. Problems due to drift?
            x_offset = x_offset.repeat_interleave(self.n_bits, 0)
            y_offset = y_offset.repeat_interleave(self.n_bits, 0)
            z_offset = z_offset.repeat_interleave(self.n_bits, 0)
            
            output_shape[1] = self.channels
            
            if phasing:
                locations = locations.repeat_interleave(2, 0)
                locations[1::2,1] = locations[1::2,1] + 1
                x_offset = x_offset.repeat_interleave(2, 0)
                y_offset = y_offset.repeat_interleave(2, 0)
                z_offset = z_offset.repeat_interleave(2, 0)       
                intensities = intensities.repeat_interleave(2, 0)   
                
                phase_facs = torch.rand(size=intensities[1::2].shape, device=intensities.device) * self.phase_fac
                intensities[1::2] = intensities[1::2]*phase_facs
                
                inds = [locations[:,1] < self.channels][0]
                x_offset, y_offset, z_offset = x_offset[inds], y_offset[inds], z_offset[inds]
                intensities = intensities[inds]
                locations = locations[inds]
            
        return locations, x_offset, y_offset, z_offset, intensities, tuple(output_shape), code_draw


    def log_prob(self, locations, x_offset=None, y_offset=None, z_offset=None, intensities=None, output_shape=None):
        locations = list_to_locations(locations, output_shape)
        log_prob = D.Bernoulli(self.local_rate).log_prob(locations)
        return log_prob

    @staticmethod
    def _get_device(x):
        return getattr(x, 'device')
    

def list_to_locations(locations, output_shape):
    tmp =torch.zeros(output_shape, device=locations[0].device)
    coord = torch.stack(locations).T
    #incase you have multiple emitter present
    for i in coord: tmp[tuple(i)] += 1
    return tmp

In [None]:
from decode_fish.funcs.merfish_eval import *
bench_df, code_ref, targets = get_benchmark()
code_inds = np.stack([np.nonzero(c)[0] for c in code_ref])

13832


In [None]:
point_process = PointProcessUniform(local_rate = torch.ones([7,1,1,48,48]).cuda()*.3, int_conc=3, int_rate=1, int_loc=1, sim_iters=1, channels=16, n_bits=4, codebook=torch.tensor(code_inds), int_option=1)
locs_3d, x_os_3d, y_os_3d, z_os_3d, ints_3d, output_shape, codes = point_process.sample(from_code_book=True, phasing=False)

In [None]:
torch.save([locs_3d, x_os_3d, y_os_3d, z_os_3d, ints_3d, output_shape, codes], '../data/sim_var_code_class.pt')

In [None]:
# for i in range(1000):
#     locs_3d, x_os_3d, y_os_3d, z_os_3d, ints_3d, output_shape = point_process.sample(from_code_book=False, phasing=False)

In [None]:
def sample_to_df(locs, x_os, y_os, z_os, ints, codes, px_size_zyx=[100,100,100], channels=16, n_bits=4):
    
    x = locs[-1] + x_os + 0.5 
    y = locs[-2] + y_os + 0.5 
    z = locs[-3] + z_os + 0.5 
    
    b_inds = [0] + list(np.diff(x_os_3d).nonzero()[0]+1) + [len(x_os_3d)]
    n_gt = len(b_inds) - 1

    frame_idx = locs[0]
    ch_idx = locs[1]
    
    loc_idx = []
    for i in range(n_gt):
        loc_idx += [i] * (b_inds[i+1] - b_inds[i])
    
    df = DF({'loc_idx': loc_idx,
             'frame_idx': frame_idx.cpu(),
             'x': x.cpu()*px_size_zyx[2],
             'y': y.cpu()*px_size_zyx[1], 
             'z': z.cpu()*px_size_zyx[0]}) 
    
    int_arr = np.zeros([n_gt, channels])
    int_arr[df['loc_idx'], ch_idx.cpu()] = ints.cpu()
    
    df = df.iloc[b_inds[:-1]]
    for i in range(16):
        df[f'int_{i}'] = int_arr[:,i]
        
    df['code_inds'] = codes
    df['ints'] = int_arr.sum(-1)

    return df

In [None]:
# from decode_fish.funcs.output_trafo import sample_to_df
# locs_3d, x_os_3d, y_os_3d, z_os_3d, ints_3d, output_shape = point_process.sample(from_code_book=False, phasing=False)
target_df = sample_to_df(locs_3d, x_os_3d, y_os_3d, z_os_3d, ints_3d, codes, px_size_zyx=[100,100,100])

TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

In [None]:
target_df

Unnamed: 0,loc_idx,frame_idx,x,y,z,int_0,int_1,int_2,int_3,int_4,...,int_8,int_9,int_10,int_11,int_12,int_13,int_14,int_15,code_inds,ints
0,0,0,631.309143,85.199944,32.783012,1.799818,0.000000,0.000000,0.000000,0.000000,...,0.000000,3.552146,0.000000,0.000000,0.000000,0.000000,5.447395,3.346759,48,14.146118
4,1,0,846.326538,65.319008,45.174450,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,4.403565,0.000000,3.796365,0.000000,1.412127,0.000000,0.000000,37,15.346783
8,2,0,946.850403,85.634987,16.053158,0.000000,0.000000,0.000000,0.000000,0.000000,...,4.444443,2.605409,0.000000,2.428115,0.000000,0.000000,3.989892,0.000000,56,13.467859
12,3,0,1361.559570,65.815376,72.125893,4.617892,0.000000,5.371887,0.000000,0.000000,...,0.000000,1.658730,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,35,14.680487
16,4,0,1611.186646,11.061340,32.804089,0.000000,0.000000,0.000000,4.718218,0.000000,...,0.000000,0.000000,0.000000,4.441515,0.000000,2.958400,0.000000,3.162205,132,15.280338
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5484,1371,1,3709.324219,4762.822754,45.909714,0.000000,0.000000,0.000000,7.288084,4.221876,...,0.000000,0.000000,0.000000,3.988392,0.000000,0.000000,1.837208,0.000000,137,17.335558
5488,1372,1,4088.468994,4728.019531,8.726340,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,5.008657,1.933038,0.000000,0.000000,0.000000,81,17.790283
5492,1373,1,4259.401855,4759.837891,82.290825,0.000000,0.000000,0.000000,0.000000,0.000000,...,7.226398,0.000000,0.000000,0.000000,0.000000,4.744279,4.272971,0.000000,127,19.073029
5496,1374,1,4308.152344,4743.708496,1.135719,0.000000,4.613276,0.000000,0.000000,0.000000,...,0.000000,0.000000,1.598869,0.000000,0.000000,0.000000,0.000000,0.000000,26,15.876626


In [None]:
len(target_df)

1376

In [None]:
target_df.loc[0]

loc_idx        0.000000
frame_idx      0.000000
x            631.309143
y             85.199944
z             32.783012
int_0          1.799818
int_1          0.000000
int_2          0.000000
int_3          0.000000
int_4          0.000000
int_5          0.000000
int_6          0.000000
int_7          0.000000
int_8          0.000000
int_9          3.552146
int_10         0.000000
int_11         0.000000
int_12         0.000000
int_13         0.000000
int_14         5.447395
int_15         3.346759
code_inds     48.000000
ints          14.146118
Name: 0, dtype: float64

In [None]:
!nbdev_build_lib

Converted 00_models.ipynb.
Converted 01_psf.ipynb.
Converted 02_microscope.ipynb.
Converted 03_noise.ipynb.
Converted 04_pointsource.ipynb.
Converted 05_gmm_loss.ipynb.
Converted 06_plotting.ipynb.
Converted 07_file_io.ipynb.
Converted 08_dataset.ipynb.
Converted 09_output_trafo.ipynb.
Converted 10_evaluation.ipynb.
Converted 11_emitter_io.ipynb.
Converted 12_utils.ipynb.
Converted 13_train.ipynb.
Converted 15_fit_psf.ipynb.
Converted 16_visualization.ipynb.
Converted 17_eval_routines.ipynb.
Converted 18_predict_funcs.ipynb.
Converted 19_MERFISH_routines.ipynb.
Converted 20_MERFISH_visualization.ipynb.
Converted 22_MERFISH_codenet.ipynb.
Converted 23_MERFISH_comparison.ipynb.
Converted index.ipynb.
