In [3]:
import os
import nibabel as nib
import numpy as np
import random
import nibabel as nib
import numpy as np
from scipy import ndimage
from scipy.ndimage import zoom
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

2023-09-11 14:32:29.922830: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-09-11 14:32:29.941602: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
# random seed (for reproducibility)
seed = 1
# set random seed for numpy
np.random.seed(seed)
# set random seed for pytorch
torch.manual_seed(seed)

<torch._C.Generator at 0x7f6c9b584f10>

In [5]:
torch.cuda.empty_cache()
print(torch.cuda.memory_summary(abbreviated=False))

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |      0 B   |      0 B   |      0 B   |      0 B   |
|       from large pool |      0 B   |      0 B   |      0 B   |      0 B   |
|       from small pool |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------------------|
| Active memory         |      0 B   |      0 B   |      0 B   |      0 B   |
|       from large pool |      0 B   |      0 B   |      0 B   |      0 B   |
|       from small pool |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------

In [6]:
DATA_PATH = os.path.join(os.getcwd(), 'Data/MIRIAD/miriad')
config = {
    'img_size': 128,
    'depth' : 64,
    'batch_size' : 16
}

In [7]:
AD_path, HC_path = [], []
for scan_dir in os.listdir(DATA_PATH):
    scan_dir_path = os.path.join(DATA_PATH, scan_dir)
    if os.path.isdir(scan_dir_path):
        for visit in os.listdir(scan_dir_path):
            scan_visit = os.path.join(scan_dir_path, visit)
            for f in os.listdir(scan_visit):
                f_path = os.path.join(scan_visit, f)
                if f.endswith(".nii"):
                    if 'AD' in f:
                        AD_path.append(f_path)
                    else:
                        HC_path.append(f_path)
                        
        
print("Total number of images are: {}\n".format(len(AD_path)+len(HC_path)))     
incons = []
for path in AD_path+HC_path:
    scan = nib.load(path)
    data = scan.get_fdata()
    if data.shape != (256, 256, 124):
        print('Shape inconsistancy found! {} for {}'.format(data.shape, path))
        incons.append(path)
        
print("\n\nRemove shape insconsitent images.")
for path in incons:
    if 'AD' in path:
        AD_path.remove(path)
    else:
        HC_path.remove(path)
        
print("After Removing shape inconsistent images total number of images {}.".format(len(AD_path)+len(HC_path)))
print("Number of Alzheimer infected MRI scans: {}".format(len(AD_path)))
print("Number of Healthy MRI scans: {}".format(len(HC_path)))

Total number of images are: 708

Shape inconsistancy found! (256, 256, 123) for /home/arindam/Alzheimer/Data/MIRIAD/miriad/miriad_192_AD_M/miriad_192_AD_M_06_MR_2/miriad_192_AD_M_06_MR_2.nii


Remove shape insconsitent images.
After Removing shape inconsistent images total number of images 707.
Number of Alzheimer infected MRI scans: 464
Number of Healthy MRI scans: 243


In [8]:
def train_test_split(path_list, train_split_ratio=0.75, val_split_ratio=0.15):
    test_split_ratio = 1 - (train_split_ratio + val_split_ratio)
    random.shuffle(path_list)
    
    train_image_paths = path_list[:int(train_split_ratio*len(path_list))]
    valid_image_paths = path_list[int(train_split_ratio*len(path_list)):int((train_split_ratio+val_split_ratio)*len(path_list))]
    test_image_paths = path_list[int((train_split_ratio+val_split_ratio)*len(path_list)):]
    
    if len(train_image_paths)+len(valid_image_paths)+len(test_image_paths)==len(path_list):
        print("Everything is fine. No of images in train, val and test set is {}, {} and {} respectively.".format(len(train_image_paths), len(valid_image_paths), len(test_image_paths)))
    else:
        print("Something wrong. Go and start debug!")
    
    return train_image_paths, valid_image_paths, test_image_paths


train_ad_image_paths, val_ad_image_paths, test_ad_image_paths = train_test_split(AD_path)
train_hc_image_paths, val_hc_image_paths, test_hc_image_paths = train_test_split(HC_path)

train_img_paths = train_ad_image_paths + train_hc_image_paths
val_img_paths = val_ad_image_paths + val_hc_image_paths
test_img_paths = test_ad_image_paths + test_hc_image_paths

print("Total no of train, validation and test images are {}, {} and {} respectively.".format(len(train_img_paths), len(val_img_paths), len(test_img_paths)))

