# Model Implementation for 3D Cell Tracking


In [2]:
!pip install torchsummary 



In [3]:
from torch.utils.data import DataLoader, random_split
from torch.utils.data.sampler import WeightedRandomSampler
from torchvision.datasets import ImageFolder
from torchvision import transforms
import torch
import torch.nn as nn
import numpy as np
import random 
import matplotlib.pyplot as plt
from torchvision import models
from torchsummary import summary

from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Designed VGG Model from Exercise 9

We will use a VGG network to classify the synapse images. The input to the network will be a 2D image as provided by your dataloader. The output will be a vector of six floats, corresponding to the probability of the input to belong to the six classes.

Implement a VGG network with the following specificatons:

* the constructor takes the size of the 2D input image as height and width
* the network starts with a downsample path consisting of:
    * one convolutional layer, kernel size (3, 3), to create 12 `fmaps`
    * `nn.BatchNorm2d` over those feature maps
    * `nn.ReLU` activation function
    * `nn.Conv2d` layer, kernel size (3, 3), to create 12 `fmaps`
    * `nn.BatchNorm2d` over those feature maps
    * `nn.ReLU` activation function
    * `nn.MaxPool2d` with a `downsample_factor` of (2, 2) at each level
* followed by three more downsampling paths like the one above, every time doubling the number of `fmaps` (i.e., the second one will have 24, the third 48, and the fourth 96). Make sure to keep track of the `current_fmaps` each time!
* then two times:
    * `nn.Linear` layer with `out_features=4096`. Be careful withe in `in_features` of the first one, which will depend on the size of the previous output!
    * `nn.ReLU` activation function
    * `nn.DropOut`
* Finally, one more fully connected layer with
    * `nn.Linear` to the 6 classes
    * no activation function 

