In [None]:
import mvstudio.data
import numpy as np
import pickle
from IPython.display import display
from collections import defaultdict

h = mvstudio.data.Hierarchy()
display(h)

In [None]:
def count_cells_per_cluster(cellMeansID:str, clusterID:str, cellSegID=None):
    """ 
    Analyse the number of cells for every cluster and optionally get cell identifiers per cluster.

    Data is loaded with dataset IDs which can either be found in the mvstudio.data.Hierarchy() or by right-click the item in Manivault studio and click 'Copy dataset ID'
    cellMeansID = dataset ID of Cluster Means object obtained by creating a Mean Dataset from cell mask clusters in ManiVault Studio.
    clusterID = cluster (mean-shift) of multiple grouped Cluster Means objects of which the object given by cellMeansID should be one.
    optional:
        cellSegID = cell mask cluster object with cell names/identifiers (Generated by Analyze -> Extract Clusters on a cell mask item) 
    
    """
    
    def dprint(text):
        print(text)
        display(text)
    
    output = []
    h = mvstudio.data.Hierarchy()
    try:
        cellMeans = h.getItemByDataID(cellMeansID)
        clusters = h.getItemByDataID(clusterID)
    except SyntaxError:
        dprint('Please provide imageID and clusterID in string format')
        return None
    
    if len(cellMeans._hierarchy_id) == 1:
        dprint('It looks like the ID of the whole dataset is given, provide ID of the cell (or cluster) Means dataset')
        if len(cellMeans._children) > 0:
            alternative = False
            for i in cellMeans._children:
                if i._type.name == 'Points':
                    cellMeansID = i.datasetId
                    dprint(f'Found point set that might be cell means dataset with ID: {cellMeansID}')
                    dprint(f'Using {cellMeansID} for cell means') 
                    cellMeans = h.getItemByDataID(cellMeansID)
                    alternative = True  
            if not alternative:
                dprint('Exiting.....')
                return None
        else:
            dprint('Exiting.....')
            return None           
                
    if cellMeans._type.name == 'Image':
        dprint('It looks like the ID of the image dataset is given, provide ID of the cell (or cluster) Means dataset')
        dprint('Exiting.....')
        return None
    
    meansDatasetIndex = [clusters._hierarchy_id[0]]
    meansDataset = h.getItemByIndex(meansDatasetIndex)
    
    c_names = clusters.cluster.names
    nr_clusters = len(c_names)
    per_cluster_counts = np.zeros(nr_clusters)

    cmname = cellMeans._name
    if cellSegID:       
        per_cluster_cells = defaultdict(list)
        cellSegData = h.getItemByDataID(cellSegID)
        cSN_name = cellSegData.name
        dprint(f'Using cell segmentation in {cSN_name} for cells in {cmname}')
        cellSegNames = cellSegData.cluster.names
        
    dprint('Matching cells to clusters.....') 
    for clusterId in range(nr_clusters):
        points = meansDataset.points[clusters.cluster.indices[clusterId]]
        # point_set = {tuple(p) for p in points}
        point_set = {hash(p.tobytes()) for p in points}
        for j, p1 in enumerate(cellMeans.points):
            # compare = tuple(p1) in point_set
            compare = hash(p1.tobytes()) in point_set
            if compare:
                per_cluster_counts[clusterId] += 1
                if cellSegID:
                    cellName = cellSegNames[j]
                    per_cluster_cells[c_names[clusterId]].append(cellName)

    dprint(f'Total cells per cluster for {cmname}: ')
    for i in range(nr_clusters):
        name = c_names[i]
        count = int(per_cluster_counts[i])
        output.append(f"{name} : {count}")

    if cellSegID:
        return dict(zip(c_names, per_cluster_counts)), per_cluster_cells
    else:
        return dict(zip(c_names, per_cluster_counts))

In [None]:
# With Cell identifiers per cluster.
d = count_cells_per_cluster(cellMeansID='218f357e-be63-49ed-82f4-25bd56e56cb1', clusterID='ab1cc3a8-bcc9-43d2-a8ff-f17d9a2129b8', cellSegID='999715c1-13e2-4c86-a46a-4f53992fdb4c')
display('Cell counts: ', d[0])
display('Cell names per cluster: ', d[1])

# Without Cell identifiers
d = count_cells_per_cluster(cellMeansID='218f357e-be63-49ed-82f4-25bd56e56cb1', clusterID='ab1cc3a8-bcc9-43d2-a8ff-f17d9a2129b8')
display('Cell counts: ', d)
