In [None]:
env = 'server' # 'local' or 'server'
assert env in ['loval','server'], "Training environment must be 'local' or 'server'"
mode = 'MAE' # 'pyradiomics' or 'MAE' or 'ensemble'
assert mode in ['pyradiomics','MAE','ensemble'], "Model's mode must be 'pyradiomics' or 'MAE' or 'ensemble'"

### Setup imports

In [None]:
import os, time, copy

# dead kernel 방지
os.environ['KMP_DUPLICATE_LIB_OK']='True'
# GPU setting
if env == 'server':
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # Arrange GPU devices starting from 0
    os.environ["CUDA_VISIBLE_DEVICES"] = "2" # Set the GPU 2 to use
    
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
from tqdm import tqdm
import logging, sys
import shutil

import numpy as np
import pandas as pd

# python에서 List, Dict, Tuple, Set와 같은 파이썬 내장 자료구조에 대한 타입을 명시해야할 때 사용
from typing import Sequence, Union


# Pytorch
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset, StackDataset
# dataset
from sklearn import datasets

# Label encoder: Categorical label to Numerical label
from sklearn.preprocessing import LabelEncoder
# z-normalization
from sklearn.preprocessing import StandardScaler

# Monai
import monai
from monai.apps import download_and_extract
from monai.config import print_config
from monai.data.utils import pad_list_data_collate
from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
)
from monai.transforms import MapTransform
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    Orientationd,
    ScaleIntensityd,
    NormalizeIntensityd,
    Resized,
    SpatialPadd,
    RandCropByPosNegLabeld,
    RandSpatialCropSamplesd,
    CenterSpatialCropd,
    CropForegroundd,
    RandAffined,
    EnsureTyped,
)
from monai.utils import set_determinism, first

from einops import repeat, rearrange
# PyTorch torch.load 함수를 기반으로 동작하며, 딕셔너리에 모델의 가중치, epoch 정보 등을 저장하여 모델을 재구성함.
from timm.models.layers import trunc_normal_ # timm: pretrained model 제공, trunc_normal_: model initialize할 때 사용

# Model performance
from sklearn.metrics import roc_auc_score, RocCurveDisplay
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
from sklearn.metrics import accuracy_score, roc_auc_score, r2_score
from scipy import stats

pin_memory = torch.cuda.is_available()

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
print_config()

In [None]:
# Warning 안뜨도록
import warnings
warnings.filterwarnings("ignore")

# Palette setting
import seaborn as sns
sns.set_palette('Pastel1')
palette1 = sns.color_palette('Pastel1', 8) # 5: 팔레트 몇개 생성할건지
palette = sns.color_palette('Pastel2', 8) # 5: 팔레트 몇개 생성할건지
# sns.palplot(palette)

# 한글 폰트 깨짐 해결
import matplotlib
if env == 'local':
    matplotlib.rcParams['font.family'] ='Malgun Gothic'
    matplotlib.rcParams['axes.unicode_minus'] =False

### Training setting

In [None]:
# Load할 모델 path
model_dir = './model_MAE/20231205_load0404_1500_lr8e-04/last_model.pth'

# pyradiomics에 shape feature를 제외할 것인지
noShape = False

In [None]:
# Hyperparameters
if mode == 'pyradiomics':
    lr = 1e-5
    weight_decay = 5e-6
elif mode == 'ensemble':
    lr = 1e-4
    weight_decay = 5e-5
else:
    lr = 1e-4
    weight_decay = 5e-5
lr_str = "{:.0e}".format(lr)

epochs = 100
batch_size = 8

In [None]:
# seed = 23
# torch.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)
# np.random.seed(seed)

In [None]:
modality = ['flair', 't1', 't1ce', 't2', 'seg'] # 'flair', 't1', 't1ce', 't2', 'seg'
mo_img = ['flair', 't1', 't1ce', 't2']

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
# Make working directory
from datetime import datetime

expd = datetime.today().strftime("%Y%m%d")+"_"+f"ep{epochs}_lr{lr_str}_05last"
path00 = f'./model_downstream/model_{mode}'

root_dir = os.path.join(path00,f'{expd}')
if os.path.isdir(root_dir)==0: # 해당 주소의 폴더가 없으면 만들어줌.
    os.mkdir(root_dir)
    print(f"Success in making {expd}~!")
else:
    if os.listdir(root_dir):
        raise UserWarning(f"'{expd}' is already exist and not empty.")
    else:
        print(f"[WARNING] {expd} is already exist but empty")

### Data Preprocessing

#### MAE

In [None]:
if env == 'local':
    data_dir = '../../Datasets/Dataset002_BRATS2017/'
elif env == 'server':
    data_dir = '/store8/njrue/Datasets/Dataset002_BRATS2017/'
split_json = 'BraTS2017_ipiu.json'

