#**Libraries / path definition**

In [None]:
!pip install peft
!pip install -i https://pypi.org/simple/ bitsandbytes -U
!pip install git+https://github.com/huggingface/transformers

In [None]:
import sys
import os
from google.colab import drive
import random
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
from sklearn.decomposition import PCA
from tqdm import tqdm
import matplotlib.pyplot as plt
from transformers import AutoModel
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
from transformers import BitsAndBytesConfig

In [None]:
# Define the project directory path
project_dir = '/content/gdrive/MyDrive/'

# Define the name of the folder containining the datasets. All data are in tiff format (or equivalent).
# The expected data directory structure is as follows:
# Datasets
# |_Sample1
# |  |_images
# |     |_image1.tiff
# |     |_image2.tiff
# |     |_...
# |  |_masks
# |     |_mask1.tiff
# |     |_mask2.tiff
# |     |_...
# |  |_...
# ...

dataset_name = "Alhammadi"
sample_name = "sample3"
num_classes = 3
crop_size = 560

weights_directory = os.path.join(project_dir, 'runs', 'weights_folder')
model_name = "model_best_iou.pth"

num_classes = 3

#feat_dim = 384  # vits14
feat_dim = 768  # vitb14
# feat_dim = 1024  # vitl14
# feat_dim = 1536  # vitg14

# Number of layer of the DINOv2 backbone concatenated before passing the features to the head
n_features = 1

#Adapt accordingly
mean =  np.array([123.07921846875976]*3)/255.0
std = np.array([84.04993142526148]*3)/255.0

# Check if CUDA is available, otherwise use CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

########

# Add the DinoV2 code directory to the system path for module imports
sys.path.append(os.path.join(project_dir, "code/DinoV2/"))

# Define the data directory path within the project directory
data_directory = os.path.join(project_dir, 'data')

# **Various functions**




In [None]:
# Define the directory paths
input_directory = os.path.join(data_directory, dataset_name, sample_name)

class DinoV2Segmentor(nn.Module):
    emb_size = {
        "small" : 384,
        "base" : 768,
        "large" : 1024,
    }

    def __init__(self, num_classes, size="base", n_features=1, peft=False, quantize=False, head_type="linear"):
        super(DinoV2Segmentor, self).__init__()
        assert size in self.emb_size.keys(), "Invalid size"
        #assert head_type in self.head.keys(), "Invalid head type"
        if n_features > 1 and head_type=="cnn":
          raise ValueError("Multi feature concatenation with cnn head is not supported currently, feel free to customize the code if required ;)")
        self.num_classes = num_classes
        self.n_features = n_features
        self.peft = peft
        self.embedding_size = self.emb_size[size]
        if quantize :
            self.quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_use_double_quant=True,
                bnb_4bit_compute_dtype=torch.bfloat16,
            )
            self.backbone = AutoModel.from_pretrained(f'facebook/dinov2-{size}', quantization_config=self.quantization_config)
            self.backbone = prepare_model_for_kbit_training(self.backbone)
        else:
            self.backbone = AutoModel.from_pretrained(f'facebook/dinov2-{size}')

        if peft:
            peft_config = LoraConfig(inference_mode=False, r=32, lora_alpha=32, lora_dropout=0.1, target_modules="all-linear", use_rslora=True)
            self.backbone = get_peft_model(self.backbone, peft_config)
            self.backbone.print_trainable_parameters()
        #self.seg_head = self.build_head(head_type)
        print(f"Number of parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad)}")

    def forward(self, x, is_training):
        #if self.n_features == 1:
        features = self.backbone(pixel_values=x).last_hidden_state[:, 1:]  # Shape [1, 1600, 768]

        return features

