## Generating BiomedCLIP Embeddings

In [None]:
import os 
import nibabel as nib
import numpy as np
import SimpleITK as sitk
import json
import pandas as pd
from tqdm import tqdm
from PIL import Image
import pickle
import sys
import open_clip
import torch
from torch.utils.data import Dataset

# Load Model
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
tokenizer = open_clip.get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)
model.to(device)

# CT scan dataset
class LoadDataset(Dataset):
    def __init__(self, path, preprocess_fn):
        self.path = path
        self.preprocess_fn = preprocess_fn
        self.volume = self._load_volume(self.path)        
        self.slices = [self.volume[i] for i in range(self.volume.shape[0])]

    def _load_volume(self, path):
        vol = sitk.ReadImage(path)
        vol = sitk.GetArrayFromImage(vol)
        return vol

    def __len__(self):
        return len(self.slices)
    
    def _preprocess_image(self, im):
        return self.preprocess_fn(Image.fromarray(im))

    def __getitem__(self, index):
        return {
            "image": self._preprocess_image(self.slices[index])
            }

## Generate Embeddings with BiomedCLIP

In [None]:
# Input Argument 
# Input Nifti Folder
dir_path = os.path.join('')

## Output embedding directory
output_dir = ""
os.makedirs(output_dir, exist_ok=True)

In [None]:
train_data = []
for img in os.listdir(dir_path):
    if '.amlignore' in img:
        pass
    elif '.DS_Store' in img:
        pass
    else:
        img_path = os.path.join(dir_path, img)
        train_data.append(img_path)

for scan_data in tqdm(train_data):    
    print(scan_data)
    image_path = scan_data
    image_name = image_path.split('/')[-1].split('.')[0]

    # Load CT scan
    volume_path = scan_data

    # Create dataloader for the CT scan
    dataset = LoadDataset(volume_path, preprocess_val)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        pin_memory=False,
        shuffle=False,
        num_workers=2,
        drop_last=False,
        batch_size=32
        )

    # Generate embeddings (not normalized)
    embds = []
    for batch in dataloader:
        images = batch["image"].to(device) 
        embd = model(images)[0].cpu().detach().numpy()
        embds.append(embd)
    embds = np.concatenate(embds, axis=0)
    with open(os.path.join(output_dir, f"{image_name}.pkl"), "wb") as f:
        pickle.dump(embds, f)