In [1]:
# Importing the libraries
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from sklearn.metrics import confusion_matrix
from dataLoader import *
from sklearn import neighbors
import time
from prettytable import PrettyTable
from torch_geometric.nn import GCNConv, GATConv, GINConv, global_mean_pool, global_max_pool, global_add_pool

torch.cuda.empty_cache()

In [2]:
# Load the datasets
DATA_PATH = os.path.join('/home/arindam/Alzheimer', 'Data/MIRIAD/miriad')
config = {
    'img_size': 256,
    'depth' : 64,
    'batch_size' : 8
}

# Modify the above config in the Dataloader to change the batch size, image size and depth of the model
train_loader, valid_loader, test_loader = LoadDatasets(return_type='loader')

Total number of images are: 708

Shape inconsistancy found! (256, 256, 123) for /home/arindam/Alzheimer/Data/MIRIAD/miriad/miriad_192_AD_M/miriad_192_AD_M_06_MR_2/miriad_192_AD_M_06_MR_2.nii


Remove shape insconsitent images.
After Removing shape inconsistent images total number of images 707.
Number of Alzheimer infected MRI scans: 464
Number of Healthy MRI scans: 243
Everything is fine. No of images in train, val and test set is 371, 46 and 47 respectively.
Everything is fine. No of images in train, val and test set is 194, 24 and 25 respectively.
Total no of train, validation and test images are 565, 70 and 72 respectively.


In [3]:
# Specify the kernerl size, stride and number of features

kc, kh, kw = 16, 16, 16  # kernel size
dc, dh, dw = 16, 16, 16  # stride
num_features = kc*kh*kw
num_classes = 2
num_of_patch_in_each_image = config['img_size']//16 * config['img_size']//16 * config['depth']//16


# Build the graph using KNN. The graph is built on the patches of the images

def build_graph(batched_images):
    batch, patch_batch = [], []
    for img in batched_images:
        patches = img.unfold(1, kc, dc).unfold(2, kh, dh).unfold(3, kw, dw)
        patches = patches.contiguous().view(patches.size(0), -1, kc, kh, kw)
        patch_batch.append(patches)
        patches = patches.contiguous().view(patches.size(0), -1, kc*kh*kw)
        batch.append(patches)
        
    patched_images = torch.cat(batch, dim=0)  # Shape -> (batch_size, num_of_patch_in_each_image, num_features)
    patch_batch = torch.cat(patch_batch, dim=0)  # Shape -> (batch_size, num_of_patch_in_each_image, kc, kh, kw) This is required for local 3D CNNs
    batch_adj = []
    for i in range(patched_images.shape[0]):
        patches = patched_images[i]
        adj = torch.as_tensor(neighbors.kneighbors_graph(patches, n_neighbors = 64).toarray(), dtype=torch.float32)  # No of neighbors = 64
        adj = adj.reshape(1, adj.shape[0], adj.shape[1])
        batch_adj.append(adj)
        
    adj = torch.cat(batch_adj, dim=0)
    return patch_batch.type(torch.FloatTensor), adj
        

In [4]:
# Check a sample batch size
for data in train_loader:
    images, labels = data
    patched_images, adj = build_graph(images)  
    print(patched_images.shape, adj.shape, labels.shape)
    break

torch.Size([8, 1024, 16, 16, 16]) torch.Size([8, 1024, 1024]) torch.Size([8, 1, 2])


