In [1]:
from tomoSegmentPipeline.utils.common import read_array, write_array
from tomoSegmentPipeline.utils import setup
from tomoSegmentPipeline.dataloader import tomoSegment_dataset
from cryoS2Sdrop.analyze import plot_centralSlices
from tomoSegmentPipeline.model import DeepFinder_model
from tomoSegmentPipeline.losses import Tversky_loss, Tversky1_loss

import numpy as np
import matplotlib.pyplot as plt
import torch
import os
from torch.utils.data import Dataset, DataLoader
from torchinfo import summary

PARENT_PATH = setup.PARENT_PATH

%matplotlib inline
%config Completer.use_jedi = False
%load_ext autoreload
%autoreload 2

In [2]:
path_data, path_target = setup.get_paths(['tomo02', 'tomo04'], 'rawCET')

Ncl = 2
dim_in = 84
lr = 1e-4
weight_decay = 0
Lrnd = 0
augment_data = False

my_dataset = tomoSegment_dataset(path_data, path_target, dim_in, Ncl, Lrnd, augment_data)
len(my_dataset)

152

In [3]:
batch_size = 2
dloader = DataLoader(my_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)

In [4]:
for batch in dloader:
    data, target = batch
    print(data.shape, target.shape)
    break

torch.Size([2, 1, 84, 84, 84]) torch.Size([2, 3, 84, 84, 84])


# Model

In [5]:
Ncl = 2
dim_in = 84
lr = 1e-4
weight_decay = 0
Lrnd = 18
augment_data = True
batch_size = 22
pretrain_type = None
loss_fn = Tversky_loss()

model = DeepFinder_model(Ncl, loss_fn, lr, weight_decay, pretrain_type)

In [21]:
model(data).shape

torch.Size([2, 2, 84, 84, 84])

In [22]:
summary(model, (batch_size, 1, dim_in, dim_in, dim_in), device='cpu')

Layer (type:depth-idx)                   Output Shape              Param #
DeepFinder_model                         [22, 2, 84, 84, 84]       --
├─Tversky_loss: 1-1                      --                        --
├─Sequential: 1-2                        [22, 32, 84, 84, 84]      --
│    └─Conv3d: 2-1                       [22, 32, 84, 84, 84]      896
│    └─ReLU: 2-2                         [22, 32, 84, 84, 84]      --
│    └─Conv3d: 2-3                       [22, 32, 84, 84, 84]      27,680
│    └─ReLU: 2-4                         [22, 32, 84, 84, 84]      --
├─Sequential: 1-3                        [22, 48, 42, 42, 42]      --
│    └─MaxPool3d: 2-5                    [22, 32, 42, 42, 42]      --
│    └─Conv3d: 2-6                       [22, 48, 42, 42, 42]      41,520
│    └─ReLU: 2-7                         [22, 48, 42, 42, 42]      --
│    └─Conv3d: 2-8                       [22, 48, 42, 42, 42]      62,256
│    └─ReLU: 2-9                         [22, 48, 42, 42, 42]      --
├─

# Loss function

In [6]:
loss = Tversky_loss()
pred = model(data)

In [7]:
loss(pred.cuda(), target.cuda())

tensor(1.1595, device='cuda:0', grad_fn=<RsubBackward1>)

In [8]:
loss(target.cuda(), target.cuda())

tensor(0., device='cuda:0')

In [9]:
loss = Tversky1_loss()
loss(pred.cuda(), target.cuda())

tensor(0.7437, device='cuda:0', grad_fn=<RsubBackward1>)