In [1]:
import os
import torch
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import nibabel as nib
from sklearn.model_selection import KFold

from monai.utils import first
from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference

from monai.config import print_config
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.networks.nets import UNet
from monai.networks.layers import Norm

from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandSpatialCropd,
    SpatialPadd,
    ScaleIntensityRanged,
    Spacingd,
    RandAffined,
    RandGaussianSmoothd,
    RandGaussianNoised,
)

from monai.data import (
    DataLoader,
    Dataset,
    decollate_batch,
)

# For descriptive error messages
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
print_config()
device = torch.device('cuda')
torch.backends.cudnn.benchmark = True

MONAI version: 1.3.0
Numpy version: 1.26.4
Pytorch version: 2.2.1+cu118
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 865972f7a791bf7b42efbcd87c8402bd865b329e
MONAI __file__: /u/home/s/<username>/miniconda3/envs/kaggle/lib/python3.10/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.2.1
scikit-image version: 0.22.0
scipy version: 1.11.4
Pillow version: 10.2.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.17.1+cu118
tqdm version: 4.65.0
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.8
pandas version: 2.2.1
einops version: 0.7.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNK

In [2]:
torch.cuda.empty_cache()

In [3]:
os.chdir('/u/home/s/skikuchi/scratch/MedAI6/Unet')

In [4]:
data_dir = '/u/home/s/skikuchi/scratch/MedAI6/ai_contest2024'
organs = ['gallbladder','liver','pancreas','spleen','kidney_left','kidney_right','adrenal_gland_left','adrenal_gland_right','aolta','stomach','duodenum']

train_images = sorted(glob.glob(os.path.join(data_dir, "imagesTr", "*.nii")))
train_labels = sorted(glob.glob(os.path.join(data_dir, "labelsTr", "*.nii")))
test_images = sorted(glob.glob(os.path.join(data_dir, "imagesTs", "*.nii")))

data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)]

sub_df = pd.read_csv('../ai_contest2024/sample_submission.csv')
print('num of train images: ',len(train_images),'\nnum of train labels: ',len(train_labels),'\nnum of test images: ',len(test_images))


print('data path\n',train_images[0],'\n',train_labels[0])
sub_df.head()

num of train images:  357 
num of train labels:  357 
num of test images:  600
data path
 /u/home/s/skikuchi/scratch/MedAI6/ai_contest2024/imagesTr/aicontest2024ver2_0002_0000.nii 
 /u/home/s/skikuchi/scratch/MedAI6/ai_contest2024/labelsTr/aicontest2024ver2_0002.nii


Unnamed: 0,id,prediction
0,0486_gallbladder,1 1
1,0486_liver,1 1
2,0486_pancreas,1 1
3,0486_spleen,1 1
4,0486_kidney_left,1 1


In [5]:
def apply_window(image, level, width):
    lower = level - (width / 2)
    upper = level + (width / 2)
    windowed_image = np.clip(image, lower, upper)
    return windowed_image

In [6]:
import albumentations as A
train_aug_list = [
        A.Normalize(mean=0, std=1, max_pixel_value=255, always_apply=True),

        A.Affine(scale={"x":(0.7, 1.3), "y":(0.7, 1.3)}, translate_percent={"x":(0, 0.1), "y":(0, 0.1)}, rotate=(-30, 30), shear=(-20, 20), p=0.5),
        A.RandomBrightnessContrast(brightness_limit=0.4, contrast_limit=0.4, p=0.5),
        A.OneOf([
            A.Blur(blur_limit=3, p=0.2),
            A.MedianBlur(blur_limit=3, p=0.2),
        ], p=1.0),
        A.OneOf([
            A.ElasticTransform(alpha=1, sigma=50, alpha_affine=10, border_mode=1, p=0.5),
            A.GridDistortion(num_steps=5, distort_limit=0.1, border_mode=1, p=0.5)
        ], p=0.4),
        A.OneOf([
            A.Resize(128, 128, cv2.INTER_LINEAR, p=1),
            A.Compose([
                A.PadIfNeeded(512, 512, position="random", border_mode=cv2.BORDER_REPLICATE, p=1.0),
                A.RandomCrop(512, 512, p=1.0)
            ], p=1.0),
        ], p=1.0),
        A.GaussNoise(var_limit=0.05, p=0.2),

        ToTensorV2(transpose_mask=True),
    ]
    train_aug = A.Compose(train_aug_list)
    valid_aug_list = [
        Cut2DFrom3D(p=1.0),
        A.Resize(512, 512, cv2.INTER_LINEAR, p=1),
        ToTensorV2(transpose_mask=True),
    ]
    valid_aug = A.Compose(valid_aug_list)