In [5]:
class GCN(nn.Module):

    """
    Base paper: https://arxiv.org/abs/1609.02907
    """
    
    def __init__(self, num_node_features, num_classes, hidden_channels, linear_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)

        # 3D Convolutional layer
        self.conv3d = nn.Conv3d(1, 16, 3, stride=1, padding=1)
        self.maxpool = nn.MaxPool3d(3, stride=2, padding=1)
        self.conv3d_2 = nn.Conv3d(16, 32, 3, stride=2, padding=1)
        self.maxpool_2 = nn.MaxPool3d(3, stride=2, padding=1)
        
        # Graph convolution layer
        self.conv1 = GCNConv(num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels//2)
        self.conv3 = GCNConv(hidden_channels//2, hidden_channels//4)
        self.maxpool1d = nn.MaxPool1d(64) # 32 is the kernel size
        self.avgpool1d = nn.AvgPool1d(64) # 32 is the kernel size
        self.fc1 = nn.Linear(linear_channels, linear_channels*2)
        self.fc2 = nn.Linear(linear_channels*2, linear_channels//2)

        self.fc3 = nn.Linear(linear_channels, linear_channels*2)
        self.fc4 = nn.Linear(linear_channels*2, linear_channels//2)
        self.classify = nn.Linear(linear_channels, num_classes)
        
    
    def forward(self, x, adj, batch_size):
        # 1. Obtain Conv features
        x = x.view(x.shape[0]*x.shape[1], 1, x.shape[2], x.shape[3], x.shape[4])
        x = F.relu(self.conv3d(x))
        x = self.maxpool(x)
        x = F.relu(self.conv3d_2(x))
        x = self.maxpool_2(x)
        xs = x.shape  # Save the shape for reshaping later. format -> (batch_size * num_of_patch_in_each_image, 4, 4, 2)
        x = x.view(x.shape[0]*x.shape[1], -1)
        x = x.view(xs[0], -1)
        
        # 2. Obtain Diagonal Blocks. This is required for the GCNConv layer to work
        block = adj[0]
        for i in range(1, adj.shape[0]):
            block = torch.block_diag(block, adj[i])
        edge_index = block.to_sparse()._indices()
        
        # 2. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 3. Readout layer
        x = x.view(batch_size, -1)
        x_max = self.maxpool1d(x)
        x_max = F.relu(self.fc1(x_max))
        x_max = F.relu(self.fc2(x_max)) # [batch_size, linear_channels*2]  ---> [batch_size, linear_channels/2]

        x_avg = self.avgpool1d(x)
        x_avg = F.relu(self.fc3(x_avg))
        x_avg = F.relu(self.fc4(x_avg))
        x = torch.cat((x_max, x_avg), dim=1)
        
        #x = global_mean_pool(x, batch=torch.tensor([0, 1, 2, 3,  4, 5, 6, 7]).to(device='cuda'), size=batch_size)  # [batch_size, hidden_channels]
        

        # 4. Apply a final classifier
        #x = F.dropout(x, p=0.1, training=self.training)

        x = F.softmax(self.classify(x), dim=-1)
        
        return x

            

In [6]:
# Specify the device and model

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = GCN(num_node_features=256, 
                    num_classes=2,
                    hidden_channels=32,
                    linear_channels=128).to(device)

In [7]:
# Print the model summary
print(model)

GCN(
  (conv3d): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (maxpool): MaxPool3d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (conv3d_2): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
  (maxpool_2): MaxPool3d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (conv1): GCNConv(256, 32)
  (conv2): GCNConv(32, 16)
  (conv3): GCNConv(16, 8)
  (maxpool1d): MaxPool1d(kernel_size=64, stride=64, padding=0, dilation=1, ceil_mode=False)
  (avgpool1d): AvgPool1d(kernel_size=(64,), stride=(64,), padding=(0,))
  (fc1): Linear(in_features=128, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=64, bias=True)
  (fc3): Linear(in_features=128, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=64, bias=True)
  (classify): Linear(in_features=128, out_features=2, bias=True)
)


In [8]:
# Get number of parameters in the model
total_params = sum(p.numel() for p in model.parameters())
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params}, Trainable Parameters {total_trainable_params}\nTotal parameters: {total_params/1000000}M, Trainable Parameters {total_trainable_params/1000000}M")

Total parameters: 122394, Trainable Parameters 122394
Total parameters: 0.122394M, Trainable Parameters 0.122394M


In [9]:
# Get parameters for each layer of the model in a tabular format

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params+=params
    print(table)
    print(f"Total Trainable Params: {total_params}")

count_parameters(model)

+------------------+------------+
|     Modules      | Parameters |
+------------------+------------+
|  conv3d.weight   |    432     |
|   conv3d.bias    |     16     |
| conv3d_2.weight  |   13824    |
|  conv3d_2.bias   |     32     |
|    conv1.bias    |     32     |
| conv1.lin.weight |    8192    |
|    conv2.bias    |     16     |
| conv2.lin.weight |    512     |
|    conv3.bias    |     8      |
| conv3.lin.weight |    128     |
|    fc1.weight    |   32768    |
|     fc1.bias     |    256     |
|    fc2.weight    |   16384    |
|     fc2.bias     |     64     |
|    fc3.weight    |   32768    |
|     fc3.bias     |    256     |
|    fc4.weight    |   16384    |
|     fc4.bias     |     64     |
| classify.weight  |    256     |
|  classify.bias   |     2      |
+------------------+------------+
Total Trainable Params: 122394


In [10]:
# Optimizer
optimizer = optim.Adam(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=0.0005,
            weight_decay=1e-5,
            betas=(0.75, 0.999))

scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [20, 30], gamma=0.1)
loss_fn = F.cross_entropy

In [11]:
def train(train_loader):
    total_time_iter = 0
    model.train()
    start = time.time()
    train_loss, n_samples = 0, 0
    correct = 0
    for batch_idx, data in enumerate(train_loader):
        images, labels = data
        if labels.dim() == 3:
            labels = torch.squeeze(labels)
        patched_images, adj = build_graph(images)
        patched_images, adj, labels = patched_images.to(device), adj.to(device), labels.to(device)
        optimizer.zero_grad()
        output = model(patched_images, adj, batch_size=images.shape[0])
        loss = loss_fn(output, labels.to(dtype=torch.float32), reduction='mean')
        loss.backward()
        optimizer.step()
        time_iter = time.time() - start
        total_time_iter += time_iter
        train_loss += loss.item() * len(output)
        n_samples += len(output)
        #print(output)
        predicted, labels = torch.argmax(output, dim=1), torch.argmax(labels, dim=1)
        #print(predicted, labels)
        correct += (predicted == labels).sum()
        if batch_idx % 10 == 0 or batch_idx == len(train_loader) - 1:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} (avg: {:.6f}) \tsec/iter: {:.4f}%\t Accuracy (avg) {:.3f}'.format(
                epoch, n_samples, len(train_loader.dataset),
                100. * (batch_idx + 1) / len(train_loader), loss.item(), train_loss / n_samples, time_iter / (batch_idx + 1), 100*(correct/n_samples) ))
    scheduler.step()
    return total_time_iter / (batch_idx + 1)

