# Model Implementation for 3D Cell Tracking


In [5]:
!pip install torchsummary 



In [6]:
from torch.utils.data import DataLoader, random_split
from torch.utils.data.sampler import WeightedRandomSampler
from torchvision.datasets import ImageFolder
from torchvision import transforms
import torch
import torch.nn as nn
import numpy as np
import random 
import matplotlib.pyplot as plt
from torchvision import models
from torchsummary import summary

from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Designed VGG Model from Exercise 9

We will use a VGG network to classify the synapse images. The input to the network will be a 2D image as provided by your dataloader. The output will be a vector of six floats, corresponding to the probability of the input to belong to the six classes.

Implement a VGG network with the following specificatons:

* the constructor takes the size of the 2D input image as height and width
* the network starts with a downsample path consisting of:
    * one convolutional layer, kernel size (3, 3), to create 12 `fmaps`
    * `nn.BatchNorm2d` over those feature maps
    * `nn.ReLU` activation function
    * `nn.Conv2d` layer, kernel size (3, 3), to create 12 `fmaps`
    * `nn.BatchNorm2d` over those feature maps
    * `nn.ReLU` activation function
    * `nn.MaxPool2d` with a `downsample_factor` of (2, 2) at each level
* followed by three more downsampling paths like the one above, every time doubling the number of `fmaps` (i.e., the second one will have 24, the third 48, and the fourth 96). Make sure to keep track of the `current_fmaps` each time!
* then two times:
    * `nn.Linear` layer with `out_features=4096`. Be careful withe in `in_features` of the first one, which will depend on the size of the previous output!
    * `nn.ReLU` activation function
    * `nn.DropOut`
* Finally, one more fully connected layer with
    * `nn.Linear` to the 6 classes
    * no activation function 

