In [21]:
import os
import torch.utils.data
import numpy as np
import pandas as pd

import utils
import vision_transformer as vits
from constants import Column
import augmentations as aug

In [3]:
# Parameters
arch = 'vit_small'
patch_size = 16
in_channels = 4
n_last_blocks = 4
avgpool_patchtokens = False
pretrained_weights = './checkpoints/pretrained_cross_batch_n16_guide0_ntc_norm.ckpt'
checkpoint_key = 'teacher'

In [4]:
model = vits.__dict__[arch](patch_size=patch_size, num_classes=0, in_chans=in_channels)
embed_dim = model.embed_dim * (n_last_blocks + int(avgpool_patchtokens))

# Run inference
if torch.cuda.is_available():
    model.cuda()
model.eval()
# # load weights to evaluate
utils.load_pretrained_weights(model, pretrained_weights, checkpoint_key, arch, patch_size)

Pretrained weights found at ./checkpoints/pretrained_cross_batch_n16_guide0_ntc_norm.ckpt and loaded with msg: <All keys matched successfully>


### Run inference on example images

In [24]:
files = os.listdir('./example_data/images')
# here we used zscore as an example
# we suggest to use the statistics calcualated from NTCs in practice
normalizer = aug.Normalization(method='zscore')
image_list = []
keys =[]
for file in files:
    keys.append(file.split('.npy')[0])
    image = np.load(os.path.join('./example_data/images', file))
    image_norm = normalizer(image)
    image_list.append(image_norm)
images = torch.Tensor(np.stack(image_list, axis=0))



In [30]:
with torch.no_grad():
    # image, key
    if torch.cuda.is_available():
        input = images.cuda(non_blocking=True)
    else:
        input = images

    if "vit" in arch:
        intermediate_output = model.get_intermediate_layers(input, n_last_blocks)
        output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
        if avgpool_patchtokens:
            output = torch.cat((output.unsqueeze(-1),
                                torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
            output = output.reshape(output.shape[0], -1)
    else:
        output = model(input)

    if torch.cuda.is_available():
        output = output.cpu().numpy()
    else:
        output = output.numpy()

embed_dim = output.shape[-1]
output_df = pd.DataFrame(output, columns=[f'feature_{idx}' for idx in range(embed_dim)])

# Add metadata
output_df['key'] = keys
output_df[[Column.plate.value, Column.well.value, Column.tile.value, Column.gene.value,
           Column.sgRNA.value, 'meta_df_index']] = \
    output_df.apply(lambda x: pd.Series(str(x['key']).split(';')), axis=1)


### Run inference on example LMDB dataset

In [44]:
import dataset

import importlib
importlib.reload(dataset)

<module 'dataset' from '/home/yaoh11/cellpaint/set-dino/dataset.py'>

In [45]:
dataset_path = './example_data/lmdb_dataset'
crop_size = 96
# Load metadata
df = pd.read_csv('./example_data/metadata.csv', index_col=0)

# Create the dataset and data loader
normalizer = aug.Normalization(method='zscore')
ds = dataset.InferenceDataset(df, dataset_path=dataset_path, crop_size=crop_size, normalizer=normalizer)

dataloader = torch.utils.data.DataLoader(
    dataset=ds,
    batch_size=8,
    shuffle=False,
    num_workers=0,
    drop_last=False,
    )



In [46]:
output_list = []
key_list = []
with torch.no_grad():
    for idx, (images, keys) in enumerate(dataloader):
        if torch.cuda.is_available():
            input = images.cuda(non_blocking=True)
        else:
            input = images

        if "vit" in arch:
            intermediate_output = model.get_intermediate_layers(input, n_last_blocks)
            output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
            if avgpool_patchtokens:
                output = torch.cat((output.unsqueeze(-1),
                                    torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
                output = output.reshape(output.shape[0], -1)
        else:
            output = model(input)

        if torch.cuda.is_available():
            output_list.append(output.cpu().numpy())
        else:
            output_list.append(output.numpy())
        key_list.extend(list(keys))

output_all = np.concatenate(output_list, axis=0)
embed_dim = output.shape[-1]
output_df = pd.DataFrame(output_all, columns=[f'feature_{idx}' for idx in range(embed_dim)])

# Add metadata
output_df['key'] = key_list
output_df[[Column.plate.value, Column.well.value, Column.tile.value, Column.gene.value,
           Column.sgRNA.value, 'meta_df_index']] = \
    output_df.apply(lambda x: pd.Series(str(x['key']).split(';')), axis=1)