# 3D Convolutional Neural Network for Tomographic Alignment

## Expanding the Network: ResNet and Engineering the Data

Since based on previous analysis it is likely that the network was too small to properly identify the features necessary for alignment, the next logical step is to expand the network. Making deeper neural networks usually results in diminishing returns, but residual neural networks have proven to successfully allow for deeper neural networks. Now this model structure will be used in order to find some form of convergence.

While it seems like the ResNet will increase the performance of the model on larger training sets, testing convergence is still a major problem. Now the input data will be modified to find the difference between each projection and the mean "projection", which should help the network recognize difference between projections more effectively. This will hopefully allow for better understanding of the specific alignment problem by the neural network.

In [1]:
# Import essential packages
import os
import math
import numpy as np
import matplotlib.pyplot as plt

# Import tomography and imaging packages
import tomopy
from skimage.transform import rotate, AffineTransform
from skimage import transform as tf

# Import neural net packages
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.profiler
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torchinfo import summary

In [2]:
# Checking to ensure environment and cuda are correct
print("Working Environment: {}".format(os.environ['CONDA_DEFAULT_ENV']))
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
print("Cuda Version: {}".format(torch.version.cuda))
print("Cuda Availability: {}".format(torch.cuda.is_available()))
!pwd

Working Environment: pytorch
Cuda Version: 11.8
Cuda Availability: True
/home/liam/Projects/Tomographic Alignment


In [3]:
# The resulting function that can be used for modifying the data
def data_mean_difference(data):
    
    # Create copy of dataset for projections to be modified, also initializes array for mean projections
    projections = data[:, 0].copy()
    mean_projections = np.zeros((projections.shape[0]), dtype = object)
    
    # Iterate through each projection stack
    for i in range (projections.shape[0]):

        # Get rid of extra dimensions for neural networks and create a mean projection
        projections[i] = np.squeeze(projections[i])
        mean_projections[i] = np.mean(projections[i], axis = 0)

        # Iterate through every projection in stack
        for j in range (projections[0].shape[0]):

            # Change current projection to difference between current projection and mean projection
            projections[i][j] = projections[i][j] - mean_projections[i]
            
        # Expand dimensions to original form for neural networks
        projections[i] = np.expand_dims(projections[i], axis = 0)
        projections[i] = np.expand_dims(projections[i], axis = 0)

    # Create data difference array
    data_diff = data.copy()

    # Replace all data with the new difference projections
    for i in range (data.shape[0]):

        data_diff[i][0] = projections[i]
        
    return data_diff

In [4]:
# Loading data, 25 entries of 128 resolution shepp3ds
res = 128
entries = 250
data = []

for i in range(entries):
    data.append(np.load('./shepp{}-{}/shepp{}-{}_{}.npy'.format(res, entries, res, entries, i), 
                        allow_pickle = True))
    
data = np.asarray(data)

data = data_mean_difference(data)

In [5]:
# Checking shape of training and testing splits
trainset, testset = np.split(data, [int(entries * 4 / 5)])
print("Shape of Training Dataset: {}".format(trainset.shape))
print("Shape of Testing Dataset: {}".format(testset.shape))

Shape of Training Dataset: (200, 2)
Shape of Testing Dataset: (50, 2)


In [6]:
# Normalize data
def norm(proj):
    proj = (proj - torch.min(proj)) / (torch.max(proj) - torch.min(proj))
    return proj

# Get inplanes for resnet
def get_inplanes():
    return [64, 128, 256, 512]


# Preset for a 3x3x3 kernel convolution
def conv3x3x3(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes,
                     out_planes,
                     kernel_size=3,
                     stride=stride,
                     padding=1,
                     bias=False)


# Preset for a 1x1x1 kernel convolution
def conv1x1x1(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes,
                     out_planes,
                     kernel_size=1,
                     stride=stride,
                     bias=False)

# Basic block for resnet
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super().__init__()

        self.conv1 = conv3x3x3(in_planes, planes, stride)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3x3(planes, planes)
        self.bn2 = nn.BatchNorm3d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out
    
# Bottleneck block for resnet
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super().__init__()

        self.conv1 = conv1x1x1(in_planes, planes)
        self.bn1 = nn.BatchNorm3d(planes)
        self.conv2 = conv3x3x3(planes, planes, stride)
        self.bn2 = nn.BatchNorm3d(planes)
        self.conv3 = conv1x1x1(planes, planes * self.expansion)
        self.bn3 = nn.BatchNorm3d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out
    

