## Exploring segmentation capabilities of DINO features on ultrasound

In [None]:
import torch
from torchvision.transforms import Normalize

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc
rc('animation', html='jshtml')
import pandas as pd
from tqdm import tqdm

from sklearn.decomposition import PCA

from model_lora_vit import get_vit, load_lora_vit_from_dino_ckpt
from data_transforms import get_deterministic_transform
from dataloader_tmed import TMED2

torch.hub.set_dir("../pretrained_weights")

In [None]:
# configure the GPU
device = 1 if torch.cuda.is_available() else "cpu"

In [None]:
# load the backbone model, ensure params are consistent with ckpt
experiment = 'full'
if experiment == 'imagenet':
    ckpt_path = None
    lora_rank = 0
elif experiment == 'full':
    ckpt_path = '../logs/training_base/checkpoint.pth'
    lora_rank = 0
elif experiment == 'lora4':
    ckpt_path = '../logs/training_1/checkpoint0009.pth'
    lora_rank = 4
else:
    raise ValueError()

arch = 'vit_small'
patch_size = 8
if ckpt_path == None:
    # load the default DINO model
    model = get_vit(arch, patch_size, lora_rank=0)
else:
    model = get_vit(arch, patch_size, lora_rank)
    load_lora_vit_from_dino_ckpt(model, ckpt_path)
model.to(device).eval()

In [None]:
# load the dataset
transform = get_deterministic_transform()
tr_dataset = TMED2(
    split = "train", # train/val/test/all/unlabeled
    transform = transform,
    parasternal_only = True,
    label_scheme_name = 'tufts',
)
tr_dataloader = torch.utils.data.DataLoader(tr_dataset, batch_size=8, shuffle=False)

va_dataset = TMED2(
    split = "val", # train/val/test/all/unlabeled
    transform = transform,
    parasternal_only = True,
    label_scheme_name = 'tufts',
)
va_dataloader = torch.utils.data.DataLoader(va_dataset, batch_size=8, shuffle=False)

### First take a look at the image pipeline 
DINO accepts BxCxHxW dimensionality. 

Images are normalized using ImageNet weights, which were used to train DINO.

When visualizing the batch, we have to un-normalize the image and permute the image dimensions to HxWxC.

In [None]:
def unnormalize(im):
    # reverses Imagenet/DINO normalization
    # assumes input size is 3xHxW
    inv_normalize = Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255], std=[1/0.229, 1/0.224, 1/0.255], inplace=True)
    return inv_normalize(im)

# return unnormalized frames for viewing
def unnormalize_batch_of_frames(imgs):
    imgs_norm = imgs.detach().clone()
    B = imgs_norm.shape[0]
    for i in range(B):
        imgs_norm[i] = unnormalize_frame(imgs_norm[i])
    return imgs_norm

def unnormalize_frame(img):
    # reverses Imagenet/DINO normalization
    # assumes input size is 3xHxW
    inv_normalize = Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255], std=[1/0.229, 1/0.224, 1/0.255])
    return inv_normalize(img)

def show_batch_of_frames(imgs):
    B, C, H, W = imgs.shape
    fig, axs = plt.subplots(ncols=B, squeeze=False, figsize=(2*B,6))
    for i in range(B):
        img = imgs[i].detach().cpu().numpy().transpose(1,2,0)
        axs[0, i].imshow(img)
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

# visualize one batch
data_iter = iter(tr_dataloader)
x, [y, y_view] = next(data_iter)
print(
        f"target_AS: {y}\n"
        f"view: {y_view}\n"
        f"cine shape: {x.shape}"
    )

vis_cine = unnormalize(x[0])
plt.imshow(vis_cine.cpu().numpy().transpose(1,2,0))
plt.show()
plt.close('all')

### Perform PCA on the DINOv2 features
DINOv2 can produce patch features.

One useful way to visualize these patch features is to first run dimensionality reduction to turn the 384-D features to 3-D, then visualize each dimension as a RGB color channel.

This helps us humans see which patches are close to each other in feature space.

Sufficiently large differences in patch features imply that one patch may contain semantically different information to another patch, this is an approximation of "segmentation", which applies well to natural images but not necessarily medical images.

In [None]:
# visualize the PCA of the patch embeddings, for one frame
features = model.get_intermediate_layers(x.to(device))[0]
print(features.shape)

In [None]:
# determine the number of patches horizontally and vertically based on the return dimension
H = int(np.sqrt(features.shape[1] - 1))
assert H**2 == features.shape[1] - 1
print(H)

In [None]:
def plot_single_image_pca(features, patch_dim=16):
    pca_features, _ = single_image_pca(features)
    # normalize to [0,1]
    pca_features = (pca_features - pca_features.min()) / (pca_features.max() - pca_features.min())
    pca_features = pca_features * 255
    
    plt.imshow(pca_features.reshape(patch_dim, patch_dim, 3).astype(np.uint8))
    plt.show()

