# KNN Label Transfer

> The `KNN_Label_transfer` module provides functionality for label transfer between datasets using k-nearest neighbors (KNN) algorithms. It includes methods for majority and weighted voting based on nearest neighbors, calculation of centroids for labeled data, and label assignment based on nearest centroids. This module is designed to facilitate the propagation of labels from a reference dataset with known labels to a query dataset where labels are unknown.


In [1]:
#| hide
from nbdev.showdoc import *

In [2]:
#| default_exp transfer


We will utilize FAISS to identify the nearest neighbor for each cell in the query dataset from the reference dataset. After identifying the nearest neighbor, we will assign the annotation of the cell in the query dataset to match the annotation of its nearest neighbor in the reference dataset. Thanks to Meta's optimization of FAISS, we can perform these nearest neighbor searches in batches and on the GPU, which significantly speeds up the process. The IndexFlatL2 function is used to calculate the L2 (or Euclidean) distance between all points in our query vector and the vectors in the reference index. A diagram illustrating this process is provided below. For more detailed information on vector searches and maximizing the use of FAISS, refer to the link provided (image credit also included).

 When the K value in our K-Nearest Neighbors (KNN) algorithm is set to more than one neighbor, we need a method to determine the final predicted label. This is essential for ensuring the accuracy and reliability of our classification tasks. Below, we will implement three methods for this purpose:


# Majority Voting Method

The Majority Voting Method is a technique used to determine the final label for each data point in the query dataset. This is achieved by conducting a simple majority vote among the labels of the K nearest neighbors in the reference dataset. The process involves the following steps:

1. **Neighbor Identification**: For every data point in the query dataset, the K nearest neighbors in the reference dataset are identified.
2. **Vote Counting**: The labels of these K neighbors are tallied.
3. **Label Assignment**: The label with the highest tally is assigned to the query data point.

It is important to note that this method assumes that each of the K neighbors has equal significance.


In [3]:
#| export
from typing import List as List
from collections import Counter

def knn_majority_voting(indices: List[List[int]], # A list of lists, where each sublist contains the indices of the k-nearest neighbors in the reference dataset for a given query point.
                        reference_labels: List[str] # A list of labels corresponding to the points in the reference dataset.
                        ) -> List[str]: # A list of labels for each point in the query dataset, determined by majority voting.
    """
    Assigns labels to query dataset points using majority voting from k-nearest neighbors.
    """
    query_labels = []
    for ind in indices:
        neighbor_labels = [reference_labels[i] for i in ind]
        label_counts = Counter(neighbor_labels)
        most_common_label = label_counts.most_common(1)[0][0]
        query_labels.append(most_common_label)
    return query_labels


# Weighted Voting Method
  
The Weighted Voting Method is a technique that considers both the labels and the distances of the K nearest neighbors from the reference dataset. This method is especially beneficial when the neighbors are at different distances from the query point. The process involves the following steps:
  
1. Neighbor and Distance Identification: For each query point, the K nearest neighbors and their distances from the query point are identified.
2. Weighted Vote Calculation: Each neighbor is assigned a weight based on its distance (neighbors closer to the query point are given higher weight).
3. Vote Aggregation: The weighted votes for each label are aggregated.
4. Label Assignment: The label with the highest aggregated weight is assigned to the query point.
  
This method allows neighbors that are closer to the query point to have more influence, potentially leading to more accurate predictions, especially in scenarios where the nearest neighbors are not uniformly distributed around the query point.


In [4]:
#| export
from typing import List as List
from typing import Dict as Dict
from collections import Counter

