#**Libraries / path definition**

In [None]:
!pip install torchmetrics

In [None]:
import os
import sys
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
import torchvision
from torchvision import datasets, transforms
from torchvision.datasets import VisionDataset
import torchvision.transforms.functional as Fv
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F
from transformers import AutoImageProcessor, AutoModel, get_scheduler, BitsAndBytesConfig
import random
from sklearn.neighbors import KNeighborsClassifier
from sklearn.manifold import TSNE
import torchmetrics
import math

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')
project_dir = '/content/gdrive/MyDrive/Schism/'


# Define the project directory path
#project_dir = '/content/gdrive/MyDrive/'
dataset_name = "Alhammadi"
# Define the data directory path within the project directory
data_directory = os.path.join(project_dir, 'npy_data', dataset_name)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the name of the folder containining the datasets. All data are in tiff format (or equivalent).
# The expected data directory structure is as follows:
# For classification :
# npy_data
# |_dataset_name
# |  |_berea_img.npy
# |  |_berea_mask.npy
# |  |_buff_berea_img.npy
# |  |_buff_berea_mask.npy
# |  |_...

# For segmentation :
# npy_data
# |_dataset_name
# |  |_sample1_img.npy
# |  |_sample1_mask.npy
# |  |_sample2_img.npy
# |  |_sample2_mask.npy
# |  |_sample3_img.npy
# |  |_sample3_mask.npy
# |  |_...


# **Various functions**

