## Setup

### Environment

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

In [None]:
'''
import sys
import subprocess

# monai 설치 확인 및 설치
try:
    import monai
except ImportError:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "monai-weekly[tqdm]"])

# matplotlib 설치 확인 및 설치
try:
    import matplotlib
except ImportError:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "matplotlib"])

# matplotlib inline을 대체
import matplotlib.pyplot as plt
plt.ion()  # 인터랙티브 모드 활성화
'''

### Library

In [None]:
import os
import sys
import csv
import PIL
import copy
import torch
import random
import tempfile
import numpy as np
import pandas as pd
import nibabel as nib
import matplotlib.pyplot as plt

from PIL import Image
from tqdm import tqdm

import keras
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint

from scipy.stats import pearsonr

import monai
from monai.data import Dataset, CacheDataset, DataLoader
from monai.utils import first, set_determinism

from monai.transforms import (
    Compose,
    Lambdad,
    Resized,
    Randomizable,
    EnsureChannelFirstd,
    ScaleIntensityRanged,
    RepeatChanneld,
    Transposed
)

In [None]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

set_determinism(42)

keras.utils.set_random_seed(42)
tf.config.experimental.enable_op_determinism()

In [None]:
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
    print("We got a GPU")
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
else:
    print("Sorry, no GPU for you...")

## Camcan Dataset

In [None]:
class CamcanDataset(Randomizable, CacheDataset):
    def __init__(
        self,
        root_dir,
        csv_file,
        section,
        transform=None,
        seed=0,
        val_frac=0.2,
        test_frac=0.2,
        cache_num=sys.maxsize,
        cache_rate=1.0,
        num_workers=0,
        progress: bool = True,
        condition_prob = 0,
    ) -> None:
        if not os.path.isdir(root_dir):
            raise ValueError("Root directory root_dir must be a directory.")
        self.root_dir = root_dir
        self.csv_file = csv_file
        self.section = section
        self.val_frac = val_frac
        self.test_frac = test_frac
        self.condition_prob = condition_prob
        self.set_random_state(seed=seed)

        data = self._generate_data_list()

        CacheDataset.__init__(
            self,
            data=data,
            transform=transform,
            cache_num=cache_num,
            cache_rate=cache_rate,
            num_workers=num_workers,
            progress=progress,
        )

    def randomize(self, data: np.ndarray) -> None:
        self.R.shuffle(data)

    def _generate_data_list(self):
        datalist = []
        with open(self.csv_file, mode='r') as file:
            reader = csv.DictReader(file)
            for row in reader:
                image_path = os.path.join(self.root_dir, f"sub-{row['Subject']}_defaced_T1.nii.gz")
                if not os.path.exists(image_path):
                    continue
                img = nib.load(image_path)
                img_data = img.get_fdata()
                
                for slice_idx in range(img_data.shape[2]//2 - 20, img_data.shape[2]//2 + 20):  # Assuming axial slices
                    slice_data = img_data[:,:,slice_idx]
                    age = np.array([int(row['Age'])]).astype('float32')
                    datalist.append({
                        "image": slice_data,
                        "age": age
                    })
        
        length = len(datalist)
        indices = np.arange(length)
        self.randomize(indices)

        # train, validation, test split
        test_length = int(length * self.test_frac)
        val_length = int(length * self.val_frac)
        if self.section == "test":
            section_indices = indices[:test_length]
        elif self.section == "validation":
            section_indices = indices[test_length : test_length + val_length]
        elif self.section == "training":
            section_indices = indices[test_length + val_length :]
        else:
            raise ValueError(
                f'Unsupported section: {self.section}, available options are ["training", "validation", "test"].'
            )
        return [datalist[i] for i in section_indices]

    def __getitem__(self, index):
        sample = self.data[index]

        if self.transform:
            sample = self.transform(sample)
        return sample

## Load Dataset

### Train

In [None]:
# Usage example
data_dir = "./dataset_camcan_sy"
csv_file = "./phenotype.csv"

# Training DataLoader
train_transforms = Compose(
    [
        EnsureChannelFirstd(keys=["image"], channel_dim='no_channel'),
        ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0),
        Lambdad(keys=["age"], func=lambda x: torch.tensor(x, dtype=torch.float32)),
        Resized(keys=["image"], spatial_size=(96,128)),
        RepeatChanneld(keys=["image"], repeats=3),  # (1, H, W) -> (3, H, W)
        Transposed(keys=["image"], indices=(1, 2, 0)),
    ])

