# Import libs

In [2]:
import torch
from timm.models import create_model
from musk import utils, modeling
from PIL import Image
from transformers import XLMRobertaTokenizer
from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
import torchvision
# from huggingface_hub import login
# login(<HF Token>)
device = torch.device("cuda:4")

# Extract Image Embeddings

- Set `ms_aug = True` for:  
  - Linear probe classification  
  - Multiple Instance Learning  

- Set `ms_aug = False` for:  
  - Zero-shot tasks (e.g., image-image retrieval and image-text retrieval)


In [3]:
# >>>>>>>>>>>> load model >>>>>>>>>>>> #
model_config = "musk_large_patch16_384"
model = create_model(model_config).eval()
local_ckpt = "/remote-home/share/lisj/Workspace/SOTA_NAS/encoder/musk/checkpoint/model.safetensors"
utils.load_model_and_may_interpolate(local_ckpt, model, 'model|module', '')
model.to(device, dtype=torch.float16)
model.eval()
# <<<<<<<<<<<< load model <<<<<<<<<<<< #

# >>>>>>>>>>>> process image >>>>>>>>>>> #
# load an image and process it
img_size = 384
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(img_size, interpolation=3, antialias=True),
    torchvision.transforms.CenterCrop((img_size, img_size)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD)
])

img = Image.open('/remote-home/share/lisj/Workspace/SOTA_NAS/datasets/core/patches/train/1819360/146_315.png').convert("RGB")  # input image
img_tensor = transform(img).unsqueeze(0)
with torch.inference_mode():
    image_embeddings = model(
        image=img_tensor.to(device, dtype=torch.float16),
        with_head=False, # We only use the retrieval head for image-text retrieval tasks.
        out_norm=True,
        ms_aug=True  # by default it is False, `image_embeddings` will be 1024-dim; if True, it will be 2048-dim.
        )[0]  # return (vision_cls, text_cls)

print(image_embeddings.shape)

Load ckpt from /remote-home/share/lisj/Workspace/SOTA_NAS/encoder/musk/checkpoint/model.safetensors
torch.Size([1, 2048])


# Multimodal Retrieval Example

In [None]:
# >>>>>>>>>>>> load model >>>>>>>>>>>> #
model_config = "musk_large_patch16_384"
model = create_model(model_config).eval()
local_ckpt = "/remote-home/share/lisj/Workspace/SOTA_NAS/encoder/musk/checkpoint/model.safetensors"
utils.load_model_and_may_interpolate(local_ckpt, model, 'model|module', '')
model.to(device, dtype=torch.float16)
model.eval()
# <<<<<<<<<<<< load model <<<<<<<<<<<< #

# >>>>>>>>>>>> process image >>>>>>>>>>> #
# load an image and process it
img_size = 384
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(img_size, interpolation=3, antialias=True),
    torchvision.transforms.CenterCrop((img_size, img_size)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD)
])

img = Image.open('/remote-home/share/lisj/Workspace/SOTA_NAS/datasets/core/patches/train/1819360/146_315.png').convert("RGB")  # input image
img_tensor = transform(img).unsqueeze(0)
with torch.inference_mode():
    image_embeddings = model(
        image=img_tensor.to(device, dtype=torch.float16),
        with_head=True,  # We only use the retrieval head for image-text retrieval tasks.
        out_norm=True
        )[0]  # return (vision_cls, text_cls)
# <<<<<<<<<<< process image <<<<<<<<<<< #

# >>>>>>>>>>> process language >>>>>>>>> #
# load tokenzier for language input
tokenizer = XLMRobertaTokenizer("./musk/models/tokenizer.spm")
labels = [
        "healthy liver",          # 正常
        "simple steatosis (NAFL)",# 单纯脂肪肝
        "non-alcoholic steatohepatitis (NASH)"  # NASH
        # "alcoholic steatohepatitis (ASH)",  # 酒精脂肪肝
        ]

texts = ['histopathology image of ' + item for item in labels]
text_ids = []
paddings = []
for txt in texts:
    txt_ids, pad = utils.xlm_tokenizer(txt, tokenizer, max_len=100)
    text_ids.append(torch.tensor(txt_ids).unsqueeze(0))
    paddings.append(torch.tensor(pad).unsqueeze(0))

text_ids = torch.cat(text_ids)
paddings = torch.cat(paddings)
with torch.inference_mode():
    text_embeddings = model(
        text_description=text_ids.to(device),
        padding_mask=paddings.to(device),
        with_head=True, 
        out_norm=True
    )[1]  # return (vision_cls, text_cls)
# <<<<<<<<<<<< process language <<<<<<<<<<< #

# >>>>>>>>>>>>> calculate similarity >>>>>>> #
with torch.inference_mode():
    # expected prob:[0.3782, 0.3247, 0.2969]  --> lung adenocarcinoma
    sim = model.logit_scale * image_embeddings @ text_embeddings.T
    prob = sim.softmax(dim=-1)
    print(prob)

Load ckpt from /remote-home/share/lisj/Workspace/SOTA_NAS/encoder/musk/checkpoint/model.safetensors
text_embeddings.shape torch.Size([3, 1024])
tensor([[0.3359, 0.3257, 0.3386]], device='cuda:4', dtype=torch.float16)
