# Preparing images for the paper

In [None]:
# Autoreload modules
%load_ext autoreload
%autoreload 2

In [None]:
# To have access to moduels
import sys,os
sys.path.append(os.path.dirname(os.path.realpath('')) + '/Modules')

In [None]:
from torchvision.transforms import Resize, Compose

from matplotlib import pyplot as plt

from dataloader.dataset import ADNI3Channels, ADNI
from dataloader.dataloader import ADNILoader
from atlas.atlas import AAL, AAL3Channels

import nibabel as nib
import numpy as np
from utils.image import save_image

## Raw images

In [None]:
image = nib.load("raw.nii").get_fdata()
print(image.shape)

In [None]:
fig, axes = plt.subplots(ncols=3, dpi=300, frameon=False)

axes[0].imshow(np.rot90(image[90, :, :]))
axes[0].axis("off");
    
axes[1].imshow(np.rot90(image[:, 80, :]))
axes[1].axis("off");

axes[2].imshow(np.rot90(image[:, :, 40]))
axes[2].axis("off");

## 60-channel pre-processed images

In [None]:
train_ds = ADNI("../Data/Training/", transforms=None, rotate=True)

idx = 0
image, label = train_ds[idx]
print(image.shape)

In [None]:
fig, axes = plt.subplots(nrows=4, ncols=15, frameon=False, figsize=(5,1.5), dpi=300)
for i in range(4):
    for j in range(15):
        axes[i][j].imshow(image[(i * 15) + j, :, :])
        axes[i][j].axis("off");

## 3-channel pre-processed images

In [None]:
train_ds = ADNI3Channels("../Data/Training/", transforms=None, rotate=True)

idx = 0
image, label = train_ds[idx]
print(image.shape)

In [None]:
fig, axes = plt.subplots(ncols=3, dpi=300, frameon=False)
for i in range(3):  
    axes[i].imshow(image[i, :, :])
    axes[i].axis("off");

save_image("3channel", image)

## Patches

In [None]:
image_size = (384, 384)
resize = Resize(size=image_size)
train_transforms = Compose([resize])

train_ds = ADNI3Channels("../Data/Training/", transforms=train_transforms, rotate=True)

idx = 0
image, label = train_ds[idx]
print(image.shape)

In [None]:
fig, axes = plt.subplots(ncols=3, dpi=300, frameon=False)
for i in range(3):  
    axes[i].imshow(image[i, :, :])
    axes[i].axis("off");

In [None]:
r = [0, 127, 254, 380]
fig, axes = plt.subplots(nrows=3, ncols=3, frameon=False, figsize=(3, 3), dpi=300)
for row in range(3):
    for col in range(3):
        axes[row][col].imshow(image[2, r[row]:r[row + 1], r[col]:r[col + 1]])
        axes[row][col].axis("off");

## Atlas

In [None]:
aal_dir = '../Data/AAL/Resized_AAL.nii'
labels_dir = '../Data/AAL/ROI_MNI_V4.txt'

In [None]:
atlas_data, atlas_labels = AAL(aal_dir=aal_dir,
                                                 labels_dir=labels_dir,
                                                 rotate=True).get_data()

print(atlas_data.shape)

In [None]:
fig, axes = plt.subplots(nrows=4, ncols=15, frameon=False, figsize=(5,1.5), dpi=300)
for i in range(4):
    for j in range(15):
        axes[i][j].imshow(atlas_data[(i * 15) + j, :, :])
        axes[i][j].axis("off");

In [None]:
atlas_data, atlas_labels = AAL3Channels(aal_dir=aal_dir,
                                                 labels_dir=labels_dir,
                                                 rotate=True).get_data()

print(atlas_data.shape)

In [None]:
fig, axes = plt.subplots(ncols=3, dpi=300, frameon=False)
for i in range(3):  
    axes[i].imshow(atlas_data[i, :, :])
    axes[i].axis("off");

save_image("AAL3", atlas_data)