<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:2vw; color:#5A7D9F; font-weight:bold;">
    <center>Understanding Ocean Deoxygenation</center>
    </b>
</p>
<hr style="color:#5A7D9F;">

In [1]:
# ---------- Librairies ----------
import matplotlib.pyplot as plt

# ---------- Jupyter ----------
%matplotlib inline
plt.rcParams.update({"font.size": 13})

# Making sure modules are reloaded when modified
%reload_ext autoreload
%autoreload 2

<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:2vw; color:#5A7D9F; font-weight:bold;">
    <center>Score Based Modeling</center>
    </b>
</p>
<hr style="color:#5A7D9F;">

In [None]:
import matplotlib.pyplot as plt
import torch
import yaml

from poseidon.diffusion.denoiser import PoseidonDenoiser
from poseidon.training import training
from poseidon.utils import MemoryUsage
from typing import Tuple

MemoryUsage()

In [11]:
with open("scripts/configs/training.yml", "r") as file:
    configs = yaml.load(file, Loader=yaml.Loader)
training(config_dataset, config_backbone, config_nn, config_training, toy_problem=True)

<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:2vw; color:#5A7D9F; font-weight:bold;">
    <center>Samplers</center>
    </b>
</p>
<hr style="color:#5A7D9F;">

In [None]:
import torch.nn as nn


class PoseidonSampler(nn.Module):
    r""""""

    def __init__(self, steps: int, dimensions: Tuple[int, int, int]):
        super().__init__()

        # Parameters
        sigma_max = 80
        sigma_min = 0.002
        rho = 7

        # Computing timesteps
        steps_tensor = torch.arange(steps)
        sigma_max_rho_ = sigma_max ** (1 / rho)
        sigma_min_rho_ = sigma_min ** (1 / rho)
        self.timesteps = (
            sigma_max_rho_ + (steps_tensor / (steps - 1)) * (sigma_min_rho_ - sigma_max_rho_)
        ) ** rho

        # Storing dimensions
        self.channels, self.latitude, self.longitude = dimensions

    def forward(
        self, denoiser: PoseidonDenoiser, trajectory_size: int, time: torch.Tensor, k: int
    ):
        assert trajectory_size == len(time), "Trajectory size must be equal to the time size"

        # Generating noise
        x = torch.randn(trajectory_size, self.channels, self.latitude, self.longitude)

        # Batching the noise as blanckets
        idx = torch.arange(trajectory_size)
        idx_start = torch.clip(idx - k, min=0)
        idx_end = torch.clip(idx + k + 1, max=trajectory_size)
        pad_start = torch.clip(k - idx, min=0)
        pad_end = torch.clip(idx + k + 1 - trajectory_size, min=0)
        idx_start -= pad_end
        idx_end += pad_start
        batched_noise = [x[start:end, :, :, :] for start, end in zip(idx_start, idx_end)]
        batched_noise = torch.stack(batched_noise, dim=0)

        # Tokenizing Time
        time = time_tokenizer(time).cuda()

        # Flatten Blanket
        batched_noise = batched_noise.flatten(1)

        idx_recomposed = torch.tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2])

        # -----------
        for i in range(len(self.timesteps) - 1):
            time_t = self.timesteps[i]
            time_t1 = self.timesteps[i + 1]
            sigma_t = time_t * torch.ones(trajectory_size, 1)

            # Pushing to GPU
            batched_noise, sigma_t = batched_noise.cuda(), sigma_t.cuda()

            # Denoising
            batched_noise_deco = denoiser(x_t=batched_noise, sigma_t=sigma_t, c=time).cpu()
            batched_noise.cpu()

            # Recomposing Blanket
            batched_noise_deco = batched_noise_deco.reshape(
                trajectory_size, 2 * k + 1, self.channels, self.latitude, self.longitude
            )

            x_recomposed = [
                batched_noise_deco[b, ind, :, :, :] for b, ind in enumerate(idx_recomposed)
            ]
            x_recomposed = torch.stack(x_recomposed, dim=0)

            score = (x_recomposed - x) / time_t**2

            x = x + (-time_t * score) * (time_t1 - time_t)

            # Pushing to CPU
            batched_noise = [x[start:end, :, :, :] for start, end in zip(idx_start, idx_end)]
            batched_noise = torch.stack(batched_noise, dim=0)

        return x


# Creaion of the sampler
ps = PoseidonSampler(steps=6, dimensions=data.shape[2:])

# Generating a sample
sample_X = ps(denoiser, 12, time[0], 1)

In [None]:
import torch
import torch.nn as nn

from typing import Tuple


