In [None]:
env = 'server' # 'local' or 'server'

### Setup imports

In [None]:
import logging
import sys
import copy, tqdm
import time
from datetime import timedelta, datetime
import shutil

import os
# dead kernel 방지
os.environ['KMP_DUPLICATE_LIB_OK']='True'
# GPU setting
if env == 'server':
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # Arrange GPU devices starting from 0
    os.environ["CUDA_VISIBLE_DEVICES"] = "2" # Set the GPU 2 to use

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import torch
import torch.nn as nn

# python에서 List, Dict, Tuple, Set와 같은 파이썬 내장 자료구조에 대한 타입을 명시해야할 때 사용
from typing import Sequence, Union

import monai
from monai.apps import download_and_extract
from monai.config import print_config
from monai.data.utils import pad_list_data_collate
from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
)
from monai.transforms import MapTransform
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    Orientationd,
    ScaleIntensityd,
    NormalizeIntensityd,
    Resized,
    SpatialPadd,
    RandCropByPosNegLabeld,
    RandSpatialCropSamplesd,
    CenterSpatialCropd,
    CropForegroundd,
    RandAffined,
    EnsureTyped,
)
from monai.utils import set_determinism, first

from einops import repeat, rearrange
# PyTorch torch.load 함수를 기반으로 동작하며, 딕셔너리에 모델의 가중치, epoch 정보 등을 저장하여 모델을 재구성함.
from timm.models.layers import trunc_normal_ # timm: pretrained model 제공, trunc_normal_: model initialize할 때 사용

pin_memory = torch.cuda.is_available()

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
print_config()

In [None]:
# Warning 안뜨도록
import warnings
warnings.filterwarnings("ignore")

# Palette setting
import seaborn as sns
sns.set_palette('Pastel1')
palette1 = sns.color_palette('Pastel1', 8) # 5: 팔레트 몇개 생성할건지
palette2 = sns.color_palette('Pastel2', 8) # 5: 팔레트 몇개 생성할건지
# sns.palplot(palette)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print('Device:', device)

In [None]:
modality = ['flair', 't1', 't1ce', 't2','seg'] # 'flair', 't1', 't1ce', 't2', 'seg'
mo_img = ['flair', 't1', 't1ce', 't2']

In [None]:
lr = 4e-4 # 5e-3
weight_decay = 5e-5
val_interval = 1
epochs = 3000

In [None]:
lr_str = "{:.0e}".format(lr)

### Setup data directory
- 실험 결과 저장할 디렉토리 설정

In [None]:
# Make working directory
from datetime import datetime

expd = datetime.today().strftime("%Y%m%d")+"_"+f"ep{epochs}_lr{lr_str}"
path00 = './model_MAE2D'

root_dir = os.path.join(path00,f'{expd}')
if os.path.isdir(root_dir)==0: # 해당 주소의 폴더가 없으면 만들어줌.
    os.mkdir(root_dir)
    print(f"Success in making {expd}~!")
else:
    if os.listdir(root_dir):
        raise UserWarning(f"'{expd}' is already exist and not empty.")
    else:
        print(f"[WARNING] {expd} is already exist but empty")

### Setup dataset

- split json 준비
- Transform 준비

In [None]:
if env == 'local':
    data_dir = '../../Datasets/Dataset002_BRATS2017/'
elif env == 'server':
    data_dir = '/store8/njrue/Datasets/Dataset002_BRATS2017/'
split_json = 'BraTS2017_ipiu.json'

datasets = data_dir+split_json
train_files = load_decathlon_datalist(datasets, is_segmentation=False, data_list_key="training")
val_files = load_decathlon_datalist(datasets, is_segmentation=False, data_list_key="validation")

In [None]:
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """
    Convert labels to multi channels based on brats classes:
    label 2 is the peritumoral edema (+roi3)
    label 4 is the GD-enhancing tumor (+roi2)
    label 1 is the necrotic tumor core (roi1)

    """

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []

            # 연구실 논문에서 규정한 class
            # roi1
            result.append(d[key] == 1)
            # # roi2
            # result.append(np.logical_or(d[key] == 4, d[key] == 1))
            # # roi3
            # result.append(
            #     np.logical_or(np.logical_or(d[key] == 1, d[key] == 2), d[key] == 4)
            # )

            d["seg"] = np.stack(result, axis=0).astype(np.float32)
        return d

