# imports

In [3]:
from torch.utils.data import DataLoader, random_split
from torch.utils.data.sampler import WeightedRandomSampler
from torchvision.datasets import ImageFolder
from torchvision import transforms
import torch
import torch.nn as nn
import numpy as np
import random 
import matplotlib.pyplot as plt
from torchvision import models
from torchsummary import summary

from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
import gunpowder as gp
import zarr
import math
%load_ext tensorboard
from torch.utils.tensorboard import SummaryWriter
import skimage
import networkx
import pathlib
from tifffile import imread, imwrite
import tensorboard
import torch.nn.functional as F
import glob

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  from .autonotebook import tqdm as notebook_tqdm


# Model

In [4]:
# model parameters
input_size = (1, 64, 64, 5)
downsample_factors =[(2, 2, 1), (2, 2, 1), (2, 2, 1), (2, 2, 1)];
output_classes = 12

# model definition
class Vgg3D(torch.nn.Module):

    def __init__(self, input_size, output_classes, downsample_factors, fmaps=12):

        super(Vgg3D, self).__init__()

        self.input_size = input_size
        self.downsample_factors = downsample_factors
        self.output_classes = 2

        current_fmaps, h, w, d = tuple(input_size)
        current_size = (h, w,d)

        features = []
        for i in range(len(downsample_factors)):

            features += [
                torch.nn.Conv3d(current_fmaps,fmaps,kernel_size=3,padding=1),
                torch.nn.BatchNorm3d(fmaps),
                torch.nn.ReLU(inplace=True),
                torch.nn.Conv3d(fmaps,fmaps,kernel_size=3,padding=1),
                torch.nn.BatchNorm3d(fmaps),
                torch.nn.ReLU(inplace=True),
                torch.nn.MaxPool3d(downsample_factors[i])
            ]

            current_fmaps = fmaps
            fmaps *= 2

            size = tuple(
                int(c/d)
                for c, d in zip(current_size, downsample_factors[i]))
            check = (
                s*d == c
                for s, d, c in zip(size, downsample_factors[i], current_size))
            assert all(check), \
                "Can not downsample %s by chosen downsample factor" % \
                (current_size,)
            current_size = size

        self.features = torch.nn.Sequential(*features)

        classifier = [
            torch.nn.Linear(current_size[0] *current_size[1]*current_size[2] *current_fmaps,4096),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(),
            torch.nn.Linear(4096,4096),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(),
            torch.nn.Linear(4096,output_classes)
        ]

        self.classifier = torch.nn.Sequential(*classifier)
    
    def forward(self, raw):

        # add a channel dimension to raw
        # shape = tuple(raw.shape)
        # raw = raw.reshape(shape[0], 1, shape[1], shape[2])
        
        # compute features
        f = self.features(raw)
        f = f.view(f.size(0), -1)
        
        # classify
        y = self.classifier(f)

        return y

# create the model to train
model = Vgg3D(input_size, output_classes,  downsample_factors = downsample_factors)
model = model.to(device)
#summary(model, input_size)

# define loss function
loss_function = torch.nn.CosineEmbeddingLoss()
#loss_function = ContrastiveLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0.0005)

# Load model

In [5]:
# path to state file
modelstateP = '/mnt/shared/celltracking/modelstates/aaron/'
stateFile = 'epoch_27'

model.load_state_dict(torch.load(modelstateP+stateFile))
model.eval()

Vgg3D(
  (features): Sequential(
    (0): Conv3d(1, 12, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): BatchNorm3d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv3d(12, 12, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (4): BatchNorm3d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool3d(kernel_size=(2, 2, 1), stride=(2, 2, 1), padding=0, dilation=1, ceil_mode=False)
    (7): Conv3d(12, 24, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (8): BatchNorm3d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv3d(24, 24, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (11): BatchNorm3d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool3d(kernel_size=(2, 2, 1), stride=(2, 2, 1), 

# Extract cell- and frame-wise model outputs

In [None]:
# set parameters
volSize = (1,5,64, 64)
zarrdir = '/mnt/shared/celltracking/data/cho/02.zarr'
raw = gp.ArrayKey('raw')

# extract centroids from annotated image stacks
annotationPath = pathlib.Path("/mnt/shared/celltracking/data/cho/")
annotations = np.stack([imread(xi) for xi in sorted((annotationPath / "02_GT/TRA").glob("*.tif"))])  # images
cells = []
for t, frame in enumerate(annotations):
    centers = skimage.measure.regionprops(frame)
    for c in centers:
        cells.append([c.label, t, int(c.centroid[1]), int(c.centroid[2])])

# define gp pipeline
pipeline_allCentroids = (gp.ZarrSource(
    zarrdir,  # the zarr container
    {raw: 'raw'},  # which dataset to associate to the array key
    {raw: gp.ArraySpec(voxel_size=(1,1,1,1), interpolatable=True)})  # meta-information
    + gp.Pad(raw, None))

# constructs gp pipeline

gp.ArraySpec()

# loop over all cell centroids
predictions = []
i=0
for id,t,x,y in cells:
    # determine coordinates
    coord = (t,0,x-(volSize[2]/2),y-(volSize[3]/2))
    request = gp.BatchRequest()
    request[raw] = gp.Roi(coord, volSize)
    
    with gp.build(pipeline_allCentroids):
        batch = pipeline_allCentroids.request_batch(request)
        
    # show the content of the batch
    # print(f"batch returned: {batch}")

    # # plot first slice of volume
    # print(batch[raw].data.shape)
    # plt.imshow(np.flipud(batch[raw].data[0,0,:,:]))

    ## evaluate model for each centroid using gp pipeline
    vol = batch[raw].data
    vol = np.reshape(vol, (1,64, 64, 5))
    vol = np.expand_dims(vol, axis =0)
    vol = torch.from_numpy(vol).to(device).float()
    pred = model(vol)
    pred = pred.detach().cpu().numpy()
    
    # save pred into list with id + position information
    predictions.append([c.label, t, int(c.centroid[1]), int(c.centroid[2]), pred])
    i += 1
    if i%50==0:
        print(f'done with: {i}/{len(cells)} total')