datasets = data_dir+split_json
train_files = load_decathlon_datalist(datasets, is_segmentation=False, data_list_key="training")
val_files = load_decathlon_datalist(datasets, is_segmentation=False, data_list_key="validation")

In [None]:
pd.DataFrame(train_files).columns

In [None]:
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """
    Convert labels to multi channels based on brats classes:
    label 2 is the peritumoral edema (+roi3)
    label 4 is the GD-enhancing tumor (+roi2)
    label 1 is the necrotic tumor core (roi1)

    """

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []

            # 연구실 논문에서 규정한 class
            # roi1
            result.append(d[key] == 1)
            # # roi2
            # result.append(np.logical_or(d[key] == 4, d[key] == 1))
            # # roi3
            # result.append(
            #     np.logical_or(np.logical_or(d[key] == 1, d[key] == 2), d[key] == 4)
            # )

            d["seg"] = np.stack(result, axis=0).astype(np.float32)
        return d

In [None]:
train_transforms = Compose([
    LoadImaged(
        keys=modality
    ),
    EnsureChannelFirstd(
        keys=mo_img
    ),
    ConvertToMultiChannelBasedOnBratsClassesd(keys=['seg']), 
    # Orientationd(keys=modality, axcodes="RAS"), # 이게 seg에는 적용이 안되는 것 같음.
    # ScaleIntensityd(keys=mo_img),
    NormalizeIntensityd(
        keys=mo_img, 
        nonzero=True, 
        channel_wise=True
    ),
    EnsureTyped(
        keys=modality
    ),
    CropForegroundd(keys=modality, source_key=mo_img[0]),
    SpatialPadd(keys=modality, spatial_size=[128,128,128]), # spatial size보다 input 이미지가 크면 padding 안함.
    RandCropByPosNegLabeld(
        keys=modality,
        label_key="seg",
        spatial_size=[128,128,128],
        pos=1,
        neg=0,
        num_samples=1,
    ),
    # RandSpatialCropSamplesd(
    #     keys=modality,
    #     roi_size=[128,128,128],
    #     random_size=False,
    #     num_samples=1,
    # ),
])

val_transforms = copy.deepcopy(train_transforms)

#### Pyradiomics
- 최초 1회만 실행하면 됨. (전처리 결과를 새로운 csv 파일로 만들어 저장하기 때문에)

##### Load feature & Split X, y

In [None]:
# features = pd.read_csv('dataset/pyradiomics_feature.csv', sep=',')
# features[:5]

In [None]:
# le = LabelEncoder() # HGG=0, LGG=1

# X = features.drop('target',axis=1)
# y = pd.Series(le.fit_transform(features.target))

In [None]:
# # z-normalization
# z_norm = StandardScaler()

# z_norm.fit(X)
# X_norm = pd.DataFrame(z_norm.transform(X), columns=X.columns).fillna(0.0)

##### Split train / test

In [None]:
val_idx = np.append(np.append(np.arange(16,20),np.arange(91,108)),np.arange(190,210))
val_idx = np.append(np.append(val_idx, np.arange(218,220)),np.arange(273,285))
train_idx = np.setdiff1d(np.arange(0,len(train_files)+len(val_files)),val_idx)

In [None]:
# new_dataset = pd.concat([X_norm, features.target], axis=1)
# valid_dataset = new_dataset.iloc[val_idx]
# train_dataset = new_dataset.iloc[train_idx]

In [None]:
# valid_dataset.to_csv('./dataset/pyradiomics_val.csv', index=False)
# train_dataset.to_csv('./dataset/pyradiomics_train.csv', index=False)

### Load Dataset

#### MAE

In [None]:
if mode != 'pyradiomics':
    train_ds = CacheDataset(data=train_files, transform=train_transforms,
                            cache_rate=1.0,
                            num_workers=8,)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, 
                            collate_fn=pad_list_data_collate,)

    val_ds = CacheDataset(data=val_files, transform=val_transforms,
                            cache_rate=1.0,
                            num_workers=8,)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True,
                            collate_fn=pad_list_data_collate,)

In [None]:
if mode != 'pyradiomics':
    check_train = first(train_loader)
    print("**Original Image shape: (1, 1, 240, 240, 155), Segmentation shape: (1, 3, 240, 240, 155)")
    print(f"\t   Image shape: {np.shape(check_train['t1'])}, Segmentation shape: {np.shape(check_train['seg'])}")

In [None]:
if mode != 'pyradiomics':
    plt.figure(figsize=(3,3))
    plt.imshow(check_train['seg'][0][0].permute(1,0,2)[:, :, 64])
    plt.show()

In [None]:
if mode != 'pyradiomics':
    plt.figure(figsize=(3,3))
    plt.imshow(check_train['t1ce'][0][0].permute(1,0,2)[:, :, 64], cmap="gray")
    plt.show()

#### Pyradiomics