In [7]:
orig_train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-175,
            a_max=250,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        SpatialPadd(keys=["image", "label"], spatial_size=(128, 128, 128)),
        RandSpatialCropd(keys=["image", "label"], roi_size=(128, 128, 128),random_size=False),

    ]
)
orig_val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-175,
            a_max=250,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
    ]
)

In [13]:
kf = KFold(n_splits=5)
for train_index, val_index in kf.split(data_dicts):
    train_files, val_files = np.array(data_dicts)[train_index], np.array(data_dicts)[val_index]
    break

train_ds = Dataset(data=train_files, transform=orig_train_transforms)
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=2,pin_memory=True)

val_ds = Dataset(data=val_files, transform=orig_val_transforms)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=1,pin_memory=True)

model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=12,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
#model.load_state_dict(torch.load("best_metric_model2.pth"))

loss_function = DiceCELoss(to_onehot_y=True, softmax=True).to(device)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
scaler = torch.cuda.amp.GradScaler()
dice_metric = DiceMetric(include_background=False, reduction="mean")

In [14]:
max_epochs = 25
val_interval = 1
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=12)])
post_label = Compose([AsDiscrete(to_onehot=12)])

T1, T2, T3 = [],[],[]

import time
            
for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        t1 = time.time()
        step += 1
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )
        t2 = time.time()
        T1.append(t2-t1)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
        t3 = time.time()
        T2.append(t3-t2)
        scaler.scale(loss).backward()
        epoch_loss += loss.item()
        scaler.unscale_(optimizer)
        scaler.step(optimizer)
        scaler.update()        
        print(f"{step}/{len(train_ds) // train_loader.batch_size}, " f"train_loss: {loss.item():.4f}")
        t4 = time.time()
        T3.append(t4-t3)
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                roi_size = (128, 128, 128)
                sw_batch_size = 4
                with torch.cuda.amp.autocast():
                    val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model)
                val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                # compute metric for current iteration
                dice_metric(y_pred=val_outputs, y=val_labels)
            # aggregate the final mean dice result
            metric = dice_metric.aggregate().item()
            # reset the status for next validation round
            dice_metric.reset()
            print(metric)
            metric_values.append(metric)
            if (metric > best_metric) and (epoch > 4):
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), "test_best_metric_model.pth")
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f"\nbest mean dice: {best_metric:.4f} "
                f"at epoch: {best_metric_epoch}"
            )
    print(f"time: {time.time()-t}")

----------
epoch 1/25
1/71, train_loss: 3.7248
2/71, train_loss: 3.7083
3/71, train_loss: 3.6746
4/71, train_loss: 3.6344
5/71, train_loss: 3.6101
6/71, train_loss: 3.5711
7/71, train_loss: 3.5570
8/71, train_loss: 3.5311
9/71, train_loss: 3.5016
10/71, train_loss: 3.4844
11/71, train_loss: 3.4497
12/71, train_loss: 3.4234


KeyboardInterrupt: 

In [15]:
T1, T2, T3

([0.005949497222900391,
  0.0057027339935302734,
  0.00577998161315918,
  0.005666017532348633,
  0.005826711654663086,
  0.005689859390258789,
  0.005859851837158203,
  0.005692243576049805,
  0.005830526351928711,
  0.005636453628540039,
  0.005742073059082031,
  0.005641460418701172],
 [0.06929659843444824,
  0.06192278861999512,
  0.06008648872375488,
  0.060167789459228516,
  0.060628652572631836,
  0.0596613883972168,
  0.06324982643127441,
  0.06025385856628418,
  0.06132364273071289,
  0.05964016914367676,
  0.05991196632385254,
  0.05955195426940918],
 [0.0965573787689209,
  0.0877687931060791,
  0.0875389575958252,
  0.08627438545227051,
  0.08630633354187012,
  0.08518815040588379,
  0.0873258113861084,
  0.09030985832214355,
  0.08620476722717285,
  0.0850381851196289,
  0.0859229564666748,
  0.085601806640625])

