In [None]:
import torch
from torch_geometric.nn import inits
from torch import Tensor
from torch.nn import Parameter
from tsl.nn.layers.norm import Norm
import torch.nn.functional as F

In [None]:
class GraphNorm(torch.nn.Module):
    """Adapted GraphNorm from https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/norm/graph_norm.html#GraphNorm
    
    Applies graph normalization over individual graphs as described in the
    `"GraphNorm: A Principled Approach to Accelerating Graph Neural Network
    Training" <https://arxiv.org/abs/2009.03294>`_ paper

    Args:
        in_channels (int): Size of each input sample.
        eps (float, optional): A value added to the denominator for numerical
            stability. (default: :obj:`1e-5`)
    """
    def __init__(self, in_channels: int, eps: float = 1e-5, affine: bool = True):
        super().__init__()

        self.in_channels = in_channels
        self.eps = eps

        if affine:
            self.weight = Parameter(torch.Tensor(in_channels))
            self.bias = Parameter(torch.Tensor(in_channels))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        #Param that learns how much information to retain in the mean
        self.mean_scale = torch.nn.Parameter(torch.Tensor(in_channels))

        self.reset_parameters()

    def reset_parameters(self):
        inits.ones(self.weight)
        inits.zeros(self.bias)
        inits.ones(self.mean_scale)

    def forward(self, x: Tensor) -> Tensor:
        #x = [batch step node channel/feature]
        step_dim = 0
        batch = x.new_zeros(x.size(step_dim), dtype=torch.long)
        mean = torch.mean(x, dim=step_dim, keepdim=True)
        x = x - mean.index_select(step_dim, batch) * self.mean_scale
        std = torch.std(x, dim=step_dim, unbiased=False, keepdim=True)
        
        out = x / (std + self.eps)

        if self.weight is not None and self.bias is not None:
            return self.weight * out + self.bias

        return out


    def __repr__(self):
        return f'{self.__class__.__name__}({self.in_channels})'



In [None]:
class TemporalNorm(torch.nn.Module):
    """Applies normalisation over the temporal dimension of a spatio-temporal
    input tensor (along time-steps).

    Args:
        in_channels (int): Size of each input sample.
        eps (float, optional): A value added to the denominator for numerical
            stability. (default: :obj:`1e-5`)
        affine (bool, optional): If set to :obj:`True`, this module has
            learnable affine parameters :math:`\gamma` and :math:`\beta`.
            (default: :obj:`True`)
    """
    def __init__(self, in_channels, eps=1e-5, affine=True, mean_lr=0.00001, gate_lr=0.001, scale_lr=0.00001):
        super().__init__()

        self.in_channels = in_channels
        self.eps = eps
        # self.spatial_norm = InstanceBatchLayerGraphNorm(in_channels, affine=False)

        if affine:
            self.weight = Parameter(torch.Tensor(in_channels))
            self.bias = Parameter(torch.Tensor(in_channels))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        inits.ones(self.weight)
        inits.zeros(self.bias)

    def forward(self, x: Tensor) -> Tensor:
        # x : [*, steps, nodes, features]
        mean = torch.mean(x, dim=-3, keepdim=True)
        std = torch.std(x, dim=-3, unbiased=False, keepdim=True)
        out = (x - mean) / (std + self.eps)   #Temporal norm

        if self.weight is not None and self.bias is not None:
            out = out * self.weight + self.bias

        return out


    def __repr__(self):
        return f'{self.__class__.__name__}({self.in_channels})'



