In [None]:
# Autoreload modules
%load_ext autoreload
%autoreload 2

In [None]:
# To have access to moduels
import sys,os
sys.path.append(os.path.dirname(os.path.realpath('')) + '/Modules')

In [None]:
import numpy as np

import torch
import torch.nn as nn

from matplotlib import pyplot as plt

import pandas as pd
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

from dataloader.dataset import ADNI
from dataloader.dataloader import ADNILoader

from utils.report import sklearn_classification_report

# Dataset and Dataloader Setup

In [None]:
train_ds = ADNI("../Data/Training/", transforms=None, extra_channel_dim=True, rotate=True)
valid_ds = ADNI("../Data/Validation/", transforms=None, extra_channel_dim=True, rotate=True)
test_ds = ADNI("../Data/Test/", transforms=None, extra_channel_dim=True, rotate=True)

In [None]:
idx = 0
image, label = train_ds[idx]

print("Image shape:", image.shape)
print("Label:", label.item())

print("Number of training samples:", len(train_ds))
print("Number of validation samples:", len(valid_ds))
print("Number of test samples:", len(test_ds), "\n")

image = image[0, :, :]
fig, axes = plt.subplots(nrows=6, ncols=10, figsize=(3, 2), dpi=300)
for row in range(6):
    for col in range(10):
        idx = row * 10 + col
        axes[row, col].imshow(image[idx, :, :])
        axes[row, col].axis("off");

In [None]:
id2label = {0: "CN", 1: "MCI", 2: "AD"}
label2id = {"CN": 0, "MCI": 1, "AD": 2}

print(id2label[label.item()])

In [None]:
train_batch_size = 8
valid_batch_size = 2
test_batch_size = 2

hparams = {'train_ds': train_ds,
           'valid_ds': valid_ds,
           'test_ds': test_ds,
           'train_batch_size': train_batch_size,
           'valid_batch_size': valid_batch_size,
           'test_batch_size': test_batch_size,
           'num_workers': 20,
           'train_shuffle': False,
           'valid_shuffle': False,
           'test_shuffle': False,
           'train_drop_last': False,
           'valid_drop_last': False,
           'test_drop_last': False,
          }

train_dataloader = ADNILoader(**hparams).train_dataloader()
valid_dataloader= ADNILoader(**hparams).validation_dataloader()
test_dataloader = ADNILoader(**hparams).test_dataloader()

batch = next(iter(train_dataloader))
print(batch[0].shape)
print(batch[1].shape)

# Model

In [None]:
class CNN(nn.Module):
    def __init__(self, num_labels=3):
        super(CNN, self).__init__()
        self.num_labels = num_labels
        
        self.model = nn.Sequential(
            nn.Conv3d(in_channels=1, out_channels=5, kernel_size=(3, 3, 3)),
            nn.ReLU(),            
            nn.Conv3d(in_channels=5, out_channels=10, kernel_size=(3, 3, 3)),
            nn.ReLU(),
            nn.Dropout3d(0.2),
            
            nn.MaxPool3d(kernel_size=(3, 3, 3)),
            nn.BatchNorm3d(num_features=10),
            
            nn.Conv3d(in_channels=10, out_channels=15, kernel_size=(3, 3, 3)),
            nn.ReLU(),
            nn.Conv3d(in_channels=15, out_channels=20, kernel_size=(3, 3, 3)),
            nn.ReLU(),
            nn.Dropout3d(0.2),
            
            nn.MaxPool3d(kernel_size=(3, 3, 3)),
            nn.BatchNorm3d(num_features=20),
            
            nn.Flatten(),
            
            nn.Linear(4480, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 3),
        )
        
    def forward(self, x):
        return self.model(x)

In [None]:
model = CNN()
model

# Dissecting Model and Saving Activation Maps

In [None]:
class dissected_CNN(CNN):
    def __init__(self, return_last_activation=False):
        super(dissected_CNN, self).__init__()
        self.features = self.model[:21]
        self.classifier = self.model[21]
        self.return_last_activation = return_last_activation
        
    def forward(self, x):
        f = self.features(x)
        logits = self.classifier(f)
        if self.return_last_activation:
            return logits, f
        else:
            return logits

