In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
import wandb

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
X_train, Y_train, X_test, Y_test = torch.load('./data/dataset_2_40D.pt')

# 将数据移动到适当的设备
X_train = X_train.to(device)
Y_train = Y_train.to(device)
X_test = X_test.to(device)
Y_test = Y_test.to(device)

X_train.shape, Y_train.shape, X_test.shape, Y_test.shape

(torch.Size([8000, 41]),
 torch.Size([8000]),
 torch.Size([2000, 41]),
 torch.Size([2000]))

In [4]:
# 使用 DataLoader 进行批处理
l = 64
train_dataset = TensorDataset(X_train, Y_train)
train_loader = DataLoader(train_dataset, batch_size=l, shuffle=True)

# 打印第一个批次的大小
for x, y in train_loader:
    print(x.shape, y.shape)
    break

torch.Size([64, 41]) torch.Size([64])


In [5]:
D = 40
m = 100

In [7]:
class Model(nn.Module):
    def __init__(self, input, hidden_layer, output):
        super(Model, self).__init__()
        self.relu = nn.ReLU()
        self.hidden_dim = hidden_layer
        self.W = nn.Parameter(torch.rand(input, hidden_layer, device=device), requires_grad=True)
        # HE初始化
        nn.init.kaiming_normal_(self.W, mode='fan_in', nonlinearity='relu')
        self.a = nn.Parameter(torch.rand(hidden_layer, output, device=device), requires_grad=True)
        nn.init.kaiming_normal_(self.a, mode='fan_in', nonlinearity='relu')
        
    def forward(self, x):
        # print(x.shape)
        z1 = self.relu(torch.mm(x, self.W))
        # print(z1.shape)
        z2 = torch.mm(z1, self.a) / self.hidden_dim
        return z2

    def loss(self, y_pred, y_true):
        return (y_pred - y_true.reshape(y_pred.shape)) ** 2

model = Model(D + 1, m, 1).to(device)

# 计算模型W和a的Norm
def get_norm(model):
    return torch.norm(model.W).item(), torch.norm(model.a).item()

get_norm(model)

(8.976402282714844, 14.475980758666992)

In [9]:
epochs = 10000
lr = 1
C = 1
_lambda = 4
r = 0
epsilon = 1e-8
beta_1 = 0.9
beta_2 = 0.999
m = 0
v = 0
J = 10
h = 0.01

train_losses = []
test_losses = []

In [10]:
config = {
    'learning_rate': lr,
    'batch_size': l,
    'epochs': epochs,
    'hidden_layer': m,
    'input': D + 1,
    'output': 1,
    'optimizer': 'Adam_SAV',
    'Approx': 'SPM',
    'C': C,
    '_lambda': _lambda,
    'r': r,
    'epsilon': epsilon,
    'J': J,
    'h': h
}

In [11]:
import datetime

date = datetime.datetime.now().strftime("%m%d%H%M")
wandb.init(project='Numerical Method',
           name=f"SPM_A_SAV_Example_2_{date}",
           config=config,
           notes="Fig9 with learning rate 1"
          )

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112984776910808, max=1.0…

