In [70]:
from typing import Union, Optional, Tuple, Dict, Sized
from numbers import Number

import numpy as np
from pydantic import BaseModel, Field, field_validator, ValidationInfo

class GeneralContainer(BaseModel,
                       arbitrary_types_allowed=True,
                       validate_assignment=True):
    """
    General container class to store variables that are all numeric (that is, either floats or numpy arrays)
    """

    @property
    def all_fields(self) -> tuple[str, ...]:
        return tuple(self.model_fields.keys())

    def length_field(self, field_name: str) -> int:
        """
        Returns length of provided field name. If the field is a float, returns 1, otherwise, returns length of array.
        If field is None, returns 0.
        """
        field_val = getattr(self, field_name, None)
        if field_val is None:
            return 0
        elif isinstance(field_val, Sized):
            return field_val.shape[1]
        return 1

    def __len__(self) -> int:
        """ Returns total length of all numerical values stored """
        return sum([self.length_field(field_name) for field_name in self.all_fields])

    def to_numpy(self) -> np.ndarray:
        """
        Outputs everything that is stored as a np.ndarray
        """
        relevant_vals = tuple()
        for field_name in self.all_fields:
            field = getattr(self, field_name, None)
            if field is not None:
                relevant_vals += (np.atleast_2d(field),)
        vals  = np.hstack(relevant_vals)
        # TODO (vventuri): shoule we return a flattened array if batch size == 1?
        # if vals.shape[0] == 1:
        #     return vals.flatten()
        return vals

    def from_numpy(self, values: np.ndarray) -> None:
        """
        Updates field values from a numpy array
        """
        # We need to know where to start reading from in the array
        begin_index = 0
        for field_name in self.all_fields:
            field_len = self.length_field(field_name)
            if field_len > 0:
                end_index = begin_index + field_len
                if len(values.shape) == 1:
                    new_field_values = values[begin_index:end_index]
                else:
                    new_field_values = values[:, begin_index:end_index]
                setattr(self, field_name, new_field_values)
                begin_index = end_index


In [75]:
class BatchedDummy(GeneralContainer,
                   arbitrary_types_allowed=True,
                   validate_assignment=True):
    # Things the user will set
    soc: Union[float, np.ndarray] = Field(description='SOC')
    q0: Optional[Union[float, np.ndarray]] = \
        Field(default=None, description='Charge in the series capacitor. Units: Coulomb')
    i_rc: Optional[np.ndarray] = \
        Field(default=None, description='Currents through RC components. Units: Amp')
    hyst: Union[float, np.ndarray] = Field(default=0, description='Hysteresis voltage. Units: V')
    
    @field_validator('soc', 'q0', 'hyst', mode='after')
    @classmethod
    def convert_single_valued(cls, value: Union[float, np.ndarray], info: ValidationInfo) -> np.ndarray:
        if isinstance(value, Number):
            return value
        else:
            return value.reshape((-1, 1))
    
    @field_validator('i_rc', mode='after')
    @classmethod
    def convert_multiple_to_2D(cls, values: np.ndarray, info: ValidationInfo) -> np.ndarray:
        if len(values.shape) == 1:
            return np.atleast_2d(values)
        else:
            return values


test0 = BatchedDummy(soc=0, i_rc=np.array([1,2,3]), hyst=4)
print(test0.to_numpy())
test0.from_numpy(values=np.array([10,11,12,13,14]))
print(test0.to_numpy())
print('---------------------------------')
test1 = BatchedDummy(soc=np.array([0,1]), i_rc=np.array([[2,3,4], [5,6,7]]), hyst=np.array([8,9]))
print(test1.to_numpy())
test1.from_numpy(values = 10 + np.arange(15).reshape((3,5)))
print(test1.to_numpy())

[[0. 1. 2. 3. 4.]]
[[10 11 12 13 14]]
---------------------------------
[[0 2 3 4 8]
 [1 5 6 7 9]]
[[10 11 12 13 14]
 [15 16 17 18 19]
 [20 21 22 23 24]]
