In [1]:
from torch.utils.data import Sampler
import torch
from torchvision import transforms
import cv2
from tqdm import tqdm
from natsort import natsorted
import os
import numpy as np
from sklearn.manifold import TSNE
from sklearn.cluster import DBSCAN
import matplotlib.pyplot as plt
from auxilary.utils import *
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.preprocessing import Normalizer
import itertools

from dataset import MonuSegDataSet
from torch.utils.data import DataLoader

In [2]:

class DinoPoweredSampler(Sampler):
    def __init__(self, images, dino_model, config, mode="train", dbscan_eps=2, training_phase='high-density'):
        '''
        Args:
            images: A list of image patches
            dino_model: The DINO model
            config: The config dictionary
            mode: The mode of the sampler. Can be "train", "val", "test" or "debug"
            dbscan_eps: The maximum distance between two samples for one to be considered as in the neighborhood of the other. This is not a maximum bound on the distances of points within a cluster. 
                        This is the most important DBSCAN parameter to choose appropriately for your data set and distance function.
        '''
        self.dino_model = dino_model
        self.mode = mode
        self.batch_size = config["batch_size"]
        self.debug = config["debug"]
        self.debugDilution = config["debugDilution"]
        self.batchVisualization = config["batchVisualization"]
        
        self.dbscan_eps = dbscan_eps
        #self.plotDir = config["expt_dir"]
        # Perform feature extraction, t-SNE, and DBSCAN here 

        self.training_phase = training_phase

        self.count_insufficientBatch = 0


        self.image_patches = images
        if config["reUseFeatures"]:
            print("Loading Features")
            self.features = np.load("Outputs/Features/"+self.mode+"-features.npy")
        else:
            print("Calculating Features")
            self.features = self.get_features()
            print("Shape of Features: ", self.features.shape)
            createDir(["Outputs/Features/"])
            np.save("Outputs/Features/"+self.mode+"-features.npy", self.features)


        #Scaling Features
        scaler = StandardScaler()
        self.scaled_features = scaler.fit_transform(self.features)

        normalizer = Normalizer(norm='l2')
        self.normalized_features = normalizer.fit_transform(self.scaled_features)

        self.image_patches_tsne = self.apply_tsne(plot=False)

        print("Applying DBSCAN")
        self.clusters = self.apply_dbscan()

        # plot the clusters
        _t = self.apply_tsne(plot=True)


        #np.save('Outputs/Features/image_clusters.npy', self.clusters)
        #print("Applying t-SNE")
        

        #np.save('Outputs/Features/image_patches_tsne.npy', self.image_patches_tsne)

        self.all_indices = set()

        print("Sampling Initialization Complete") 

    def plot_batches(self, all_indices, total_batches):

        #plotted indices
        plotted_indices = set()

        for batch_num in range(total_batches):
            plt.figure(figsize=(8, 8))
            
            # Plot all points in a light grey color as a background
            plt.scatter(self.image_patches_tsne[:, 0], self.image_patches_tsne[:, 1], color='lightgrey', alpha=0.5)
            
            # Highlight the selected images for this batch
            selected_image_indexes = all_indices[batch_num * self.batch_size: (batch_num + 1) * self.batch_size]
            selected_tsne = self.image_patches_tsne[selected_image_indexes]
            old_tsne = self.image_patches_tsne[list(plotted_indices)]
            plotted_indices.update(selected_image_indexes)
    
            if old_tsne is not None:
                plt.scatter(old_tsne[:, 0], old_tsne[:, 1], color='green', alpha=0.6)  # Previously selected points in blue
            plt.scatter(selected_tsne[:, 0], selected_tsne[:, 1], color='red', alpha=0.6)  # Selected points in red

            plt.title(f't-SNE visualization of images for batch {batch_num + 1}')
            plt.xlabel('t-SNE component 1')
            plt.ylabel('t-SNE component 2')
            
            # Save the plot with the batch number
            plt.savefig(f"Outputs/Batch_Plots/tsne_batch_{batch_num + 1}.png")
            plt.close()


    def __iter__(self):
        # Reset all_indices at the beginning of each iteration to start fresh
        self.all_indices = set()

        # Initialize an empty list to store all indices for the epoch
        all_indices = []

        # Calculate the total number of batches needed
        total_batches = self.__len__()

        # Loop to generate all indices for the epoch
        for _ in range(total_batches):
            # Sample indices for a batch
            batch_indices = self.sampleImages()

            # Check if we have already included enough indices
            if len(all_indices) + len(batch_indices) > len(self.image_patches):
                # If adding the current batch_indices exceeds the number of images,
                # trim the batch_indices to fit the remaining number of images
                batch_indices = batch_indices[:len(self.image_patches) - len(all_indices)]

            # Extend the all_indices list with the new batch indices
            all_indices.extend(batch_indices)

        # Shuffle the indices to ensure random order of image access
        np.random.shuffle(all_indices)

        # Plot the batches
        if self.batchVisualization:
            print("\nPlotting Batches for visualization")
            createDir(["Outputs/Batch_Plots/"])
            print("Total Batches: ", total_batches)
            self.plot_batches(all_indices, total_batches)

        # Yield each index at a time
        return iter(all_indices)



    def __len__(self):
        if self.debug:
            return len(self.image_patches) // self.debugDilution
        else:
            num_images = len(self.image_patches)
            return (num_images // self.batch_size) + int(num_images % self.batch_size > 0)
    
    
    #process image patches
    def get_features(self):
        features = []

        transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    

        for img in tqdm(self.image_patches):
            img_tensor = transform(img).unsqueeze(0)
            img_tensor = img_tensor.to('cuda')
            with torch.no_grad():
                feature = self.dino_model(img_tensor)
                feature = feature.cpu()
                features.append(feature.squeeze().numpy())

        return np.array(features)

    def save_batches(self, filename):
        all_batches = []
        total_batches = self.__len__()
        f_writer = open(filename, 'w')

        # Generate all batches
        for _ in range(total_batches):
            batch_indices = self.sampleImages()
            all_batches.append(batch_indices)
            f_writer.write(str(batch_indices) + '\n')
            if not len(batch_indices):
                break 




    def sample_from_cluster(self, cluster_indices, k=1, used_indices=None):
        if used_indices is None:
            used_indices = []

        # Extract the t-SNE coordinates for the current cluster
        cluster_tsne = self.image_patches_tsne[cluster_indices]

        # Calculate the centroid of the current cluster
        centroid = np.mean(cluster_tsne, axis=0)
        # Calculate the distances from each point in the cluster to the centroid
        distances = np.linalg.norm(cluster_tsne - centroid, axis=1)

        # Determine the maximum distance as the "radius" of the cluster
        max_distance = np.max(distances)
        # Set the threshold as a fraction of the maximum distance
        threshold_distance = max_distance / 2

        # Classify points as central or boundary
        central_indices = [cluster_indices[i] for i in range(len(cluster_indices)) if distances[i] <= threshold_distance and cluster_indices[i] not in used_indices]
        boundary_indices = [cluster_indices[i] for i in range(len(cluster_indices)) if distances[i] > threshold_distance and cluster_indices[i] not in used_indices]

        return central_indices[:k], boundary_indices[:k]
    

    def apply_tsne(self, plot = False):
        if not plot:
            tsne = TSNE(n_components=2, perplexity=40, learning_rate=200, random_state=42)
            image_patches_tsne = tsne.fit_transform(self.normalized_features)
            return image_patches_tsne
        # Plot the results
        plt.figure(1)
        plt.scatter(self.image_patches_tsne[:, 0], self.image_patches_tsne[:, 1], c=self.clusters)
        plt.colorbar()
        plt.title(f"t-SNE Visualization, DBSNE c - {len(np.unique(self.clusters))}")
        plt.xlabel('Y')
        plt.ylabel('X')
        #plt.imsave(self.plotDir+self.mode+"-tsne.png", image_patches_tsne)
        createDir(["Outputs/Plots/"])
        if self.mode == "debug":
            plt.savefig(f"Outputs/Plots/{self.mode}-tsne-{self.dbscan_eps}.png")
        else:
            plt.savefig("Outputs/Plots/"+self.mode+"-tsne.png")
        plt.clf()
        return None

    def apply_dbscan(self, eps = 2, min_samples = 5, metrics='euclidean',gen_plot = False):
        eps = self.dbscan_eps
        dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric=metrics)
        clusters = dbscan.fit_predict(self.image_patches_tsne)
        print("Unique clusters:", np.unique(clusters))  # You should see more than just -1

        if not gen_plot:
            return clusters
        
        # Plot the results
        plt.figure(figsize=(10, 10))

        # Scatter plot for each uniquely labeled cluster
        unique_clusters = np.unique(clusters)
        for cluster in unique_clusters:
            x = self.image_patches_tsne[clusters == cluster][:, 0]
            y = self.image_patches_tsne[clusters == cluster][:, 1]
            plt.scatter(x, y, label=f"Cluster {cluster}")

        plt.title("DBSCAN Clustering")
        plt.xlabel("1st component")
        plt.ylabel("2nd component")
        plt.legend()
        createDir(["Outputs/Plots/"])
        
        if self.mode == "debug":
            plt.savefig(f"Outputs/Plots/{self.mode}-dbscan-{self.dbscan_eps}.png")
        else:
            plt.savefig("Outputs/Plots/"+self.mode+"-dbscan.png")
            
        return clusters


    def sampleImages(self):
        valid_clusters = [c for c in np.unique(self.clusters) if c >= 0]
        np.random.shuffle(valid_clusters)

        batch_indices = []
        loopCount = 0
        while len(batch_indices) < self.batch_size:
            new_indices = []
        
            for cluster in valid_clusters:
                if len(batch_indices) >= self.batch_size:
                    break

                center_indices, boundary_indices = self.sample_from_cluster(cluster_indices=np.where(self.clusters == cluster)[0], k=1, used_indices=self.all_indices)

                # Select indices based on the current training phase
                if self.training_phase == 'high-density':
                    new_indices = [idx for idx in center_indices if idx not in self.all_indices]
                else:  # low-density phase
                    new_indices = [idx for idx in boundary_indices if idx not in self.all_indices]

                batch_indices.extend(new_indices)
                self.all_indices.update(new_indices)

            if len(new_indices) == 0:
                loopCount += 1

            if loopCount > 10:
                # Handle the case where new indices are not found
                # Repeat some of the already selected indices to fill the batch
                remaining_slots = self.batch_size - len(batch_indices)
                repeat_indices = batch_indices[:remaining_slots]
                batch_indices.extend(repeat_indices)
                break  # Exit the while loop as the batch is now full
        print("Batch Indices: ", batch_indices)

        return batch_indices[:self.batch_size]