In [None]:
train_dataset = pd.read_csv('dataset/pyradiomics_train.csv', sep=',') # .sample(frac=1)
val_dataset = pd.read_csv('dataset/pyradiomics_val.csv', sep=',') # .sample(frac=1)

In [None]:
# shape feature가 너무 절대적이기 때문에 shape feature가 없는 버전을 만드는 중

train_dataset_noShape = copy.deepcopy(train_dataset)
val_dataset_noShape = copy.deepcopy(val_dataset)

for col in train_dataset.columns:
    if 'shape' in col:
        train_dataset_noShape.drop(columns=col, inplace=True)
        val_dataset_noShape.drop(columns=col, inplace=True)

In [None]:
if noShape == True:
    train_dataset = train_dataset_noShape
    val_dataset = val_dataset_noShape

In [None]:
le = LabelEncoder() # HGG=0, LGG=1

X_train = train_dataset.drop('target',axis=1)
y_train = pd.Series(le.fit_transform(train_dataset.target))
X_val = val_dataset.drop('target',axis=1)
y_val = pd.Series(le.transform(val_dataset.target))

In [None]:
# Convert to numpy array
X_train = np.array(X_train)
y_train = np.array(y_train)
X_val = np.array(X_val)
y_val = np.array(y_val)

In [None]:
# Convert to torch.tensor
X_train_t = torch.from_numpy(X_train).to(torch.float32)
y_train_t = torch.from_numpy(y_train).to(torch.float32).reshape(-1,1)
X_val_t = torch.from_numpy(X_val).to(torch.float32)
y_val_t = torch.from_numpy(y_val).to(torch.float32).reshape(-1,1)

In [None]:
X_val_t.shape

In [None]:
# Create Dataset
training_data = TensorDataset(X_train_t, y_train_t)
val_data = TensorDataset(X_val_t, y_val_t)

# Create DataLoader
train_loader_py = DataLoader(training_data, batch_size=batch_size, shuffle=True,)
val_loader_py = DataLoader(val_data, batch_size=batch_size, shuffle=False)

for X, y in train_loader_py:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

#### Both: stack dataset

In [None]:
if mode == 'ensemble':
    # Stack two datasets
    stack_train_ds = StackDataset(train_ds, training_data)
    stack_val_ds = StackDataset(val_ds, val_data)

    # Create Dataloader
    stack_train_loader = DataLoader(stack_train_ds, batch_size=batch_size, shuffle=True,)
    stack_val_loader = DataLoader(stack_val_ds, batch_size=batch_size, shuffle=False,)

### Model define

#### MAE

In [None]:
from monai.networks.blocks.transformerblock import TransformerBlock
from monai.networks.nets import ViT

In [None]:
class MAE_MLP(nn.Module):
    """
    Vision Transformer (ViT), based on: "Dosovitskiy et al.,
    An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"

    Modified to also give same dimension outputs as the input size of the image
    """

    def __init__(
        self,
        in_channels: int,
        img_size: Union[Sequence[int], int],
        patch_size: Union[Sequence[int], int],
        hidden_size: int = 768,
        mlp_dim: int = 3072,
        num_layers: int = 12,
        num_heads: int = 12,
        pos_embed: str = "perceptron",
        dropout_rate: float = 0.0,
        spatial_dims: int = 3,
        masking_ratio: float = 0.0,
        revise_keys=[("model.", "")],
        **kwargs,
    ) -> None:
        """
        Args:
            in_channels: dimension of input channels or the number of channels for input
            img_size: dimension of input image.
            patch_size: dimension of patch size.
            hidden_size: dimension of hidden layer.
            mlp_dim: dimension of feedforward layer.
            num_layers: number of transformer blocks.
            num_heads: number of attention heads.
            pos_embed: position embedding layer type.
            dropout_rate: faction of the input units to drop.
            spatial_dims: number of spatial dimensions.
        """

        super().__init__()
        self.spatial_dims = spatial_dims
        self.masking_ratio = masking_ratio

        self.encoder = ViT(
            in_channels=in_channels,
            img_size=img_size,
            patch_size=patch_size,
            hidden_size=hidden_size,
            mlp_dim=mlp_dim,
            num_layers=num_layers,
            num_heads=num_heads,
            pos_embed=pos_embed,
            dropout_rate=dropout_rate,
            spatial_dims=spatial_dims,
        )

        # patch embedding block
        patch_embedding = self.encoder.patch_embedding
        self.to_patch, self.patch_to_emb = patch_embedding.patch_embeddings

        self.MAE_outhidden = nn.Sequential(
            # patch 별로 768개의 feature를 가지고 있는 것을 64로 축약
            nn.Linear(in_features=768, out_features=256, bias=True), # 768 > 64
            nn.Tanh(),
            nn.Linear(in_features=256, out_features=64, bias=True),
            nn.Tanh(),
        )
        self.flatten = nn.Flatten()
        self.MAE_hidden = nn.Sequential(
            # nn.Linear(393216, 1024, bias=True), # 96일 때 165888(4*96*768), 128일 때 393216(4*128*768)
            nn.Linear(32768, 4096, bias=True),
            nn.ReLU(),
            nn.Linear(4096, 1024, bias=True),
            nn.ReLU(),
            nn.Linear(1024, 512, bias=True), # 96일 때 165888
            nn.ReLU(),
            nn.Linear(512,256,bias=True)
        )
        self.classifier = nn.Sequential(
            nn.Linear(256, 1, bias=True),
            # nn.Sigmoid(),
        )
        self.init_weights(revise_keys=revise_keys)

    def init_weights(self, pretrained=None, revise_keys=[]):
        """Initialize the weights in backbone.

        Args:
            pretrained (str, optional): Path to pre-trained weights.
                Defaults to None.
        """

        def _init_weights(m):
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=0.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)

        self.apply(_init_weights)

    def forward(self, x):
        """
        Args:
            x: input tensor must have isotropic spatial dimensions,
                such as ``[batch_size, channels, sp_size, sp_size[, sp_size]]``.
        """

        # get patches
        patches = self.to_patch(x)

        # patch to encoder tokens and add positions
        tokens = self.patch_to_emb(patches)
        
        tokens = tokens + self.encoder.patch_embedding.position_embeddings
        
        for blk in self.encoder.blocks:
            tokens = blk(tokens)
        encoded_tokens = tokens # 이게 인코더를 거친 결과? latent space??
        encoded_tokens = self.MAE_outhidden(encoded_tokens)
        latent_space = self.flatten(encoded_tokens)

        hidden = self.MAE_hidden(latent_space)
        pred = self.classifier(hidden)

        return pred

