In [1]:
import math
import numpy as np
import torch
from torch import nn as nn
from torch.nn import functional as F
from torch.nn import init
from lifelong_rl.torch import pytorch_util as ptu
from lifelong_rl.torch.modules import LayerNorm


def identity(x):
    return x


class BatchEnsembleConv2D(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, num_models, stride=1, padding=0, bias=True):
        super().__init__()

        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.num_models = num_models

        self.alpha = nn.Parameter(torch.empty(num_models, in_channels))
        self.gamma = nn.Parameter(torch.empty(num_models, out_channels))
        self.conv = nn.Conv2d(in_channels, out_channels,
                              kernel_size, stride, padding, bias=False)

        if bias:
            # use one bias vector per ensemble member
            self.bias = nn.Parameter(torch.empty(num_models, out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def forward(self, x):
        """
        X: Tensor of shape (B * M, C_in, H, W)
        Ensemble members should be stacked in BATCH dimension.
        Dim 0 layout:
            ------ batch elem 0, model 0 ------
            -------batch elem 1, model 0 ------
                      ...
            ------ batch elem 0, model n ------
            -------batch elem 1, model n ------
                      ...
        """
        batch_size = x.shape[0]
        # arguably this is the actual batch size
        examples_per_model = batch_size // self.num_models

        alpha = self.alpha.tile(1, examples_per_model).view(
            batch_size, self.in_channels)[:, :, None, None]
        gamma = self.gamma.tile(1, examples_per_model).view(
            batch_size, self.out_channels)[:, :, None, None]

        x = self.conv(x * alpha) * gamma

        if self.bias is not None:
            bias = self.bias.tile(1, examples_per_model).view(
                batch_size, self.out_channels)[:, :, None, None]
            x = x + bias
        return x

    def reset_parameters(self):
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.conv.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

        # random sign initialization for fast weights as mentioned in paper
        with torch.no_grad():
            self.alpha.bernoulli_(0.5).mul_(2).add_(-1)
            self.gamma.bernoulli_(0.5).mul_(2).add_(-1)


class BatchEnsembleConv1D(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, num_models, stride=1, padding=0, bias=True):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.num_models = num_models

        self.alpha = nn.Parameter(torch.empty(num_models, in_channels))
        self.gamma = nn.Parameter(torch.empty(num_models, out_channels))
        self.conv = nn.Conv1d(in_channels, out_channels,
                              kernel_size, stride, padding, bias=False)

        if bias:
            # use one bias vector per ensemble member
            self.bias = nn.Parameter(torch.empty(num_models, out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def forward(self, x):
        """
        X: Tensor of shape (B * M, C_in, W)
        Ensemble members should be stacked in BATCH dimension.
        Dim 0 layout:
            ------ batch elem 0, model 0 ------
            -------batch elem 1, model 0 ------
                      ...
            ------ batch elem 0, model n ------
            -------batch elem 1, model n ------
                      ...
        """
        batch_size = x.shape[0]
        # arguably this is the actual batch size
        examples_per_model = batch_size // self.num_models

        alpha = self.alpha.tile(1, examples_per_model).view(
            batch_size, self.in_channels)[:, :, None]
        gamma = self.gamma.tile(1, examples_per_model).view(
            batch_size, self.out_channels)[:, :, None]

        x = self.conv(x * alpha) * gamma

        if self.bias is not None:
            bias = self.bias.tile(1, examples_per_model).view(
                batch_size, self.out_channels)[:, :, None]
            x = x + bias
        return x

    def reset_parameters(self):
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.conv.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

        # random sign initialization for fast weights as mentioned in paper
        with torch.no_grad():
            self.alpha.bernoulli_(0.5).mul_(2).add_(-1)
            self.gamma.bernoulli_(0.5).mul_(2).add_(-1)


class BatchEnsembleLinear(nn.Module):

    def __init__(self, input_size, output_size, ensemble_size, bias=True):
        super().__init__()
        self.in_features = input_size
        self.out_features = output_size
        self.ensemble_size = ensemble_size

        self.W = nn.Parameter(torch.empty(output_size, input_size))  # m*n
        self.r = nn.Parameter(torch.empty(ensemble_size, input_size))  # M*m
        self.s = nn.Parameter(torch.empty(ensemble_size, output_size))  # M*n

        if bias:
            self.bias = nn.Parameter(torch.empty(ensemble_size, output_size))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def forward(self, X):
        """
        Expects input in shape (B*M, C_in), dim 0 layout:
            ------ x0, model 0 ------
            -------x0, model 1 ------
                      ...
            ------ x1, model 0 ------
            -------x1, model 1 ------
                      ...
        """
        B = X.shape[0] // self.ensemble_size
        R = self.r.repeat(B, 1)
        S = self.s.repeat(B, 1)
        bias = self.bias.repeat(B, 1)
        # Eq. 5 from BatchEnsembles paper
        return torch.mm((X * R), self.W.T) * S + bias  # (B*M, C_out)

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.W, a=math.sqrt(5))

        # Another way to initialize the fast weights
        #nn.init.normal_(self.r, mean=1., std=0.1)
        #nn.init.normal_(self.s, mean=1., std=0.1)

        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

        with torch.no_grad():
            self.r.bernoulli_(0.5).mul_(2).add_(-1)
            self.s.bernoulli_(0.5).mul_(2).add_(-1)


In [2]:
class BatchEnsembleFlattenMLP(nn.Module):

    def __init__(
            self,
            ensemble_size,
            hidden_sizes,
            input_size,
            output_size,
            init_w=3e-3,
            hidden_init=ptu.fanin_init,
            w_scale=1,
            b_init_value=0.1,
            layer_norm=None,
            batch_norm=False,
            final_init_scale=None,
            norm_input=False,
            obs_norm_mean=None,
            obs_norm_std=None,
    ):
        super().__init__()

        self.ensemble_size = ensemble_size
        self.input_size = input_size
        self.output_size = output_size

        self.sampler = np.random.default_rng()

        self.hidden_activation = F.relu
        self.output_activation = identity
        
        self.layer_norm = layer_norm

        self.norm_input = norm_input
        if self.norm_input:
            self.obs_norm_mean, self.obs_norm_std = ptu.from_numpy(obs_norm_mean), ptu.from_numpy(obs_norm_std + 1e-6)

        self.fcs = []

        if batch_norm:
            raise NotImplementedError

        in_size = input_size
        for i, next_size in enumerate(hidden_sizes):
            fc = BatchEnsembleLinear(
                ensemble_size=ensemble_size,
                input_size=in_size,
                output_size=next_size,
            )
            self.__setattr__('fc%d'% i, fc)
            self.fcs.append(fc)
            in_size = next_size

        self.last_fc = BatchEnsembleLinear(
            ensemble_size=ensemble_size,
            input_size=in_size,
            output_size=output_size,
        )
        if final_init_scale is None:
            self.last_fc.W.data.uniform_(-init_w, init_w)
            self.last_fc.bias.data.uniform_(-init_w, init_w)

    def forward(self, *inputs, **kwargs):
        """Calculate the forward pass of Q(s, a).

        Args:
            inputs: list[observation,action]: list of tensors containing the observation and action size B x obs_dim , B x act_dim 

        Returns:
            Q(s,a): return Q(s,a) size B x 1, where emsamble members output are stack along dim 0 [q_m0 , q_m1, ...,q_mN, q_m0, q_m1, ...]^T 
        """
        
        inputs = [inputs[0], inputs[1]]
        if self.norm_input:
            inputs[0] = (inputs[0] - self.obs_norm_mean) / self.obs_norm_std

        flat_inputs = torch.cat(inputs, dim=-1)
        dim=len(flat_inputs.shape)

        # input normalization
        h = flat_inputs

        # standard feedforward network
        for _, fc in enumerate(self.fcs):
            h = fc(h)
            h = self.hidden_activation(h)
            if hasattr(self, 'layer_norm') and (self.layer_norm is not None):
                h = self.layer_norm(h)
        preactivation = self.last_fc(h)
        output = self.output_activation(preactivation)

        # if original dim was 1D, squeeze the extra created layer
        if dim == 1:
            output = output.squeeze(1)
        return output

    def sample(self, *inputs):
        inputs = [inputs[0].repeat_interleave(self.ensemble_size,0), inputs[1].repeat_interleave(self.ensemble_size,0)]
        preds = self.forward(*inputs)
        B = preds.shape[0] // self.ensemble_size
        #(B*Self.ensemble_size,1) => (self.ensemble_size, B, 1)
        preds = preds.view(self.ensemble_size, B, 1)
        # Return min, mean and std of the ensemble
        return preds.min(0)[0], preds.mean(0), preds.std(0)
 


    def fit_input_stats(self, data, mask=None):
        raise NotImplementedError



A = BatchEnsembleFlattenMLP(ensemble_size=10, hidden_sizes=[64,64], input_size=4, output_size=1, norm_input=True, obs_norm_mean=np.array([0,0]), obs_norm_std=np.array([1,1]))
print(A)

BatchEnsembleFlattenMLP(
  (fc0): BatchEnsembleLinear()
  (fc1): BatchEnsembleLinear()
  (last_fc): BatchEnsembleLinear()
)


In [17]:
# https://github.com/noowad93/MIMO-pytorch
# (The link above may contain some errors)
# https://colab.research.google.com/drive/16i8Wd8hYYgZfVLs6MPVFZ2faccpnYkyk?usp=sharing#scrollTo=i_2GT74Ecu0c
# https://colab.research.google.com/drive/1JIgyVeEmlOH-j0oGmeDLV8iE5UYdreYm#scrollTo=Ac1oS_S8kQOB

class MimoEnsembleFlattenMLP(nn.Module):
    def __init__(self, 
                ensemble_size,
                hidden_sizes,
                input_size,
                output_size,
                init_w=3e-3,
                w_scale=1,
                b_init_value=0.1,
                layer_norm=None,
                batch_norm=False,
                final_init_scale=None,
                norm_input=False,
                obs_norm_mean=None,
                obs_norm_std=None,
                width_multiplier = 1):
        super(MimoEnsembleFlattenMLP, self).__init__()
        self.ensemble_num = ensemble_size
        self.hidden_activation = torch.tanh
        
        self.input_layer = nn.Linear(input_size*ensemble_size, hidden_sizes[0]*width_multiplier)
        self.backbone_model = BackboneModel([layer_size*width_multiplier for layer_size in hidden_sizes],hidden_activation=self.hidden_activation)
        self.norm_input = norm_input
        if self.norm_input:
            self.obs_norm_mean, self.obs_norm_std = ptu.from_numpy(obs_norm_mean), ptu.from_numpy(obs_norm_std + 1e-6)
        self.output_layer = nn.Linear(hidden_sizes[-1]*width_multiplier, output_size* ensemble_size)
        self.output_activation = identity
        # initialize weights
        init.xavier_uniform_(self.input_layer.weight)
        self.input_layer.bias.data.fill_(0)
        init.xavier_uniform_(self.output_layer.weight)
        self.output_layer.bias.data.fill_(0)

    def forward(self, *inputs, **kwargs):
        inputs = [inputs[0], inputs[1]]
        if self.norm_input:
            inputs[0] = (inputs[0] - self.obs_norm_mean) / self.obs_norm_std

        inputs = torch.cat(inputs, dim=-1)

        if kwargs.get("sample",None) == True:
            inputs = inputs.repeat_interleave(self.ensemble_num,0)

        return self.forward_(inputs)
        
    def forward_(self,input):
        dim = len(input.shape)
        # transform B*E to B//M*(E*M)
        B,E,*_= input.shape
        M = self.ensemble_num
        h = input.view(B//M,-1) 

        # standard feedforward network
        h = self.input_layer(h)
        h = self.hidden_activation(h)
        h = self.backbone_model(h)
        h = self.output_layer(h)
        output = self.output_activation(h)

        # if original dim was 1D, squeeze the extra created layer
        if  dim == 1:
            output = output.squeeze(1)
        return output.view(B, -1)


    def sample(self, *inputs):
        preds = self.forward(*inputs,sample = True)
        B = preds.shape[0] // self.ensemble_num
        return  torch.min(preds.view(B,self.ensemble_num,-1), dim=1)

class BackboneModel(nn.Module):
    def __init__(self, hidden_dim, hidden_activation=F.relu):
        super(BackboneModel, self).__init__()
        self.hidden_activation = hidden_activation
        for i, (in_dim, out_dim) in enumerate(zip(hidden_dim[:-1], hidden_dim[1:])):
            self.add_module(f"l{i}", nn.Linear(in_dim, out_dim))
        self.apply(self.init_weights)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for _, layer in self.named_children():
            x = layer(x)
            x = self.hidden_activation(x)
        return x
    
    def init_weights(self,m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
            m.bias.data.fill_(0.01)


A = MimoEnsembleFlattenMLP(ensemble_size=10, hidden_sizes=[64,64], input_size=4, output_size=1, norm_input=True, obs_norm_mean=np.array([0,0]), obs_norm_std=np.array([1,1]),width_multiplier=2)

# from flopth import flopth
# lops, params = flopth(A, in_size=((10,),))
# print(lops, params)
print(A)

MimoEnsembleFlattenMLP(
  (input_layer): Linear(in_features=40, out_features=128, bias=True)
  (backbone_model): BackboneModel(
    (l0): Linear(in_features=128, out_features=128, bias=True)
  )
  (output_layer): Linear(in_features=128, out_features=10, bias=True)
)


In [18]:
from lifelong_rl.models.networks import ParallelizedEnsembleFlattenMLP
from lifelong_rl.policies.base.base import MakeDeterministic
from lifelong_rl.policies.models.tanh_gaussian_policy import TanhGaussianPolicy
from lifelong_rl.trainers.q_learning.sac import SACTrainer
import lifelong_rl.util.pythonplusplus as ppp
from torch.nn import functional as F

num_qs = 10 #variant['trainer_kwargs']['num_qs']
M = 512 #variant['policy_kwargs']['layer_size']
num_q_layers = 4 #variant['policy_kwargs']['num_q_layers']
num_p_layers = 4 #variant['policy_kwargs']['num_p_layers']
obs_dim = 7
action_dim = 4

# normalization
norm_input = None#variant['norm_input']
obs_norm_mean, obs_norm_std = None , None #variant['normalization_info']['obs_mean'], variant['normalization_info']['obs_std']

qfs, target_qfs = ppp.group_init(
    2,
    ParallelizedEnsembleFlattenMLP,
    ensemble_size=num_qs,
    hidden_sizes=[M] * num_q_layers,
    input_size=obs_dim + action_dim,
    output_size=1,
    layer_norm=None,
    norm_input=norm_input,
    obs_norm_mean=obs_norm_mean,
    obs_norm_std=obs_norm_std,
)

policy = TanhGaussianPolicy(
    obs_dim=obs_dim,
    action_dim=action_dim,
    hidden_sizes=[M] * num_p_layers,
    layer_norm=None,
    norm_input=norm_input,
    obs_norm_mean=obs_norm_mean,
    obs_norm_std=obs_norm_std,
)


qfs1, target_qfs1 = ppp.group_init(
    2,
    BatchEnsembleFlattenMLP,
    ensemble_size=num_qs,
    hidden_sizes=[M] * num_q_layers,
    input_size=obs_dim + action_dim,
    output_size=1,
    layer_norm=None,
    norm_input=norm_input,
    obs_norm_mean=obs_norm_mean,
    obs_norm_std=obs_norm_std,
)


qfs2, target_qfs2 = ppp.group_init(
    2,
    MimoEnsembleFlattenMLP,
    ensemble_size=num_qs,
    hidden_sizes=[M] * num_q_layers,
    input_size=obs_dim + action_dim,
    output_size=1,
    layer_norm=None,
    norm_input=norm_input,
    obs_norm_mean=obs_norm_mean,
    obs_norm_std=obs_norm_std,
)





In [19]:
# Test networks

obs = torch.randn(20, obs_dim)
acts = torch.randn(20, action_dim)
print(qfs(obs, acts).shape)
print(target_qfs.sample(obs, acts).shape)
print(qfs1(obs, acts).shape)
print(target_qfs1.sample(obs, acts)[0].shape)

print(qfs2(obs, acts).shape)
print(target_qfs2.sample(obs, acts)[0].shape)


# def get_noised_obs(obs, actions, eps):
#         M, N, A = obs.shape[0], obs.shape[1], actions.shape[1]
#         size = 5
#         obs_std = 1
#         delta_s = 2 * eps * obs_std * (torch.rand(size, N, device=ptu.device) - 0.5)
#         tmp_obs = obs.reshape(-1, 1, N).repeat(1, size, 1).reshape(-1, N)
#         delta_s = delta_s.reshape(1, size, N).repeat(M, 1, 1).reshape(-1, N)
#         noised_obs = tmp_obs + delta_s
#         return M, A, size, noised_obs, delta_s

# get_noised_obs(obs, acts, 0.1)


torch.Size([10, 20, 1])
torch.Size([20, 1])
torch.Size([20, 1])
torch.Size([20, 1])
torch.Size([20, 1])
torch.Size([20, 1])


In [None]:
import torch

x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=float)
y = x.repeat_interleave(10, 0)  # repeat each element along dim 0 10 times
print(y)

x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=float)
y = x.repeat(10, 1)  # repeat dim 0 10 times, dim 1 1 time
print(y)

y = y.unsqueeze(0).reshape(2, 10, 3)
print(torch.min(y, dim=0, keepdim=True).values.shape)

x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=float)
A = 3
size = 4
print(x.repeat_interleave(size, 0))
print(x.reshape(-1, 1, A).repeat(1, size, 1).reshape(-1, A))
