In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from glob import glob
import pandas as pd
import pickle
from torch.utils.data import RandomSampler
import random
import scipy
import torch.nn.functional as F
from PIL import Image
from glob import glob
import wandb
import re
from adjustText import adjust_text
import seaborn as sns
import scipy
import statannot
import argparse

In [3]:
from MedSAM_HCP.utils_hcp import *
from MedSAM_HCP.dataset import *

In [5]:
def load_model(model_type, model_path, num_classes):
    result = torch.load(model_path)
    try:
        if 'model' in result.keys():
            splits = model_path.split('/')
            new_path = os.path.join('/'.join(splits[:-1]), f'{splits[-1].split(".pth")[0]}_sam_readable.pth')
            print(f'model path converted to sam readable format and saved to {new_path}')

            result = result['model']

            # now remove the "module." prefix
            result_dict = {}
            for k,v in result.items():
                key_splits = k.split('.')
                assert key_splits[0] == 'module'
                new_k = '.'.join(key_splits[1:])
                result_dict[new_k] = v

            torch.save(result_dict, new_path)
            model_path = new_path

    except (AttributeError):
        # already in the correct format
        print('model path in sam readable format already')

    if model_type == 'multitask_unprompted':
        model = build_sam_vit_b_multiclass(num_classes, checkpoint=model_path).to('cuda')
    elif model_type == 'pooltask_yolov7_prompted':
        model = build_sam_vit_b_multiclass(num_classes, checkpoint=model_path).to('cuda')
    else:
        # singletask model
        model = build_sam_vit_b_multiclass(3, checkpoint=model_path).to('cuda')



    model.eval()
    return model
def load_model_from_label_and_type(model_type, label):
    assert model_type in ['singletask_unprompted', 'multitask_unprompted',
                'singletask_medsam_prompted', 'singletask_yolov7_prompted',
                'singletask_yolov7_longer_prompted', 'pooltask_yolov7_prompted']
    
    if model_type == 'singletask_unprompted':
        raise NotImplementedError
    elif model_type == 'multitask_unprompted':
        model_path = '/gpfs/data/luilab/karthik/pediatric_seg_proj/results_copied_from_kn2347/ce_only_resume_training_from_checkpoint_8-9-23/MedSAM_finetune_hcp_ya_constant_bbox_all_tasks-20230810-115803/medsam_model_best.pth'
        num_classes = 103
    elif model_type == 'singletask_medsam_prompted':
        raise NotImplementedError
    elif model_type == 'singletask_yolov7_prompted':
        model_path = f'/gpfs/data/luilab/karthik/pediatric_seg_proj/results_copied_from_kn2347/second_round_w_bbox_yolov7_finetunes_longer_8-17-23/label{label}/*/medsam_model_best.pth'
        listo = glob(model_path)
        assert len(listo) == 1
        model_path = listo[0]
        num_classes = 3 # note we have to pass in 3 so that we get the singletask sam model, which predicts 3 masks, even though the more accurate number would be 2
    elif model_type == 'singletask_yolov7_longer_prompted':
        model_path = f'/gpfs/data/luilab/karthik/pediatric_seg_proj/results_copied_from_kn2347/second_round_w_bbox_yolov7_finetunes_60epochs_8-20-23/label{label}/*/medsam_model_best.pth'
        listo = glob(model_path)
        assert len(listo) == 1
        model_path = listo[0]
        num_classes = 3
    elif model_type == 'pooltask_yolov7_prompted':
        model_path = '/gpfs/data/luilab/karthik/pediatric_seg_proj/results_copied_from_kn2347/pooled_labels_ckpt_continue_8-22-23/model_best_20230822-115028.pth'
        num_classes = 103 # have to pass in 103 here unfortunately because this model was accidentally trained to output 103 masks, even though only the first one is actually used and loss-propagated through

    return load_model(model_type, model_path, num_classes)

In [29]:
def get_num_trainable_params(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])

    return params


In [10]:
model_type = 'pooltask_yolov7_prompted'
label = 1
model = load_model_from_label_and_type(model_type, label)

In [30]:

# all mask decoder trainable parameters
model_parameters = filter(lambda p: p.requires_grad, model.mask_decoder.parameters())
used_params = get_num_trainable_params(model.mask_decoder)

# subtract out "fake" trainable parameters
unused_params = get_num_trainable_params(model.mask_decoder.output_hypernetworks_mlps[1:])
print(used_params, unused_params, used_params - unused_params)

18090440 14400224 3690216


In [31]:
model_type = 'singletask_yolov7_longer_prompted'
label = 1
model = load_model_from_label_and_type(model_type, label)

model path converted to sam readable format and saved to /gpfs/data/luilab/karthik/pediatric_seg_proj/results_copied_from_kn2347/second_round_w_bbox_yolov7_finetunes_60epochs_8-20-23/label1/MedSAM_finetune_final_round-20230821-200628/medsam_model_best_sam_readable.pth


In [35]:
used = get_num_trainable_params(model.mask_decoder)
unused = get_num_trainable_params(model.mask_decoder.output_hypernetworks_mlps[1:])
(used - unused) * 102

371169432

In [36]:
model_type = 'multitask_unprompted'
label = 1
model = load_model_from_label_and_type(model_type, label)

model path converted to sam readable format and saved to /gpfs/data/luilab/karthik/pediatric_seg_proj/results_copied_from_kn2347/ce_only_resume_training_from_checkpoint_8-9-23/MedSAM_finetune_hcp_ya_constant_bbox_all_tasks-20230810-115803/medsam_model_best_sam_readable.pth


In [38]:
model

Sam(
  (image_encoder): ImageEncoderViT(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(in_features=3072, out_features=768, bias=True)
          (act): GELU(approximate='none')
        )
      )
    )
    (neck): Sequential(
      (0): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): LayerNorm2d()
      (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (3): LayerNorm2d()
    )


In [39]:
used = get_num_trainable_params(model.mask_decoder)
print(used)

18090440