In [3]:
config = readConfig("config.sys")
config["batch_size"] = 16
config["reUseFeatures"] = False
config["batchVisualization"] = True
config["training_phase"] = 'high-density'

In [4]:
trainPaths = config["trainDataset"]
sampleTrainImages = load_images(trainPaths)
dino_model = load_sampling_model(modelType="giga")
train_dataset = MonuSegDataSet(config["trainDataset"])

loading Images from path: Dataset/trainNormal2/


100%|██████████| 63360/63360 [01:01<00:00, 1023.16it/s]
Using cache found in /home/blue/.cache/torch/hub/facebookresearch_dinov2_main
xFormers not available
xFormers not available


In [5]:
sampler = DinoPoweredSampler(sampleTrainImages, dino_model, config)
sampler.save_batches("log/batch_log.txt")

Calculating Features


  0%|          | 0/31680 [00:00<?, ?it/s]

100%|██████████| 31680/31680 [38:22<00:00, 13.76it/s] 


Shape of Features:  (31680, 1536)
Directory Outputs/Features/ already exists
Applying DBSCAN
Unique clusters: [-1  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22
 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37]
Directory Outputs/Plots/ already exists
Sampling Initialization Complete
Batch Indices:  [14256, 23760, 17637, 1602, 25368, 3168, 4758, 9510, 22176, 9531, 1584, 17454, 3, 7929, 26931, 26943]
Batch Indices:  [51, 30096, 12675, 19008, 1604, 23775, 9511, 26932, 22177, 25369, 23761, 11088, 20592, 15861, 3169, 7947]
Batch Indices:  [4759, 17455, 20593, 15843, 23776, 9512, 12672, 25370, 20655, 4, 30097, 25350, 1617, 22178, 12676, 19009]
Batch Indices:  [25641, 14257, 11089, 19010, 23777, 26933, 17638, 23762, 22179, 20656, 20594, 15844, 9532, 12673, 4760, 22215]
Batch Indices:  [26944, 7948, 9533, 19023, 15862, 17639, 17456, 52, 9516, 23784, 6336, 28545, 30098, 26937, 25374, 5]
Batch Indices:  [30105, 12674, 28546, 19024, 9591, 20598, 7949, 22180, 25642, 12677, 11090

<Figure size 640x480 with 0 Axes>

In [23]:
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], sampler=sampler)