In [12]:
def validation(valid_loader):
    model.eval()
    start = time.time()
    valid_loss, correct, n_samples = 0, 0, 0
    for batch_idx, data in enumerate(valid_loader):
        images, labels = data
        if labels.dim() == 3:
            labels = torch.squeeze(labels)
        patched_images, adj = build_graph(images)
        patched_images, adj, labels = patched_images.to(device), adj.to(device), labels.to(device)
        output = model(patched_images, adj, batch_size=images.shape[0])
        loss = loss_fn(output, labels.to(dtype=torch.float32), reduction='sum')
        valid_loss += loss.item()
        n_samples += len(output)
        pred, labels = torch.argmax(output.data, dim=1), torch.argmax(labels, dim=1)

        correct += (pred == labels).sum()

    time_iter = time.time() - start

    valid_loss /= n_samples

    acc = 100. * correct / n_samples
    print('Validation set (epoch {}): Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%) Took {} sec\n'.format(epoch, 
                                                                                          valid_loss, 
                                                                                          correct, 
                                                                                          n_samples, acc, 
                                                                                          time_iter))
    return valid_loss, acc

In [17]:
from datetime import datetime

epochs=20 # Number of epochs --> 30+20 = 50
# Train the model. Save the model with the best validation accuracy. Only last epoch execution is shown
best_loss = 100000

for epoch in range(epochs):
    train(train_loader)
    valid_loss, acc = validation(valid_loader)
    if valid_loss < best_loss:
        best_loss = valid_loss
        torch.save(model.state_dict(), os.getcwd() + f'/best_model{datetime.now()}.pt')
        print('Model Saved')
    