In [None]:
device = torch.device('cpu') 
dissected_model = dissected_CNN(return_last_activation=False).to(device)
dissected_model.load_state_dict(torch.load("../CNN/Best models/CNN_3D_loss.pt"), strict=False);

In [None]:
def predict(model, dataloader, device):
    y_true = []
    y_pred = []

    model.eval()
    with torch.no_grad():
        for step, (x, y) in enumerate(dataloader):
            x, y  = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(1)
            y_pred.append(preds.cpu().numpy())
            y_true.append(y.cpu().numpy())

    y_pred = np.concatenate(y_pred, axis=0)
    y_true = np.concatenate(y_true, axis=0)
    
    return y_true, y_pred

y_true, y_pred = predict(dissected_model, test_dataloader, device)
sklearn_classification_report(y_true, y_pred)

# Preparing Activation and Label Dataframes

In [None]:
x_train_df = pd.DataFrame()
y_train_df = pd.DataFrame()
x_test_df = pd.DataFrame()
y_test_df = pd.DataFrame()

dissected_model.return_last_activation = True
dissected_model.eval()

with torch.no_grad():
    for sample, label in train_dataloader:
            out, f = dissected_model(sample.to(device))

            act_df = pd.DataFrame(f.detach().numpy())
            label_df = pd.DataFrame(label.detach().numpy())

            x_train_df = pd.concat([x_train_df, act_df])
            y_train_df = pd.concat([y_train_df, label_df])

with torch.no_grad():
    for sample, label in test_dataloader:
        out, f = dissected_model(sample.to(device))

        act_df = pd.DataFrame(f.detach().numpy())
        label_df = pd.DataFrame(label.detach().numpy())

        x_test_df = pd.concat([x_test_df, act_df])
        y_test_df = pd.concat([y_test_df, label_df])

# PCA

In [None]:
x_train_df = x_train_df.reset_index(drop=True)
y_train_df = y_train_df.rename(columns={0: "target"}).reset_index(drop=True)
x_test_df = x_test_df.reset_index(drop=True)
y_test_df = y_test_df.rename(columns={0: "target"}).reset_index(drop=True)

scaler = StandardScaler()
x_train_std = scaler.fit_transform(x_train_df)
x_test_std = scaler.fit_transform(x_test_df)

pca = PCA(n_components=2)
pca_train = pca.fit_transform(x_train_std)
pca_test = pca.fit_transform(x_test_std)

pca_train_df = pd.DataFrame(pca_train, columns = ['principal component 1', 'principal component 2'])
pca_test_df = pd.DataFrame(pca_test, columns = ['principal component 1', 'principal component 2'])

final_train_df = pd.concat([pca_train_df, y_train_df], axis = 1)
final_test_df = pd.concat([pca_test_df, y_test_df], axis = 1)

final_df = pd.concat([final_train_df, final_test_df], axis = 0, keys=['train', 'test'])
final_df

In [None]:
fig, ax = plt.subplots(figsize = (8, 8), dpi=300)
targets = [0, 1, 2]
colors = ['r', 'g', 'b']

for i in range(len(targets)):
    sample_df = final_df[final_df['target'] == targets[i]]
    
    ax.scatter(sample_df.loc['train', 'principal component 1'],
               sample_df.loc['train', 'principal component 2'],
               s=150,
               alpha=0.3,
               c=colors[i],
               marker='o',
               label=f'{id2label[targets[i]]} (Train)'
              )
    
    ax.scatter(sample_df.loc['test', 'principal component 1'],
               sample_df.loc['test', 'principal component 2'],
               s=150,
               alpha=0.3,
               c=colors[i],
               marker='*',
               label=f'{id2label[targets[i]]} (Test)'
              )
    
ax.legend();
ax.set_xlabel('Principle Component 1')
ax.set_ylabel('Principle Component 2');
# plt.savefig('PCA_CNN_3D.png')