# print(MAE_MLP(in_channels=4, img_size=[96,96,96],patch_size=[16,16,16]))

#### pyradiomics MLP

In [None]:
# Define MLP with pyradiomics features
class py_MLP(torch.nn.Module):
  def __init__(self):
    super(py_MLP, self).__init__()

    self.input = nn.Sequential(
      nn.Linear(X.size(-1), 512, bias=True),
      nn.Tanh(),
    )
    self.hidden = nn.Sequential(
      nn.Linear(512, 256, bias=True),
      nn.ReLU(),
      nn.Linear(256, 128, bias=True), # 256 > 128
      nn.ReLU(),
    )
    self.output = nn.Sequential(
      nn.Linear(128, 1, bias=True),
      # nn.Sigmoid(),
    )
    # torch.nn.init.xavier_uniform_(self.output.weight),

  def forward(self, x):
    pred = self.input(x)
    hidden = self.hidden(pred)
    pred = self.output(hidden)
    return pred

print(py_MLP())

#### Ensemble

In [None]:
class en_MLP(nn.Module):
    """
    Vision Transformer (ViT), based on: "Dosovitskiy et al.,
    An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"

    Modified to also give same dimension outputs as the input size of the image
    """

    def __init__(
        self,
        in_channels: int,
        img_size: Union[Sequence[int], int],
        patch_size: Union[Sequence[int], int],
        hidden_size: int = 768,
        mlp_dim: int = 3072,
        num_layers: int = 12,
        num_heads: int = 12,
        pos_embed: str = "perceptron",
        dropout_rate: float = 0.0,
        spatial_dims: int = 3,
        masking_ratio: float = 0.0,
        revise_keys=[("model.", "")],
        **kwargs,
    ) -> None:
        """
        Args:
            in_channels: dimension of input channels or the number of channels for input
            img_size: dimension of input image.
            patch_size: dimension of patch size.
            hidden_size: dimension of hidden layer.
            mlp_dim: dimension of feedforward layer.
            num_layers: number of transformer blocks.
            num_heads: number of attention heads.
            pos_embed: position embedding layer type.
            dropout_rate: faction of the input units to drop.
            spatial_dims: number of spatial dimensions.

        """

        super().__init__()
        self.spatial_dims = spatial_dims
        self.masking_ratio = masking_ratio

        self.encoder = ViT(
            in_channels=in_channels,
            img_size=img_size,
            patch_size=patch_size,
            hidden_size=hidden_size,
            mlp_dim=mlp_dim,
            num_layers=num_layers,
            num_heads=num_heads,
            pos_embed=pos_embed,
            dropout_rate=dropout_rate,
            spatial_dims=spatial_dims,
        )

        # patch embedding block
        patch_embedding = self.encoder.patch_embedding
        self.to_patch, self.patch_to_emb = patch_embedding.patch_embeddings

        self.MAE_outhidden = nn.Sequential(
            # patch 별로 768개의 feature를 가지고 있는 것을 64로 축약
            nn.Linear(in_features=768, out_features=256, bias=True), # 768 > 64
            nn.Tanh(),
            nn.Linear(in_features=256, out_features=64, bias=True),
            nn.Tanh(),
        )
        self.flatten = nn.Flatten()
        self.MAE_hidden = nn.Sequential(
            # nn.Linear(393216, 1024, bias=True), # 96일 때 165888(4*96*768), 128일 때 393216(4*128*768)
            nn.Linear(32768, 4096, bias=True),
            nn.ReLU(),
            nn.Linear(4096, 1024, bias=True),
            nn.ReLU(),
            nn.Linear(1024, 512, bias=True), # 96일 때 165888
            nn.ReLU(),
        )
        self.py_hidden = nn.Sequential(
            nn.Linear(X.size(-1),512,bias=True),
            nn.Tanh(),
            nn.Linear(512,256,bias=True),
            nn.ReLU(),
        )
        self.hidden = nn.Sequential(
            # nn.Linear(X.size(-1)+1024,256, bias=True),
            # nn.ReLU(),
            nn.Linear(512+256,512, bias=True),
            nn.ReLU(),
            nn.Linear(512,256, bias=True), # 256 -> 64
            nn.ReLU(),
        )
        self.classifier = nn.Sequential(
            nn.Linear(256, 1, bias=True),
            # nn.Sigmoid(),
        )

        self.init_weights(revise_keys=revise_keys)

    def init_weights(self, pretrained=None, revise_keys=[]):
        """Initialize the weights in backbone.

        Args:
            pretrained (str, optional): Path to pre-trained weights.
                Defaults to None.
        """

        def _init_weights(m):
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=0.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)

        self.apply(_init_weights)

    def forward(self, x_mae, x_py):
        """
        Args:
            x: input tensor must have isotropic spatial dimensions,
                such as ``[batch_size, channels, sp_size, sp_size[, sp_size]]``.
        """

        # get patches
        patches = self.to_patch(x_mae)

        # patch to encoder tokens and add positions
        tokens = self.patch_to_emb(patches)
        
        tokens = tokens + self.encoder.patch_embedding.position_embeddings
        
        for blk in self.encoder.blocks:
            tokens = blk(tokens)
        encoded_tokens = tokens
        encoded_tokens = self.MAE_outhidden(encoded_tokens)
        latent_space = self.flatten(encoded_tokens)

        mae_hidden = self.MAE_hidden(latent_space)
        py_hidden = self.py_hidden(x_py)
        
        concat_inputs = torch.concat((mae_hidden,py_hidden), dim=1)
        hidden = self.hidden(concat_inputs)
        pred = self.classifier(hidden)

        return pred