In [None]:
pd.DataFrame(train_files).columns

In [None]:
train_transforms = Compose([
    LoadImaged(
        keys=modality
    ),
    EnsureChannelFirstd(
        keys=mo_img
    ),
    # Orientationd(keys=modality, axcodes="RAS"),
    ConvertToMultiChannelBasedOnBratsClassesd(keys=['seg']),
    # ScaleIntensityd(keys=mo_img),
    NormalizeIntensityd(
        keys=mo_img, 
        nonzero=True, 
        channel_wise=True
    ),
    EnsureTyped(
        keys=modality
    ),
    CropForegroundd(keys=modality, source_key=mo_img[0]),
    SpatialPadd(keys=modality, spatial_size=[128,128,128]), # spatial size보다 input 이미지가 크면 padding 안함.
    RandCropByPosNegLabeld(
        keys=modality,
        label_key="seg",
        spatial_size=[128,128,128],
        pos=1,
        neg=1,
        num_samples=16,
    ),
    # RandSpatialCropSamplesd(
    #     keys=modality,
    #     roi_size=[128,128,128],
    #     random_size=False,
    #     num_samples=2,
    # ),
])

# transforms_noCrop = Compose([
#     LoadImaged(
#         keys=modality
#     ),
#     EnsureChannelFirstd(
#         keys=mo_img
#     ),
#     Orientationd(keys=modality, axcodes="RAS"),
#     # ConvertToMultiChannelBasedOnBratsClassesd(keys=['seg']),
#     # ScaleIntensityd(keys=mo_img),
#     NormalizeIntensityd(
#         keys=mo_img, 
#         nonzero=True, 
#         channel_wise=True
#     ),
#     EnsureTyped(
#         keys=modality
#     ),
#     CenterSpatialCropd(
#         keys=modality, roi_size=[240,240,128],
#     ),
#     CropForegroundd(keys=modality, source_key=mo_img[0]),
#     Resized(
#         keys=modality,
#         spatial_size=(128,128,128),
#         mode=['area','area','area','area',]
#     ),
# ])

val_transforms = copy.deepcopy(train_transforms)

### Check dataset

In [None]:
check_ds = CacheDataset(data=train_files[50:60], transform=train_transforms,
                        cache_num=5,
                        cache_rate=1.0,
                        num_workers=4,)
check_loader = DataLoader(check_ds, batch_size=1, shuffle=True, num_workers=0, pin_memory=False,
                          collate_fn=pad_list_data_collate,)
check_data = first(check_loader)

In [None]:
print("**Original Image shape: (1, 1, 240, 240, 155), Segmentation shape: (1, 3, 240, 240, 155)")
print(f"\t   Image shape: {np.shape(check_data['t1'])}, Segmentation shape: {np.shape(check_data['seg'])}")

In [None]:
sample = 1
flair = check_data['flair'][sample][0].permute(1,0,2)
t1 = check_data['t1'][sample][0].permute(1,0,2)
t1ce = check_data['t1ce'][sample][0].permute(1,0,2)
t2 = check_data['t2'][sample][0].permute(1,0,2)
seg = check_data['seg'][sample].permute(0,2,1,3)

print(check_data['flair'].meta['filename_or_obj'],end='\n\n')
print(f"flair image shape: {flair.shape}, "
      f"t1 image shape: {t1.shape}, \n"
      f"t1ce image shape:  {t1ce.shape}, "
      f"t2 image shape: {t2.shape}, \n"
      f"seg image shape:   {seg.shape}, ")

