#**Libraries / path definition**

In [None]:
!pip install peft
!pip install -i https://pypi.org/simple/ bitsandbytes -U
!pip install git+https://github.com/huggingface/transformers


In [None]:
import sys
import os
from google.colab import drive
import random
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
from sklearn.decomposition import PCA
from tqdm import tqdm
import matplotlib.pyplot as plt
from transformers import AutoModel
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
from transformers import BitsAndBytesConfig

In [None]:
# Define the project directory path
project_dir = '/content/gdrive/MyDrive/'

# Define the name of the folder containining the datasets. All data are in tiff format (or equivalent).
# The expected data directory structure is as follows:
# Datasets
# |_Sample1
# |  |_img
# |     |_image1.tiff
# |     |_image2.tiff
# |     |_...
# |  |_mask
# |     |_mask1.tiff
# |     |_mask2.tiff
# |     |_...
# |  |_...
# ...

dataset_name = "some_dataset"
sample_name = "some_sample"
num_classes = 3
crop_size = 350

weights_directory = "for_pca"
model_name = "model_best_iou.pth"

num_classes = 3

#feat_dim = 384  # vits14
feat_dim = 768  # vitb14
# feat_dim = 1024  # vitl14
# feat_dim = 1536  # vitg14

#Adapt accordingly
mean =  np.array([123.07921846875976]*3)/255.0
std = np.array([84.04993142526148]*3)/255.0

# Check if CUDA is available, otherwise use CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

########

# Add the DinoV2 code directory to the system path for module imports
sys.path.append(os.path.join(project_dir, "code/DinoV2/"))

# Define the data directory path within the project directory
data_directory = os.path.join(project_dir, 'data')

# **Various functions**




In [None]:
# Define the directory paths
input_directory = os.path.join(data_directory, dataset_name, sample_name)

class LinearHead(nn.Module):
    def __init__(self, embedding_size=768, num_classes=3):
        super(LinearHead, self).__init__()
        self.embedding_size = embedding_size
        self.head = nn.Sequential(
            nn.BatchNorm2d(self.embedding_size),
            nn.Conv2d(self.embedding_size, num_classes, kernel_size=1, padding=0, bias=True),
            nn.Upsample(size=(560, 560), mode='bilinear', align_corners=False)
        )

    def forward(self, inputs):
        features = inputs["features"]
        logits = self.head(features)
        return logits

class DinoV2Segmentor(nn.Module):

    head = {
        "linear" : LinearHead
    }

    emb_size = {
        "small" : 384,
        "base" : 768,
        "large" : 1024,
    }

    def __init__(self, num_classes, size="base", peft=False, quantize=False, head_type="linear"):
        super(DinoV2Segmentor, self).__init__()
        self.num_classes = num_classes
        self.peft = peft
        self.embedding_size = self.emb_size[size]
        if quantize:
            self.quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_use_double_quant=True,
                bnb_4bit_compute_dtype=torch.bfloat16,
            )
            self.backbone = AutoModel.from_pretrained(f'facebook/dinov2-{size}', quantization_config=self.quantization_config)
            self.backbone = prepare_model_for_kbit_training(self.backbone)
        else:
            self.backbone = AutoModel.from_pretrained(f'facebook/dinov2-{size}')

        if peft:
            peft_config = LoraConfig(inference_mode=False, r=32, lora_alpha=32, lora_dropout=0.1, target_modules="all-linear", use_rslora=True)
            self.backbone = get_peft_model(self.backbone, peft_config)
            self.backbone.print_trainable_parameters()
        self.seg_head = self.build_head(head_type)
        print(self.seg_head)

    def forward(self, x, is_training=False):
        with torch.set_grad_enabled(self.peft and is_training):
            patch_size = x.shape[-1] // 14
            features = self.backbone(pixel_values=x).last_hidden_state[:, 1:]
        inputs = {"features": features, "image": x}
        if is_training:
            logits = self.seg_head(inputs)
            return logits
        else:
            return features

    def build_head(self, head_type):
        return self.head[head_type](embedding_size=self.embedding_size, num_classes=self.num_classes)

