In [2]:
from __future__ import annotations
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass, field

import numpy as np
import typing as tp


@dataclass
class DataNode:
    data: tp.Any
    children: list[DataNode | FunctionNode] = field(default_factory=list)
    
    
class FunctionNode:
    children: list[DataNode] = field(default_factory=list)


In [4]:
import uuid
import numpy as np
import pandas as pd

@dataclass
class ID(DataNode):
    data: str = str(uuid.uuid4())

@dataclass
class Array(DataNode):
    data: np.ndarray
    
    def min(self) -> float:
        return float(self.data.min())
    
    def max(self) -> float:
        return float(self.data.max())
    
    def mean(self) -> float:
        return float(self.data.mean())
    
    def std(self) -> float:
        return float(self.data.std())
    
    
@dataclass
class Energy(Array):
    name: str = "energy"

@dataclass
class Objective(Array):
    name: str = "objective"

@dataclass
class ConstraintViolation(Array):
    name: str = "constraint_violation"
    

class Table(FunctionNode):
    data: pd.DataFrame
    children: list[Record]

class Record(FunctionNode):
    children: list[DataNode]
    
    def to_dict(self):
        pass
    
    def to_series(self):
        data = {}

class Artifact(FunctionNode):
    data: dict
    


'Energy'

In [11]:
import jijzept as jz
import jijmodeling as jm

sampler = jz.JijSASampler(config="/home/d/.jijzept/config.toml")

problem = jm.Problem("sample")
x = jm.Binary("x", 5)
problem += x[:]
problem += jm.Constraint("onehot", x[:] == 1)

res = sampler.sample_model(problem, {})

<bound method SampleSet.feasible of JijModelingResponse(record=Record(solution={'x': [(([4],), [1], (5,))]}, num_occurrences=[1]), evaluation=Evaluation(energy=[0.0], objective=[1.0], constraint_violations={'onehot': [0.0]}, penalty={}), measuring_time=MeasuringTime(solve=SolvingTime(preprocess=None, solve=0.012308359146118164, postprocess=None), system=SystemTime(post_problem_and_instance_data=0.8866491317749023, request_queue=0.3785228729248047, fetch_problem_and_instance_data=None, fetch_result=2.327130079269409, deserialize_solution=7.677078247070312e-05), total=3.5964653491973877))>

In [67]:
type(float(np.zeros(10).mean().astype(float))) == float

True

In [66]:
type(np.zeros((10, 10)).mean(axis=0)) == np.ndarray

True