In [12]:
a = time.time()
for i in range(1000000):
    pass
time.time() - a

0.034300804138183594

In [None]:
model.load_state_dict(torch.load("test_best_metric_model.pth"))
model.eval()
with torch.no_grad():
    for i, val_data in enumerate(val_loader):
        roi_size = (128, 128, 128)
        sw_batch_size = 4
        val_outputs = sliding_window_inference(val_data["image"].to(device), roi_size, sw_batch_size, model)
        # plot the slice [:, :, 80]
        plt.figure("check", (18, 6))
        plt.subplot(1, 3, 1)
        plt.title(f"image {i}")
        plt.imshow(val_data["image"][0, 0, :, :, -20], cmap="gray")
        plt.subplot(1, 3, 2)
        plt.title(f"label {i}")
        plt.imshow(val_data["label"][0, 0, :, :, -20])
        plt.subplot(1, 3, 3)
        plt.title(f"output {i}")
        plt.imshow(torch.argmax(val_outputs, dim=1).detach().cpu()[0, :, :, -20])
        plt.show()
        if i == 5:
            break

In [None]:
import gc

del train_ds, val_ds
del train_loader, val_loader

# ガーベージコレクションの実行
gc.collect()

In [None]:
class_names ={1: 'gallbladder', 2: 'liver', 3: 'pancreas', 4: 'spleen',5:'kidney_left',6:'kidney_right',7:'adrenal_gland_left',8:'adrenal_gland_right',9:'aorta',10:'stomach',11:'duodenum'}
def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)


def encode_rle_multiclass(img_data):
    # 画像データ内のユニークなクラスを取得し、バックグラウンドクラスを除外
    rle_results = {cls: "1 1" for cls in class_names.keys()}
    classes = np.unique(img_data)
    # 各クラスに対するRLE結果を計算
    for cls in classes:
        if cls == 0:  # バックグラウンドクラスをスキップ
            continue
        class_component = (img_data == cls).astype(np.float32)
        rle_encoded = rle_encode(class_component)
        rle_results[cls] = rle_encoded

    return [(class_names[cls], rle_results[cls]) for cls in class_names.keys()]

In [None]:
test_images = sorted(glob.glob(os.path.join(data_dir, "imagesTs", "*.nii")))
test_ids = [path.split('_')[-2] for path in test_images]##　idを取得する。
test_data = [{"image": image, "id": ID} for image,ID in zip(test_images,test_ids)]
post_pred = Compose([AsDiscrete(argmax=True,keepdim=False)])

submission_df = pd.DataFrame()

test_transforms = Compose(
    [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys=["image"]),
        Orientationd(keys=["image"], axcodes="RAS"),
        Spacingd(
            keys=["image"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear"),
        ),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-175,
            a_max=250,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
    ]
)

test_ds = Dataset(data=test_data, transform=test_transforms)
test_loader = DataLoader(test_ds, batch_size=1, num_workers=4)

model.load_state_dict(torch.load("test_best_metric_model.pth"))##上のセルで学習したモデルを読み込むときは、パスを"best_metric_model.pth"に変更してください。
model.eval()

with torch.no_grad():
    for test_data in tqdm(test_loader):
        test_inputs = test_data["image"].to(device)
        file_id = str(test_data["id"][0])
        roi_size = (128, 128, 128)
        sw_batch_size = 16
        with torch.cuda.amp.autocast():
            test_outputs = sliding_window_inference(test_inputs, roi_size, sw_batch_size, model)
        test_outputs = [post_pred(i) for i in decollate_batch(test_outputs)]

        for j, output in enumerate(test_outputs):
            # 予測をNumpy配列に変換
            output_np = output.cpu().numpy()
            #print(output_np.shape)
            # RLEエンコーディングとデータフレームへの追加
            rle_encoded_data = encode_rle_multiclass(output_np)
            for cls_name, rle in rle_encoded_data:
                # ここでの 'file_id' は適切なファイル識別子に置き換える
                submission_df = pd.concat([submission_df,pd.DataFrame([f'{file_id}_{cls_name}', rle]).T], ignore_index=True)

submission_df.columns=['id', 'prediction']
# CSVファイルに保存
submission_df.to_csv('submission.csv', index=False)