# VOLTA inference

This code intents to demonstrate simple steps to perform inference on a pretrained VOLTA model. 

First, we need to load the weights into the memory as shown below:

In [3]:
import torch

checkpoint_path = 'pretrained_weights/ovarian/checkpoint.pth.tar'
checkpoint = torch.load(checkpoint_path, map_location="cpu")
state_dict = checkpoint['state_dict']

Next, we have to construct the model, set the loaded weights into it, and set it to the evaluation mode. Please note that the below configuration for the model remain the same for all of our pretrained models across all datasets.

In [8]:
import backbones
import moco.mocov3.builder

model = moco.mocov3.builder.MoCoV3(
    base_encoder=backbones.__dict__["preact_resnet18"], 
    dim=64, 
    m=0.999, 
    mlp=[128], 
    prediction_head=32, 
    mlp_embedding=False, 
    spectral_normalization=False, 
    queue_size=65536, 
    teacher=True)

print('loading model message: ', model.load_state_dict(state_dict, strict=True))
model = model.eval()

loading model message:  <All keys matched successfully>


Now, we have to prepare the transformation pipeline. This pipeline includes 3 steps:
1. Resizing the cells into the models input size ($32\times32$ for all of our models)
2. Normalize the cell with the normalization vector of the dataset used for the training of the model (in this case ovarian dataset)
3. Convert the cell vector to a Pytorch tensor 

In [13]:
import albumentations
import albumentations.pytorch
import cv2
from dataset.ovarian.transform import get_cell_normalization # Note: use the same normalization as the training data of the pretrained model

transforms = albumentations.Compose([
    albumentations.Resize(height=32, width=32, interpolation=cv2.INTER_CUBIC), # step 1: resize
    get_cell_normalization(), # step 2: normalization
    albumentations.pytorch.ToTensorV2(transpose_mask=True), # step 3: to tensor
    ])


In [15]:
from PIL import Image
import numpy as np

source_image_path = 'examples/example1.png'

# read the image with PIL and convert to numpy array
img = np.array(Image.open(source_image_path))

# resize, normalize, and convert to tensor
img = transforms(image=img)['image']

# pass the image through the model
embedding = model(img.unsqueeze(0)) # Note: unsqueeze to add a batch dimension

# print the shape of the embedding
print('final embedding shape: ', embedding.shape)

torch.Size([1, 512])


In [19]:
import os
from PIL import Image
import numpy as np


# get the list of all images in the examples folder
source_image_path = [os.path.join('examples', f) for f in os.listdir('examples') if f.endswith('.png')]

print(f'found {len(source_image_path)} images: {source_image_path}')

# read the image with PIL and convert to numpy array
img = [np.array(Image.open(f)) for f in source_image_path]

# resize, normalize, convert to tensor, and stack
img = torch.stack([transforms(image=i)['image'] for i in img])

# pass the image through the model
embedding = model(img)

# print the shape of the embedding
print('final embedding shape: ', embedding.shape)

found 3 images: ['examples/example3.png', 'examples/example2.png', 'examples/example1.png']
final embedding shape:  torch.Size([3, 512])