In [None]:
zn = 0
plt.figure("check", (20, 6))
plt.subplot(1, 5, 1)
plt.title("flair")
# plt.imshow(check_ds[3][0]['flair'][0].permute(1,0,2)[:, :, zn], cmap="gray")
plt.imshow(flair[:, :, zn], cmap="gray")
plt.subplot(1, 5, 2)
plt.title("t1")
plt.imshow(t1[:, :, zn], cmap="gray")
plt.subplot(1, 5, 3)
plt.title("t1ce")
plt.imshow(t1ce[:, :, zn], cmap="gray")
plt.subplot(1, 5, 4)
plt.title("t2")
plt.imshow(t2[:, :, zn], cmap="gray")
plt.subplot(1, 5, 5)
plt.title("seg")
# plt.imshow(seg[0,:, :, zn])
plt.imshow(torch.sum(seg[:,:, :, zn],axis=0),vmin=0,vmax=3)
plt.show()

### Dataset Load

In [None]:
train_ds = CacheDataset(data=train_files, transform=train_transforms,
                        cache_rate=1.0,
                        num_workers=8,)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=0, pin_memory=True,
                          collate_fn=pad_list_data_collate,)

val_ds = CacheDataset(data=val_files, transform=val_transforms,
                        cache_rate=1.0,
                        num_workers=8,)
val_loader = DataLoader(val_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True,
                        collate_fn=pad_list_data_collate,)

In [None]:
check_train = first(train_loader)
print("**Original Image shape: (1, 1, 240, 240, 155), Segmentation shape: (1, 3, 240, 240, 155)")
print(f"\t   Image shape: {np.shape(check_train['t1'])}, Segmentation shape: {np.shape(check_train['seg'])}")

In [None]:
plt.figure(figsize=(3,3))
plt.imshow(check_train['t1ce'][5][0].permute(1,0,2)[:, :, 0], cmap="gray")
plt.show()

### Set parameter

In [None]:
data_keys = modality
if 'seg' in data_keys:
    in_channels = len(data_keys)+2
else:
    in_channels = len(data_keys)

### Define model

In [None]:
from monai.networks.blocks.transformerblock import TransformerBlock
from monai.networks.nets import ViT

In [None]:
class MAE(nn.Module):
    def __init__(
        self,
        in_channels: int,
        img_size: Union[Sequence[int], int],
        patch_size: Union[Sequence[int], int],
        hidden_size: int = 768,
        mlp_dim: int = 3072,
        num_layers: int = 12,
        num_heads: int = 12,
        pos_embed: str = "perceptron",
        dropout_rate: float = 0.0,
        spatial_dims: int = 3,
        decoder_dim: int = 768,
        decoder_depth: int = 1,
        decoder_heads: int = 8,
        masking_ratio: float = 0.75,
        revise_keys=[("model.", "")],
        **kwargs,
    ) -> None:

        super().__init__()
        self.spatial_dims = spatial_dims

        self.encoder = ViT(
            in_channels=in_channels,
            img_size=img_size,
            patch_size=patch_size,
            hidden_size=hidden_size,
            mlp_dim=mlp_dim,
            num_layers=num_layers,
            num_heads=num_heads,
            pos_embed=pos_embed,
            dropout_rate=dropout_rate,
            spatial_dims=spatial_dims,
        )

        # patch embedding block
        patch_embedding = self.encoder.patch_embedding
        self.to_patch, self.patch_to_emb = patch_embedding.patch_embeddings
        n_patches = patch_embedding.n_patches
        patch_dim = patch_embedding.patch_dim

        # connect encoder and decoder if mismatch dimension
        self.enc_to_dec = (
            nn.Linear(hidden_size, decoder_dim)
            if hidden_size != decoder_dim
            else nn.Identity()
        )

        # build up decoder transformer blocks
        self.decoder_blocks = nn.ModuleList(
            [
                TransformerBlock(
                    decoder_dim, decoder_dim * 4, decoder_heads, dropout_rate
                )
                for i in range(decoder_depth)
            ]
        )
        self.decoder_norm = nn.LayerNorm(decoder_dim)
        self.masking_ratio = masking_ratio
        assert (
            masking_ratio > 0 and masking_ratio < 1
        ), "masking ratio must be kept between 0 and 1"
        self.mask_token = nn.Parameter(torch.randn(decoder_dim))
        self.decoder_pos_emb = nn.Embedding(n_patches, decoder_dim)

        # embeddings to pixels
        self.to_pixels = nn.Linear(decoder_dim, patch_dim)

        self.init_weights(revise_keys=revise_keys)

    def init_weights(self, pretrained=None, revise_keys=[]):
        def _init_weights(m):
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=0.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)

        self.apply(_init_weights)

    def forward(self, x):
        device = x.device

        # get patches
        patches = self.to_patch(x)
        batch, n_patches, *_ = patches.shape

        # patch to encoder tokens and add positions
        tokens = self.patch_to_emb(patches)
        tokens = tokens + self.encoder.patch_embedding.position_embeddings

        # calculate of patches needed to be masked, and get random indices
        num_masked = int(self.masking_ratio * n_patches)
        rand_indices = torch.rand(batch, n_patches, device=device).argsort(dim=-1)
        masked_indices, unmasked_indices = (
            rand_indices[:, :num_masked],
            rand_indices[:, num_masked:],
        )

        # get the unmasked tokens to be encoded
        batch_range = torch.arange(batch, device=device)[:, None]
        tokens = tokens[batch_range, unmasked_indices]

        # get the patches to be masked for the final reconstruction loss
        # masked_patches = patches[batch_range, masked_indices]

        for blk in self.encoder.blocks:
            tokens = blk(tokens)
        encoded_tokens = tokens 

        decoder_tokens = self.enc_to_dec(encoded_tokens)
        decoder_tokens += self.decoder_pos_emb(unmasked_indices)

        mask_tokens = repeat(self.mask_token, "d -> b n d", b=batch, n=num_masked)
        mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)

        # concat the masked tokens to the decoder tokens and attend with decoder
        decoder_tokens = torch.cat((mask_tokens, decoder_tokens), dim=1)
        for blk in self.decoder_blocks:
            decoder_tokens = blk(decoder_tokens)
        decoded_tokens = self.decoder_norm(decoder_tokens)

        # splice out the mask tokens and project to pixel values
        mask_tokens = decoded_tokens[:, :num_masked]
        pred_pixel_values = self.to_pixels(mask_tokens)

        return pred_pixel_values, patches, batch_range, masked_indices, unmasked_indices, encoded_tokens