In [None]:
class EfficientClassifDataset(VisionDataset):
    def __init__(self, data_dir, rock_names, num_samples=None, num_classes=1, crop_size = (996,996), p=0.5, train=True):
        super().__init__(data_dir, transforms=None)
        print("Loading data ...")
        self.data_stats = {
            "berea" : [
                np.array([80.06]*3)/255.0, np.array([35.56]*3)/255.0
            ],
            "bentheimer" : [
                np.array([98.06]*3)/255.0, np.array([54.95]*3)/255.0
            ],
            "parker" : [
                np.array([96.57]*3)/255.0, np.array([29.23]*3)/255.0
            ],
            "kirby" : [
                np.array([99.19]*3)/255.0, np.array([39.49]*3)/255.0
            ],
            "buff_berea" : [
                np.array([99.53]*3)/255.0, np.array([45.07]*3)/255.0
            ],
            "leopard" : [
                np.array([103.04]*3)/255.0, np.array([44.41]*3)/255.0
            ],
            "bandera_brown" : [
                np.array([103.25]*3)/255.0, np.array([37.77]*3)/255.0
            ],
            "bandera_gray" : [
                np.array([100.98]*3)/255.0, np.array([35.74]*3)/255.0
            ],
            "berea_sister" : [
                np.array([74.89]*3)/255.0, np.array([32.01]*3)/255.0
            ],
            "castle_gate" : [
                np.array([110.37]*3)/255.0, np.array([53.73]*3)/255.0
            ],
        }
        self.img_data = [np.lib.format.open_memmap(data_dir+f"/{rock}_img.npy", dtype=np.uint8, mode='r') for rock in rock_names]
        self.rock_names = rock_names
        self.crop_size=crop_size
        self.p=p
        self.train = train

        self.num_classes = num_classes

        if num_samples is None:
          self.num_samples = len(self.img_data[0])
        else:
          self.num_samples = num_samples

        self.num_datasets = len(self.img_data)


    def get_random_crop_params(self, img):
        """Get parameters for ``crop`` for a random crop.

        Args:
            img (PIL Image or Tensor): Image to be cropped.

        Returns:
            tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
        """
        h, w = img.shape[:2]
        th, tw = self.crop_size

        if h < th or w < tw:
            raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")

        if w == tw and h == th:
            return 0, 0, h, w

        i = torch.randint(0, h - th + 1, size=(1,)).item()
        j = torch.randint(0, w - tw + 1, size=(1,)).item()

        return i, j, th, tw


    def __getitem__(self, idx):
        dataset_index = idx % self.num_datasets
        data_idx = (idx // self.num_datasets)
        img = self.img_data[dataset_index][data_idx]

        # random crop
        i, j, h, w = self.get_random_crop_params(img)
        img = img[i:i+h, j:j+w, :].copy()

        img = torch.from_numpy(img.transpose((2, 0, 1))).contiguous()/255.0 # forcément en [0 255] car uint8
        img = F.interpolate(input=img.unsqueeze(0), size=(256, 256), mode="bilinear", align_corners=False).squeeze()

        # normalize
        m = self.data_stats[self.rock_names[dataset_index]][0]
        s = self.data_stats[self.rock_names[dataset_index]][1]

        return torchvision.transforms.functional.normalize(img, m, s).float(), dataset_index, img

    def __len__(self):
        return self.num_datasets * self.num_samples

class EfficientSegmentationDataset(VisionDataset):
    def __init__(self, data_dir, rock_names, num_classes=3, num_samples=None, crop_size = (224,224), img_res=560, mask_res=128, save_dir=None):
        super().__init__(data_dir, transforms=None)
        print("Loading data ...")
        self.data_stats = {
            "sample1" : [
                np.array([123.07921846875976]*3)/255.0, np.array([84.04993142526148]*3)/255.0
            ],

            "sample2" : [
                np.array([117.92807255795907]*3)/255.0, np.array([80.61479412614699]*3)/255.0
            ],

            "sample3" : [
                np.array([119.7933619436969]*3)/255.0, np.array([80.18348841827216]*3)/255.0
            ],

        }
        self.img_data = [np.lib.format.open_memmap(data_dir+f"/{rock}_img.npy", dtype=np.uint8, mode='r') for rock in rock_names]
        self.mask_data = [np.lib.format.open_memmap(data_dir+f"/{rock}_mask.npy", dtype=np.uint8, mode='r') for rock in rock_names]
        self.rock_names = rock_names
        self.crop_size = crop_size
        self.IMG_RES = img_res
        self.mask_res = mask_res
        self.save_dir = save_dir
        self.num_classes = num_classes

        if num_samples is None:
          self.num_samples = len(self.img_data[0])
        else:
          self.num_samples = num_samples

        self.num_datasets = len(self.img_data)

    def center_crop(self, image, mask):
        height, width = image.shape[:2]
        crop_height, crop_width = self.crop_size

        if height < crop_height or width < crop_width:
            raise ValueError("Crop size must be smaller than the image size")

        top = (height - crop_height) // 2
        left = (width - crop_width) // 2
        cropped_image = image[top:top + crop_height, left:left + crop_width]
        cropped_mask = mask[top:top + crop_height, left:left + crop_width]

        return cropped_image, cropped_mask

    def __getitem__(self, idx):

        dataset_index = idx % self.num_datasets
        data_idx = (idx // self.num_datasets)

        img, mask = self.center_crop(self.img_data[dataset_index][data_idx].copy(), self.mask_data[dataset_index][data_idx].copy())

        img = torch.from_numpy(img.transpose((2, 0, 1))).contiguous()/255.0 # forcément en [0 255] car uint8
        mask = torch.from_numpy(mask).contiguous()/255.0

        img = F.interpolate(input=img.unsqueeze(0), size=(self.IMG_RES, self.IMG_RES), mode="bicubic", align_corners=False).squeeze()
        mask = F.interpolate(input=mask.unsqueeze(0).unsqueeze(0), size=(self.mask_res, self.mask_res), mode="nearest").squeeze()

        # normalize
        m = self.data_stats[self.rock_names[dataset_index]][0]
        s = self.data_stats[self.rock_names[dataset_index]][1]

        if self.num_classes > 2:
           mask = (mask* self.num_classes).long()-1

        return Fv.normalize(img, m, s).float(), mask, img

    def __len__(self):
        return self.num_datasets * self.num_samples

class HuggingFaceClassificator(nn.Module):
    def __init__(self, classification = True, rescale_size = 128):
        super(HuggingFaceClassificator, self).__init__()
        self.rescale_size = rescale_size
        self.classification = classification
        self.backbone = AutoModel.from_pretrained('facebook/dinov2-base') #768 features

    def forward(self, x):
        with torch.no_grad():
            if self.classification:
              features = self.backbone(pixel_values=x).last_hidden_state[:,1:].reshape(x.shape[0], -1)
            else:
              patch = int(x.shape[2] / 14)
              features = self.backbone(pixel_values=x).last_hidden_state[:, 1:].reshape(int(x.shape[0]), patch, patch, 768).permute(0, 3, 1, 2)
              features = F.interpolate(input=features, size=(self.rescale_size, self.rescale_size), mode="bilinear", align_corners=False)
              features = features.permute(0, 2, 3, 1).reshape(-1, 768)
        return features


def train_test_builder(num_samples, dataset, model, classification=True):

    X_array = []
    y_array = []
    img_list = []

    for i in tqdm(range(num_samples), desc="Loading", total= num_samples):
        x, y, img = dataset[i]
        feat = model(x.unsqueeze(0)).squeeze().numpy()
        X_array.append(feat)
        y_array.append(y)
        img_list.append(img)

    if classification:
        X_train = np.array(X_array)
        y_array = np.array(y_array)
    else:
        X_train = np.array(X_array).reshape(-1, 768)
        y_array = np.array(y_array).reshape(-1)

    return X_train, y_array, img_list


def display_random_set(img, pred, label):
    idx = random.randint(0, len(img) - 1)
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax[0].imshow(img[idx].permute(1,2,0).numpy(), cmap='gray')
    ax[0].set_title('Image')
    ax[1].imshow(label[idx], cmap='gray')
    ax[1].set_title('Ground Truth Mask')
    ax[2].imshow(pred[idx], cmap='gray')
    ax[2].set_title('kNN')
    plt.show()

# **kNN (classification)**

## Parameters

In [None]:
class_names = ["berea", "buff_berea", "bandera_brown", "bandera_gray", "berea_sister",
               "castle_gate", "kirby", "parker", "leopard", "bentheimer"]
num_train_samples = 300
num_test_samples = 200
n_neighbors = [5, 10, 50, 100, 200] # List of k values to test for the K-Nearest Neighbors classifier
num_classes = 1

## Training

In [None]:
train_dataset = EfficientClassifDataset(data_directory,
                                        class_names,
                                        num_samples=num_train_samples,
                                        num_classes=num_classes)

test_dataset = EfficientClassifDataset(data_directory,
                                       class_names,
                                       num_samples=num_test_samples,
                                       num_classes=num_classes)

model = HuggingFaceClassificator(classification = True)
model.eval()
model.to(device)

X_train, y_train, _ = train_test_builder(num_train_samples, train_dataset, model)
X_test, y_test, _ = train_test_builder(num_test_samples, test_dataset, model)

for k in n_neighbors:
    knn = KNeighborsClassifier(n_neighbors=k, n_jobs=-1)
    knn.fit(X_train, y_train)
    y_pred = knn.predict(X_test)
    accuracy = np.sum(y_pred==np.array(y_test))/num_test_samples
    print(f"{k} neighbors -> accuracy : ", accuracy * 100)

## **t-SNE**

In [None]:
# Assigning colors to each class
colors = ['blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan']

# Perform dimensionality reduction using t-SNE
tsne = TSNE(n_components=2, random_state=42)
tsne_result = tsne.fit_transform(X_train)

# Plot the data with different colors for each class
plt.figure(figsize=(8, 5))
for i in range(len(class_names)):
    indices = np.where(y_train == i)[0]
    plt.scatter(tsne_result[indices, 0], tsne_result[indices, 1], s=15, c=colors[i], label=class_names[i], alpha=0.5)

plt.xlabel('Component 1', fontsize=18)
plt.ylabel('Component 2', fontsize=18)
plt.legend(fontsize=14, bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(False)
plt.tight_layout()
#plt.savefig('/content/gdrive/MyDrive/t-sne.eps', format='eps', bbox_inches='tight')
plt.show()


# **kNN (segmentation)**

## Parameters

In [None]:
train_rocks = ['sample1', 'sample2']
test_rocks = ['sample3']
predict_res = 128
img_res = 560
num_train_samples = 10
num_test_samples = 10
num_classes = 3
n_neighbors = 250
n_jobs = 30

## Training

In [None]:
train_dataset = EfficientSegmentationDataset(data_directory,
                                             train_rocks,
                                             num_samples=num_train_samples,
                                             num_classes=num_classes,
                                             img_res=img_res,
                                             mask_res=predict_res)

test_dataset = EfficientSegmentationDataset(data_directory,
                                            test_rocks,
                                            num_samples=num_test_samples,
                                            num_classes=num_classes,
                                            img_res=img_res,
                                            mask_res=predict_res)

model = HuggingFaceClassificator(classification=False, rescale_size=predict_res)
model.eval()
model.to(device)

X_train, y_train, _ = train_test_builder(num_train_samples, train_dataset, model, classification = False)
X_test, y_test, img = train_test_builder(num_test_samples, test_dataset, model, classification = False)

#jaccard = torchmetrics.classification.MulticlassJaccardIndex(num_classes=num_classes)

knn = KNeighborsClassifier(n_neighbors=n_neighbors, n_jobs=n_jobs, weights="uniform")
knn.fit(X_train, y_train)
pred = knn.predict(X_test)

prediction = pred.reshape(-1, num_test_samples, predict_res, predict_res, 1).squeeze()
label = y_test.reshape(-1, num_test_samples, predict_res, predict_res, 1).squeeze()

## Display

In [None]:
display_random_set(img, prediction, label)