In [1]:
from typing import Union, Optional, Tuple
from typing_extensions import Self
from numbers import Number
from typing_extensions import Annotated

import numpy as np
from pydantic import (BaseModel,
                      Field,
                      field_validator,
                      model_validator,
                      ValidationInfo,
                      AfterValidator,
                      computed_field)

In [16]:
# Define what to do with variables that should be single-valued, and what to do with those that are multi-valued
class BatchedVariable(BaseModel,
                      arbitrary_types_allowed=True,
                      validate_assignment=True):
    batched_values: Union[Number, np.ndarray, None]
    batch_size: int = 1
    inner_dimensions: int = 1

    @computed_field
    @property
    def shape(self) -> Tuple[int, int]:
        return (self.batch_size, self.inner_dimensions)

def convert_single_valued(value: Union[float, np.ndarray, None]) -> BatchedVariable:
    if value is None:
        return BatchedVariable(batched_values=None)
    if isinstance(value, Number):
        return BatchedVariable(batched_values=np.array(value))
    elif len(value.shape) <= 2:
        if len(value.shape) == 2 and (1 not in value.shape):
            raise ValueError(f'Single-valued variable cannot be passed as a {value.shape} matrix; ', 
                             'one of the dimensions must be equal to 1!')
        value = value.reshape((-1,1))
        return BatchedVariable(batched_values=value, batch_size=value.shape[0])
    else:
        raise ValueError(f'Single-valued variable cannot be passed as >2D array; shape provided was {value.shape}')
    
def convert_multi_valued(values: Union[None, np.ndarray]) -> BatchedVariable:
    if values is None:
        return BatchedVariable(batched_values=None)
    if len(values.shape) == 1:  # assume these are the inner dimensions
        return BatchedVariable(batched_values=np.atleast_2d(values), inner_dimensions=values.shape[0])
    elif len(values.shape) == 2:  # assume first axis is batch, second is inner
        return BatchedVariable(batched_values=values, batch_size=values.shape[0], inner_dimensions=values.shape[1])
    else:
        raise ValueError(f'Multi-valued variable must be provided as a 1D or 2D array; shape provided was {values.shape}')
    
SingleVar = Annotated[Union[float, np.ndarray], AfterValidator(lambda v: convert_single_valued(v))]
MultiVar = Annotated[Union[float, np.ndarray], AfterValidator(lambda v: convert_multi_valued(v))]

In [17]:
class GeneralContainer(BaseModel,
                       arbitrary_types_allowed=True,
                       validate_assignment=True):
    """
    General container class to store variables that are all numeric (that is, either floats or numpy arrays)
    """

    @property
    def all_fields(self) -> tuple[str, ...]:
        return tuple(self.model_fields.keys())

    def length_field(self, field_name: str) -> int:
        """
        Returns length of provided field name. If the field is a float, returns 1, otherwise, returns length of array.
        If field is None, returns 0.
        """
        field_val = getattr(self, field_name, None)
        if field_val is None:
            return 0
        # elif isinstance(field_val, Sized):
        #     return field_val.shape[1]
        # return 1
        return field_val.inner_dimensions

    def __len__(self) -> int:
        """ Returns total length of all numerical values stored """
        return sum([self.length_field(field_name) for field_name in self.all_fields])

    def to_numpy(self) -> np.ndarray:
        """
        Outputs everything that is stored as a np.ndarray
        """
        relevant_vals = tuple()
        for field_name in self.all_fields:
            field = getattr(self, field_name, None)
            if field is not None:
                relevant_vals += (np.atleast_2d(field.batched_values),)
        vals  = np.hstack(relevant_vals)
        # TODO (vventuri): shoule we return a flattened array if batch size == 1?
        # if vals.shape[0] == 1:
        #     return vals.flatten()
        return vals

    def from_numpy(self, values: np.ndarray) -> None:
        """
        Updates field values from a numpy array
        """
        # We need to know where to start reading from in the array
        begin_index = 0
        for field_name in self.all_fields:
            field_len = self.length_field(field_name)
            if field_len > 0:
                end_index = begin_index + field_len
                if len(values.shape) == 1:
                    new_field_values = values[begin_index:end_index]
                else:
                    new_field_values = values[:, begin_index:end_index]
                setattr(self, field_name, new_field_values)
                begin_index = end_index


