In [21]:
import numpy as np
import pandas as pd
import scipy.stats
import scipy.spatial.distance

In [22]:
import os
import re

# Change directory to the root of the SpatialDeconBenchmark repository
os.chdir('/'.join(re.match(r'^(.*SpatialDeconBenchmark)', os.getcwd()).group(0).split('/')))

In [23]:
class Metrics:
    def __init__(self, method_name, ground_truth_mtx, deconvolved_mtx):
        """
        Initialize the Metrics object with method, ground truth matrix, and deconvolved matrix.
        """
        self.method = method_name
        self.ground_truth_mtx = ground_truth_mtx
        self.deconvolved_mtx = deconvolved_mtx
        self.num_cell_types = None
        self.num_spots = None
        self.rmse = None
        self.jsd = None

    def check_mtx_format(self):
        """
        Ensures matrices are in the correct format (grid by cell type) with matching indices and columns.
        """
        # Check if cell type labels match
        if set(self.ground_truth_mtx.columns) != set(self.deconvolved_mtx.columns):
            raise ValueError("Cell type labels do not match")

        # Check if grid labels match
        if set(self.ground_truth_mtx.index) != set(self.deconvolved_mtx.index):
            raise ValueError("Spot labels do not match")
        
        self.deconvolved_mtx = self.deconvolved_mtx.loc[self.ground_truth_mtx.index, self.ground_truth_mtx.columns]
        self.num_cell_types = self.deconvolved_mtx.shape[1]
        self.num_spots = self.deconvolved_mtx.shape[0]

    def calculate_rmse(self):
        """
        Calculates RMSE from the ground truth and deconvolved matrix.
        Returns a dictionary with average RMSE and RMSE per cell type.
        """
        sum_mse = 0
        rmse_per_cell_type = []
        
        for i in range(self.num_cell_types):
            p_true = self.ground_truth_mtx.iloc[:, i]
            p_pred = self.deconvolved_mtx.iloc[:, i]
        
            mse = np.sum((p_true - p_pred)**2)
            sum_mse += mse
            rmse_per_cell_type.append(np.sqrt(mse / self.num_spots))
            
        rmse_per_cell_type = pd.Series(rmse_per_cell_type, index=self.ground_truth_mtx.columns)

        avg_rmse = np.sqrt(sum_mse / (self.num_spots * self.num_cell_types)) 

        self.rmse = {"Average RMSE": avg_rmse, "RMSE per cell type": rmse_per_cell_type}

        return self.rmse
    
    def calculate_jsd(self, use_scipy=True):
        """
        Compute Jensen-Shannon divergence per cell type using either the scipy implementation or a custom method.
        Returns JSD per cell type.
        """
        jsd_per_cell_type = []

        for i in range(self.num_cell_types):
            p_true = self.ground_truth_mtx.iloc[:, i]
            p_pred = self.deconvolved_mtx.iloc[:, i]

            if use_scipy:
                jsd = scipy.spatial.distance.jensenshannon(p_true, p_pred)**2
            else:
                # Custom JSD calculation (requires verification and testing)
                p_true_dist = scipy.stats.rv_histogram(np.histogram(p_true, bins=10))
                p_pred_dist = scipy.stats.rv_histogram(np.histogram(p_pred, bins=10))

                p_true_quantiles = [np.quantile(p_true, x) for x in [0.25, 0.5, 0.75]]
                p_pred_quantiles = [np.quantile(p_pred, x) for x in [0.25, 0.5, 0.75]]
                p_true_pdf = p_true_dist.pdf(p_true_quantiles)
                p_pred_pdf = p_pred_dist.pdf(p_pred_quantiles)

                mean_pdf = (p_true_pdf + p_pred_pdf) / 2

                kld_true = np.sum([x * np.log(x / y) for x, y in zip(p_true_pdf, mean_pdf) if y != 0])
                kld_pred = np.sum([x * np.log(x / y) for x, y in zip(p_pred_pdf, mean_pdf) if y != 0])

                jsd = (kld_true + kld_pred) / 2
                
            jsd_per_cell_type.append(jsd)
            
        jsd_per_cell_type = pd.Series(jsd_per_cell_type, index=self.ground_truth_mtx.columns)
        self.jsd = jsd_per_cell_type

        return self.jsd

In [24]:
ground_truth = pd.read_csv(os.getcwd() + "/data/seqFISH/ground_truth.csv", index_col=0)
results_dir = os.getcwd() + "/results/methods/seqFISH/"

benchmark = []
for result in os.listdir(results_dir):
    method = result.split("_")[1].split(".")[0]
    mtx = pd.read_csv(results_dir + result, index_col=0)
    metrics = Metrics(method, ground_truth, mtx)
    metrics.check_mtx_format()
    metrics.calculate_jsd()
    metrics.calculate_rmse()
    
    data = {
    'method': metrics.method,
    'JSD': metrics.jsd.mean(),
    'total_RMSE': metrics.rmse['Average RMSE']
    }
    
    data.update(metrics.rmse['RMSE per cell type'][1:].to_dict())
    benchmark.append(pd.DataFrame([data]))
        
benchmark = pd.concat(benchmark, ignore_index=True)
benchmark.to_csv(os.getcwd() + "/results/benchmark/seqFISH_Benchmark_Results.csv", index=False)

Index(['astrocytes', 'eNeuron', 'endo_mural', 'iNeuron', 'microglia', 'olig'], dtype='object')
Index(['astrocytes', 'endo_mural', 'eNeuron', 'iNeuron', 'microglia', 'olig'], dtype='object')
Index(['astrocytes', 'eNeuron', 'endo_mural', 'iNeuron', 'microglia', 'olig'], dtype='object')
Index(['astrocytes', 'endo_mural', 'eNeuron', 'iNeuron', 'microglia', 'olig'], dtype='object')
Index(['astrocytes', 'eNeuron', 'endo_mural', 'iNeuron', 'microglia', 'olig'], dtype='object')
Index(['iNeuron', 'eNeuron', 'olig', 'microglia', 'endo_mural', 'astrocytes'], dtype='object')
Index(['astrocytes', 'eNeuron', 'endo_mural', 'iNeuron', 'microglia', 'olig'], dtype='object')
Index(['Olig', 'eNeuron', 'endo_mural', 'iNeuron', 'astrocytes', 'microglia'], dtype='object')


ValueError: Cell type labels do not match