In [None]:
class UnitedNorm(torch.nn.Module):
    """Adapted from https://github.com/cyh1112/GraphNormalization/blob/master/norm/united_norm.py

    Applies a united spatial norm on the input by combining and weighting different normalization
    strategies on the given input, as stated in the `Learning Graph Normalization for Graph Neural
    Networks <https://arxiv.org/pdf/2009.11746.pdf>` paper.

    This variation of the idea encompasses Instance Norm instead of the originally specified 
    Adjacent-wise normalisation.

    Args:
        in_channels (int): Size of each input sample.
        affine (bool, optional): If set to :obj:`True`, this module has
            learnable affine parameters :math:`\gamma` and :math:`\beta`.
            (default: :obj:`True`)
        caller_class (SpatioTemporalNormExperiment, optional): caller class to store norm weights
    """
    def __init__(self, in_channels, caller_class, affine=True):
        super().__init__()
        self.in_channels = in_channels
        self.caller_class = caller_class

        if (affine):
            self.weight = Parameter(torch.ones(self.in_channels))
            self.bias = Parameter(torch.zeros(self.in_channels))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        #Trainable gate parameters that indicate the contribution of a normalization strategy 
        self.lambda_batch = Parameter(torch.ones(self.in_channels))
        self.lambda_instance = Parameter(torch.ones(self.in_channels))
        self.lambda_layer = Parameter(torch.ones(self.in_channels))
        self.lambda_graph = Parameter(torch.ones(self.in_channels))

        self.batch_norm = self.batch_norm = Norm(norm_type="batch", in_channels=self.in_channels, affine=False)
        self.instance_norm = Norm(norm_type="instance", in_channels=self.in_channels, affine=False)
        self.layer_norm = Norm(norm_type="layer", in_channels=self.in_channels, affine=False)
        self.graph_norm = GraphNorm(self.in_channels, affine=False)

    def forward(self, x):
        lambda_softmax = F.softmax(torch.cat([self.lambda_batch.unsqueeze(0), self.lambda_instance.unsqueeze(0), self.lambda_layer.unsqueeze(0), self.lambda_graph.unsqueeze(0)], dim=0), dim=0)
        x = lambda_softmax[0]*self.batch_norm(x) + lambda_softmax[1]*self.instance_norm(x) + lambda_softmax[2]*self.layer_norm(x) + lambda_softmax[3]*self.graph_norm(x)
        
        if self.weight is not None and self.bias is not None:
            x = x * self.weight + self.bias

        #Append each norm's weighting at every step to track evolution
        self.caller_class.norm_weights.append(list(lambda_softmax.mean(-1)))  #Store the mean weight (across features) for each normalisation strategy

        return x

    def __repr__(self):
        return f'{self.__class__.__name__}({self.in_channels})'



In [None]:
class UnitedTemporalNorm(torch.nn.Module):
    """Applies a united norm using spatial and temporal normalization strategies

    Args:
        in_channels (int): Size of each input sample.
        affine (bool, optional): If set to :obj:`True`, this module has
            learnable affine parameters :math:`\gamma` and :math:`\beta`.
            (default: :obj:`True`)
        caller_class (SpatioTemporalNormExperiment, optional): caller class to store norm weights
    """
    def __init__(self, in_channels, caller_class, affine=True):
        super().__init__()
        self.in_channels = in_channels
        self.caller_class = caller_class

        if (affine):
            self.weight = Parameter(torch.ones(self.in_channels))
            self.bias = Parameter(torch.zeros(self.in_channels))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        self.lambda_batch = Parameter(torch.ones(self.in_channels))
        self.lambda_instance = Parameter(torch.ones(self.in_channels))
        self.lambda_layer = Parameter(torch.ones(self.in_channels))
        self.lambda_graph = Parameter(torch.ones(self.in_channels))
        self.lambda_temporal = Parameter(torch.ones(self.in_channels))
        self.lambda_sptpl = Parameter(torch.ones(self.in_channels))
        self.lambda_tplsp = Parameter(torch.ones(self.in_channels))

        self.batch_norm = Norm(norm_type="batch", in_channels=self.in_channels, affine=False)
        self.instance_norm = Norm(norm_type="instance", in_channels=self.in_channels, affine=False)
        self.layer_norm = Norm(norm_type="layer", in_channels=self.in_channels, affine=False)
        self.graph_norm = GraphNorm(self.in_channels, affine=False)
        self.temporal_norm = TemporalNorm(self.in_channels, affine=False)

    def forward(self, x):
        lambda_softmax = F.softmax(torch.cat([self.lambda_batch.unsqueeze(0), self.lambda_instance.unsqueeze(0), self.lambda_layer.unsqueeze(0), self.lambda_graph.unsqueeze(0), self.lambda_temporal.unsqueeze(0)], dim=0), dim=0)
        x = lambda_softmax[0]*self.batch_norm(x) + lambda_softmax[1]*self.instance_norm(x) + lambda_softmax[2]*self.layer_norm(x) + lambda_softmax[3]*self.graph_norm(x) + lambda_softmax[4]*self.temporal_norm(x)
        
        if self.weight is not None and self.bias is not None:
            x = x * self.weight + self.bias

        #Append each norm's weighting at every step to track evolution
        self.caller_class.norm_weights.append(list(lambda_softmax.mean(-1)))  #Store the mean weight (across features) for each normalisation strategy
        
        return x
        
    def __repr__(self):
        return f'{self.__class__.__name__}({self.in_channels})'

