# 3D Convolutional Neural Network for Tomographic Alignment

## Regular CNN

In order to test potential methods for performing automatic tomographic alignment using neural networks, we start with a standard model using a three dimensional convolution. The main problem with tomographic alignment is that a stack of two dimensional projections have to be processed simultameously for optimal results. While two dimensional convolutions can be used with channels corresponding with each projection angle, this is likely similar in computational complexity to a three dimensional neural network. Therefore the approach for this test is similar to video classification, where each frame in a video is is instead each projection angle.

In order to test if this method can provide a convergence, phantoms will be artificially misaligned to create a training and testing set. But first packages for tomography, image transformations, and neural networks have to be imported.

In [1]:
# Import essential packages
import os
import numpy as np
import matplotlib.pyplot as plt

# Import tomography and imaging packages
import tomopy
from skimage.transform import rotate, AffineTransform
from skimage import transform as tf

# Import neural net packages
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.profiler
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torchinfo import summary

Since the model will be a computationally complex CNN, we must ensure that the GPU is being used for calculations or else computation will be far too slow.

In [2]:
# Checking to ensure environment and cuda are correct
print("Working Environment: {}".format(os.environ['CONDA_DEFAULT_ENV']))
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
print("Cuda Version: {}".format(torch.version.cuda))
print("Cuda Availability: {}".format(torch.cuda.is_available()))

Working Environment: pytorch
Cuda Version: 11.8
Cuda Availability: True


Now that the packages have been imported and CUDA is set up correctly, the next step is to create the dataset to be used for training and testing the neural network. The misalignment function is created to perform different random misalignments on the phantom set of tomographic scans. The shape of all of the data is then checked for errors and the data is split into training and testing sets.

In [3]:
# Define function for artificial misalignment
def misalign(prj, mis_axis, ang_tilt = False, noise = False, background = False):
    num_prj, col, row = prj.shape
    dx = mis_axis[:, 0]
    dy = mis_axis[:, 1]
    prj_tmp = tomopy.shift_images(prj, dx, dy)
    
    for i in range(num_prj):
        d_row, d_col, d_ang = mis_axis[i]
        
        if ang_tilt == True:
            prj_tmp[i, :, :] = rotate(prj[i,:,:], d_ang)
        else:
            prj_tmp[i, :, :] = prj[i,:,:]
        
        if noise == True:
            prj_tmp[i, :, :] = random_noise(prj_tmp[i, :, :], mode = 'gaussian')
            
        if background == True:
            prj_tmp[i, :, :] = prj_tmp[i, :, :]+np.random.random()/5
            prj_tmp[i, :, :] = prj_tmp[i, :, :]/prj_tmp[i, :, :].max()  
            
    return prj_tmp

In [4]:
# Creating ground truth tomography
data = tomopy.shepp3d(256)
ang = tomopy.angles(180)
proj = tomopy.project(data, ang)

In [5]:
# Create dataset to store misaligned projections and 
entries = 120
dataset = np.zeros((entries, 2), dtype = object)

for i in range(entries):
    # Randomly determined misalignment axis
    mis_axis = np.random.normal(0, 1, (200, 3))
    mis_axis[:, :1] = mis_axis[:, :1]*4
    mis_axis = np.round(mis_axis).astype(int)
    mis_axis_in = np.expand_dims(mis_axis, axis = 0)
    
    proj_mis = misalign(proj.copy(), mis_axis, ang_tilt = True)
    proj_mis = np.expand_dims(proj_mis, axis = 0)
    proj_mis = np.expand_dims(proj_mis, axis = 0)
    
    dataset[i, 0] = proj_mis
    dataset[i, 1] = np.concatenate((mis_axis_in[:, :180, 0], mis_axis_in[:, :180, 1]), axis = 1)

In [6]:
print(dataset[0][1])

