In [1]:
import numpy as np
from moirae.models.base import GeneralContainer

In [2]:
class MixedContainer(GeneralContainer):
    array_1d: np.ndarray
    array_2d: np.ndarray

test = MixedContainer(array_1d=np.array([0,1,2]), array_2d=np.arange(12).reshape((3,4)))
print(test.model_dump())
# Get new values to update, as a collection of 3 rows
new_vals = np.array([[10, 11, 12, 13, 14],
                     [20, 21, 32, 23, 24],
                     [30, 31, 32, 33, 34]])
test.from_numpy(new_vals)
print(test.model_dump())

{'array_1d': array([0, 1, 2]), 'array_2d': array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])}
{'array_1d': array([[10, 11, 12, 13, 14],
       [20, 21, 32, 23, 24],
       [30, 31, 32, 33, 34]]), 'array_2d': array([], shape=(0, 5), dtype=int64)}


In [None]:
class _PropertyHelper():
    def __init__(self, outer_class, property_name: str, single_valued: bool = False) -> None:
        self.outer = outer_class
        self.prop_name = property_name
        self.single_val = single_valued
        if single_valued:
            self.len = 1
    
    def setter(self, new_vals: Union[float, np.ndarray]) -> None:
        batched_vals = getattr(self.outer, '_batched_vals', None)
        if batched_vals is None:  # If it does not exist yet, we have to create it
            if isinstance(new_vals, Number):
                new_vals = np.array([[new_vals]])
                self.len = 1
            else:
                if len(new_vals.shape) > 2:
                    raise ValueError('New values passed to \'' + self.prop_name + 
                                        '\' have more than two dimensions; their shape is ', new_vals.shape)
                if len(new_vals.shape) == 2:
                    if self.single_val:
                        if (1 not in new_vals.shape):
                            raise ValueError('Property \'' + self.name + '\' is single valued!')
                        new_vals = new_vals.reshape((-1,1))
                    else:
                        self.len = new_vals.shape[1]
                if len(new_vals.shape) == 1:
                    self.len = 1
                    new_vals = new_vals.reshape((new_vals.shape[0], 1))
            batched_vals = new_vals
        else:  # If values have already been batched, we need to make sure things will work with new values
            batched_shape = batched_vals.shape
            if isinstance(new_vals, Number):
                new_vals = new_vals * np.ones((batched_shape[0], 1))
                self.len = 1
            else:
                if len(new_vals.shape) > 2:
                    raise ValueError('New values passed to \'' + self.prop_name + 
                                        '\' have more than two dimensions; their shape is ', new_vals.shape)
                if len(new_vals.shape) == 1:
                    if self.single_val:
                        if new_vals.shape[0] != batched_shape[0]:
                            raise ValueError('New values passed to \'' + self.prop_name + 
                                                '\' do not match batch cardinality of %d!' % batched_shape[0])
                        new_vals = new_vals.reshape((-1, 1))
                    else:
                        # This is basically only for the I_RC currents. We want to keep the behavior from before, in
                        # which an array of multiple values indicated the currents through different RC components. 
                        # The assumption here therefore is that, if a single array of shape (X,) is provided, there are 
                        # X RC components, and these values need to be converted into (batch, X) array
                        new_vals = np.tile(new_vals, reps=(batched_shape[0], 1)) 
                        self.len = new_vals.shape(0)                    
                batched_vals = np.hstack((batched_vals, new_vals))
        # Update everything in outer class
        self.outer._batched_values = batched_vals
        self.outer._batched_names += (self.prop_name,)
    
    def getter(self):
        # find where 

In [None]:
from typing import Union, Optional, Tuple
from numbers import Number

import numpy as np
from pydantic import BaseModel, Field

class BatchedDummy(BaseModel,
                   arbitrary_types_allowed=True,
                   validate_assignment=True):
    # Things the user will set
    soc: Union[float, np.ndarray] = Field(description='SOC')
    q0: Optional[Union[float, np.ndarray]] = \
        Field(default=None, description='Charge in the series capacitor. Units: Coulomb')
    i_rc: Optional[np.ndarray] = \
        Field(default=None, description='Currents through RC components. Units: Amp')
    hyst: Union[float, np.ndarray] = Field(default=0, description='Hysteresis voltage. Units: V')
    # Hidden variables?
    _batched_vals: np.ndarray = Field(description='All values batched', init=False)
    _batched_names: Optional[Tuple[str, ...]] = Field(deafult = tuple(), description='Stored property names')


In [33]:
from typing import Union

import numpy as np

class BasicBatch():
    def __init__(self, single: Union[float, np.ndarray], mult: np.ndarray):
        """
        Silly class to test out ideas. Here, 'single' indicates a variable that is single-valued, so, when an array is 
        passed, the assumption is that it is batched. On the other hand, 'mult' is a multi-valued variable, so 
        everything passed must be an array, and if it is 2D, it means it is batched. 
        Let's assume that everything is passed either as a single instance, or batched instances, to simplify. We will
        also assume that the batch cardinality is consistent accross all variables. 
        """
        self._fully_built = False
        self.single = single
        self.mult = mult

    @property
    def single(self) -> np.ndarray:
        return self._batched_values[:,0]
    @single.setter
    def single(self, value):
        batched_vals = getattr(self, '_batched_values', None)
        if batched_vals is not None:
            batched_vals[:,0] = value
        else:
            batched_vals = np.atleast_2d(value).T
        self._batched_values = batched_vals

    @property
    def mult(self) -> np.ndarray:
        return self._batched_values[:,1:]
    @mult.setter
    def mult(self, values) -> None:
        batched_vals = getattr(self, '_batched_values', None)
        if batched_vals is not None:
            if self._fully_built:
                batched_vals[:,1:] = values
            else:
                batched_vals = np.hstack((batched_vals, np.atleast_2d(values)))
                self._fully_built = True
        else:
            batched_vals = np.atleast_2d(values)
        self._batched_values = batched_vals

test0 = BasicBatch(single=0, mult=np.array([1]))
print(f'Single: {test0.single}')
print(f'Mult: {test0.mult}')
print(f'Batched: {test0._batched_values}')
print('------------------------------')
test1 = BasicBatch(single=0, mult=np.array([1,2]))
print(f'Single: {test1.single}')
print(f'Mult: {test1.mult}')
print(f'Batched: {test1._batched_values}')
print('------------------------------')
test2 = BasicBatch(single=np.array([0,1,2]), mult=np.array([[3],[4],[5]]))
print(f'Single: {test2.single}')
print(f'Mult: {test2.mult}')
print(f'Batched: {test2._batched_values}')
print('------------------------------')
test3 = BasicBatch(single=np.array([0,1,2]), mult=np.array([[3,4],[5,6],[7,8]]))
print(f'Single: {test3.single}')
print(f'Mult: {test3.mult}')
print(f'Batched: {test3._batched_values}')

Single: [0]
Mult: [[1]]
Batched: [[0 1]]
------------------------------
Single: [0]
Mult: [[1 2]]
Batched: [[0 1 2]]
------------------------------
Single: [0 1 2]
Mult: [[3]
 [4]
 [5]]
Batched: [[0 3]
 [1 4]
 [2 5]]
------------------------------
Single: [0 1 2]
Mult: [[3 4]
 [5 6]
 [7 8]]
Batched: [[0 3 4]
 [1 5 6]
 [2 7 8]]
