In [None]:
# default_exp engine.gmm_loss

In [None]:
%load_ext autoreload
%autoreload 2

# Loss function

In [None]:
#export
from decode_fish.imports import *
from torch import distributions as D, Tensor
from torch.distributions import Distribution

In [None]:
#export
class PointProcessGaussian(Distribution):
    def __init__(self, logits: torch.tensor, xyzi_mu: torch.tensor, xyzi_sigma: torch.tensor, **kwargs):
        """ Defines our loss function. Given logits, xyzi_mu and xyzi_sigma 
        
        The count loss first constructs a Gaussian approximation to the predicted number of emitters by summing the mean and the variance of the Bernoulli detection probability map,
        and then maximizes the probability of the true number of emitters under this distribution. 
        The localization loss models the distribution of sub-pixel localizations with a coordinate-wise independent Gaussian probability distribution  with a 3D standard deviation. 
        For imprecise localizations, this probability is maximized for large sigmas, for precise localizations for small sigmas. 
        The distribution of all localizations over the entire image is approximated as a weighted average of individual localization distributions, where the weights correspond to the probability of detection.
        
        Args:
            logits: shape (B,1,D,H,W)
            xyzi_mu: shape (B,4,D,H,W)
            xyzi_sigma: shape (B,4,D,H,W)
        """
        self.logits = logits
        self.xyzi_mu = xyzi_mu
        self.xyzi_sigma = xyzi_sigma

    def log_prob(self, locations, x_offset, y_offset, z_offset, intensities, n_bits, channels, min_int_sig):
        """ Creates the distributions for the count and localization loss and evaluates the log probability for the given set of localizations under those distriubtions.
        
            Args:
                locations: tuple with voxel locations of inferred emitters
                x_offset, y_offset,z_offset: continuous within pixel offsets. Has lenght of number of emitters in the whole batch.
                intensties: brightness of emitters. Has lenght of number of emitters in the whole batch.
                
            Returns:
                count_prob: count loss. Has langth of batch_size
                spatial_prob: localizations loss. Has langth of batch_size
        """     
        
        gauss_dim = 3 + channels
        batch_size = self.logits.shape[0]
        if channels == 1:
            xyzi, s_mask = get_true_labels(batch_size, locations, x_offset, y_offset, z_offset, intensities)
        else:
            xyzi, s_mask = get_true_labels_mf(batch_size, n_bits, channels, 
                                              locations, x_offset, y_offset, z_offset, intensities)
        counts = s_mask.sum(-1)
        
        P = torch.sigmoid(self.logits) 
        count_mean = P.sum(dim=[2, 3, 4]).squeeze(-1)
        count_var = (P - P ** 2).sum(dim=[2, 3, 4]).squeeze(-1) 
        count_dist = D.Normal(count_mean, torch.sqrt(count_var))
                
        count_prob =  count_dist.log_prob(counts) # * counts
        
        mixture_probs = P / P.sum(dim=[2, 3, 4], keepdim=True)
        
        pix_inds = torch.nonzero(P,as_tuple=True)

        xyzi_mu = self.xyzi_mu[pix_inds[0],:,pix_inds[2],pix_inds[3],pix_inds[4]]
        xyzi_mu[:,:3] += torch.stack([pix_inds[4],pix_inds[3],pix_inds[2]], 1) + 0.5
        xyzi_mu = xyzi_mu.reshape(batch_size,-1,gauss_dim)
        xyzi_sig = self.xyzi_sigma[pix_inds[0],:,pix_inds[2],pix_inds[3],pix_inds[4]].reshape(batch_size,-1,gauss_dim)

#         print(xyzi_mu.shape, xyzi.shape)
        
        mix = D.Categorical(mixture_probs[pix_inds].reshape(batch_size,-1))
        
        xyzi_sig[:,:,3:] = xyzi_sig[:,:,3:] + min_int_sig
        comp = D.Independent(D.Normal(xyzi_mu, xyzi_sig + 0.00001), 1)
        spatial_gmm = D.MixtureSameFamily(mix, comp)
        spatial_prob = spatial_gmm.log_prob(xyzi.transpose(0, 1)).transpose(0,1)
        spatial_prob = (spatial_prob * s_mask).sum(-1)

        return count_prob, spatial_prob

def get_sample_mask(bs, locations):
    
    counts_ = torch.unique(locations[0], return_counts=True)[1]
    batch_loc = torch.unique(locations[0])
    
    counts = torch.cuda.LongTensor(bs).fill_(0)
    counts[batch_loc] = counts_
    
    max_counts = counts.max()
    if max_counts==0: max_counts = 1 #if all 0 will return empty matrix of correct size
    s_arr = cum_count_per_group(locations[0])
    s_mask   = torch.cuda.FloatTensor(bs,max_counts).fill_(0)
    s_mask[locations[0],s_arr] = 1   
    
    return s_mask, s_arr
    
def get_true_labels(bs, locations, x_os, y_os, z_os, *args):
    
    s_mask, s_arr = get_sample_mask(bs, locations)
    max_counts = s_mask.shape[1]
    
    x =  x_os + locations[4].type(torch.cuda.FloatTensor) + 0.5 
    y =  y_os + locations[3].type(torch.cuda.FloatTensor) + 0.5 
    z =  z_os + locations[2].type(torch.cuda.FloatTensor) + 0.5 
    
    gt_vars = torch.stack([x, y, z] + [item for item in args], dim=1)
    gt_list = torch.cuda.FloatTensor(bs,max_counts,gt_vars.shape[1]).fill_(0)
    
    gt_list[locations[0],s_arr] = gt_vars
    return gt_list, s_mask    