In [18]:
class BatchedDummy(GeneralContainer,
                   arbitrary_types_allowed=True,
                   validate_assignment=True):
    # Things the user will set
    # soc: Union[float, np.ndarray] = Field(description='SOC')
    # q0: Optional[Union[float, np.ndarray]] = \
    #     Field(default=None, description='Charge in the series capacitor. Units: Coulomb')
    # i_rc: Optional[np.ndarray] = \
    #     Field(default=None, description='Currents through RC components. Units: Amp')
    # hyst: Union[float, np.ndarray] = Field(default=0, description='Hysteresis voltage. Units: V')
    soc: SingleVar = Field(description='SOC')
    q0: Optional[SingleVar] = \
        Field(default=None, description='Charge in the series capacitor. Units: Coulomb')
    i_rc: Optional[MultiVar] = \
        Field(default=None, description='Currents through RC components. Units: Amp')
    hyst: SingleVar = Field(default=0, description='Hysteresis voltage. Units: V')
    
    # @field_validator('soc', 'q0', 'hyst', mode='after')
    # @classmethod
    # def convert_single_valued(cls, value: Union[float, np.ndarray], info: ValidationInfo) -> np.ndarray:
    #     if isinstance(value, Number):
    #         return value
    #     else:
    #         return value.reshape((-1, 1))
    
    # @field_validator('i_rc', mode='after')
    # @classmethod
    # def convert_multiple_to_2D(cls, values: np.ndarray, info: ValidationInfo) -> np.ndarray:
    #     if len(values.shape) == 1:
    #         return np.atleast_2d(values)
    #     else:
    #         return values


test0 = BatchedDummy(soc=0, i_rc=np.array([1,2,3]), hyst=4)
print(test0.to_numpy())
test0.from_numpy(values=np.array([10,11,12,13,14]))
print(test0.to_numpy())
print('---------------------------------')
test1 = BatchedDummy(soc=np.array([0,1]), i_rc=np.array([[2,3,4], [5,6,7]]), hyst=np.array([8,9]))
print(test1.to_numpy())
test1.from_numpy(values = 10 + np.arange(15).reshape((3,5)))
print(test1.to_numpy())
print('---------------------------------')
# test1 = BatchedDummy(soc=np.array([0,1]), i_rc=np.atleast_3d([[2,3,4], [5,6,7]]), hyst=np.array([8,9]))

[[0. 1. 2. 3. 4.]]
[[10 11 12 13 14]]
---------------------------------
[[0 2 3 4 8]
 [1 5 6 7 9]]
[[10 11 12 13 14]
 [15 16 17 18 19]
 [20 21 22 23 24]]
---------------------------------


In [19]:
from pydantic import PrivateAttr

class Dummy(BaseModel,
            arbitrary_types_allowed=True,
            validate_assignment=True,
            extra='allow'):
    _a: MultiVar = PrivateAttr(default=None)
    a: np.ndarray
    
    @field_validator('a', mode='after')
    @classmethod
    def get_batch(cls, values) -> np.ndarray:
        cls._a = convert_multi_valued(values)
        return values
    
test0 = Dummy(a=np.arange(5))
print(test0.model_fields.keys())
print(test0.model_dump())
print((test0._a))
print(test0.a)
print('--------------------------------------------')
test1 = Dummy(a=np.arange(10).reshape((2,5)))
print(test1.a)
print(test1._a)
test1.a = np.arange(5)
print(test1.a)
print(test1._a)
# test1.a = np.array(5)
# print(test1.a)
# print(test1._a)

dict_keys(['a'])
{'a': array([0, 1, 2, 3, 4])}
batched_values=array([[0, 1, 2, 3, 4]]) batch_size=1 inner_dimensions=5 shape=(1, 5)
[0 1 2 3 4]
--------------------------------------------
[[0 1 2 3 4]
 [5 6 7 8 9]]
batched_values=array([[0, 1, 2, 3, 4],
       [5, 6, 7, 8, 9]]) batch_size=2 inner_dimensions=5 shape=(2, 5)
[0 1 2 3 4]
batched_values=array([[0, 1, 2, 3, 4]]) batch_size=1 inner_dimensions=5 shape=(1, 5)
