In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
!pip install monai
!pip install einops

Collecting monai
  Downloading monai-1.2.0-202306081546-py3-none-any.whl (1.3 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.3 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.2/1.3 MB[0m [31m5.3 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━[0m [32m0.7/1.3 MB[0m [31m10.2 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.3/1.3 MB[0m [31m13.2 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m11.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: monai
Successfully installed monai-1.2.0
Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collec

In [None]:
import os
import matplotlib.pyplot as plt
import torch
import numpy as np
from sklearn.metrics import classification_report

from monai.data import decollate_batch, DataLoader, CacheDataset, ThreadDataLoader
from monai.metrics import ROCAUCMetric
from monai.networks.nets import DenseNet121
from monai.transforms import (
    Activations,
    EnsureChannelFirstd,
    Compose,
    LoadImaged,
    ScaleIntensityd,
    EnsureTyped,
    Resized,
    CropForegroundd,
    SpatialPadd,
    CastToTyped,
    ConcatItemsd,
)
from monai.utils import set_determinism
from tqdm import tqdm

In [None]:
import sys
import os
import requests

import torch
import numpy as np

import matplotlib.pyplot as plt
from PIL import Image

# check whether run in Colab
if 'google.colab' in sys.modules:
    print('Running in Colab.')
    !pip3 install timm==0.4.5  # 0.3.2 does not work in Colab
    !git clone https://github.com/facebookresearch/mae.git
    sys.path.append('./mae')
else:
    sys.path.append('..')
import models_mae

Running in Colab.
Collecting timm==0.4.5
  Downloading timm-0.4.5-py3-none-any.whl (287 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m287.4/287.4 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: timm
Successfully installed timm-0.4.5
Cloning into 'mae'...
remote: Enumerating objects: 39, done.[K
remote: Total 39 (delta 0), reused 0 (delta 0), pack-reused 39[K
Receiving objects: 100% (39/39), 829.54 KiB | 2.17 MiB/s, done.
Resolving deltas: 100% (12/12), done.


In [None]:
imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])

def show_image(image, title=''):
    # image is [H, W, 3]
    assert image.shape[2] == 3
    plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int())
    plt.title(title, fontsize=16)
    plt.axis('off')
    return

def prepare_model(chkpt_dir, arch='mae_vit_large_patch16'):
    # build model
    model = getattr(models_mae, arch)()
    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    print(msg)
    return model

def run_one_image(img, model):
    x = torch.tensor(img)

    # make it a batch-like
    x = x.unsqueeze(dim=0)
    x = torch.einsum('nhwc->nchw', x)

    # run MAE
    loss, y, mask = model(x.float(), mask_ratio=0.75)
    y = model.unpatchify(y)
    y = torch.einsum('nchw->nhwc', y).detach().cpu()

    # visualize the mask
    mask = mask.detach()
    mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3)  # (N, H*W, p*p*3)
    mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()

    x = torch.einsum('nchw->nhwc', x)

    # masked image
    im_masked = x * (1 - mask)

    # MAE reconstruction pasted with visible patches
    im_paste = x * (1 - mask) + y * mask

    # make the plt figure larger
    plt.rcParams['figure.figsize'] = [24, 24]

    plt.subplot(1, 4, 1)
    show_image(x[0], "original")

    plt.subplot(1, 4, 2)
    show_image(im_masked[0], "masked")

    plt.subplot(1, 4, 3)
    show_image(y[0], "reconstruction")

    plt.subplot(1, 4, 4)
    show_image(im_paste[0], "reconstruction + visible")

    plt.show()

In [None]:
from numpy.core.arrayprint import printoptions
import pandas as pd
import os
from enum import Enum
from monai.transforms import Transform

# modality_names = ["MR T1w", "MR T2w", "MR T2*", "MR FLAIR", "MR TOF-MRA"]
# modality_names = ["MR T1w", "MR T2w", "MR T2*", "MR FLAIR"]
# modality_names = ["MR T1w", "MR T2w", "MR T2*", "MR TOF-MRA"]
# modality_names = ["MR T1w", "MR T2w", "MR FLAIR", "MR TOF-MRA"]
# modality_names = ["MR T1w", "MR T2*", "MR FLAIR", "MR TOF-MRA"]
# modality_names = ["MR T2w", "MR T2*", "MR FLAIR", "MR TOF-MRA"]
#modality_names = ["MR T1w", "MR T2w", "MR T2*"]
# modality_names = ["MR T1w", "MR T2w", "MR FLAIR"]
# modality_names = ["MR T1w", "MR T2w","MR TOF-MRA"]
# modality_names = ["MR T1w", "MR T2*", "MR FLAIR"]
# modality_names = ["MR T1w", "MR T2*", "MR TOF-MRA"]
# modality_names = ["MR T1w", "MR FLAIR", "MR TOF-MRA"]
# modality_names = ["MR T2w", "MR T2*", "MR TOF-MRA"]
# modality_names = ["MR T2w", "MR T2*", "MR FLAIR"]
# modality_names = ["MR T2w", "MR FLAIR", "MR TOF-MRA"]
# modality_names = ["MR T2*", "MR FLAIR", "MR TOF-MRA"]
# modality_names = ["MR T1w", "MR T2w"]
# modality_names = ["MR T1w", "MR T2*"]
# modality_names = ["MR T1w", "MR FLAIR"]
# modality_names = ["MR T1w", "MR TOF-MRA"]
modality_names = ["MR T2w", "MR TOF-MRA"]


def create_oasis_3_multimodal_dataset(csv_path: str, dataset_root: str, transform: Transform, cache_rate: float, missing_modality: str):
    train_df = pd.read_csv(csv_path, sep=";")
    train_df.fillna('', inplace=True)

    train_data = []
    for index, row in train_df.iterrows():
        data_dict = {}
        has_non_empty = False
        for modality in modality_names:
            file_path = row[modality]
            if file_path:
                has_non_empty = True
                data_dict[modality] = os.path.join(dataset_root, file_path)
            else:
                if missing_modality == "zeros":
                    data_dict[modality] = "/content/drive/MyDrive/OASIS-3-MR-Sessions-2D/OAS/empty_volume_2d.nii.gz"
                elif missing_modality == "gauss":
                    data_dict[modality] = "/content/drive/MyDrive/OASIS-3-MR-Sessions-2D/OAS/gauss_2d_256.nii.gz"
                else:
                    raise ValueError(f"Invalid missing modality key {missing_modality}")
        if not has_non_empty:
            continue
        data_dict["label"] = row["label"]
        train_data.append(data_dict)
    print(train_data)
    return CacheDataset(data=train_data, transform=transform, cache_rate=cache_rate, num_workers=5, copy_cache=False)

class SafeCropForegroundd:
    def __init__(self, keys, source_key, select_fn, margin=0):
        self.source_key = source_key
        self.crop_foreground = CropForegroundd(keys=keys, source_key=source_key, select_fn=select_fn, margin=margin)

    def __call__(self, data):
        cropped_data = self.crop_foreground(data.copy())
        cropped_image = cropped_data[self.source_key]

        # Check if any dimension (excluding batch and channel dimensions) is zero.
        if np.any(np.asarray(cropped_image.shape[1:]) == 0):
            return data  # Revert to original data if cropped size is zero in any dimension

        return cropped_data

In [None]:
resolution = 224
cache_rate = 0.1
batch_size = 16
dataset_root = r"/content/drive/MyDrive/OASIS-3-MR-Sessions-2D"
missing_modality = "gauss"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
foreground_crop_threshold = 0.1
transform_list = [
        LoadImaged(keys=modality_names, image_only=True),
        EnsureChannelFirstd(keys=modality_names + ["label"], channel_dim="no_channel"),
        CastToTyped("label", dtype=np.float64),
        ScaleIntensityd(keys=modality_names),
    ]
for i in range(len(modality_names)):
     transform_list.append(
         SafeCropForegroundd(keys=modality_names[i], source_key=modality_names[i], select_fn=lambda x: x > foreground_crop_threshold, margin=5)
     )
transform_list.extend([
    Resized(keys=modality_names, spatial_size=resolution, size_mode="longest"),
    SpatialPadd(keys=modality_names, spatial_size=(resolution, resolution)),
    ConcatItemsd(keys=modality_names, name="image"),
    EnsureTyped(keys=["image"], device=device),
]
)
transform = Compose(transform_list)

train_table_path = r"/content/drive/MyDrive/MLMI/oasis_3_multimodal_train.csv"
train_ds = create_oasis_3_multimodal_dataset(csv_path=train_table_path, dataset_root=dataset_root, transform=transform, cache_rate=cache_rate, missing_modality=missing_modality)
train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=batch_size, shuffle=True)

val_table_path = r"/content/drive/MyDrive/MLMI/oasis_3_multimodal_val_all.csv"
val_ds = create_oasis_3_multimodal_dataset(csv_path=val_table_path, dataset_root=dataset_root, transform=transform, cache_rate=cache_rate, missing_modality=missing_modality)
val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=batch_size, shuffle=True)

[{'MR T2w': '/content/drive/MyDrive/OASIS-3-MR-Sessions-2D/OAS31114_MR_d2658/anat7/NIFTI/sub-OAS31114_ses-d2658_T2w.nii.gz', 'MR TOF-MRA': '/content/drive/MyDrive/OASIS-3-MR-Sessions-2D/OAS31114_MR_d2658/anat1/NIFTI/sub-OAS31114_ses-d2658_acq-TOF_angio.nii.gz', 'label': 0}, {'MR T2w': '/content/drive/MyDrive/OASIS-3-MR-Sessions-2D/OAS31232_MR_d0159/anat4/NIFTI/sub-OAS31232_sess-d0159_T2w.nii.gz', 'MR TOF-MRA': '/content/drive/MyDrive/OASIS-3-MR-Sessions-2D/OAS/gauss_2d_256.nii.gz', 'label': 0}, {'MR T2w': '/content/drive/MyDrive/OASIS-3-MR-Sessions-2D/OAS30055_MR_d0566/anat5/NIFTI/sub-OAS30055_sess-d0566_T2w.nii.gz', 'MR TOF-MRA': '/content/drive/MyDrive/OASIS-3-MR-Sessions-2D/OAS30055_MR_d0566/anat1/NIFTI/sub-OAS30055_sess-d0566_acq-TOF_angio.nii.gz', 'label': 0}, {'MR T2w': '/content/drive/MyDrive/OASIS-3-MR-Sessions-2D/OAS30298_MR_d0292/anat5/NIFTI/sub-OAS30298_ses-d0292_T2w.nii.gz', 'MR TOF-MRA': '/content/drive/MyDrive/OASIS-3-MR-Sessions-2D/OAS/gauss_2d_256.nii.gz', 'label': 1}, 

Loading dataset: 100%|██████████| 222/222 [01:40<00:00,  2.20it/s]


[{'MR T2w': '/content/drive/MyDrive/OASIS-3-MR-Sessions-2D/OAS30419_MR_d0621/anat4/NIFTI/sub-OAS30419_ses-d0621_T2w.nii.gz', 'MR TOF-MRA': '/content/drive/MyDrive/OASIS-3-MR-Sessions-2D/OAS/gauss_2d_256.nii.gz', 'label': 0}, {'MR T2w': '/content/drive/MyDrive/OASIS-3-MR-Sessions-2D/OAS30935_MR_d0427/anat5/NIFTI/sub-OAS30935_ses-d0427_T2w.nii.gz', 'MR TOF-MRA': '/content/drive/MyDrive/OASIS-3-MR-Sessions-2D/OAS/gauss_2d_256.nii.gz', 'label': 0}, {'MR T2w': '/content/drive/MyDrive/OASIS-3-MR-Sessions-2D/OAS31373_MR_d0133/anat2/NIFTI/sub-OAS31373_sess-d0133_acq-TSE_T2w.nii.gz', 'MR TOF-MRA': '/content/drive/MyDrive/OASIS-3-MR-Sessions-2D/OAS31373_MR_d0133/anat1/NIFTI/sub-OAS31373_sess-d0133_acq-TOF_angio.nii.gz', 'label': 0}, {'MR T2w': '/content/drive/MyDrive/OASIS-3-MR-Sessions-2D/OAS30976_MR_d0065/anat4/NIFTI/sub-OAS30976_ses-d0065_T2w.nii.gz', 'MR TOF-MRA': '/content/drive/MyDrive/OASIS-3-MR-Sessions-2D/OAS/gauss_2d_256.nii.gz', 'label': 1}, {'MR T2w': '/content/drive/MyDrive/OASIS-3-

Loading dataset: 100%|██████████| 25/25 [00:10<00:00,  2.30it/s]


### Define network and optimizer

In [None]:
# This is an MAE model trained with pixels as targets for visualization (ViT-Large, training mask ratio=0.75)

# download checkpoint if not exist
!wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth

chkpt_dir = 'mae_visualize_vit_large.pth'
model_mae = prepare_model(chkpt_dir, 'mae_vit_large_patch16')
print('Model loaded.')

File ‘mae_visualize_vit_large.pth’ already there; not retrieving.

<All keys matched successfully>
Model loaded.


In [None]:
from einops import repeat, rearrange
import torch.nn as nn
class ViT_Classifier(torch.nn.Module):
    def __init__(self, mae, num_classes=1) -> None:
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(len(modality_names), 3, kernel_size=1, stride=1, padding=0, bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(3)
        )
        self.mae = mae
        self.head = torch.nn.Linear(mae.pos_embed.shape[-1], num_classes)

    def forward(self, img):
        img = self.conv(img)
        features, mask, ids_restore = self.mae.forward_encoder(img, 0)
        features = rearrange(features, 'b t c -> t b c')
        logits = self.head(features[0])
        return logits
chkpt_dir = 'mae_visualize_vit_large.pth'
model_mae = prepare_model(chkpt_dir, 'mae_vit_large_patch16').cuda()
print('Model loaded.')
model = ViT_Classifier(model_mae).cuda()

<All keys matched successfully>
Model loaded.


In [None]:
model = model.to(device)
loss_function = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-5)
scaler = torch.cuda.amp.GradScaler()
max_epochs = 20
val_interval = 1
auc_metric = ROCAUCMetric()
out_model_dir = "/content/drive/MyDrive/M3AE_N/pretrained"
model_file_name = f"inputfusion_ad_cls_oasis_3_t2mra_gauss.pth"
model_file_name

'inputfusion_ad_cls_oasis_3_t2mra_gauss.pth'

### Training

In [None]:
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []

y_pred_trans = Compose([Activations(sigmoid=True)])

for epoch in range(max_epochs):
    model.train()

    with tqdm(train_loader, unit="batch") as tepoch:
        for batch_data in tepoch:
            tepoch.set_description(f"Epoch {epoch + 1} / {max_epochs}")

            inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                outputs = model(inputs)
                loss = loss_function(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            tepoch.set_postfix(loss=loss.item())

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            y_pred = torch.tensor([], dtype=torch.float32, device=device)
            y = torch.tensor([], dtype=torch.long, device=device)
            for val_data in val_loader:
                val_images, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                y_pred = torch.cat([y_pred, model(val_images)], dim=0)
                y = torch.cat([y, val_labels], dim=0)
            y_onehot = torch.cat([i for i in decollate_batch(y, detach=False)], dim=0)
            y_pred_act = torch.cat([y_pred_trans(i) for i in decollate_batch(y_pred)], dim=0)
            auc_metric(y_pred_act, y_onehot)
            result = auc_metric.aggregate()
            auc_metric.reset()
            metric_values.append(result)
            acc_value = torch.eq((y_pred_act > 0.5).long(), y)
            acc_metric = acc_value.float().mean().item()
            del y_pred_act, y_onehot
            if result > best_metric:
                best_metric = result
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), os.path.join(out_model_dir, model_file_name))
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current AUC: {result:.4f}"
                f" current accuracy: {acc_metric:.4f}"
                f" best AUC: {best_metric:.4f}"
                f" at epoch: {best_metric_epoch}"
            )

print(f"Training completed, best_metric (AUC): {best_metric:.4f} " f"at epoch: {best_metric_epoch}")

Epoch 1 / 20: 100%|██████████| 140/140 [50:38<00:00, 21.70s/batch, loss=0.209]


saved new best metric model
current epoch: 1 current AUC: 0.7942 current accuracy: 0.7898 best AUC: 0.7942 at epoch: 1


Epoch 2 / 20: 100%|██████████| 140/140 [02:35<00:00,  1.11s/batch, loss=0.876]


saved new best metric model
current epoch: 2 current AUC: 0.8123 current accuracy: 0.6664 best AUC: 0.8123 at epoch: 2


Epoch 3 / 20: 100%|██████████| 140/140 [02:35<00:00,  1.11s/batch, loss=0.629]


current epoch: 3 current AUC: 0.8029 current accuracy: 0.7339 best AUC: 0.8123 at epoch: 2


Epoch 4 / 20: 100%|██████████| 140/140 [02:34<00:00,  1.10s/batch, loss=0.494]


current epoch: 4 current AUC: 0.7806 current accuracy: 0.7386 best AUC: 0.8123 at epoch: 2


Epoch 5 / 20: 100%|██████████| 140/140 [02:33<00:00,  1.10s/batch, loss=0.123]


current epoch: 5 current AUC: 0.8032 current accuracy: 0.7107 best AUC: 0.8123 at epoch: 2


Epoch 6 / 20: 100%|██████████| 140/140 [02:32<00:00,  1.09s/batch, loss=0.00934]


current epoch: 6 current AUC: 0.7879 current accuracy: 0.6804 best AUC: 0.8123 at epoch: 2


Epoch 7 / 20: 100%|██████████| 140/140 [02:33<00:00,  1.10s/batch, loss=0.00291]


current epoch: 7 current AUC: 0.6918 current accuracy: 0.7502 best AUC: 0.8123 at epoch: 2


Epoch 8 / 20: 100%|██████████| 140/140 [02:32<00:00,  1.09s/batch, loss=0.00281]


current epoch: 8 current AUC: 0.7973 current accuracy: 0.6571 best AUC: 0.8123 at epoch: 2


Epoch 9 / 20: 100%|██████████| 140/140 [02:32<00:00,  1.09s/batch, loss=0.00124]


current epoch: 9 current AUC: 0.7760 current accuracy: 0.7572 best AUC: 0.8123 at epoch: 2


Epoch 10 / 20: 100%|██████████| 140/140 [02:31<00:00,  1.08s/batch, loss=0.000566]


current epoch: 10 current AUC: 0.7939 current accuracy: 0.7409 best AUC: 0.8123 at epoch: 2


Epoch 11 / 20: 100%|██████████| 140/140 [02:30<00:00,  1.07s/batch, loss=0.000446]


current epoch: 11 current AUC: 0.7844 current accuracy: 0.7363 best AUC: 0.8123 at epoch: 2


Epoch 12 / 20: 100%|██████████| 140/140 [02:29<00:00,  1.07s/batch, loss=0.000401]


current epoch: 12 current AUC: 0.7798 current accuracy: 0.7293 best AUC: 0.8123 at epoch: 2


Epoch 13 / 20: 100%|██████████| 140/140 [02:30<00:00,  1.08s/batch, loss=0.000353]


current epoch: 13 current AUC: 0.7670 current accuracy: 0.7526 best AUC: 0.8123 at epoch: 2


Epoch 14 / 20: 100%|██████████| 140/140 [02:30<00:00,  1.08s/batch, loss=0.000337]


current epoch: 14 current AUC: 0.7585 current accuracy: 0.7433 best AUC: 0.8123 at epoch: 2


Epoch 15 / 20: 100%|██████████| 140/140 [02:30<00:00,  1.08s/batch, loss=0.000253]


current epoch: 15 current AUC: 0.7586 current accuracy: 0.7479 best AUC: 0.8123 at epoch: 2


Epoch 16 / 20: 100%|██████████| 140/140 [02:30<00:00,  1.08s/batch, loss=0.000174]


current epoch: 16 current AUC: 0.7594 current accuracy: 0.7409 best AUC: 0.8123 at epoch: 2


Epoch 17 / 20: 100%|██████████| 140/140 [02:29<00:00,  1.07s/batch, loss=0.000222]


current epoch: 17 current AUC: 0.8058 current accuracy: 0.6595 best AUC: 0.8123 at epoch: 2


Epoch 18 / 20: 100%|██████████| 140/140 [02:29<00:00,  1.07s/batch, loss=0.000364]


current epoch: 18 current AUC: 0.7510 current accuracy: 0.7456 best AUC: 0.8123 at epoch: 2


Epoch 19 / 20: 100%|██████████| 140/140 [02:30<00:00,  1.07s/batch, loss=0.000193]


current epoch: 19 current AUC: 0.7653 current accuracy: 0.7339 best AUC: 0.8123 at epoch: 2


Epoch 20 / 20: 100%|██████████| 140/140 [02:31<00:00,  1.08s/batch, loss=0.0159]


current epoch: 20 current AUC: 0.7915 current accuracy: 0.6781 best AUC: 0.8123 at epoch: 2
Training completed, best_metric (AUC): 0.8123 at epoch: 2
