In [230]:
import torch
from torch import nn

In [200]:
layer = nn.Linear(10, 10)

In [202]:
def permutation(n: int):
    return list(range(1, n)) + [0]

In [201]:
type(layer)

torch.nn.modules.linear.Linear

In [313]:
class CrossStitchUnits(nn.Module):
    def __init__(self, layer_1: nn.Module, layer_2: nn.Module, init: float=0.9):
        super(CrossStitchUnits, self).__init__()
        self.type = None
        if type(layer_1) != type(layer_2):
            assert "Type layer is different"
        if layer_1.weight.shape != layer_2.weight.shape:
            assert " Error dimension"        
        if isinstance(layer_1, nn.Linear):
            self.type = 'linear'
        if isinstance(layer_1, nn.Conv2d):
            self.type = 'conv2D'
        if isinstance(layer_1, nn.Conv1d):
            self.type = 'conv1D'
            
        weight_1 = torch.tensor([[init, 1-init]]*layer_2.weight.shape[0])
        weight_2 = torch.tensor([[1-init, init]]*layer_2.weight.shape[0])
        self.mat_1 = nn.Parameter(weight_1)
        self.mat_2 = nn.Parameter(weight_2)
        
    def forward(self, x1, x2):
        data = torch.stack((x1, x2))
        data = data.permute(permutation(data.ndim))
        if self.type == "linear":
            res_1 = data*self.mat_1
            res_2 = data*self.mat_2
            res_1 = res_1.sum(-1)
            res_2 = res_2.sum(-1)
        elif self.type == "conv2D":
            data = data.permute(0, 2, 3, 1, 4)
            print(data.shape)
            print(self.mat_1.shape)
            res_1 = data*self.mat_1
            res_2 = data*self.mat_2
            res_1 = res_1.sum(-1)
            res_2 = res_2.sum(-1)            
            res_1 = res_1.permute(0, 3, 1, 2)
            res_2 = res_2.permute(0, 3, 1, 2)
        elif self.type == "conv1D":
            data = data.permute(0, 2, 1, 3)
            res_1 = data*self.mat_1
            res_2 = data*self.mat_2
            res_1 = res_1.sum(-1)
            res_2 = res_2.sum(-1)
            res_1 = res_1.permute(0, 2, 1)
            res_2 = res_2.permute(0, 2, 1)            
        return res_1, res_2

In [314]:
res_1, res_2 = cross(x1, x2)

In [315]:
print(x1.shape)

torch.Size([1, 3, 3])


In [316]:
data = torch.rand((1, 2, 5))
x1 = a(data)
x2 = b(data)

In [317]:
a = nn.Conv1d(2, 3, kernel_size=3)
b = nn.Conv1d(2, 3, kernel_size=3)
cross = CrossStitchUnits(a, b, init=0.5)

In [318]:
res_1, res_2 = cross(x1, x2)