### Call model

In [None]:
model = MAE(
    in_channels=len(mo_img),
    img_size=[128,128,128],
    patch_size=[16,16,16],
    hidden_size=768,
    mlp_dim=3072,
    num_layers=12,
    num_heads=12,
    pos_embed='perceptron',
    dropout_rate=0.0,
    spatial_dims=3,
    masking_ratio=0.75,
).to(device)

In [None]:
model

### Load best model

In [None]:
model_dir = './model_MAE2D/20231221_load1220_ep1000_lr1e-04/'
model.load_state_dict(torch.load(os.path.join(model_dir, "best_model.pth")))

### Train MAE

In [None]:
recon_loss = nn.L1Loss()
# recon_loss = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

In [None]:
start = time.time()
best_loss = 100
best_loss_epoch = -1
epoch_loss_values = []
val_loss_values = []

real_labels = [] # validation dataset의 label을 모을 list
pred_labels = [] # best model의 validation dataset에 대한 prediction을 모을 list

for epoch in range(epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{epochs}")
    model.train()
    epoch_loss = 0
    step = 0

    for batch_data in train_loader:
        step += 1
        flairs, t1s, t1ces, t2s = batch_data['flair'].to(device), batch_data['t1'].to(device), batch_data['t1ce'].to(device), \
                                        batch_data['t2'].to(device)
        # t1ces = batch_data['t1ce'].to(device)
        optimizer.zero_grad()
        pred_pixel_values, patches, batch_range, masked_indices, unmasked_indices, encoded_tokens = \
            model(torch.cat((flairs, t1s, t1ces, t2s),dim=1))
        loss = recon_loss(pred_pixel_values, patches[batch_range, masked_indices])
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(train_ds) // train_loader.batch_size
        # print(f"{step:02}/{epoch_len:02}, train_loss: {loss.item():.4f}")

    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0 or epoch + 1 == epochs:
        model.eval()
        
        val_loss = 0
        for val_data in val_loader:
            val_flairs, val_t1s, val_t1ces, val_t2s = val_data['flair'].to(device), val_data['t1'].to(device),\
                                                                val_data['t1ce'].to(device),  val_data['t2'].to(device)
            # val_t1ces = val_data['t1ce'].to(device)
            with torch.no_grad():
                val_pred_pixel_values, val_patches, val_batch_range, val_masked_indices, _, _ = \
                    model(torch.cat((val_flairs, val_t1s, val_t1ces, val_t2s),dim=1))
                val_loss += recon_loss(pred_pixel_values, patches[batch_range, masked_indices]).item()
                batch_size = pred_pixel_values.shape[0]

        val_loss /= len(val_loader)
        for _ in range(val_interval):
            val_loss_values.append(val_loss)
        df = pd.DataFrame({'epoch':range(len(epoch_loss_values)),'train_loss':epoch_loss_values,
                           'val_loss':val_loss_values})
        df = df.set_index('epoch')
        df.to_csv(os.path.join(root_dir,'save_results.csv'))
        
        if val_loss < best_loss:
            best_loss = val_loss
            best_loss_epoch = epoch + 1
            torch.save(model.state_dict(), os.path.join(root_dir,"best_model.pth"))
            print("saved new best metric model")
        else:
            torch.save(model.state_dict(), os.path.join(root_dir,f"save_model_ep{epoch+1}.pth"))

        print(f"Current epoch: {epoch+1}, average loss: {val_loss:.4f}")
        print(f"Minimum loss: {best_loss:.4f} at epoch {best_loss_epoch}")


print(f"Training completed, best_loss: {best_loss:.4f} at epoch: {best_loss_epoch}")

In [None]:
torch.save(model.state_dict(), os.path.join(root_dir,f"last_model.pth"))

In [None]:
end = time.time()
sec = end - start
print(f"Training time: {str(timedelta(seconds=sec)).split('.')[0]}")

In [None]:
f= open(os.path.join(root_dir,"Training_time.txt"),"w")
f.write(f"Training time: {str(timedelta(seconds=sec)).split('.')[0]}")
f.close()

In [None]:
fig = plt.figure("train/valid", (13, 5))

fig.add_subplot(1, 2, 1)
plt.title('Iteration Train Loss')
plt.xlabel("Iteration")
plt.plot(epoch_loss_values, label='train loss')

fig.add_subplot(1, 2, 2)
plt.title("Iteration Average Loss")
x = range(len(val_loss_values))
y = epoch_loss_values[:len(val_loss_values)]
y2 = val_loss_values
plt.xlabel("Iteration")
plt.plot(x, y2, label='validation loss', color=palette1[1])
plt.plot(x, y, label='train loss', color=palette1[0])
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(root_dir,f'MAE_loss_graph.png'))

