<a href="https://colab.research.google.com/github/KarineAyrs/science_work/blob/main/training/monai.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [None]:
%cd drive/MyDrive/vit-pytorch

In [3]:
#!git clone https://github.com/KarineAyrs/vit-pytorch.git

In [4]:
#%cd vit-pytorch

In [5]:
!ls

examples  LICENSE      README.md  tests        vit_pytorch.egg-info
images	  MANIFEST.in  setup.py   vit_pytorch


In [6]:
!python -m pip install -Ue .

Obtaining file:///content/drive/MyDrive/vit-pytorch
Installing collected packages: vit-pytorch
  Attempting uninstall: vit-pytorch
    Found existing installation: vit-pytorch 0.26.7
    Can't uninstall 'vit-pytorch'. No files were found to uninstall.
  Running setup.py develop for vit-pytorch
Successfully installed vit-pytorch-0.26.7


In [7]:
!pip install monai



In [8]:
#!pip install vit-pytorch

In [9]:
# vit_pytorch mae на BraTS, загружаем данные через monai
import os
import shutil
import tempfile
import time
import matplotlib.pyplot as plt
import numpy as np
from monai.apps import DecathlonDataset
from monai.config import print_config
from monai.data import DataLoader, decollate_batch

from monai.transforms import (
    Compose,
    Invertd,
    LoadImaged,
    MapTransform,
    NormalizeIntensityd,
    Orientationd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    Spacingd,
    EnsureChannelFirstd,
    EnsureTyped,
    Resize
)
from monai.utils import set_determinism

import torch
from vit_pytorch import ViT, MAE
import random
from tqdm.notebook import tqdm



In [10]:
print_config()


vit_config = dict({'image_size': 50,
                   'patch_size':5,
                   'num_classes':3,
                   'dim':960,
                   'depth':6,
                   'heads':12,
                   'mlp_dim':1920,
                   'channels':4})


train_config = dict({'device':'cuda', 
                     'max_epochs':20, 
                     'root_dir':'/content/drive/MyDrive/Task01_BrainTumour'})


roi_size = [50,50,50]
seed = 42

MONAI version: 0.8.0
Numpy version: 1.21.5
Pytorch version: 1.10.0+cu111
MONAI flags: HAS_EXT = False, USE_COMPILED = False
MONAI rev id: 714d00dffe6653e21260160666c4c201ab66511b

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 3.0.2
scikit-image version: 0.18.3
Pillow version: 7.1.2
Tensorboard version: 2.8.0
gdown version: 4.2.1
TorchVision version: 0.11.1+cu111
tqdm version: 4.62.3
lmdb version: 0.99
psutil version: 5.4.8
pandas version: 1.3.5
einops version: 0.4.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies



In [11]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

In [12]:
seed_everything(seed)
set_determinism(seed=0)

In [13]:
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """
    Convert labels to multi channels based on brats classes:
    label 1 is the peritumoral edema
    label 2 is the GD-enhancing tumor
    label 3 is the necrotic and non-enhancing tumor core
    The possible classes are TC (Tumor core), WT (Whole tumor)
    and ET (Enhancing tumor).

    """

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            # merge label 2 and label 3 to construct TC
            result.append(np.logical_or(d[key] == 2, d[key] == 3))
            # merge labels 1, 2 and 3 to construct WT
            result.append(
                np.logical_or(
                    np.logical_or(d[key] == 2, d[key] == 3), d[key] == 1
                )
            )
            # label 2 is ET
            result.append(d[key] == 2)
            d[key] = np.stack(result, axis=0).astype(np.float32)
        return d


In [14]:
train_transform = Compose(
    [
        # load 4 Nifti images and stack them together
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        RandSpatialCropd(keys=["image", "label"], roi_size=roi_size, random_size=False),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
        EnsureTyped(keys=["image", "label"]),
    ]
)
val_transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        RandSpatialCropd(keys=["image", "label"], roi_size=roi_size, random_size=False),  
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        EnsureTyped(keys=["image", "label"]),
    ]
)


In [15]:
# here we don't cache any data in case out of memory issue
train_ds = DecathlonDataset(
    root_dir=train_config['root_dir'],
    task="Task01_BrainTumour",
    transform=train_transform,
    section="training",
    download=True,
    cache_rate=0.0,
    num_workers=2,
)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=2)
val_ds = DecathlonDataset(
    root_dir=train_config['root_dir'],
    task="Task01_BrainTumour",
    transform=val_transform,
    section="validation",
    download=False,
    cache_rate=0.0,
    num_workers=2,
)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2)

2022-02-16 15:56:35,333 - INFO - Verified 'Task01_BrainTumour.tar', md5: 240a19d752f0d9e9101544901065d872.
2022-02-16 15:56:35,335 - INFO - File exists: /content/drive/MyDrive/Task01_BrainTumour/Task01_BrainTumour.tar, skipped downloading.
2022-02-16 15:56:35,354 - INFO - Non-empty folder exists in /content/drive/MyDrive/Task01_BrainTumour/Task01_BrainTumour, skipped extracting.


In [16]:
# # pick one image from DecathlonDataset to visualize and check the 4 channels
# print(f"image shape: {val_ds[2]['image'].shape}")
# plt.figure("image", (24, 6))
# for i in range(4):
#     plt.subplot(1, 4, i + 1)
#     plt.title(f"image channel {i}")
#     plt.imshow(val_ds[2]["image"][i, :, :, 60].detach().cpu(), cmap="gray")
# plt.show()
# # also visualize the 3 channels label corresponding to this image
# print(f"label shape: {val_ds[2]['label'].shape}")
# plt.figure("label", (18, 6))
# for i in range(3):
#     plt.subplot(1, 3, i + 1)
#     plt.title(f"label channel {i}")
#     plt.imshow(val_ds[2]["label"][i, :, :, 60].detach().cpu())
# plt.show()


