In [181]:
import mlcroissant as mlc

# FileObjects and FileSets define the resources of the dataset.
distribution = [
    # gpt-3 is hosted on a GitHub repository:
    mlc.FileObject(
        id="github-repository",
        name="github-repository",
        description="Generalized Firing Rate Neurons repository on GitHub.",
        content_url="https://github.com/AllenInstitute/GRNN",
        encoding_format="git+https",
        sha256="main",
    ),
    # Within that repository, a FileSet lists all JSONL files:
    mlc.FileSet(
        id="jsonl-files",
        name="jsonl-files",
        description="JSON files are hosted on the GitHub repository.",
        contained_in=["github-repository"],
        encoding_format="application/jsonlines",
        includes="model/*.jsonl",
    ),
]
record_sets = [
    # RecordSets contains records in the dataset.
    mlc.RecordSet(
        id="jsonl",
        name="jsonl",
        # Each record has one or many fields...
        fields=[
            # Fields can be extracted from the FileObjects/FileSets.
            mlc.Field(
                id="jsonl/cell_id",
                name="cell_id",
                description="",
                data_types=mlc.DataType.INTEGER,
                source=mlc.Source(
                    file_set="jsonl-files",
                    # Extract the field from the column of a FileObject/FileSet:
                    extract=mlc.Extract(column="cell_id"),
                ),
            ),
            mlc.Field(
                id="jsonl/cre-line",
                name="cre-line",
                description="The expected completion of the promt.",
                data_types=mlc.DataType.TEXT,
                source=mlc.Source(
                    file_set="jsonl-files",
                    extract=mlc.Extract(column="cre-line"),
                ),
            ),
            mlc.Field(
                id="jsonl/bin_size",
                name="bin_size",
                description=(
                    "The machine learning task appearing as the name of the"
                    " file."
                ),
                data_types=mlc.DataType.INTEGER,
                source=mlc.Source(
                    file_set="jsonl-files",
                    extract=mlc.Extract(column="bin_size"),
                ),
            ),
            mlc.Field(
                id="jsonl/actv_bin_size",
                name="actv_bin_size",
                description=(
                    "The machine learning task appearing as the name of the"
                    " file."
                ),
                data_types=mlc.DataType.INTEGER,
                source=mlc.Source(
                    file_set="jsonl-files",
                    extract=mlc.Extract(column="actv_bin_size"),
                ),
            ),
            mlc.Field(
                id="jsonl/val_evr",
                name="val_evr",
                description=(
                    "The machine learning task appearing as the name of the"
                    " file."
                ),
                data_types=mlc.DataType.FLOAT,
                source=mlc.Source(
                    file_set="jsonl-files",
                    extract=mlc.Extract(column="val_evr"),
                ),
            ),
            mlc.Field(
                id="jsonl/test_evr",
                name="test_evr",
                description=(
                    "The machine learning task appearing as the name of the"
                    " file."
                ),
                data_types=mlc.DataType.FLOAT,
                source=mlc.Source(
                    file_set="jsonl-files",
                    extract=mlc.Extract(column="test_evr"),
                ),
            ),
            mlc.Field(
                id="jsonl/train_loss",
                name="train_loss",
                description=(
                    "The machine learning task appearing as the name of the"
                    " file."
                ),
                data_types=mlc.DataType.FLOAT,
                source=mlc.Source(
                    file_set="jsonl-files",
                    extract=mlc.Extract(column="train_loss"),
                ),
            ),
            mlc.Field(
                id="jsonl/test_loss",
                name="test_loss",
                description=(
                    "The machine learning task appearing as the name of the"
                    " file."
                ),
                data_types=mlc.DataType.FLOAT,
                source=mlc.Source(
                    file_set="jsonl-files",
                    extract=mlc.Extract(column="test_loss"),
                ),
            ),
            mlc.Field(
                id="jsonl/params",
                name="params",
                description=(
                    "The machine learning task appearing as the name of the"
                    " file."
                ),
                # data_types=mlc.DataType.TEXT,
                # source=mlc.Source(
                #     file_set="jsonl-files",
                #     extract=mlc.Extract(column="params"),
                # ),
                sub_fields = [
                    mlc.Field(
                        id="jsonl/params/a",
                        name="a",
                        description="The expected completion of the promt.",
                        data_types=mlc.DataType.FLOAT,
                        repeated=True,
                        source=mlc.Source(
                            #field="id=jsonl/params",
                            file_set="jsonl-files",
                            
                            extract=mlc.Extract(column="params"),
                            transforms=[mlc.Transform(json_path="a[0][0]")], # MUSTFIX : this is not working
                        ),
                    ),
                    mlc.Field(
                        id="jsonl/params/b",
                        name="b",
                        description="The expected completion of the promt.",
                        data_types=mlc.DataType.FLOAT,
                        repeated=True,
                        source=mlc.Source(
                            #field="id=jsonl/params",
                            file_set="jsonl-files",
                            
                            extract=mlc.Extract(column="params"),
                            transforms=[mlc.Transform(json_path="b[0][0]")], # MUSTFIX : this is not working
                        ),
                    ),
                    mlc.Field(
                        id="jsonl/params/ds",
                        name="ds",
                        description="The expected completion of the promt.",
                        data_types=mlc.DataType.FLOAT,
                        repeated=True,
                        source=mlc.Source(
                            #field="id=jsonl/params",
                            file_set="jsonl-files",
                            
                            extract=mlc.Extract(column="params"),
                            transforms=[mlc.Transform(json_path="ds[0]")], # MUSTFIX : this is not working
                        ),
                    ),
                    mlc.Field(
                        id="jsonl/params/bin_size",
                        name="params_bin_size",
                        description="The expected completion of the promt.",
                        data_types=mlc.DataType.INTEGER,
                        repeated=True,
                        source=mlc.Source(
                            #field="id=jsonl/params",
                            file_set="jsonl-files",
                            
                            extract=mlc.Extract(column="params"),
                            transforms=[mlc.Transform(json_path="bin_size")], # MUSTFIX : this is not working
                        ),
                    ),
                    # mlc.Field(
                    #     id="jsonl/params/g",
                    #     name="params_g",
                    #     description="The expected completion of the promt.",
                    #     data_types=mlc.DataType.FLOAT,
                    #     # repeated=True,
                    #     # source=mlc.Source(
                    #     #     #field="id=jsonl/params",
                    #     #     file_set="jsonl-files",
                            
                    #     #     extract=mlc.Extract(column="params"),
                    #     #     transforms=[mlc.Transform(json_path="a[0][0:4]")],
                    #     # ),
                    #     sub_fields = [
                    #         mlc.Field(
                    #             id="jsonl/params/g/max_current",
                    #             name="max_current",
                    #             description="The expected completion of the promt.",
                    #             data_types=mlc.DataType.FLOAT,
                    #             repeated=True,
                    #             source=mlc.Source(
                    #                 #field="id=jsonl/params",
                    #                 #file_set="jsonl-files",
                                    
                    #                 extract=mlc.Extract(column="params"),
                    #                 #transforms=[mlc.Transform(json_path="params.g.max_current")],
                    #                 transforms=[mlc.Transform(json_path="bin_size")],  # MUSTFIX : this is not working
                    #             ),
                    #         ),
                    #         # mlc.Field(
                    #         #     id="jsonl/params/g/max_firing_rate",
                    #         #     name="max_firing_rate",
                    #         #     description="The expected completion of the promt.",
                    #         #     data_types=mlc.DataType.FLOAT,
                    #         #     repeated=True,
                    #         #     source=mlc.Source(
                    #         #         #field="id=jsonl/params",
                    #         #         file_set="jsonl-files",
                                    
                    #         #         extract=mlc.Extract(column="params"),
                    #         #         transforms=[mlc.Transform(json_path="g.max_firing_rate")],
                    #         #     ),
                    #         # ),
                    #         # mlc.Field(
                    #         #     id="jsonl/params/g/poly_coeff",
                    #         #     name="poly_coeff",
                    #         #     description="The expected completion of the promt.",
                    #         #     data_types=mlc.DataType.FLOAT,
                    #         #     repeated=True,
                    #         #     source=mlc.Source(
                    #         #         #field="id=jsonl/params",
                    #         #         file_set="jsonl-files",
                                    
                    #         #         extract=mlc.Extract(column="params"),
                    #         #         transforms=[mlc.Transform(json_path="g.poly_coeff")],
                    #         #     ),
                    #         # ),
                    #         # mlc.Field(
                    #         #     id="jsonl/params/g/b",
                    #         #     name="g_b",
                    #         #     description="The expected completion of the promt.",
                    #         #     data_types=mlc.DataType.FLOAT,
                    #         #     repeated=True,
                    #         #     source=mlc.Source(
                    #         #         #field="id=jsonl/params",
                    #         #         file_set="jsonl-files",
                                    
                    #         #         extract=mlc.Extract(column="params"),
                    #         #         transforms=[mlc.Transform(json_path="g.b")],
                    #         #     ),
                    #         # ),
                    #         # mlc.Field(
                    #         #     id="jsonl/params/g/bin_size",
                    #         #     name="g_bin_size",
                    #         #     description="The expected completion of the promt.",
                    #         #     data_types=mlc.DataType.INTEGER,
                    #         #     repeated=True,
                    #         #     source=mlc.Source(
                    #         #         #field="id=jsonl/params",
                    #         #         file_set="jsonl-files",
                                    
                    #         #         extract=mlc.Extract(column="params"),
                    #         #         transforms=[mlc.Transform(json_path="g.bin_size")],
                    #         #     ),
                    #         # ),
                    #     ]
                    # ),
                ],
            ),
        ],
    )
]