In [12]:
for epoch in range(epochs):
    cnt = 0
    for X, Y in train_loader:
        loss = 0
        for j in range(J):
            original_params = [model.W.clone(), model.a.clone()]
            for param in model.parameters():
                param.data += h * torch.randn_like(param)
            loss += model.loss(model(X), Y).sum()
            model.W.data, model.a.data = original_params
        loss /= J 
        if cnt == 0:
            r = torch.sqrt(torch.tensor(loss + C, device=device))
            cnt = 1
        loss.backward()
        with torch.no_grad():
            #=========Nonlinear Term==========
            N_a_init = model.a.grad
            N_w_init = model.W.grad
            m_a = beta_1 * m + (1 - beta_1) * N_a_init
            m_w = beta_1 * m + (1 - beta_1) * N_w_init
            v_a = beta_2 * v + (1 - beta_2) * torch.norm(N_a_init) ** 2
            v_w = beta_2 * v + (1 - beta_2) * torch.norm(N_w_init) ** 2
            m_a_hat = m_a / (1 - beta_1 ** (epoch + 1))
            m_w_hat = m_w / (1 - beta_1 ** (epoch + 1))
            v_a_hat = v_a / (1 - beta_2 ** (epoch + 1))
            v_w_hat = v_w / (1 - beta_2 ** (epoch + 1))
            N_a = m_a_hat
            N_w = m_w_hat
            #=========Time Step Update========
            adaptive_lr = lr / (torch.sqrt(v_a_hat) + epsilon)
            #=========SAV Update========== 
            theta_a_2 = adaptive_lr * N_a / (torch.sqrt(loss + C) * (1 + adaptive_lr * _lambda))
            theta_w_2 = adaptive_lr * N_w / (torch.sqrt(loss + C) * (1 + adaptive_lr * _lambda))
            r = r / (1 + adaptive_lr * (torch.sum(N_a * (N_a / (1 + adaptive_lr * _lambda))) + torch.sum(N_w * (N_w) / (1 + adaptive_lr * _lambda))) / (2 * (loss + C)))
            model.a -= r.item() * theta_a_2
            model.W -= r.item() * theta_w_2
            model.a.grad.zero_()
            model.W.grad.zero_()
    with torch.no_grad():
        train_loss = model.loss(model(X_train), Y_train).mean()
        test_loss = model.loss(model(X_test), Y_test).mean()
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        norm = get_norm(model)
        wandb.log({'epoch': epoch + 1,
                   'train_loss': train_loss, 
                   'test_loss': test_loss,
                   'norm_W': norm[0],
                   'norm_a': norm[1],
                   'accuracy': 1 - test_loss,
                   'r': r.item(),
                   'adaptive_lr': adaptive_lr.item()})
        print(f'epoch {epoch + 1}, loss {train_loss:.6f}, test loss {test_loss:.6f}')

  r = torch.sqrt(torch.tensor(loss + C, device=device))


epoch 1, loss 0.243721, test loss 0.246770
epoch 2, loss 0.069172, test loss 0.071887
epoch 3, loss 0.062567, test loss 0.066732
epoch 4, loss 0.059862, test loss 0.064797
epoch 5, loss 0.057484, test loss 0.063140
epoch 6, loss 0.055285, test loss 0.061525
epoch 7, loss 0.053306, test loss 0.059738
epoch 8, loss 0.051385, test loss 0.058131
epoch 9, loss 0.049309, test loss 0.056177
epoch 10, loss 0.047189, test loss 0.054835
epoch 11, loss 0.045073, test loss 0.052901
epoch 12, loss 0.043131, test loss 0.050722
epoch 13, loss 0.041320, test loss 0.049230
epoch 14, loss 0.039679, test loss 0.047517
epoch 15, loss 0.038054, test loss 0.045767
epoch 16, loss 0.036494, test loss 0.044183
epoch 17, loss 0.035017, test loss 0.042670
epoch 18, loss 0.033640, test loss 0.041438
epoch 19, loss 0.032286, test loss 0.040091
epoch 20, loss 0.030965, test loss 0.038357
epoch 21, loss 0.029790, test loss 0.037344
epoch 22, loss 0.028673, test loss 0.036224
epoch 23, loss 0.027588, test loss 0.0350

In [13]:
wandb.finish()

VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁▅███▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
adaptive_lr,▁▂▃▃▃▅▆▆▆█▄▆▄▆▆▆▆▇▄▆▄▆▃▄▄▄▅▄▅▄▅▆▇▅▇▆▆▆▅▅
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
norm_W,▁▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇█████
norm_a,▇█████▇▇▇▇▆▆▆▆▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁
r,▁▃▅▅▆▆▆▆▇▆▆▆▆▆▆▆▇▇▆▇▇▇▇▇▇▇▇█▇▇▇█▇█▇▇▇▇▇█
test_loss,█▄▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
train_loss,█▄▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
accuracy,0.99229
adaptive_lr,31.21263
epoch,10000.0
norm_W,39.12162
norm_a,21.27081
r,0.99592
test_loss,0.00771
train_loss,0.00055