In [None]:
class SpatialThenTemporalNorm(torch.nn.Module):
    """Applies a united spatial norm on the input and then applies a temporal
    norm on the spatially normalized input

    Args:
        in_channels (int): Size of each input sample.
        affine (bool, optional): If set to :obj:`True`, this module has
            learnable affine parameters :math:`\gamma` and :math:`\beta`.
            (default: :obj:`True`)
    """
    def __init__(self, in_channels, affine=True):
        super().__init__()
        self.in_channels = in_channels

        if (affine):
            self.weight = Parameter(torch.ones(self.in_channels))
            self.bias = Parameter(torch.zeros(self.in_channels))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        self.temporal_norm = TemporalNorm(self.in_channels, affine=True)
        self.spatial_norm = UnitedNorm(self.in_channels, affine=True)

    def forward(self, x):
        x = self.spatial_norm(x)
        x = self.temporal_norm(x)

        if self.weight is not None and self.bias is not None:
            x = x * self.weight + self.bias

        return x

    def __repr__(self):
        return f'{self.__class__.__name__}({self.in_channels})'



In [None]:
class TemporalThenSpatialNorm(torch.nn.Module):
    """Applies a temporal norm on the input and then applies a united spatial
    norm on the temporally normalized input

    Args:
        in_channels (int): Size of each input sample.
        affine (bool, optional): If set to :obj:`True`, this module has
            learnable affine parameters :math:`\gamma` and :math:`\beta`.
            (default: :obj:`True`)
    """
    def __init__(self, in_channels, affine=True):
        super().__init__()
        self.in_channels = in_channels

        if (affine):
            self.weight = Parameter(torch.ones(self.in_channels))
            self.bias = Parameter(torch.zeros(self.in_channels))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        self.temporal_norm = TemporalNorm(self.in_channels, affine=True)
        self.spatial_norm = UnitedNorm(self.in_channels, affine=True)

    def forward(self, x):
        x = self.temporal_norm(x)
        x = self.spatial_norm(x)

        if self.weight is not None and self.bias is not None:
            x = x * self.weight + self.bias

        return x

    def __repr__(self):
        return f'{self.__class__.__name__}({self.in_channels})'


In [None]:
class SpatioTemporalNorm(torch.nn.Module):
    """Applies a normalization of the specified type.

    Args:
        norm_type (string): type of normalization to be applied
        in_channels (int): Size of each input sample.
    """
    def __init__(self, norm_type, in_channels, caller_class, **kwargs):
        super().__init__()
        self.norm_type = norm_type
        self.in_channels = in_channels

        if norm_type == "instance":
            self.norm = Norm(norm_type="instance", in_channels=in_channels, **kwargs)
        elif norm_type == "batch":
            self.norm = Norm(norm_type="batch", in_channels=in_channels, **kwargs)
        elif norm_type == "layer":
            self.norm = Norm(norm_type="layer", in_channels=in_channels, **kwargs)
        elif norm_type == "graph":
            self.norm = GraphNorm(in_channels=in_channels, **kwargs)
        elif norm_type == "united":
            self.norm = UnitedNorm(in_channels=in_channels, caller_class=caller_class, **kwargs)
        elif norm_type == "temporal":
            self.norm = TemporalNorm(in_channels=in_channels, **kwargs)
        elif norm_type == "united_temporal":
            self.norm = UnitedTemporalNorm(in_channels=in_channels, caller_class=caller_class, **kwargs)
        elif norm_type == "spatial_then_temporal":
            self.norm = SpatialThenTemporalNorm(in_channels=in_channels, **kwargs)
        elif norm_type == "temporal_then_spatial":
            self.norm = TemporalThenSpatialNorm(in_channels=in_channels, **kwargs)
        elif norm_type == 'none':
            self.norm = torch.nn.Identity()
        else:
            raise ValueError("Please choose one of the following Norm Types: instance, batch, layer, graph, united, temporal, united_temporal, spatial_then_temporal, temporal_then_spatial, none")            

    def forward(self, x: Tensor) -> Tensor:
        """"""
        return self.norm(x)

    def __repr__(self):
        return f'{self.__class__.__name__}({self.norm_type}, {self.in_channels})'