def knn_weighted_voting(indices: List, # A list of lists, where each sublist contains the indices of the k-nearest neighbors in the reference dataset for a given query point.
                        distances: List, # A list of lists, where each sublist contains the distances of the k-nearest neighbors from a given query point.
                        reference_labels: List # A list of labels corresponding to the points in the reference dataset.
                        ) -> List[str]: # A list of labels for each point in the query dataset, determined by weighted voting.
    """
    Assigns labels to query dataset points using weighted voting from k-nearest neighbors.
    """
    query_labels = []
    for ind, dist in zip(indices, distances):
        weighted_votes: Dict[str, float] = {}
        for i, d in zip(ind, dist):
            label = reference_labels[i]
            weight = 1 / (d + 1e-6)  # Adding a small constant to avoid division by zero
            weighted_votes[label] = weighted_votes.get(label, 0) + weight
        most_common_label = max(weighted_votes, key=weighted_votes.get)
        query_labels.append(most_common_label)
    return query_labels


# Centroid-Based Label Assignment
  
The Centroid-Based Label Assignment method assigns labels to data points in a query dataset based on their closeness to the centroids of different classes in a reference dataset. The centroids are the average position of all points within a specific class. Here's a detailed breakdown of the process:
  
1. **Centroid Calculation**: For each label in the reference dataset, compute the centroid. A centroid is the arithmetic mean position of all the points sharing the same label. This step involves adding up all data points of each label and dividing by the total number of points for that label.
  
2. **Label Assignment by Nearest Centroid**: For each point in the query dataset, identify its closest centroid. The closest centroid is the one with the least Euclidean distance from the query point. The label of this nearest centroid is then assigned to the query point.
  
This method operates under the assumption that data points of the same class are typically clustered together, with the centroid serving as the central point of these clusters. It is particularly useful in situations where data points of the same class form distinct, compact clusters. However, its performance may be suboptimal in scenarios where the class distribution is highly irregular or overlapping.


In [5]:
#| export
import numpy as np
from collections import defaultdict, Counter


def calculate_centroids(reference_data: List[List[float]], # A list of lists, where each sublist represents a data point in the reference dataset.
                        reference_labels: List[str] # A list of labels corresponding to the points in the reference dataset.
                        ) -> Dict[str, np.ndarray]: # Returns a dictionary where each key is a label and the corresponding value is the centroid of that label.
    """
    Calculates the centroids for each label in the reference dataset.
    """
    

    # Initialize a dictionary to store the sum of data points for each label.
    label_sums = defaultdict(lambda: np.zeros(len(reference_data[0])))
    
    # Count the number of occurrences of each label in the reference dataset.
    label_counts = Counter(reference_labels)
    
    # For each data point and its corresponding label in the reference dataset,
    # add the data point to the sum of data points for that label.
    for data, label in zip(reference_data, reference_labels):
        label_sums[label] += np.array(data)
    
    # Calculate the centroid for each label by dividing the sum of data points for that label by the count of that label.
    centroids = {label: label_sums[label] / count for label, count in label_counts.items()}
    
    return centroids


In [6]:
#| export
def assign_labels_by_nearest_centroid(query_data: List[List[float]], # A list of lists, where each sublist represents a data point in the query dataset.
                                      centroids: Dict[str, np.ndarray] # A dictionary where each key is a label and the corresponding value is the centroid of that label.
                                      ) -> List[str]: # Returns a list of labels assigned to each point in the query dataset based on the nearest centroid.
    """
    Assigns labels to each point in the query dataset based on the nearest centroid.
    """
    # Initialize an empty list to store the labels for the query data points.
    query_labels = []
    
    # For each data point in the query dataset,
    for data in query_data:
        # Convert the data point to a numpy array.
        data_point = np.array(data)
        
        # Find the label of the centroid closest to the data point.
        # The closest centroid is the one with the least Euclidean distance from the data point.
        closest_label = min(centroids.keys(), key=lambda label: np.linalg.norm(data_point - centroids[label]))
        
        # Append the label of the closest centroid to the list of labels for the query data points.
        query_labels.append(closest_label)
    
    # Return the list of labels for the query data points.
    return query_labels

# Running the Label Transfer
The following section of code implements our label transfer algorithm using FAISS. This function is designed to produce a very simple label transfer under default conditions, i.e., when \( k = 1 \). In this case, there is no need for label consensus as there is only one nearest neighbor.

