In [None]:
%config Completer.use_jedi = False

In [None]:
from typing import Dict, Hashable, Any

import numpy as np
import xarray as xr

from xarray.core.variable import Variable

from xarray.backends.common import BACKEND_ENTRYPOINTS

import tiledb
from tiledb.cf.engines.xarray_engine import TileDBBackendEntrypoint

BACKEND_ENTRYPOINTS["tiledb"] = TileDBBackendEntrypoint

In [None]:
def write_tdb_array(
    path: str, data: np.ndarray, metadata: Dict[Hashable, Any] = None
):
    with tiledb.open(path, "w") as array:
        array[:] = data
        if metadata is not None:
            for key, value in metadata.items():
                array.meta[key] = value

def to_tiledb(dataset: xr.Dataset, path: str):  # noqa: C901
    coords = dataset.coords

    tdb_dims = []
    for name in coords:
        if name in dataset.dims:
            coord = coords[name]
            dtype = coord.dtype
            if dtype.kind not in "iuM":
                raise NotImplementedError(
                    f"TDB Arrays don't work yet with this dtype coord {dtype}"
                )
            min_value = coord.data[0]
            max_value = coord.data[-1]
            if dtype.kind == "M":
                domain = (1, len(coord.data))
                dtype = np.int32
            else:
                # test for NetCDF dimension type coord starting at 0
                if min_value == 0:
                    min_value, max_value = min_value + 1, max_value + 1
                domain = (min_value, max_value)
            tdb_dim = tiledb.Dim(name=name, domain=domain, dtype=dtype)
            tdb_dims.append(tdb_dim)

    dom = tiledb.Domain(*tdb_dims)

    data_vars = []
    data = dict()
    vars_attrs = dict()
    for var in dataset.data_vars:
        var = dataset[var]
        data_var = tiledb.Attr(name=var.name, dtype=var.dtype)
        data_vars.append(data_var)
        data[var.name] = var.data
        vars_attrs[var.name] = var.attrs

    schema = tiledb.ArraySchema(domain=dom, attrs=data_vars, sparse=False)

    data = list(data.values())[0]

    if tiledb.array_exists(path):
        tiledb.remove(path)
    tiledb.DenseArray.create(path, schema)

    metadata = dict()

    data_var_attrs = dict()
    dim_attrs = dict()
    for key, value in dataset.attrs.items():
        if key not in dataset.data_vars and key not in dataset.dims:
            metadata[key] = value
        elif key in dataset.data_vars:
            data_var_attrs[key] = value
        elif key in dataset.dims:
            dim_attrs[key] = value

    for var_name, attrs in data_var_attrs.items():
        key_prefix = f"{_ATTR_PREFIX}{var_name}"
        if isinstance(attrs, dict):
            for attr_name, value in attrs.items():
                key = f"{key_prefix}.{attr_name}"
                if isinstance(value, np.datetime64):
                    value = str(value)
                metadata[key] = value
        else:
            metadata[key_prefix] = attrs

    for dim_name, attrs in dim_attrs.items():
        key_prefix = f"{_DIM_PREFIX}{dim_name}"
        if isinstance(attrs, dict):
            for attr_name, value in attrs.items():
                key = f"{key_prefix}.{attr_name}"
                if isinstance(value, np.datetime64):
                    value = str(value)
                metadata[key] = value
        else:
            metadata[key_prefix] = attrs

    write_tdb_array(path, data, metadata)


In [None]:
data = np.arange(16, dtype=np.uint64)
data_attrs = {"description": "data_var called 'data' with simple int range"}
coords = {'x': np.linspace(0, 1, num=16)}
GeoTransform = str(list((0, 1.0, 0, 0, 0, 1.0)))
ds_attrs = {"description": "gis ds", "GeoTransform": GeoTransform}
var = {"data": Variable(["x"], data, data_attrs)}
dataset = xr.Dataset(data_vars=var, coords=coords, attrs=ds_attrs)
to_tiledb(dataset, '../../../../Downloads/test_gis_dataset')