In [2]:
from dataclasses import dataclass
from typing import Optional, Tuple, Union, List 
import math 
import numpy as np 
import torch 


from diffusers.configuration_utils import ConfigMixin, register_to_config 
from diffusers.utils import BaseOutput, logging 
from diffusers.utils.torch_utils import randn_tensor
from diffusers.schedulers.scheduling_utils import SchedulerMixin

from IPython import embed 


In [5]:
@dataclass
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):

    """ 
    Output class for scheduler's 'step' function output.

    Args:
        prev_sample ('torch.FloatTensor' of shape `(batch_size, num_channels, height, width)` for images):
            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the 
            denoising loop.
    
    """

    prev_sample: torch.FloatTensor







In [None]:
class PyramidFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):

    """ 
    Euler scheduler.

    This model inherits from ['SchedulerMixin'] and ['ConfigMixin']. Check the superclass documentation for the generic 
    methods the library implements for all schedulers such as loading and saving.

    
    """


    _compatibles = []
    order = 1 

    @register_to_config
    def __init__(self,
                 num_train_timesteps:int = 1000,
                 shift: float = 1.0,
                 stages: int = 3,
                 stage_range: List = [0, 1/3, 2/3, 1],
                 gamma: float = 1/3):
        
        super().__init__()
        self.timestep_ratios = {}
        self.timestep_per_stage = {}
        self.sigmas_per_stage = {}
        self.start_sigmas = {}
        self.end_sigmas = {}
        self.ori_start_sigmas = {}

        # self.init_sigmas()
        self.init_sigmas_for_each_stage()
        self.sigma_min = self.sigmas[-1].item()
        self.sigma_max = self.sigmas[0].item()
        self.gamma = gamma 



   





   
    


    
    


   






    

## 0.2 Init Sigmas

- This method initializes the `timesteps` and `sigmas` attributes of the class, which are essential for managing the progression of training steps and their corresponding scaling factor. 
The transformation applied to `sigmas` ensures they are appropriately scaled based on the `shift` parameter 

In [None]:
def init_sigmas(self):
        """ 
        Initialize the global timesteps and sigmas
        """

        num_train_timesteps = self.config.num_train_timesteps     # the total number of training timesteps
        shift = self.config.shift                                 # A scaling factor used in the calculation of sigmas.

        timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)     # creates an array of num_train_timesteps evenly spaced values between 1 and num_train_timesteps.
        [::-1].copy()  # Reverse the array



        timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)    # convert the numpy array to a PyTorch tensor with float32 data type

        sigmas = timesteps / num_train_timesteps                 # normalize the timesteps
        sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)     # Applies a transformation to the sigmas using the shift parameter

        self.timesteps = sigmas * num_train_timesteps            # scale the sigmas back to the range of timesteps


        # initialize this indices
        self._step_index = None 
        self._begin_index = None 


        # Moves the sigmas tensor to the CPU
        self.sigmas = sigmas.to("cpu")

## 0.2 Initialize the  sigmas for each stage 

- This method initialize the timesteps and sigmas for each stage based on the configuration and calculated ratios.

In [None]:
def init_sigmas_for_each_stage(self):

        """ 
        init the timesteps for each stage
        """

        self.init_sigmas()     # initialize sigma values

        stage_distance = []
        stages = self.config.stages 
        training_steps = self.config.num_train_timesteps 
        stage_range = self.config.stage_range 


        # Init the start and end point of each stage 
        # Loop through stages iterates over each stage 
        # start and End indices calculate start and end indices for each stage based on `stage_range` and `training_steps`
        for i_s in range(stages):
            # To decide the start and ends point 
            start_indices = int(stage_range[i_s] * training_steps)
            start_indices = max(start_indices, 0)
            
            end_indices = int(stage_range[i_s+1] * training_steps)
            end_indices = min(end_indices, training_steps)


            
            start_sigma = self.sigmas[end_indices].item() if end_indices < training_steps else 0.0   # Determines the starting sigma value for the stage 
            self.ori_start_sigmas[i_s] = start_sigma             # store the starting sigma value in `self.ori_start_sigmas`


            # correct sigma: for stages other than the first corrects the sigma value using a formula involving `gamma`.
            if i_s != 0:
                ori_sigma = 1 - start_sigma
                gamma = self.config.gamma 

                corrected_sigma = (1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)) * ori_sigma

                start_sigma = 1 - corrected_sigma



        # Determine the ratio of each stage according to flow length 
        tot_distance = sum(stage_distance)      # calculate the total distance by sum `stage_distance`


        # stage ratio: Determine the start and end ratio for each stage based on `stage_distance`
        for i_s in range(stages):
            if i_s == 0:
                start_ratio = 0.0 


            else:
                start_ratio = sum(stage_distance[:i_s]) / tot_distance

            if i_s == stages - 1:
                end_ratio = 1.0 


            else:
                end_ratio = sum(stage_distance[:i_s+1]) / tot_distance


            self.timestep_ratios[i_s] = (start_ratio, end_ratio)





        # Determine the tiemsteps and sigmas for each stage 
        # Timesteps and sigmas: For each stage, calculates the timesteps and sigmas then stores them in `self.timestep_per_stage` and `self.sigmas_per_stage`
        for i_s in range(stages):
            timestep_ratio = self.timestep_ratios[i_s]
            timestep_max = self.timestep[int(timestep_ratio[0] * training_steps)]
            timestep_min = self.timesteps[min(int(timestep_ratio[1] * training_steps), training_steps - 1)]
            timesteps = np.linspace(
                timestep_max, timestep_min, training_steps + 1
            )

            self.timestep_per_stage[i_s] = torch.from_numpy(timesteps[:-1])
            stage_sigmas = np.linspace(
                1, 0, training_steps + 1
            )

            self.sigmas_per_stage[i_s] = torch.from_numpy(stage_sigmas[:-1])




