In [1]:
import os
import sys
import h5py
from argparse import ArgumentParser
import torch

In [2]:
os.chdir(os.path.dirname(os.path.dirname(os.getcwd())))
print(os.getcwd())

/home/AR32500/AR32500/MyPapers/box-prompt-learning-VFM/src


In [3]:
from Data.datamodule import SAMDataModule
from Models.SAM_WithPromptGenerator import SAMPromptLearning_Ours
from Utils.load_utils import get_dict_from_config, update_config_from_args
from Utils.utils import find_matching_key

In [4]:
def parse_args():
    parser = ArgumentParser()
    parser.add_argument('--type', type=str, help='train or test',
                        default='train')
    
    # These are the paths to the data and output folder
    parser.add_argument('--data_dir', default='/home/AR32500/net/data', type=str, help='Directory for data')
    parser.add_argument('--output_dir', default='/home/AR32500/AR32500/output_PromptLearningForSAM', type=str, help='Directory for output run')

    # These are config files located in src/Config
    parser.add_argument('--data_config',  type=str, 
                        default='data_config/ACDC_256.yaml'
                        #default='data_config/CAMUS_512.yaml'
                        #default='data_config/HC_640.yaml'
                        #default='data_config/MSDSpleen_512.yaml'
                        #default='data_config/MSDLiver_256.yaml'
                        )
    parser.add_argument('--model_config', type=str, 
                        default='model_config/ours_samh_config.yaml'
                        )
    parser.add_argument('--module_config', type=str, default='model_config/module_hardnet_config.yaml')
    parser.add_argument('--train_config', type=str, default='train_config/train_config_200_100_00001.yaml')
    parser.add_argument('--logger_config', type=str, default='logger_config.yaml')
    parser.add_argument('--prompt_config', type=str, 
                        default='prompt_config/box_tight.yaml',
                        )
    parser.add_argument('--loss_config', type=str, nargs='+', 
                        help='type of loss to appply (does not matter here, just for initializating the model)',
                        default=['loss_config/WBCE_Dice/wbcedice_gtpromptedpred.yaml'])
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--num_gpu', default=1, help='number of GPU devices to use')
    parser.add_argument('--gpu_idx', default=[1], type=int, nargs='+', help='otherwise, gpu index, if we want to use a specific gpu')

    parser.add_argument('--logger__project_name', type=str, help='name of project in comet',
                        default='')

    # Training hyper-parameters that we should change according to the dataset
    # Arguments of data input and output
    parser.add_argument('--data__compute_sam_embeddings', help='whether to use precomputed embeddings',
                        action="store_true", default=True)
    parser.add_argument('--data__class_to_segment', type=int, nargs='+', help='class values to segment',
                        default=(1))

    parser.add_argument('--train__train_indices', type=int, nargs='+', help='indices of training data for Segmentation task',
                        default=[])
    parser.add_argument('--train__val_indices', help='indices of val data for Segmentation task',
                        default=[])
    parser.add_argument('--train__clip_gradient_norm_value', type=float, help='value to clip gradient norm (default: 0.0 = No clipping)',
                        default=1.0)
    args = parser.parse_args()
    return args

In [5]:
# Backup the original sys.argv
original_argv = sys.argv

# Temporarily replace sys.argv with your desired arguments
sys.argv = ['ipykernel_launcher.py', '--type', 'train']

# Now parse_args() will work as expected
args = parse_args()

# Optionally restore sys.argv to its original state
sys.argv = original_argv

In [6]:
# We set the gpu devices (either a specific gpu or a given number of available gpus)
if args.gpu_idx is not None:
    gpu_devices = args.gpu_idx
else:
    gpu_devices = args.num_gpu
print('gpu_devices {}'.format(gpu_devices))

gpu_devices [1]


In [7]:
# We extract the configs from the file names
train_config = get_dict_from_config(args.train_config)
data_config = get_dict_from_config(args.data_config)
model_config = get_dict_from_config(args.model_config)
module_config = get_dict_from_config(args.module_config)
logger_config = get_dict_from_config(args.logger_config)
train_config["loss"] = {}

# We add the loss configs to the train config. If two losses have the same type, we will add a subscript
for _file_config in args.loss_config:
    cur_config = get_dict_from_config(_file_config)
    loss_name = cur_config["type"]
    # Check if the loss_name is already in the dictionary
    original_loss_name = loss_name
    count = 1
    while loss_name in train_config["loss"]:
        # Append a number to the loss_name if it already exists
        loss_name = f"{original_loss_name}{count}"
        count += 1
    # Add the (possibly renamed) loss_name to the train_config
    train_config["loss"][loss_name] = cur_config

In [8]:
# We update the model and logger config files with the command-line arguments
data_config = update_config_from_args(data_config, args, 'data')
logger_config = update_config_from_args(logger_config, args, 'logger')
train_config = update_config_from_args(train_config, args, 'train')
print('train_config {}'.format(train_config))

