# Model Implementation for 3D Cell Tracking


In [50]:
# !pip install torchsummary 
# !pip install gunpowder
# !pip install zarr
# !pip install matplotlib
#!pip install tensorboard

Collecting tensorboard
  Using cached tensorboard-2.10.0-py3-none-any.whl (5.9 MB)
Collecting absl-py>=0.4
  Using cached absl_py-1.2.0-py3-none-any.whl (123 kB)
Collecting markdown>=2.6.8
  Using cached Markdown-3.4.1-py3-none-any.whl (93 kB)
Collecting google-auth-oauthlib<0.5,>=0.4.1
  Using cached google_auth_oauthlib-0.4.6-py2.py3-none-any.whl (18 kB)
Collecting tensorboard-plugin-wit>=1.6.0
  Using cached tensorboard_plugin_wit-1.8.1-py3-none-any.whl (781 kB)
Collecting protobuf<3.20,>=3.9.2
  Using cached protobuf-3.19.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB)
Collecting tensorboard-data-server<0.7.0,>=0.6.0
  Using cached tensorboard_data_server-0.6.1-py3-none-manylinux2010_x86_64.whl (4.9 MB)
Collecting werkzeug>=1.0.1
  Using cached Werkzeug-2.2.2-py3-none-any.whl (232 kB)
Collecting google-auth<3,>=1.6.3
  Using cached google_auth-2.11.0-py2.py3-none-any.whl (167 kB)
Collecting grpcio>=1.24.3
  Downloading grpcio-1.48.1-cp39-cp39-manylinux_2_17_x86_

In [57]:
from torch.utils.data import DataLoader, random_split
from torch.utils.data.sampler import WeightedRandomSampler
from torchvision.datasets import ImageFolder
from torchvision import transforms
import torch
import torch.nn as nn
import numpy as np
import random 
import matplotlib.pyplot as plt
from torchvision import models
from torchsummary import summary

from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
import gunpowder as gp
import zarr
import math
#%load_ext tensorboard
from torch.utils.tensorboard import SummaryWriter

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


# Data

In [12]:
# provide path to zarr directory
zarrdir = '/mnt/shared/celltracking/data/cho/01.zarr'

data = zarr.open(zarrdir)
loader = []

In [13]:
# provide path to zarr directory
#zarrdir = '/mnt/shared/celltracking/data/cho/01.zarr'

# pre-selected volumes from cho dataset
coord_paired = ((0,0,80, 175),(1,0,80, 180))
coord_unpaired = ((0,0,80, 175),(1,0,92, 232))

# specify subvolume size and volume source
volSize = (1,5,64, 64)
coord = coord_paired[0]

# declare arrays to use in the pipeline
key ='raw'
raw = gp.ArrayKey(key)

# create "pipeline" consisting only of a data source
source = gp.ZarrSource(
    zarrdir,  # the zarr container
    {raw: key},  # which dataset to associate to the array key
    {raw: gp.ArraySpec(interpolatable=True, voxel_size = (1,1,1,1))}  # meta-information
)
pipeline = source

# specify request paired
request_vol1p = gp.BatchRequest()
request_vol2p = gp.BatchRequest()
request_vol1p[raw] = gp.Roi(coord_paired[0], volSize)
request_vol2p[raw] = gp.Roi(coord_paired[1], volSize)

# specify request unpaired
request_vol1u = gp.BatchRequest()
request_vol2u = gp.BatchRequest()
request_vol1u[raw] = gp.Roi(coord_unpaired[0], volSize)
request_vol2u[raw] = gp.Roi(coord_unpaired[1], volSize)

# build the pipeline
with gp.build(pipeline):
    batch_vol1p = pipeline.request_batch(request_vol1p)
    batch_vol2p = pipeline.request_batch(request_vol2p)
    batch_vol1u = pipeline.request_batch(request_vol1u)
    batch_vol2u = pipeline.request_batch(request_vol2u)
    
# # specify request 
# request_vol1 = gp.BatchRequest()
# request_vol2 = gp.BatchRequest()
# request_vol1[raw] = gp.Roi(coord_paired[0], volSize)
# request_vol2[raw] = gp.Roi(coord_paired[1], volSize)

# # build the pipeline
# with gp.build(pipeline):
#     batch_vol1 = pipeline.request_batch(request_vol1)
#     batch_vol2 = pipeline.request_batch(request_vol2)
    
# show the content of the batch
#print(f"batch returned: {batch_vol1}")

# plot first slice of volume