def segment_and_plot(pca_features, threshold, patch_h, patch_w, num_threshold_to_display, title):
    """
    Segment the background and foreground and plot the results.

    Parameters:
    - pca_features: PCA transformed features
    - threshold: Threshold value for segmentation
    - patch_h: Patch height
    - patch_w: Patch width
    - num_threshold_to_display: Number of images to display
    - title: Title for the plot
    """
    pca_features_bg = pca_features[:, 0] > threshold
    pca_features_fg = ~pca_features_bg

    num_img = len(pca_features_fg.reshape(-1, patch_h, patch_w))

    for i in range(int(num_threshold_to_display)):
        index = random.randint(0, num_img - 1)
        plt.subplot(1, int(num_threshold_to_display), i + 1)
        plt.imshow(pca_features_bg[index * patch_h * patch_w: (index + 1) * patch_h * patch_w].reshape(patch_h, patch_w))
        plt.title(title)
    plt.show()

    return pca_features_fg, pca_features_bg

def process_pca_foreground(pca, total_features, pca_features_fg, pca_features_bg, patch_h, patch_w):
    """
    Process the PCA foreground features and scale them.

    Parameters:
    - pca: PCA object
    - total_features: Total feature matrix
    - pca_features_fg: Foreground PCA features
    - pca_features_bg: Background PCA features
    - patch_h: Patch height
    - patch_w: Patch width
    """
    pca.fit(total_features[pca_features_fg])
    pca_features_left = pca.transform(total_features[pca_features_fg])

    for i in range(3):
        pca_features_left[:, i] = (pca_features_left[:, i] - pca_features_left[:, i].min()) / (pca_features_left[:, i].max() - pca_features_left[:, i].min())

    pca_features_rgb = np.zeros_like(total_features)
    pca_features_rgb[pca_features_bg] = 0
    pca_features_rgb[pca_features_fg] = pca_features_left
    pca_features_rgb = pca_features_rgb.reshape(2, patch_h, patch_w, 3)

    return pca_features_rgb

def create_transforms(crop_size, mean, std):
    """
    Create image transformation pipelines.

    Parameters:
    - crop_size: Size to crop the image
    - mean: Mean for normalization
    - std: Standard deviation for normalization
    """
    contrast_adjustment = transforms.ColorJitter(contrast=0.5, brightness=0.2)
    denoising = transforms.GaussianBlur(3, sigma=(0.1, 2.0))

    transform1 = transforms.Compose([
        transforms.CenterCrop(crop_size),
        denoising,
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
        contrast_adjustment,
    ])

    transform2 = transforms.Compose([
        transforms.CenterCrop(crop_size),
        transforms.ToTensor(),
    ])

    return transform1, transform2

