In [1]:
import torch
from modules import Simple_Perceptron
from data import grip_data

In [16]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda')
model = Simple_Perceptron.Simple_Perceptron(41, 100, 1)
model = model.to(device)
train_loader, X_train, Y_train, X_test, Y_test, D = grip_data.load_data(device=device)

In [17]:
lr = 0.01

In [18]:
import time
from tqdm import tqdm

epochs = 10

In [19]:
for epoch in tqdm(range(epochs), desc="Training Epochs"):
    start_time = time.time()  # 开始计时
    flag = True
    for X, Y in train_loader:
        if flag:
            flag = False
            U_W = model.forward(X) - Y.reshape(-1, 1)
            U_a = U_W.clone()

        # 处理 W 的梯度和更新
        theta_W_0 = model.W.flatten().reshape(-1, 1)
        J_W = torch.zeros(U_W.shape[0], theta_W_0.numel(), device=device)
        for i in range(U_W.shape[0]):
            U_W[i].backward(retain_graph=True)
            J_W[i] = model.W.grad.flatten()
            model.W.grad.zero_()
        
        with torch.no_grad():
            J_W_T = J_W.T.to(device)
            A_W = torch.eye(theta_W_0.numel(), device=device) + 2 * torch.mm(J_W_T, J_W)
            L_W = torch.linalg.cholesky(A_W)
            L_inv_W = torch.inverse(L_W)
            A_inv_W = torch.mm(L_inv_W.T, L_inv_W)
            theta_W_1 = theta_W_0 - 2 * lr * torch.mm(torch.mm(A_inv_W, J_W_T), U_W)
            model.W.data = theta_W_1.reshape(model.W.shape)
        
        # 处理 a 的梯度和更新
        theta_a_0 = model.a.flatten().reshape(-1, 1)
        J_a = torch.zeros(U_a.shape[0], theta_a_0.numel(), device=device)
        for i in range(U_a.shape[0]):
            U_a[i].backward(retain_graph=True)
            J_a[i] = model.a.grad.flatten()
            model.a.grad.zero_()
        
        with torch.no_grad():
            J_a_T = J_a.T.to(device)
            A_a = torch.eye(theta_a_0.numel(), device=device) + 2 * torch.mm(J_a_T, J_a)
            L_a = torch.linalg.cholesky(A_a)
            L_inv_a = torch.inverse(L_a)
            A_inv_a = torch.mm(L_inv_a.T, L_inv_a)
            theta_a_1 = theta_a_0 - 2 * lr * torch.mm(torch.mm(A_inv_a, J_a_T), U_a)
            model.a.data = theta_a_1.reshape(model.a.shape)
        
        # 使用递推公式更新 U_W 和 U_a
        U_W = U_W - 2 * lr * torch.mm(J_W, torch.mm(A_inv_W, torch.mm(J_W_T, U_W)))
        U_a = U_a - 2 * lr * torch.mm(J_a, torch.mm(A_inv_a, torch.mm(J_a_T, U_a)))

        # 清零梯度
        model.W.grad.zero_()
        model.a.grad.zero_()

    end_time = time.time()  # 结束计时
    epoch_duration = end_time - start_time  # 计算该 epoch 花费的时间
    print(f"Epoch {epoch+1}/{epochs} completed in {epoch_duration:.2f} seconds")

Training Epochs:   0%|          | 0/10 [00:00<?, ?it/s]


NotImplementedError: The operator 'aten::linalg_cholesky_ex.L' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

In [4]:
lr = 1
for X, Y in train_loader:
    U = (model(X) - Y.reshape(-1, 1))
    # 增广模型中的参数
    theta_0 = torch.cat([model.W.flatten(), model.a.flatten()]).reshape(-1, 1)
    # 计算雅可比矩阵
    J = torch.zeros(U.shape[0], theta_0.numel())
    for i in range(U.shape[0]):
        U[i].backward(retain_graph=True)
        J[i] = torch.cat([model.W.grad.flatten(), model.a.grad.flatten()])
        model.W.grad.zero_()
        model.a.grad.zero_()
    with torch.no_grad():
        J_T = J.T
        # 计算量A，A=I + 2(J^T)J
        A = torch.eye(theta_0.numel()) + 2 * torch.mm(J_T, J)
        A_inv = torch.inverse(A)
        theta_1 = theta_0 - 2 * lr * torch.mm(torch.mm(A_inv, J_T), U)
        update_U = torch.eye(U.shape[0]) - 2 * lr * torch.mm(J, torch.mm(A_inv, J_T))
        U = torch.mm(update_U, U)
        # 更新参数
        model.W.data = theta_1[:model.W.numel()].reshape(model.W.shape)
        model.a.data = theta_1[model.W.numel():].reshape(model.a.shape)
        model.W.grad.zero_()
        model.a.grad.zero_() 
    break 

torch.Size([64, 1])
torch.Size([4200, 1])
torch.Size([64, 64]) torch.Size([64, 64])
torch.Size([64, 1])
torch.Size([4200, 1])
torch.Size([64, 64]) torch.Size([64, 64])
torch.Size([64, 1])
torch.Size([4200, 1])
torch.Size([64, 64]) torch.Size([64, 64])
torch.Size([64, 1])
torch.Size([4200, 1])
torch.Size([64, 64]) torch.Size([64, 64])
torch.Size([64, 1])
torch.Size([4200, 1])
torch.Size([64, 64]) torch.Size([64, 64])
torch.Size([64, 1])
torch.Size([4200, 1])
torch.Size([64, 64]) torch.Size([64, 64])
torch.Size([64, 1])
torch.Size([4200, 1])
torch.Size([64, 64]) torch.Size([64, 64])
torch.Size([64, 1])
torch.Size([4200, 1])
torch.Size([64, 64]) torch.Size([64, 64])
torch.Size([64, 1])
torch.Size([4200, 1])
torch.Size([64, 64]) torch.Size([64, 64])
torch.Size([64, 1])
torch.Size([4200, 1])
torch.Size([64, 64]) torch.Size([64, 64])
torch.Size([64, 1])
torch.Size([4200, 1])
torch.Size([64, 64]) torch.Size([64, 64])
torch.Size([64, 1])
torch.Size([4200, 1])
torch.Size([64, 64]) torch.Size([6

KeyboardInterrupt: 