# fig, (ax1, ax2) = plt.subplots(1, 2)
# ax1.imshow(np.flipud(batch_vol1[raw].data[0,1,:,:]))
# ax2.imshow(np.flipud(batch_vol2[raw].data[0,1,:,:]))

# load volume into napari
# viewer = napari.Viewer()
# viewer.add_image(batch[raw].data, name="volume 1");

In [14]:
vol1 = batch_vol1p[raw].data
vol2 = batch_vol2p[raw].data

vol1 = np.reshape(vol1, (1,64, 64, 5))
vol2 = np.reshape(vol2, (1,64, 64, 5))
y = 1

vol1 = np.expand_dims(vol1, axis =0)
vol2 = np.expand_dims(vol2, axis=0)

loader.append((vol1, vol2, y))
np.shape(loader)

(1, 3)

# Augmentations

In [18]:
#Augmentations=

# chose a random source (i.e., sample) from the above
#random_location = gp.RandomLocation()

# elastically deform the batch
Elastic_augment=gp.ElasticAugment(
    [2,10,10],
    [0,2,2],
    [0,0*math.pi/2.0],
    prob_slip=0.05,
    prob_shift=0.05,
    max_misalign=25)

# apply transpose and mirror augmentations
# Simple_augment=gp.SimpleAugment(transpose_only=[2, 3], mirror_only=[]) 
Simple_augment=gp.SimpleAugment(transpose_only=[], mirror_only=[])

# scale and shift the intensity of the raw array
Intensity_augment=gp.IntensityAugment(
    raw,
    scale_min=0.9,
    scale_max=1.1,
    shift_min=-0.1,
    shift_max=0.1,
    )

Noise_augment = gp.NoiseAugment(raw, mode="gaussian")

pipeline = (
    source + gp.Normalize(raw)+
            Intensity_augment+Elastic_augment + Simple_augment + Noise_augment
           )

# build the pipeline
with gp.build(pipeline):
    for x in range(10):
        batch_vol1p = pipeline.request_batch(request_vol1p)
        batch_vol2p = pipeline.request_batch(request_vol2p)
        batch_vol1u = pipeline.request_batch(request_vol1u)
        batch_vol2u = pipeline.request_batch(request_vol2u)
        # show the content of the batch
        #print(f"batch returned: {batch}")

        # plot first slice of volume

        # fig, (ax1, ax2) = plt.subplots(1, 2)
        # ax1.imshow(np.flipud(batch_vol1p[raw].data[0,1,:,:]))
        # ax2.imshow(np.flipud(batch_vol2p[raw].data[0,1,:,:]))
        # plt.show()

with gp.build(pipeline):
    for x in range(10):
        batch_vol1p = pipeline.request_batch(request_vol1p)
        batch_vol2p = pipeline.request_batch(request_vol2p)
        batch_vol1u = pipeline.request_batch(request_vol1u)
        batch_vol2u = pipeline.request_batch(request_vol2u)


In [17]:
# specify request 
#plt.imshow(np.flipud(batch_vol1[raw].data[0,0,:,:]))

# request = gp.BatchRequest()
# request[raw] = gp.Roi(coord, volSize)

# with gp.build(pipeline):
#     batch_vol1_aug = pipeline.request_batch(request)
#     batch_vol2_aug = pipeline.request_batch(request)

# fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4)
# ax1.imshow(np.flipud(batch_vol1[raw].data[0,0,:,:]))
# ax1.set_title('input vol1')
# ax2.imshow(np.flipud(batch_vol1_aug[raw].data[0,0,:,:]))
# ax2.set_title('aug vol1')
# ax3.imshow(np.flipud(batch_vol2[raw].data[0,0,:,:]))
# ax3.set_title('input vol2')
# ax4.imshow(np.flipud(batch_vol2_aug[raw].data[0,0,:,:]))
# ax4.set_title('aug vol2')

Add augmentations to loader

In [None]:
# vol1 = batch_vol1_aug[raw].data
# vol2 = batch_vol2_aug[raw].data

# vol1 = np.reshape(vol1, (1,64, 64, 5))
# vol2 = np.reshape(vol2, (1,64, 64, 5))
# y = 1

# vol1 = np.expand_dims(vol1, axis =0)
# vol2 = np.expand_dims(vol2, axis=0)

# loader.append((vol1, vol2, y))
# np.shape(loader)

In [44]:
# loader[2][2]

1

# Define the Model

