In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import wandb

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [3]:
X_train, Y_train, X_test, Y_test = torch.load('../data/dataset_2_40D.pt', weights_only=True)

# 将数据移动到适当的设备
X_train = X_train.to(device)
Y_train = Y_train.to(device)
X_test = X_test.to(device)
Y_test = Y_test.to(device)

X_train.shape, Y_train.shape, X_test.shape, Y_test.shape

(torch.Size([8000, 41]),
 torch.Size([8000]),
 torch.Size([2000, 41]),
 torch.Size([2000]))

In [4]:
# 使用 DataLoader 进行批处理
l = 64
train_dataset = TensorDataset(X_train, Y_train)
train_loader = DataLoader(train_dataset, batch_size=l, shuffle=True)

# 打印第一个批次的大小
for x, y in train_loader:
    print(x.shape, y.shape)
    break

torch.Size([64, 41]) torch.Size([64])


In [5]:
D = 40
m = 100
class PM_Euler(nn.Module):
    def __init__(self, input, hidden_layer, output):
        super(PM_Euler, self).__init__()
        self.relu = nn.ReLU()
        self.hidden_dim = hidden_layer
        self.W = nn.Parameter(torch.rand(input, hidden_layer, device=device), requires_grad=True)
        # HE初始化
        nn.init.kaiming_normal_(self.W, mode='fan_in', nonlinearity='relu')
        self.a = nn.Parameter(torch.rand(hidden_layer, output, device=device), requires_grad=True)
        nn.init.kaiming_normal_(self.a, mode='fan_in', nonlinearity='relu')

    def forward(self, x):
        z1 = self.relu(torch.mm(x, self.W))
        z2 = torch.mm(z1, self.a) / self.hidden_dim
        return z2

    def loss(self, y_pred, y_true):
        return (y_pred - y_true.reshape(y_pred.shape)) ** 2

model = PM_Euler(D + 1, m, 1).to(device)

# 计算模型W和a的Norm
def get_norm(model):
    return torch.norm(model.W).item(), torch.norm(model.a).item()

get_norm(model)

(8.965682029724121, 13.136861801147461)

In [6]:
epochs = 10000
lr = 1

train_losses = []
test_losses = []

In [7]:
config = {
    'learning_rate': lr,
    'batch_size': l,
    'epochs': epochs,
    'hidden_layer': m,
    'input': D + 1,
    'output': 1,
    'optimizer': 'IEQ'
}

import datetime

date = datetime.datetime.now().strftime("%m%d%H%M")
# wandb.init(project='Numerical Method', name=f"PM_IEQ_Example_2_{date}", config=config)

In [8]:
for epoch in range(epochs):
    flag = True
    for X, Y in train_loader:
        U = (model.forward(X) - Y.reshape(-1, 1))
        # if flag:
        #     U = (model.forward(X) - Y.reshape(-1, 1))
        #     flag = False
        theta_0 = torch.cat([model.W.flatten(), model.a.flatten()]).reshape(-1, 1)
        J = torch.zeros(U.shape[0], theta_0.numel())
        for i in range(U.shape[0]):
            U[i].backward(retain_graph=True)
            J[i] = torch.cat([model.W.grad.flatten(), model.a.grad.flatten()])
            model.W.grad.zero_()
            model.a.grad.zero_()
        with torch.no_grad():
            J_T = J.T
            # 计算量A，A=I + 2(J^T)J
            A = torch.eye(theta_0.numel(), device=device) + 2 * torch.mm(J_T, J)
            A_inv = torch.inverse(A)
            theta_1 = theta_0 - 2 * lr * torch.mm(torch.mm(A_inv, J_T), U)
            # 更新参数
            model.W.data = theta_1[:model.W.numel()].reshape(model.W.shape)
            model.a.data = theta_1[model.W.numel():].reshape(model.a.shape)
            model.W.grad.zero_()
            model.a.grad.zero_()

    with torch.no_grad():
        train_loss = model.loss(model(X_train), Y_train).mean()
        test_loss = model.loss(model(X_test), Y_test).mean()
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        norm = get_norm(model)
        # wandb.log({'epoch': epoch + 1,
        #            'train_loss': train_loss, 
        #            'test_loss': test_loss,
        #            'norm_W': norm[0],
        #            'norm_a': norm[1],
        #            'accuracy': 1 - test_loss,
        #            'r': r.item()})
        print(f'epoch {epoch + 1}, loss {train_loss:.4f}, test loss {test_loss:.4f}')

KeyboardInterrupt: 

In [None]:
wandb.finish()