In [1]:
import argparse
from pathlib import Path
import numpy as np
import torch
import pandas as pd
import torch.nn as nn
from sklearn.metrics import (
    roc_auc_score, f1_score, accuracy_score, balanced_accuracy_score, confusion_matrix
)

import wandb 

from ISUPMedSAM import IMG_SIZE, MedSAMSliceSpatialAttn
from segment_anything import sam_model_registry

from triplet_loss_utils import (
    get_histo_by_isup,
    triplet_loss_batch,
)
import train_utils
from train_utils import (
    build_datasets_and_loaders,
    evaluate_loader,
    format_perclass_acc_auc,
    format_sens_spec,
    print_operating_points_table,
    EarlyStopper,
    set_seed,
    wandb_init, wandb_log, wandb_finish,
    save_embeddings,
)
from sklearn.linear_model import LogisticRegression
import numpy as np
import pandas as pd
from collections import Counter
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, WeightedRandomSampler

from dataset_picai_slices import map_binary_all, map_binary_low_high, map_isup3, PicaiSliceDataset, map_isupc3
from ISUPMedSAM import IMG_SIZE

from sklearn.metrics import (
    accuracy_score, balanced_accuracy_score, f1_score, confusion_matrix,
    roc_auc_score, auc, roc_curve
)



In [2]:
MANIFEST="/project/aip-medilab/shared/picai/manifests/slices_manifest.csv"
model_dir="/project/6106383/obed/medproj/pca_contrastive/mri_model_medsam_finetune_2D/LATEST/results_isupc3_ob/two_stage-head_and_proj/ckpt_triplet_best_macro_auc.pt"
sam_checkpoint = "/datasets/exactvu_pca/checkpoint_store/sam/medsam_vit_b_cpu.pth"
n_classes = 6
proj_dim = 512
device = 'cuda' if torch.cuda.is_available else 'cpu'


In [3]:
sam = sam_model_registry["vit_b"]()
sam.load_state_dict(torch.load(sam_checkpoint, map_location="cpu"), strict=True)
model = MedSAMSliceSpatialAttn(
    sam_model=sam, num_classes=n_classes,
    proj_dim=proj_dim, attn_dim=256,
    head_hidden=256, head_dropout=0.1,
    use_pre_neck=True, pixel_mean_std=None,
).to(device)




In [4]:
model.load_state_dict(torch.load(model_dir, map_location='cpu')['model'])

<All keys matched successfully>

In [5]:
_ds = PicaiSliceDataset(
        manifest_csv=MANIFEST,
        folds=None,
        use_skip=True,
        label6_column='label6',
        target='isup6',
        channels=("path_T2","path_ADC","path_HBV"),
        missing_channel_mode="zeros",
        pct_lower=0.5, pct_upper=99.5,
        cache_size=64,
    )

In [11]:
len(_ds)

28139

In [12]:
import pandas as pd

In [19]:
df = pd.read_csv(MANIFEST)

In [20]:
df.head(5)

Unnamed: 0,case_id,fold,z,label6,label3,has_lesion,area_frac,path_T2,path_ADC,path_HBV,...,path_mask_prostate,bbox_prostate_z0,bbox_prostate_z1,bbox_prostate_h0,bbox_prostate_h1,bbox_prostate_w0,bbox_prostate_w1,skip,patient_ISUP,merged_ISUP
0,10000_1000000,0,0,0,0,0,0.0,/project/aip-medilab/shared/picai/picai_preppe...,/project/aip-medilab/shared/picai/picai_preppe...,/project/aip-medilab/shared/picai/picai_preppe...,...,/project/aip-medilab/shared/picai/picai_preppe...,0,28,181,381,263,403,0,0,0
1,10000_1000000,0,1,0,0,0,0.0,/project/aip-medilab/shared/picai/picai_preppe...,/project/aip-medilab/shared/picai/picai_preppe...,/project/aip-medilab/shared/picai/picai_preppe...,...,/project/aip-medilab/shared/picai/picai_preppe...,0,28,181,381,263,403,0,0,0
2,10000_1000000,0,2,0,0,0,0.0,/project/aip-medilab/shared/picai/picai_preppe...,/project/aip-medilab/shared/picai/picai_preppe...,/project/aip-medilab/shared/picai/picai_preppe...,...,/project/aip-medilab/shared/picai/picai_preppe...,0,28,181,381,263,403,0,0,0
3,10000_1000000,0,3,0,0,0,0.0,/project/aip-medilab/shared/picai/picai_preppe...,/project/aip-medilab/shared/picai/picai_preppe...,/project/aip-medilab/shared/picai/picai_preppe...,...,/project/aip-medilab/shared/picai/picai_preppe...,0,28,181,381,263,403,0,0,0
4,10000_1000000,0,4,0,0,0,0.0,/project/aip-medilab/shared/picai/picai_preppe...,/project/aip-medilab/shared/picai/picai_preppe...,/project/aip-medilab/shared/picai/picai_preppe...,...,/project/aip-medilab/shared/picai/picai_preppe...,0,28,181,381,263,403,0,0,0