In [19]:
class Vgg3D(torch.nn.Module):

    def __init__(self, input_size, output_classes, downsample_factors, fmaps=12):

        super(Vgg3D, self).__init__()

        self.input_size = input_size
        self.downsample_factors = downsample_factors
        self.output_classes = 2

        current_fmaps, h, w, d = tuple(input_size)
        current_size = (h, w,d)

        features = []
        for i in range(len(downsample_factors)):

            features += [
                torch.nn.Conv3d(current_fmaps,fmaps,kernel_size=3,padding=1),
                torch.nn.BatchNorm3d(fmaps),
                torch.nn.ReLU(inplace=True),
                torch.nn.Conv3d(fmaps,fmaps,kernel_size=3,padding=1),
                torch.nn.BatchNorm3d(fmaps),
                torch.nn.ReLU(inplace=True),
                torch.nn.MaxPool3d(downsample_factors[i])
            ]

            current_fmaps = fmaps
            fmaps *= 2

            size = tuple(
                int(c/d)
                for c, d in zip(current_size, downsample_factors[i]))
            check = (
                s*d == c
                for s, d, c in zip(size, downsample_factors[i], current_size))
            assert all(check), \
                "Can not downsample %s by chosen downsample factor" % \
                (current_size,)
            current_size = size

        self.features = torch.nn.Sequential(*features)

        classifier = [
            torch.nn.Linear(current_size[0] *current_size[1]*current_size[2] *current_fmaps,4096),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(),
            torch.nn.Linear(4096,4096),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(),
            torch.nn.Linear(4096,output_classes)
        ]

        self.classifier = torch.nn.Sequential(*classifier)
    
    def forward(self, raw):

        # add a channel dimension to raw
        # shape = tuple(raw.shape)
        # raw = raw.reshape(shape[0], 1, shape[1], shape[2])
        
        # compute features
        f = self.features(raw)
        f = f.view(f.size(0), -1)
        
        # classify
        y = self.classifier(f)

        return y

# Loss Functions

We'll probably need to test some different loss functions. List some here:
Contrastive loss
cosine similarity
triplet loss



In [20]:
class ContrastiveLoss(nn.Module):
    "Contrastive loss function"

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean(
            (1 - label) * torch.pow(euclidean_distance, 2)
            + (label)
            * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
        )

        return loss_contrastive

In [21]:
input_size = (1, 64, 64, 5)
downsample_factors =[(2, 2, 1), (2, 2, 1), (2, 2, 1), (2, 2, 1)];
output_classes = 12

# create the model to train
model = Vgg3D(input_size, output_classes,  downsample_factors = downsample_factors)
model = model.to(device)

