In [1]:
import importlib
import numpy as np
import tmd
import os
import sklearn.tree
import sklearn 

import classification_functions as fcts
from typing import List, Dict

In [2]:
data_path = "./Data/Reconstructed/"
folders_to_treat = ["L5_UPC","L5_TPC_A","L5_TPC_B","L5_TPC_C"]
training_neurons = fcts.load_neurons_from_folders(data_path, folders_to_treat)

Neurons loaded from L5_UPC: 27
Neurons loaded from L5_TPC_A: 64
Neurons loaded from L5_TPC_B: 38
Neurons loaded from L5_TPC_C: 30


In [3]:
cell = data_path+"L2_TPC_A/C170797A-P1.asc"

In [4]:
def load_data(
    all_neurons: Dict[str, List],
    types: List[str],
    neurite_type: str,
    pers_hom_function: str = "radial_distances",
):
    """
    Load persistence diagrams and labels from a neuron dictionary.

    Parameters:
        all_neurons (dict): Dictionary of {group_name: list_of_neurons}
        types (list): List of group names to use
        neurite_type (str): 'basal_dendrite', 'apical_dendrite', 'axon', 'dendrite'
        pers_hom_function (str): Feature type for persistence (e.g. 'radial_distances', 'path')

    Returns:
        labels (list): List of integer labels for each neuron
        pers_diagrams (list): List of persistence diagrams
    """

    labels = []
    pers_diagrams = []

    for label_index, group in enumerate(types):
        neurons = all_neurons[group]
        for neuron in neurons:
            diag = tmd.methods.get_ph_neuron(neuron, feature=pers_hom_function, neurite_type=neurite_type)
            if diag:  # Only keep non-empty diagrams
                pers_diagrams.append(diag)
                labels.append(label_index + 1)  # Label is 1-based index

    return labels, pers_diagrams

### Data Loading

In [5]:
labels, pers_diagrams = load_data(
    all_neurons = training_neurons,
    types = folders_to_treat,
    neurite_type = "apical_dendrite",
    pers_hom_function = "radial_distances"
)


### Make images

In [6]:
xlim, ylim = tmd.vectorizations.get_limits(pers_diagrams)
train_images = [
    tmd.analysis.persistence_image_data(d, xlim=xlim, ylim=ylim).flatten()
    for d in pers_diagrams
]

X_train = [
    tmd.analysis.persistence_image_data(d, xlim=xlim, ylim=ylim).flatten()
    for d in pers_diagrams
]
y_train = labels

### Training Dataset

In [7]:
test_neuron = tmd.io.load_neuron_from_morphio(cell)
pers2test = tmd.methods.get_ph_neuron(test_neuron, feature="radial_distances", neurite_type="apical_dendrite")
pers_image2test = tmd.analysis.persistence_image_data(pers2test, xlim=xlim, ylim=ylim)
X_test = [pers_image2test.flatten()]

In [8]:
from sklearn.tree import DecisionTreeClassifier

# Train
cls = DecisionTreeClassifier()
cls.fit(X_train, y_train)

# Predict
print("Prediction:", cls.predict(X_test))

Prediction: [1]