In [14]:
len(df)

33704

In [21]:
df = df[df['skip']==0]

In [26]:
df.columns

Index(['case_id', 'fold', 'z', 'label6', 'label3', 'has_lesion', 'area_frac',
       'path_T2', 'path_ADC', 'path_HBV', 'path_mask_lesion',
       'path_mask_prostate', 'bbox_prostate_z0', 'bbox_prostate_z1',
       'bbox_prostate_h0', 'bbox_prostate_h1', 'bbox_prostate_w0',
       'bbox_prostate_w1', 'skip', 'patient_ISUP', 'merged_ISUP'],
      dtype='object')

In [28]:
df['has_lesion'].unique()

array([0, 1])

In [22]:
len(df)

28139

In [6]:
def collate_resize_to_imgsize(batch):
    imgs, labels = [], []
    extras_keys = [k for k in batch[0].keys() if k not in ("image", "label")]
    extras = {k: [] for k in extras_keys}
    for s in batch:
        x = s["image"].unsqueeze(0)  # [1,C,H,W]
        x = F.interpolate(x, size=(IMG_SIZE, IMG_SIZE), mode="bilinear", align_corners=False).squeeze(0)
        imgs.append(x)
        labels.append(torch.as_tensor(s["label"], dtype=torch.long))
        for k in extras_keys:
            extras[k].append(s[k])
    return {"image": torch.stack(imgs, 0),
            "label": torch.stack(labels, 0),
            **extras}

In [7]:
_loader = DataLoader(
    _ds, batch_size=64, shuffle=False,
    num_workers=4, pin_memory=True,
    collate_fn=collate_resize_to_imgsize
)

In [8]:
@torch.no_grad()
def extract_embeddings(loader, model, device="cuda"):
    model.eval()
    embs, ys = [], []
    for batch in loader:
        x = batch["image"].to(device, non_blocking=True)
        y = batch["label"].to(device, non_blocking=True)
        _, emb = model(x)
        embs.append(emb.cpu())
        ys.append(y.cpu())
    X = torch.cat(embs, 0).numpy() if embs else np.empty((0, 0), dtype=np.float32)
    y = torch.cat(ys, 0).numpy() if ys else np.empty((0,), dtype=np.int64)
    return X, y

In [9]:
model.eval() 

MedSAMSliceSpatialAttn(
  (encoder): ImageEncoderViT(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(in_features=3072, out_features=768, bias=True)
          (act): GELU(approximate='none')
        )
      )
    )
    (neck): Identity()
  )
  (pool): SpatialAttnPool2d(
    (norm): GroupNorm(1, 768, eps=1e-05, affine=True)
    (theta): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1))
    (gate): Conv2d(768, 256, kernel_size=(1, 1), stride=(1

In [18]:
x.shape

torch.Size([43, 3, 256, 256])

In [11]:
all_embeddings = []
case_ids = []
zs = []
for data in _loader:
    x = data['image'].to(device, non_blocking=True)
    y = data['label'].to(device, non_blocking = True)
    case_id = data['case_id']
    z = data['z']

    _, emb = model(x)
    
    all_embeddings.append(emb.detach().cpu().numpy())
    case_ids.append(case_id)
    zs.append(z)

In [12]:
embed_path = "/home/obed/projects/aip-medilab/shared/mri_embeddings_medsam"

In [13]:
import os

In [17]:
for i in range (len(all_embeddings)):
    embedding_batch = all_embeddings[i]
    id_batch = case_ids[i]
    z_batch = zs[i]
    for j in range(len(id_batch)):
        embedding = embedding_batch[j]
        id_case = id_batch[j]
        z_slice = z_batch[j]

        filename = str(id_case)+"_"+str(z_slice)+'.npy'
        file_path = os.path.join(embed_path, filename)
        np.save(file_path, embedding)



In [15]:
i

439

In [16]:
len(id_batch)

43