<a href="https://colab.research.google.com/github/KarineAyrs/science_work/blob/main/training/monai1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [2]:
%cd drive/MyDrive/vit-pytorch

/content/drive/MyDrive/vit-pytorch


In [3]:
#!git clone https://github.com/KarineAyrs/vit-pytorch.git

In [4]:
#%cd vit-pytorch

In [5]:
!ls

examples  LICENSE      README.md  tests        vit_pytorch.egg-info
images	  MANIFEST.in  setup.py   vit_pytorch


In [6]:
!python -m pip install -Ue .

Obtaining file:///content/drive/MyDrive/vit-pytorch
Collecting einops>=0.3
  Downloading einops-0.4.0-py3-none-any.whl (28 kB)
Installing collected packages: einops, vit-pytorch
  Running setup.py develop for vit-pytorch
Successfully installed einops-0.4.0 vit-pytorch-0.26.7


In [7]:
!pip install monai

Collecting monai
  Downloading monai-0.8.1-202202162213-py3-none-any.whl (721 kB)
[K     |████████████████████████████████| 721 kB 4.1 MB/s 
Installing collected packages: monai
Successfully installed monai-0.8.1


In [8]:
#!pip install vit-pytorch

In [9]:
# vit_pytorch mae на BraTS, загружаем данные через monai
import os
import shutil
import tempfile
import time
import matplotlib.pyplot as plt
import numpy as np
from monai.apps import DecathlonDataset
from monai.config import print_config
from monai.data import DataLoader, decollate_batch

from monai.transforms import (
    Compose,
    Invertd,
    LoadImaged,
    MapTransform,
    NormalizeIntensityd,
    Orientationd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    Spacingd,
    EnsureChannelFirstd,
    EnsureTyped,
    Resize
)
from monai.utils import set_determinism

import torch
from vit_pytorch import ViT, MAE
import random
from tqdm.notebook import tqdm



In [10]:
print_config()


vit_config = dict({'image_size': 50,
                   'patch_size':5,
                   'num_classes':3,
                   'dim':960,
                   'depth':6,
                   'heads':12,
                   'mlp_dim':1920,
                   'channels':4})


train_config = dict({'device':'cuda', 
                     'max_epochs':20, 
                     'root_dir':'/content/drive/MyDrive/Task01_BrainTumour'})


roi_size = [50,50,50]
seed = 42

MONAI version: 0.8.1
Numpy version: 1.21.5
Pytorch version: 1.10.0+cu111
MONAI flags: HAS_EXT = False, USE_COMPILED = False
MONAI rev id: 71ff399a3ea07aef667b23653620a290364095b1

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 3.0.2
scikit-image version: 0.18.3
Pillow version: 7.1.2
Tensorboard version: 2.8.0
gdown version: 4.2.1
TorchVision version: 0.11.1+cu111
tqdm version: 4.62.3
lmdb version: 0.99
psutil version: 5.4.8
pandas version: 1.3.5
einops version: 0.4.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies



In [11]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

In [12]:
seed_everything(seed)
set_determinism(seed=0)

In [13]:
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """
    Convert labels to multi channels based on brats classes:
    label 1 is the peritumoral edema
    label 2 is the GD-enhancing tumor
    label 3 is the necrotic and non-enhancing tumor core
    The possible classes are TC (Tumor core), WT (Whole tumor)
    and ET (Enhancing tumor).

    """

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            # merge label 2 and label 3 to construct TC
            result.append(np.logical_or(d[key] == 2, d[key] == 3))
            # merge labels 1, 2 and 3 to construct WT
            result.append(
                np.logical_or(
                    np.logical_or(d[key] == 2, d[key] == 3), d[key] == 1
                )
            )
            # label 2 is ET
            result.append(d[key] == 2)
            d[key] = np.stack(result, axis=0).astype(np.float32)
        return d


In [14]:
train_transform = Compose(
    [
        # load 4 Nifti images and stack them together
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        RandSpatialCropd(keys=["image", "label"], roi_size=roi_size, random_size=False),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
        EnsureTyped(keys=["image", "label"]),
    ]
)
val_transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        RandSpatialCropd(keys=["image", "label"], roi_size=roi_size, random_size=False),  
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        EnsureTyped(keys=["image", "label"]),
    ]
)