Everything is fine. No of images in train, val and test set is 348, 69 and 47 respectively.
Everything is fine. No of images in train, val and test set is 182, 36 and 25 respectively.
Total no of train, validation and test images are 530, 105 and 72 respectively.


In [9]:
class ProcessScan:
    def __init__(self):
        pass
    
    def read_nifti_file(self, filepath):
        """Read and load volume"""
        scan = nib.load(filepath)
        scan = scan.get_fdata()
        return scan


    def normalize(self, volume):
        """Normalize the volume"""
        volume = (volume - np.min(volume)) / (np.max(volume) - np.min(volume))
        return volume.astype('float32')


    def resize_volume(self, img, desired_width=256, desired_height=256, desired_depth=64):
        """Resize the volume"""
        width_factor = desired_width / img.shape[0]
        height_factor = desired_height / img.shape[1]
        depth_factor = desired_depth / img.shape[-1]

        #img = ndimage.rotate(img, 90, reshape=False)
        img = zoom(img, (width_factor, height_factor, depth_factor), order=1)
        return img


    def process_scan(self, path):
        """Read and resize volume"""
        volume = self.read_nifti_file(path)
        volume = self.normalize(volume)
        volume = self.resize_volume(volume, config['img_size'], config['img_size'], config['depth'])

        return volume
    
    def label_extract(self, path):
        """Label Extraction"""
        path = path.split('/')[-1]
        if 'AD' in path:
            return 1
        elif 'HC' in path:
            return 0
    
scan_process = ProcessScan()

In [10]:
class MIRIADAlzheimerDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_filepath = self.image_paths[idx]
        image = scan_process.process_scan(image_filepath)
        image = image[:, :, int(image.shape[-1]/2)]   # Get the middle slice
        image = image.reshape(1, image.shape[0], image.shape[1]) # image.shape[-1])
        
        label = scan_process.label_extract(image_filepath)
        if self.transform is not None:
            image = self.transform(image=image)["image"]
        
        return image, label
    


train_dataset = MIRIADAlzheimerDataset(train_img_paths)
valid_dataset = MIRIADAlzheimerDataset(val_img_paths)
test_dataset = MIRIADAlzheimerDataset(test_img_paths)

In [11]:
print('The shape of tensor for 50th image in train dataset: ',train_dataset[49][0].shape)
print('The label for 50th image in train dataset: ',train_dataset[49][1])

The shape of tensor for 50th image in train dataset:  (1, 128, 128)
The label for 50th image in train dataset:  1


In [12]:
train_loader = DataLoader(
    train_dataset, batch_size=config['batch_size'], shuffle=True
)

valid_loader = DataLoader(
    valid_dataset, batch_size=config['batch_size'], shuffle=True
)


test_loader = DataLoader(
    test_dataset, batch_size=config['batch_size'], shuffle=False
)

#batch of image tensor
print("Size of one batch ---> {}".format(next(iter(train_loader))[0].shape))

Size of one batch ---> torch.Size([16, 1, 128, 128])


In [13]:
#import matplotlib.pyplot as plt
#%matplotlib inline
    
# obtain one batch of training images
#dataiter = iter(train_loader)
#images, labels = next(dataiter)
#images = images.numpy()

# plot the images in the batch, along with the corresponding labels
#fig = plt.figure(figsize=(25, 4))
#for idx in np.arange(config['batch_size']):
#    ax = fig.add_subplot(2, config['batch_size']//2, idx+1, xticks=[], yticks=[])
#    ax.imshow(np.squeeze(images[idx]), cmap='gray')
#    # print out the correct label for each image
#    # .item() gets the value contained in a Tensor
#    ax.set_title(str(labels[idx].item()))

In [14]:
# Available device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def squash(x, dim=-1):
    squared_norm = (x ** 2).sum(dim=dim, keepdim=True)
    scale = squared_norm / (1 + squared_norm)
    return scale * x / (squared_norm.sqrt() + 1e-8)


class PrimaryCaps(nn.Module):
    """Primary capsule layer."""

    def __init__(self, num_conv_units, in_channels, out_channels, kernel_size, stride):
        super(PrimaryCaps, self).__init__()

        # Each conv unit stands for a single capsule.
        self.conv = nn.Conv2d(in_channels=in_channels,
                              out_channels=out_channels * num_conv_units,
                              kernel_size=kernel_size,
                              stride=stride)
        self.out_channels = out_channels

    def forward(self, x):
        # Shape of x: (batch_size, in_channels, height, weight)
        # Shape of out: out_capsules * (batch_size, out_channels, height, weight)
        out = self.conv(x)
        # Flatten out: (batch_size, out_capsules * height * weight, out_channels)
        batch_size = out.shape[0]
        return squash(out.contiguous().view(batch_size, -1, self.out_channels), dim=-1)


