In [15]:
import torch
import numpy as np
from numpy.random import default_rng
from abc import ABC, ABCMeta, abstractmethod
from deepymod.data import DeePyModGPULoader


class Subsampler(ABC, metaclass=ABCMeta):
    @abstractmethod
    def sample():
        raise NotImplementedError
    
    
class Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        subsampler : Subsampler = None,
        load_kwargs : dict = {},
        preprocess_kwargs : dict = {},
        subsample_kwargs : dict = {},
        normalize_coords :  bool = False,
        normalize_data : bool = False,
        device : str = None,
    ):
        """A dataset class that loads the data, preprocesses it and lastly applies subsampling to it
        Args: 
            subsampler (Subsampler): Function that applies some kind of subsampling to it
            load_kwargs (dict): optional arguments for the load method
            preprocess_kwargs (dict): optional arguments for the preprocess method
            subsample_kwargs (dict): optional arguments for the subsample method
            normalize_coords (bool): apply normalization to the coordinates
            normalize_data (bool): apply normalization to the data
            device (string): which device to send the data to 
        Returns: 
            (torch.utils.data.Dataset)"""
        self.subsampler = subsampler
        self.load_kwargs = load_kwargs
        self.preprocess_kwargs  = preprocess_kwargs 
        self.subsample_kwargs = subsample_kwargs  # so total number of samples is size(self.t_domain) * n_samples_per_frame
        self.device = device
        self.normalize_coords = normalize_coords
        self.normalize_data = normalize_data
        self.coords = None
        self.data = None
        if self.load:
            self.coords, self.data = self.load()
        if self.preprocess:
            self.coords, self.data = self.preprocess(self.coords, self.data, **self.preprocess_kwargs)
        if self.subsampler:
            self.coords, self.data = self.subsampler.sample(self.coords, self.data, **self.subsample_kwargs)
        if self.device:
            self.x.to(self.device)
            self.y.to(self.device)

    # Pytorch methods
    def __len__(self) -> int:
        """ Returns length of dataset. Required by pytorch"""
        return self.number_of_samples
    
    def __getitem__(self, idx: int) -> int:
        """ Returns coordinate and value. First axis of coordinate should be time."""
        return self.coords[idx], self.data[idx]
        
    # User defined methods
    def load(self):
        """Define a load function that loads a dataset from memory, another function or something else."""
        raise NotImplementedError
        
    # Logical methods    
    def preprocess(self, X : torch.tensor, y : torch.tensor, random_state : int =42, noise : float = None):
        """Add noise to the data and normalize the features
        Arguments:
            X (torch.tensor) : coordinates of the dataset
            y (torch.tensor) : values of the dataset
            random_state (int) : state for random number geerator
            noise (float) : standard deviations of noise to add
            """
        # add noise
        y_processed = y + self.add_noise(y, noise, random_state)
        # normalize coordinates
        if self.normalize_coords:
            X_processed = self.apply_normalize(X)
        else:
            X_processed = X
        # normalize data
        if self.normalize_data:
            y_processed = self.apply_normalize(y)
        else:
            y_processed = y      
        return X_processed, y_processed

    @staticmethod
    def add_noise(y, noise_level, random_state):
        """ Adds gaussian white noise of noise_level standard deviation.
        Args:
            y (torch.tensor): the data to which noise should be added
            noise_level (float): add white noise as a function of standard deviation
            random_state (int): the random state used for random number generation
        """
        noise = noise_level * torch.std(y).data
        y_noisy = y + torch.tensor(
            default_rng(random_state).normal(loc=0.0, scale=noise, size=y.shape),
            dtype=torch.float32,
        )  # TO DO: switch to pytorch rng
        return y_noisy

    @staticmethod
    def apply_normalize(X):
        """ minmax Normalize the data along the zeroth axis.
        Args:
            X (torch.tensor): data to be minmax normalized
        Returns: 
            (torch.tensor): minmaxed data"""
        X_norm = (X - X.min(dim=0).values) / (
            X.max(dim=0).values - X.min(dim=0).values
        ) * 2 - 1
        return X_norm

In [16]:
class Subsample_grid(Subsampler):
    @staticmethod
    def sample(grid, grid_data, number_of_samples):
        print(number_of_samples)
        """Subsample on the second axis for data in the format [u, x, t]"""
        # getting indices of samples
        x_idx = torch.linspace(0, grid.shape[1] - 1, number_of_samples, dtype=torch.long)# getting x locations
        # getting sample locations from indices
        subsampled_coords = torch.tensor(grid[:, :, x_idx].reshape(-1, 2))
        subsampled_data = torch.tensor(grid_data[:, :, x_idx].reshape(-1, 1))
        return subsampled_coords, subsampled_data

class Subsample_shifted_grid(Subsampler):
    @staticmethod
    def sub_sample_shifted_grid(grid, grid_data, number_of_samples):
        # getting indices of samples
        x_idx = torch.linspace(0, grid.shape[1] - 1, number_of_samples, dtype=torch.long)# getting x locations
        # getting sample locations from indices
        subsampled_coords = torch.tensor(grid[:, :, x_idx].reshape(-1, 2))
        subsampled_data = torch.tensor(grid_data[:, :, x_idx].reshape(-1, 1))
        return subsampled_coords, subsampled_data

class Subsample_random(Subsampler):
    @staticmethod
    def sub_sample_random(grid, grid_data, number_of_samples):
        # getting indices of samples
        x_idx = torch.linspace(0, grid.shape[1] - 1, number_of_samples, dtype=torch.long)# getting x locations
        # getting sample locations from indices
        subsampled_coords = torch.tensor(grid[:, :, x_idx].reshape(-1, 2))
        subsampled_data = torch.tensor(grid_data[:, :, x_idx].reshape(-1, 1))
        return subsampled_coords, subsampled_data


class MatlabDataset2D(Dataset):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)

  def load(self):
    """ Output: Grid[N x M x L],  Data[N x M x O],
    N = Coordinate dimension 0
    M = Coordinate dimension 1
    L = Input data dimension
    O = Output data dimension
    """
    x0 = np.linspace(0, 2*np.pi, 100)
    x1 = np.linspace(-np.pi, np.pi, 100)
    X0, X1 = np.meshgrid(x0, x1)
    y = np.sinc(X0*X1)
    coords = torch.tensor(np.stack((X0, X1)))#.reshape(-1, 2))
    data = torch.tensor(y).unsqueeze(0)#.reshape(-1, 1))
    return coords, data

In [17]:
md = MatlabDataset(subsampler=Subsample_grid, preprocess_kwargs={"noise":0.01}, subsample_kwargs={"number_of_samples":10})

10


  subsampled_coords = torch.tensor(grid[:, :, x_idx].reshape(-1, 2))
  subsampled_data = torch.tensor(grid_data[:, :, x_idx].reshape(-1, 1))
