In [1]:
import os
import sys
import shutil
import argparse
import math
import IPython 
from PIL import Image
from enum import Enum
from typing import Callable, List, Optional, Tuple, Union
from functools import partial

import h5py
import logging
import torch
import torch.nn as nn
import torchio as tio
import torchvision
from torchvision.datasets import VisionDataset
from torchvision.transforms import transforms
import numpy as np
import pandas as pd
import skimage
from scipy import sparse
import scipy.io as sio
import matplotlib.pyplot as plt 
import torchxrayvision as xrv
import nibabel as nib
from typing import Any, Dict, Optional
from collections import OrderedDict
from monai.losses.dice import DiceLoss, DiceCELoss

from torchmetrics import Metric, MetricCollection
from torchmetrics.wrappers import ClasswiseWrapper
from torchmetrics.classification import (MultilabelAUROC, MultilabelF1Score, MultilabelAccuracy, MulticlassF1Score, 
                                        MulticlassAccuracy, MulticlassAUROC, Accuracy, BinaryF1Score, BinaryAUROC,
                                        JaccardIndex, MulticlassJaccardIndex, Dice, BinaryAUROC)

from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer
import dinov2.distributed as distributed
from dinov2.models.unet import UNet
from dinov2.data import SamplerType, make_data_loader, make_dataset
from dinov2.data.datasets import NIHChestXray, MC, Shenzhen, SARSCoV2CT, BTCV, BTCVSlice, AMOS, MSDHeart
from dinov2.data.datasets.medical_dataset import MedicalVisionDataset
from dinov2.data.loaders import make_data_loader
from dinov2.data.transforms import (make_segmentation_train_transforms, make_classification_eval_transform, make_segmentation_eval_transforms,
                                    make_classification_train_transform)
from dinov2.eval.setup import setup_and_build_model
from dinov2.eval.utils import (is_padded_matrix, ModelWithIntermediateLayers, ModelWithNormalize, evaluate, extract_features, collate_fn_3d,
                               make_datasets, make_data_loaders)
from dinov2.eval.classification.utils import LinearClassifier, create_linear_input, setup_linear_classifiers, AllClassifiers
from dinov2.eval.metrics import build_segmentation_metrics, MetricAveraging, MetricType
from dinov2.eval.segmentation.utils import LinearDecoder, setup_decoders, DINOV2Encoder
from dinov2.utils import show_image_from_tensor

  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'dinov2'

In [4]:
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import os
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

patch_h = 40
patch_w = 40
# feat_dim = 384 # vits14
# feat_dim = 768 # vitb14
feat_dim = 1024 # vitl14
# feat_dim = 1536 # vitg14

transform = T.Compose([
    # T.GaussianBlur(9, sigma=(0.1, 2.0)),
    T.Resize((patch_h * 14, patch_w * 14)),
    T.CenterCrop((patch_h * 14, patch_w * 14)),
    T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

# dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
# dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14').cuda()
# dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')

# extract features
features = torch.zeros(4, patch_h * patch_w, feat_dim)
imgs_tensor = torch.zeros(4, 3, patch_h * 14, patch_w * 14)
for i in range(1, 5):
    img_path = f'nih_images/img_{i+1}.png'
    img = Image.open(img_path)
    img = torch.from_numpy(np.stack((img,)*3, axis=-1))
    print(img.shape)
    imgs_tensor[i] = transform(img)[:3]
with torch.no_grad():
    features_dict = dinov2_vitl14.forward_features(imgs_tensor)
    features = features_dict['x_norm_patchtokens']

# PCA for feature inferred
from sklearn.decomposition import PCA

features = features.reshape(4 * patch_h * patch_w, feat_dim)

pca = PCA(n_components=3)
pca.fit(features)
pca_features = pca.transform(features)

# visualize PCA components for finding a proper threshold
plt.subplot(1, 3, 1)
plt.hist(pca_features[:, 0])
plt.subplot(1, 3, 2)
plt.hist(pca_features[:, 1])
plt.subplot(1, 3, 3)
plt.hist(pca_features[:, 2])
plt.show()
plt.close()

# uncomment below to plot the first pca component
# pca_features[:, 0] = (pca_features[:, 0] - pca_features[:, 0].min()) / (pca_features[:, 0].max() - pca_features[:, 0].min())
# for i in range(4):
#     plt.subplot(2, 2, i+1)
#     plt.imshow(pca_features[i * patch_h * patch_w: (i+1) * patch_h * patch_w, 0].reshape(patch_h, patch_w))
# plt.show()
# plt.close()

# segment using the first component
pca_features_bg = pca_features[:, 0] < 10
pca_features_fg = ~pca_features_bg

# plot the pca_features_bg
for i in range(4):
    plt.subplot(2, 2, i+1)
    plt.imshow(pca_features_bg[i * patch_h * patch_w: (i+1) * patch_h * patch_w].reshape(patch_h, patch_w))
plt.show()

# PCA for only foreground patches
pca.fit(features[pca_features_fg]) # NOTE: I forgot to add it in my original answer
pca_features_rem = pca.transform(features[pca_features_fg])
for i in range(3):
    # pca_features_rem[:, i] = (pca_features_rem[:, i] - pca_features_rem[:, i].min()) / (pca_features_rem[:, i].max() - pca_features_rem[:, i].min())
    # transform using mean and std, I personally found this transformation gives a better visualization
    pca_features_rem[:, i] = (pca_features_rem[:, i] - pca_features_rem[:, i].mean()) / (pca_features_rem[:, i].std() ** 2) + 0.5

pca_features_rgb = pca_features.copy()
pca_features_rgb[pca_features_bg] = 0
pca_features_rgb[pca_features_fg] = pca_features_rem

pca_features_rgb = pca_features_rgb.reshape(4, patch_h, patch_w, 3)
for i in range(4):
    plt.subplot(2, 2, i+1)
    plt.imshow(pca_features_rgb[i][..., ::-1])
plt.savefig('features.png')
plt.show()
plt.close()

torch.Size([1024, 1024, 3])