# Resnet structure
class ResNet(nn.Module):

    def __init__(self,
                 block,
                 layers,
                 block_inplanes,
                 n_input_channels=1,
                 conv1_t_size=7,
                 conv1_t_stride=1,
                 no_max_pool=False,
                 shortcut_type='B',
                 widen_factor=1.0,
                 n_classes=360):
        super().__init__()

        block_inplanes = [int(x * widen_factor) for x in block_inplanes]

        self.in_planes = block_inplanes[0]
        self.no_max_pool = no_max_pool

        self.conv1 = nn.Conv3d(n_input_channels,
                               self.in_planes,
                               kernel_size=(conv1_t_size, 7, 7),
                               stride=(conv1_t_stride, 2, 2),
                               padding=(conv1_t_size // 2, 3, 3),
                               bias=False)
        self.bn1 = nn.BatchNorm3d(self.in_planes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, block_inplanes[0], layers[0],
                                       shortcut_type)
        self.layer2 = self._make_layer(block,
                                       block_inplanes[1],
                                       layers[1],
                                       shortcut_type,
                                       stride=2)
        self.layer3 = self._make_layer(block,
                                       block_inplanes[2],
                                       layers[2],
                                       shortcut_type,
                                       stride=2)
        self.layer4 = self._make_layer(block,
                                       block_inplanes[3],
                                       layers[3],
                                       shortcut_type,
                                       stride=2)

        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc = nn.Linear(block_inplanes[3] * block.expansion, n_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight,
                                        mode='fan_out',
                                        nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _downsample_basic_block(self, x, planes, stride):
        out = F.avg_pool3d(x, kernel_size=1, stride=stride)
        zero_pads = torch.zeros(out.size(0), planes - out.size(1), out.size(2),
                                out.size(3), out.size(4))
        if isinstance(out.data, torch.cuda.FloatTensor):
            zero_pads = zero_pads.cuda()

        out = torch.cat([out.data, zero_pads], dim=1)

        return out

    # make layer helper function
    def _make_layer(self, block, planes, blocks, shortcut_type, stride=1):
        downsample = None
        if stride != 1 or self.in_planes != planes * block.expansion:
            
                downsample = nn.Sequential(
                    conv1x1x1(self.in_planes, planes * block.expansion, stride),
                    nn.BatchNorm3d(planes * block.expansion))

        layers = []
        layers.append(
            block(in_planes=self.in_planes,
                  planes=planes,
                  stride=stride,
                  downsample=downsample))
        self.in_planes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.in_planes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        if not self.no_max_pool:
            x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)

        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


# Generates form of resnet
def generate_model(model_depth, **kwargs):
    assert model_depth in [10, 18, 34, 50, 101, 152, 200]

    if model_depth == 10:
        model = ResNet(BasicBlock, [1, 1, 1, 1], get_inplanes(), **kwargs)
    elif model_depth == 18:
        model = ResNet(BasicBlock, [2, 2, 2, 2], get_inplanes(), **kwargs)
    elif model_depth == 34:
        model = ResNet(BasicBlock, [3, 4, 6, 3], get_inplanes(), **kwargs)
    elif model_depth == 50:
        model = ResNet(Bottleneck, [3, 4, 6, 3], get_inplanes(), **kwargs)
    elif model_depth == 101:
        model = ResNet(Bottleneck, [3, 4, 23, 3], get_inplanes(), **kwargs)
    elif model_depth == 152:
        model = ResNet(Bottleneck, [3, 8, 36, 3], get_inplanes(), **kwargs)
    elif model_depth == 200:
        model = ResNet(Bottleneck, [3, 24, 36, 3], get_inplanes(), **kwargs)

    return model

In [7]:
model = generate_model(50)
summary(model, (1, 1, 180, 128, 184))

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   [1, 360]                  --
├─Conv3d: 1-1                            [1, 64, 180, 64, 92]      21,952
├─BatchNorm3d: 1-2                       [1, 64, 180, 64, 92]      128
├─ReLU: 1-3                              [1, 64, 180, 64, 92]      --
├─MaxPool3d: 1-4                         [1, 64, 90, 32, 46]       --
├─Sequential: 1-5                        [1, 256, 90, 32, 46]      --
│    └─Bottleneck: 2-1                   [1, 256, 90, 32, 46]      --
│    │    └─Conv3d: 3-1                  [1, 64, 90, 32, 46]       4,096
│    │    └─BatchNorm3d: 3-2             [1, 64, 90, 32, 46]       128
│    │    └─ReLU: 3-3                    [1, 64, 90, 32, 46]       --
│    │    └─Conv3d: 3-4                  [1, 64, 90, 32, 46]       110,592
│    │    └─BatchNorm3d: 3-5             [1, 64, 90, 32, 46]       128
│    │    └─ReLU: 3-6                    [1, 64, 90, 32, 46]       --


In [8]:
# Train the model

# Create writer and profiler to analyze loss over each epoch
writer = SummaryWriter()

# Set device to CUDA if available, initialize model
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Device: {}'.format(device))
net = generate_model(50)
net.to(device)

# Set up optimizer and loss function, set number of epochs
optimizer = optim.SGD(net.parameters(), lr = 5e-2, momentum=0.9, weight_decay = 0)
criterion = nn.MSELoss(reduction = 'mean')
criterion.to(device)
num_epochs = 500

# Iniitializing variables to show statistics
iteration = 0
test_iteration = 0
loss_list = []
test_loss_list = []
epoch_loss_averages = []
test_epoch_loss_averages = []

# Iterates over dataset multiple times
for epoch in range(num_epochs):
    
    epoch_loss = 0
    test_epoch_loss = 0
    
    for i, data in enumerate(trainset, 0):
        
        inputs, truths = norm(torch.from_numpy(data[0]).to(device)), torch.from_numpy(data[1]).to(device).float()
        optimizer.zero_grad()

        outputs = net(inputs).to(device)
        loss = criterion(outputs, truths)
        writer.add_scalar("Loss / Train", loss, epoch) # adds training loss scalar
        loss_list.append(loss.cpu().detach().numpy())
        epoch_loss += loss.cpu().detach().numpy()
        loss.backward()
        optimizer.step()

        iteration += 1
        if iteration % trainset.shape[0] == 0:
            epoch_loss_averages.append(epoch_loss / trainset.shape[0])
            print('Epoch: {}   Training Loss: {} '.format(epoch, epoch_loss / trainset.shape[0]))
            
    for i, test_data in enumerate(testset, 0):
        
        inputs, truths = norm(torch.from_numpy(test_data[0]).to(device)), torch.from_numpy(test_data[1]).to(device).float()
        outputs = net(inputs).to(device)
        test_loss = criterion(outputs, truths)
        
        writer.add_scalar("Loss / Test", test_loss, epoch) # adds testing loss scalar
        test_loss_list.append(test_loss.cpu().detach().numpy())
        test_epoch_loss += test_loss.cpu().detach().numpy()
        
        test_iteration +=1
        if test_iteration % testset.shape[0] == 0:
            test_epoch_loss_averages.append(test_epoch_loss / testset.shape[0])
            print('Epoch: {}   Validation Loss: {} '.format(epoch, test_epoch_loss / testset.shape[0]))
            
writer.flush()
writer.close()

Device: cuda:0
Epoch: 0   Training Loss: 8.910967659950256 
Epoch: 0   Validation Loss: 8.727133646011353 
Epoch: 1   Training Loss: 8.56054787158966 
Epoch: 1   Validation Loss: 8.72879846572876 
Epoch: 2   Training Loss: 8.546824946403504 
Epoch: 2   Validation Loss: 8.73028904914856 
Epoch: 3   Training Loss: 8.53639750480652 
Epoch: 3   Validation Loss: 8.733432140350342 
Epoch: 4   Training Loss: 8.533109016418457 
Epoch: 4   Validation Loss: 8.733615894317627 
Epoch: 5   Training Loss: 8.522382926940917 
Epoch: 5   Validation Loss: 8.736032991409301 
Epoch: 6   Training Loss: 8.511893513202667 
Epoch: 6   Validation Loss: 8.739968938827515 
Epoch: 7   Training Loss: 8.50456552028656 
Epoch: 7   Validation Loss: 8.740758304595948 
Epoch: 8   Training Loss: 8.500386874675751 
Epoch: 8   Validation Loss: 8.739876985549927 
Epoch: 9   Training Loss: 8.498206617832183 
Epoch: 9   Validation Loss: 8.740848627090454 
Epoch: 10   Training Loss: 8.493323023319244 
Epoch: 10   Validation L

Epoch: 88   Training Loss: 8.46782990604639 
Epoch: 88   Validation Loss: 8.858855066299439 
Epoch: 89   Training Loss: 8.449937668144702 
Epoch: 89   Validation Loss: 8.874694061279296 
Epoch: 90   Training Loss: 8.445013651177288 
Epoch: 90   Validation Loss: 8.790194654464722 
Epoch: 91   Training Loss: 8.437083059251309 
Epoch: 91   Validation Loss: 8.780454216003418 
Epoch: 92   Training Loss: 8.426953628957271 
Epoch: 92   Validation Loss: 8.763729438781738 
Epoch: 93   Training Loss: 8.419134220927953 
Epoch: 93   Validation Loss: 8.747516841888428 
Epoch: 94   Training Loss: 8.421218005120755 
Epoch: 94   Validation Loss: 8.738847455978394 
Epoch: 95   Training Loss: 8.41088564403355 
Epoch: 95   Validation Loss: 8.755626983642578 
Epoch: 96   Training Loss: 8.416103798672557 
Epoch: 96   Validation Loss: 8.750886497497559 
Epoch: 97   Training Loss: 8.400067842714488 
Epoch: 97   Validation Loss: 8.75445424079895 
Epoch: 98   Training Loss: 8.385491701811553 
Epoch: 98   Valid

Epoch: 174   Training Loss: 7.883467222601175 
Epoch: 174   Validation Loss: 9.252546548843384 
Epoch: 175   Training Loss: 7.933497635424137 
Epoch: 175   Validation Loss: 9.208754997253418 
Epoch: 176   Training Loss: 7.898731876313686 
Epoch: 176   Validation Loss: 9.146307563781738 
Epoch: 177   Training Loss: 7.844026287645102 
Epoch: 177   Validation Loss: 9.191887817382813 
Epoch: 178   Training Loss: 7.813667964413762 
Epoch: 178   Validation Loss: 9.169699277877807 
Epoch: 179   Training Loss: 7.809930724501609 
Epoch: 179   Validation Loss: 9.198466663360596 
Epoch: 180   Training Loss: 7.819246979728341 
Epoch: 180   Validation Loss: 9.142744302749634 
Epoch: 181   Training Loss: 7.805179300904274 
Epoch: 181   Validation Loss: 9.101352910995484 
Epoch: 182   Training Loss: 7.734932668954134 
Epoch: 182   Validation Loss: 9.145942058563232 
Epoch: 183   Training Loss: 7.705524940788746 
Epoch: 183   Validation Loss: 9.15460651397705 
Epoch: 184   Training Loss: 7.71279370188

Epoch: 259   Validation Loss: 11.329025077819825 
Epoch: 260   Training Loss: 4.570484665036202 
Epoch: 260   Validation Loss: 11.138452339172364 
Epoch: 261   Training Loss: 4.479416221678257 
Epoch: 261   Validation Loss: 11.055455074310302 
Epoch: 262   Training Loss: 4.50043534964323 
Epoch: 262   Validation Loss: 11.096573181152344 
Epoch: 263   Training Loss: 4.244444440603257 
Epoch: 263   Validation Loss: 10.962744045257569 
Epoch: 264   Training Loss: 4.444811956882477 
Epoch: 264   Validation Loss: 11.06393310546875 
Epoch: 265   Training Loss: 4.151600767970085 
Epoch: 265   Validation Loss: 11.214322090148926 
Epoch: 266   Training Loss: 4.004123290777207 
Epoch: 266   Validation Loss: 11.150833435058594 
Epoch: 267   Training Loss: 4.125936575531959 
Epoch: 267   Validation Loss: 11.319279556274415 
Epoch: 268   Training Loss: 3.7784598967432976 
Epoch: 268   Validation Loss: 11.658071212768554 
Epoch: 269   Training Loss: 3.8922485002875327 
Epoch: 269   Validation Loss: 

RuntimeError: CUDA error: unspecified launch failure
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
# Plot epoch loss to test for convergence
plt.plot(epoch_loss_averages)
plt.plot(test_epoch_loss_averages)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()

As we can see, a small dataset shows convergence which means this network MAY be able to be a solution. The next step is to test on a significantly larger dataset to see if this can deal with the overfitting problem. If the large set has convergence in the training data with a larger set then the overfitting must be dealt with in another way