# imports

In [1]:
from torch.utils.data import DataLoader, random_split
from torch.utils.data.sampler import WeightedRandomSampler
from torchvision.datasets import ImageFolder
from torchvision import transforms
import torch
import torch.nn as nn
import numpy as np
import random 
import matplotlib.pyplot as plt
from torchvision import models
from torchsummary import summary

from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
import gunpowder as gp
import zarr
import math
%load_ext tensorboard
from torch.utils.tensorboard import SummaryWriter
import skimage
import networkx
import pathlib
from tifffile import imread, imwrite
import tensorboard
import torch.nn.functional as F
import glob

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  from .autonotebook import tqdm as notebook_tqdm


# Model

In [2]:
# model parameters
input_size = (1, 64, 64, 5)
downsample_factors =[(2, 2, 1), (2, 2, 1), (2, 2, 1), (2, 2, 1)];
output_classes = 12

# model definition
class Vgg3D(torch.nn.Module):

    def __init__(self, input_size, output_classes, downsample_factors, fmaps=12):

        super(Vgg3D, self).__init__()

        self.input_size = input_size
        self.downsample_factors = downsample_factors
        self.output_classes = 2

        current_fmaps, h, w, d = tuple(input_size)
        current_size = (h, w,d)

        features = []
        for i in range(len(downsample_factors)):

            features += [
                torch.nn.Conv3d(current_fmaps,fmaps,kernel_size=3,padding=1),
                torch.nn.BatchNorm3d(fmaps),
                torch.nn.ReLU(inplace=True),
                torch.nn.Conv3d(fmaps,fmaps,kernel_size=3,padding=1),
                torch.nn.BatchNorm3d(fmaps),
                torch.nn.ReLU(inplace=True),
                torch.nn.MaxPool3d(downsample_factors[i])
            ]

            current_fmaps = fmaps
            fmaps *= 2

            size = tuple(
                int(c/d)
                for c, d in zip(current_size, downsample_factors[i]))
            check = (
                s*d == c
                for s, d, c in zip(size, downsample_factors[i], current_size))
            assert all(check), \
                "Can not downsample %s by chosen downsample factor" % \
                (current_size,)
            current_size = size

        self.features = torch.nn.Sequential(*features)

        classifier = [
            torch.nn.Linear(current_size[0] *current_size[1]*current_size[2] *current_fmaps,4096),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(),
            torch.nn.Linear(4096,4096),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(),
            torch.nn.Linear(4096,output_classes)
        ]

        self.classifier = torch.nn.Sequential(*classifier)
    
    def forward(self, raw):

        # add a channel dimension to raw
        # shape = tuple(raw.shape)
        # raw = raw.reshape(shape[0], 1, shape[1], shape[2])
        
        # compute features
        f = self.features(raw)
        f = f.view(f.size(0), -1)
        
        # classify
        y = self.classifier(f)

        return y

# create the model to train
model = Vgg3D(input_size, output_classes,  downsample_factors = downsample_factors)
model = model.to(device)
#summary(model, input_size)

# define loss function
loss_function = torch.nn.CosineEmbeddingLoss()
#loss_function = ContrastiveLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0.0005)

# Load model

In [3]:
# path to state file
modelstateP = '/mnt/shared/celltracking/modelstates/aaron/'
stateFile = 'epoch_27'

model.load_state_dict(torch.load(modelstateP+stateFile))
model.eval()