# Metadata contains information about the dataset.
metadata = mlc.Metadata(
    name="gfr-neurons",
    # Descriptions can contain plain text or markdown.
    description=(
        "We present a dataset of over 1000 biologically-derived, parameterized, and differentiable neuronal models. Unlike traditional approaches to biologically realistic single-neuron models, which typically require many parameters and are not fully-differentiable, our proposed model is lightweight and fully-differentiable, enabling it to be easily integrated into artificial networks. The Generalized Firing Rate (GFR) neuron model consists of two sets of filters that integrate a neuron's input and firing rate history over multiple time scales to generate the neuron's membrane potential. These filters collectively form a kernel that can effectively capture the complex temporal dependencies typical of neuronal activity. The membrane potential is subsequently passed through a bio-realistic non-linearity to produce the neuron's firing rate. We fit GFR neurons to patch-clamp electrophysiological recordings on neurons from acute slices of mouse and human cortex. Clustering of obtained neuronal parameters reflects biological properties that partially align to transgenic lines. This dataset is not intended to represent the most accurate models, as more complex models outperform them; nonetheless, it is a collection of lightweight and interpretable models for practitioners interested in building bio-realistic machine learning models, providing competitive accuracy while being fully-differentiable. When integrated into a recurrent neural network (RNN), these models exhibit robust performance on computational tasks like the sequential MNIST, with faster learning compared to a vanilla RNN with a comparable number of parameters. Our dataset thus provides a practical framework for embedding biological realism into artificial neural network architectures, opening new possibilities for both neuroscience research and advanced machine learning applications."
    ),
    cite_as=(
        "@article{gfr2024, title={A dataset of differentiable biologically-derived single neuron models}, "
        " author={Anonymous}, year={2024},"
        " eprint={2024.0000}, archivePrefix={arXiv}, primaryClass={cs.CL} }"
    ),
    url="https://github.com/Anonymous/GRNN",
    distribution=distribution,
    record_sets=record_sets,
)