# print(en_MLP(in_channels=4, img_size=[96,96,96],patch_size=[16,16,16]))

### Load model

In [None]:
def load_encoder_weight(model_dir, model):
    state_dict = torch.load(model_dir, map_location=device)
    
    for block in ['encoder','to_patch','patch_to_emb']:
        state_dict_temp = copy.deepcopy(state_dict)
        for key in list(state_dict.keys()):
            if key == 'encoder.patch_embedding.position_embeddings':
                new_key = key.replace(block+'.', "")
                state_dict_temp[new_key] = model.encoder.patch_embedding.position_embeddings
                _ = state_dict_temp.pop(key)
            elif block in key:
                new_key = key.replace(block+'.', "")
                state_dict_temp[new_key] = state_dict_temp.pop(key)
            else:
                _ = state_dict_temp.pop(key)
        
        if block=='encoder':
            model.encoder.load_state_dict(state_dict_temp)
        elif block=='to_patch':
            model.to_patch.load_state_dict(state_dict_temp)
        elif block=='patch_to_emb':
            model.patch_to_emb.load_state_dict(state_dict_temp)
    return model

In [None]:
if mode=='pyradiomics':
    model = py_MLP().to(device)
elif mode=='MAE':
    model = MAE_MLP(
        in_channels=4,
        img_size=[128,128,128],
        patch_size=[16,16,16],
        hidden_size=768,
        mlp_dim=3072,
        num_layers=12,
        num_heads=12,
        pos_embed='perceptron',
        dropout_rate=0.0,
        spatial_dims=3,
    ).to(device)
    model = load_encoder_weight(model_dir=model_dir, model=model)
else:
    model = en_MLP(
        in_channels=4,
        img_size=[128,128,128],
        patch_size=[16,16,16],
        hidden_size=768,
        mlp_dim=3072,
        num_layers=12,
        num_heads=12,
        pos_embed='perceptron',
        dropout_rate=0.0,
        spatial_dims=3,
    ).to(device)
    model = load_encoder_weight(model_dir=model_dir, model=model)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
# Loss function
# loss_function = nn.BCEWithLogitsLoss()
loss_function = nn.BCELoss()
# CrossEntropyLoss: 다중클래스 분류에 사용하는 loss func