def preprocess_images(input_directory, transform1, transform2, device, model, dinov2_vitl14, crop_size):
    """
    Preprocess images and extract features using the given model.

    Parameters:
    - input_directory: Directory containing input images and masks
    - transform1: Transformation for images
    - transform2: Transformation for masks
    - device: Device to perform computations on (CPU/GPU)
    - model: Model to use for feature extraction
    - dinov2_vitl14: Pre-trained DINO V2 model
    - crop_size: Size to crop the image
    """
    total_features_raw = []
    total_features_linear = []
    ground_truth = []
    images = []

    with torch.no_grad():
        image_paths = sorted(os.listdir(os.path.join(input_directory, 'images')))
        mask_paths = sorted(os.listdir(os.path.join(input_directory, 'masks')))

        for img_path, gt_path in tqdm(zip(image_paths, mask_paths), total=len(image_paths)):
            img = Image.open(os.path.join(input_directory, 'images', img_path)).convert('RGB')
            img_t = transform1(img)
            images.append(transform2(img))

            gt = Image.open(os.path.join(input_directory, 'masks', gt_path)).convert('L')
            gt_t = transform2(gt)
            ground_truth.append(gt_t)

            img_linear = img_t.to(device)
            features_linear = model(img_linear.unsqueeze(0), is_training=False)
            features_linear = features_linear.cpu()
            total_features_linear.append(features_linear)

            features_dict = dinov2_vitl14.forward_features(img_t.unsqueeze(0))
            features_raw = features_dict['x_norm_patchtokens']
            total_features_raw.append(features_raw)

    total_features_raw = torch.cat(total_features_raw, dim=0)
    total_features_linear = torch.cat(total_features_linear, dim=0)
    total_features_raw = total_features_raw.reshape(len(image_paths) * (crop_size // 14) * (crop_size // 14), -1)
    total_features_linear = total_features_linear.reshape(len(image_paths) * (crop_size // 14) * (crop_size // 14), -1)

    return images, total_features_raw, total_features_linear, ground_truth

def perform_pca(total_features, n_components=3):
    """
    Perform PCA on the total features.

    Parameters:
    - total_features: Feature matrix
    - n_components: Number of PCA components
    """
    pca = PCA(n_components=n_components)
    pca.fit(total_features)
    pca_features = pca.transform(total_features)

    for i in range(pca_features.shape[1]):
        pca_features[:, i] = (pca_features[:, i] - pca_features[:, i].min()) / (pca_features[:, i].max() - pca_features[:, i].min())

    return pca_features, pca

def display_random_images(preprocessed_images, pca_features_rgb_raw, pca_features_rgb_linear, ground_truth):
    """
    Display random images with their PCA transformed features and ground truth.

    Parameters:
    - preprocessed_images: List of preprocessed images
    - pca_features_rgb_raw: PCA features of raw images
    - pca_features_rgb_linear: PCA features of linear images
    - ground_truth: Ground truth masks
    """
    index = random.randint(0, len(preprocessed_images) - 1)
    preprocessed_image_np = np.array(preprocessed_images[index])
    ground_truth_np = np.array(ground_truth[index])
    preprocessed_image_np = np.transpose(preprocessed_image_np, (1, 2, 0))
    ground_truth_np = np.transpose(ground_truth_np, (1, 2, 0))

    fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(20, 5))
    ax1.imshow(preprocessed_image_np, cmap='gray')
    ax1.axis('off')
    ax2.imshow(pca_features_rgb_raw[index])
    ax2.axis('off')
    ax3.imshow(pca_features_rgb_linear[index])
    ax3.axis('off')
    ax4.imshow(ground_truth_np, cmap='gray')
    ax4.axis('off')

    plt.tight_layout()
    plt.show()

def pca_to_rgb(pca, pca_features, total_features, pca_features_fg, pca_features_bg, crop_size):
    """
    Convert PCA features to RGB format.

    Parameters:
    - pca: PCA object
    - pca_features: PCA features
    - total_features: Total feature matrix
    - pca_features_fg: Foreground PCA features
    - pca_features_bg: Background PCA features
    - crop_size: Size to crop the image
    """
    num_img = len(pca_features_fg.reshape(-1, (crop_size // 14), (crop_size // 14)))
    pca.fit(total_features[pca_features_fg])
    pca_features_left = pca.transform(total_features[pca_features_fg])

    for i in range(3):
        pca_features_left[:, i] = (pca_features_left[:, i] - pca_features_left[:, i].min()) / (pca_features_left[:, i].max() - pca_features_left[:, i].min())

    pca_features_rgb = pca_features.copy()
    pca_features_rgb[pca_features_bg] = 0
    pca_features_rgb[pca_features_fg] = pca_features_left
    pca_features_rgb = pca_features_rgb.reshape(num_img, (crop_size // 14), (crop_size // 14), 3)

    return pca_features_rgb

# Initialize the fine-tuned DINOv2
model = DinoV2Segmentor(num_classes=num_classes, size="base", peft=True, quantize=True, head_type="linear")
model.to(device)

# Load state dictionary
checkpoint_path = os.path.join(data_directory, weights_directory, model_name)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint)
model.eval()

# Load base DINOv2 model
dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')

# **PCA**

In [None]:
transform1, transform2 = create_transforms(crop_size, mean, std)

preprocessed_images, total_features_raw, total_features_linear, ground_truth = preprocess_images(input_directory, transform1, transform2, device, model, dinov2_vitl14, crop_size)

pca_features_raw, pca_raw = perform_pca(total_features_raw)
pca_features_linear, pca_linear = perform_pca(total_features_linear)

# **Threshold setting**




In [None]:
# Adapt accordingly
threshold = 0.95
num_threshold_to_display = 4

pca_features_fg_raw, pca_features_bg_raw = segment_and_plot(pca_features_raw, threshold, (crop_size//14), (crop_size//14), num_threshold_to_display, 'No fine-tuning')
pca_features_fg_linear, pca_features_bg_linear = segment_and_plot(pca_features_linear, threshold, (crop_size//14), (crop_size//14),num_threshold_to_display, 'Fine-tuned')

pca_features_rgb_raw = pca_to_rgb(pca_raw, pca_features_raw, total_features_raw, pca_features_fg_raw, pca_features_bg_raw, crop_size)
pca_features_rgb_linear = pca_to_rgb(pca_linear, pca_features_linear, total_features_linear, pca_features_fg_linear, pca_features_bg_linear, crop_size)

# **Display**

In [None]:
display_random_images(preprocessed_images, pca_features_rgb_raw, pca_features_rgb_linear, ground_truth)