In [82]:
import torch as th
from torch import nn
from torch.nn import functional as F
from random import choice, shuffle

In [94]:
class SCM(nn.Module):
    def __init__(self, x_dim: int, hidden_size: int, n_layer: int, drop_neuron_proba: float, n_features: int) -> None:
        super().__init__()
        
        self.__mlp = nn.ModuleList(
            nn.Linear(x_dim if i == 0 else hidden_size, hidden_size)
            for i in range(n_layer)
        )
        
        self.__mask = th.ge(th.rand(n_layer, hidden_size), drop_neuron_proba)
        
        act_fn = [F.relu, F.tanh, F.leaky_relu, F.elu]
        
        self.__act = [
            choice(act_fn) for _ in range(n_layer)
        ]
        
        self.__features_per_layer = self.__mask.nonzero()[:n_features].split(1, dim=1)
    
    def forward(self, x: th.Tensor) -> th.Tensor:
        out = x
        outs = []
        
        for i, (layer, act) in enumerate(zip(self.__mlp, self.__act)):
            out = layer(out)
            out = act(out)
            # TODO add epsilon ?
            out = out * self.__mask[None, i, :]
            outs.append(out)
        
        # stack layers output
        # (batch, layer, hidden_features)
        outs_stacked = th.stack(outs, dim=1)
        
        # select features
        out = outs_stacked[:, *self.__features_per_layer].squeeze(-1)
        
        # TODO pick y
        
        return out

In [95]:
scm = SCM(8, 16, 3, 0.5, 4)

In [96]:
x = th.randn(5, 8)

In [97]:
o = scm(x)

In [98]:
o.size()

torch.Size([5, 4])

In [99]:
o

tensor([[-5.8394e-01, -4.2783e-01, -1.9807e-03,  1.4071e-01],
        [ 1.7129e-01, -3.0003e-01, -2.3668e-03, -5.7277e-04],
        [-4.6973e-01, -1.3460e-01, -2.1900e-03,  4.6558e-02],
        [ 3.1427e-01, -6.4476e-01, -2.2367e-03, -8.2763e-05],
        [-5.2258e-01, -5.9966e-01, -1.9117e-03,  1.6734e-01]],
       grad_fn=<SqueezeBackward1>)

In [53]:
t = th.rand(2, 3) > 0.5

In [54]:
t

tensor([[False,  True,  True],
        [False,  True,  True]])

In [56]:
t.nonzero()

tensor([[0, 1],
        [0, 2],
        [1, 1],
        [1, 2]])

In [55]:
t.nonzero().split(1, dim=1)

(tensor([[0],
         [0],
         [1],
         [1]]),
 tensor([[1],
         [2],
         [1],
         [2]]))

In [35]:
t_2 = t.nonzero().tolist()

In [21]:
features = [ (i, j) for _, i, j in t_2]

In [38]:
shuffle(features)

In [42]:
features = features[:2]

In [43]:
features_per_layer = {}
for i, j in features:
    features_per_layer.setdefault(i, [])
    features_per_layer[i].append(j)

In [44]:
features_per_layer

{1: [2, 1]}