Vgg3D(
  (features): Sequential(
    (0): Conv3d(1, 12, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): BatchNorm3d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv3d(12, 12, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (4): BatchNorm3d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool3d(kernel_size=(2, 2, 1), stride=(2, 2, 1), padding=0, dilation=1, ceil_mode=False)
    (7): Conv3d(12, 24, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (8): BatchNorm3d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv3d(24, 24, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (11): BatchNorm3d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool3d(kernel_size=(2, 2, 1), stride=(2, 2, 1), 

# Extract cell- and frame-wise model outputs

In [46]:
# set parameters
volSize = (1,5,64, 64)
zarrdir = '/mnt/shared/celltracking/data/cho/02.zarr'
raw = gp.ArrayKey('raw')

# extract centroids from annotated image stacks
annotationPath = pathlib.Path("/mnt/shared/celltracking/data/cho/")
annotations = np.stack([imread(xi) for xi in sorted((annotationPath / "02_GT/TRA").glob("*.tif"))])  # images
cells = []
for t, frame in enumerate(annotations):
    centers = skimage.measure.regionprops(frame)
    for c in centers:
        cells.append([c.label, t, int(c.centroid[1]), int(c.centroid[2])])

# define gp pipeline
pipeline_allCentroids = (gp.ZarrSource(
    zarrdir,  # the zarr container
    {raw: 'raw'},  # which dataset to associate to the array key
    {raw: gp.ArraySpec(voxel_size=(1,1,1,1), interpolatable=True)})  # meta-information
    + gp.Pad(raw, None))

# constructs gp pipeline

gp.ArraySpec()

# loop over all cell centroids
predictions = []
i=0
for id,t,x,y in cells:
    # determine coordinates
    coord = (t,0,x-(volSize[2]/2),y-(volSize[3]/2))
    request = gp.BatchRequest()
    request[raw] = gp.Roi(coord, volSize)
    
    with gp.build(pipeline_allCentroids):
        batch = pipeline_allCentroids.request_batch(request)
        
    # show the content of the batch
    # print(f"batch returned: {batch}")

    # # plot first slice of volume
    # print(batch[raw].data.shape)
    # plt.imshow(np.flipud(batch[raw].data[0,0,:,:]))

    ## evaluate model for each centroid using gp pipeline
    vol = batch[raw].data
    vol = np.reshape(vol, (1,64, 64, 5))
    vol = np.expand_dims(vol, axis =0)
    vol = torch.from_numpy(vol).to(device).float()
    pred = model(vol)
    pred = pred.detach().cpu().numpy()
    
    # save pred into list with id + position information
    predictions.append([id, t, x, y, pred])
    i += 1
    if i%50==0:
        print(f'done with: {i}/{len(cells)} total')

done with: 50/614 total
done with: 100/614 total
done with: 150/614 total
done with: 200/614 total
done with: 250/614 total
done with: 300/614 total
done with: 350/614 total
done with: 400/614 total
done with: 450/614 total
done with: 500/614 total
done with: 550/614 total
done with: 600/614 total


In [44]:
len(predictions)

614

In [51]:
tracks = np.array(predictions)[:,0:4]

  tracks = np.array(predictions)[:,0:4]


In [21]:
np.array(predictions)[:,4].shape

  np.array(predictions)[:,4].shape


(614,)

### MATRIX T and T+1

In [52]:
from scipy.spatial import distance
distances = []
#row corresponds to index of cells in t
rows = [] 
#row corresponds to index of cells in t+1
cols = []

#Loop throug the times frames
for t in range(max(tracks[:,1])):

    #get index
    idxt=np.where(tracks[:,1]==t)[0]
    idxt_next=np.where(tracks[:,1]==t+1)[0]
    t_matrix=np.zeros((len(idxt),len(idxt_next)), dtype=float)

    for ii in range(0, len(idxt)):
        for jj in range(0, len(idxt_next)):
            #coordinate x,y cellN in t
            pt1=[tracks[ii,2], tracks[ii,3]]
            #pt1=tracks[ii,2:]-->embedding
            #coordinate x,y cellN in t next
            pt2=[tracks[jj,2], tracks[jj,3]] 
            #distance from pt1 and pt2
            dist=distance.euclidean(pt1,pt2)
            #dist = np.sqrt(np.sum(np.square(pt1-pt2)))
            #fill matrix with the distances
            
            t_matrix[ii,jj]= dist
            
    #print(t, len(idxt), len(idxt_next))
    distances.append(t_matrix)
    rows.append(idxt)
    cols.append(idxt_next)

In [53]:
distances_em[43].shape

(8, 7)

###  Embedding matrix distances

In [54]:
from scipy.spatial import distance

predictnp=np.array(predictions)

distances_em = []
#row corresponds to index of cells in t
rows_em = [] 
#row corresponds to index of cells in t+1
cols_em = []

#Loop throug the times frames
for t in range(max(predictnp[:,1])):

    #get index
    idxt=np.where(predictnp[:,1]==t)[0]
    idxt_next=np.where(predictnp[:,1]==t+1)[0]
    t_matrix_emb=np.zeros((len(idxt),len(idxt_next)), dtype=float)

    for ii in range(0, len(idxt)):
        for jj in range(0, len(idxt_next)):
            #coordinate x,y cellN in t
            pt1=predictnp[ii,4][0]
            #pt1=tracks[ii,2:]-->embedding
            #coordinate x,y cellN in t next
            pt2=predictnp[jj,4][0]
            #distance from pt1 and pt2
            dist=distance.euclidean(pt1,pt2)
            #dist = np.sqrt(np.sum(np.square(pt1-pt2)))
            #fill matrix with the distances
            
            t_matrix_emb[ii,jj]= dist
            
    #print(t, len(idxt), len(idxt_next))
    distances_em.append(t_matrix_emb)
    rows_em.append(idxt)
    cols_em.append(idxt_next)

  predictnp=np.array(predictions)


In [58]:
cost_matrix_t=distances[t] + distances_em[t]

In [78]:
predictnp

array([[1, 0, 160, 297,
        array([[  7770.195 , -18469.22  , -10354.515 ,  17674.717 ,   5795.583 ,
                 16273.54  ,   7105.2046,  -4123.329 ,  -5032.6187, -16819.05  ,
                  1590.4056,  10218.346 ]], dtype=float32)                      ],
       [4, 0, 273, 208,
        array([[  8536.617 , -21871.062 , -11443.676 ,  21673.814 ,   7095.0366,
                 19508.102 ,   8196.135 ,  -4920.857 ,  -4863.902 , -20336.3   ,
                  1596.6511,  12575.106 ]], dtype=float32)                      ],
       [7, 0, 356, 272,
        array([[  5733.0303 , -16769.85   ,  -4270.3696 ,  16714.781  ,
                 -2707.229  ,   8084.8296 ,   9418.933  ,   -472.37906,
                  4309.2476 , -22026.822  ,   5204.6724 ,  13365.835  ]],
              dtype=float32)                                             ],
       ...,
       [20, 91, 99, 296,
        array([[  6853.2026, -16915.326 ,  -8936.24  ,  17078.479 ,   5272.0557,
                 15176.544

In [76]:
def link_two_frames(cost_matrix): 
    
    cost_matrix = cost_matrix.copy()
    b = self.birth_cost_factor * min(self.threshold, cost_matrix.max())
    d = self.death_cost_factor * min(self.threshold, cost_matrix.max())
    no_link = max(cost_matrix.max(), max(b, d)) * 1e9
        
    cost_matrix[cost_matrix > self.threshold] = no_link
    lower_right = cost_matrix.transpose()

    deaths = np.full(shape=(cost_matrix.shape[0], cost_matrix.shape[0]), fill_value=no_link)
    np.fill_diagonal(deaths, d)
    births = np.full(shape=(cost_matrix.shape[1], cost_matrix.shape[1]), fill_value=no_link)
    np.fill_diagonal(births, b)
        
    square_cost_matrix = np.block([
        [cost_matrix, deaths],
        [births, lower_right],
    ])
    row_ind, col_ind = scipy.optimize.linear_sum_assignment(square_cost_matrix)
        
    ids_from = []
    ids_to = []
    births = []
    deaths = []
    for row, col in zip(row_ind, col_ind):
        if row < cost_matrix.shape[0] and col < cost_matrix.shape[1]:
            ids_from.append(row)
            ids_to.append(col)

        if row >= cost_matrix.shape[0] and col < cost_matrix.shape[1]:
                births.append(col)
        if row < cost_matrix.shape[0] and col >= cost_matrix.shape[1]:
                deaths.append(row)

    ids_from = np.array(ids_from)
    ids_to = np.array(ids_to)
    births = np.array(births)
    deaths = np.array(deaths)
                        
    # Account for +1 offset of the dense labels
    ids_from += 1
    ids_to += 1
    births += 1
    deaths += 1
        
    links = {"links": (ids_from, ids_to), "births": births, "deaths": deaths}
    return links

In [77]:
for time in range(len(distances)):
    cost_matrix_t=distances[time] + distances_em[time]
    links_data=link_two_frames(cost_matrix)
    

NameError: name 'self' is not defined