[[ -2   1  -5  -3   1   3  -7   3   1  -1   2   5   8   3   0   0   3  -2
    7   5   2  -1   3  -6   1  -2   3  -8 -10  -8  -2   2 -12  -3  -2   6
   -2   6  -4   0  -1   4   1  -2   8  -2  -4   2   5   0  -3  -1   9   4
   -1   1   5   3   2  -7   1   3   8   0  -4   3  -4  -1  -2   0  -2   1
    1  -4   7   1  -4   0   1   4   4  -1   5  -7   2   2  -6  -1   5   5
   -2   0   2   3  -3   0  -4  -2   0  -1  -7   1  -2   1   2   1  -4  -6
    0  -6  -4   5   5  -3  -4  -2  10   8  -2  -3   6  -4   5  -4  11   0
    1  -5   7  -2   4   2   3  -3   0   4  -1  -5   2   2   1  -3  -6  -8
   -9  -4  -3   2   1 -10  -2   9  -5  -2  -4  -4   2   1  -3  -7  -7   2
    2  -2   2   4  -1  -1  -1  -1   5   0   2  -4   1   0   6  10   9   2
   -1   0   1  -1   1   1   0  -1   1   1   0   1   1   1   1  -1  -1   1
    0  -1  -2   0   0   2  -1   0  -1  -1  -2  -1  -1   0   1   0   0  -1
   -2   3   0  -1  -1  -1  -1   0   1   0   0   1   1   2   0   2  -1   0
   -1   0   0  -1   1   0  -1   0   0 

In [7]:
# Checking shape of dataset
print(dataset.shape)
print(dataset[0].shape)
print(dataset[0][0].shape)
print(dataset[0][1].shape)

(120, 2)
(2,)
(1, 1, 180, 256, 366)
(1, 360)


In [8]:
# Checking shape of training and testing splits
trainset, testset = np.split(dataset, [int(entries* 4 / 5)])
print(trainset.shape)
print(testset.shape)

(96, 2)
(24, 2)


Now that the data has been set up, the CUDA cache should be cleared and the model will be implemented.

In [9]:
torch.cuda.empty_cache()
print("Cleared Cache.")

Cleared Cache.


In [65]:
# Normalize data
def norm(proj):
    proj = (proj - torch.min(proj)) / (torch.max(proj) - torch.min(proj))
    return proj

def g_norm(shift):
    mean_tmp = torch.mean(shift)
    std_tmp = torch.std(shift)
    shift = (shift - mean_tmp) / std_tmp
    # shift = (shift - torch.min(shift)) / (torch.max(shift) - torch.min(shift))
    return 10 * shift

# 3D CNN to determine shift parameters

class CNN_3D_aligner(nn.Module):
    def __init__(self):
        super(CNN_3D_aligner, self).__init__()

        self.group1 = nn.Sequential(
            nn.Conv3d(1, 16, kernel_size=(3, 3, 3), padding=1), 
            nn.BatchNorm3d(16),
            nn.Sigmoid(),
            nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 
        )
        
        self.group2 = nn.Sequential(
            nn.Conv3d(16, 32, kernel_size=(3, 3, 3), padding=1), 
            nn.BatchNorm3d(32),
            nn.Sigmoid(),
            nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 
        )
        
        self.group3 = nn.Sequential(
            nn.Conv3d(32, 64, kernel_size=(3, 3, 3), padding=1), 
            nn.BatchNorm3d(64),
            nn.Sigmoid(),
            nn.Conv3d(64, 64, kernel_size=(3, 3, 3), padding=1), 
            nn.BatchNorm3d(64),
            nn.Sigmoid(),
            nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 
        )
        
        self.group4 = nn.Sequential(
            nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=1), 
            nn.BatchNorm3d(128),
            nn.Sigmoid(),
            nn.Conv3d(128, 128, kernel_size=(3, 3, 3), padding=1), 
            nn.BatchNorm3d(128),
            nn.Sigmoid(),
            nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 
        )
        
        self.group5 = nn.Sequential(
            nn.Conv3d(128, 256, kernel_size=(3, 3, 3), padding=1), 
            nn.BatchNorm3d(256),
            nn.Sigmoid(),
            nn.Conv3d(256, 256, kernel_size=(3, 3, 3), padding=1), 
            nn.BatchNorm3d(256),
            nn.Sigmoid(),
            nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 
        )
        
        self.group6 = nn.Sequential(
            nn.Conv3d(256, 16, kernel_size=(3, 3, 3), padding=1),
            nn.BatchNorm3d(16),
            nn.Sigmoid()
        )
        
        self.flatten = nn.Flatten()
        
        self.fc1 = nn.Sequential(
            nn.Linear(7040, 512),
            nn.Sigmoid(),
            nn.Dropout(0.25),
            nn.Linear(512, 256)
        )
        
        self.fc2 = nn.Linear(256, 360)

    def forward(self, x):
        
        x = norm(x)
        
        x = self.group1(x)
        x = self.group2(x)
        x = self.group3(x)
        x = self.group4(x)
        x = self.group5(x)
        x = self.group6(x)
        
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = g_norm(x)
        
        return x

In order to ensure the network works and understand its structure before training data on it, use the summary function in order to get an understanding of the network and fix any linear algebra errors in creating the neural network.

In [66]:
# Test model shape
model = CNN_3D_aligner()
summary(model, (1, 1, 180, 256, 366))

Layer (type:depth-idx)                   Output Shape              Param #
CNN_3D_aligner                           [1, 360]                  --
├─Sequential: 1-1                        [1, 16, 90, 128, 183]     --
│    └─Conv3d: 2-1                       [1, 16, 180, 256, 366]    448
│    └─BatchNorm3d: 2-2                  [1, 16, 180, 256, 366]    32
│    └─Sigmoid: 2-3                      [1, 16, 180, 256, 366]    --
│    └─MaxPool3d: 2-4                    [1, 16, 90, 128, 183]     --
├─Sequential: 1-2                        [1, 32, 45, 64, 91]       --
│    └─Conv3d: 2-5                       [1, 32, 90, 128, 183]     13,856
│    └─BatchNorm3d: 2-6                  [1, 32, 90, 128, 183]     64
│    └─Sigmoid: 2-7                      [1, 32, 90, 128, 183]     --
│    └─MaxPool3d: 2-8                    [1, 32, 45, 64, 91]       --
├─Sequential: 1-3                        [1, 64, 22, 32, 45]       --
│    └─Conv3d: 2-9                       [1, 64, 45, 64, 91]       55,360
│    └

Now the model can be trained, making sure to move all of the elements of the training process to the GPU to optimize computational speed.

In [None]:
# Train the model

# Dataloader for the trainset
trainload = DataLoader(trainset, batch_size=1, shuffle=True)

# Create writer and profiler to analyze loss over each epoch
writer = SummaryWriter()
prof = torch.profiler.profile(
    schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
    on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/net3d'),
    record_shapes=True, profile_memory=True, with_stack=True)
prof.start()

# Set device to CUDA if available, initialize model
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Device: {}'.format(device))
net = CNN_3D_aligner()
net.to(device)

# Set up optimizer and loss function, set number of epochs
optimizer = optim.SGD(net.parameters(), lr = 1e-2, momentum = 0.9)
criterion = nn.MSELoss(reduction = 'mean')
criterion.to(device)
num_epochs = 50

# Iniitializing variables to show statistics
iteration = 0
loss_list = []
epoch_loss_averages = []

# Iterates over dataset multiple times
for epoch in range(num_epochs):
    epoch_loss = 0
    for i, data in enumerate(trainset, 0):
        inputs, truths = norm(torch.from_numpy(data[0]).to(device)), torch.from_numpy(data[1]).to(device).float()
        optimizer.zero_grad()

        outputs = net(inputs).to(device)
        loss = criterion(outputs, truths)
        if i == 0 and epoch == num_epochs - 1:
            print(truths)
            print("_"*75)
            print(outputs)
        writer.add_scalar("Loss / Train", loss, epoch) # adds training loss scalar
        loss_list.append(loss.cpu().detach().numpy())
        epoch_loss += loss.cpu().detach().numpy()
        loss.backward()
        optimizer.step
        prof.step

        iteration += 1
        if iteration % trainset.shape[0] == 0:
            epoch_loss_averages.append(epoch_loss / trainset.shape[0])
            print('Iteration: {}   Loss: {} '.format(iteration, epoch_loss / trainset.shape[0]))
            
prof.stop()
writer.flush()
writer.close()

Device: cuda:0
Iteration: 96   Loss: 108.38527830441792 
Iteration: 192   Loss: 108.40818643569946 
Iteration: 288   Loss: 108.17373450597127 
Iteration: 384   Loss: 108.26605947812398 
Iteration: 480   Loss: 108.37963628768921 
Iteration: 576   Loss: 108.22290086746216 
Iteration: 672   Loss: 108.23287963867188 
Iteration: 768   Loss: 108.33371567726135 
Iteration: 864   Loss: 108.22760232289632 


Now in order to observe convergence or lack thereof graphs of loss per iteration as well as a moving average based on each epoch are created for analysis.

In [None]:
# Plot epoch loss to test for convergence
plt.plot(epoch_loss_averages)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()

As seen above, this neural network fails to create convergence for the dataset. However, the fact that the loss does not stay completely constant is promising and it is possible that a deeper neural network can prove to solve this problem. The main restrictions at this point are memory allocation errors and different frameworks will have to be used in order to create a neural network with the depth necessary to observe convergence.