# Code documentation

# diffusion.py

<details>
<summary><code>forward_sde</code> Class Documentation</summary>

```python
import torch
class forward_sde: 
    def __init__(self, dimension, final_time, sigma_infty, device=torch.device('cpu')):

        self.d = dimension
        self.final_time = final_time
        self.sigma_infty = sigma_infty
        self.device = device
        self.final = gaussian(dimension, 
                              torch.zeros(dimension, device = device), 
                              self.sigma_infty**2 * torch.eye(dimension, device = device))
    def to(self, device):
        self.device = device
        self.final = self.final.to(device)
```

The `forward_sde` class is a general class to implement stochastic differential equations (SDEs) and serves as a foundation for specific forward SDE implementations like the Ornstein-Uhlenbeck also called Variance-Preserving process (VPSDE) or the scaled Brownian motion also called Variance-Exploding process (VESDE).
### Attributes:

| Attribute       | Type           | Description                                                                                         |
|-----------------|----------------|-----------------------------------------------------------------------------------------------------|
| `d`            | `int`          | Dimension of the state space.                                        |
| `final_time`   | `float`        | Diffusion time $T$ of the SDE.                                       |
| `sigma_infty`  | `float`        | Asymptotic standard deviation. |
| `device`       | `torch.device`       | Computational device (e.g., `'cpu'` or `'cuda'`).                           |


### Methods:

#### `__init__(self, dimension, final_time, sigma_infty, device=torch.device('cpu'))`

Initializes the `forward_sde` class.

#### `to(self, device)`

Transfers the object to the specified computational device.

---

In [6]:
class beta_parametric:
    def __init__(self, a, final_time, beta_min, beta_max):
        self.a = a
        self.final_time = final_time
        self.beta_min = beta_min
        self.beta_max = beta_max
        if a == 0:
            self.delta = (beta_max - beta_min) / final_time
        else: 
            self.delta = (beta_max - beta_min) / (math.exp(self.a * final_time) - 1.)
    def __call__(self, t):
        if np.abs(self.a) < 1e-3: #for numerical stability
            return self.beta_min + self.delta * t
        else:
            return self.beta_min + self.delta * (torch.exp(self.a*t) - 1.)
    def integrate(self, t): 
        if np.abs(self.a) < 1e-3:
            return self.beta_min * t + 0.5 * self.delta * t**2
        else:
            return self.beta_min * t + self.delta * ((torch.exp(self.a*t)-1)/self.a - t)
    def square_integrate(self,t):
        if np.abs(self.a) < 1e-3:
            return self.beta_min**2 * t +  self.beta_min * self.delta * t**2 + (1./3) * self.delta**2 * t**3  #modified
        else:
            res = self.beta_min**2 * t + 2*self.beta_min*self.delta*(torch.exp(self.a*t) / self.a - t) 
            res += (self.delta)**2 * ( (torch.exp(2*self.a*t))/(2* self.a) - 2* (torch.exp(self.a*t))/(self.a) + t)
            res -= (2* self.beta_min * self.delta /self.a - self.delta**2 *(3/2)*(1/self.a))
            return res  
    def change_a(self, a): 
        self.a = a 
        if np.abs(self.a) < 1e-3: 
            self.delta = (self.beta_max - self.beta_min) / self.final_time 
        else:
            self.delta = (self.beta_max - self.beta_min) / (math.exp(self.a * self.final_time) - 1.)