# Model definitions

In [1]:
# https://github.com/alebeck/batchensemble-pytorch/blob/main/layers.py
# https://github.com/quanpn90/NMTGMinor/blob/607ed45d0a3287dcbb064f012d3101300a95e891/onmt/modules/batch_ensemble/batch_ensemble_linear.py
import math
import numpy as np
import torch
from torch import nn as nn
from torch.nn import functional as F
from torch.nn import init
from lifelong_rl.torch import pytorch_util as ptu
from lifelong_rl.torch.modules import LayerNorm


def identity(x):
    return x


class BatchEnsembleConv2D(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, num_models, stride=1, padding=0, bias=True):
        super().__init__()

        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.num_models = num_models

        self.alpha = nn.Parameter(torch.empty(num_models, in_channels))
        self.gamma = nn.Parameter(torch.empty(num_models, out_channels))
        self.conv = nn.Conv2d(in_channels, out_channels,
                              kernel_size, stride, padding, bias=False)

        if bias:
            # use one bias vector per ensemble member
            self.bias = nn.Parameter(torch.empty(num_models, out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def forward(self, x):
        """
        X: Tensor of shape (B * M, C_in, H, W)
        Ensemble members should be stacked in BATCH dimension.
        Dim 0 layout:
            ------ batch elem 0, model 0 ------
            -------batch elem 1, model 0 ------
                      ...
            ------ batch elem 0, model n ------
            -------batch elem 1, model n ------
                      ...
        """
        batch_size = x.shape[0]
        # arguably this is the actual batch size
        examples_per_model = batch_size // self.num_models

        alpha = self.alpha.tile(1, examples_per_model).view(
            batch_size, self.in_channels)[:, :, None, None]
        gamma = self.gamma.tile(1, examples_per_model).view(
            batch_size, self.out_channels)[:, :, None, None]

        x = self.conv(x * alpha) * gamma

        if self.bias is not None:
            bias = self.bias.tile(1, examples_per_model).view(
                batch_size, self.out_channels)[:, :, None, None]
            x = x + bias
        return x

    def reset_parameters(self):
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.conv.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

        # random sign initialization for fast weights as mentioned in paper
        with torch.no_grad():
            self.alpha.bernoulli_(0.5).mul_(2).add_(-1)
            self.gamma.bernoulli_(0.5).mul_(2).add_(-1)


class BatchEnsembleConv1D(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, num_models, stride=1, padding=0, bias=True):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.num_models = num_models

        self.alpha = nn.Parameter(torch.empty(num_models, in_channels))
        self.gamma = nn.Parameter(torch.empty(num_models, out_channels))
        self.conv = nn.Conv1d(in_channels, out_channels,
                              kernel_size, stride, padding, bias=False)

        if bias:
            # use one bias vector per ensemble member
            self.bias = nn.Parameter(torch.empty(num_models, out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def forward(self, x):
        """
        X: Tensor of shape (B * M, C_in, W)
        Ensemble members should be stacked in BATCH dimension.
        Dim 0 layout:
            ------ batch elem 0, model 0 ------
            -------batch elem 1, model 0 ------
                      ...
            ------ batch elem 0, model n ------
            -------batch elem 1, model n ------
                      ...
        """
        batch_size = x.shape[0]
        # arguably this is the actual batch size
        examples_per_model = batch_size // self.num_models

        alpha = self.alpha.tile(1, examples_per_model).view(
            batch_size, self.in_channels)[:, :, None]
        gamma = self.gamma.tile(1, examples_per_model).view(
            batch_size, self.out_channels)[:, :, None]

        x = self.conv(x * alpha) * gamma

        if self.bias is not None:
            bias = self.bias.tile(1, examples_per_model).view(
                batch_size, self.out_channels)[:, :, None]
            x = x + bias
        return x

    def reset_parameters(self):
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.conv.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

        # random sign initialization for fast weights as mentioned in paper
        with torch.no_grad():
            self.alpha.bernoulli_(0.5).mul_(2).add_(-1)
            self.gamma.bernoulli_(0.5).mul_(2).add_(-1)


class BatchEnsembleLinearPlus(nn.Module):

    def __init__(self, input_size, output_size, ensemble_size, bias=True, diversity = False):
        super().__init__()
        self.in_features = input_size
        self.out_features = output_size
        self.ensemble_size = ensemble_size
        self.weight_diversity = diversity

        self.W = nn.Parameter(torch.empty(output_size, input_size))  # m*n
        self.r = nn.Parameter(torch.empty(ensemble_size, input_size))  # M*m
        self.s = nn.Parameter(torch.empty(ensemble_size, output_size))  # M*n

        if bias:
            self.bias = nn.Parameter(torch.empty(ensemble_size, output_size))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def forward(self, X):
        """
        Expects input in shape (B*M, C_in), dim 0 layout:
            ------ x0, model 0 ------
            -------x0, model 1 ------
                      ...
            ------ x1, model 0 ------
            -------x1, model 1 ------
                      ...
        """
        B = X.shape[0] // self.ensemble_size
        X = X.view(B, self.ensemble_size, -1)  # Reshape input to (B, M, C_in)
        R = self.r.unsqueeze(0)  # Add a dimension for broadcasting
        S = self.s.unsqueeze(0)  # Add a dimension for broadcasting
        bias = self.bias.unsqueeze(0)  # Add a dimension for broadcasting

        # Eq. 5 from BatchEnsembles paper
        output = torch.matmul((X * R), self.W.t()) * S + bias  # (B, M, C_out)
        
        # Flatten output back to (B*M, C_out)
        output = output.view(B * self.ensemble_size, -1)
        diver =  torch.tensor(0) 
        if self.weight_diversity:
          R1 = self.r/torch.norm(self.r,dim=1,keepdim=True)
          S1 = self.s/torch.norm(self.s,dim=1,keepdim=True)
          diver = 1 - (torch.mean(torch.matmul(R1,R1.t()) + torch.matmul(S1,S1.t())))/2

        return output,diver

    def reset_parameters(self):
        # nn.init.kaiming_uniform_(self.W, a=math.sqrt(5))
        nn.init.xavier_uniform_(self.W,gain=nn.init.calculate_gain('relu'))
        # Another way to initialize the fast weights
        #nn.init.normal_(self.r, mean=1., std=0.1)
        #nn.init.normal_(self.s, mean=1., std=0.1)

        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)
        
        if True:
            with torch.no_grad():
              # random sign initialization from paper
                self.r.bernoulli_(0.5).mul_(2).add_(-1)
                self.s.bernoulli_(0.5).mul_(2).add_(-1)
        else:
            # nn.init.normal_(self.r, mean=1., std=0.5)
            # nn.init.normal_(self.s, mean=1., std=0.5)
            nn.init.normal_(self.r, mean=1., std=0.5)
            nn.init.normal_(self.r, mean=1., std=0.5)




class BatchEnsembleLinear(nn.Module):

    def __init__(self, input_size, output_size, ensemble_size, bias=True):
        super().__init__()
        self.in_features = input_size
        self.out_features = output_size
        self.ensemble_size = ensemble_size

        self.W = nn.Parameter(torch.empty(output_size, input_size))  # m*n
        self.r = nn.Parameter(torch.empty(ensemble_size, input_size))  # M*m
        self.s = nn.Parameter(torch.empty(ensemble_size, output_size))  # M*n

        if bias:
            self.bias = nn.Parameter(torch.empty(ensemble_size, output_size))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def forward(self, X):
        """
        Expects input in shape (B*M, C_in), dim 0 layout:
            ------ x0, model 0 ------
            -------x0, model 1 ------
                      ...
            ------ x1, model 0 ------
            -------x1, model 1 ------
                      ...
        """
        B = X.shape[0] // self.ensemble_size
        X = X.view(B, self.ensemble_size, -1)  # Reshape input to (B, M, C_in)
        R = self.r.unsqueeze(0)  # Add a dimension for broadcasting
        S = self.s.unsqueeze(0)  # Add a dimension for broadcasting
        bias = self.bias.unsqueeze(0)  # Add a dimension for broadcasting

        # Eq. 5 from BatchEnsembles paper
        output = torch.matmul((X * R), self.W.t()) * S + bias  # (B, M, C_out)
        
        # Flatten output back to (B*M, C_out)
        output = output.view(B * self.ensemble_size, -1)
        return output

    def reset_parameters(self):
        # nn.init.kaiming_uniform_(self.W, a=math.sqrt(5))
        nn.init.xavier_uniform_(self.W,gain=nn.init.calculate_gain('relu'))
        # Another way to initialize the fast weights
        #nn.init.normal_(self.r, mean=1., std=0.1)
        #nn.init.normal_(self.s, mean=1., std=0.1)

        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)
        
        if True:
            with torch.no_grad():
              # random sign initialization from paper
                self.r.bernoulli_(0.5).mul_(2).add_(-1)
                self.s.bernoulli_(0.5).mul_(2).add_(-1)
        else:
            # nn.init.normal_(self.r, mean=1., std=0.5)
            # nn.init.normal_(self.s, mean=1., std=0.5)
            nn.init.normal_(self.r, mean=1., std=0.5)
            nn.init.normal_(self.r, mean=1., std=0.5)

## Batch ensemble

In [2]:
class BatchEnsembleFlattenMLP(nn.Module):

    def __init__(
            self,
            ensemble_size,
            hidden_sizes,
            input_size,
            output_size,
            init_w=3e-3,
            hidden_init=ptu.fanin_init,
            w_scale=1,
            b_init_value=0.1,
            layer_norm=None,
            batch_norm=False,
            final_init_scale=None,
            norm_input=False,
            obs_norm_mean=None,
            obs_norm_std=None,
    ):
        super().__init__()

        self.ensemble_size = ensemble_size
        self.input_size = input_size
        self.output_size = output_size

        self.sampler = np.random.default_rng()

        self.hidden_activation = F.relu
        self.output_activation = identity
        
        self.layer_norm = layer_norm

        self.norm_input = norm_input
        if self.norm_input:
            self.obs_norm_mean, self.obs_norm_std = ptu.from_numpy(obs_norm_mean), ptu.from_numpy(obs_norm_std + 1e-6)

        self.fcs = []

        if batch_norm:
            raise NotImplementedError

        in_size = input_size
        for i, next_size in enumerate(hidden_sizes):
            fc = BatchEnsembleLinear(
                ensemble_size=ensemble_size,
                input_size=in_size,
                output_size=next_size,
            )
            self.__setattr__('fc%d'% i, fc)
            self.fcs.append(fc)
            in_size = next_size

        self.last_fc = BatchEnsembleLinear(
            ensemble_size=ensemble_size,
            input_size=in_size,
            output_size=output_size,
        )
        if final_init_scale is None:
            self.last_fc.W.data.uniform_(-init_w, init_w)
            self.last_fc.bias.data.uniform_(-init_w, init_w)

    def forward(self, *inputs, **kwargs):
        """Calculate the forward pass of Q(s, a).

        Args:
            inputs: list[observation,action]: list of tensors containing the observation and action size B x obs_dim , B x act_dim 

        Returns:
            Q(s,a): return Q(s,a) size B x 1, where emsamble members output are stack along dim 0 [q_m0 , q_m1, ...,q_mN, q_m0, q_m1, ...]^T 
        """
        
        inputs = [inputs[0], inputs[1]]
        if self.norm_input:
            inputs[0] = (inputs[0] - self.obs_norm_mean) / self.obs_norm_std

        flat_inputs = torch.cat(inputs, dim=-1)
        dim=len(flat_inputs.shape)
        if kwargs.get("sample",False):
            flat_inputs = flat_inputs.repeat_interleave(self.ensemble_size,0)

        # input normalization
        h = flat_inputs

        # standard feedforward network
        for _, fc in enumerate(self.fcs):
            h = fc(h)
            h = self.hidden_activation(h)
            if hasattr(self, 'layer_norm') and (self.layer_norm is not None):
                h = self.layer_norm(h)
        preactivation = self.last_fc(h)
        output = self.output_activation(preactivation)
        return output

    def sample(self, *inputs):
        preds = self.forward(*inputs,sample=True)
        B = preds.shape[0] // self.ensemble_size
        #(B*Self.ensemble_size,1) => (self.ensemble_size, B, 1)
        preds = preds.view(B,self.ensemble_size,-1 )
        # Return min, mean and std of the ensemble
        return torch.min(preds, dim=1)[0],preds.mean(dim=1), preds.std(dim=1)
 


    def fit_input_stats(self, data, mask=None):
        raise NotImplementedError



A = BatchEnsembleFlattenMLP(ensemble_size=10, hidden_sizes=[64,64], input_size=4, output_size=1, norm_input=True, obs_norm_mean=np.array([0,0]), obs_norm_std=np.array([1,1]))
print(A)

BatchEnsembleFlattenMLP(
  (fc0): BatchEnsembleLinear()
  (fc1): BatchEnsembleLinear()
  (last_fc): BatchEnsembleLinear()
)


## Batch Ensemble Plus

In [3]:
# Improve efficiency



class BatchEnsembleFlattenMLPPlus(nn.Module):

    def __init__(
            self,
            ensemble_size,
            hidden_sizes,
            input_size,
            output_size,
            init_w=3e-3,
            hidden_init=ptu.fanin_init,
            w_scale=1,
            b_init_value=0.1,
            layer_norm=None,
            batch_norm=False,
            final_init_scale=None,
            norm_input=False,
            obs_norm_mean=None,
            obs_norm_std=None,
            hidden_activate=F.gelu, # THANH
            diversity_regularize = False # THANH
    ):
        super().__init__()

        self.ensemble_size = ensemble_size
        self.ensemble_num = ensemble_size
        self.input_size = input_size
        self.output_size = output_size

        self.sampler = np.random.default_rng()

        self.hidden_activation = hidden_activate
        self.output_activation = identity
        
        self.layer_norm = layer_norm

        self.norm_input = norm_input
        if self.norm_input:
            self.obs_norm_mean, self.obs_norm_std = ptu.from_numpy(obs_norm_mean), ptu.from_numpy(obs_norm_std + 1e-6)

        self.fcs = []
        self.diversity_regularize= diversity_regularize

        if batch_norm:
            raise NotImplementedError

        in_size = input_size
        for i, next_size in enumerate(hidden_sizes):
            fc = BatchEnsembleLinearPlus(
                ensemble_size=ensemble_size,
                input_size=in_size,
                output_size=next_size,
                diversity = self.diversity_regularize,
            )
            self.__setattr__('fc%d'% i, fc)
            self.fcs.append(fc)
            in_size = next_size

        self.last_fc = BatchEnsembleLinearPlus(
            ensemble_size=ensemble_size,
            input_size=in_size,
            output_size=output_size,
            diversity = self.diversity_regularize,
        )


    def forward(self, *inputs, **kwargs):
        if self.norm_input:
            obs = (inputs[0] - self.obs_norm_mean) / self.obs_norm_std
            flat_inputs = torch.cat([obs, inputs[1]], dim=-1)
        else:
            flat_inputs = torch.cat([inputs[0], inputs[1]], dim=-1)
        
        if kwargs.get("sample", False):
            flat_inputs = flat_inputs.repeat_interleave(self.ensemble_size, 0)

        # input normalization
        h = flat_inputs

        # standard feedforward network
        diversity = 0
        for _, fc in enumerate(self.fcs):
            h,div = fc(h)
            diversity +=div 
            h = self.hidden_activation(h)
            if hasattr(self, 'layer_norm') and (self.layer_norm is not None):
                h = self.layer_norm(h)
        preactivation,div = self.last_fc(h)
        diversity +=div
        output = self.output_activation(preactivation)
        return output,diversity

    def sample(self, *inputs):
        preds,*_ = self.forward(*inputs,sample=True)
        B = preds.shape[0] // self.ensemble_size
        #(B*Self.ensemble_size,1) => (self.ensemble_size, B, 1)
        preds = preds.view(B, self.ensemble_size, -1)
        # Return min, mean and std of the ensemble
        return torch.min(preds, dim=1)[0],0#,preds.mean(dim=1), preds.std(dim=1)
    
    def fit_input_stats(self, data, mask=None):
        raise NotImplementedError



A = BatchEnsembleFlattenMLPPlus(ensemble_size=10, hidden_sizes=[64,64], input_size=4, output_size=1, norm_input=True, obs_norm_mean=np.array([0,0]), obs_norm_std=np.array([1,1]))
print(A)

BatchEnsembleFlattenMLPPlus(
  (fc0): BatchEnsembleLinearPlus()
  (fc1): BatchEnsembleLinearPlus()
  (last_fc): BatchEnsembleLinearPlus()
)


In [4]:
a = torch.randn(10)
print(a.shape)
a.unsqueeze(1)
print(a.shape)
a.view(-1).shape

torch.Size([10])
torch.Size([10])


torch.Size([10])

## Rank-1 Ensemble 

In [5]:

class BatchEnsembleLinearRank1(nn.Module):

    def __init__(self, input_size, output_size, ensemble_size, bias=True, diversity = False):
        super().__init__()
        self.in_features = input_size
        self.out_features = output_size
        self.ensemble_size = ensemble_size
        self.weight_diversity = diversity

        self.W = nn.Parameter(torch.empty(output_size, input_size))  # m*n
        self.r = nn.Parameter(torch.empty(ensemble_size, input_size))  # M*m
        self.s = nn.Parameter(torch.empty(ensemble_size, output_size))  # M*n

        if bias:
            self.bias = nn.Parameter(torch.empty(ensemble_size, output_size))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def forward(self, X):
        '''
        X: (B, M, C_in)
        return (B, M, C_out)

        '''
        R = self.r.unsqueeze(0)  # Add a dimension for broadcasting
        S = self.s.unsqueeze(0)  # Add a dimension for broadcasting
        bias = self.bias.unsqueeze(0)  # Add a dimension for broadcasting

        # Eq. 5 from BatchEnsembles paper
        output = torch.matmul((X * R), self.W.t()) * S + bias  # (B, M, C_out)

        diver =  torch.tensor(0) 
        if self.weight_diversity:
          R1 = self.r/torch.norm(self.r,dim=1,keepdim=True)
          S1 = self.s/torch.norm(self.s,dim=1,keepdim=True)
          diver = 1 - (torch.mean(torch.matmul(R1,R1.t()) + torch.matmul(S1,S1.t())))/2

        return output,diver

    def reset_parameters(self):
        # nn.init.kaiming_uniform_(self.W, a=math.sqrt(5))
        nn.init.xavier_uniform_(self.W,gain=nn.init.calculate_gain('relu'))
        # Another way to initialize the fast weights
        #nn.init.normal_(self.r, mean=1., std=0.1)
        #nn.init.normal_(self.s, mean=1., std=0.1)

        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)
        
        if True:
            with torch.no_grad():
              # random sign initialization from paper
                self.r.bernoulli_(0.5).mul_(2).add_(-1)
                self.s.bernoulli_(0.5).mul_(2).add_(-1)
        else:
            # nn.init.normal_(self.r, mean=1., std=0.5)
            # nn.init.normal_(self.s, mean=1., std=0.5)
            nn.init.normal_(self.r, mean=1., std=0.5)
            nn.init.normal_(self.r, mean=1., std=0.5)


class BatchEnsembleFlattenRank1(nn.Module):

    def __init__(
            self,
            ensemble_size,
            hidden_sizes,
            input_size,
            output_size,
            layer_norm=None,
            batch_norm=False,
            norm_input=False,
            obs_norm_mean=None,
            obs_norm_std=None,
            hidden_activate=F.gelu, # THANH
            diversity_regularize = False # THANH
    ):
        super().__init__()

        self.ensemble_size = ensemble_size
        self.ensemble_num = ensemble_size
        self.input_size = input_size
        self.output_size = output_size

        self.sampler = np.random.default_rng()

        self.hidden_activation = hidden_activate
        self.output_activation = identity
        
        self.layer_norm = layer_norm

        self.norm_input = norm_input
        if self.norm_input:
            self.obs_norm_mean, self.obs_norm_std = ptu.from_numpy(obs_norm_mean), ptu.from_numpy(obs_norm_std + 1e-6)

        self.fcs = []
        self.diversity_regularize= diversity_regularize

        if batch_norm:
            raise NotImplementedError

        in_size = input_size
        for i, next_size in enumerate(hidden_sizes):
            fc = BatchEnsembleLinearRank1(
                ensemble_size=ensemble_size,
                input_size=in_size,
                output_size=next_size,
                diversity = self.diversity_regularize,
            )
            self.__setattr__('fc%d'% i, fc)
            self.fcs.append(fc)
            in_size = next_size

        self.last_fc = BatchEnsembleLinearRank1(
            ensemble_size=ensemble_size,
            input_size=in_size,
            output_size=output_size,
            diversity = self.diversity_regularize,
        )


    def forward(self, *inputs, **kwargs):
        if self.norm_input:
            obs = (inputs[0] - self.obs_norm_mean) / self.obs_norm_std
            flat_inputs = torch.cat([obs, inputs[1]], dim=-1)
        else:
            flat_inputs = torch.cat([inputs[0], inputs[1]], dim=-1)

        flat_inputs = flat_inputs.repeat_interleave(self.ensemble_size, 0).view(-1,self.ensemble_size, self.input_size)

        # input normalization
        h = flat_inputs

        # standard feedforward network
        diversity = 0
        for _, fc in enumerate(self.fcs):
            h,div = fc(h)
            diversity +=div 
            h = self.hidden_activation(h)
            if hasattr(self, 'layer_norm') and (self.layer_norm is not None):
                h = self.layer_norm(h)
        preactivation,div = self.last_fc(h)
        diversity +=div
        output = self.output_activation(preactivation) # (B,M, C_out)
        # Transpose to (M, B, C_out)
        output = output.transpose(0, 1).contiguous()
        
        return output #,diversity

    def sample(self, *inputs):
        preds = self.forward(*inputs)
        return torch.min(preds, dim=0)[0]
    
    def fit_input_stats(self, data, mask=None):
        raise NotImplementedError



A = BatchEnsembleFlattenMLPPlus(ensemble_size=10, hidden_sizes=[64,64], input_size=4, output_size=1, norm_input=True, obs_norm_mean=np.array([0,0]), obs_norm_std=np.array([1,1]))
print(A)

BatchEnsembleFlattenMLPPlus(
  (fc0): BatchEnsembleLinearPlus()
  (fc1): BatchEnsembleLinearPlus()
  (last_fc): BatchEnsembleLinearPlus()
)


## GAUSS ensemble 

In [16]:
class BatchEnsembleLinearGauss(nn.Module):

    def __init__(self, input_size, output_size, ensemble_size, bias=True, diversity = False):
        super().__init__()
        self.in_features = input_size
        self.out_features = output_size
        self.ensemble_size = ensemble_size
        self.weight_diversity = diversity

        self.W = nn.Parameter(torch.empty(output_size, input_size))  # m*n
        self.r = nn.Parameter(torch.empty(ensemble_size, input_size))  # M*m
        self.s = nn.Parameter(torch.empty(ensemble_size, output_size))  # M*n

        if bias:
            self.bias = nn.Parameter(torch.empty(ensemble_size, output_size))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def forward(self, X):
        '''
        X: (B, M, C_in)
        return (B, M, C_out)

        '''
        R = self.r.unsqueeze(0)  # Add a dimension for broadcasting
        S = self.s.unsqueeze(0)  # Add a dimension for broadcasting
        bias = self.bias.unsqueeze(0)  # Add a dimension for broadcasting

        # Eq. 5 from BatchEnsembles paper
        output = torch.matmul((X * R), self.W.t()) * S + bias  # (B, M, C_out)

        diver =  torch.tensor(0) 
        if self.weight_diversity:
          R1 = self.r/torch.norm(self.r,dim=1,keepdim=True)
          S1 = self.s/torch.norm(self.s,dim=1,keepdim=True)
          diver = 1 - (torch.mean(torch.matmul(R1,R1.t()) + torch.matmul(S1,S1.t())))/2

        return output,diver

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.W, a=math.sqrt(5))

        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

        with torch.no_grad():
        # random sign initialization from paper
            self.r.bernoulli_(0.5).mul_(2).add_(-1)
            self.s.bernoulli_(0.5).mul_(2).add_(-1)



