In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import wandb
from tqdm import tqdm

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
X_train, Y_train, X_test, Y_test = torch.load('../data/dataset_2_40D.pt', weights_only=True)

# 将数据移动到适当的设备
X_train = X_train.to(device)
Y_train = Y_train.to(device)
X_test = X_test.to(device)
Y_test = Y_test.to(device)

# 使用 DataLoader 进行批处理
l = 64
train_dataset = TensorDataset(X_train, Y_train)
train_loader = DataLoader(train_dataset, batch_size=l, shuffle=True)

D = 40
m = 100

class PM_Euler(nn.Module):
    def __init__(self, input, hidden_layer, output):
        super(PM_Euler, self).__init__()
        self.relu = nn.ReLU()
        self.hidden_dim = hidden_layer
        self.W = nn.Parameter(torch.rand(input, hidden_layer, device=device), requires_grad=True)
        # HE初始化
        nn.init.kaiming_normal_(self.W, mode='fan_in', nonlinearity='relu')
        self.a = nn.Parameter(torch.rand(hidden_layer, output, device=device), requires_grad=True)
        nn.init.kaiming_normal_(self.a, mode='fan_in', nonlinearity='relu')
        
    def forward(self, x):
        # print(x.shape)
        z1 = self.relu(torch.mm(x, self.W))
        # print(z1.shape)
        z2 = torch.mm(z1, self.a) / self.hidden_dim
        return z2

    def loss(self, y_pred, y_true):
        return (y_pred - y_true.reshape(y_pred.shape)) ** 2
    
        # 计算模型W和a的Norm
    def get_norm(self):
        return [torch.norm(self.W).item(), torch.norm(self.a).item()]


model = PM_Euler(D + 1, m, 1).to(device)

In [4]:
epochs = 10000
lr = 1
_lambda = 4

train_losses = []
test_losses = []

import datetime

config = {
    'learning_rate': lr,
    'batch_size': l,
    'epochs': epochs,
    'hidden_layer': m,
    'input': D + 1,
    '_lambda': _lambda,
    'output': 1,
    'optimizer': 'ESAV'
}

date = datetime.datetime.now().strftime("%m%d%H%M")
wandb.init(project='Numerical Method', name=f"PM_ESAV_Example_2_{date}", config=config, 
           notes="尝试合并W和a的更新，使用ESAV方法")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpheonizard[0m ([33mpheonizard-university-of-nottingham[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
for epoch in tqdm(range(epochs)):
    cnt = 0
    for X, Y in train_loader:
        loss = model.loss(model(X), Y).mean()
        if cnt == 0:
            r = torch.exp(loss)
            cnt = 1
        loss.backward()
        with torch.no_grad():
            N_w = model.W.grad
            N_a = model.a.grad
            N_theta = torch.cat((N_w.flatten(), N_a.flatten()))
            linear_N_theta = N_theta / (1 + _lambda * lr)
            # linear_N_a = N_a / (1 + _lambda * lr)
            # linear_N_w = N_w / (1 + _lambda * lr)
            # theta_a_2 = - linear_N_a * lr / (torch.exp(loss))
            # theta_w_2 = - linear_N_w * lr / (torch.exp(loss))
            theta_2 = - linear_N_theta * lr / (torch.exp(loss))
            #=========Update SAV R================
            r = r / (1 + lr * (torch.sum(N_theta * linear_N_theta)))
            #=========Update Params================
            theta_2 = r.item() * theta_2
            model.W += theta_2[:m * (D + 1)].reshape(D + 1, m)
            model.a += theta_2[m * (D + 1):].reshape(m, 1)
            # model.a += r.item() * theta_a_2
            # model.W += r.item() * theta_w_2
            model.a.grad.zero_()
            model.W.grad.zero_()
            wandb.log({'r': r.item()})
    with torch.no_grad():
        train_loss = model.loss(model(X_train), Y_train).mean()
        test_loss = model.loss(model(X_test), Y_test).mean()
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        norm = model.get_norm()
        wandb.log({'epoch': epoch + 1,
                   'train_loss': train_loss, 
                   'test_loss': test_loss,
                   'norm_W': norm[0],
                   'norm_a': norm[1],
                   'accuracy': 1 - test_loss})
        # print(f'epoch {epoch + 1}, loss {train_loss:.8f}, test loss {test_loss:.8f}')

100%|██████████| 10000/10000 [26:07<00:00,  6.38it/s]


In [6]:
wandb.finish()

0,1
accuracy,▁▅▇█████████████████████████████████████
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
norm_W,▁▃▅▆▇▇▇▇▇▇▇▇▇▇▇█████████████████████████
norm_a,▁▃▅▆▇▇▇▇▇▇▇▇▇▇▇█████████████████████████
r,█▄▂▁▁▁▁▂▂▂▂▂▁▁▁▂▂▂▁▁▁▁▂▂▂▂▂▁▁▂▂▂▂▂▁▁▂▂▂▂
test_loss,█▄▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
accuracy,0.99229
epoch,10000.0
norm_W,27.66876
norm_a,28.9782
r,0.99816
test_loss,0.00771
train_loss,0.00096