class DigitCaps(nn.Module):
    """Digit capsule layer."""

    def __init__(self, in_dim, in_caps, out_caps, out_dim, num_routing):
        """
        Initialize the layer.

        Args:
            in_dim: 		Dimensionality of each capsule vector.
            in_caps: 		Number of input capsules if digits layer.
            out_caps: 		Number of capsules in the capsule layer
            out_dim: 		Dimensionality, of the output capsule vector.
            num_routing:	Number of iterations during routing algorithm
        """
        super(DigitCaps, self).__init__()
        self.in_dim = in_dim
        self.in_caps = in_caps
        self.out_caps = out_caps
        self.out_dim = out_dim
        self.num_routing = num_routing
        self.device = device
        self.W = nn.Parameter(0.01 * torch.randn(1, out_caps, in_caps, out_dim, in_dim),
                              requires_grad=True)

    def forward(self, x):
        batch_size = x.size(0)
        # (batch_size, in_caps, in_dim) -> (batch_size, 1, in_caps, in_dim, 1)
        x = x.unsqueeze(1).unsqueeze(4)
        # W @ x =
        # (1, out_caps, in_caps, out_dim, in_dim) @ (batch_size, 1, in_caps, in_dim, 1) =
        # (batch_size, out_caps, in_caps, out_dims, 1)
        u_hat = torch.matmul(self.W, x)
        # (batch_size, out_caps, in_caps, out_dim)
        u_hat = u_hat.squeeze(-1)
        # detach u_hat during routing iterations to prevent gradients from flowing
        temp_u_hat = u_hat.detach()

        b = torch.zeros(batch_size, self.out_caps, self.in_caps, 1).to(self.device)

        for route_iter in range(self.num_routing - 1):
            # (batch_size, out_caps, in_caps, 1) -> Softmax along out_caps
            c = b.softmax(dim=1)

            # element-wise multiplication
            # (batch_size, out_caps, in_caps, 1) * (batch_size, in_caps, out_caps, out_dim) ->
            # (batch_size, out_caps, in_caps, out_dim) sum across in_caps ->
            # (batch_size, out_caps, out_dim)
            s = (c * temp_u_hat).sum(dim=2)
            # apply "squashing" non-linearity along out_dim
            v = squash(s)
            # dot product agreement between the current output vj and the prediction uj|i
            # (batch_size, out_caps, in_caps, out_dim) @ (batch_size, out_caps, out_dim, 1)
            # -> (batch_size, out_caps, in_caps, 1)
            uv = torch.matmul(temp_u_hat, v.unsqueeze(-1))
            b += uv

        # last iteration is done on the original u_hat, without the routing weights update
        c = b.softmax(dim=1)
        s = (c * u_hat).sum(dim=2)
        # apply "squashing" non-linearity along out_dim
        v = squash(s)

        return v

In [33]:
in_ch_conv1 = 1
out_ch_conv1 = 128
kernel_conv1 = 9

in_ch_primary = out_ch_conv1
out_ch_primary = 16
num_conv_units_primary = 16

digit_caps = out_ch_primary


class CapsNet(nn.Module):
    """Basic implementation of capsule network layer."""

    def __init__(self):
        super(CapsNet, self).__init__()

        # Conv2d layer
        self.conv = nn.Conv2d(in_ch_conv1, out_ch_conv1, kernel_conv1, stride=2)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(3, stride=2)

        # Primary capsule
        self.primary_caps = PrimaryCaps(num_conv_units=num_conv_units_primary,
                                        in_channels=in_ch_primary,
                                        out_channels=out_ch_primary,
                                        kernel_size=kernel_conv1,
                                        stride=2)

        # Digit capsule
        self.digit_caps = DigitCaps(in_dim=digit_caps,
                                    in_caps=num_conv_units_primary * 11 * 11,
                                    out_caps=2,
                                    out_dim=256,
                                    num_routing=3)
        

        # Reconstruction layer
        self.decoder = nn.Sequential(
            nn.Linear(2* 256, 1024),
            nn.ReLU(inplace=True),
            #nn.Linear(256, 1024),
            #nn.ReLU(inplace=True),
            nn.Linear(1024, 128*128),
            nn.Sigmoid()
        )

    def forward(self, x):
        out = self.relu(self.conv(x))
        out = self.relu(self.maxpool(out))
        out = self.primary_caps(out)
        out = self.digit_caps(out)

        # Shape of logits: (batch_size, out_capsules)
        logits = torch.norm(out, dim=-1)
        pred = torch.eye(2).to(device).index_select(dim=0, index=torch.argmax(logits, dim=1))

        # Reconstruction
        batch_size = out.shape[0]
        reconstruction = self.decoder((out * pred.unsqueeze(2)).contiguous().view(batch_size, -1))

        return logits , reconstruction

