In [None]:
# Autoreload modules
%load_ext autoreload
%autoreload 2

In [None]:
# To have access to moduels
import sys,os
sys.path.append(os.path.dirname(os.path.realpath('')) + '/Modules')

In [None]:
import numpy as np

import torch
import torch.nn as nn
from torchvision.transforms import Compose, Resize

from transformers import ViTConfig, ViTFeatureExtractor, ViTForImageClassification

from matplotlib import pyplot as plt

import pandas as pd
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

from dataloader.dataset import ADNI3Channels
from dataloader.dataloader import ADNILoader

from utils.report import sklearn_classification_report

# Dataset and Dataloader Setup

In [None]:
image_size = (384, 384)
resize = Resize(size=image_size)

train_transforms = Compose([resize])
valid_transforms = Compose([resize])
test_transforms = Compose([resize])

In [None]:
train_ds = ADNI3Channels("../Data/Training/", transforms=train_transforms)
valid_ds = ADNI3Channels("../Data/Validation/", transforms=valid_transforms)
test_ds = ADNI3Channels("../Data/Test/", transforms=test_transforms)

In [None]:
idx = 0
image, label = train_ds[idx]

print("Image shape:", image.shape)
print("Label:", label.item())

print("Number of training samples:", len(train_ds))
print("Number of validation samples:", len(valid_ds))
print("Number of test samples:", len(test_ds), "\n")

fig, axes = plt.subplots(ncols=3, figsize=(6, 2), dpi=300)
for i in range(3):
    axes[i].imshow(image[i, :, :])
    axes[i].axis("off");
    # print(image[i, :, :].min(), image[i, :, :].max())

In [None]:
id2label = {0: "CN", 1: "MCI", 2: "AD"}
label2id = {"CN": 0, "MCI": 1, "AD": 2}

print(id2label[label.item()])

In [None]:
train_batch_size = 8
valid_batch_size = 2
test_batch_size = 2

hparams = {'train_ds': train_ds,
           'valid_ds': valid_ds,
           'test_ds': test_ds,
           'train_batch_size': train_batch_size,
           'valid_batch_size': valid_batch_size,
           'test_batch_size': test_batch_size,
           'num_workers': 20,
           'train_shuffle': True,
           'valid_shuffle': False,
           'test_shuffle': False,
           'train_drop_last': True,
           'valid_drop_last': False,
           'test_drop_last': False,
          }

train_dataloader = ADNILoader(**hparams).train_dataloader()
valid_dataloader= ADNILoader(**hparams).validation_dataloader()
test_dataloader = ADNILoader(**hparams).test_dataloader()

batch = next(iter(train_dataloader))
print(batch[0].shape)
print(batch[1].shape)

# Model Development

In [None]:
class Model(nn.Module):
    def __init__(self, num_labels=3, return_last_hidden_state=False):
        super(Model, self).__init__()
        self.return_last_hidden_state = return_last_hidden_state
        self.resize = Resize(image_size)
        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=8, kernel_size=(3, 3)),
            nn.ReLU(),
            nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3, 3)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(3, 3)),
            nn.BatchNorm2d(num_features=16),
            
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3)),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=3, kernel_size=(3, 3)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(3, 3)),
            nn.BatchNorm2d(num_features=3)
        )
        
        self.vit = ViTForImageClassification.from_pretrained('google/vit-base-patch32-384',
                                                             # output_attentions=True,
                                                             output_hidden_states=True,
                                                             num_labels=num_labels,
                                                             hidden_dropout_prob=0.1,
                                                             # attention_probs_dropout_prob=0.1,
                                                             ignore_mismatched_sizes=True)
                
    def forward(self, x):
        outputs = self.cnn(x)
        outputs = self.resize(outputs)
        outputs = self.vit(outputs)
        
        if self.return_last_hidden_state:
            return outputs.logits, outputs.hidden_states
        else:
            return outputs.logits

# Saving Hidden States

In [None]:
device = torch.device('cpu') 
model = Model(num_labels=3).to(device)
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch32-384',
                                                        do_resize=False,
                                                        do_normalize=False)

model.load_state_dict(torch.load("../Hybrid/Best models/Hybrid_loss.pt"))