In [24]:
epochs = 1
# Training loop
for epoch in range(epochs):
    i = 0
    for batch in train_loader:
        pass

    print(f"Epoch {epoch+1}/{epochs} completed.")

Batch Indices:  [31680, 25359, 1593, 12696, 26931, 23790, 22182, 41184, 28533, 22179, 6354, 39600, 30099, 4777, 31686, 33264]
Batch Indices:  [3, 33265, 31687, 38022, 36435, 28534, 39601, 28515, 9507, 15840, 12672, 25360, 34848, 36432, 3198, 38016]
Batch Indices:  [11088, 6355, 4, 22183, 7944, 41190, 6345, 33266, 20592, 3168, 9508, 30100, 20604, 39602, 25392, 34849]
Batch Indices:  [28516, 28535, 41191, 25393, 36433, 11089, 5, 9509, 1594, 12673, 31688, 7945, 22180, 17433, 20605, 3199]
Batch Indices:  [9516, 39603, 36436, 34850, 20593, 33270, 19008, 28569, 31681, 41192, 11090, 22184, 17434, 31098, 31701, 23791]
Batch Indices:  [4785, 14286, 26932, 41193, 31702, 1595, 30101, 33271, 11106, 6, 20606, 3200, 6356, 7946, 19009, 39604]
Batch Indices:  [3169, 25394, 39605, 28517, 12674, 30102, 31703, 20594, 7962, 9510, 11100, 38017, 20616, 31682, 34851, 22206]
Batch Indices:  [36434, 30103, 26933, 23763, 28518, 9511, 25446, 39609, 11101, 41185, 7, 7963, 20617, 23792, 14287, 33272]
Batch Indices

