In [2]:
import transformers
import torch
import torch.nn as nn

In [4]:
layer = nn.Linear(10, 5)

In [6]:
layer.weight.shape

torch.Size([5, 10])

In [30]:
from torch import Tensor
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
import math
from torch.nn import init, functional as F

class ParallelLinear(Module):
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: Tensor

    def __init__(self, in_features: int, out_features: int, bias: bool = True,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        self.left_weight = Parameter(torch.empty((out_features, in_features // 2), **factory_kwargs))
        self.right_weight = Parameter(torch.empty((out_features, in_features // 2 + in_features % 2), **factory_kwargs))
        
        if bias:
            self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
        # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
        # https://github.com/pytorch/pytorch/issues/57109
        init.kaiming_uniform_(self.left_weight, a=math.sqrt(5))
        init.kaiming_uniform_(self.right_weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.left_weight)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            init.uniform_(self.bias, -bound, bound)

    def forward(self, input: Tensor) -> Tensor:
        left_input = input[:(self.in_features // 2)]
        right_input = input[(self.in_features // 2):]
        left_product = F.linear(left_input, self.left_weight)
        right_product = F.linear(right_input, self.right_weight)
        return left_product + right_product + self.bias

    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )

In [31]:
parallelLinear = ParallelLinear(10, 5)

In [33]:
my_tensor = torch.Tensor([1.0, 1.0, 1.2, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 2.1])

parallelLinear.forward(my_tensor)

tensor([-0.5562,  0.8195,  0.3016, -1.4929,  1.5394], grad_fn=<AddBackward0>)

In [34]:
layer.forward(my_tensor)

tensor([-0.0338,  0.4054, -0.5144, -0.6205,  0.2743], grad_fn=<AddBackward0>)

In [11]:
# Tensor의 bias를 언제 더해줘야 하는지?

In [15]:
sample_tensor = torch.Tensor([0.0, 0.1, 0.2]).unsqueeze(dim=1)

In [26]:
sample_tensor

tensor([[0.0000],
        [0.1000],
        [0.2000]])