In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
import wandb

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [12]:
X_train, Y_train, X_test, Y_test = torch.load('../data/dataset_2_40D.pt')

# 将数据移动到适当的设备
X_train = X_train.to(device)
Y_train = Y_train.to(device)
X_test = X_test.to(device)
Y_test = Y_test.to(device)

# 使用 DataLoader 进行批处理
l = 64
train_dataset = TensorDataset(X_train, Y_train)
train_loader = DataLoader(train_dataset, batch_size=l, shuffle=True)

D = 40
m = 100

class PM_Euler(nn.Module):
    def __init__(self, input, hidden_layer, output):
        super(PM_Euler, self).__init__()
        self.relu = nn.ReLU()
        self.hidden_dim = hidden_layer
        self.W = nn.Parameter(torch.rand(input, hidden_layer, device=device), requires_grad=True)
        # HE初始化
        nn.init.kaiming_normal_(self.W, mode='fan_in', nonlinearity='relu')
        self.a = nn.Parameter(torch.rand(hidden_layer, output, device=device), requires_grad=True)
        nn.init.kaiming_normal_(self.a, mode='fan_in', nonlinearity='relu')
        
    def forward(self, x):
        # print(x.shape)
        z1 = self.relu(torch.mm(x, self.W))
        # print(z1.shape)
        z2 = torch.mm(z1, self.a) / self.hidden_dim
        return z2

    def loss(self, y_pred, y_true):
        return (y_pred - y_true.reshape(y_pred.shape)) ** 2

model = PM_Euler(D + 1, m, 1).to(device)

# 计算模型W和a的Norm
def get_norm(model):
    return torch.norm(model.W).item(), torch.norm(model.a).item()

get_norm(model)

(9.075422286987305, 13.460860252380371)

In [9]:
epochs = 10000
lr = 0.01
_lambda = 0

train_losses = []
test_losses = []

import datetime

wandb.config = {
    'learning_rate': lr,
    'batch_size': l,
    'epochs': epochs,
    'hidden_layer': m,
    'input': D + 1,
    '_lambda': _lambda,
    'output': 1,
    'optimizer': 'ESAV'
}

date = datetime.datetime.now().strftime("%m%d%H%M")
wandb.init(project='Numerical Method', name=f"PM_ESAV_Example_2_{date}")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpheonizard[0m ([33mpheonizard-university-of-nottingham[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [14]:
for epoch in range(epochs):
    cnt = 0
    for X, Y in train_loader:
        loss = model.loss(model(X), Y).mean()
        if cnt == 0:
            r = torch.exp(loss)
            cnt = 1
        loss.backward()
        with torch.no_grad():
            theta_w_1 = model.W.data
            theta_a_1 = model.a.data
            N_a = model.a.grad.clone()
            N_w = model.W.grad.clone()
            b_a = r * N_a / torch.exp(loss)
            b_w = r * N_w / torch.exp(loss)
            model.a.data = (theta_a_1 - lr * b_a) / (1 + lr * _lambda)
            model.W.data = (theta_w_1 - lr * b_w) / (1 + lr * _lambda)
            #=========Update SAV R================
            r = torch.exp(torch.log(r) + torch.sum(b_a * (model.a - theta_a_1)) + torch.sum(b_w * (model.W - theta_w_1)))
            model.a.grad.zero_()
            model.W.grad.zero_()
    with torch.no_grad():
        train_loss = model.loss(model(X_train), Y_train).mean()
        test_loss = model.loss(model(X_test), Y_test).mean()
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        norm = get_norm(model)
        wandb.log({'epoch': epoch + 1,
                   'train_loss': train_loss, 
                   'test_loss': test_loss,
                   'norm_W': norm[0],
                   'norm_a': norm[1],
                   'accuracy': 1 - test_loss,
                   'r': r.item()})
        print(f'epoch {epoch + 1}, loss {train_loss:.8f}, test loss {test_loss:.8f}')

epoch 1, loss 1.0229, test loss 1.0452
epoch 2, loss 1.0039, test loss 1.0259
epoch 3, loss 0.9872, test loss 1.0090
epoch 4, loss 0.9700, test loss 0.9916
epoch 5, loss 0.9489, test loss 0.9703
epoch 6, loss 0.9280, test loss 0.9491
epoch 7, loss 0.9140, test loss 0.9350
epoch 8, loss 0.8969, test loss 0.9175
epoch 9, loss 0.8710, test loss 0.8912
epoch 10, loss 0.8526, test loss 0.8725
epoch 11, loss 0.8289, test loss 0.8484
epoch 12, loss 0.8108, test loss 0.8301
epoch 13, loss 0.7958, test loss 0.8148
epoch 14, loss 0.7791, test loss 0.7978
epoch 15, loss 0.7600, test loss 0.7784
epoch 16, loss 0.7478, test loss 0.7661
epoch 17, loss 0.7330, test loss 0.7509
epoch 18, loss 0.7186, test loss 0.7363
epoch 19, loss 0.7015, test loss 0.7189
epoch 20, loss 0.6863, test loss 0.7034
epoch 21, loss 0.6685, test loss 0.6853
epoch 22, loss 0.6533, test loss 0.6699
epoch 23, loss 0.6329, test loss 0.6491
epoch 24, loss 0.6160, test loss 0.6319
epoch 25, loss 0.5966, test loss 0.6122
epoch 26,

KeyboardInterrupt: 

In [15]:
wandb.finish()

0,1
accuracy,▁▂▃▄▄▅▆▆▇▇▇█████████████████████████████
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
norm_W,▁▁▁▂▂▃▃▄▄▅▅▆▆▆▇▇▇▇▇▇████████████████████
norm_a,▁▁▁▂▂▃▃▄▄▅▅▆▆▆▇▇▇▇▇▇████████████████████
r,██▃▆▄▄▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_loss,█▇▆▅▅▄▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▇▆▅▅▄▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
accuracy,0.9332
epoch,264.0
norm_W,13.3976
norm_a,16.68236
r,1.05752
test_loss,0.0668
train_loss,0.06221