# BCEWithLogitsLoss: Sigmoid + BCELoss (BCELoss와 Sigmoid를 따로 쓰는 것보다 더 안정적)
# Output이 0~1 사이 값일 필요가 없음

# BCELoss : Sigmoid를 모델 안에 추가하면 안되고, 모델의 출력 값에 따로 저장해줘야함. Sigmoid를 모델 안에 넣게되면 학습이 잘 안됨.

### Training

#### Define function: MAE

In [None]:

def train_MAE(dataloader, model, loss_fn, optimizer):
    model.train()
    size = len(dataloader.dataset)

    pred_list, y_list = [], []
    train_running_loss = 0
    train_running_correct = 0
    for batch_data in dataloader:
        flairs, t1s, t1ces, t2s = batch_data['flair'].to(device), batch_data['t1'].to(device), \
            batch_data['t1ce'].to(device), batch_data['t2'].to(device)
        X_mae = torch.cat((flairs, t1s, t1ces, t2s),dim=1)
        y = (batch_data['label']).type(torch.float).reshape([-1,1]).to(device)
        
        # 예측 오류 계산
        pred = model(X_mae) # forward pass
        pred = nn.Sigmoid()(pred)
        y_list.extend([i.item() for i in list(y.cpu())]); pred_list.extend([i.item() for i in list(pred.cpu())])
        loss = loss_fn(pred, y) # calculate the loss
        train_running_loss += loss.item()

        # Calculate the accuracy
        y_pred = (pred>=0.5).type(torch.int8)
        train_running_correct += torch.eq(y_pred, y).sum().item()

        # 역전파
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    epoch_loss = train_running_loss / len(dataloader)
    epoch_acc = train_running_correct / size
    print(f"Train: \n Accuracy: {epoch_acc:>0.5f}%, Avg loss: {epoch_loss:>7f}")
    return epoch_loss, epoch_acc, y_list, pred_list

