In [1]:
import torch
from pathlib import Path
import os
import sys

In [2]:
datasets_path=os.path.abspath('../../neuraloperator/neuralop/datasets')
os.listdir(datasets_path)

['mesh_datamodule.py',
 'hdf5_dataset.py',
 '__pycache__',
 'zarr_dataset.py',
 'data',
 'dict_dataset.py',
 'burgers.py',
 'navier_stokes.py',
 'data_transforms.py',
 'pt_dataset.py',
 'tensor_dataset.py',
 'output_encoder.py',
 'tests',
 'darcy.py',
 'transforms.py',
 '__init__.py',
 'spherical_swe.py']

In [3]:
sys.path.insert(0,datasets_path)

In [4]:
from abc import abstractmethod
from typing import List

import torch
from torch.utils.data import Dataset
from neuralop.training.patching import MultigridPatching2D

class Transform(torch.nn.Module):
    """
    Applies transforms or inverse transforms to 
    model inputs or outputs, respectively
    """
    def __init__(self):
        super().__init__()
    
    @abstractmethod
    def transform(self):
        pass

    @abstractmethod
    def inverse_transform(self):
        pass

    @abstractmethod
    def cuda(self):
        pass

    @abstractmethod
    def cpu(self):
        pass

    @abstractmethod
    def to(self, device):
        pass
    
class Normalizer():
    def __init__(self, mean, std, eps=1e-6):
        self.mean = mean
        self.std = std
        self.eps = eps

    def __call__(self, data):
        return (data - self.mean)/(self.std + self.eps)

class Composite(Transform):
    def __init__(self, transforms: List[Transform]):
        """Composite transform composes a list of
        Transforms into one Transform object.

        Transformations are not assumed to be commutative

        Parameters
        ----------
        transforms : List[Transform]
            list of transforms to be applied to data
            in order
        """
        super.__init__()
        self.transforms = transforms
    
    def transform(self, data_dict):
        for tform in self.transforms:
            data_dict = tform.transform(self.data_dict)
        return data_dict
    
    def inverse_transform(self, data_dict):
        for tform in self.transforms[::-1]:
            data_dict = tform.transform(self.data_dict)
        return data_dict

    def to(self, device):
        # all Transforms are required to implement .to()
        self.transforms = [t.to(device) for t in self.transforms if hasattr(t, 'to')]
        return self

class MGPatchingTransform(Transform):
    def __init__(self, model: torch.nn.Module, levels: int, 
                 padding_fraction: float, stitching: float):
        """Wraps MultigridPatching2D to expose canonical
        transform .transform() and .inverse_transform() API

        Parameters
        ----------
        model: nn.Module
            model to wrap in MultigridPatching2D
        levels : int
            mg_patching level parameter for MultigridPatching2D
        padding_fraction : float
            mg_padding_fraction parameter for MultigridPatching2D
        stitching : float
            mg_patching_stitching parameter for MultigridPatching2D
        """
        super.__init__()

        self.levels = levels
        self.padding_fraction = padding_fraction
        self.stitching = stitching
        self.patcher = MultigridPatching2D(model=model, levels=self.levels, 
                                      padding_fraction=self.padding_fraction,
                                      stitching=self.stitching)
    def transform(self, data_dict):
        
        x = data_dict['x']
        y = data_dict['y']

        x,y = self.patcher.patch(x,y)

        data_dict['x'] = x
        data_dict['y'] = y
        return data_dict
    
    def inverse_transform(self, data_dict):
        x = data_dict['x']
        y = data_dict['y']

        x,y = self.patcher.unpatch(x,y)

        data_dict['x'] = x
        data_dict['y'] = y
        return data_dict
    
    def to(self, _):
        # nothing to pass to device
        return self

