# Medical image analysis with PyTorch Section 3

Create a not-so-deep convolution network.  

In [26]:
from torch import nn
import torch.nn.functional as F
import nibabel as nib
import numpy as np


In [40]:
class ConvTestNet(nn.Module):
    def __init__(self, s=32):
        super().__init__()
        self.conv1 = nn.Conv3d(in_channels=1, out_channels=s, kernel_size=3, padding=1, bias=False)
        self.pool1 = nn.MaxPool3d(kernel_size=2)
        
        self.o_conv1 = nn.ConvTranspose3d(in_channels=s, out_channels=1, kernel_size=3)
                
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool1(x)
        x = F.sigmoid(self.o_conv1(x))
        return x

In [41]:
model = ConvTestNet()
print(model)

ConvTestNet(
  (conv1): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
  (pool1): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (o_conv1): ConvTranspose3d(32, 1, kernel_size=(3, 3, 3), stride=(1, 1, 1))
)


## Test input/output dims

Not sure if this is how you want to test output dims of an NN, but this shows you get the same dim as the output.

In [44]:
img = np.zeros((32,1,3,3,3))
img = torch.from_numpy(np.asarray(img)).float()
o = model.forward(img)
o.shape

torch.Size([32, 1, 3, 3, 3])