Validation set (epoch 0): Average loss: 0.3926, Accuracy: 65/70 (92.86%) Took 18.370365142822266 sec

Model Saved
Validation set (epoch 1): Average loss: 0.3924, Accuracy: 65/70 (92.86%) Took 18.533689737319946 sec

Model Saved
Validation set (epoch 2): Average loss: 0.3924, Accuracy: 64/70 (91.43%) Took 18.1994731426239 sec

Model Saved
Validation set (epoch 3): Average loss: 0.3923, Accuracy: 64/70 (91.43%) Took 18.014457941055298 sec

Model Saved
Validation set (epoch 4): Average loss: 0.3922, Accuracy: 64/70 (91.43%) Took 18.20318913459778 sec

Model Saved
Validation set (epoch 5): Average loss: 0.3920, Accuracy: 65/70 (92.86%) Took 17.514727115631104 sec

Model Saved
Validation set (epoch 6): Average loss: 0.3920, Accuracy: 65/70 (92.86%) Took 17.827833890914917 sec

Model Saved
Validation set (epoch 7): Average loss: 0.3921, Accuracy: 64/70 (91.43%) Took 17.62971043586731 sec

Validation set (epoch 8): Average loss: 0.3919, Accuracy: 65/70 (92.86%) Took 17.685707092285156 sec

Mo

In [14]:
def test(test_loader):
    print('Test model ...')
    model.eval()
    start = time.time()
    test_loss, correct, n_samples = 0, 0, 0
    for batch_idx, data in enumerate(test_loader):
        images, labels = data
        patched_images, adj = build_graph(images)
        patched_images, adj, labels = patched_images.to(device), adj.to(device), labels.to(device)
        if labels.dim() == 3:
            labels = torch.squeeze(labels)
        output = model(patched_images, adj, batch_size=images.shape[0])
        loss = loss_fn(output, labels.to(device=device, dtype=torch.float32), reduction='sum')
        test_loss += loss.item()
        n_samples += len(output)
        pred = torch.argmax(output.data, 1)
        print(pred)

        correct += (pred == torch.argmax(labels, dim=1)).sum()

    time_iter = time.time() - start

    test_loss /= n_samples

    acc = 100. * correct / n_samples
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%). Took {} sec time'.format(test_loss, 
                                                                                correct, 
                                                                                n_samples, acc, time_iter))
    return test_loss, acc

In [18]:
# Load the model
loaded_model = GCN(num_node_features=256,
                    num_classes=2,
                    hidden_channels=32,
                    linear_channels=128).to(device)
loaded_model.load_state_dict(torch.load("best_model2023-09-22 16:11:28.118115.pt"))
loaded_model.eval()


def loaded_test(test_loader):
    print('Test model ...')
    loaded_model.eval()
    start = time.time()
    test_loss, correct, n_samples = 0, 0, 0
    for batch_idx, data in enumerate(test_loader):
        images, labels = data
        patched_images, adj = build_graph(images)
        patched_images, adj, labels = patched_images.to(device), adj.to(device), labels.to(device)
        if labels.dim() == 3:
            labels = torch.squeeze(labels)
        output = loaded_model(patched_images, adj, batch_size=images.shape[0])
        loss = loss_fn(output, labels.to(device=device, dtype=torch.float32), reduction='sum')
        test_loss += loss.item()
        n_samples += len(output)
        pred = torch.argmax(output.data, 1)
        #print(pred, torch.argmax(labels, dim=1))

        correct += (pred == torch.argmax(labels, dim=1)).sum()

    time_iter = time.time() - start

    test_loss /= n_samples

    acc = 100. * correct / n_samples
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%). Took {} sec time'.format(test_loss, 
                                                                                correct, 
                                                                                n_samples, acc, time_iter))
    return test_loss, acc

# Test the model
test_loss, acc = loaded_test(test_loader)

Test model ...
Test set: Average loss: 0.3908, Accuracy: 67/72 (93.06%). Took 17.238627910614014 sec time
