In [None]:
import torch as th
from torch import nn
from torch.distributions import MultivariateNormal
from torch.nn import functional as F
from random import choice, shuffle, randint

In [None]:
from typing import Tuple


class SCM(nn.Module):
    def __init__(self, drop_neuron_proba: float, n_features: int, layer_bounds: Tuple[int, int] = (4, 16), node_bounds: Tuple[int, int] = (16, 64)) -> None:
        super().__init__()
        
        n_layer = randint(layer_bounds[0], layer_bounds[1])
        hidden_size = randint(node_bounds[0], node_bounds[1])
        
        self.__mlp = nn.ModuleList(
            nn.Linear(hidden_size, hidden_size, bias=False)
            for i in range(n_layer)
        )
        
        self.__mask = th.ge(th.rand(n_layer, hidden_size), drop_neuron_proba)
        
        act_fn = [F.relu, F.tanh, F.leaky_relu, F.elu]
        
        self.__act = [
            choice(act_fn) for _ in range(n_layer)
        ]
        
        self.__features_per_layer = self.__mask.nonzero()[:n_features].split(1, dim=1)
        
        self.__hidden_size = hidden_size
        
        cov_mat = th.randn(hidden_size, hidden_size)
        cov_mat = th.matmul(cov_mat.transpose(0, 1), cov_mat)
        
        loc = th.randn(hidden_size)
        
        self.__distribution = MultivariateNormal(loc, cov_mat)
    
    @th.no_grad()
    def forward(self, batch_size: int) -> th.Tensor:
        out = self.__distribution.sample((batch_size,))
        outs = []
        
        for i, (layer, act) in enumerate(zip(self.__mlp, self.__act)):
            out = layer(out) + self.__distribution.sample((batch_size,))
            out = act(out)
            out = out * self.__mask[None, i, :]
            outs.append(out)
        
        # stack layers output
        # (batch, layer, hidden_features)
        outs_stacked = th.stack(outs, dim=1)
        
        # select features
        out = outs_stacked[:, *self.__features_per_layer].squeeze(-1)
        
        # TODO pick y
        
        return out

In [None]:
scm = SCM(0.2, 128)

In [None]:
o = scm(3)

In [None]:
o.size()

In [None]:
o

In [None]:
t = th.rand(2, 3) > 0.5

In [None]:
t

In [None]:
t.nonzero()

In [None]:
t.nonzero().split(1, dim=1)

In [None]:
t_2 = t.nonzero().tolist()

In [None]:
features = [ (i, j) for _, i, j in t_2]

In [None]:
shuffle(features)

In [None]:
features = features[:2]

In [None]:
features_per_layer = {}
for i, j in features:
    features_per_layer.setdefault(i, [])
    features_per_layer[i].append(j)

In [None]:
features_per_layer