class PoseidonSampler(nn.Module):
    r"""
    Poseidon Sampler for diffusion process based on a custom noise schedule.
    """

    def __init__(self, steps: int, dimensions: Tuple[int, int, int]):
        super().__init__()

        # Parameters
        sigma_max = 80
        sigma_min = 0.002
        rho = 7

        # Compute timesteps for noise schedule
        steps_tensor = torch.arange(steps)
        sigma_max_rho = sigma_max ** (1 / rho)
        sigma_min_rho = sigma_min ** (1 / rho)
        self.timesteps = (
            sigma_max_rho + (steps_tensor / (steps - 1)) * (sigma_min_rho - sigma_max_rho)
        ) ** rho

        # Store dimensions
        self.channels, self.latitude, self.longitude = dimensions

    def forward(
        self, denoiser: nn.Module, trajectory_size: int, time: torch.Tensor, k: int
    ) -> torch.Tensor:
        """
        Run the sampling process by iterating over the timesteps.

        Arguments:
            denoiser: The denoising network to apply at each step.
            trajectory_size: Number of trajectories.
            time: Time conditioning tensor.
            k: Window size parameter for neighborhood extraction.

        Returns:
            Final denoised tensor.
        """
        assert trajectory_size == len(time), "Trajectory size must be equal to the time size"

        # Generate initial noise
        x = torch.randn(trajectory_size, self.channels, self.latitude, self.longitude)

        # Prepare batched noise (blankets)
        batched_noise = self._batch_noise(x, trajectory_size, k)

        # Tokenize the time
        time = time_tokenizer(time).cuda()

        # Flatten the batched noise for processing
        batched_noise = batched_noise.flatten(1)
        idx_recomposed = torch.tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2])

        # Iterate over the timesteps
        for i in range(len(self.timesteps) - 1):
            x = self.step(i, batched_noise, x, idx_recomposed, denoiser, time, trajectory_size, k)

        return x

    def step(
        self,
        i: int,
        batched_noise: torch.Tensor,
        x: torch.Tensor,
        idx_recomposed: torch.Tensor,
        denoiser: nn.Module,
        time: torch.Tensor,
        trajectory_size: int,
        k: int,
    ) -> torch.Tensor:
        """
        Perform a single diffusion step.

        Arguments:
            i: Current timestep index.
            batched_noise: The noise tensor (batched into blankets).
            x: Current sample tensor.
            idx_recomposed: Indices used for reassembling the trajectory.
            denoiser: Denoising network.
            time: Tokenized time tensor.
            trajectory_size: Size of the trajectory.
            k: Window size for neighborhood.

        Returns:
            Updated sample tensor after applying the diffusion step.
        """
        time_t = self.timesteps[i]
        time_t1 = self.timesteps[i + 1]
        sigma_t = time_t * torch.ones(trajectory_size, 1)

        # Move to GPU
        batched_noise, sigma_t = batched_noise.cuda(), sigma_t.cuda()

        # Denoising
        batched_noise_deco = denoiser(x_t=batched_noise, sigma_t=sigma_t, c=time).cpu()

        # Ensure that batched_noise is moved back to CPU immediately after use
        batched_noise = batched_noise.cpu()

        # Reassemble the batched noise
        batched_noise_deco = batched_noise_deco.reshape(
            trajectory_size, 2 * k + 1, self.channels, self.latitude, self.longitude
        )

        x_recomposed = [
            batched_noise_deco[b, ind, :, :, :] for b, ind in enumerate(idx_recomposed)
        ]
        x_recomposed = torch.stack(x_recomposed, dim=0)

        # Compute the score function
        score = (x_recomposed - x) / time_t**2

        # Update x for the next step
        x = x + (-time_t * score) * (time_t1 - time_t)

        # Move score and x_recomposed back to CPU after computation
        score = score.cpu()
        x_recomposed = x_recomposed.cpu()

        # Prepare batched noise for the next step (moved back to CPU)
        batched_noise = self._batch_noise(x, trajectory_size, k)

        return x

    def _batch_noise(self, x: torch.Tensor, trajectory_size: int, k: int) -> torch.Tensor:
        """
        Create a batched noise tensor based on neighborhood windows of size k.

        Arguments:
            x: The current noise tensor.
            trajectory_size: The number of trajectories.
            k: Window size for the neighborhood.

        Returns:
            A batched noise tensor with windows of size `2k+1`.
        """
        idx = torch.arange(trajectory_size)
        idx_start = torch.clip(idx - k, min=0)
        idx_end = torch.clip(idx + k + 1, max=trajectory_size)

        pad_start = torch.clip(k - idx, min=0)
        pad_end = torch.clip(idx + k + 1 - trajectory_size, min=0)

        idx_start -= pad_end
        idx_end += pad_start

        batched_noise = [x[start:end, :, :, :] for start, end in zip(idx_start, idx_end)]
        return torch.stack(batched_noise, dim=0)


ps = PoseidonSampler(6, dimensions=data.shape[2:])
sample_X = ps(denoiser, 12, time[0], 1)