Original One (from https://blog.paperspace.com/vgg-from-scratch-pytorch/)
https://github.com/pytorch/vision/blob/6db1569c89094cf23f3bc41f79275c45e9fcb3f3/torchvision/models/vgg.py#L24

# Create and Load Artificial Data
Make a fake dataset to test on the VGG model while waiting for data.


In [37]:
#original data size = 512x712x34 

fd_class1 = np.random.randn(1, 1, 128,128,128) + 0.5
fd_class2 = np.random.randn(1,1,128,128,128)
y1 = 1
y2 = 0
# fd_class1 = np.expand_dims(fd_class1, axis=0)
# fd_class2 = np.expand_dims(fd_class2, axis=0)

loader = [(fd_class1, y1), (fd_class2,y2)]

#Split
# train_set_size = int(len(fd_class1) * 0.7)
# valid_set_size = int(len(fd_class1) * 0.2)
# test_data_size = len(fd_class1) - train_set_size - valid_set_size
    
# train_data_C1, val_data_C1, test_data_C1 = random_split(
#     fd_class1,
#     [train_set_size, valid_set_size, test_data_size],
#     generator=torch.Generator().manual_seed(23061912))

# train_data_C2, val_data_C2, test_data_C2 = random_split(
#     fd_class2,
#     [train_set_size, valid_set_size, test_data_size],
#     generator=torch.Generator().manual_seed(23061912))

#train = np.ndarray.flatten(fd_class1)
#train2 = np.ndarray.flatten(fd_class2)




In [6]:
#sampler = balanced_sampler(train_data_C1)
#dataloader = DataLoader(train_data_C1, batch_size=8, drop_last=True)

# Define the Model

In [7]:
class Vgg3D(torch.nn.Module):

    def __init__(self, input_size, output_classes, downsample_factors, fmaps=12):

        super(Vgg3D, self).__init__()

        self.input_size = input_size
        self.downsample_factors = downsample_factors
        self.output_classes = 2

        current_fmaps, h, w, d = tuple(input_size)
        current_size = (h, w,d)

        features = []
        for i in range(len(downsample_factors)):

            features += [
                torch.nn.Conv3d(current_fmaps,fmaps,kernel_size=3,padding=1),
                torch.nn.BatchNorm3d(fmaps),
                torch.nn.ReLU(inplace=True),
                torch.nn.Conv3d(fmaps,fmaps,kernel_size=3,padding=1),
                torch.nn.BatchNorm3d(fmaps),
                torch.nn.ReLU(inplace=True),
                torch.nn.MaxPool3d(downsample_factors[i])
            ]

            current_fmaps = fmaps
            fmaps *= 2

            size = tuple(
                int(c/d)
                for c, d in zip(current_size, downsample_factors[i]))
            check = (
                s*d == c
                for s, d, c in zip(size, downsample_factors[i], current_size))
            assert all(check), \
                "Can not downsample %s by chosen downsample factor" % \
                (current_size,)
            current_size = size

        self.features = torch.nn.Sequential(*features)

        classifier = [
            torch.nn.Linear(current_size[0] *current_size[1]*current_size[2] *current_fmaps,4096),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(),
            torch.nn.Linear(4096,4096),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(),
            torch.nn.Linear(4096,output_classes)
        ]

        self.classifier = torch.nn.Sequential(*classifier)
    
    def forward(self, raw):

        # add a channel dimension to raw
        # shape = tuple(raw.shape)
        # raw = raw.reshape(shape[0], 1, shape[1], shape[2])
        
        # compute features
        f = self.features(raw)
        f = f.view(f.size(0), -1)
        
        # classify
        y = self.classifier(f)

        return y

# Training and Evaluation

# Loss Functions

We'll probably need to test some different loss functions. List some here:
Contrastive loss
cosine similarity
triplet loss



In [9]:
class ContrastiveLoss(nn.Module):
    "Contrastive loss function"

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean(
            (1 - label) * torch.pow(euclidean_distance, 2)
            + (label)
            * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
        )

        return loss_contrastive

In [10]:
input_size = (1, 128, 128, 128)
downsample_factors =[(2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2)];
output_classes = 16

# create the model to train
model = Vgg3D(input_size, output_classes,  downsample_factors = downsample_factors)
model = model.to(device)

summary(model, input_size)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1    [-1, 12, 128, 128, 128]             336
       BatchNorm3d-2    [-1, 12, 128, 128, 128]              24
              ReLU-3    [-1, 12, 128, 128, 128]               0
            Conv3d-4    [-1, 12, 128, 128, 128]           3,900
       BatchNorm3d-5    [-1, 12, 128, 128, 128]              24
              ReLU-6    [-1, 12, 128, 128, 128]               0
         MaxPool3d-7       [-1, 12, 64, 64, 64]               0
            Conv3d-8       [-1, 24, 64, 64, 64]           7,800
       BatchNorm3d-9       [-1, 24, 64, 64, 64]              48
             ReLU-10       [-1, 24, 64, 64, 64]               0
           Conv3d-11       [-1, 24, 64, 64, 64]          15,576
      BatchNorm3d-12       [-1, 24, 64, 64, 64]              48
             ReLU-13       [-1, 24, 64, 64, 64]               0
        MaxPool3d-14       [-1, 24, 32,

In [40]:
#Training length
epochs = 2000

loss_function = torch.nn.BCELoss()
#loss_function = torch.nn.CosineSimilarity()
#loss_function = ContrastiveLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0.0005)

# Training Test

In [41]:
from tqdm import tqdm

def train():
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
        
    loss=[] 
    counter=[]
    for epoch in tqdm(range(epochs)):
        epoch_loss = 0
        for i, (x, y) in enumerate(loader):
            
            x = torch.from_numpy(x).to(device).float()
            y = torch.from_numpy(np.array([y])).to(device).float()
            
            #vol0, vol1 , label = data
            #vol0, vol1 , label = vol0.to(device), vol1.to(device) , label.to(device)
            
            optimizer.zero_grad()
            
            pred = model(x).mean(axis = 1).sigmoid()
            #print(model(x))
            #output1 = model(vol0)
            #output2 = model(vol1)
            
            #print(pred,y)
            loss = loss_function(pred, y)
            #loss_contrastive = loss_function(output1,output2,label)
            
            loss.backward()
            optimizer.step()    
            epoch_loss += loss
            
        print(f"epoch {epoch}, training loss={epoch_loss}")
    
    #show_plot(epoch, loss)
    
    return model

model = train()

  0%|                                        | 1/2000 [00:02<1:20:59,  2.43s/it]

epoch 0, training loss=2.3841860752327193e-07


  0%|                                        | 2/2000 [00:04<1:18:43,  2.36s/it]

epoch 1, training loss=2.3841887468734058e-06


  0%|                                        | 3/2000 [00:07<1:17:59,  2.34s/it]

epoch 2, training loss=2.6010806560516357


  0%|                                        | 4/2000 [00:09<1:17:32,  2.33s/it]

epoch 3, training loss=0.00338159897364676


  0%|                                        | 5/2000 [00:11<1:17:16,  2.32s/it]

epoch 4, training loss=0.01047830656170845


  0%|                                        | 6/2000 [00:14<1:17:05,  2.32s/it]

epoch 5, training loss=2.414116859436035


  0%|▏                                       | 7/2000 [00:16<1:16:57,  2.32s/it]

epoch 6, training loss=1.1742184142349288e-05


  0%|▏                                       | 8/2000 [00:18<1:16:53,  2.32s/it]

epoch 7, training loss=0.0


  0%|▏                                       | 9/2000 [00:20<1:16:54,  2.32s/it]

epoch 8, training loss=0.00011981251009274274


  0%|▏                                      | 10/2000 [00:23<1:16:53,  2.32s/it]

epoch 9, training loss=4.734710693359375


  1%|▏                                      | 11/2000 [00:25<1:16:50,  2.32s/it]

epoch 10, training loss=2.3841860752327193e-07


  1%|▏                                      | 12/2000 [00:27<1:16:44,  2.32s/it]

epoch 11, training loss=0.0


  1%|▎                                      | 13/2000 [00:30<1:16:42,  2.32s/it]

epoch 12, training loss=1.788139485370266e-07


  1%|▎                                      | 14/2000 [00:32<1:16:40,  2.32s/it]

epoch 13, training loss=0.0


  1%|▎                                      | 15/2000 [00:34<1:16:35,  2.32s/it]

epoch 14, training loss=0.0


  1%|▎                                      | 16/2000 [00:37<1:16:32,  2.31s/it]

epoch 15, training loss=1.585496102052275e-05


  1%|▎                                      | 17/2000 [00:39<1:16:30,  2.31s/it]

epoch 16, training loss=0.0


  1%|▎                                      | 18/2000 [00:41<1:16:28,  2.32s/it]

epoch 17, training loss=0.001489257556386292


  1%|▎                                      | 19/2000 [00:44<1:16:25,  2.31s/it]

epoch 18, training loss=0.0


  1%|▍                                      | 20/2000 [00:46<1:16:23,  2.31s/it]

epoch 19, training loss=0.0


  1%|▍                                      | 21/2000 [00:48<1:16:19,  2.31s/it]

epoch 20, training loss=0.0


  1%|▍                                      | 22/2000 [00:51<1:16:18,  2.31s/it]

epoch 21, training loss=5.960464477539063e-08


  1%|▍                                      | 23/2000 [00:53<1:16:20,  2.32s/it]

epoch 22, training loss=3.099446303167497e-06


  1%|▍                                      | 24/2000 [00:55<1:16:19,  2.32s/it]

epoch 23, training loss=5.960464477539063e-08


  1%|▍                                      | 25/2000 [00:58<1:16:18,  2.32s/it]

epoch 24, training loss=0.0


  1%|▌                                      | 26/2000 [01:00<1:16:15,  2.32s/it]

epoch 25, training loss=0.0


  1%|▌                                      | 27/2000 [01:02<1:16:12,  2.32s/it]

epoch 26, training loss=0.0


  1%|▌                                      | 28/2000 [01:04<1:16:08,  2.32s/it]

epoch 27, training loss=0.0


  1%|▌                                      | 29/2000 [01:07<1:16:03,  2.32s/it]

epoch 28, training loss=0.0


  2%|▌                                      | 30/2000 [01:09<1:16:01,  2.32s/it]

epoch 29, training loss=0.0


  2%|▌                                      | 31/2000 [01:11<1:16:01,  2.32s/it]

epoch 30, training loss=0.0


  2%|▌                                      | 32/2000 [01:14<1:16:00,  2.32s/it]

epoch 31, training loss=0.0


  2%|▋                                      | 33/2000 [01:16<1:15:58,  2.32s/it]

epoch 32, training loss=0.0


  2%|▋                                      | 34/2000 [01:18<1:15:58,  2.32s/it]

epoch 33, training loss=5.960464477539063e-08


  2%|▋                                      | 35/2000 [01:21<1:15:55,  2.32s/it]

epoch 34, training loss=0.0


  2%|▋                                      | 36/2000 [01:23<1:15:55,  2.32s/it]

epoch 35, training loss=5.960464477539063e-08


  2%|▋                                      | 37/2000 [01:25<1:15:53,  2.32s/it]

epoch 36, training loss=1.788139485370266e-07


  2%|▋                                      | 38/2000 [01:28<1:15:51,  2.32s/it]

epoch 37, training loss=0.0


  2%|▊                                      | 39/2000 [01:30<1:15:48,  2.32s/it]

epoch 38, training loss=5.960464477539063e-08


  2%|▊                                      | 40/2000 [01:32<1:15:45,  2.32s/it]

epoch 39, training loss=0.0


  2%|▊                                      | 41/2000 [01:35<1:15:43,  2.32s/it]

epoch 40, training loss=0.0


  2%|▊                                      | 42/2000 [01:37<1:15:43,  2.32s/it]

epoch 41, training loss=0.0


  2%|▊                                      | 43/2000 [01:39<1:15:41,  2.32s/it]

epoch 42, training loss=1.1920930376163597e-07


  2%|▊                                      | 44/2000 [01:42<1:15:39,  2.32s/it]

epoch 43, training loss=0.0


  2%|▉                                      | 45/2000 [01:44<1:15:38,  2.32s/it]

epoch 44, training loss=0.011329271830618382


  2%|▉                                      | 46/2000 [01:46<1:15:36,  2.32s/it]

epoch 45, training loss=0.0


  2%|▉                                      | 47/2000 [01:49<1:15:34,  2.32s/it]

epoch 46, training loss=0.0


  2%|▉                                      | 48/2000 [01:51<1:15:35,  2.32s/it]

epoch 47, training loss=1.7881409348774469e-06


  2%|▉                                      | 49/2000 [01:53<1:15:34,  2.32s/it]

epoch 48, training loss=0.2014390528202057


  2%|▉                                      | 50/2000 [01:56<1:15:29,  2.32s/it]

epoch 49, training loss=0.0


  3%|▉                                      | 51/2000 [01:58<1:15:28,  2.32s/it]

epoch 50, training loss=5.960464477539063e-08


  3%|█                                      | 52/2000 [02:00<1:15:24,  2.32s/it]

epoch 51, training loss=0.05431920662522316


  3%|█                                      | 53/2000 [02:02<1:15:21,  2.32s/it]

epoch 52, training loss=0.0


  3%|█                                      | 54/2000 [02:05<1:15:22,  2.32s/it]

epoch 53, training loss=0.0


  3%|█                                      | 55/2000 [02:07<1:15:18,  2.32s/it]

epoch 54, training loss=0.0


  3%|█                                      | 56/2000 [02:09<1:15:17,  2.32s/it]

epoch 55, training loss=0.0


  3%|█                                      | 57/2000 [02:12<1:15:14,  2.32s/it]

epoch 56, training loss=0.0


  3%|█▏                                     | 58/2000 [02:14<1:15:14,  2.32s/it]

epoch 57, training loss=0.0


  3%|█▏                                     | 59/2000 [02:16<1:15:12,  2.32s/it]

epoch 58, training loss=0.0


  3%|█▏                                     | 60/2000 [02:19<1:15:09,  2.32s/it]

epoch 59, training loss=0.0


  3%|█▏                                     | 61/2000 [02:21<1:15:09,  2.33s/it]

epoch 60, training loss=0.0


  3%|█▏                                     | 62/2000 [02:23<1:15:08,  2.33s/it]

epoch 61, training loss=2.3841887468734058e-06


  3%|█▏                                     | 63/2000 [02:26<1:15:03,  2.32s/it]

epoch 62, training loss=0.0


  3%|█▏                                     | 64/2000 [02:28<1:14:58,  2.32s/it]

epoch 63, training loss=0.0


  3%|█▎                                     | 65/2000 [02:30<1:15:00,  2.33s/it]

epoch 64, training loss=0.0


  3%|█▎                                     | 66/2000 [02:33<1:14:57,  2.33s/it]

epoch 65, training loss=0.0


  3%|█▎                                     | 67/2000 [02:35<1:14:53,  2.32s/it]

epoch 66, training loss=0.0


  3%|█▎                                     | 68/2000 [02:37<1:14:50,  2.32s/it]

epoch 67, training loss=0.0


  3%|█▎                                     | 69/2000 [02:40<1:14:48,  2.32s/it]

epoch 68, training loss=0.0


  4%|█▎                                     | 70/2000 [02:42<1:14:44,  2.32s/it]

epoch 69, training loss=7.152560215217818e-07


  4%|█▍                                     | 71/2000 [02:44<1:14:37,  2.32s/it]

epoch 70, training loss=0.0


  4%|█▍                                     | 72/2000 [02:47<1:14:32,  2.32s/it]

epoch 71, training loss=0.0


  4%|█▍                                     | 73/2000 [02:49<1:14:34,  2.32s/it]

epoch 72, training loss=0.0


  4%|█▍                                     | 74/2000 [02:51<1:14:34,  2.32s/it]

epoch 73, training loss=7.033372639853042e-06


  4%|█▍                                     | 75/2000 [02:54<1:14:32,  2.32s/it]

epoch 74, training loss=0.0


  4%|█▍                                     | 76/2000 [02:56<1:14:30,  2.32s/it]

epoch 75, training loss=0.0


  4%|█▌                                     | 77/2000 [02:58<1:14:26,  2.32s/it]

epoch 76, training loss=0.0


  4%|█▌                                     | 78/2000 [03:01<1:14:28,  2.32s/it]

epoch 77, training loss=3.4570753086882178e-06


  4%|█▌                                     | 79/2000 [03:03<1:14:24,  2.32s/it]

epoch 78, training loss=0.0


  4%|█▌                                     | 80/2000 [03:05<1:14:24,  2.33s/it]

epoch 79, training loss=0.0


  4%|█▌                                     | 81/2000 [03:08<1:14:24,  2.33s/it]

epoch 80, training loss=0.0


  4%|█▌                                     | 82/2000 [03:10<1:14:25,  2.33s/it]

epoch 81, training loss=0.0


  4%|█▌                                     | 83/2000 [03:12<1:14:26,  2.33s/it]

epoch 82, training loss=4.172333774477011e-06


  4%|█▋                                     | 84/2000 [03:15<1:14:21,  2.33s/it]

epoch 83, training loss=0.0


  4%|█▋                                     | 85/2000 [03:17<1:14:22,  2.33s/it]

epoch 84, training loss=2.3841860752327193e-07


  4%|█▋                                     | 86/2000 [03:19<1:14:27,  2.33s/it]

epoch 85, training loss=0.0


  4%|█▋                                     | 87/2000 [03:22<1:14:20,  2.33s/it]

epoch 86, training loss=0.0


  4%|█▋                                     | 88/2000 [03:24<1:14:14,  2.33s/it]

epoch 87, training loss=6.55653229841846e-06


  4%|█▋                                     | 89/2000 [03:26<1:14:07,  2.33s/it]

epoch 88, training loss=0.0


  4%|█▊                                     | 90/2000 [03:29<1:14:06,  2.33s/it]

epoch 89, training loss=0.0


  5%|█▊                                     | 91/2000 [03:31<1:14:05,  2.33s/it]

epoch 90, training loss=0.0


  5%|█▊                                     | 92/2000 [03:33<1:14:05,  2.33s/it]

epoch 91, training loss=0.0


  5%|█▊                                     | 93/2000 [03:36<1:14:08,  2.33s/it]

epoch 92, training loss=0.0


  5%|█▊                                     | 94/2000 [03:38<1:14:09,  2.33s/it]

epoch 93, training loss=0.00015498408174607903


  5%|█▊                                     | 95/2000 [03:40<1:14:07,  2.33s/it]

epoch 94, training loss=0.0


  5%|█▊                                     | 96/2000 [03:43<1:14:04,  2.33s/it]

epoch 95, training loss=0.0


  5%|█▉                                     | 97/2000 [03:45<1:14:05,  2.34s/it]

epoch 96, training loss=0.0


  5%|█▉                                     | 98/2000 [03:47<1:13:56,  2.33s/it]

epoch 97, training loss=0.0


  5%|█▉                                     | 99/2000 [03:50<1:13:54,  2.33s/it]

epoch 98, training loss=0.0


  5%|█▉                                     | 99/2000 [03:51<1:13:59,  2.34s/it]


KeyboardInterrupt: 

# Implementing the Siamese Network

The above training is just to test if the VGG model works for 3D data. Here, the training will take two pairs of images and calculate the loss from both pairs of images.

In [39]:
from tqdm import tqdm

def train():
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
        
    loss=[] 
    counter=[]
    for epoch in tqdm(range(epochs)):
        epoch_loss = 0
        for i, (x, y) in enumerate(loader):
            
            x = torch.from_numpy(x).to(device).float()
            y = torch.from_numpy(np.array([y])).to(device).float()
            
            #vol0, vol1 , label = data
            #vol0, vol1 , label = vol0.to(device), vol1.to(device) , label.to(device)
            
            optimizer.zero_grad()
            
            pred = model(x).mean(axis = 1).sigmoid()
            #print(model(x))
            #output1 = model(vol0)
            #output2 = model(vol1)
            
            #print(pred,y)
            loss = loss_function(pred, y)
            #loss_contrastive = loss_function(output1,output2,label)
            
            loss.backward()
            optimizer.step()    
            epoch_loss += loss
            
        print(f"epoch {epoch}, training loss={epoch_loss}")
    
    return model

model = train()

In [33]:
for i, (x, y) in enumerate(loader):
    #print(i)
    print(np.shape(x))
    print((y))
#print(x)
#print(y)

(1, 1, 128, 128, 128)
1
(1, 1, 128, 128, 128)
0