class RandomMGPatch():
    def __init__(self, levels=2):
        self.levels = levels
        self.step = 2**levels

    def __call__(self, data):

        def _get_patches(shifted_image, step, height, width):
            """Take as input an image and return multi-grid patches centered around the middle of the image
            """
            if step == 1:
                return (shifted_image, )
            else:
                # Notice that we need to stat cropping at start_h = (height - patch_size)//2
                # (//2 as we pad both sides)
                # Here, the extracted patch-size is half the size so patch-size = height//2
                # Hence the values height//4 and width // 4
                start_h = height//4
                start_w = width//4

                patches = _get_patches(shifted_image[:, start_h:-start_h, start_w:-start_w], step//2, height//2, width//2)

                return (shifted_image[:, ::step, ::step], *patches)
        
        x, y = data
        channels, height, width = x.shape
        center_h = height//2
        center_w = width//2

        # Sample a random patching position
        pos_h = torch.randint(low=0, high=height, size=(1,))[0]
        pos_w = torch.randint(low=0, high=width, size=(1,))[0]

        shift_h = center_h - pos_h
        shift_w = center_w - pos_w

        shifted_x = torch.roll(x, (shift_h, shift_w), dims=(0, 1))
        patches_x = _get_patches(shifted_x, self.step, height, width)
        shifted_y = torch.roll(y, (shift_h, shift_w), dims=(0, 1))
        patches_y = _get_patches(shifted_y, self.step, height, width)

        return torch.cat(patches_x, dim=0), patches_y[-1]

class MGPTensorDataset(Dataset):
    def __init__(self, x, y, levels=2):
        assert (x.size(0) == y.size(0)), "Size mismatch between tensors"
        self.x = x
        self.y = y
        self.levels = 2
        self.transform = RandomMGPatch(levels=levels)

    def __getitem__(self, index):
        return self.transform((self.x[index], self.y[index]))

    def __len__(self):
        return self.x.size(0)
    

def regular_grid(spatial_dims, grid_boundaries=[[0, 1], [0, 1]]):
    """
    Appends grid positional encoding to an input tensor, concatenating as additional dimensions along the channels
    """
    height, width = spatial_dims

    xt = torch.linspace(grid_boundaries[0][0], grid_boundaries[0][1],
                        height + 1)[:-1]
    yt = torch.linspace(grid_boundaries[1][0], grid_boundaries[1][1],
                        width + 1)[:-1]

    grid_x, grid_y = torch.meshgrid(xt, yt, indexing='ij')

    grid_x = grid_x.repeat(1, 1)
    grid_y = grid_y.repeat(1, 1)

    return grid_x, grid_y


class PositionalEmbedding2D():
    """A simple positional embedding as a regular 2D grid
    """
    def __init__(self, grid_boundaries=[[0, 1], [0, 1]]):
        """PositionalEmbedding2D applies a simple positional 
        embedding as a regular 2D grid

        Parameters
        ----------
        grid_boundaries : list, optional
            coordinate boundaries of input grid, by default [[0, 1], [0, 1]]
        """
        self.grid_boundaries = grid_boundaries
        self._grid = None
        self._res = None

    def grid(self, spatial_dims, device, dtype):
        """grid generates 2D grid needed for pos encoding
        and caches the grid associated with MRU resolution

        Parameters
        ----------
        spatial_dims : torch.size
             sizes of spatial resolution
        device : literal 'cpu' or 'cuda:*'
            where to load data
        dtype : str
            dtype to encode data

        Returns
        -------
        torch.tensor
            output grids to concatenate 
        """
        # handle case of multiple train resolutions
        if self._grid is None or self._res != spatial_dims: 
            grid_x, grid_y = regular_grid(spatial_dims,
                                      grid_boundaries=self.grid_boundaries)
            grid_x = grid_x.to(device).to(dtype).unsqueeze(0).unsqueeze(0)
            grid_y = grid_y.to(device).to(dtype).unsqueeze(0).unsqueeze(0)
            self._grid = grid_x, grid_y
            self._res = spatial_dims

        return self._grid

    def __call__(self, data, batched=True):
        if not batched:
            if data.ndim == 3:
                data = data.unsqueeze(0)
        batch_size = data.shape[0]
        x, y = self.grid(data.shape[-2:], data.device, data.dtype)
        out =  torch.cat((data, x.expand(batch_size, -1, -1, -1),
                          y.expand(batch_size, -1, -1, -1)),
                         dim=1)
        # in the unbatched case, the dataloader will stack N 
        # examples with no batch dim to create one
        if not batched and batch_size == 1: 
            return out.squeeze(0)
        else:
            return out

In [29]:
from math import prod

def count_model_params(model):
    """Returns the total number of parameters of a PyTorch model
    
    Notes
    -----
    One complex number is counted as two parameters (we count real and imaginary parts)'
    """
    return sum(
        [p.numel() * 2 if p.is_complex() else p.numel() for p in model.parameters()]
    )


def count_tensor_params(tensor, dims=None):
    """Returns the number of parameters (elements) in a single tensor, optionally, along certain dimensions only

    Parameters
    ----------
    tensor : torch.tensor
    dims : int list or None, default is None
        if not None, the dimensions to consider when counting the number of parameters (elements)
    
    Notes
    -----
    One complex number is counted as two parameters (we count real and imaginary parts)'
    """
    if dims is None:
        dims = list(tensor.shape)
    else:
        dims = [tensor.shape[d] for d in dims]
    n_params = prod(dims)
    if tensor.is_complex():
        return 2*n_params
    return n_params

In [30]:

class UnitGaussianNormalizer(Transform):
    """
    UnitGaussianNormalizer normalizes data to be zero mean and unit std.
    """

    def __init__(self, mean=None, std=None, eps=1e-7, dim=None, mask=None):
        """
        mean : torch.tensor or None
            has to include batch-size as a dim of 1
            e.g. for tensors of shape ``(batch_size, channels, height, width)``,
            the mean over height and width should have shape ``(1, channels, 1, 1)``
        std : torch.tensor or None
        eps : float, default is 0
            for safe division by the std
        dim : int list, default is None
            if not None, dimensions of the data to reduce over to compute the mean and std.

            .. important::

                Has to include the batch-size (typically 0).
                For instance, to normalize data of shape ``(batch_size, channels, height, width)``
                along batch-size, height and width, pass ``dim=[0, 2, 3]``

        mask : torch.Tensor or None, default is None
            If not None, a tensor with the same size as a sample,
            with value 0 where the data should be ignored and 1 everywhere else

        Notes
        -----
        The resulting mean will have the same size as the input MINUS the specified dims.
        If you do not specify any dims, the mean and std will both be scalars.

        Returns
        -------
        UnitGaussianNormalizer instance
        """
        super().__init__()

        self.register_buffer("mean", mean)
        self.register_buffer("std", std)
        self.register_buffer("mask", mask)

        self.eps = eps
        if mean is not None:
            self.ndim = mean.ndim
        if isinstance(dim, int):
            dim = [dim]
        self.dim = dim
        self.n_elements = 0

    def fit(self, data_batch):
        self.update_mean_std(data_batch)

    def partial_fit(self, data_batch, batch_size=1):
        if 0 in list(data_batch.shape):
            return
        count = 0
        n_samples = len(data_batch)
        while count < n_samples:
            samples = data_batch[count : count + batch_size]
            # print(samples.shape)
            # if batch_size == 1:
            #     samples = samples.unsqueeze(0)
            if self.n_elements:
                self.incremental_update_mean_std(samples)
            else:
                self.update_mean_std(samples)
            count += batch_size

    def update_mean_std(self, data_batch):
        self.ndim = data_batch.ndim  # Note this includes batch-size
        if self.mask is None:
            self.n_elements = count_tensor_params(data_batch, self.dim)
            self.mean = torch.mean(data_batch, dim=self.dim, keepdim=True)
            self.squared_mean = torch.mean(data_batch**2, dim=self.dim, keepdim=True)
            self.std = torch.std(data_batch, dim=self.dim, keepdim=True)
        else:
            batch_size = data_batch.shape[0]
            dim = [i - 1 for i in self.dim if i]
            shape = [s for i, s in enumerate(self.mask.shape) if i not in dim]
            self.n_elements = torch.count_nonzero(self.mask, dim=dim) * batch_size
            self.mean = torch.zeros(shape)
            self.std = torch.zeros(shape)
            self.squared_mean = torch.zeros(shape)
            data_batch[:, self.mask == 1] = 0
            self.mean[self.mask == 1] = (
                torch.sum(data_batch, dim=dim, keepdim=True) / self.n_elements
            )
            self.squared_mean = (
                torch.sum(data_batch**2, dim=dim, keepdim=True) / self.n_elements
            )
            self.std = torch.std(data_batch, dim=self.dim, keepdim=True)

    def incremental_update_mean_std(self, data_batch):
        if self.mask is None:
            n_elements = count_tensor_params(data_batch, self.dim)
            dim = self.dim
        else:
            dim = [i - 1 for i in self.dim if i]
            n_elements = torch.count_nonzero(self.mask, dim=dim) * data_batch.shape[0]
            data_batch[:, self.mask == 1] = 0

        self.mean = (1.0 / (self.n_elements + n_elements)) * (
            self.n_elements * self.mean + torch.sum(data_batch, dim=dim, keepdim=True)
        )
        self.squared_mean = (1.0 / (self.n_elements + n_elements)) * (
            self.n_elements * self.squared_mean
            + torch.sum(data_batch**2, dim=dim, keepdim=True)
        )
        self.n_elements += n_elements

        # 1/(n_i + n_j) * (n_i * sum(x_i^2)/n_i + sum(x_j^2) - (n_i*sum(x_i)/n_i + sum(x_j))^2)
        # = 1/(n_i + n_j)  * (sum(x_i^2) + sum(x_j^2) - sum(x_i)^2 - 2sum(x_i)sum(x_j) - sum(x_j)^2))
        # multiply by (n_i + n_j) / (n_i + n_j + 1) for unbiased estimator
        self.std = torch.sqrt(self.squared_mean - self.mean**2) * self.n_elements / (self.n_elements - 1)

    def transform(self, x):
        return (x - self.mean) / (self.std + self.eps)

    def inverse_transform(self, x):
        return x * (self.std + self.eps) + self.mean

    def forward(self, x):
        return self.transform(x)

    def cuda(self):
        self.mean = self.mean.cuda()
        self.std = self.std.cuda()
        return self

    def cpu(self):
        self.mean = self.mean.cpu()
        self.std = self.std.cpu()
        return self

    def to(self, device):
        self.mean = self.mean.to(device)
        self.std = self.std.to(device)
        return self

    @classmethod
    def from_dataset(cls, dataset, dim=None, keys=None, mask=None):
        """Return a dictionary of normalizer instances, fitted on the given dataset

        Parameters
        ----------
        dataset : pytorch dataset
            each element must be a dict {key: sample}
            e.g. {'x': input_samples, 'y': target_labels}
        dim : int list, default is None
            * If None, reduce over all dims (scalar mean and std)
            * Otherwise, must include batch-dimensions and all over dims to reduce over
        keys : str list or None
            if not None, a normalizer is instanciated only for the given keys
        """
        for i, data_dict in enumerate(dataset):
            if not i:
                if not keys:
                    keys = data_dict.keys()
        instances = {key: cls(dim=dim, mask=mask) for key in keys}
        for i, data_dict in enumerate(dataset):
            for key, sample in data_dict.items():
                if key in keys:
                    instances[key].partial_fit(sample.unsqueeze(0))
        return instances

In [31]:
from torch.utils.data.dataset import Dataset


class TensorDataset(Dataset):
    def __init__(self, x, y, transform_x=None, transform_y=None):
        assert (x.size(0) == y.size(0)), "Size mismatch between tensors"
        self.x = x
        self.y = y
        self.transform_x = transform_x
        self.transform_y = transform_y

    def __getitem__(self, index):
        x = self.x[index]
        y = self.y[index]
        
        if self.transform_x is not None:
            x = self.transform_x(x)

        if self.transform_y is not None:
            y = self.transform_y(y)

        return {'x': x, 'y':y}

    def __len__(self):
        return self.x.size(0)

In [32]:

class PositionalEmbedding2D():
    """A simple positional embedding as a regular 2D grid
    """
    def __init__(self, grid_boundaries=[[0, 1], [0, 1]]):
        """PositionalEmbedding2D applies a simple positional 
        embedding as a regular 2D grid

        Parameters
        ----------
        grid_boundaries : list, optional
            coordinate boundaries of input grid, by default [[0, 1], [0, 1]]
        """
        self.grid_boundaries = grid_boundaries
        self._grid = None
        self._res = None

    def grid(self, spatial_dims, device, dtype):
        """grid generates 2D grid needed for pos encoding
        and caches the grid associated with MRU resolution

        Parameters
        ----------
        spatial_dims : torch.size
             sizes of spatial resolution
        device : literal 'cpu' or 'cuda:*'
            where to load data
        dtype : str
            dtype to encode data

        Returns
        -------
        torch.tensor
            output grids to concatenate 
        """
        # handle case of multiple train resolutions
        if self._grid is None or self._res != spatial_dims: 
            grid_x, grid_y = regular_grid(spatial_dims,
                                      grid_boundaries=self.grid_boundaries)
            grid_x = grid_x.to(device).to(dtype).unsqueeze(0).unsqueeze(0)
            grid_y = grid_y.to(device).to(dtype).unsqueeze(0).unsqueeze(0)
            self._grid = grid_x, grid_y
            self._res = spatial_dims

        return self._grid

    def __call__(self, data, batched=True):
        if not batched:
            if data.ndim == 3:
                data = data.unsqueeze(0)
        batch_size = data.shape[0]
        x, y = self.grid(data.shape[-2:], data.device, data.dtype)
        out =  torch.cat((data, x.expand(batch_size, -1, -1, -1),
                          y.expand(batch_size, -1, -1, -1)),
                         dim=1)
        # in the unbatched case, the dataloader will stack N 
        # examples with no batch dim to create one
        if not batched and batch_size == 1: 
            return out.squeeze(0)
        else:
            return out

In [33]:
from abc import ABCMeta, abstractmethod

import torch
from neuralop.training.patching import MultigridPatching2D


class DataProcessor(torch.nn.Module, metaclass=ABCMeta):
    def __init__(self):
        """DataProcessor exposes functionality for pre-
        and post-processing data during training or inference.

        To be a valid DataProcessor within the Trainer requires
        that the following methods are implemented:

        - to(device): load necessary information to device, in keeping
            with PyTorch convention
        - preprocess(data): processes data from a new batch before being
            put through a model's forward pass
        - postprocess(out): processes the outputs of a model's forward pass
            before loss and backward pass
        - wrap(self, model):
            wraps a model in preprocess and postprocess steps to create one forward pass
        - forward(self, x):
            forward pass providing that a model has been wrapped
        """
        super().__init__()

    @abstractmethod
    def to(self, device):
        pass

    @abstractmethod
    def preprocess(self, x):
        pass

    @abstractmethod
    def postprocess(self, x):
        pass

    @abstractmethod
    def wrap(self, model):
        pass

    @abstractmethod
    def forward(self, x):
        pass


class DefaultDataProcessor(DataProcessor):
    def __init__(
        self, in_normalizer=None, out_normalizer=None, positional_encoding=None
    ):
        """A simple processor to pre/post process data before training/inferencing a model.

        Parameters
        ----------
        in_normalizer : Transform, optional, default is None
            normalizer (e.g. StandardScaler) for the input samples
        out_normalizer : Transform, optional, default is None
            normalizer (e.g. StandardScaler) for the target and predicted samples
        positional_encoding : Processor, optional, default is None
            class that appends a positional encoding to the input
        """
        super().__init__()
        self.in_normalizer = in_normalizer
        self.out_normalizer = out_normalizer
        self.positional_encoding = positional_encoding
        self.device = "cpu"

    def wrap(self, model):
        self.model = model
        return self

    def to(self, device):
        if self.in_normalizer is not None:
            self.in_normalizer = self.in_normalizer.to(device)
        if self.out_normalizer is not None:
            self.out_normalizer = self.out_normalizer.to(device)
        self.device = device
        return self

    def preprocess(self, data_dict, batched=True):
        x = data_dict["x"].to(self.device)
        y = data_dict["y"].to(self.device)

        if self.in_normalizer is not None:
            x = self.in_normalizer.transform(x)
        if self.positional_encoding is not None:
            x = self.positional_encoding(x, batched=batched)
        if self.out_normalizer is not None and self.train:
            y = self.out_normalizer.transform(y)

        data_dict["x"] = x
        data_dict["y"] = y

        return data_dict

    def postprocess(self, output, data_dict):
        y = data_dict["y"]
        if self.out_normalizer and not self.train:
            output = self.out_normalizer.inverse_transform(output)
            y = self.out_normalizer.inverse_transform(y)
        data_dict["y"] = y
        return output, data_dict

    def forward(self, **data_dict):
        data_dict = self.preprocess(data_dict)
        output = self.model(data_dict["x"])
        output = self.postprocess(output)
        return output, data_dict

class IncrementalDataProcessor(torch.nn.Module):
    def __init__(self, 
                 in_normalizer=None, out_normalizer=None, 
                 positional_encoding=None, device = 'cpu',
                 subsampling_rates=[2, 1], dataset_resolution=16, dataset_indices=[2,3], epoch_gap=10, verbose=False):
        """An incremental processor to pre/post process data before training/inferencing a model
        In particular this processor first regularizes the input resolution based on the sub_list and dataset_indices
        in the spatial domain based on a fixed number of epochs. We incrementally increase the resolution like done 
        in curriculum learning to train the model. This is useful for training models on large datasets with high
        resolution data.

        Parameters
        ----------
        in_normalizer : Transform, optional, default is None
            normalizer (e.g. StandardScaler) for the input samples
        out_normalizer : Transform, optional, default is None
            normalizer (e.g. StandardScaler) for the target and predicted samples
        positional_encoding : Processor, optional, default is None
            class that appends a positional encoding to the input
        device : str, optional, default is 'cpu'
            device 'cuda' or 'cpu' where computations are performed
        subsampling_rates : list, optional, default is [2, 1]
            list of subsampling rates to use
        dataset_resolution : int, optional, default is 16
            resolution of the input data
        dataset_indices : list, optional, default is [2, 3]
            list of indices of the dataset to slice to regularize the input resolution - Spatial Dimensions
        epoch_gap : int, optional, default is 10
            number of epochs to wait before increasing the resolution
        verbose : bool, optional, default is False
            if True, print the current resolution
        """
        super().__init__()
        self.in_normalizer = in_normalizer
        self.out_normalizer = out_normalizer
        self.positional_encoding = positional_encoding
        self.device = device
        self.sub_list = subsampling_rates
        self.dataset_resolution = dataset_resolution
        self.dataset_indices = dataset_indices
        self.epoch_gap = epoch_gap
        self.verbose = verbose
        self.mode = "Train"
        self.epoch = 0
        
        self.current_index = 0
        self.current_logged_epoch = 0
        self.current_sub = self.index_to_sub_from_table(self.current_index)
        self.current_res = int(self.dataset_resolution / self.current_sub)   
        
        print(f'Original Incre Res: change index to {self.current_index}')
        print(f'Original Incre Res: change sub to {self.current_sub}')
        print(f'Original Incre Res: change res to {self.current_res}')
            
    def wrap(self, model):
        self.model = model
        return self

    def to(self, device):
        if self.in_normalizer is not None:
            self.in_normalizer = self.in_normalizer.to(device)
        if self.out_normalizer is not None:
            self.out_normalizer = self.out_normalizer.to(device)
        self.device = device
        return self
    
    def epoch_wise_res_increase(self, epoch):
        # Update the current_sub and current_res values based on the epoch
        if epoch % self.epoch_gap == 0 and epoch != 0 and (
                self.current_logged_epoch != epoch):
            self.current_index += 1
            self.current_sub = self.index_to_sub_from_table(self.current_index)
            self.current_res = int(self.dataset_resolution / self.current_sub)
            self.current_logged_epoch = epoch

            if self.verbose:
                print(f'Incre Res Update: change index to {self.current_index}')
                print(f'Incre Res Update: change sub to {self.current_sub}')
                print(f'Incre Res Update: change res to {self.current_res}')

    def index_to_sub_from_table(self, index):
        # Get the sub value from the sub_list based on the index
        if index >= len(self.sub_list):
            return self.sub_list[-1]
        else:
            return self.sub_list[index]

    def regularize_input_res(self, x, y):
        # Regularize the input data based on the current_sub and dataset_name
        for idx in self.dataset_indices:
            indexes = torch.arange(0, x.size(idx), self.current_sub, device=self.device)
            x = x.index_select(dim=idx, index=indexes)
            y = y.index_select(dim=idx, index=indexes)
        return x, y
    
    def step(self, loss=None, epoch=None, x=None, y=None):
        if x is not None and y is not None:
            self.epoch_wise_res_increase(epoch)
            return self.regularize_input_res(x, y)
        
    def preprocess(self, data_dict, batched=True):
        x = data_dict['x'].to(self.device)
        y = data_dict['y'].to(self.device)

        if self.in_normalizer is not None:
            x = self.in_normalizer.transform(x)
        if self.positional_encoding is not None:
            x = self.positional_encoding(x, batched=batched)
        if self.out_normalizer is not None and self.train:
            y = self.out_normalizer.transform(y)
        
        if self.mode == "Train":
            x, y = self.step(epoch=self.epoch, x=x, y=y)
        
        data_dict['x'] = x
        data_dict['y'] = y

        return data_dict 

    def postprocess(self, output, data_dict):
        y = data_dict['y']
        if self.out_normalizer and not self.train:
            output = self.out_normalizer.inverse_transform(output)
            y = self.out_normalizer.inverse_transform(y)
        data_dict['y'] = y
        return output, data_dict
    
    def forward(self, **data_dict):
        data_dict = self.preprocess(data_dict)
        output = self.model(data_dict['x'])
        output = self.postprocess(output)
        return output, data_dict
    
class MGPatchingDataProcessor(DataProcessor):
    def __init__(
        self,
        model: torch.nn.Module,
        levels: int,
        padding_fraction: float,
        stitching: float,
        device: str = "cpu",
        in_normalizer=None,
        out_normalizer=None,
        positional_encoding=None,
    ):
        """MGPatchingDataProcessor
        Applies multigrid patching to inputs out-of-place
        with an optional output encoder/other data transform

        Parameters
        ----------
        model: nn.Module
            model to wrap in MultigridPatching2D
        levels : int
            mg_patching level parameter for MultigridPatching2D
        padding_fraction : float
            mg_padding_fraction parameter for MultigridPatching2D
        stitching : float
            mg_patching_stitching parameter for MultigridPatching2D
        in_normalizer : neuralop.datasets.transforms.Transform, optional
            OutputEncoder to decode model inputs, by default None
        in_normalizer : neuralop.datasets.transforms.Transform, optional
            OutputEncoder to decode model outputs, by default None
        positional_encoding : neuralop.datasets.transforms.PositionalEmbedding2D, optional
            appends pos encoding to x if used
        device : str, optional
            device 'cuda' or 'cpu' where computations are performed
        positional_encoding : neuralop.datasets.transforms.Transform, optional
        """
        super().__init__()
        self.levels = levels
        self.padding_fraction = padding_fraction
        self.stitching = stitching
        self.patcher = MultigridPatching2D(
            model=model,
            levels=self.levels,
            padding_fraction=self.padding_fraction,
            stitching=self.stitching,
        )
        self.device = device

        # set normalizers to none by default
        self.in_normalizer, self.out_normalizer = None, None
        if in_normalizer:
            self.in_normalizer = in_normalizer.to(self.device)
        if out_normalizer:
            self.out_normalizer = out_normalizer.to(self.device)
        self.positional_encoding = positional_encoding
        self.model = None

    def to(self, device):
        self.device = device
        if self.in_normalizer:
            self.in_normalizer = self.in_normalizer.to(self.device)
        if self.out_normalizer:
            self.out_normalizer = self.out_normalizer.to(self.device)

    def wrap(self, model):
        self.model = model
        return self

    def preprocess(self, data_dict, batched=True):
        """
        Preprocess data assuming that if encoder exists, it has
        encoded all data during data loading

        Params
        ------

        data_dict: dict
            dictionary keyed with 'x', 'y' etc
            represents one batch of data input to a model
        batched: bool
            whether the first dimension of 'x', 'y' represents batching
        """
        data_dict = {
            k: v.to(self.device) for k, v in data_dict.items() if torch.is_tensor(v)
        }
        x, y = data_dict["x"], data_dict["y"]
        if self.in_normalizer:
            x = self.in_normalizer.transform(x)
            y = self.out_normalizer.transform(y)
        if self.positional_encoding is not None:
            x = self.positional_encoding(x, batched=batched)
        data_dict["x"], data_dict["y"] = self.patcher.patch(x, y)
        return data_dict

    def postprocess(self, out, data_dict):
        """
        Postprocess model outputs, including decoding
        if an encoder exists.

        Params
        ------

        data_dict: dict
            dictionary keyed with 'x', 'y' etc
            represents one batch of data input to a model
        out: torch.Tensor
            model output predictions
        """
        y = data_dict["y"]
        out, y = self.patcher.unpatch(out, y)

        if self.out_normalizer:
            y = self.out_normalizer.inverse_transform(y)
            out = self.out_normalizer.inverse_transform(out)

        data_dict["y"] = y

        return out, data_dict

    def forward(self, **data_dict):
        data_dict = self.preprocess(data_dict)
        output = self.model(**data_dict)
        output, data_dict = self.postprocess(output, data_dict)
        return output, data_dict


In [34]:
data_path=os.path.abspath('../../neuraloperator/neuralop/datasets/data')
data_path

'/home/rusted/Projects/Neural_operator/neuraloperator/neuralop/datasets/data'

In [35]:
train_resolution=16
channel_dim=1
n_train=1000
batch_size=32
test_resolutions=[16, 32]
n_tests=[100, 50]
test_batch_sizes=[32, 32]
encode_input=False
encode_output=True
encoding="channel-wise"
positional_encoding=True
grid_boundaries=[[0, 1], [0, 1]]

In [36]:
"""Load the Navier-Stokes dataset"""
data = torch.load(
    Path(data_path).joinpath(f"darcy_train_{train_resolution}.pt").as_posix()
)

In [37]:
data

{'x': tensor([[[ True,  True,  True,  ...,  True,  True, False],
          [ True,  True,  True,  ...,  True, False, False],
          [ True,  True,  True,  ...,  True,  True, False],
          ...,
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False]],
 
         [[ True,  True,  True,  ...,  True,  True,  True],
          [ True,  True,  True,  ...,  True,  True,  True],
          [False,  True,  True,  ...,  True,  True,  True],
          ...,
          [ True,  True,  True,  ..., False, False,  True],
          [ True,  True,  True,  ...,  True,  True,  True],
          [ True,  True, False,  ...,  True,  True,  True]],
 
         [[ True,  True,  True,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False],
          ...,
          [ True,  True,  True,  ..., False,

In [38]:
x_train = (
    data["x"][0:n_train, :, :].unsqueeze(channel_dim).type(torch.float32).clone()
)
y_train = data["y"][0:n_train, :, :].unsqueeze(channel_dim).clone()
del data

In [39]:
n_tests

[100, 50]

In [40]:
idx = test_resolutions.index(train_resolution)
test_resolutions.pop(idx)
n_test = n_tests.pop(idx)
test_batch_size = test_batch_sizes.pop(idx)

In [41]:
data = torch.load(
    Path(data_path).joinpath(f"darcy_test_{train_resolution}.pt").as_posix()
)

In [42]:
data

{'x': tensor([[[False, False, False,  ..., False,  True,  True],
          [False, False, False,  ...,  True,  True,  True],
          [False, False, False,  ...,  True,  True,  True],
          ...,
          [False, False, False,  ...,  True,  True,  True],
          [False, False, False,  ...,  True,  True,  True],
          [False, False, False,  ...,  True,  True,  True]],
 
         [[False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          ...,
          [ True,  True,  True,  ...,  True,  True,  True],
          [ True,  True,  True,  ...,  True,  True,  True],
          [ True,  True,  True,  ...,  True,  True,  True]],
 
         [[False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          ...,
          [ True,  True,  True,  ...,  True,

In [43]:
x_test = data["x"][:n_test, :, :].unsqueeze(channel_dim).type(torch.float32).clone()
y_test = data["y"][:n_test, :, :].unsqueeze(channel_dim).clone()
del data

In [53]:
if encode_input:
    if encoding == "channel-wise":
        reduce_dims = list(range(x_train.ndim))
    elif encoding == "pixel-wise":
        reduce_dims = [0]

    input_encoder = UnitGaussianNormalizer(dim=reduce_dims)
    input_encoder.fit(x_train)
    #x_train = input_encoder.transform(x_train)
    #x_test = input_encoder.transform(x_test.contiguous())
else:
    input_encoder = None


In [45]:

if encode_output:
    if encoding == "channel-wise":
        reduce_dims = list(range(y_train.ndim))
    elif encoding == "pixel-wise":
        reduce_dims = [0]

    output_encoder = UnitGaussianNormalizer(dim=reduce_dims)
    output_encoder.fit(y_train)
    #y_train = output_encoder.transform(y_train)
else:
    output_encoder = None

In [46]:
device=torch.device('cuda:0')

In [47]:
train_db = TensorDataset(
    x_train.to(device),
    y_train.to(device),
)

In [48]:

train_loader = torch.utils.data.DataLoader(
    train_db,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
    persistent_workers=False,
)

In [49]:

test_db = TensorDataset(
    x_test.to(device),
    y_test.to(device),
)
test_loader = torch.utils.data.DataLoader(
    test_db,
    batch_size=test_batch_size,
    shuffle=False,
    num_workers=0,
    pin_memory=True,
    persistent_workers=False,
)

In [50]:

test_loaders = {train_resolution: test_loader}
for (res, n_test, test_batch_size) in zip(
    test_resolutions, n_tests, test_batch_sizes
):
    print(
        f"Loading test db at resolution {res} with {n_test} samples "
        f"and batch-size={test_batch_size}"
    )
    data = torch.load(Path(data_path).joinpath(f"darcy_test_{res}.pt").as_posix())
    x_test = (
        data["x"][:n_test, :, :].unsqueeze(channel_dim).type(torch.float32).clone()
    )
    y_test = data["y"][:n_test, :, :].unsqueeze(channel_dim).clone()
    del data
    #if input_encoder is not None:
        #x_test = input_encoder.transform(x_test)

    test_db = TensorDataset(
        x_test,
        y_test,
    )
    test_loader = torch.utils.data.DataLoader(
        test_db,
        batch_size=test_batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
        persistent_workers=False,
    )
    test_loaders[res] = test_loader 


Loading test db at resolution 32 with 50 samples and batch-size=32


In [51]:

if positional_encoding:
    pos_encoding = PositionalEmbedding2D(grid_boundaries=grid_boundaries)
else:
    pos_encoding = None
data_processor = DefaultDataProcessor(
    in_normalizer=input_encoder,
    out_normalizer=output_encoder,
    positional_encoding=pos_encoding
)


In [52]:
train_loader, test_loaders, data_processor

(<torch.utils.data.dataloader.DataLoader at 0x70e9ae7ab110>,
 {16: <torch.utils.data.dataloader.DataLoader at 0x70e9c5931cd0>,
  32: <torch.utils.data.dataloader.DataLoader at 0x70ea2455add0>},
 DefaultDataProcessor(
   (out_normalizer): UnitGaussianNormalizer()
 ))