train_config {'train_indices': [], 'num_workers': 16, 'num_epochs': 200, 'batch_size': 4, 'optimizer': {'type': 'Adam', 'lr': 0.0001, 'weight_decay': 0.0001}, 'sched': {'update_interval': 'epoch', 'update_freq': 1, 'MultiStepLR': {'milestones': [100], 'gamma': 0.1}}, 'loss': {'WBCE_Dice': {'type': 'WBCE_Dice', 'weight': 1, 'start_epoch': 0, 'kwargs': {'target_str': 'gt_prompted_pred_masks', 'idc': [1], 'alpha_CE': 0.5, 'reduction': 'mean'}, 'other_kwargs': {'bounds_name': None, 'bounds_params': None, 'fn': None}}}, 'val_indices': [], 'clip_gradient_norm_value': 1.0}


In [9]:
train_config['batch_size'] = 1
device = 'cuda:1'

In [10]:
# If we are at inference, we can add the config on the prompts to be used
if args.prompt_config != '':
    prompt_config = get_dict_from_config(args.prompt_config)
    train_config = {**train_config, **{'prompt': prompt_config}}

In [11]:
model_config

{'model_class': 'SAMPromptLearning_Ours',
 'model_name': 'vit_h',
 'sam_checkpoint': '/home/AR32500/net/models/sam/sam_vit_h_4b8939.pth',
 'image_size': 1024,
 'in_channels': 3,
 'out_channels': 2}

In [12]:
# We create model (importing the appropriate class from model_config['model_class'])
model_cls = globals().get(model_config['model_class'])
full_model = model_cls(num_devices=1,
                        model_config=model_config,
                        module_config=module_config,
                        train_config=train_config,
                        seed=args.seed,
                )      
full_model.to(device)
full_model.eval()

>> 0th list of losses: WBCE_Dice - {'type': 'WBCE_Dice', 'weight': 1, 'start_epoch': 0, 'kwargs': {'target_str': 'gt_prompted_pred_masks', 'idc': [1], 'alpha_CE': 0.5, 'reduction': 'mean'}, 'other_kwargs': {'bounds_name': None, 'bounds_params': None, 'fn': None}}
sam_args: Namespace(model_class='SAMPromptLearning_Ours', model_name='vit_h', sam_checkpoint='/home/AR32500/net/models/sam/sam_vit_h_4b8939.pth', image_size=1024, in_channels=3, out_channels=2)
ImageNet pretrained weights for HarDNet85 is loaded


SAMPromptLearning_Ours(
  (activation_fct): Sigmoid()
  (sam): Sam(
    (image_encoder): ImageEncoderViT(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 1280, kernel_size=(16, 16), stride=(16, 16))
      )
      (blocks): ModuleList(
        (0-31): 32 x Block(
          (norm1): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=1280, out_features=3840, bias=True)
            (proj): Linear(in_features=1280, out_features=1280, bias=True)
          )
          (norm2): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
          (mlp): MLPBlock(
            (lin1): Linear(in_features=1280, out_features=5120, bias=True)
            (lin2): Linear(in_features=5120, out_features=1280, bias=True)
            (act): GELU(approximate='none')
          )
        )
      )
      (neck): Sequential(
        (0): Conv2d(1280, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): LayerNorm2d()
      

In [13]:
# We create the datamodule  
kwargs = {'prompt': train_config['prompt'] if 'prompt' in train_config else None,
            'data_shape': data_config['data_shape'],
            'class_to_segment': data_config['class_to_segment'],
            'box_prior_args': find_matching_key(train_config["loss"], "TightBoxPrior", default={}).get('kwargs', None),
            'bounds_args_list': [{**_config["other_kwargs"], 'C': model_config["out_channels"]} for _, _config in train_config["loss"].items() if (_config["other_kwargs"] is not None and "bounds_name" in _config["other_kwargs"] and _config["other_kwargs"]["bounds_name"] is not None)],
            'compute_sam_embeddings': data_config['compute_sam_embeddings'],
            'model_image_size': model_config.get('image_size', None),
            'sam_checkpoint': model_config.get('sam_checkpoint', None)
        }

data_module = SAMDataModule(data_dir=args.data_dir,
                            dataset_name=data_config["dataset_name"],
                            batch_size=train_config["batch_size"],
                            val_batch_size=train_config["batch_size"],
                            num_workers=train_config["num_workers"],
                            train_indices=train_config["train_indices"],
                            dataset_kwargs=kwargs)

dataset_kwargs: {'prompt': {'prompt_type': ['box'], 'args': {'perturbation_bound': [0, 1]}}, 'data_shape': [256, 256], 'class_to_segment': 1, 'box_prior_args': None, 'bounds_args_list': [], 'compute_sam_embeddings': True, 'model_image_size': 1024, 'sam_checkpoint': '/home/AR32500/net/models/sam/sam_vit_h_4b8939.pth'}
Number of training images 42
Number of validation images 0
Number of test images 0


In [14]:
data_module.setup()
data_module.train_dataloader()
data_module.val_dataloader()

<monai.data.dataloader.DataLoader at 0x7f734c164a90>

In [15]:
data_module.setup('test')
data_module.test_dataloader()

<monai.data.dataloader.DataLoader at 0x7f734c167460>

In [16]:
# We save the embeddings to re-use later
base_savefolder = os.path.join(args.data_dir, 
                               data_config["dataset_name"])

In [17]:
os.path.join(base_savefolder, 'image_embeddings', os.path.basename(model_config['sam_checkpoint']).replace('.', '-'))

'/home/AR32500/net/data/ACDC/preprocessed_sam/image_embeddings/sam_vit_h_4b8939-pth'

# Saving SAM embeddings

In [18]:
torch.use_deterministic_algorithms(True, warn_only=True)

In [19]:
context = 'train_2d_images'

for batch_idx, batched_input in enumerate(data_module.train_loader):
    #print(batch_idx)
    B = len(batched_input['filename'])
    assert B == 1
    i = 0
    filename = batched_input['filename'][i].split('.')[0]
    embed_savefolder = os.path.join(base_savefolder, 'image_embeddings', 
                                    os.path.basename(model_config['sam_checkpoint']).replace('.', '-'), context, filename)
    
    if not os.path.exists(embed_savefolder + '.h5'):
        print(batch_idx, filename)
        
        if not os.path.exists(os.path.dirname(embed_savefolder)):
            # If it doesn't exist, create it
            os.makedirs(os.path.dirname(embed_savefolder))
        
        with torch.no_grad():
            batched_input = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batched_input.items()}
            image_embeddings = full_model.sam.image_encoder(batched_input['data'])

            with h5py.File(embed_savefolder + '.h5', 'w') as f:
                f.create_dataset(filename, data=image_embeddings[i].detach().cpu().numpy())

    else:
        print(f"Dataset {filename} already exists")

