In [119]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import numpy as np

class Model(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers):
       
        super(Model, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers


        self.layers = nn.ModuleList() 
        self.layers.append(nn.Linear(input_size*2, hidden_size*2))
        self.layers.append(nn.ReLU())

        for _ in range(num_layers - 1):
            self.layers.append(nn.Linear(hidden_size*2, hidden_size*2))
            self.layers.append(nn.ReLU())

        self.layers.append(nn.Linear(hidden_size*2, output_size))
        self.layers.append(nn.ReLU())

        self.initialize_weights()

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
    
    def initialize_weights(self):
        def initial_weights(dim):
            w = []
            total_dim = 0
            for i in range(0, len(dim) - 1):
                if i < len(dim) - 2:
                    temp = torch.randn(dim[i + 1], dim[i]) / torch.sqrt(dim[i + 1])
                    temp = torch.kron(torch.eye(2, dtype=torch.int), temp)
                    w.append(temp)
                    total_dim += dim[i + 1] * dim[i] * 4
                else:
                    temp = torch.randn(dim[i + 1], dim[i]) / torch.sqrt(dim[i])
                    temp = torch.kron(torch.tensor([[1, -1]], dtype=torch.float), temp)
                    w.append(temp)
                    total_dim += dim[i + 1] * dim[i] * 2

            return w, total_dim
        
        def INI(dim):  
            #### initialization
            #### dim consists of (d1, d2,...), where dl = 1 (placeholder, deprecated)
            w = []
            total_dim = 0
            for i in range(0, len(dim) - 1):
                if i < len(dim) - 2:
                    temp = np.random.randn(dim[i + 1], dim[i]) / np.sqrt(dim[i + 1])
                    temp = np.kron(np.eye(2, dtype=int), temp)
                    temp = torch.from_numpy(temp).to(torch.float32)
                    w.append(temp)
                    total_dim += dim[i + 1] * dim[i] *4
                else:
                    temp = np.random.randn(dim[i + 1], dim[i]) / np.sqrt(dim[i])
                    temp = np.kron([[1, -1]], temp)
                    temp = torch.from_numpy(temp).to(torch.float32)
                    w.append(temp)
                    total_dim += dim[i + 1] * dim[i]*2

            return w, total_dim




        input_size = self.input_size
        hidden_sizes = [self.hidden_size for layer in range(self.num_layers)]
        output_size = self.output_size

        dim_tensor = torch.tensor([input_size] + hidden_sizes + [output_size], dtype=torch.int)  # 将列表转换为Tensor

        print(dim_tensor)
       # w = initial_weights(dim_tensor)
        w,total_dim = INI(dim_tensor)
    
        idx = 0
        for layer in self.layers:
            if isinstance(layer, nn.Linear):
                layer.weight.data = w[idx]
                idx += 1


In [120]:
input_size = 5
hidden_size = 10 
output_size = 1
num_layers = 5

model = Model(input_size, hidden_size, output_size, num_layers)
print(model) 

tensor([ 5, 10, 10, 10, 10, 10,  1], dtype=torch.int32)
Model(
  (layers): ModuleList(
    (0): Linear(in_features=10, out_features=20, bias=True)
    (1): ReLU()
    (2): Linear(in_features=20, out_features=20, bias=True)
    (3): ReLU()
    (4): Linear(in_features=20, out_features=20, bias=True)
    (5): ReLU()
    (6): Linear(in_features=20, out_features=20, bias=True)
    (7): ReLU()
    (8): Linear(in_features=20, out_features=20, bias=True)
    (9): ReLU()
    (10): Linear(in_features=20, out_features=1, bias=True)
    (11): ReLU()
  )
)


In [121]:
params = list(model.parameters())

# 打印每个参数的形状
for param in params:
    print(param.shape)
    print(param)

torch.Size([20, 10])
Parameter containing:
tensor([[ 0.1189, -0.4775, -0.0099, -0.0684,  0.1807,  0.0000, -0.0000, -0.0000,
         -0.0000,  0.0000],
        [-0.1309,  0.2276, -0.3036,  0.4168, -0.0725, -0.0000,  0.0000, -0.0000,
          0.0000, -0.0000],
        [-0.1819,  0.3912,  0.2109,  0.3624, -0.1779, -0.0000,  0.0000,  0.0000,
          0.0000, -0.0000],
        [ 0.4101,  0.0438, -0.0158, -0.4862, -0.1244,  0.0000,  0.0000, -0.0000,
         -0.0000, -0.0000],
        [ 0.3138, -0.0529, -0.2332, -0.1647,  0.5417,  0.0000, -0.0000, -0.0000,
         -0.0000,  0.0000],
        [ 0.2220,  0.0849,  0.1084,  0.0495, -0.1738,  0.0000,  0.0000,  0.0000,
          0.0000, -0.0000],
        [ 0.6981, -0.5076,  0.0119,  0.3173, -0.3090,  0.0000, -0.0000,  0.0000,
          0.0000, -0.0000],
        [-0.1380, -0.4961, -0.8045, -0.0448,  0.1800, -0.0000, -0.0000, -0.0000,
         -0.0000,  0.0000],
        [ 0.2471,  0.7437,  0.2603, -0.1367,  0.0692,  0.0000,  0.0000,  0.0000,
    

In [124]:
a = torch.randn(1,10)
a.dtype

torch.float32

In [125]:
model(a)

tensor([[0.2780]], grad_fn=<ReluBackward0>)

torch.float64