plt.show()

### Display reconstruction result

In [None]:
# model_dir = './model_MAE/20231205_load0404_1500_lr8e-04/'
# expd = 'load1204_4000_lr8e-04'
model.load_state_dict(torch.load(os.path.join(root_dir, "best_model.pth")))
model.eval()

In [None]:
def show_image(image, title='', vmin=-4, vmax=4):
    # image is [H, W, 3]
    # assert image.shape[2] == 3
    plt.imshow(image, cmap='gray', vmin=vmin, vmax=vmax)
    plt.title(title, fontsize=16)
    plt.axis('off')
    return

def run_one_image(img, model, modality='flair', expd='', slice=50, save_dir=root_dir, dim=3):
    if dim == 3:
        dim_info = {'h':8, 'w':8, 'd':8, 'p1':16, 'p2':16, 'p3':16, 'c':len(mo_img)}
    elif dim == 2:
        dim_info = {'h':8, 'w':8, 'd':1, 'p1':16, 'p2':16, 'p3':1, 'c':len(mo_img)}
    modality_dict = {'flair':0, 't1':1, 't1ce':2, 't2':3}
    modality_dim = modality_dict[modality]
    
    flairs, t1s, t1ces, t2s = img['flair'].to(device), img['t1'].to(device), \
        img['t1ce'].to(device), img['t2'].to(device)
    # t1ces = img['t1ce'].to(device)
    x = torch.cat((flairs, t1s, t1ces, t2s),dim=1)
    # x = t1ces
    # modality_dim = 0
    vmin, vmax = torch.tensor(np.min(x[0,modality_dim,:,:,slice])), torch.tensor(np.max(x[0,modality_dim,:,:,slice]))
    model.eval()
    with torch.no_grad():
        # run MAE
        pred_pixel_values, patches, batch_range, masked_indices, unmasked_indices, encoded_tokens = model(x)
        
        # masked image
        im_masked = copy.deepcopy(patches.cpu())
        im_masked[batch_range, masked_indices] = vmin
        im_masked = rearrange(im_masked, 'b (h w d) (p1 p2 p3 c)->b c (h p1) (w p2) (d p3)', **dim_info)
        
        # only reconstruction image
        y = copy.deepcopy(patches.cpu())
        y[batch_range, masked_indices] = pred_pixel_values.cpu()
        y[batch_range, unmasked_indices] = vmin
        y = rearrange(y, 'b (h w d) (p1 p2 p3 c)->b c (h p1) (w p2) (d p3)', **dim_info)
        
        # MAE reconstruction pasted with visible patches
        im_paste = copy.deepcopy(patches.cpu())
        im_paste[batch_range, masked_indices] = pred_pixel_values.cpu()
        # im_paste[batch_range, unmasked_indices] = model.to_pixels(model.decoder_norm(encoded_tokens)).cpu()
        im_paste = rearrange(im_paste, 'b (h w d) (p1 p2 p3 c)->b c (h p1) (w p2) (d p3)', **dim_info)

        # make the plt figure larger
        plt.rcParams['figure.figsize'] = [12, 4]
        plt.suptitle(f"MAE with pixel reconstruction{expd} ({modality})",fontsize=20)
        plt.subplot(1, 4, 1)
        show_image(x.cpu().permute(0,1,3,2,4)[1,modality_dim,:,:,slice], "original", vmin, vmax)

        plt.subplot(1, 4, 2)
        show_image(im_masked.permute(0,1,3,2,4)[1,modality_dim,:,:,slice], "masked", vmin, vmax)

        plt.subplot(1, 4, 3)
        show_image(y.permute(0,1,3,2,4)[1,modality_dim,:,:,slice], "reconstruction", vmin, vmax)

        plt.subplot(1, 4, 4)
        show_image(im_paste.permute(0,1,3,2,4)[1,modality_dim,:,:,slice], "reconstruction + visible", vmin, vmax)
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir,f'MAE_recon_results{expd}_{modality}.png'))
        plt.show()