In [34]:
class CapsuleLoss(nn.Module):
    """Combine margin loss & reconstruction loss of capsule network."""

    def __init__(self, upper_bound=0.9, lower_bound=0.1, lmda=0.5):
        super(CapsuleLoss, self).__init__()
        self.upper = upper_bound
        self.lower = lower_bound
        self.lmda = lmda
        self.reconstruction_loss_scalar = 5e-4
        self.mse = nn.MSELoss(reduction='sum')

    def forward(self, images, labels, logits, reconstructions):
        # Shape of left / right / labels: (batch_size, num_classes)
        left = (self.upper - logits).relu() ** 2  # True negative
        right = (logits - self.lower).relu() ** 2  # False positive
        margin_loss = torch.sum(labels * left) + self.lmda * torch.sum((1 - labels) * right)

        # Reconstruction loss
        reconstruction_loss = self.mse(reconstructions.contiguous().view(images.shape), images)

        # Combine two losses
        return margin_loss + self.reconstruction_loss_scalar * reconstruction_loss

In [35]:
# Load model
model = CapsNet().to(device)
criterion = CapsuleLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.96)

In [36]:
total_params = sum(p.numel() for p in model.parameters())
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Total No of Parameters {} \nTotal no of trainable parameters {}".format(total_params, total_trainable_params))


Total No of Parameters 35843584 
Total no of trainable parameters 35843584


In [37]:
for p in model.parameters():
    print(p.numel())

10368
128
2654208
256
15859712
524288
1024
16777216
16384


In [32]:
# Train
EPOCHES = 50
model.train()
for ep in range(EPOCHES):
    batch_id = 1
    correct, total, total_loss = 0, 0, 0.
    for images, labels in train_loader:
        optimizer.zero_grad()
        images = images.to(device)
        labels = torch.eye(2).index_select(dim=0, index=labels).to(device)
        logits, reconstruction = model(images)
        #logits = model(images)

        # Compute loss & accuracy
        loss = criterion(images, labels, logits, reconstruction)
        correct += torch.sum(
            torch.argmax(logits, dim=1) == torch.argmax(labels, dim=1)).item()
        total += len(labels)
        accuracy = correct / total
        total_loss += loss
        loss.backward()
        optimizer.step()
        print('Epoch {}, batch {}, loss: {}, accuracy: {}'.format(ep + 1,
                                                                  batch_id,
                                                                  total_loss / batch_id,
                                                                  accuracy))
        batch_id += 1
    scheduler.step(ep)
    print('Total loss for epoch {}: {}'.format(ep + 1, total_loss))
    model.eval()
    vcorrect, vtotal = 0, 0
    for vimages, vlabels in val_loader:
        # Add channels = 1
        vimages = vimages.to(device)
        # Categogrical encoding
        vlabels = torch.eye(10).index_select(dim=0, index=vlabels).to(device)
        vlogits, vreconstructions = model(vimages)
        vpred_labels = torch.argmax(logits, dim=1)
        vcorrect += torch.sum(vpred_labels == torch.argmax(vlabels, dim=1)).item()
        vtotal += len(vlabels)
    print('Validation Accuracy: {}\n'.format(vcorrect / vtotal))
    model.train(True)

Epoch 1, batch 1, loss: 26.686803817749023, accuracy: 0.5
Epoch 1, batch 2, loss: 26.844375610351562, accuracy: 0.5625
Epoch 1, batch 3, loss: 26.78057861328125, accuracy: 0.6041666666666666
Epoch 1, batch 4, loss: 26.483919143676758, accuracy: 0.609375
Epoch 1, batch 5, loss: 26.033588409423828, accuracy: 0.575
Epoch 1, batch 6, loss: 25.49132537841797, accuracy: 0.5833333333333334
Epoch 1, batch 7, loss: 24.90448760986328, accuracy: 0.5982142857142857
Epoch 1, batch 8, loss: 24.246593475341797, accuracy: 0.6015625
Epoch 1, batch 9, loss: 23.480009078979492, accuracy: 0.6180555555555556
Epoch 1, batch 10, loss: 22.681442260742188, accuracy: 0.625
Epoch 1, batch 11, loss: 21.871740341186523, accuracy: 0.6306818181818182
Epoch 1, batch 12, loss: 21.068254470825195, accuracy: 0.6354166666666666
Epoch 1, batch 13, loss: 20.29008674621582, accuracy: 0.6442307692307693
Epoch 1, batch 14, loss: 19.56147575378418, accuracy: 0.65625
Epoch 1, batch 15, loss: 18.882957458496094, accuracy: 0.6458