def single_image_pca(features):
    # return the PCA'd features and the PCA object for more .transform() calls
    pca = PCA(n_components=3)
    pca.fit(features)
    
    pca_features = pca.transform(features)
    return pca_features, pca

for i in range(x.shape[0]):
    img_features = features[i].detach().cpu().numpy() # 785 x 384
    patch_features = img_features[1:, :]  # 784 x 384
    plot_single_image_pca(patch_features, H)

### Background removal
By looking at the value of the first component, we can separate the "foreground" of the image and the "background" of the image. 

This is not always consistent but can generally segment the beam area.

We can subsequently run a PCA on only the foreground features to visualize the similarity between foreground patches.

In [None]:
# background removal on the PCA of one image based on the first principal component
def single_image_background(features, threshold=0):
    # patch features are H**2 x 3 after PCA
    pca_features, _ = single_image_pca(features)
    background = pca_features[:,0] <= 0 # H**2
    return background
    
#patch_features = features['x_norm_patchtokens'][0].detach().cpu().numpy()
for i in range(x.shape[0]):
    patch_features = features[i, 1:, :].detach().cpu().numpy()
    plt.imshow(single_image_background(patch_features).reshape(H, H))
    plt.show()

In [None]:
# to summarize what we did:
# remove the background from each image, compute the PCA of the leftover components altogether
def pca_based_part_segmentation(imgs, learned_pca_obj=None, animate=False):
    # assume imgs are BxCxHxW tensors
    x = imgs.to(device)
    B = x.shape[0]
    features = model.get_intermediate_layers(x.to(device))[0] # Bx785x384
    features = features.detach().cpu().numpy()
    patch_features = features[:, 1:, :]

    # calculate the number of patches
    H = int(np.sqrt(patch_features.shape[1]))
    assert H**2 == patch_features.shape[1]

    # extract the background
    background = []
    for i in range(B):
        background.append(single_image_background(patch_features[i]))
        
    # keep track of the positions where the image is not the background
    positions = []
    non_background_features = []
    for bg in background:
        positions.append(np.argwhere(bg == False)[:,0])
        
    # compound the features from the non-background pixels
    for i in range(B):
        pos_indices = positions[i]
        for p in pos_indices:
            non_background_features.append(features[i][p])
    non_background_features = np.array(non_background_features) # N'x384 where N' < B*(H**2)

    if learned_pca_obj is None:
        # we run PCA to learn these features
        nb_pca_features, pca_obj = single_image_pca(non_background_features) # N'x3
    else:
        # we use PCA.transform with the learned PCA
        nb_pca_features = learned_pca_obj.transform(non_background_features)
    # re-center these features
    nb_pca_features = (nb_pca_features - nb_pca_features.min()) / (nb_pca_features.max() - nb_pca_features.min())
    nb_pca_features = nb_pca_features * 255

    # re-integrate these features into the image
    new_images = np.zeros((B, H**2, 3))
    j = 0
    for i in range(B):
        pos_indices = positions[i]
        for p in pos_indices:
            new_images[i,p,:] = nb_pca_features[j, :]
            j += 1

    # show the final results
    ret = {}
    vis_imgs = np.clip(unnormalize_batch_of_frames(imgs).numpy(), 0, 1)
    if animate:
        frames = []
        fig, axs = plt.subplots(ncols=2, squeeze=False, figsize=(6, 4))
        for i in range(B):
            left = axs[0,0].imshow(new_images[i].reshape(H, H, 3).astype(np.uint8))
            right = axs[0,1].imshow(vis_imgs[i].transpose(1,2,0))
            frames.append([left, right])
            #print(frames[i])
        ani = animation.ArtistAnimation(fig, frames)
        ret['animation'] = ani
        plt.close('all')
    else:
        for i in range(B):
            fig, axs = plt.subplots(ncols=2, squeeze=False, figsize=(6, 12))
            axs[0,0].imshow(new_images[i].reshape(H, H, 3).astype(np.uint8))
            axs[0,1].imshow(vis_imgs[i].transpose(1,2,0))
            plt.show()
        
    # return the learned PCA to transfer over to video
    if learned_pca_obj is None:
        ret['learned_pca_obj'] = pca_obj
    return ret

learned_pca_obj = pca_based_part_segmentation(x)['learned_pca_obj']

In [None]:
# # use the PCA features for video
# def save_gif(name, ani):
#     # save the gif if needed
#     writergif = animation.PillowWriter(fps=3)
#     ani.save(name +'.gif',writer=writergif) 
#     print("GIF saved at " + name + ".gif")

# for i in range(len(sample_dict['cine'])):
#     video = sample_dict['cine'][i, :, :, :, :] # BxCxTxHxW
#     video = video.permute(1, 0, 2, 3)
#     ani = pca_based_part_segmentation(video, learned_pca_obj, animate=True)['animation']
#     #save_gif(str(i), ani)