In [None]:
import numpy as np
import torch
import os
import skimage
import matplotlib.pyplot as plt
from torchvision.transforms import v2

In [None]:
img_path = "/scr/data/LINCS/DP-project/outputs/max_concentration_set/SQ00015128/A07/8/1@516.1053445229701x147.44633392226152.png"
img = skimage.io.imread(img_path)

In [None]:
plt.imshow(img)

In [None]:
def fold_channels(image, channel_width, mode="ignore"):
    # Expected input image shape: (h, w * c)
    # Output image shape: (h, w, c)
    output = np.reshape(image, (image.shape[0], channel_width, -1), order="F")

    if mode == "ignore":
        # Keep all channels
        pass
    elif mode == "drop":
        # Drop mask channel (last)
        output = output[:, :, 0:-1]
    elif mode == "apply":
        # Use last channel as a binary mask
        mask = output["image"][:, :, -1:]
        output = output[:, :, 0:-1] * mask

    return output

In [None]:
fold = fold_channels(img, img.shape[0])

In [None]:
plt.imshow(fold[:,:,0:3])

In [None]:
def channel_to_rgb(channel):
    px = np.concatenate(
        (channel[np.newaxis, :, :], channel[np.newaxis, :, :], channel[np.newaxis, :, :]),
        axis=0)
    tensor = torch.Tensor(px)[None, ...]
    normalized_tensor = v2.functional.normalize(tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    return normalized_tensor

# Load the ViT model

In [None]:
# Load model
gpu = 5
device = f"cuda:{gpu}" if torch.cuda.is_available() else 'cpu'

dinov2_vits14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg')
dinov2_vits14_reg.eval()
dinov2_vits14_reg.to(device)

In [None]:
img.shape, fold.shape

In [None]:
image_batch = torch.cat([channel_to_rgb(fold[17:-17,17:-17,i]) for i in range(5)])
image_batch.shape

In [None]:
output = dinov2_vits14_reg.forward_features(image_batch.to(device))

In [None]:
features = output["x_norm_clstoken"].cpu().detach().numpy()

In [None]:
features.shape

In [None]:
np.savez_compressed("features.npz", features)