In [1]:
import torch
from utils.loader import DicomDataset3D
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import utils.notebooks as nb

train_dataloader = DataLoader(DicomDataset3D("data/train.csv"), batch_size=2)
test_dataloader = DataLoader(DicomDataset3D("data/test.csv"), batch_size=2)

In [2]:
train_features, train_labels = next(iter(train_dataloader))
print(f'inputs shape: {train_features.shape}')
print(f'labels shape: {train_labels.shape}')
feat1, label1 = train_features[0], train_labels[0]
print(f'inputs type: {feat1.dtype}')
print(f'labels type: {label1.dtype}')

inputs shape: torch.Size([2, 1, 91, 512, 512])
labels shape: torch.Size([2, 1, 91, 512, 512])
inputs type: torch.float32
labels type: torch.float32


### Single ConvBlock

In [3]:
from models.WNet import *
input = train_features

conv_block = ConvBlock(1, 8)
x = conv_block(input)
x.shape

torch.Size([2, 8, 91, 512, 512])

### Encoder

In [4]:
layers = [8, 16, 32]

conv_block = ConvBlock(1, 8)
encoder = Encoder(layers)
x = conv_block(input)
x = encoder(x)
for ftr in x: print(ftr.shape)

ConvBlock(
  (block): Sequential(
    (0): Conv3d(8, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): ReLU()
    (2): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (4): ReLU()
    (5): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)
ConvBlock(
  (block): Sequential(
    (0): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): ReLU()
    (2): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (4): ReLU()
    (5): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)


### Decoder

In [None]:
encoder = Decoder(layers)
x = conv_block(input)
x.shape

torch.Size([2, 8, 91, 512, 512])