train_ds = CamcanDataset(root_dir=data_dir, csv_file=csv_file, transform=train_transforms, section="training", condition_prob=0.2)
train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, num_workers=8, persistent_workers=True)

### Validation & Test

In [None]:
# Validation & Test DataLoader
val_transforms = Compose(
    [
        EnsureChannelFirstd(keys=["image"], channel_dim='no_channel'),
        ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0),
        Lambdad(keys=["age"], func=lambda x: torch.tensor(x, dtype=torch.float32)),
        Resized(keys=["image"], spatial_size=(96,128)),
        RepeatChanneld(keys=["image"], repeats=3),  # (1, H, W) -> (3, H, W)
        Transposed(keys=["image"], indices=(1, 2, 0)),
    ]
)

val_ds = CamcanDataset(root_dir=data_dir, csv_file=csv_file, transform=val_transforms, section="validation")
val_loader = DataLoader(val_ds, batch_size=8, shuffle=False, num_workers=8, persistent_workers=True)

test_ds = CamcanDataset(root_dir=data_dir, csv_file=csv_file, transform=val_transforms, section="test")
test_loader = DataLoader(test_ds, batch_size=8, shuffle=False, num_workers=8, persistent_workers=True)

## Fine-tuning Model

### Pytorch -> Tensorflow

In [None]:
def pytorch_to_numpy(data_loader):
    images, ages = [], []
    
    for batch in data_loader:
        batch_cp = copy.deepcopy(batch)  
        
        img_batch = batch_cp["image"].numpy()  # (batch, 96, 128, 3)
        age_batch = batch_cp["age"].numpy()  # (batch, 1)
        
        del batch
        images.append(img_batch)
        ages.append(age_batch)
        
    return np.concatenate(images), np.concatenate(ages)

# PyTorch DataLoader → NumPy 변환
train_images, train_ages = pytorch_to_numpy(train_loader)
val_images, val_ages = pytorch_to_numpy(val_loader)

### Load Pretrained Model

In [None]:
# 모델 불러오기
lr = 1e-4
checkpoint_path = './DeepBrainNet/Models/DBN_model.h5'
model = load_model(checkpoint_path, compile=False)

# 옵티마이저 및 손실 함수 설정
model.compile(optimizer=Adam(learning_rate=lr), loss='mse', metrics=['mae'])

# 콜백 함수 설정
reduce_lr = ReduceLROnPlateau(monitor='loss', factor=0.1, patience=5, verbose=1)
#early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True, verbose=1)
#early_stopping_loss = EarlyStopping(monitor='loss', patience=10, restore_best_weights=True, verbose=1)
checkpoint = ModelCheckpoint("./DBN_finetuned/best_{epoch}.h5", monitor="val_loss", save_best_only=True, verbose=1)

#model.summary()

### Train Model

In [None]:
epochs = 100

# 모델 학습
history = model.fit(
    x=train_images, y=train_ages,  # train_loader를 fit()에 직접 전달 가능
    batch_size=8,
    epochs=epochs,
    validation_data=(val_images, val_ages),
    callbacks=[reduce_lr, checkpoint],
    verbose=1  # 1: 자세한 로그 출력 (0: 출력 없음, 2: epoch별 요약만 출력)
)

# 모델 저장 (마지막 epoch 기준)
model.save("./DBN_finetuned/last.h5")