In [None]:
# test_ds = CacheDataset(data=val_files, transform=transforms_noCrop,
#                         cache_rate=1.0,
#                         num_workers=8,)
# test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=0, pin_memory=True,
#                         collate_fn=pad_list_data_collate,)

In [None]:
expd2 = 'recon_results(change_format)'
res_path = os.path.join(root_dir,expd2)
if os.path.isdir(res_path)==0: # 해당 주소의 폴더가 없으면 만들어줌.
    os.mkdir(res_path)
    print(f"Success in making {expd}/{expd2}~!")
else:
    if os.listdir(res_path):
        raise UserWarning(f"'{expd}/{expd2}' is already exist and not empty.")
    else:
        print(f"[WARNING] {expd}/{expd2} is already exist but empty")

In [None]:
print('MAE with pixel reconstruction:')
for i, data in enumerate(val_loader):
    print(f'{i}th data: ')
    for mo in mo_img:
        run_one_image(data, model, modality=mo, expd=f'_test{i:02}', 
                      slice=0, save_dir=res_path, dim=3) # expd 넣을거면 _로 시작하도록

In [None]:
img = first(train_loader)

In [None]:
print('MAE with pixel reconstruction:')
for mo in mo_img:
    run_one_image(img, model, modality=mo, expd='_train_data', slice=0, dim=3) # expd 넣을거면 _로 시작하도록