However, the function also provides the flexibility to vary \( k \), the number of nearest neighbors considered for label assignment. When \( k > 1 \), a label consensus method is required to decide the label for the query point based on the labels of its \( k \) nearest neighbors. The function supports three label consensus methods: 'majority_voting', 'weighted_voting', and 'centroid_based'.

'majority_voting' assigns the label that appears most frequently among the \( k \) nearest neighbors. In case of a tie, it selects the label of the closest neighbor among the tied labels.

'weighted_voting' assigns the label based on a weighted vote where closer neighbors have a higher weight. The weight of each neighbor is inversely proportional to its distance from the query point.

'centroid_based' assigns labels to data points in the query dataset based on their closeness to the centroids of different classes in the reference dataset. This method is particularly useful in situations where data points of the same class form distinct, compact clusters.

The function also allows the distance metric to be varied according to any metric supported by FAISS. This provides additional flexibility in handling different types of data distributions.


In [7]:
#| export

from typing import List, Optional, Union, Tuple


def labels(embedding_array_reference: np.ndarray, # A numpy array representing the reference dataset.
                       embedding_array_query: np.ndarray, # A numpy array representing the query dataset.
                       reference_labels: List[str], # A list of labels for the reference dataset.
                       k: int = 1, # The number of nearest neighbors to consider for label assignment.
                       use_gpu: bool = True, # Whether to use GPU for computation.
                       batch_size: Optional[int] = None, # The size of the batch for computation. If None, the entire query dataset is processed in one batch.
                       distance_metric: str = 'L2', # The distance metric to use. Can be 'L2' or 'IP'.
                       label_consensus: str = 'majority_voting', # The label consensus method to use. Can be 'majority_voting', 'weighted_voting', or 'centroid_based'.
                       timed: bool = False # Whether to return the time taken for label transfer.
                       ) -> Union[List[str], Tuple[List[str], float]]: # Returns a list of labels for the query dataset. If timed is True, also returns the time taken for label transfer.
    
    "Transfers labels from a reference dataset to a query dataset using FAISS."
    
    
    from collections import Counter, defaultdict
    import numpy as np
    import faiss
    import time
    
    start_time = time.time()
    dimension = embedding_array_reference.shape[1]
    res = faiss.StandardGpuResources() if use_gpu else None
    if use_gpu:
        res.noTempMemory()
    if distance_metric == 'L2':
        index = faiss.GpuIndexFlatL2(res, dimension) if use_gpu else faiss.IndexFlatL2(dimension)
    elif distance_metric == 'IP':
        index = faiss.GpuIndexFlatIP(res, dimension) if use_gpu else faiss.IndexFlatIP(dimension)
    else:
        raise ValueError("Invalid distance metric. Choose 'L2' or 'IP'.")
    index.add(embedding_array_reference)
    num_query_points = embedding_array_query.shape[0]
    batch_size = batch_size or num_query_points
    query_labels = []
    for i in range(0, num_query_points, batch_size):
        batch_query = embedding_array_query[i:i + batch_size]
        distances, indices = index.search(batch_query, k)
        if k > 1 and label_consensus == 'majority_voting':
            batch_labels = knn_majority_voting(indices, reference_labels)
        elif k > 1 and label_consensus == 'weighted_voting':
            batch_labels = knn_weighted_voting(indices, distances, reference_labels)
        elif label_consensus == 'centroid_based':
            if i == 0:  # Calculate centroids only once
                centroids = calculate_centroids(embedding_array_reference, reference_labels)
            batch_labels = assign_labels_by_nearest_centroid(batch_query, centroids)
        else:
            batch_labels = [reference_labels[i[0]] for i in indices]
        query_labels.extend(batch_labels)
    end_time = time.time()
    duration_minutes = (end_time - start_time) / 60
    if timed:
        return query_labels, duration_minutes
    else:
        return query_labels
# End of Selection

In [8]:
#| hide
import nbdev; nbdev.nbdev_export()

