In [1]:
import numpy as np
from scipy.stats.qmc import Sobol
from pydantic import BaseModel, ValidationError
from typing import Type, Any, List, Dict

In [2]:
class SobolSampler:
    def __init__(self, pydantic_class: Type[BaseModel], dimensions: Dict[str, List[float]]):
        self.pydantic_class = pydantic_class
        self.dimensions = dimensions
        self.dimension_names = list(dimensions.keys())
        self.lower_bounds = np.array([v[0] for v in dimensions.values()])
        self.upper_bounds = np.array([v[1] for v in dimensions.values()])
        self.d = len(dimensions)
        self.sampler = Sobol(d=self.d, scramble=True)

    def sample(self, n_samples: int) -> List[BaseModel]:
        samples = self.sampler.random(n_samples)
        scaled_samples = self.lower_bounds + (self.upper_bounds - self.lower_bounds) * samples
        instances = []

        for sample in scaled_samples:
            data = {self.dimension_names[i]: sample[i] for i in range(self.d)}
            try:
                instance = self.pydantic_class(**data)
                instances.append(instance)
            except ValidationError as e:
                print(f"Validation error for sample {data}: {e}")

        return instances

In [3]:
# Example usage:

class ExampleModel(BaseModel):
    x: float
    y: float

In [10]:
dimensions = {
    'x': [0.0, 1.0],
    'y': [0.0, 1.0]
}

In [9]:
sampler = Sampler(ExampleModel, dimensions)

In [3]:
samples = sampler.sample(10)

for s in samples:
    print(s)

x=0.965993114747107 y=0.4378434782847762
x=0.3657691217958927 y=0.9745242334902287
x=0.1677903849631548 y=0.18230918888002634
x=0.5004286309704185 y=0.6553836856037378
x=0.6885354518890381 y=0.0008815405890345573
x=0.10417966265231371 y=0.5364559143781662
x=0.4296379489824176 y=0.36921071726828814
x=0.7776579912751913 y=0.8433913569897413
x=0.8349188230931759 y=0.2099787648767233
x=0.49739681277424097 y=0.7370263384655118


  sample = self._random(n, workers=workers)


In [5]:
more_samples = sampler.sample(10)

for s in more_samples:
    print(s)

x=0.6617371942847967 y=0.37738210428506136
x=0.013912404887378216 y=0.9138329084962606
x=0.45642897207289934 y=0.24269575905054808
x=0.8679332584142685 y=0.7160285785794258
x=0.7992169056087732 y=0.333571657538414
x=0.4001637762412429 y=0.8594986675307155
x=0.06699216458946466 y=0.046289561316370964
x=0.7336157970130444 y=0.5105786891654134
x=0.5451959101483226 y=0.14568642154335976
x=0.1307628881186247 y=0.672475672326982


In [7]:
sampler.dimensions

{'x': [0.0, 1.0], 'y': [0.0, 1.0]}