In [None]:
import pandas as pd

class PredictionProcessor:
    """
    A class designed to process prediction results and systematically assign gene labels to cell types. 

    Algorithm Overview:
    - Standardize column names for consistency.
    - Generate a DataFrame organizing the top 2 value counts for each cell type, excluding already used labels.
    - Determine the highest counts of gene labels for each cell type and organize them into a DataFrame.
    - Assign the gene label with the highest count to the corresponding cell type and update the used labels set.
    - Remove the assigned gene label from consideration for other cell types.
    - Execute the steps iteratively until all cell types are assigned a label or no more assignments can be made.
    """

    def __init__(self, predictions_df):
        """
        Initializes the PredictionProcessor with the DataFrame containing predictions.

        Parameters:
        predictions_df (DataFrame): The DataFrame containing the prediction results for different cell types and their corresponding gene markers.
        """

        self.predictions_df = predictions_df
        self.free_ones = list(predictions_df.columns)
        self.annotation_dict = {}
        self.used_labels = set()  # Keep track of used gene labels


    def step_0_fix_columns(self): 
        """
        Standardizes the column names by removing 'cells' or 'cell' suffix and trimming any leading/trailing whitespace.
        """
        new_columns = [] 
        for i in self.predictions_df.columns:
            try:
                #new_columns.append(i.replace("cells", "").replace("cell", "").strip()) 
                new_columns.append(i.strip()) 
            except:
                new_columns.append(i) 
        self.predictions_df.columns = new_columns 


    def step_1_create_organized(self):
        """
        Generates a DataFrame that organizes the top 2 most frequent gene labels for each cell type column, 
        excluding labels that have already been used.
        """
        # This step now needs to remove used labels from consideration
        Organized = pd.DataFrame()
        for column_name in self.free_ones:
            # Get value counts excluding used labels
            value_counts = self.predictions_df[column_name][~self.predictions_df[column_name].isin(self.used_labels)].value_counts().reset_index()
            Organized = pd.concat([Organized, value_counts.iloc[:2]], axis=1)  # Take top 2 counts after excluding used labels
        return Organized

    
    def step_2_calculate_highest(self, Organized):
        """
        Identifies the gene labels with the highest occurrence for each cell type from the Organized DataFrame and summarizes this information.
        """
        Highest_Counts = {}
        current_columns = pd.Series(Organized.columns)
        for i in range(0, len(current_columns), 2):
            cell_type_column = current_columns[i]
            count_column = current_columns[i+1]

            first_row = Organized.iloc[0, i:i+2] if len(Organized) > 0 else pd.Series([None, 0], index=[cell_type_column, count_column])
            second_row = Organized.iloc[1, i:i+2] if len(Organized) > 1 else pd.Series([None, 0], index=[cell_type_column, count_column])
            third_row = Organized.iloc[2, i:i+2] if len(Organized) > 2 else pd.Series([None, 0], index=[cell_type_column, count_column])

            Highest_Counts[cell_type_column] = [
                first_row[cell_type_column], first_row[count_column],
                second_row[cell_type_column], second_row[count_column],
                third_row[cell_type_column], third_row[count_column]
            ]

        Highest_Counts_DF = pd.DataFrame(Highest_Counts).T.reset_index()
        Highest_Counts_DF.columns = ['Cell_Type', 'first_row_cell_type', 'first_row_count', 
                                    'second_row_cell_type', 'second_row_count', 
                                    'third_row_cell_type', 'third_row_count']
        return Highest_Counts_DF


    
    def step_3_update_annotation_dict(self, Highest_Counts_DF, target_row = 1):
        """
        Updates the annotation dictionary with the most probable gene label for each cell type based on the Highest_Counts_DF.
        Also updates the used labels set to ensure labels are not reused.
        """
        first_row = Highest_Counts_DF.sort_values(by="first_row_count", ascending=False).iloc[0]
        if target_row == 1:
            cluster_label = first_row["Cell_Type"]
            found_label = first_row["first_row_cell_type"]
            self.annotation_dict[cluster_label] = found_label
            
            self.used_labels.add(found_label)
            return found_label

        elif target_row == 2:
            cluster_label = first_row["Cell_Type"]
            found_label = first_row["second_row_cell_type"]
            self.annotation_dict[cluster_label] = found_label
            
            self.used_labels.add(found_label)
            return found_label

        elif target_row == 3:
            cluster_label = first_row["Cell_Type"]
            found_label = first_row["third_row_cell_type"]
            self.annotation_dict[cluster_label] = found_label
            
            self.used_labels.add(found_label)
            return found_label

        else:
            raise ValueError("target_row must be 1, 2, or 3") 
    

    def step_4_update_free_ones(self, found_label):
        """
        Updates the list of cell type columns (free_ones) by removing those that have already been assigned a label.
        """
        new_free_ones = [] 
        for i in self.free_ones:
            if i != found_label:
                new_free_ones.append(i) 
        self.free_ones =      new_free_ones    


    def step_4_update_free_ones(self, found_label):
        self.free_ones = [column for column in self.free_ones if column != found_label]
    
    def process(self):
        """
        Executes the prediction processing steps iteratively until all labels are assigned or no further assignments can be made.
        """
        self.step_0_fix_columns()
        while self.free_ones:
            print(f"Remaining columns: {len(self.free_ones)}") 
            Organized = self.step_1_create_organized()
            if Organized.empty:
                break  # Exit if Organized becomes empty, indicating no more labels to assign
            Highest_Counts_DF = self.step_2_calculate_highest(Organized)
            found_label = self.step_3_update_annotation_dict(Highest_Counts_DF)
            self.step_4_update_free_ones(found_label)
        
        return self.annotation_dict



In [None]:
def calculate_row_accuracy(row):
    correct = 0
    total = 0
    for col, value in row.items():
        if pd.notna(value):
            total += 1
            if col == value:
                correct += 1
    return correct / total if total else np.nan 