In [2]:
import torch
import torch.nn as nn
import numpy as np
from torch.nn.parameter import Parameter, UninitializedParameter
import math

In [3]:
class ChannelWise(nn.Module):
    def __init__(self, in_features: int, out_features: int, bias: bool = True,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))
        if bias:
            self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
        # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
        # https://github.com/pytorch/pytorch/issues/57109
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            init.uniform_(self.bias, -bound, bound)

    def forward(self, input: Tensor) -> Tensor:
        return F.linear(input, self.weight, self.bias)

    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )

NameError: name 'Tensor' is not defined

In [7]:
class ChannelWise(nn.Linear):
    def __init__(self, in_features: int, out_features: int, bias: bool = False, device=None, dtype=None,
                 input_split_size: int = 1):
        super(ChannelWise, self).__init__(in_features, out_features, bias, device=None, dtype=None)
        self.factory_kwargs = {'device': device, 'dtype': dtype}

        self.in_features = in_features
        self.out_features = out_features
        self.input_split_size = input_split_size

    def forward(self, input):
        input_tensor = torch.chunk(input, self.input_split_size, dim=1)

        length = len(input_tensor)
        split_input_size = math.ceil(self.in_features / length)
        split_output_size = math.ceil(self.out_features / length)

        output_tensor = []
        for i in range(length):
            weight = Parameter(torch.empty((split_output_size, split_input_size), **self.factory_kwargs))
            bias = Parameter(torch.empty(split_output_size, **self.factory_kwargs))
            output_tensor.append(F.linear(input_tensor[i], weight, bias))
        return torch.cat(output_tensor, dim=1)

In [20]:
import torch
import torch.nn as nn

class ChannelWiseLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        # 继承父类的初始化函数
        super(ChannelWiseLinear, self).__init__()
        
        # 输入特征数
        self.in_features = in_features
        # 输出特征数
        self.out_features = out_features
        # 是否使用偏置项
        self.bias = bias
        
        # 创建可学习的权重参数
        self.weights = nn.Parameter(torch.Tensor(out_features, in_features))
        # 创建可学习的偏置项参数（如果使用）
        if bias:
            self.biases = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('biases', None)
        
        # 初始化参数
        self.reset_parameters()
        
    def reset_parameters(self):
        # 使用 Kaiming 初始化权重参数
        nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5))
        # 如果使用偏置项，则使用均匀分布初始化偏置项参数
        if self.biases is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weights)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.biases, -bound, bound)

    def forward(self, input):
        # 创建与输入大小相同的零张量
        output = torch.zeros_like(input)
        # 循环遍历每个输出特征
        for i in range(self.out_features):
            # 使用线性操作处理该特征
            output[:, i] = torch.nn.functional.linear(input, self.weights[i, :], self.biases[i])
        return output


<class 'torch.nn.modules.linear.Linear'>


In [35]:
x = torch.randn(10, 3)

In [36]:
x

tensor([[ 0.4693, -1.2657, -0.5899],
        [-0.0104, -0.6261, -2.3330],
        [ 0.0194, -0.2384,  0.5886],
        [-1.1524, -0.0638,  0.6801],
        [-1.2305,  0.6501,  1.8660],
        [-0.5634, -0.3469,  0.7891],
        [-0.1487, -0.1093, -0.6186],
        [ 0.4206, -1.8290,  0.5863],
        [-0.4153,  0.2154, -1.4775],
        [ 0.8700, -0.5762, -0.3122]])

In [43]:
torch.chunk(x, 4, dim = 1)

(tensor([[ 0.4693],
         [-0.0104],
         [ 0.0194],
         [-1.1524],
         [-1.2305],
         [-0.5634],
         [-0.1487],
         [ 0.4206],
         [-0.4153],
         [ 0.8700]]),
 tensor([[-1.2657],
         [-0.6261],
         [-0.2384],
         [-0.0638],
         [ 0.6501],
         [-0.3469],
         [-0.1093],
         [-1.8290],
         [ 0.2154],
         [-0.5762]]),
 tensor([[-0.5899],
         [-2.3330],
         [ 0.5886],
         [ 0.6801],
         [ 1.8660],
         [ 0.7891],
         [-0.6186],
         [ 0.5863],
         [-1.4775],
         [-0.3122]]))

In [53]:
a = []
for i in range(len(torch.chunk(x, 4, dim = 1))):
    a.append(torch.chunk(x, 4, dim = 1))

In [54]:
math.ceil(10/3)

4

In [55]:
a

[]