def create_transforms(crop_size, mean, std):
    transform1 = transforms.Compose([
        transforms.CenterCrop(252),
        transforms.Resize(crop_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    transform2 = transforms.Compose([
        transforms.CenterCrop(252),
        transforms.Resize(crop_size),
        transforms.ToTensor(),
    ])

    return transform1, transform2

def preprocess_images(input_directory, transform1, transform2, device, model, dinov2_vitl14, crop_size):
    total_features_raw = []
    total_features_linear = []
    ground_truth = []
    images = []

    with torch.no_grad():
        image_paths = sorted(os.listdir(os.path.join(input_directory, 'images')))
        mask_paths = sorted(os.listdir(os.path.join(input_directory, 'masks')))

        for img_path, gt_path in tqdm(zip(image_paths, mask_paths), total=len(image_paths)):
            img = Image.open(os.path.join(input_directory, 'images', img_path)).convert('RGB')
            img_t = transform1(img)
            images.append(transform2(img))

            gt = Image.open(os.path.join(input_directory, 'masks', gt_path)).convert('L')
            gt_t = transform2(gt)
            ground_truth.append(gt_t)

            img_linear = img_t.to(device)
            features_linear = model(img_linear.unsqueeze(0), is_training=False)
            features_linear = features_linear[0].cpu()
            total_features_linear.append(features_linear)

            features_dict = dinov2_vitl14.forward_features(img_t.unsqueeze(0))
            features_raw = features_dict['x_norm_patchtokens']
            total_features_raw.append(features_raw)

    total_features_raw = torch.cat(total_features_raw, dim=0)
    total_features_linear = torch.cat(total_features_linear, dim=0)
    total_features_raw = total_features_raw.reshape(len(image_paths) * (crop_size // 14) * (crop_size // 14), -1)
    total_features_linear = total_features_linear.reshape(len(image_paths) * (crop_size // 14) * (crop_size // 14), -1)

    return images, total_features_raw, total_features_linear, ground_truth

def perform_pca(total_features, n_components=3):
    from sklearn.preprocessing import StandardScaler, MinMaxScaler

    pca = PCA(n_components=n_components)
    scaler = StandardScaler()
    features_normalized = scaler.fit_transform(total_features)
    pca_features = pca.fit_transform(features_normalized)

    scaler2 = MinMaxScaler()
    pca_features = scaler2.fit_transform(pca_features)

    return pca_features, pca

def display_random_images(preprocessed_images, pca_features_rgb_raw, pca_features_rgb_linear, ground_truth):

    patchs =  np.array(preprocessed_images).shape[-1] //14

    index = random.randint(0, len(preprocessed_images) - 1)

    pca_raw_img = pca_features_rgb_raw.reshape(-1, patchs, patchs, 3)
    pca_linear_img = pca_features_rgb_linear.reshape(-1, patchs, patchs, 3)

    preprocessed_image_np = np.array(preprocessed_images[index])
    ground_truth_np = np.array(ground_truth[index])
    preprocessed_image_np = np.transpose(preprocessed_image_np, (1, 2, 0))
    ground_truth_np = np.transpose(ground_truth_np, (1, 2, 0))

    fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(20, 5))
    ax1.imshow(preprocessed_image_np, cmap='gray')
    ax1.axis('off')
    ax2.imshow(pca_raw_img[index])
    ax2.axis('off')
    ax3.imshow(pca_linear_img[index])
    ax3.axis('off')
    ax4.imshow(ground_truth_np, cmap='gray')
    ax4.axis('off')

    plt.tight_layout()
    plt.show()

    return preprocessed_image_np, pca_raw_img[index], pca_linear_img[index], ground_truth_np

# Initialize the fine-tuned DINOv2
model = DinoV2Segmentor(num_classes=3, size='base', peft=True, quantize=True, head_type='linear', n_features=n_features)
model.to(device)

# Load state dictionary
checkpoint_path = os.path.join(data_directory, weights_directory, model_name)
checkpoint = torch.load(checkpoint_path, map_location=device)

model.load_state_dict(checkpoint, strict=False)
model.eval()

# Load base DINOv2 model
dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')

# **PCA**

In [None]:
transform1, transform2 = create_transforms(crop_size, mean, std)

preprocessed_images, total_features_raw, total_features_linear, ground_truth = preprocess_images(input_directory, transform1, transform2, device, model, dinov2_vitl14, crop_size)

pca_features_raw, pca_raw = perform_pca(total_features_raw)
pca_features_linear, pca_linear = perform_pca(total_features_linear)

# **Display**

In [None]:
scan, pca_nonfinetuned, pca_finetuned, gt =  display_random_images(preprocessed_images, pca_features_raw, pca_features_linear, ground_truth)