0 patient005_frame01_slice7


  ret = func(*args, **kwargs)
  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]


1 patient001_frame01_slice2
2 patient002_frame01_slice5
3 patient003_frame01_slice2
4 patient001_frame01_slice8
5 patient003_frame01_slice8
6 patient002_frame01_slice3
7 patient005_frame01_slice3
8 patient005_frame01_slice8
9 patient004_frame01_slice3
10 patient005_frame01_slice4
11 patient005_frame01_slice6
12 patient004_frame01_slice6
13 patient003_frame01_slice5
14 patient004_frame01_slice2
15 patient005_frame01_slice5
16 patient004_frame01_slice1
17 patient003_frame01_slice1
18 patient003_frame01_slice7
19 patient003_frame01_slice6
20 patient002_frame01_slice4
21 patient003_frame01_slice4
22 patient005_frame01_slice2
23 patient002_frame01_slice6
24 patient004_frame01_slice0
25 patient001_frame01_slice9
26 patient001_frame01_slice7
27 patient001_frame01_slice5
28 patient003_frame01_slice3
29 patient002_frame01_slice7
30 patient004_frame01_slice4
31 patient001_frame01_slice6
32 patient002_frame01_slice1
33 patient001_frame01_slice1
34 patient005_frame01_slice1
35 patient002_frame01_s

In [20]:
context = 'val_2d_images'

for batch_idx, batched_input in enumerate(data_module.val_loader):
    #print(batch_idx)
    B = len(batched_input['filename'])
    assert B == 1
    i = 0
    filename = batched_input['filename'][i].split('.')[0]
    embed_savefolder = os.path.join(base_savefolder, 'image_embeddings', 
                                    os.path.basename(model_config['sam_checkpoint']).replace('.', '-'), context, filename)

    if not os.path.exists(embed_savefolder + '.h5'):
        print(batch_idx, filename)
        
        if not os.path.exists(os.path.dirname(embed_savefolder)):
            # If it doesn't exist, create it
            os.makedirs(os.path.dirname(embed_savefolder))
        
        with torch.no_grad():
            batched_input = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batched_input.items()}
            image_embeddings = full_model.sam.image_encoder(batched_input['data'])

            with h5py.File(embed_savefolder + '.h5', 'w') as f:
                f.create_dataset(filename, data=image_embeddings[i].detach().cpu().numpy())
    else:
        print(f"Dataset {filename} already exists")

In [21]:
context = 'test_2d_images'

for batch_idx, batched_input in enumerate(data_module.test_loader):
    #print(batch_idx)
    B = len(batched_input['filename'])
    assert B == 1
    i = 0
    filename = batched_input['filename'][i].split('.')[0]
    embed_savefolder = os.path.join(base_savefolder, 'image_embeddings', 
                                    os.path.basename(model_config['sam_checkpoint']).replace('.', '-'), context, filename)

    if not os.path.exists(embed_savefolder + '.h5'):
        print(batch_idx, filename)
        
        if not os.path.exists(os.path.dirname(embed_savefolder)):
            # If it doesn't exist, create it
            os.makedirs(os.path.dirname(embed_savefolder))
        
        with torch.no_grad():
            batched_input = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batched_input.items()}
            image_embeddings = full_model.sam.image_encoder(batched_input['data'])

            with h5py.File(embed_savefolder + '.h5', 'w') as f:
                f.create_dataset(filename, data=image_embeddings[i].detach().cpu().numpy())

    else:
        print(f"Dataset {filename} already exists")