In [None]:
def predict(model, dataloader, device):
    y_true = []
    y_pred = []
    
    model.eval()
    with torch.no_grad():
        for step, (x, y) in enumerate(dataloader):
            x = np.split(np.array(x), dataloader.batch_size)
            for i in range(len(x)):
                x[i] = np.squeeze(x[i])
            x = torch.tensor(np.stack(feature_extractor(x)['pixel_values'], axis=0))
            x, y  = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(1)
        
            y_pred.append(preds.cpu().numpy())
            y_true.append(y.cpu().numpy())

    y_pred = np.concatenate(y_pred, axis=0)
    y_true = np.concatenate(y_true, axis=0)
    
    return y_true, y_pred

y_true, y_pred = predict(model, test_dataloader, device)
sklearn_classification_report(y_true, y_pred)

# Preparing Activation and Label Dataframes

In [None]:
x_train_df = pd.DataFrame()
y_train_df = pd.DataFrame()

x_test_df = pd.DataFrame()
y_test_df = pd.DataFrame()

model.return_last_hidden_state = True
model.eval()

with torch.no_grad():
    for sample, label in train_dataloader:
        sample = np.split(np.array(sample), train_batch_size)
        for i in range(len(sample)):
            sample[i] = np.squeeze(sample[i])
        sample = torch.tensor(np.stack(feature_extractor(sample)['pixel_values'], axis=0))
        sample, label  = sample.to(device), label.to(device)
        _, activation = model(sample)

        act_df = pd.DataFrame(activation[12].reshape((train_batch_size, 145*768)).detach().numpy())
        label_df = pd.DataFrame(label.detach().numpy())

        x_train_df = pd.concat([x_train_df, act_df])
        y_train_df = pd.concat([y_train_df, label_df])

with torch.no_grad():
    for sample, label in test_dataloader:
        sample = np.split(np.array(sample), test_batch_size)
        for i in range(len(sample)):
            sample[i] = np.squeeze(sample[i])
        sample = torch.tensor(np.stack(feature_extractor(sample)['pixel_values'], axis=0))
        sample, label  = sample.to(device), label.to(device)
        _, activation = model(sample)

        act_df = pd.DataFrame(activation[12].reshape((test_batch_size, 145*768)).detach().numpy())
        label_df = pd.DataFrame(label.detach().numpy())

        x_test_df = pd.concat([x_test_df, act_df])
        y_test_df = pd.concat([y_test_df, label_df])

# PCA

In [None]:
x_train_df = x_train_df.reset_index(drop=True)
y_train_df = y_train_df.rename(columns={0: "target"}).reset_index(drop=True)
x_test_df = x_test_df.reset_index(drop=True)
y_test_df = y_test_df.rename(columns={0: "target"}).reset_index(drop=True)

scaler = StandardScaler()
x_train_std = scaler.fit_transform(x_train_df)
x_test_std = scaler.fit_transform(x_test_df)

pca = PCA(n_components=2)
pca_train = pca.fit_transform(x_train_std)
pca_test = pca.fit_transform(x_test_std)

pca_train_df = pd.DataFrame(pca_train, columns = ['principal component 1', 'principal component 2'])
pca_test_df = pd.DataFrame(pca_test, columns = ['principal component 1', 'principal component 2'])

final_train_df = pd.concat([pca_train_df, y_train_df], axis = 1)
final_test_df = pd.concat([pca_test_df, y_test_df], axis = 1)

final_df = pd.concat([final_train_df, final_test_df], axis = 0, keys=['train', 'test'])
final_df

In [None]:
fig, ax = plt.subplots(figsize = (8, 8), dpi=300)
targets = [0, 1, 2]
colors = ['r', 'g', 'b']

for i in range(len(targets)):
    sample_df = final_df[final_df['target'] == targets[i]]
    
    ax.scatter(sample_df.loc['train', 'principal component 1'],
               sample_df.loc['train', 'principal component 2'],
               s=150,
               alpha=0.3,
               c=colors[i],
               marker='o',
               label=f'{id2label[targets[i]]} (Train)'
              )
    
    ax.scatter(sample_df.loc['test', 'principal component 1'],
               sample_df.loc['test', 'principal component 2'],
               s=150,
               alpha=0.3,
               c=colors[i],
               marker='*',
               label=f'{id2label[targets[i]]} (Test)'
              )
    
ax.legend();
ax.set_xlabel('Principle Component 1')
ax.set_ylabel('Principle Component 2');
# plt.savefig('PCA_Hybrid.png')