In [56]:
print(metadata.issues.report())





In [182]:
import json

with open("croissant.json", "w") as f:
  content = metadata.to_json()
  content = json.dumps(content, indent=2)
  #print(content)
  f.write(content)
  f.write("\n")  # Terminate file with newline

In [183]:
dataset = mlc.Dataset(jsonld="croissant.json")
records = dataset.records(record_set="jsonl")

for i, record in enumerate(records):
  print(record)
  if i > 10:
    break

{'cell_id': 566517779, 'cre-line': b'Chrna2-Cre_OE25', 'bin_size': 10, 'actv_bin_size': 20, 'val_evr': 0.594774315578198, 'test_evr': 0.5847444417389941, 'train_loss': 0.152364676816367, 'test_loss': 0.037812300461989, 'a': 8.994219779968262, 'b': 0.7104464769363401, 'ds': 1.0, 'params_bin_size': 10}
{'cell_id': 486875162, 'cre-line': b'Htr3a-Cre_NO152', 'bin_size': 10, 'actv_bin_size': 20, 'val_evr': 0.6407574082969301, 'test_evr': 0.648385856441786, 'train_loss': 0.114419631612205, 'test_loss': 0.101205808199368, 'a': 8.690719604492188, 'b': -0.541902422904968, 'ds': 1.0, 'params_bin_size': 10}
{'cell_id': 562651165, 'cre-line': b'Ndnf-IRES2-dgCre', 'bin_size': 10, 'actv_bin_size': 20, 'val_evr': 0.733286328159182, 'test_evr': 0.7744659475322031, 'train_loss': 0.079649135694039, 'test_loss': 0.11636616633488502, 'a': 8.911795616149902, 'b': -0.920267701148986, 'ds': 1.0, 'params_bin_size': 10}
{'cell_id': 608108585, 'cre-line': b'Tlx3-Cre_PL56', 'bin_size': 10, 'actv_bin_size': 20, '