class BatchEnsembleGaussMLP(nn.Module):

    def __init__(
            self,
            ensemble_size,
            hidden_sizes,
            input_size,
            output_size,
            init_w=3e-3,
            layer_norm=None,
            batch_norm=False,
            final_init_scale=None,
            norm_input=False,
            obs_norm_mean=None,
            obs_norm_std=None,
            hidden_activate=F.gelu, # THANH
            diversity_regularize = False # THANH
    ):
        super().__init__()

        self.ensemble_size = ensemble_size
        self.ensemble_num = ensemble_size
        self.input_size = input_size
        self.output_size = output_size

        self.sampler = np.random.default_rng()

        self.hidden_activation = hidden_activate
        self.output_activation = identity
        
        self.norm_input = norm_input
        if self.norm_input:
            self.obs_norm_mean, self.obs_norm_std = ptu.from_numpy(obs_norm_mean), ptu.from_numpy(obs_norm_std + 1e-6)

        self.layer_norm = layer_norm

        self.norm_input = norm_input

        self.fcs = []
        self.diversity_regularize= diversity_regularize

        if batch_norm:
            raise NotImplementedError

        in_size = input_size
        for i, next_size in enumerate(hidden_sizes):
            fc = BatchEnsembleLinearGauss(
                ensemble_size=ensemble_size,
                input_size=in_size,
                output_size=next_size,
                diversity = self.diversity_regularize,
            )
            self.__setattr__('fc%d'% i, fc)
            self.fcs.append(fc)
            in_size = next_size

        self.mu = BatchEnsembleLinearGauss(
            ensemble_size=ensemble_size,
            input_size=in_size,
            output_size=output_size,
            diversity = self.diversity_regularize,
        )

        self.var = BatchEnsembleLinearGauss(
            ensemble_size=ensemble_size,
            input_size=in_size,
            output_size=output_size,
            diversity = self.diversity_regularize,
        )
        if final_init_scale is None:
            self.mu.W.data.uniform_(-init_w, init_w)
            self.var.W.data.uniform_(-init_w, init_w)
            self.mu.bias.data.uniform_(-init_w, init_w)
            self.var.bias.data.uniform_(-init_w, init_w)

    def forward(self, *inputs, **kwargs):

        if self.norm_input:
            obs = (inputs[0] - self.obs_norm_mean) / self.obs_norm_std
            flat_inputs = torch.cat([obs, inputs[1]], dim=-1)
        else:
            flat_inputs = torch.cat([inputs[0], inputs[1]], dim=-1)

        flat_inputs = flat_inputs.repeat_interleave(self.ensemble_size, 0).view(-1,self.ensemble_size, self.input_size)

        # input normalization
        h = flat_inputs

        # standard feedforward network
        diversity = 0
        for _, fc in enumerate(self.fcs):
            h,div = fc(h)
            diversity +=div 
            h = self.hidden_activation(h)
            if hasattr(self, 'layer_norm') and (self.layer_norm is not None):
                h = self.layer_norm(h)
        mu_pre,div = self.mu(h)
        diversity +=div
        var_pre,div = self.var(h)
        diversity +=div
        var = torch.exp(var_pre)
        var = var.transpose(0, 1).contiguous()
        mu = self.output_activation(mu_pre)
        mu = mu.transpose(0, 1).contiguous()
        return mu,var, diversity

    def sample(self, *inputs):

        preds,vars,*_ = self.forward(*inputs)
        return preds,vars
    
    def fit_input_stats(self, data, mask=None):
        raise NotImplementedError

