In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

from help_func import print_var_detail, create_path
from utils.radimgnet_loader_ipt import create_radimgnet_dataloader_multi
from utils.data_transform import DataTransform
from utils.sample_mask import RandomMask, EquiSpaceMask, RandomMaskGaussian1D, RandomMaskGaussian

In [None]:
# ====== Construct dataset ======

DATASET_PATH = 'PATH_TO_DATASET'
# set the input height and width
INPUT_HEIGHT = 224
INPUT_WIDTH = 224
# set the batch size and validation data split
VAL_SPLIT = 0.1
SEED = None
BATCH_SIZE = 1

# initialize mask level
acc0 = 2.0
frac_c0 = 0.1  # center fraction
acc1 = 4.0
frac_c1 = 0.08  # center fraction
acc2 = 6.0
frac_c2 = 0.06  # center fraction
acc3 = 8.0
frac_c3 = 0.04  # center fraction
acc4 = 10.0
frac_c4 = 0.02  # center fraction
# ====== Cartesian Random ======
mask_func0 = RandomMask(center_fraction=frac_c0, acceleration=acc0, size=[1, INPUT_HEIGHT, INPUT_WIDTH], seed=SEED)
mask_func1 = RandomMask(center_fraction=frac_c1, acceleration=acc1, size=[1, INPUT_HEIGHT, INPUT_WIDTH], seed=SEED)
mask_func2 = RandomMask(center_fraction=frac_c2, acceleration=acc2, size=[1, INPUT_HEIGHT, INPUT_WIDTH], seed=SEED)
mask_func3 = RandomMask(center_fraction=frac_c3, acceleration=acc3, size=[1, INPUT_HEIGHT, INPUT_WIDTH], seed=SEED)
mask_func4 = RandomMask(center_fraction=frac_c4, acceleration=acc4, size=[1, INPUT_HEIGHT, INPUT_WIDTH], seed=SEED)
RandomMask_func0 = DataTransform(mask_func=mask_func0)
RandomMask_func1 = DataTransform(mask_func=mask_func1)
RandomMask_func2 = DataTransform(mask_func=mask_func2)
RandomMask_func3 = DataTransform(mask_func=mask_func3)
RandomMask_func4 = DataTransform(mask_func=mask_func4)
# ====== Cartesian Equispace ======
mask_func0 = EquiSpaceMask(center_fraction=frac_c0, acceleration=acc0, size=[1, INPUT_HEIGHT, INPUT_WIDTH], seed=SEED)
mask_func1 = EquiSpaceMask(center_fraction=frac_c1, acceleration=acc1, size=[1, INPUT_HEIGHT, INPUT_WIDTH], seed=SEED)
mask_func2 = EquiSpaceMask(center_fraction=frac_c2, acceleration=acc2, size=[1, INPUT_HEIGHT, INPUT_WIDTH], seed=SEED)
mask_func3 = EquiSpaceMask(center_fraction=frac_c3, acceleration=acc3, size=[1, INPUT_HEIGHT, INPUT_WIDTH], seed=SEED)
mask_func4 = EquiSpaceMask(center_fraction=frac_c4, acceleration=acc4, size=[1, INPUT_HEIGHT, INPUT_WIDTH], seed=SEED)
EquiSpaceMask_func0 = DataTransform(mask_func=mask_func0)
EquiSpaceMask_func1 = DataTransform(mask_func=mask_func1)
EquiSpaceMask_func2 = DataTransform(mask_func=mask_func2)
EquiSpaceMask_func3 = DataTransform(mask_func=mask_func3)
EquiSpaceMask_func4 = DataTransform(mask_func=mask_func4)
# ====== 1D Gaussian ======
mask_func0 = RandomMaskGaussian1D(center_fraction=frac_c0, acceleration=acc0, size=[1, INPUT_HEIGHT, INPUT_WIDTH], seed=SEED)
mask_func1 = RandomMaskGaussian1D(center_fraction=frac_c1, acceleration=acc1, size=[1, INPUT_HEIGHT, INPUT_WIDTH], seed=SEED)
mask_func2 = RandomMaskGaussian1D(center_fraction=frac_c2, acceleration=acc2, size=[1, INPUT_HEIGHT, INPUT_WIDTH], seed=SEED)
mask_func3 = RandomMaskGaussian1D(center_fraction=frac_c3, acceleration=acc3, size=[1, INPUT_HEIGHT, INPUT_WIDTH], seed=SEED)
mask_func4 = RandomMaskGaussian1D(center_fraction=frac_c4, acceleration=acc4, size=[1, INPUT_HEIGHT, INPUT_WIDTH], seed=SEED)
RandomMaskGaussian1D_func0 = DataTransform(mask_func=mask_func0)
RandomMaskGaussian1D_func1 = DataTransform(mask_func=mask_func1)
RandomMaskGaussian1D_func2 = DataTransform(mask_func=mask_func2)
RandomMaskGaussian1D_func3 = DataTransform(mask_func=mask_func3)
RandomMaskGaussian1D_func4 = DataTransform(mask_func=mask_func4)
# ====== 2D Gaussian ======
mask_func0 = RandomMaskGaussian(center_fraction=frac_c0**0.5, acceleration=acc0, size=[1, INPUT_HEIGHT, INPUT_WIDTH], seed=SEED, cov=[[1.5, 0], [0, 1.5]])
mask_func1 = RandomMaskGaussian(center_fraction=frac_c1**0.5, acceleration=acc1, size=[1, INPUT_HEIGHT, INPUT_WIDTH], seed=SEED, cov=[[1.5, 0], [0, 1.5]])
mask_func2 = RandomMaskGaussian(center_fraction=frac_c2**0.5, acceleration=acc2, size=[1, INPUT_HEIGHT, INPUT_WIDTH], seed=SEED, cov=[[1.5, 0], [0, 1.5]])
mask_func3 = RandomMaskGaussian(center_fraction=frac_c3**0.5, acceleration=acc3, size=[1, INPUT_HEIGHT, INPUT_WIDTH], seed=SEED, cov=[[1.5, 0], [0, 1.5]])
mask_func4 = RandomMaskGaussian(center_fraction=frac_c4**0.5, acceleration=acc4, size=[1, INPUT_HEIGHT, INPUT_WIDTH], seed=SEED, cov=[[1.5, 0], [0, 1.5]])
RandomMaskGaussian_func0 = DataTransform(mask_func=mask_func0)
RandomMaskGaussian_func1 = DataTransform(mask_func=mask_func1)
RandomMaskGaussian_func2 = DataTransform(mask_func=mask_func2)
RandomMaskGaussian_func3 = DataTransform(mask_func=mask_func3)
RandomMaskGaussian_func4 = DataTransform(mask_func=mask_func4)