In [15]:
# here we don't cache any data in case out of memory issue
train_ds = DecathlonDataset(
    root_dir=train_config['root_dir'],
    task="Task01_BrainTumour",
    transform=train_transform,
    section="training",
    download=True,
    cache_rate=0.0,
    num_workers=2,
)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=2)
val_ds = DecathlonDataset(
    root_dir=train_config['root_dir'],
    task="Task01_BrainTumour",
    transform=val_transform,
    section="validation",
    download=False,
    cache_rate=0.0,
    num_workers=2,
)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2)

Task01_BrainTumour.tar: 7.09GB [05:54, 21.5MB/s]                            


2022-02-17 17:10:12,869 - INFO - Downloaded: /content/drive/MyDrive/Task01_BrainTumour/Task01_BrainTumour.tar
2022-02-17 17:11:14,422 - INFO - Verified 'Task01_BrainTumour.tar', md5: 240a19d752f0d9e9101544901065d872.
2022-02-17 17:11:14,987 - INFO - Non-empty folder exists in /content/drive/MyDrive/Task01_BrainTumour/Task01_BrainTumour, skipped extracting.


In [16]:
# # pick one image from DecathlonDataset to visualize and check the 4 channels
# print(f"image shape: {val_ds[2]['image'].shape}")
# plt.figure("image", (24, 6))
# for i in range(4):
#     plt.subplot(1, 4, i + 1)
#     plt.title(f"image channel {i}")
#     plt.imshow(val_ds[2]["image"][i, :, :, 60].detach().cpu(), cmap="gray")
# plt.show()
# # also visualize the 3 channels label corresponding to this image
# print(f"label shape: {val_ds[2]['label'].shape}")
# plt.figure("label", (18, 6))
# for i in range(3):
#     plt.subplot(1, 3, i + 1)
#     plt.title(f"label channel {i}")
#     plt.imshow(val_ds[2]["label"][i, :, :, 60].detach().cpu())
# plt.show()


In [17]:
model = ViT(
    image_size=vit_config['image_size'],
    patch_size=vit_config['patch_size'],
    num_classes=vit_config['num_classes'],
    dim=vit_config['dim'], 
    depth=vit_config['depth'],
    heads=vit_config['heads'],
    mlp_dim=vit_config['mlp_dim'], 
    channels=vit_config['channels'],
)


mae = MAE(
    encoder=model,
    masking_ratio=0.75,  # the paper recommended 75% masked patches
    decoder_dim=512,  # paper showed good results with just 512
    decoder_depth=6  # anywhere from 1 to 8
).to(train_config['device'])

In [18]:
optimizer = torch.optim.Adam(model.parameters(), 1e-4, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=train_config['max_epochs'])
scaler = torch.cuda.amp.GradScaler()

In [None]:
for epoch in range(train_config['max_epochs']):
    model.train()

    epoch_loss = 0
    epoch_accuracy = 0

    for batch_data in tqdm(train_loader):

        data = batch_data['image'].to(train_config['device'])

        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            loss = mae(data)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
  
    lr_scheduler.step()
    epoch_loss += loss / len(train_loader)

    model.eval()
    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for batch_data in val_loader:
            data = batch_data['image'].to(train_config['device'])

            val_loss = mae(data)

            epoch_val_loss += val_loss / len(val_loader)

    print(
        f"Epoch : {epoch + 1} - train_loss : {epoch_loss:.4f} - val_loss : {epoch_val_loss:.4f} \n"
    )


torch.save(model.state_dict(), '/content/drive/MyDrive/vit_pretrain/vit_mae/vit_mae')

with open('/content/drive/MyDrive/vit_pretrain/vit_mae/vit_mae.txt', 'w') as f:
    for k,v in vit_config.items():
      f.write(str(k)+':'+str(v)+'\n')

    for k,v in train_config.items():
      f.write(str(k)+':'+str(v)+'\n')




  0%|          | 0/388 [00:00<?, ?it/s]

Epoch : 1 - train_loss : 0.0016 - val_loss : 0.8873 



  0%|          | 0/388 [00:00<?, ?it/s]

Epoch : 2 - train_loss : 0.0016 - val_loss : 0.8893 



  0%|          | 0/388 [00:00<?, ?it/s]

Epoch : 3 - train_loss : 0.0017 - val_loss : 0.8867 



  0%|          | 0/388 [00:00<?, ?it/s]

Epoch : 4 - train_loss : 0.0016 - val_loss : 0.8845 



  0%|          | 0/388 [00:00<?, ?it/s]

Epoch : 5 - train_loss : 0.0016 - val_loss : 0.8875 



  0%|          | 0/388 [00:00<?, ?it/s]

Epoch : 6 - train_loss : 0.0017 - val_loss : 0.8888 



  0%|          | 0/388 [00:00<?, ?it/s]