In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import numpy as np

class Model(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers):
       
        super(Model, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers


        self.layers = nn.ModuleList() 
        self.layers.append(nn.Linear(input_size*2, hidden_size*2))
        self.layers.append(nn.ReLU())

        for _ in range(num_layers - 1):
            self.layers.append(nn.Linear(hidden_size*2, hidden_size*2))
            self.layers.append(nn.ReLU())

        self.layers.append(nn.Linear(hidden_size*2, output_size))
        self.layers.append(nn.ReLU())

        self.initialize_weights()

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        # x multiply by output_size
        x = x * self.output_size
        return x
    
    def initialize_weights(self):
        def initial_weights(dim):
            w = []
            total_dim = 0
            for i in range(0, len(dim) - 1):
                if i < len(dim) - 2:
                    temp = torch.randn(dim[i + 1], dim[i]) / torch.sqrt(dim[i + 1])
                    temp = torch.kron(torch.eye(2, dtype=torch.int), temp)
                    w.append(temp)
                    total_dim += dim[i + 1] * dim[i] * 4
                else:
                    temp = torch.randn(dim[i + 1], dim[i]) / torch.sqrt(dim[i])
                    temp = torch.kron(torch.tensor([[1, -1]], dtype=torch.float), temp)
                    w.append(temp)
                    total_dim += dim[i + 1] * dim[i] * 2

            return w, total_dim
        
        def INI(dim):  
            #### initialization
            #### dim consists of (d1, d2,...), where dl = 1 (placeholder, deprecated)
            w = []
            total_dim = 0
            for i in range(0, len(dim) - 1):
                if i < len(dim) - 2:
                    temp = np.random.randn(dim[i + 1], dim[i]) / np.sqrt(dim[i + 1])
                    temp = np.kron(np.eye(2, dtype=int), temp)
                    temp = torch.from_numpy(temp).to(torch.float32)
                    w.append(temp)
                    total_dim += dim[i + 1] * dim[i] *4
                else:
                    temp = np.random.randn(dim[i + 1], dim[i]) / np.sqrt(dim[i])
                    temp = np.kron([[1, -1]], temp)
                    temp = torch.from_numpy(temp).to(torch.float32)
                    w.append(temp)
                    total_dim += dim[i + 1] * dim[i]*2

            return w, total_dim




        input_size = self.input_size
        hidden_sizes = [self.hidden_size for layer in range(self.num_layers)]
        output_size = self.output_size

        dim_tensor = torch.tensor([input_size] + hidden_sizes + [output_size], dtype=torch.int)  # 将列表转换为Tensor

        print(dim_tensor)
       # w = initial_weights(dim_tensor)
        w,total_dim = INI(dim_tensor)
    
        idx = 0
        for layer in self.layers:
            if isinstance(layer, nn.Linear):
                layer.weight.data = w[idx]
                idx += 1


In [7]:
input_size = 5
hidden_size = 10 
output_size = 5
num_layers = 5

model = Model(input_size, hidden_size, output_size, num_layers)
print(model) 

tensor([ 5, 10, 10, 10, 10, 10,  5], dtype=torch.int32)
Model(
  (layers): ModuleList(
    (0): Linear(in_features=10, out_features=20, bias=True)
    (1): ReLU()
    (2): Linear(in_features=20, out_features=20, bias=True)
    (3): ReLU()
    (4): Linear(in_features=20, out_features=20, bias=True)
    (5): ReLU()
    (6): Linear(in_features=20, out_features=20, bias=True)
    (7): ReLU()
    (8): Linear(in_features=20, out_features=20, bias=True)
    (9): ReLU()
    (10): Linear(in_features=20, out_features=5, bias=True)
    (11): ReLU()
  )
)


In [8]:
params = list(model.parameters())

# 打印每个参数的形状
for param in params:
    print(param.shape)
    print(param)

torch.Size([20, 10])
Parameter containing:
tensor([[-0.2481,  0.3775, -0.0269, -0.1910, -0.3675, -0.0000,  0.0000, -0.0000,
         -0.0000, -0.0000],
        [ 0.3673, -0.0804, -0.2812,  0.1010,  0.2147,  0.0000, -0.0000, -0.0000,
          0.0000,  0.0000],
        [-0.1829,  0.3992,  0.0248, -0.0607,  0.2709, -0.0000,  0.0000,  0.0000,
         -0.0000,  0.0000],
        [-0.3759, -0.2806,  0.2293,  0.2149,  0.0905, -0.0000, -0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0430,  0.0586,  0.0880, -0.3302,  0.0563,  0.0000,  0.0000,  0.0000,
         -0.0000,  0.0000],
        [-0.4990,  0.2182, -0.5370, -0.5060,  0.5337, -0.0000,  0.0000, -0.0000,
         -0.0000,  0.0000],
        [ 1.0201,  0.5120,  0.5113, -0.0718,  0.1633,  0.0000,  0.0000,  0.0000,
         -0.0000,  0.0000],
        [-0.3178,  0.0042,  0.3415,  0.4703, -0.0895, -0.0000,  0.0000,  0.0000,
          0.0000, -0.0000],
        [ 0.0783,  0.2729,  0.4997,  0.1500,  0.5315,  0.0000,  0.0000,  0.0000,
    

In [9]:
a = torch.randn(1,10)
a.dtype

torch.float32

In [10]:
model(a)

tensor([[1.0443, 0.0000, 0.7857, 0.0000, 0.0000]], grad_fn=<MulBackward0>)