Epoch 2, batch 1, loss: 7.895901679992676, accuracy: 0.875
Epoch 2, batch 2, loss: 7.820241928100586, accuracy: 0.71875
Epoch 2, batch 3, loss: 7.824572563171387, accuracy: 0.7083333333333334
Epoch 2, batch 4, loss: 7.768609523773193, accuracy: 0.75
Epoch 2, batch 5, loss: 7.805556774139404, accuracy: 0.725
Epoch 2, batch 6, loss: 7.797515869140625, accuracy: 0.71875
Epoch 2, batch 7, loss: 7.791528224945068, accuracy: 0.6785714285714286
Epoch 2, batch 8, loss: 7.771053791046143, accuracy: 0.6796875
Epoch 2, batch 9, loss: 7.7586669921875, accuracy: 0.6805555555555556
Epoch 2, batch 10, loss: 7.737706661224365, accuracy: 0.675
Epoch 2, batch 11, loss: 7.720767974853516, accuracy: 0.6704545454545454
Epoch 2, batch 12, loss: 7.728679656982422, accuracy: 0.6614583333333334
Epoch 2, batch 13, loss: 7.723946571350098, accuracy: 0.6634615384615384
Epoch 2, batch 14, loss: 7.722270965576172, accuracy: 0.6473214285714286
Epoch 2, batch 15, loss: 7.713057994842529, accuracy: 0.65
Epoch 2, batch

Epoch 5, batch 18, loss: 7.567408084869385, accuracy: 0.6701388888888888
Epoch 5, batch 19, loss: 7.565521240234375, accuracy: 0.6677631578947368
Epoch 5, batch 20, loss: 7.567768096923828, accuracy: 0.66875
Epoch 5, batch 21, loss: 7.568479537963867, accuracy: 0.6696428571428571
Epoch 5, batch 22, loss: 7.567031383514404, accuracy: 0.6619318181818182
Epoch 5, batch 23, loss: 7.568179607391357, accuracy: 0.6657608695652174
Epoch 5, batch 24, loss: 7.568295955657959, accuracy: 0.6666666666666666
Epoch 5, batch 25, loss: 7.5697021484375, accuracy: 0.655
Epoch 5, batch 26, loss: 7.570771217346191, accuracy: 0.6538461538461539
Epoch 5, batch 27, loss: 7.57048225402832, accuracy: 0.6504629629629629
Epoch 5, batch 28, loss: 7.571534633636475, accuracy: 0.6540178571428571
Epoch 5, batch 29, loss: 7.570863246917725, accuracy: 0.6530172413793104
Epoch 5, batch 30, loss: 7.572089195251465, accuracy: 0.65
Epoch 5, batch 31, loss: 7.569310665130615, accuracy: 0.6451612903225806
Epoch 5, batch 32, 

Epoch 9, batch 1, loss: 7.648087501525879, accuracy: 0.5625
Epoch 9, batch 2, loss: 7.615144729614258, accuracy: 0.625
Epoch 9, batch 3, loss: 7.613123893737793, accuracy: 0.6458333333333334
Epoch 9, batch 4, loss: 7.594395637512207, accuracy: 0.59375
Epoch 9, batch 5, loss: 7.588812351226807, accuracy: 0.6125
Epoch 9, batch 6, loss: 7.5758819580078125, accuracy: 0.6458333333333334
Epoch 9, batch 7, loss: 7.5697712898254395, accuracy: 0.6428571428571429
Epoch 9, batch 8, loss: 7.554168224334717, accuracy: 0.65625
Epoch 9, batch 9, loss: 7.563313007354736, accuracy: 0.6666666666666666


KeyboardInterrupt: 

In [None]:
# Eval
model.eval()
correct, total = 0, 0
for images, labels in test_loader:
    # Add channels = 1
    images = images.to(device)
    # Categogrical encoding
    labels = torch.eye(10).index_select(dim=0, index=labels).to(device)
    logits, reconstructions = model(images)
    pred_labels = torch.argmax(logits, dim=1)
    correct += torch.sum(pred_labels == torch.argmax(labels, dim=1)).item()
    total += len(labels)
print('Accuracy: {}'.format(correct / total))

# Save model
torch.save(model.state_dict(), './model/capsnet_ep{}_acc{}.pt'.format(EPOCHES, correct / total))