summary(model, input_size)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1        [-1, 12, 64, 64, 5]             336
       BatchNorm3d-2        [-1, 12, 64, 64, 5]              24
              ReLU-3        [-1, 12, 64, 64, 5]               0
            Conv3d-4        [-1, 12, 64, 64, 5]           3,900
       BatchNorm3d-5        [-1, 12, 64, 64, 5]              24
              ReLU-6        [-1, 12, 64, 64, 5]               0
         MaxPool3d-7        [-1, 12, 32, 32, 5]               0
            Conv3d-8        [-1, 24, 32, 32, 5]           7,800
       BatchNorm3d-9        [-1, 24, 32, 32, 5]              48
             ReLU-10        [-1, 24, 32, 32, 5]               0
           Conv3d-11        [-1, 24, 32, 32, 5]          15,576
      BatchNorm3d-12        [-1, 24, 32, 32, 5]              48
             ReLU-13        [-1, 24, 32, 32, 5]               0
        MaxPool3d-14        [-1, 24, 16

In [36]:
#Training length
epochs = 2000

#loss_function = torch.nn.BCELoss()
loss_function = torch.nn.CosineEmbeddingLoss()
#loss_function = ContrastiveLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0.0005)

# Training Test

# Implementing the Siamese Network

The above training is just to test if the VGG model works for 3D data. Here, the training will take two pairs of images and calculate the loss from both pairs of images.

In [65]:
#%tensorboard --logdir models

logger = SummaryWriter()
%tensorboard --logdir runs

In [68]:
from tqdm import tqdm

def train(tb_logger = None, log_image_interval = 10):
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
        
    loss=[] 
    counter=[]
    with gp.build(pipeline):
        for epoch in tqdm(range(epochs)):
            epoch_loss = 0
            epoch_loss_pos = 0
            epoch_loss_neg = 0
            for x in range(100):
                unpaired1 = pipeline.request_batch(request_vol1u)
                unpaired2 = pipeline.request_batch(request_vol2u)
                yu = -1
                
                paired1 = pipeline.request_batch(request_vol1p)
                paired2 = pipeline.request_batch(request_vol2p)
                yp = 1
                
                unpaired1 = unpaired1[raw].data[:,:,:,:]
                unpaired2 = unpaired2[raw].data[:,:,:,:]
                paired1 = paired1[raw].data[:,:,:,:]
                paired2 = paired2[raw].data[:,:,:,:]
                
                unpaired1 = np.reshape(unpaired1, (1,64, 64, 5))
                unpaired2 = np.reshape(unpaired2, (1,64, 64, 5))
                paired1 = np.reshape(paired1, (1,64, 64, 5))
                paired2 = np.reshape(paired2, (1,64, 64, 5))
                
                unpaired1 = np.expand_dims(unpaired1, axis =0)
                unpaired2 = np.expand_dims(unpaired2, axis=0)
                paired1 = np.expand_dims(paired1, axis =0)
                paired2 = np.expand_dims(paired2, axis=0)

                unpaired1 = torch.from_numpy(unpaired1).to(device).float()
                unpaired2 = torch.from_numpy(unpaired2).to(device).float()
                yu = torch.from_numpy(np.array([yu])).to(device).float()
                
                paired1 = torch.from_numpy(paired1).to(device).float()
                paired2 = torch.from_numpy(paired2).to(device).float() 
                yp = torch.from_numpy(np.array([yp])).to(device).float()

                optimizer.zero_grad()
                
                predp1 = model(unpaired1)
                predp2 = model(unpaired2)
                predu1 = model(paired1)
                predu2 = model(paired2)
                # print(model(unpaired1).shape)
                # print(predp1.shape)

                #loss = loss_function(pred, y)

                loss_contrastivep = loss_function(predp1,predp2,yp)
                loss_contrastiveu = loss_function(predu1,predu2,yu)

                loss_contrastivep.backward()
                loss_contrastiveu.backward()
                optimizer.step()    
                epoch_loss_pos += loss_contrastivep
                epoch_loss_neg += loss_contrastiveu
                epoch_loss += loss_contrastivep + loss_contrastiveu
                
                if tb_logger is not None:
                    step = epoch * 100 + x
                    tb_logger.add_scalar(
                        tag="positive_loss", scalar_value=epoch_loss_pos.item(), global_step=step
                    )
                    tb_logger.add_scalar(
                        tag="negative_loss", scalar_value=epoch_loss_neg.item(), global_step=step
                    )
                    tb_logger.add_scalar(
                        tag="total_loss", scalar_value=epoch_loss.item(), global_step = step
                    )
                    # check if we log images in this iteration
                    # if step % log_image_interval == 0:
                    #     tb_logger.add_images(
                    #         tag="in_unpaired1", img_tensor=unpaired1.to("cpu"), global_step=step
                    #     )
                    #     tb_logger.add_images(
                    #         tag="in_unpaired2", img_tensor=unpaired2.to("cpu"), global_step=step
                    #     )
                    #     tb_logger.add_images(
                    #         tag="in_paired1", img_tensor=paired1.to("cpu"), global_step=step
                    #     )
                    #     tb_logger.add_images(
                    #         tag="in_paired2", img_tensor=paired2.to("cpu"), global_step=step
                    #     )


            print(f"epoch {epoch}, total_loss = {epoch_loss}, positive_loss={epoch_loss_pos}, negative_loss={epoch_loss_neg}")
        
    return model

model = train(tb_logger = logger)

  0%|          | 1/2000 [00:33<18:51:41, 33.97s/it]

epoch 0, total_loss = 7.232059478759766, positive_loss=0.5443277955055237, negative_loss=6.687731742858887


  0%|          | 2/2000 [01:08<19:08:49, 34.50s/it]

epoch 1, total_loss = 10.415596008300781, positive_loss=4.628561973571777, negative_loss=5.78703498840332


  0%|          | 3/2000 [01:42<18:56:01, 34.13s/it]

epoch 2, total_loss = 2.9506771564483643, positive_loss=1.1045331954956055, negative_loss=1.8461437225341797


  0%|          | 4/2000 [02:15<18:42:53, 33.75s/it]

epoch 3, total_loss = 4.185060024261475, positive_loss=1.3474764823913574, negative_loss=2.837583065032959


  0%|          | 4/2000 [02:16<18:59:02, 34.24s/it]


KeyboardInterrupt: 

# Tracking / Linear Assignment