# 0.3 Property

In [1]:
@property
def step_index(self):
        """ 
        The index counter for current timestep. It will increase 1 after each scheduler step.
        """
        return self._step_index
    

@property
def begin_index(self):
        """ 
        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
        """
        return self._begin_index

In [2]:
def set_begin_index(self, 
                        begin_index:int = 0):
        
        """ 
        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

        Args:
            begin_index ('int'):
                The begin index for the scheduler.
        """

        self._begin_index = begin_index


def sigma_to_t(self, sigma):
    return sigma * self.config.num_train_timesteps 

## 0.4 Set Timesteps

- This method sets the timesteps and sigma values for a specific stage, preparing them for the inference process.

In [None]:
def set_timesteps(self, num_inference_steps: int, stage_index: int, device: Union[str, torch.device] = None):
        

        """
            Setting the timesteps and sigmas for each stage 
        """

        self.num_inference_steps = num_inference_steps          # Assign `num_inference_steps` to an instance variable 
        training_steps = self.config.num_train_timesteps        # get the number of training timesteps from the configuration.
        self.init_sigmas()                                      # `self.init_sigmas()` to initilize sigma values

        stage_timesteps = self.timesteps_per_stage[stage_index]      # Retrives the timesteps for the specified stage
        timestep_max = stage_timesteps[0].item()                     # get the maximumn and minimum timesteps for the stage.
        timestep_min = stage_timesteps[-1].item()


        # Generate Timestep: creates a linerly spaced array of timesteps from `timestep_max` to `timestep_min`
        timesteps = np.linspace(
            timestep_max, timestep_min, num_inference_steps,
        )

        # Convert to Tensor: Converts the numpy array to a PyTorch tensor and moves it to the specified device.
        self.timesteps = torch.from_numpy(timesteps).to(device=device)

        stage_sigmas = self.sigmas_per_stage[stage_index]     # stage sigmas: Retrives the sigma values for the specified stage 
        sigma_max = stage_sigmas[0].item()                    # Max and Min sigmas: Gets the maximum and  minimum sigma values for the stage.
        sigma_min = stage_sigmas[-1].item()



        # Generate sigma: Creates a linerly speced array of sigma values from `sigma_max` to `sigma_min`
        ratios = np.linspace(
            sigma_max, sigma_min, num_inference_steps
        )

        sigmas = torch.from_numpy(ratios).to(device=device)    # convert to Tensor: converts the numpy array to a Pytorch tensor and moves it to the specified device.
        self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])    # Concatenate zero:  Adds a zero at the end of the sigma tensor

        self._step_index = None   # Reset Step index: Reset the internal step index to `None`

## 0.5 Index for Timesteps

In [None]:
def index_for_timestep(self, timestep, schedule_timestep=None):

    if schedule_timestep is None:
        schedule_timestep = self.timesteps


    indices = (schedule_timestep == timestep).nonzero()


    # The sigma index that is taken for the **very** first `step`
    # is always the second index (or the last index if there is only 1.)
    # This may we can ensure we don't accidently skip a sigma in 
    # case we start in the middle of the denoising schedule (e.g. For image-to-image)
    pos = 1 if len(indices) > 1 else 0 

    return indices[pos].item()

In [3]:
def _init_step_index(self, timestep):

    if self.begin_index is None:
        if isinstance(timestep, torch.Tensor):
            timestep = timestep.to(self.timesteps.device)

        self._step_index = self.index_for_timestep(timestep)

    else:
        self._step_index = self._begin_index 

In [None]:
def step(
        self,
        model_output: torch.FloatTensor,
        timestep: Union[float, torch.FloatTensor],
        sample: torch.FloatTensor,
        generator: Optional[torch.Generator] = None,
        return_dict: bool = True
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
    
    """
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
        process from the learned model outputs (most often the predicted noise).
        Args:
            model_output (`torch.FloatTensor`):
                The direct output from learned diffusion model.
            timestep (`float`):
                The current discrete timestep in the diffusion chain.
            sample (`torch.FloatTensor`):
                A current instance of a sample created by the diffusion process.
            generator (`torch.Generator`, *optional*):
                A random number generator.
            return_dict (`bool`):
                Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
                tuple.
        Returns:
            [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
                returned, otherwise a tuple is returned where the first element is the sample tensor.
    """


    if (
        isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor)
    ): 
        
        raise ValueError(
            (
                 "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
                    " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
                    " one of the `scheduler.timesteps` as a timestep."
            )
        )
    

    if self.step_index is None:
        self._step_index = 0 


    # Upcast to avoid precision issues when computing prev_sample 
    sample = sample.to(torch.float32)

    sigma = self.sigmas[self.step_index]
    sigma_next = self.sigmas[self.step_index + 1]

    prev_sample = sample + (sigma_next - sigma) * model_output

    # Cast sample back to model compatible dtype 
    prev_sample = prev_sample.to(model_output.dtype)

    # upon completation increase step index by one 
    self._step_index += 1 

    if not return_dict:
        return (prev_sample,)
    

    return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)



In [4]:
def __len__(self):
    return self.config.num_train_timesteps