Original One (from https://blog.paperspace.com/vgg-from-scratch-pytorch/)
https://github.com/pytorch/vision/blob/6db1569c89094cf23f3bc41f79275c45e9fcb3f3/torchvision/models/vgg.py#L24

# Create and Load Artificial Data
Make a fake dataset to test on the VGG model while waiting for data.


In [33]:
#original data size = 512x712x34 

#fd_class1 = np.random.randn(1, 1, 128,128,128) + 0.5
#fd_class2 = np.random.randn(1,1,128,128,128)
fd_class1 = np.random.randn(1, 1, 128,128,5) + 0.5
fd_class2 = np.random.randn(1,1,128,128,5)
y1 = 1
y2 = 0
# fd_class1 = np.expand_dims(fd_class1, axis=0)
# fd_class2 = np.expand_dims(fd_class2, axis=0)

loader = [(fd_class1, y1), (fd_class2,y2)]

#Split
# train_set_size = int(len(fd_class1) * 0.7)
# valid_set_size = int(len(fd_class1) * 0.2)
# test_data_size = len(fd_class1) - train_set_size - valid_set_size
    
# train_data_C1, val_data_C1, test_data_C1 = random_split(
#     fd_class1,
#     [train_set_size, valid_set_size, test_data_size],
#     generator=torch.Generator().manual_seed(23061912))

# train_data_C2, val_data_C2, test_data_C2 = random_split(
#     fd_class2,
#     [train_set_size, valid_set_size, test_data_size],
#     generator=torch.Generator().manual_seed(23061912))

#train = np.ndarray.flatten(fd_class1)
#train2 = np.ndarray.flatten(fd_class2)




In [8]:
#sampler = balanced_sampler(train_data_C1)
#dataloader = DataLoader(train_data_C1, batch_size=8, drop_last=True)

# Define the Model

In [9]:
class Vgg3D(torch.nn.Module):

    def __init__(self, input_size, output_classes, downsample_factors, fmaps=12):

        super(Vgg3D, self).__init__()

        self.input_size = input_size
        self.downsample_factors = downsample_factors
        self.output_classes = 2

        current_fmaps, h, w, d = tuple(input_size)
        current_size = (h, w,d)

        features = []
        for i in range(len(downsample_factors)):

            features += [
                torch.nn.Conv3d(current_fmaps,fmaps,kernel_size=3,padding=1),
                torch.nn.BatchNorm3d(fmaps),
                torch.nn.ReLU(inplace=True),
                torch.nn.Conv3d(fmaps,fmaps,kernel_size=3,padding=1),
                torch.nn.BatchNorm3d(fmaps),
                torch.nn.ReLU(inplace=True),
                torch.nn.MaxPool3d(downsample_factors[i])
            ]

            current_fmaps = fmaps
            fmaps *= 2

            size = tuple(
                int(c/d)
                for c, d in zip(current_size, downsample_factors[i]))
            check = (
                s*d == c
                for s, d, c in zip(size, downsample_factors[i], current_size))
            assert all(check), \
                "Can not downsample %s by chosen downsample factor" % \
                (current_size,)
            current_size = size

        self.features = torch.nn.Sequential(*features)

        classifier = [
            torch.nn.Linear(current_size[0] *current_size[1]*current_size[2] *current_fmaps,4096),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(),
            torch.nn.Linear(4096,4096),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(),
            torch.nn.Linear(4096,output_classes)
        ]

        self.classifier = torch.nn.Sequential(*classifier)
    
    def forward(self, raw):

        # add a channel dimension to raw
        # shape = tuple(raw.shape)
        # raw = raw.reshape(shape[0], 1, shape[1], shape[2])
        
        # compute features
        f = self.features(raw)
        f = f.view(f.size(0), -1)
        
        # classify
        y = self.classifier(f)

        return y

# Training and Evaluation

# Loss Functions

We'll probably need to test some different loss functions. List some here:
Contrastive loss
cosine similarity
triplet loss



In [10]:
class ContrastiveLoss(nn.Module):
    "Contrastive loss function"

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean(
            (1 - label) * torch.pow(euclidean_distance, 2)
            + (label)
            * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
        )

        return loss_contrastive

In [36]:
input_size = (1, 64, 64, 2)
downsample_factors =[(2, 2, 1), (2, 2, 1), (2, 2, 1), (2, 2, 1)];
output_classes = 12

# create the model to train
model = Vgg3D(input_size, output_classes,  downsample_factors = downsample_factors)
model = model.to(device)

summary(model, input_size)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1        [-1, 12, 64, 64, 2]             336
       BatchNorm3d-2        [-1, 12, 64, 64, 2]              24
              ReLU-3        [-1, 12, 64, 64, 2]               0
            Conv3d-4        [-1, 12, 64, 64, 2]           3,900
       BatchNorm3d-5        [-1, 12, 64, 64, 2]              24
              ReLU-6        [-1, 12, 64, 64, 2]               0
         MaxPool3d-7        [-1, 12, 32, 32, 2]               0
            Conv3d-8        [-1, 24, 32, 32, 2]           7,800
       BatchNorm3d-9        [-1, 24, 32, 32, 2]              48
             ReLU-10        [-1, 24, 32, 32, 2]               0
           Conv3d-11        [-1, 24, 32, 32, 2]          15,576
      BatchNorm3d-12        [-1, 24, 32, 32, 2]              48
             ReLU-13        [-1, 24, 32, 32, 2]               0
        MaxPool3d-14        [-1, 24, 16

In [12]:
#Training length
epochs = 2000

#loss_function = torch.nn.BCELoss()
loss_function = torch.nn.CosineEmbeddingLoss()
#loss_function = ContrastiveLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0.0005)

# Training Test

In [52]:
from tqdm import tqdm

def train():
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
        
    loss=[] 
    counter=[]
    for epoch in tqdm(range(epochs)):
        epoch_loss = 0
        for i, (x, y) in enumerate(loader):
            
            x = torch.from_numpy(x).to(device).float()
            y = torch.from_numpy(np.array([y])).to(device).float()
            
            #vol0, vol1 , label = data
            #vol0, vol1 , label = vol0.to(device), vol1.to(device) , label.to(device)
            
            optimizer.zero_grad()
            
            pred = model(x).mean(axis = 1).sigmoid()
            #print(model(x))
            #output1 = model(vol0)
            #output2 = model(vol1)
            
            #print(pred,y)
            loss = loss_function(pred, y)
            #loss_contrastive = loss_function(output1,output2,label)
            
            loss.backward()
            optimizer.step()    
            epoch_loss += loss
            
        print(f"epoch {epoch}, training loss={epoch_loss}")
    
    #show_plot(epoch, loss)
    
    return model

model = train()

  0%|                                        | 1/2000 [00:02<1:21:15,  2.44s/it]

epoch 0, training loss=3.0052499771118164


  0%|                                        | 2/2000 [00:04<1:18:54,  2.37s/it]

epoch 1, training loss=3.012202262878418


  0%|                                        | 3/2000 [00:07<1:18:07,  2.35s/it]

epoch 2, training loss=1.7127869129180908


  0%|                                        | 4/2000 [00:09<1:17:42,  2.34s/it]

epoch 3, training loss=1.0924696922302246


  0%|                                        | 5/2000 [00:11<1:17:26,  2.33s/it]

epoch 4, training loss=1.3951603174209595


  0%|                                        | 6/2000 [00:14<1:17:17,  2.33s/it]

epoch 5, training loss=3.6244781017303467


  0%|▏                                       | 7/2000 [00:16<1:17:10,  2.32s/it]

epoch 6, training loss=1.4776408672332764


  0%|▏                                       | 8/2000 [00:18<1:17:05,  2.32s/it]

epoch 7, training loss=0.8746283650398254


  0%|▏                                       | 9/2000 [00:20<1:17:01,  2.32s/it]

epoch 8, training loss=2.9282970428466797


  0%|▏                                      | 10/2000 [00:23<1:16:59,  2.32s/it]

epoch 9, training loss=0.2575957477092743


  1%|▏                                      | 11/2000 [00:25<1:16:58,  2.32s/it]

epoch 10, training loss=0.12791509926319122


  1%|▏                                      | 12/2000 [00:27<1:16:54,  2.32s/it]

epoch 11, training loss=0.32547080516815186


  1%|▎                                      | 13/2000 [00:30<1:16:57,  2.32s/it]

epoch 12, training loss=0.06220933794975281


  1%|▎                                      | 14/2000 [00:32<1:16:55,  2.32s/it]

epoch 13, training loss=0.05425200238823891


  1%|▎                                      | 15/2000 [00:34<1:16:53,  2.32s/it]

epoch 14, training loss=0.07212472707033157


  1%|▎                                      | 16/2000 [00:37<1:16:50,  2.32s/it]

epoch 15, training loss=0.12383498251438141


  1%|▎                                      | 17/2000 [00:39<1:16:49,  2.32s/it]

epoch 16, training loss=0.01637139357626438


  1%|▎                                      | 18/2000 [00:41<1:16:50,  2.33s/it]

epoch 17, training loss=0.03595505282282829


  1%|▎                                      | 19/2000 [00:44<1:16:47,  2.33s/it]

epoch 18, training loss=0.15058046579360962


  1%|▍                                      | 20/2000 [00:46<1:16:45,  2.33s/it]

epoch 19, training loss=1.123289704322815


  1%|▍                                      | 21/2000 [00:48<1:16:43,  2.33s/it]

epoch 20, training loss=0.010253466665744781


  1%|▍                                      | 22/2000 [00:51<1:16:38,  2.32s/it]

epoch 21, training loss=0.006728644948452711


  1%|▍                                      | 23/2000 [00:53<1:16:37,  2.33s/it]

epoch 22, training loss=0.006067739799618721


  1%|▍                                      | 24/2000 [00:55<1:16:35,  2.33s/it]

epoch 23, training loss=0.0001572479377500713


  1%|▍                                      | 25/2000 [00:58<1:16:32,  2.33s/it]

epoch 24, training loss=5.602850251307245e-06


  1%|▌                                      | 26/2000 [01:00<1:16:37,  2.33s/it]

epoch 25, training loss=0.0001462208601878956


  1%|▌                                      | 27/2000 [01:02<1:16:35,  2.33s/it]

epoch 26, training loss=3.576279254957626e-07


  1%|▌                                      | 28/2000 [01:05<1:16:28,  2.33s/it]

epoch 27, training loss=7.748606662971724e-07


  1%|▌                                      | 29/2000 [01:07<1:16:27,  2.33s/it]

epoch 28, training loss=0.0012869765050709248


  2%|▌                                      | 30/2000 [01:09<1:16:23,  2.33s/it]

epoch 29, training loss=0.00023771745327394456


  2%|▌                                      | 31/2000 [01:12<1:16:21,  2.33s/it]

epoch 30, training loss=4.5895640141679905e-06


  2%|▌                                      | 32/2000 [01:14<1:16:19,  2.33s/it]

epoch 31, training loss=1.2219012205605395e-05


  2%|▋                                      | 33/2000 [01:16<1:16:17,  2.33s/it]

epoch 32, training loss=4.529963462118758e-06


  2%|▋                                      | 34/2000 [01:19<1:16:14,  2.33s/it]

epoch 33, training loss=0.007000154349952936


  2%|▋                                      | 35/2000 [01:21<1:16:11,  2.33s/it]

epoch 34, training loss=1.0430863767396659e-05


  2%|▋                                      | 36/2000 [01:23<1:16:11,  2.33s/it]

epoch 35, training loss=0.0003237652126699686


  2%|▋                                      | 37/2000 [01:26<1:16:10,  2.33s/it]

epoch 36, training loss=0.00046669403673149645


  2%|▋                                      | 38/2000 [01:28<1:16:06,  2.33s/it]

epoch 37, training loss=5.9008771131630056e-06


  2%|▊                                      | 39/2000 [01:30<1:16:16,  2.33s/it]

epoch 38, training loss=0.0003932891704607755


  2%|▊                                      | 40/2000 [01:33<1:16:24,  2.34s/it]

epoch 39, training loss=1.4305124977909145e-06


  2%|▊                                      | 41/2000 [01:35<1:16:35,  2.35s/it]

epoch 40, training loss=5.042655539000407e-05


  2%|▊                                      | 42/2000 [01:37<1:16:43,  2.35s/it]

epoch 41, training loss=1.6689440599293448e-05


  2%|▊                                      | 43/2000 [01:40<1:16:30,  2.35s/it]

epoch 42, training loss=1.1920935776288388e-06


  2%|▊                                      | 44/2000 [01:42<1:16:15,  2.34s/it]

epoch 43, training loss=2.0265599687263602e-06


  2%|▉                                      | 45/2000 [01:44<1:16:04,  2.33s/it]

epoch 44, training loss=1.001360851660138e-05


  2%|▉                                      | 46/2000 [01:47<1:15:56,  2.33s/it]

epoch 45, training loss=2.3841860752327193e-07


  2%|▉                                      | 47/2000 [01:49<1:15:50,  2.33s/it]

epoch 46, training loss=9.119551577896345e-06


  2%|▉                                      | 48/2000 [01:51<1:15:49,  2.33s/it]

epoch 47, training loss=0.001830404158681631


  2%|▉                                      | 49/2000 [01:54<1:16:01,  2.34s/it]

epoch 48, training loss=9.668338316259906e-05


  2%|▉                                      | 50/2000 [01:56<1:16:10,  2.34s/it]

epoch 49, training loss=9.990236139856279e-05


  3%|▉                                      | 51/2000 [01:58<1:16:10,  2.35s/it]

epoch 50, training loss=5.090366175863892e-05


  3%|█                                      | 52/2000 [02:01<1:16:12,  2.35s/it]

epoch 51, training loss=0.008482864126563072


  3%|█                                      | 53/2000 [02:03<1:16:10,  2.35s/it]

epoch 52, training loss=0.019860263913869858


  3%|█                                      | 54/2000 [02:05<1:16:05,  2.35s/it]

epoch 53, training loss=0.5437625646591187


  3%|█                                      | 55/2000 [02:08<1:15:48,  2.34s/it]

epoch 54, training loss=0.0


  3%|█                                      | 56/2000 [02:10<1:15:39,  2.34s/it]

epoch 55, training loss=0.12990209460258484


  3%|█                                      | 57/2000 [02:12<1:15:34,  2.33s/it]

epoch 56, training loss=0.008733268827199936


  3%|█▏                                     | 58/2000 [02:15<1:15:42,  2.34s/it]

epoch 57, training loss=0.004572039004415274


  3%|█▏                                     | 59/2000 [02:17<1:15:47,  2.34s/it]

epoch 58, training loss=7.152560215217818e-07


  3%|█▏                                     | 60/2000 [02:19<1:15:46,  2.34s/it]

epoch 59, training loss=4.768382950715022e-06


  3%|█▏                                     | 61/2000 [02:22<1:15:42,  2.34s/it]

epoch 60, training loss=8.344653679159819e-07


  3%|█▏                                     | 62/2000 [02:24<1:15:33,  2.34s/it]

epoch 61, training loss=0.0


  3%|█▏                                     | 63/2000 [02:26<1:15:25,  2.34s/it]

epoch 62, training loss=5.841272468387615e-06


  3%|█▏                                     | 64/2000 [02:29<1:15:27,  2.34s/it]

epoch 63, training loss=2.4437933916487964e-06


  3%|█▎                                     | 65/2000 [02:31<1:15:17,  2.33s/it]

epoch 64, training loss=0.0


  3%|█▎                                     | 66/2000 [02:33<1:15:09,  2.33s/it]

epoch 65, training loss=0.0


  3%|█▎                                     | 67/2000 [02:36<1:15:05,  2.33s/it]

epoch 66, training loss=0.0


  3%|█▎                                     | 68/2000 [02:38<1:14:57,  2.33s/it]

epoch 67, training loss=1.1920930376163597e-07


  3%|█▎                                     | 69/2000 [02:40<1:14:51,  2.33s/it]

epoch 68, training loss=0.0009415408712811768


  4%|█▎                                     | 70/2000 [02:43<1:14:48,  2.33s/it]

epoch 69, training loss=0.0006705385749228299


  4%|█▍                                     | 71/2000 [02:45<1:14:47,  2.33s/it]

epoch 70, training loss=1.2695870282186661e-05


  4%|█▍                                     | 72/2000 [02:47<1:14:43,  2.33s/it]

epoch 71, training loss=0.0


  4%|█▍                                     | 73/2000 [02:50<1:14:46,  2.33s/it]

epoch 72, training loss=0.0004028892144560814


  4%|█▍                                     | 74/2000 [02:52<1:14:44,  2.33s/it]

epoch 73, training loss=0.0


  4%|█▍                                     | 75/2000 [02:54<1:14:48,  2.33s/it]

epoch 74, training loss=1.1920930376163597e-07


  4%|█▍                                     | 76/2000 [02:57<1:14:51,  2.33s/it]

epoch 75, training loss=0.004787262994796038


  4%|█▌                                     | 77/2000 [02:59<1:14:51,  2.34s/it]

epoch 76, training loss=1.1920930376163597e-07


  4%|█▌                                     | 78/2000 [03:01<1:14:49,  2.34s/it]

epoch 77, training loss=0.0


  4%|█▌                                     | 79/2000 [03:04<1:14:44,  2.33s/it]

epoch 78, training loss=1.5974172129062936e-05


  4%|█▌                                     | 80/2000 [03:06<1:14:38,  2.33s/it]

epoch 79, training loss=0.0


  4%|█▌                                     | 81/2000 [03:08<1:14:35,  2.33s/it]

epoch 80, training loss=0.0


  4%|█▌                                     | 82/2000 [03:11<1:14:32,  2.33s/it]

epoch 81, training loss=0.00512600177899003


  4%|█▌                                     | 83/2000 [03:13<1:14:24,  2.33s/it]

epoch 82, training loss=0.0


  4%|█▋                                     | 84/2000 [03:15<1:14:31,  2.33s/it]

epoch 83, training loss=0.0


  4%|█▋                                     | 85/2000 [03:18<1:14:29,  2.33s/it]

epoch 84, training loss=0.0


  4%|█▋                                     | 86/2000 [03:20<1:14:24,  2.33s/it]

epoch 85, training loss=0.0


  4%|█▋                                     | 87/2000 [03:22<1:14:21,  2.33s/it]

epoch 86, training loss=0.0


  4%|█▋                                     | 88/2000 [03:25<1:14:15,  2.33s/it]

epoch 87, training loss=0.0


  4%|█▋                                     | 89/2000 [03:27<1:14:12,  2.33s/it]

epoch 88, training loss=0.0


  4%|█▊                                     | 90/2000 [03:29<1:14:06,  2.33s/it]

epoch 89, training loss=0.0


  5%|█▊                                     | 91/2000 [03:32<1:14:00,  2.33s/it]

epoch 90, training loss=0.0


  5%|█▊                                     | 92/2000 [03:34<1:14:03,  2.33s/it]

epoch 91, training loss=0.0015056733973324299


  5%|█▊                                     | 93/2000 [03:36<1:13:58,  2.33s/it]

epoch 92, training loss=0.0


  5%|█▊                                     | 94/2000 [03:39<1:13:58,  2.33s/it]

epoch 93, training loss=0.0


  5%|█▊                                     | 95/2000 [03:41<1:13:55,  2.33s/it]

epoch 94, training loss=0.0


  5%|█▊                                     | 96/2000 [03:43<1:13:56,  2.33s/it]

epoch 95, training loss=2.3841860752327193e-07


  5%|█▉                                     | 97/2000 [03:46<1:13:51,  2.33s/it]

epoch 96, training loss=0.0


  5%|█▉                                     | 98/2000 [03:48<1:13:51,  2.33s/it]

epoch 97, training loss=0.0


  5%|█▉                                     | 99/2000 [03:50<1:13:45,  2.33s/it]

epoch 98, training loss=0.0


  5%|█▉                                    | 100/2000 [03:53<1:13:45,  2.33s/it]

epoch 99, training loss=0.0


  5%|█▉                                    | 101/2000 [03:55<1:13:40,  2.33s/it]

epoch 100, training loss=0.0


  5%|█▉                                    | 102/2000 [03:57<1:13:38,  2.33s/it]

epoch 101, training loss=0.0


  5%|█▉                                    | 103/2000 [04:00<1:13:36,  2.33s/it]

epoch 102, training loss=0.0


  5%|█▉                                    | 104/2000 [04:02<1:13:33,  2.33s/it]

epoch 103, training loss=0.0


  5%|█▉                                    | 105/2000 [04:04<1:13:28,  2.33s/it]

epoch 104, training loss=0.0


  5%|██                                    | 106/2000 [04:07<1:13:26,  2.33s/it]

epoch 105, training loss=0.0


  5%|██                                    | 107/2000 [04:09<1:13:26,  2.33s/it]

epoch 106, training loss=0.0


  5%|██                                    | 108/2000 [04:11<1:13:23,  2.33s/it]

epoch 107, training loss=0.0


  5%|██                                    | 109/2000 [04:14<1:13:23,  2.33s/it]

epoch 108, training loss=0.0


  6%|██                                    | 110/2000 [04:16<1:13:20,  2.33s/it]

epoch 109, training loss=0.0


  6%|██                                    | 111/2000 [04:18<1:13:22,  2.33s/it]

epoch 110, training loss=0.0


  6%|██▏                                   | 112/2000 [04:21<1:13:24,  2.33s/it]

epoch 111, training loss=0.0


  6%|██▏                                   | 113/2000 [04:23<1:13:17,  2.33s/it]

epoch 112, training loss=0.0


  6%|██▏                                   | 114/2000 [04:25<1:13:17,  2.33s/it]

epoch 113, training loss=0.0


  6%|██▏                                   | 115/2000 [04:28<1:13:12,  2.33s/it]

epoch 114, training loss=0.0


  6%|██▏                                   | 116/2000 [04:30<1:13:07,  2.33s/it]

epoch 115, training loss=0.0


  6%|██▏                                   | 117/2000 [04:32<1:13:02,  2.33s/it]

epoch 116, training loss=0.0


  6%|██▏                                   | 118/2000 [04:35<1:13:01,  2.33s/it]

epoch 117, training loss=4.768372718899627e-07


  6%|██▎                                   | 119/2000 [04:37<1:12:59,  2.33s/it]

epoch 118, training loss=0.0


  6%|██▎                                   | 120/2000 [04:39<1:13:00,  2.33s/it]

epoch 119, training loss=1.1920930376163597e-07


  6%|██▎                                   | 121/2000 [04:42<1:12:55,  2.33s/it]

epoch 120, training loss=0.0


  6%|██▎                                   | 122/2000 [04:44<1:12:50,  2.33s/it]

epoch 121, training loss=0.0


  6%|██▎                                   | 123/2000 [04:46<1:12:49,  2.33s/it]

epoch 122, training loss=0.0


  6%|██▎                                   | 124/2000 [04:49<1:12:49,  2.33s/it]

epoch 123, training loss=0.0


  6%|██▍                                   | 125/2000 [04:51<1:12:47,  2.33s/it]

epoch 124, training loss=0.0


  6%|██▍                                   | 126/2000 [04:53<1:12:46,  2.33s/it]

epoch 125, training loss=1.788139485370266e-07


  6%|██▍                                   | 127/2000 [04:56<1:12:47,  2.33s/it]

epoch 126, training loss=0.0


  6%|██▍                                   | 128/2000 [04:58<1:12:42,  2.33s/it]

epoch 127, training loss=0.0


  6%|██▍                                   | 129/2000 [05:00<1:12:41,  2.33s/it]

epoch 128, training loss=0.0


  6%|██▍                                   | 130/2000 [05:03<1:12:37,  2.33s/it]

epoch 129, training loss=2.527268952690065e-05


  7%|██▍                                   | 131/2000 [05:05<1:12:32,  2.33s/it]

epoch 130, training loss=0.0


  7%|██▌                                   | 132/2000 [05:07<1:12:24,  2.33s/it]

epoch 131, training loss=0.0


  7%|██▌                                   | 133/2000 [05:10<1:12:18,  2.32s/it]

epoch 132, training loss=0.0


  7%|██▌                                   | 134/2000 [05:12<1:12:16,  2.32s/it]

epoch 133, training loss=0.0


  7%|██▌                                   | 135/2000 [05:14<1:12:13,  2.32s/it]

epoch 134, training loss=0.0


  7%|██▌                                   | 136/2000 [05:17<1:12:17,  2.33s/it]

epoch 135, training loss=5.960466182841628e-07


  7%|██▌                                   | 137/2000 [05:19<1:12:23,  2.33s/it]

epoch 136, training loss=2.3841860752327193e-07


  7%|██▌                                   | 138/2000 [05:21<1:12:26,  2.33s/it]

epoch 137, training loss=3.3379157684976235e-05


  7%|██▋                                   | 139/2000 [05:24<1:12:26,  2.34s/it]

epoch 138, training loss=0.00031690849573351443


  7%|██▋                                   | 140/2000 [05:26<1:12:21,  2.33s/it]

epoch 139, training loss=0.0


  7%|██▋                                   | 141/2000 [05:28<1:12:13,  2.33s/it]

epoch 140, training loss=0.0


  7%|██▋                                   | 142/2000 [05:31<1:12:11,  2.33s/it]

epoch 141, training loss=9.536747711536009e-07


  7%|██▋                                   | 143/2000 [05:33<1:12:06,  2.33s/it]

epoch 142, training loss=0.0


  7%|██▋                                   | 144/2000 [05:35<1:11:58,  2.33s/it]

epoch 143, training loss=0.0


  7%|██▊                                   | 145/2000 [05:37<1:11:55,  2.33s/it]

epoch 144, training loss=0.0


  7%|██▊                                   | 146/2000 [05:40<1:11:55,  2.33s/it]

epoch 145, training loss=0.0


  7%|██▊                                   | 147/2000 [05:42<1:11:53,  2.33s/it]

epoch 146, training loss=0.0


  7%|██▊                                   | 148/2000 [05:44<1:11:51,  2.33s/it]

epoch 147, training loss=5.960464477539063e-08


  7%|██▊                                   | 149/2000 [05:47<1:11:57,  2.33s/it]

epoch 148, training loss=0.0002823870745487511


  8%|██▊                                   | 150/2000 [05:49<1:11:50,  2.33s/it]

epoch 149, training loss=0.0


  8%|██▊                                   | 151/2000 [05:51<1:11:49,  2.33s/it]

epoch 150, training loss=0.0


  8%|██▉                                   | 152/2000 [05:54<1:11:49,  2.33s/it]

epoch 151, training loss=0.0


  8%|██▉                                   | 153/2000 [05:56<1:11:43,  2.33s/it]

epoch 152, training loss=0.0


  8%|██▉                                   | 154/2000 [05:58<1:11:41,  2.33s/it]

epoch 153, training loss=0.0


  8%|██▉                                   | 155/2000 [06:01<1:11:41,  2.33s/it]

epoch 154, training loss=9.298368240706623e-06


  8%|██▉                                   | 156/2000 [06:03<1:11:35,  2.33s/it]

epoch 155, training loss=0.0


  8%|██▉                                   | 156/2000 [06:05<1:12:05,  2.35s/it]


KeyboardInterrupt: 

# Implementing the Siamese Network

The above training is just to test if the VGG model works for 3D data. Here, the training will take two pairs of images and calculate the loss from both pairs of images.

In [13]:
from tqdm import tqdm

def train():
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
        
    loss=[] 
    counter=[]
    for epoch in tqdm(range(epochs)):
        epoch_loss = 0
        for i, (x, y) in enumerate(loader):
            
            x = torch.from_numpy(x).to(device).float()
            y = torch.from_numpy(np.array([y])).to(device).float()
            
            vol0, vol1 , label0, label1 = data
            vol0, vol1 , label = vol0.to(device), vol1.to(device), label.to(device)
            
            optimizer.zero_grad()
            
            output1 = model(vol0)
            output2 = model(vol1)
            
            if(label0 == label1):
                isMatch = 1;
            else:
                isMatch = 0;
                        
            loss = loss_function((output1, output2), isMatch)
            
            loss.backward()
            optimizer.step()    
            epoch_loss += loss
            
        print(f"epoch {epoch}, training loss={epoch_loss}")
    
    return model

def validate():    
    model.eval()
    #dataloader = DataLoader(validation_dataset, batch_size=32)
    
    return evaluate(dataloader, 'validate')

def test():    
    model.eval()
    dataloader = DataLoader(test_dataset, batch_size=32)
    
    return evaluate(dataloader, 'test')

model = train()

  0%|                                                  | 0/2000 [00:00<?, ?it/s]


NameError: name 'data' is not defined

In [32]:
for i, (x, y) in enumerate(loader):
    #print(i)
    #print(np.shape(x))
    print(y)
#print(x)
#print(y)


1
0
