In [1]:
%load_ext autoreload
%autoreload 2
### Set CUDA device
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '7'

In [2]:
import sys, os
import pickle as pkl
from omegaconf import OmegaConf
import hydra
from time import time
sys.path.append('../')
from data_utils import get_data_module
from model.unet import get_unet_module
from model.score_adapter import (
    get_score_prediction_module, 
)

In [None]:

for dataset in ['mnmv2', 'pmri']:
    # configs
    unet_cfg = OmegaConf.load('../configs/unet/monai_unet.yaml')
    UNET_CKPTS = {
        "mnmv2": 'mnmv2_symphony_dropout-0-1_2025-01-14-15-19', 
        'pmri': 'pmri_runmc_dropout-0-1_2025-01-14-15-58',
    }

    if dataset == 'mnmv2':
        unet_cfg.out_channels = 4
        num_classes = 4
        data_cfg = OmegaConf.load('../configs/data/mnmv2.yaml')
        domain = 'Symphony'

    else:
        unet_cfg.out_channels = 1
        num_classes = 2
        data_cfg = OmegaConf.load('../configs/data/pmri.yaml')
        domain = 'RUNMC'
        sigma = 6.9899

# # dataset = 'mnmv2'
# unet_cfg.out_channels = 4
# num_classes = 4
# data_cfg = OmegaConf.load('../configs/data/mnmv2.yaml')
# data_cfg.domain = 'Symphony'
# data_cfg.non_empty_target = True
# unet_ckpt = UNET_CKPTS[data_cfg.dataset] #[cfg.data.domain]
# unet_cfg.checkpoint_path = f'../../pre-trained/monai-unets/{unet_ckpt}.ckpt'

    model_cfg = OmegaConf.load('../configs/model/score_predictor.yaml')
    model_cfg.num_classes = num_classes
    model_cfg.adversarial_training=True
    model_cfg.adversarial_prob = 0.5
    model_cfg.adversarial_step_size = 0.1
    model_cfg.loss_fn = 'dice'
    model_cfg.non_adversarial_target = False

    # init datamodule
    datamodule = get_data_module(
        cfg=data_cfg
    )
    ckpt = UNET_CKPTS[data_cfg.dataset]
        
    unet_cfg.checkpoint_path = f'../../{unet_cfg.checkpoint_dir}{ckpt}.ckpt'
    unet = get_unet_module(
        cfg=unet_cfg,
        metadata=OmegaConf.to_container(unet_cfg),
        load_from_checkpoint=True
    ).model

    model = get_score_prediction_module(
        data_cfg=data_cfg,
        model_cfg=model_cfg,
        unet=unet,
        metadata=OmegaConf.to_container(model_cfg), #TODO
        ckpt=None
    )

    datamodule.setup('fit')

    if dataset == 'mnmv2':
        data = datamodule.mnm_train

    else:
        data = datamodule.pmri_train

    batch_size = 15
    input = data[10:11]['input'].repeat(batch_size, 1, 1 ,1)

    start = time()
    for i in range(100):
        _ = model(input.cuda())

    time_taken = time() - start
    print(f'{time_taken / (100 * 15)} Seconds per image for forward passes before finetuning')
    print(f'{time_taken / (100)} Seconds per 15 images for forward passes before finetuning')

    start = time()
    for i in range(100):
        _ = unet(input.cuda())
        _ = model(input.cuda())

    time_taken = time() - start
    print(f'{time_taken / (100 * 15)} Seconds per image for forward passes after finetuning + unet')
    print(f'{time_taken / (100)} Seconds per 15 images for forward passes after finetuning + unet')

# ckpt = '../../pre-trained/score-predictors/mnmv2_Symphony_dice_adversarial_2025-01-15-11-29.ckpt'
# ckpt = '../../pre-trained/score-predictors/mnmv2_Symphony_dice_normal_2025-01-15-11-27.ckpt'

32768 8 64 64
4196936
32768 8 64 64
0.0009477243423461914 Seconds per image for forward passes before finetuning
0.014215865135192872 Seconds per 15 images for forward passes before finetuning
0.0018232946395874024 Seconds per image for forward passes after finetuning + unet
0.027349419593811035 Seconds per 15 images for forward passes after finetuning + unet
73728 8 96 96
9439816
73728 8 96 96
0.002008889039357503 Seconds per image for forward passes before finetuning
0.030133335590362548 Seconds per 15 images for forward passes before finetuning
0.003857900937398275 Seconds per image for forward passes after finetuning + unet
0.057868514060974124 Seconds per 15 images for forward passes after finetuning + unet


In [22]:

# init datamodule
datamodule = get_data_module(
    cfg=data_cfg
)

unet = get_unet_module(
    cfg=unet_cfg,
    metadata=OmegaConf.to_container(unet_cfg),
    load_from_checkpoint=True
).model

model = get_score_prediction_module(
    data_cfg=data_cfg,
    model_cfg=model_cfg,
    unet=unet,
    metadata=OmegaConf.to_container(model_cfg), #TODO
    ckpt=None
)

32768 8 64 64
4196936
32768 8 64 64


In [7]:
datamodule.setup('fit')

In [12]:
data = datamodule.mnm_train
batch_size = 15
input = data[10:11]['input'].repeat(batch_size, 1, 1 ,1)

In [20]:
start = time()
for i in range(100):
    _ = model(input.cuda()).cpu().detach().numpy()

end = time() - start

In [21]:
end / ( 100 * 15)

0.0009508248964945475