KeyboardInterrupt: 

{'log': 'log/',
 'debug': False,
 'debugDilution': 20,
 'wandb': True,
 'normalization': 'reinhard',
 'targetImagePath': 'Dataset/MonuSegData/Training/TissueImages/TCGA-A7-A13F-01Z-00-DX1.png',
 'to_be_aug': 'Dataset/MonuSegData/',
 'out_dir': 'Dataset/MonuSegData/slidingAugNormal/',
 'augmented_dir': 'Dataset/MonuSegData/augmentated/',
 'tileHeight': 800,
 'tileWidth': 800,
 'slidingSize': 50,
 'augmentPerImage': 100,
 'finalTileHeight': 256,
 'finalTileWidth': 256,
 'splitRatio': 0.9,
 'trainDataset': 'Dataset/trainNormal/',
 'valDataset': 'Dataset/valNormal/',
 'testDataset': 'Dataset/testNormal/',
 'resumeModel': 'model/best_model.pth',
 'sampleImages': True,
 'dinoModelType': 'giga',
 'reUseFeatures': True,
 'batchVisualization': True,
 'trainingPhase': 'high-density',
 'class1': [0, 0, 0],
 'class2': [255, 255, 255],
 'model_type': 'UNet_3Plus',
 'input_img_type': 'rgb',
 'kernel_size': 3,
 'use_maxblurpool': False,
 'epochs': 30,
 'batch_size': 32,
 'learning_rate': 1e-06,
 'lr_