## MIMO 

In [6]:
# https://github.com/noowad93/MIMO-pytorch
# (The link above may contain some errors)
# https://colab.research.google.com/drive/16i8Wd8hYYgZfVLs6MPVFZ2faccpnYkyk?usp=sharing#scrollTo=i_2GT74Ecu0c
# https://colab.research.google.com/drive/1JIgyVeEmlOH-j0oGmeDLV8iE5UYdreYm#scrollTo=Ac1oS_S8kQOB

class MimoEnsembleFlattenMLP(nn.Module):
    def __init__(self, 
                ensemble_size,
                hidden_sizes,
                input_size,
                output_size,
                init_w=3e-3,
                w_scale=1,
                b_init_value=0.1,
                layer_norm=None,
                batch_norm=False,
                final_init_scale=None,
                norm_input=False,
                obs_norm_mean=None,
                obs_norm_std=None,
                width_multiplier = 1):
        super(MimoEnsembleFlattenMLP, self).__init__()
        self.ensemble_num = ensemble_size
        self.hidden_activation = torch.tanh
        
        self.input_layer = nn.Linear(input_size*ensemble_size, hidden_sizes[0]*width_multiplier)
        self.backbone_model = BackboneModel([layer_size*width_multiplier for layer_size in hidden_sizes],hidden_activation=self.hidden_activation)
        self.norm_input = norm_input
        if self.norm_input:
            self.obs_norm_mean, self.obs_norm_std = ptu.from_numpy(obs_norm_mean), ptu.from_numpy(obs_norm_std + 1e-6)
        self.output_layer = nn.Linear(hidden_sizes[-1]*width_multiplier, output_size* ensemble_size)
        self.output_activation = identity
        # initialize weights
        init.xavier_uniform_(self.input_layer.weight)
        self.input_layer.bias.data.fill_(0)
        init.xavier_uniform_(self.output_layer.weight)
        self.output_layer.bias.data.fill_(0)

    def forward(self, *inputs, **kwargs):
        inputs = [inputs[0], inputs[1]]
        if self.norm_input:
            inputs[0].sub_(self.obs_norm_mean).div_(self.obs_norm_std)

        inputs = torch.cat(inputs, dim=-1)

        if kwargs.get("sample",None) == True:
            inputs = inputs.repeat_interleave(self.ensemble_num,0)

        return self.forward_(inputs)
        
    def forward_(self,input):
        dim = len(input.shape)
        # transform B*E to B//M*(E*M)
        B,E,*_= input.shape
        M = self.ensemble_num
        h = input.view(B//M,-1) 

        # standard feedforward network
        h = self.input_layer(h)
        h = self.hidden_activation(h)
        h = self.backbone_model(h)
        h = self.output_layer(h)
        output = self.output_activation(h)

        # if original dim was 1D, squeeze the extra created layer
        if  dim == 1:
            output = output.squeeze(1)
        return output.view(B, -1)


    def sample(self, *inputs):
        preds = self.forward(*inputs,sample = True)
        B = preds.shape[0] // self.ensemble_num
        return  torch.min(preds.view(B,self.ensemble_num,-1), dim=1)

class BackboneModel(nn.Module):
    def __init__(self, hidden_dim, hidden_activation=F.relu):
        super(BackboneModel, self).__init__()
        self.hidden_activation = hidden_activation
        for i, (in_dim, out_dim) in enumerate(zip(hidden_dim[:-1], hidden_dim[1:])):
            self.add_module(f"l{i}", nn.Linear(in_dim, out_dim))
        self.apply(self.init_weights)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for _, layer in self.named_children():
            x = layer(x)
            x = self.hidden_activation(x)
        return x
    
    def init_weights(self,m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
            m.bias.data.fill_(0.01)


A = MimoEnsembleFlattenMLP(ensemble_size=10, hidden_sizes=[64,64], input_size=4, output_size=1, norm_input=True, obs_norm_mean=np.array([0,0]), obs_norm_std=np.array([1,1]),width_multiplier=2)

# from flopth import flopth
# lops, params = flopth(A, in_size=((10,),))
# print(lops, params)
print(A)

MimoEnsembleFlattenMLP(
  (input_layer): Linear(in_features=40, out_features=128, bias=True)
  (backbone_model): BackboneModel(
    (l0): Linear(in_features=128, out_features=128, bias=True)
  )
  (output_layer): Linear(in_features=128, out_features=10, bias=True)
)


## Test networks

In [18]:
from lifelong_rl.models.networks import ParallelizedEnsembleFlattenMLP
from lifelong_rl.policies.base.base import MakeDeterministic
from lifelong_rl.policies.models.tanh_gaussian_policy import TanhGaussianPolicy
from lifelong_rl.trainers.q_learning.sac import SACTrainer
import lifelong_rl.util.pythonplusplus as ppp
from torch.nn import functional as F

num_qs = 20 #variant['trainer_kwargs']['num_qs']
M = 512 #variant['policy_kwargs']['layer_size']
num_q_layers = 4 #variant['policy_kwargs']['num_q_layers']
num_p_layers = 4 #variant['policy_kwargs']['num_p_layers']
obs_dim = 7
action_dim = 4

# normalization
norm_input = False#variant['norm_input']
obs_norm_mean, obs_norm_std = torch.randn(obs_dim).numpy() , torch.randn(obs_dim).numpy() #variant['normalization_info']['obs_mean'], variant['normalization_info']['obs_std']

qfs, target_qfs = ppp.group_init(
    2,
    ParallelizedEnsembleFlattenMLP,
    ensemble_size=num_qs,
    hidden_sizes=[M] * num_q_layers,
    input_size=obs_dim + action_dim,
    output_size=1,
    layer_norm=None,
    norm_input=norm_input,
    obs_norm_mean=obs_norm_mean,
    obs_norm_std=obs_norm_std,
)

qfs3, target_qfs3 = ppp.group_init(
    2,
    BatchEnsembleFlattenRank1,
    ensemble_size=num_qs,
    hidden_sizes=[M] * num_q_layers,
    input_size=obs_dim + action_dim,
    output_size=1,
    layer_norm=None,
    norm_input=norm_input,
    obs_norm_mean=obs_norm_mean,
    obs_norm_std=obs_norm_std,
)


qfs4, target_qfs4 = ppp.group_init(
    2,
    BatchEnsembleGaussMLP,
    ensemble_size=num_qs,
    hidden_sizes=[M] * num_q_layers,
    input_size=obs_dim + action_dim,
    output_size=1,
    layer_norm=None,
    norm_input=norm_input,
    obs_norm_mean=obs_norm_mean,
    obs_norm_std=obs_norm_std,
)



policy = TanhGaussianPolicy(
    obs_dim=obs_dim,
    action_dim=action_dim,
    hidden_sizes=[M] * num_p_layers,
    layer_norm=None,
    norm_input=norm_input,
    obs_norm_mean=obs_norm_mean,
    obs_norm_std=obs_norm_std,
)


qfs1, target_qfs1 = ppp.group_init(
    2,
    BatchEnsembleFlattenMLPPlus,
    ensemble_size=num_qs,
    hidden_sizes=[M] * num_q_layers,
    input_size=obs_dim + action_dim,
    output_size=1,
    layer_norm=None,
    norm_input=norm_input,
    obs_norm_mean=obs_norm_mean,
    obs_norm_std=obs_norm_std,
)


qfs2, target_qfs2 = ppp.group_init(
    2,
    MimoEnsembleFlattenMLP,
    ensemble_size=num_qs,
    hidden_sizes=[M] * num_q_layers,
    input_size=obs_dim + action_dim,
    output_size=1,
    layer_norm=None,
    norm_input=norm_input,
    obs_norm_mean=obs_norm_mean,
    obs_norm_std=obs_norm_std,
)




In [19]:
# Test networks
import gtimer as gt
obs = torch.randn(500, obs_dim)
acts = torch.randn(500, action_dim)

print(qfs(obs, acts).shape)
print(target_qfs.sample(obs, acts).shape)

print(qfs3(obs, acts).shape)
print(target_qfs3.sample(obs, acts).shape)

print(qfs4(obs, acts)[0].shape)
print(target_qfs4.sample(obs, acts)[0].shape)


print(qfs1(obs, acts)[0].shape) if isinstance(qfs1(obs, acts),tuple) else print(qfs1(obs, acts).shape)
print(target_qfs1.sample(obs, acts)[0].shape)

print(qfs2(obs, acts).shape)
print(target_qfs2.sample(obs, acts)[0].shape)

gt.reset()
gt.start()
gt.stamp('start')
for i in range(10):
    qfs(obs, acts)
    target_qfs.sample(obs, acts)
gt.stamp('EnsembleFlattenMLP')

for i in range(10):
    qfs3(obs, acts)
    target_qfs3.sample(obs, acts)
gt.stamp('EnsembleFlattenRank1')

for i in range(10):
    qfs4(obs, acts)
    target_qfs4.sample(obs, acts)
gt.stamp('EnsembleGaussMLP')


for i in range(10):
    qfs1(obs, acts)
    target_qfs1.sample(obs, acts)
gt.stamp('BatchEnsembleFlattenMLP')
for i in range(10):
    qfs2(obs, acts)
    target_qfs2.sample(obs, acts)
gt.stamp('MimoEnsembleFlattenMLP')

print(gt.report())


torch.Size([20, 500, 1])
torch.Size([500, 1])
torch.Size([20, 500, 1])
torch.Size([500, 1])
torch.Size([500, 20, 1])
torch.Size([500, 20, 1])
torch.Size([500, 1])
torch.Size([500, 1])
torch.Size([500, 1])
torch.Size([500, 1])


---Begin Timer Report (root)---
Timer Name:          root (running)
Total Time (s):      7.65
Stamps Sum:          7.642
Self Time (Agg.):    0.0002102


Intervals
---------
start .............. 5.657e-05
EnsembleFlattenMLP . 1.723
EnsembleFlattenRank1  2.249
EnsembleGaussMLP ... 2.315
BatchEnsembleFlattenMLP  1.271
MimoEnsembleFlattenMLP  0.0829

---End Timer Report (root)---




## Test pytorch

In [None]:
import torch

x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=float)
y = x.repeat_interleave(10, 0)  # repeat each element along dim 0 10 times
print(y)

x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=float)
y = x.repeat(10, 1)  # repeat dim 0 10 times, dim 1 1 time
print(y)

y = y.unsqueeze(0).reshape(2, 10, 3)
print(torch.min(y, dim=0, keepdim=True).values.shape)

x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=float)
A = 3
size = 4
print(x.repeat_interleave(size, 0))
print(x.reshape(-1, 1, A).repeat(1, size, 1).reshape(-1, A))