def get_true_labels_mf(bs, n_bits, channels, locations, x_os, y_os, z_os, int_ch):
    
    xyz_locs = [l[::n_bits] for l in locations]
    x_os = x_os[::n_bits]
    y_os = y_os[::n_bits]
    z_os = z_os[::n_bits]
    
    s_mask, s_arr = get_sample_mask(bs, xyz_locs)
    max_counts = s_mask.shape[1]
    
    x =  x_os + xyz_locs[4].type(torch.cuda.FloatTensor) + 0.5 
    y =  y_os + xyz_locs[3].type(torch.cuda.FloatTensor) + 0.5 
    z =  z_os + xyz_locs[2].type(torch.cuda.FloatTensor) + 0.5 
    
    int_ch = int_ch.reshape(-1, n_bits)
    intensity = torch.zeros([len(int_ch), channels]).to(x.device)
    intensity[torch.arange(len(int_ch))[:,None], locations[1].reshape(-1, n_bits)] = int_ch
    
    gt_vars = torch.stack([x, y, z], dim=1)
    gt_vars = torch.cat([gt_vars, intensity], dim=1)
    gt_list = torch.cuda.FloatTensor(bs,max_counts,gt_vars.shape[1]).fill_(0)
    
    gt_list[xyz_locs[0],s_arr] = gt_vars
    return gt_list, s_mask  

def grp_range(counts: torch.Tensor):
    assert counts.dim() == 1

    idx = counts.cumsum(0)
    id_arr = torch.ones(idx[-1], dtype=int, device=counts.device)
    id_arr[0] = 0
    id_arr[idx[:-1]] = -counts[:-1] + 1
    return id_arr.cumsum(0)

def cum_count_per_group(arr):
    """
    Helper function that returns the cumulative sum per group.
    Example:
        [0, 0, 0, 1, 2, 2, 0] --> [0, 1, 2, 0, 0, 1, 3]
    """

    if arr.numel() == 0:
        return arr

    _, cnt = torch.unique(arr, return_counts=True)
    return grp_range(cnt)[np.argsort(np.argsort(arr.cpu().numpy(), kind='mergesort'), kind='mergesort')]

In [None]:
model_out = torch.load('../data/model_batch_output_code.pt')
for k in model_out:
    model_out[k] = model_out[k].to('cuda') 
model_out.keys()
# model_out['logits'].shape

dict_keys(['logits', 'xyzi_mu', 'xyzi_sigma', 'background'])

In [None]:
from decode_fish.engine.point_process import PointProcessUniform
point_process = PointProcessUniform(local_rate = torch.ones([2,1,48,48,48])*.0001, int_conc=1.0, sim_iters=5, channels=16, n_bits=4)
locs_3d, x_os_3d, y_os_3d, z_os_3d, ints_3d, output_shape = point_process.sample()

In [None]:
locs_3d = [l.cuda() for l in locs_3d]
# xyzi_true, s_mask = get_true_labels(2, locs_3d, x_os_3d.cuda(), y_os_3d.cuda(), z_os_3d.cuda(), ints_3d.cuda())
xyzi_true, s_mask = get_true_labels_mf(2, 4, 16, locs_3d, x_os_3d.cuda(), y_os_3d.cuda(), z_os_3d.cuda(), ints_3d.cuda())
print(len(locs_3d[0]))
print(s_mask)
print(s_mask.sum(-1))

116
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],
       device='cuda:0')
tensor([12., 17.], device='cuda:0')


In [None]:
ppg = PointProcessGaussian(**model_out)
gmm_loss = ppg.log_prob(locs_3d, x_os_3d.cuda(), y_os_3d.cuda(), z_os_3d.cuda(), ints_3d.cuda(), n_bits=4, channels=16, int_chs=6, min_int_sig=0.1)
gmm_loss

(tensor([-2.7682, -2.3394], device='cuda:0', grad_fn=<SubBackward0>),
 tensor([ -6410.5063, -12788.6523], device='cuda:0', grad_fn=<SumBackward1>))

In [None]:
gmm_loss[1]

tensor([ -6410.5063, -12788.6523], device='cuda:0', grad_fn=<SumBackward1>)

In [None]:
# for i in range(1000):
#     gmm_loss = PointProcessGaussian(**model_out).log_prob(locs_3d, x_os_3d.cuda(), y_os_3d.cuda(), z_os_3d.cuda(), ints_3d.cuda())

In [None]:
!nbdev_build_lib

Converted 00_models.ipynb.
Converted 01_psf.ipynb.
Converted 02_microscope.ipynb.
Converted 03_noise.ipynb.
Converted 04_pointsource.ipynb.
Converted 05_gmm_loss.ipynb.
Converted 06_plotting.ipynb.
Converted 07_file_io.ipynb.
Converted 08_dataset.ipynb.
Converted 09_output_trafo.ipynb.
Converted 10_evaluation.ipynb.
Converted 11_emitter_io.ipynb.
Converted 12_utils.ipynb.
Converted 13_train.ipynb.
Converted 15_fit_psf.ipynb.
Converted 16_visualization.ipynb.
Converted 17_eval_routines.ipynb.
Converted 18_predict_funcs.ipynb.
Converted 19_MERFISH_routines.ipynb.
Converted 20_MERFISH_visualization.ipynb.
Converted 21_MERFISH_routines_2.ipynb.
Converted hyper.ipynb.
Converted index.ipynb.