scales = [[1,1,1,1,1], [1,1,1,1,1], [1,1,1,1,1], [1,1,1,1,1]]
func_list = [[RandomMask_func0, RandomMask_func1, RandomMask_func2, RandomMask_func3, RandomMask_func4],
             [EquiSpaceMask_func0, EquiSpaceMask_func1, EquiSpaceMask_func2, EquiSpaceMask_func3, EquiSpaceMask_func4],
             [RandomMaskGaussian1D_func0, RandomMaskGaussian1D_func1, RandomMaskGaussian1D_func2, RandomMaskGaussian1D_func3, RandomMaskGaussian1D_func4],
             [RandomMaskGaussian_func0, RandomMaskGaussian_func1, RandomMaskGaussian_func2, RandomMaskGaussian_func3, RandomMaskGaussian_func4]]

In [None]:
dataloader_train = create_radimgnet_dataloader_multi(
    data_dir=DATASET_PATH,
    random_seed=0,
    val_split=VAL_SPLIT,
    image_size=(INPUT_HEIGHT, INPUT_WIDTH),
    batch_size=BATCH_SIZE,
    is_distributed=False,
    is_train=True,
    scales=scales,
    func_list=func_list,
    num_workers=0,
    fix_scale_idx=None,
)
dataloader_test = create_radimgnet_dataloader_multi(
    data_dir=DATASET_PATH,
    random_seed=0,
    val_split=VAL_SPLIT,
    image_size=(INPUT_HEIGHT, INPUT_WIDTH),
    batch_size=1,
    is_distributed=False,
    is_train=False,
    scales=scales,
    func_list=func_list,
    num_workers=0,
    fix_scale_idx=None,
)

