In [1]:
import os

from monai.data   import DataLoader, Dataset, CacheDataset
from monai.config import print_config, USE_COMPILED
from monai.utils import set_determinism, first
from monai.networks.nets import GlobalNet
from monai.networks.blocks import Warp
from monai.apps import MedNISTDataset
import matplotlib.pyplot as plt
from torch.nn import MSELoss
from monai import transforms
import numpy as np
import torch

print_config()
set_determinism(99)

  from .autonotebook import tqdm as notebook_tqdm
  warn(f"Failed to load image Python extension: {e}")


MONAI version: 1.2.0
Numpy version: 1.25.2
Pytorch version: 1.13.1.post200
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: c33f1ba588ee00229a309000e888f9817b4f1934
MONAI __file__: /opt/conda/envs/ml/lib/python3.10/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.12
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: NOT INSTALLED or UNKNOWN VERSION.
scikit-image version: 0.21.0
Pillow version: 9.4.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.14.1
tqdm version: 4.66.1
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.5
pandas version: 2.0.3
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: 4.32.0
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://doc

In [2]:
SEP       = os.path.sep
ROOT_PATH = SEP.join(os.getcwd().split(SEP)[:-4])
DATA_PATH = f'{ROOT_PATH}/Datasets/MedNIST'

os.makedirs(DATA_PATH, exist_ok = True)

In [3]:
train_data     = MedNISTDataset(DATA_PATH, section = 'training', 
                                download = True, transform=None)

## 손 x-ray 데이터 셋을 이용한 image registration
train_datadict = [{'fixed_hand' : item['image'], 'moving_hand' : item['image']}
                  for item in train_data.data if item['label'] == 4]

print(f'sample datas \n{train_datadict[:3]}')

2023-08-28 02:40:03,833 - INFO - Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.
2023-08-28 02:40:03,834 - INFO - File exists: /home/jovyan/dove/utils/TIL/Datasets/MedNIST/MedNIST.tar.gz, skipped downloading.
2023-08-28 02:40:03,835 - INFO - Non-empty folder exists in /home/jovyan/dove/utils/TIL/Datasets/MedNIST/MedNIST, skipped extracting.


Loading dataset: 100%|██████████| 47164/47164 [00:00<00:00, 147409.06it/s]

sample datas 
[{'fixed_hand': '/home/jovyan/dove/utils/TIL/Datasets/MedNIST/MedNIST/Hand/005758.jpeg', 'moving_hand': '/home/jovyan/dove/utils/TIL/Datasets/MedNIST/MedNIST/Hand/005758.jpeg'}, {'fixed_hand': '/home/jovyan/dove/utils/TIL/Datasets/MedNIST/MedNIST/Hand/007758.jpeg', 'moving_hand': '/home/jovyan/dove/utils/TIL/Datasets/MedNIST/MedNIST/Hand/007758.jpeg'}, {'fixed_hand': '/home/jovyan/dove/utils/TIL/Datasets/MedNIST/MedNIST/Hand/001798.jpeg', 'moving_hand': '/home/jovyan/dove/utils/TIL/Datasets/MedNIST/MedNIST/Hand/001798.jpeg'}]





In [4]:
train_transforms = transforms.Compose(
    [
        transforms.LoadImageD(keys=["fixed_hand", "moving_hand"]),
        transforms.EnsureChannelFirstD(keys=["fixed_hand", "moving_hand"]),
        transforms.ScaleIntensityRanged(
            keys=["fixed_hand", "moving_hand"],
            a_min=0.0,
            a_max=255.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        transforms.RandRotateD(keys=["moving_hand"], range_x=np.pi / 4, prob=1.0, keep_size=True, mode="bicubic"),
        transforms.RandZoomD(keys=["moving_hand"], min_zoom=0.9, max_zoom=1.1, prob=1.0, mode="bicubic", align_corners=False),
    ]
)



In [6]:
check_ds     = Dataset(data = train_datadict, transform = train_transforms)
check_loader = DataLoader(check_ds, batch_size = 1, shuffle = True)
check_data   = first(check_loader)
fixed_image  = check_data['fixed_hand'][0][0]
moving_image = check_data['moving_hand'][0][0]

print(f'moving image shape : {moving_image.shape}')
print(f'fixed  image shape : {fixed_image.shape}')

moving image shape : torch.Size([64, 64])
fixed  image shape : torch.Size([64, 64])


In [8]:
train_ds     = CacheDataset(data = train_datadict[:1000], transform = train_transforms)
train_loader = DataLoader(train_ds, batch_size = 16, shuffle = True, num_workers = 2) 

Loading dataset: 100%|██████████| 1000/1000 [00:01<00:00, 886.93it/s]


In [9]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model  = GlobalNet(image_size = (64, 64), spatial_dims = 2, in_channels = 2,
                   num_channel_initial = 16, depth=3).to(device)

image_loss = MSELoss()
if USE_COMPILED: warp_layer = Warp(3, 'border').to(device)
else: warp_layer = Warp('bilinear', 'border').to(device)

optimizer = torch.optim.Adam(model.parameters(), 1e-5)



In [11]:
epochs = 200
losses = []

for epoch in range(1, epochs + 1):
    
    print(f'[{epoch} / {epochs}]')
    model.train()
    
    loss, step = 0, 0
    for batch in train_loader:
        
        step += 1
        optimizer.zero_grad()
        
        moving     = batch['moving_hand'].to(device)
        fixed      = batch['fixed_hand'].to(device)
        ddf        = model(torch.cat((moving, fixed), dim = 1))
        pred_image = warp_layer(moving, ddf)
        
        loss_      = image_loss(pred_image, fixed)
        loss_.backward()
        
        optimizer.step()
        loss += loss_.item()
        
    loss /= step
    losses.append(loss)
    print(f'epoch {epoch} avg loss : {loss:.3f}')

[1 / 200]
epoch 1 avg loss : 0.052
[2 / 200]
epoch 2 avg loss : 0.047
[3 / 200]
epoch 3 avg loss : 0.046
[4 / 200]
epoch 4 avg loss : 0.044
[5 / 200]
epoch 5 avg loss : 0.042
[6 / 200]
epoch 6 avg loss : 0.040
[7 / 200]
epoch 7 avg loss : 0.040
[8 / 200]
epoch 8 avg loss : 0.037
[9 / 200]
epoch 9 avg loss : 0.036
[10 / 200]
epoch 10 avg loss : 0.036
[11 / 200]
epoch 11 avg loss : 0.034
[12 / 200]
epoch 12 avg loss : 0.033
[13 / 200]
epoch 13 avg loss : 0.033
[14 / 200]
epoch 14 avg loss : 0.031
[15 / 200]
epoch 15 avg loss : 0.029
[16 / 200]
epoch 16 avg loss : 0.029
[17 / 200]
epoch 17 avg loss : 0.028
[18 / 200]
epoch 18 avg loss : 0.027
[19 / 200]
epoch 19 avg loss : 0.027
[20 / 200]
epoch 20 avg loss : 0.025
[21 / 200]
epoch 21 avg loss : 0.026
[22 / 200]
epoch 22 avg loss : 0.026
[23 / 200]
epoch 23 avg loss : 0.024
[24 / 200]
epoch 24 avg loss : 0.024
[25 / 200]
epoch 25 avg loss : 0.023
[26 / 200]
epoch 26 avg loss : 0.023
[27 / 200]
epoch 27 avg loss : 0.022
[28 / 200]
epoch 28