In [17]:
model = ViT(
    image_size=vit_config['image_size'],
    patch_size=vit_config['patch_size'],
    num_classes=vit_config['num_classes'],
    dim=vit_config['dim'], 
    depth=vit_config['depth'],
    heads=vit_config['heads'],
    mlp_dim=vit_config['mlp_dim'], 
    channels=vit_config['channels'],
)


mae = MAE(
    encoder=model,
    masking_ratio=0.75,  # the paper recommended 75% masked patches
    decoder_dim=512,  # paper showed good results with just 512
    decoder_depth=6  # anywhere from 1 to 8
).to(train_config['device'])

for epoch in range(train_config['max_epochs']):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{train_config['max_epochs']}")
    model.train()

    epoch_loss = 0
    epoch_accuracy = 0

    for batch_data in tqdm(train_loader):
        data = batch_data['image'].to(train_config['device'])
        loss = mae(data)
        loss.backward()
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for batch_data in val_loader:
            data = batch_data['image'].to(train_config['device'])

            val_loss = mae(data)

            epoch_val_loss += val_loss / len(val_loader)

    print(
        f"Epoch : {epoch + 1} - train_loss : {epoch_loss:.4f} - val_loss : {epoch_val_loss:.4f} \n"
    )


torch.save(model.state_dict(), '/content/drive/MyDrive/vit_pretrain/vit_mae/vit_mae')

with open('/content/drive/MyDrive/vit_pretrain/vit_mae/vit_mae.txt', 'w') as f:
    for k,v in vit_config.items():
      f.write(str(k)+':'+str(v))

    for k,v in train_config.items():
      f.write(str(k)+':'+str(v))




----------
epoch 1/20


  0%|          | 0/388 [00:00<?, ?it/s]

Epoch : 1 - train_loss : 1.1856 - val_loss : 1.1400 

----------
epoch 2/20


  0%|          | 0/388 [00:00<?, ?it/s]

Epoch : 2 - train_loss : 1.1848 - val_loss : 1.1421 

----------
epoch 3/20


  0%|          | 0/388 [00:00<?, ?it/s]

Epoch : 3 - train_loss : 1.1860 - val_loss : 1.1393 

----------
epoch 4/20


  0%|          | 0/388 [00:00<?, ?it/s]

Epoch : 4 - train_loss : 1.1858 - val_loss : 1.1373 

----------
epoch 5/20


  0%|          | 0/388 [00:00<?, ?it/s]

Epoch : 5 - train_loss : 1.1895 - val_loss : 1.1403 

----------
epoch 6/20


  0%|          | 0/388 [00:00<?, ?it/s]

Epoch : 6 - train_loss : 1.1837 - val_loss : 1.1414 

----------
epoch 7/20


  0%|          | 0/388 [00:00<?, ?it/s]

Epoch : 7 - train_loss : 1.1862 - val_loss : 1.1379 

----------
epoch 8/20


  0%|          | 0/388 [00:00<?, ?it/s]

Epoch : 8 - train_loss : 1.1822 - val_loss : 1.1383 

----------
epoch 9/20


  0%|          | 0/388 [00:00<?, ?it/s]

Epoch : 9 - train_loss : 1.1883 - val_loss : 1.1413 

----------
epoch 10/20


  0%|          | 0/388 [00:00<?, ?it/s]

Epoch : 10 - train_loss : 1.1887 - val_loss : 1.1430 

----------
epoch 11/20


  0%|          | 0/388 [00:00<?, ?it/s]

Epoch : 11 - train_loss : 1.1835 - val_loss : 1.1406 

----------
epoch 12/20


  0%|          | 0/388 [00:00<?, ?it/s]

Epoch : 12 - train_loss : 1.1832 - val_loss : 1.1393 

----------
epoch 13/20


  0%|          | 0/388 [00:00<?, ?it/s]

Epoch : 13 - train_loss : 1.1865 - val_loss : 1.1407 

----------
epoch 14/20


  0%|          | 0/388 [00:00<?, ?it/s]

Epoch : 14 - train_loss : 1.1862 - val_loss : 1.1414 

----------
epoch 15/20


  0%|          | 0/388 [00:00<?, ?it/s]

Epoch : 15 - train_loss : 1.1855 - val_loss : 1.1409 

----------
epoch 16/20


  0%|          | 0/388 [00:00<?, ?it/s]

Epoch : 16 - train_loss : 1.1869 - val_loss : 1.1437 

----------
epoch 17/20


  0%|          | 0/388 [00:00<?, ?it/s]

Epoch : 17 - train_loss : 1.1874 - val_loss : 1.1396 

----------
epoch 18/20


  0%|          | 0/388 [00:00<?, ?it/s]

Epoch : 18 - train_loss : 1.1863 - val_loss : 1.1421 

----------
epoch 19/20


  0%|          | 0/388 [00:00<?, ?it/s]

Epoch : 19 - train_loss : 1.1873 - val_loss : 1.1384 

----------
epoch 20/20


  0%|          | 0/388 [00:00<?, ?it/s]

Epoch : 20 - train_loss : 1.1884 - val_loss : 1.1404 



In [18]:
!nvidia-smi

Wed Feb 16 19:26:19 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   73C    P0    87W / 149W |   1845MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces