### Note:
Full inference script will be refactored and made available closer to publication. The following code is provided to reviewers for illustration purposes.

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer

In [None]:
# Base class of all CLIP-like dual encoder models
class DualEncoderModel(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        
    
    def encode_text(self, token_ids, attention_mask = None):
        """
        encode a sequence text tokens into shared visual/language latent space 
        optionally accepts an attention mask to prohibit attention to certain positions
        """
        raise NotImplementedError("to be implemented")
    
    def encode_image(self, image):
        """
        encode an image into shared visual/language latent space
        """
        raise NotImplementedError("to be implemented")

# create a dummy dual encoder model for illustration purposes
class DummyDualEncoderModel(DualEncoderModel):
    def __init__(self, emb_dim = 128):
        super().__init__()
        self.emb_dim = emb_dim
        
    def encode_text(self, token_ids, attention_mask = None):
        # create random embeddings for illustration purposes
        text_emb = torch.rand(len(token_ids), self.emb_dim)
        return text_emb
    
    def encode_image(self, image):
        # create random embeddings for illustration purposes
        img_emb = torch.rand(len(image), self.emb_dim)
        return img_emb
    
def spatial_smoothing(logits, coords, ss_k = 8):
    I = knn_lookup(coords.astype('float32'), k = ss_k, device=logits.device)
    logits = logits[I].mean(dim=1)
    return logits

def knn_lookup(coords, device, k = 8):
    d = coords.shape[1]
    index = faiss.IndexFlatL2(d)
    index.add(coords)
    D, I = index.search(coords, k + 1)
    I = I.to(device)
    return I

def build_zero_shot_classifier(model, tokenizer, classnames, templates, device = 'cuda'):
    with torch.no_grad():
        zeroshot_weights = []
        for classname in classnames:
            texts = [template.replace('CLASSNAME', classname) for template in templates]
            texts, attention_mask = tokenize(tokenizer, texts) # Tokenize with custom tokenizer
            texts = torch.from_numpy(np.array(texts)).to(device)
            attention_mask = torch.from_numpy(np.array(attention_mask)).to(device)
            class_embeddings = model.encode_text(texts, attention_mask=attention_mask)
            class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
            class_embedding /= class_embedding.norm()
            zeroshot_weights.append(class_embedding)

        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
    return zeroshot_weights

def tokenize(tokenizer, texts):
    tokens = tokenizer.batch_encode_plus(texts, 
                                         max_length = 64,
                                         add_special_tokens = True, 
                                         return_token_type_ids = False,
                                         truncation = True,
                                         padding = 'max_length',
                                         return_attention_mask=True)
    return tokens['input_ids'], tokens['attention_mask']

def topK_pooling(logits, topK = (1,)):
    """
    logits: N x C logits for each patch
    topK: tuple of the top number of patches to use for pooling
    """
    # Sums logits across topj patches for each class, to get class prediction for each topj
    maxK = min(max(topK), logits.size(0)) # Ensures k is smaller than number of patches. 
    values, _ = logits.topk(maxK, 0, True, True) # maxK x C
    values = {k : values[:min(k, maxK)].mean(dim=0, keepdim=True) for k in topK} # dict of 1 x C logit scores
    preds = {key: val.argmax(dim=1) for key,val in values.items()} # dict of predicted class indices
    return preds, values



## inference on WSIs
Suppose a WSI has 10,000 patches, each 256 x 256, we first embed each patch using the DualEncoderModel's *encode_image* method. For large WSIs, due to GPU memory constraints, these embeddings will be created in mini-batches and cached to local SSD storage before inference can performed. We use an embedding dimension of 512 in our model, as a result, the preprocesed input to MI-Zero for this WSI will be a tensor of size 10,000 x 512.

In the example below, for simplicity, we'll assume the WSI is very small and has only 64 patches, therefore we can perform embedding and inference in a single step without caching the features.

In [None]:
# set device 
device = 'cpu'

# embedding dimension
emb_dim = 512

# classnames
classnames = ['lung adenocarcinoma', 'lung squamous cell carcinoma']

# list of templates to ensemble
templates = ['an image of CLASSNAME.',
             'CLASSNAME is present.',
             'a histopathological image showing CLASSNAME.',
             'presence of CLASSNAME.',
             'CLASSNAME.']

# whether to perform spatial smoothing
ss = True
try: 
    import faiss
except ImportError:
    print('faiss installation not found, disabling spatial smoothing')
    ss = False

In [None]:
# create model
model = DummyDualEncoderModel(emb_dim = emb_dim).to(device)

# a WSI of 64 256 by 256 images
imgs = torch.rand(64, 3, 256, 256).to(device)

# coordinates of each patch, used for building knn for spatial smoothing
coords = np.random.randint(1000, size = (len(imgs), 2))

# a embedded WSI bag of size 64 x emb_dim
image_features = model.encode_image(imgs)

In [None]:
# load a pretrained tokenizer, for illustration, we use BioClinicalBERT's tokenizer
tokenizer = AutoTokenizer.from_pretrained('emilyalsentzer/Bio_ClinicalBERT', fast=True)
# build classifier using prompts
classifier = build_zero_shot_classifier(model, tokenizer, classnames, templates, device)

In [None]:
image_features = F.normalize(image_features, dim=-1) 
logits = image_features @ classifier

print('classifier weights: ', tuple(classifier.size()))
print('image embeddings: ', tuple(image_features.size()))


if ss:
    logits = spatial_smoothing(logits, coords)

# mean pooling
meanpool_logits = logits.mean(dim=0, keepdim=True)

# topK pooling
k = 10
_, topkpool_logits = topK_pooling(logits, topK = (k,))
topkpool_logits = topkpool_logits[k]

print('\ninstance logits: ', tuple(logits.size()))
print('mean pooled logits: ', tuple(meanpool_logits.size()))
print('topk pooled logits: ', tuple(topkpool_logits.size()))