In [None]:
def valid_MAE(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    
    pred_list, y_list = [], []
    valid_loss, correct = 0, 0
    model.eval()
    with torch.no_grad(): # no training
        for batch_data in dataloader:
            flairs, t1s, t1ces, t2s = batch_data['flair'].to(device), batch_data['t1'].to(device), \
                batch_data['t1ce'].to(device), batch_data['t2'].to(device)                
            X_mae = torch.cat((flairs, t1s, t1ces, t2s),dim=1)  
            y = (batch_data['label']).type(torch.float).reshape([-1,1]).to(device)
            
            pred = model(X_mae)
            pred = nn.Sigmoid()(pred)
            y_list.extend([i.item() for i in list(y.cpu())]); pred_list.extend([i.item() for i in list(pred.cpu())])
            valid_loss += loss_fn(pred, y).item()
            # correct += torch.eq((pred>0).type(torch.int8), y).sum().item()
            correct += torch.eq((pred>=0.5).type(torch.int8), y).sum().item()
    valid_loss /= num_batches
    correct /= size
    print(f"Validation: \n Accuracy: {correct:>0.5f}%, Avg loss: {valid_loss:>8f} \n")

    return valid_loss, correct, y_list, pred_list

#### Define function: pyradiomics

In [None]:
def train_py(dataloader, model, loss_fn, optimizer):
    model.train()
    y_list, pred_list = [],[]
    size = len(dataloader.dataset)

    train_running_loss = 0
    train_running_correct = 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # 예측 오류 계산
        pred = model(X) # forward pass
        pred = nn.Sigmoid()(pred)
        y_list.extend([i.item() for i in list(y.cpu())]); pred_list.extend([i.item() for i in list(pred.cpu())])
        loss = loss_fn(pred, y) # calculate the loss
        train_running_loss += loss.item()

        # Calculate the accuracy
        # y_pred = (pred>0).type(torch.int8)
        y_pred = (pred>=0.5).type(torch.int8)
        train_running_correct += torch.eq(y_pred, y).sum().item()

        # 역전파
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    epoch_loss = train_running_loss / len(dataloader)
    epoch_acc = train_running_correct / size
    print(f"Train: \n Accuracy: {epoch_acc:>0.5f}%, Avg loss: {epoch_loss:>7f}")
    
    return epoch_loss, epoch_acc, y_list, pred_list

In [None]:
def valid_py(dataloader, model, loss_fn):
    y_list, pred_list = [],[]
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    valid_loss, correct = 0, 0
    with torch.no_grad(): # no training
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            pred = nn.Sigmoid()(pred)
            y_list.extend([i.item() for i in list(y.cpu())]); pred_list.extend([i.item() for i in list(pred.cpu())])
            valid_loss += loss_fn(pred, y).item()
            # correct += torch.eq((pred>0).type(torch.int8), y).sum().item()
            correct += torch.eq((pred>=0.5).type(torch.int8), y).sum().item()
    valid_loss /= num_batches
    correct /= size
    print(f"Validation: \n Accuracy: {correct:>0.5f}%, Avg loss: {valid_loss:>8f} \n")

    return valid_loss, correct, y_list, pred_list

#### Define function: ensemble(MAE + pyradiomics)

In [None]:
def train_en(dataloader, model, loss_fn, optimizer):
    model.train()
    size = len(dataloader.dataset)

    pred_list, y_list = [], []
    train_running_loss = 0
    train_running_correct = 0
    for batch_data, (X, y) in dataloader:
        '''
        - batch_data: MRI images (4 modality)
        - X, y: pyradiomics feature 
        '''
        # for i in range(batch_size):
        flairs, t1s, t1ces, t2s = batch_data[0]['flair'].to(device), batch_data[0]['t1'].to(device), \
            batch_data[0]['t1ce'].to(device), batch_data[0]['t2'].to(device)
        X_mae = torch.cat((flairs, t1s, t1ces, t2s),dim=1)
        
        X_py, y = X.to(device), y.to(device)

        # 예측 오류 계산
        pred = model(X_mae, X_py) # forward pass
        pred = nn.Sigmoid()(pred)
        y_list.extend([i.item() for i in list(y.cpu())]); pred_list.extend([i.item() for i in list(pred.cpu())])
        
        loss = loss_fn(pred, y) # calculate the loss
        train_running_loss += loss.item()

        # Calculate the accuracy
        # y_pred = (pred>0).type(torch.int8)
        y_pred = (pred>=0.5).type(torch.int8)
        train_running_correct += torch.eq(y_pred, y).sum().item()

        # 역전파
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    epoch_loss = train_running_loss / len(dataloader)
    epoch_acc = train_running_correct / size
    print(f"Train: \n Accuracy: {epoch_acc:>0.5f}%, Avg loss: {epoch_loss:>7f}")

    return epoch_loss, epoch_acc, y_list, pred_list

In [None]:
def valid_en(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    
    pred_list, y_list = [], []
    valid_loss, correct = 0, 0
    model.eval()
    with torch.no_grad(): # no training
        for batch_data, (X, y) in dataloader:
            '''
            - batch_data: MRI images (4 modality)
            - X, y: pyradiomics feature 
            '''
            # for i in range(batch_size):
            flairs, t1s, t1ces, t2s = batch_data[0]['flair'].to(device), batch_data[0]['t1'].to(device), \
                batch_data[0]['t1ce'].to(device), batch_data[0]['t2'].to(device)
            X_mae = torch.cat((flairs, t1s, t1ces, t2s),dim=1)
            
            X_py, y = X.to(device), y.to(device)
            pred = model(X_mae, X_py)
            pred = nn.Sigmoid()(pred)
            y_list.extend([i.item() for i in list(y.cpu())]); pred_list.extend([i.item() for i in list(pred.cpu())])
            
            valid_loss += loss_fn(pred, y).item()
            # correct += torch.eq((pred>0).type(torch.int8), y).sum().item()
            correct += torch.eq((pred>=0.5).type(torch.int8), y).sum().item()
    valid_loss /= num_batches
    correct /= size
    print(f"Validation: \n Accuracy: {correct:>0.5f}%, Avg loss: {valid_loss:>8f} \n")

    return valid_loss, correct, y_list, pred_list

#### Training

In [None]:
epoch_loss_values, val_loss_values = [], []
train_acc, val_acc = [], []

In [None]:
# state_dict = torch.load(os.path.join(root_dir, "best_model.pth"), map_location=device)
# model.load_state_dict(state_dict)

In [None]:
for ep in range(epochs):
    print(f"Epoch {ep+1}\n-------------------------------")
    if mode == 'MAE':
        train_epoch_loss, train_epoch_acc, y_list, pred_list = train_MAE(train_loader, model, loss_function, optimizer)
        val_epoch_loss, val_epoch_acc, val_y_list, val_pred_list = valid_MAE(val_loader, model, loss_function)
    elif mode == 'pyradiomics':
        train_epoch_loss, train_epoch_acc, y_list, pred_list = train_py(train_loader_py, model, loss_function, optimizer)
        val_epoch_loss, val_epoch_acc, val_y_list, val_pred_list = valid_py(val_loader_py, model, loss_function)
    else:
        train_epoch_loss, train_epoch_acc, y_list, pred_list = train_en(stack_train_loader, model, loss_function, optimizer)
        val_epoch_loss, val_epoch_acc, val_y_list, val_pred_list = valid_en(stack_val_loader, model, loss_function)
    
    if len(val_loss_values)==0 or val_epoch_loss < min(val_loss_values):
        torch.save(model.state_dict(), os.path.join(root_dir,f"best_model.pth"))
        print("Saved Best Model State to model.pth")
        
        best_y_list = copy.deepcopy(val_y_list)
        best_pred_list = copy.deepcopy(val_pred_list)
    epoch_loss_values.append(train_epoch_loss)  ;train_acc.append(train_epoch_acc)
    val_loss_values.append(val_epoch_loss)      ;val_acc.append(val_epoch_acc)

torch.save(model.state_dict(), os.path.join(root_dir,f"last_model.pth"))
print("Saved Last Model State to model.pth")
print(f"Train Done!")

In [None]:
min_idx = torch.argmin (torch.tensor(val_loss_values))
max_idx = torch.argmax(torch.tensor(val_acc))

print(f"Minimum validation loss is {val_loss_values[min_idx]:.5} in epoch {min_idx}",
      f"\nMaximum validation loss is {val_acc[max_idx]:.8} in epoch {max_idx}")

f= open(os.path.join(root_dir,"result_summary.txt"),"w")
f.write(f"Minimum validation loss is {val_loss_values[min_idx]:.5} in epoch {min_idx}\n")
f.write(f"Maximum validation loss is {val_acc[max_idx]:.8} in epoch {max_idx}\n")

In [None]:
df = pd.DataFrame({'epoch':range(len(epoch_loss_values)),
                   'train_loss':epoch_loss_values,  'val_loss':val_loss_values,
                   'train_acc':train_acc,           'val_acc':val_acc})
df = df.set_index('epoch')
df.to_csv(os.path.join(root_dir,'save_results.csv'))

In [None]:
metric_values = val_acc
fig = plt.figure("train/valid", (18, 12))

fig.add_subplot(2, 2, 1)
plt.title('Iteration Train Loss')
plt.xlabel("Iteration")
plt.plot(epoch_loss_values, label='train loss')

fig.add_subplot(2, 2, 2)
plt.title("Iteration Average Loss")
x = [(i + 1) for i in range(len(val_loss_values))]
# y = epoch_loss_values
y = [epoch_loss_values[i-1] for i in x]
y2 = [x for x in val_loss_values]
plt.xlabel("Iteration")
plt.plot(x, y, label='train loss')
plt.plot(x, y2, label='validation loss')
plt.legend()

fig.add_subplot(2, 2, 3)
plt.title("Val Metric")
x = [(i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("Iteration")
plt.plot(x,y)

fig.add_subplot(2, 2, 4)
plt.title("Val Metric")
plt.xlabel("Iteration")
devide_num = int(len(y)*0.2)
if devide_num==0: devide_num=1
x2 = range(0,len(metric_values),devide_num)
y2 = [metric_values[i] for i in x2]
plt.bar(x2,y2,color=palette[:len(x)], width=devide_num*0.75)
plt.savefig(os.path.join(root_dir,f'results_graph_{expd}.png'), bbox_inches='tight')
plt.show()

### Performance Check

In [None]:
# Accuracy, Confusion matrix
from sklearn.metrics import roc_auc_score, RocCurveDisplay
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

In [None]:
best_pred_label = [float(i>=0.5) for i in best_pred_list]

In [None]:
# calculate AUC
print(f"validation AUC :  {roc_auc_score(y_true=best_y_list, y_score=best_pred_list)}")

f.write(f"\nvalidation AUC :  {roc_auc_score(y_true=best_y_list, y_score=best_pred_list)}")

In [None]:
# ROC curve
_, ax = plt.subplots(figsize=(7,7))
disp = RocCurveDisplay.from_predictions(y_true=best_y_list, y_pred=best_pred_list, ax=ax) # display_labels=classes, --> 이미 actual, predicted가 라벨명으로 표현되어있어서 따로 주지 않아도 됨.
disp.ax_.set_title(f'ROC curve ({mode})')
ax.plot([0, 1], [0, 1], color='#FE5A6D', label='Random Model')
ax.legend(loc='lower right')
plt.savefig(os.path.join(root_dir,f'ROC curve(Validation).png'))

In [None]:
cm = confusion_matrix(y_true=best_y_list, y_pred=best_pred_label)

 # Sensitivity
sensitivity = cm[0,0]/(cm[0,0]+cm[0,1])
# Specificity
specificity = cm[1,1]/(cm[1,0]+cm[1,1])

print(f"validation Sensitivity :  {sensitivity}")
print(f"validation Specificity :  {specificity}")

f.write(f"validation Sensitivity :  {sensitivity}\n")
f.write(f"validation Specificity :  {specificity}\n")

In [None]:
# confusion matrix
_, ax = plt.subplots(figsize=(12,10))
disp = ConfusionMatrixDisplay.from_predictions(y_true=best_y_list, y_pred=best_pred_label, cmap=plt.cm.Blues, ax=ax, display_labels=le.classes_) # display_labels=classes, --> 이미 actual, predicted가 라벨명으로 표현되어있어서 따로 주지 않아도 됨.
disp.ax_.set_title(f'Confusion matrix ({mode})')
plt.savefig(os.path.join(root_dir,f'Confusion matrix(Validation).png'))

In [None]:
f.close()