In [None]:
# ====== Construct model ======
from modeling.image_encoder import ImageEncoderViT
from modeling.prompt_encoder import PromptEncoderMulti
from modeling.image_decoder import ImageDecoderMulti
from modeling.transformer import TwoWayTransformer
from modeling.mript import MRIPT
from functools import partial

print(torch.__version__)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device:', device)

encoder_embed_dim=1024
encoder_depth=24
encoder_num_heads=16
encoder_global_attn_indexes=[5, 11, 17, 23]
in_feat = 64
output_dim_factor = 4
prompt_embed_dim = in_feat * output_dim_factor
image_size = INPUT_HEIGHT
vit_patch_size = 4
n_colors = 1
mlp_ratio = 4
image_embedding_size = image_size // vit_patch_size

SETTINGS_model_mutation_mode = 'type'  # 'type', 'level' or 'combine'

image_encoder=ImageEncoderViT(
    depth=encoder_depth,
    in_chans=in_feat,
    embed_dim=encoder_embed_dim,
    img_size=image_size,
    mlp_ratio=mlp_ratio,
    norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
    num_heads=encoder_num_heads,
    patch_size=vit_patch_size,
    qkv_bias=True,
    use_rel_pos=True,
    global_attn_indexes=encoder_global_attn_indexes,
    window_size=14,
    out_chans=prompt_embed_dim,
    )

prompt_encoder=PromptEncoderMulti(
    num_level=len(scales[0]),
    num_type=len(scales),
    embed_dim=prompt_embed_dim,
    image_embedding_size=(image_embedding_size, image_embedding_size),
    input_image_size=(image_size, image_size),
    )

image_decoder=ImageDecoderMulti(
    transformer=TwoWayTransformer(
        depth=2,
        embedding_dim=prompt_embed_dim,
        mlp_dim=2048,
        num_heads=8,
        attention_downsample_rate=1,
    ),
    transformer_dim=prompt_embed_dim,
    output_dim_factor=output_dim_factor
    )

_model = MRIPT(
    n_feats=in_feat,
    n_colors=n_colors,
    scale=scales,
    conv_kernel_size=3,
    res_kernel_size=5,
    image_encoder=image_encoder,
    prompt_encoder=prompt_encoder,
    image_decoder=image_decoder,
    mode=SETTINGS_model_mutation_mode
    )

In [None]:
# save settings
PATH_MODEL = 'PATH_TO_SAVE_MODEL'
create_path(PATH_MODEL)
print('PATH_MODEL:', PATH_MODEL)

In [None]:
# ====== Train ======
from modeling.mript_trainer import TrainerMulti

learning_rate = 1e-5
NUM_EPOCH = 5

optimizer = torch.optim.Adam(_model.parameters(), lr=learning_rate, amsgrad=False)
criteon = nn.L1Loss()
trainer = TrainerMulti(
    loader_train=dataloader_train,
    loader_test=dataloader_test,
    my_model=_model,
    my_loss=criteon,
    optimizer=optimizer,
    PATH_MODEL=PATH_MODEL,
    device = device,
    NUM_EPOCH=NUM_EPOCH,
    RESUME_EPOCH=0,
    if_save=True
)
_model = trainer.train(show_step=1, show_test=True)