In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import wandb
from tqdm import tqdm

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
X_train, Y_train, X_test, Y_test = torch.load('../data/dataset_2_40D.pt', weights_only=True)

# 将数据移动到适当的设备
X_train = X_train.to(device)
Y_train = Y_train.to(device)
X_test = X_test.to(device)
Y_test = Y_test.to(device)

# 使用 DataLoader 进行批处理
l = 64
train_dataset = TensorDataset(X_train, Y_train)
train_loader = DataLoader(train_dataset, batch_size=l, shuffle=True)

# 打印第一个批次的大小
for x, y in train_loader:
    print(x.shape, y.shape)
    break

torch.Size([64, 41]) torch.Size([64])


In [4]:
class PM_Euler(nn.Module):
    def __init__(self, input, hidden_layer, output):
        super(PM_Euler, self).__init__()
        self.relu = nn.ReLU()
        self.hidden_dim = hidden_layer
        self.W = nn.Parameter(torch.rand(input, hidden_layer, device=device), requires_grad=True)
        # HE初始化
        nn.init.kaiming_normal_(self.W, mode='fan_in', nonlinearity='relu')
        self.a = nn.Parameter(torch.rand(hidden_layer, output, device=device), requires_grad=True)
        nn.init.kaiming_normal_(self.a, mode='fan_in', nonlinearity='relu')
        
    def forward(self, x):
        z1 = self.relu(torch.mm(x, self.W))
        z2 = torch.mm(z1, self.a) / self.hidden_dim
        return z2

    def loss(self, y_pred, y_true):
        return (y_pred - y_true.reshape(y_pred.shape)) ** 2

    # 计算模型W和a的Norm
    def get_norm(self):
        return [torch.norm(self.W).item(), torch.norm(self.a).item()]

In [5]:
epochs = 10000
lr = 1
C = 1
_lambda = 4
D = 40
m = 100

config = {
    'learning_rate': lr,
    'batch_size': l,
    'epochs': epochs,
    'hidden_layer': m,
    'input': D + 1,
    'output': 1,
    'optimizer': 'SAV',
    'C': C,
    '_lambda': _lambda
}

train_losses = []
test_losses = []

In [6]:
import datetime
date = datetime.datetime.now().strftime("%m%d%H%M")
wandb.init(project='Numerical Method', name=f"PM_SAV_Example_2_{date}", config=config, notes="尝试合并theta更新")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpheonizard[0m ([33mpheonizard-university-of-nottingham[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
model = PM_Euler(D + 1, m, 1).to(device)

for epoch in tqdm(range(epochs)):
    flag = True
    for X, Y in train_loader:
        loss = model.loss(model(X), Y).mean()
        if flag:
            r = torch.sqrt(loss + C)
            flag = False
        loss.backward()
        with torch.no_grad():
            N_a = model.a.grad.clone()
            N_w = model.W.grad.clone()
            N = torch.cat([N_w.flatten(), N_a.flatten()])
            theta_2 = lr * N / (torch.sqrt(loss + C) * (1 + lr * _lambda))
            # theta_a_2 = lr * N_a / (torch.sqrt(loss + C) * (1 + lr * _lambda))
            # theta_w_2 = lr * N_w / (torch.sqrt(loss + C) * (1 + lr * _lambda))
            r = r / (1 + lr * torch.sum(N * (N/ (1 + lr * _lambda))) / (2 * (loss + C)))
            theta_2 = theta_2 * r.item()
            model.W -= theta_2[:m * (D + 1)].reshape(D + 1, m)
            model.a -= theta_2[(D + 1) * m:].reshape(m, 1)
            # model.a -= r.item() * theta_a_2
            # model.W -= r.item() * theta_w_2
            model.a.grad.zero_()
            model.W.grad.zero_()
    with torch.no_grad():
        train_loss = model.loss(model(X_train), Y_train).mean()
        test_loss = model.loss(model(X_test), Y_test).mean()
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        norm = model.get_norm()
        wandb.log({'epoch': epoch + 1,
                   'train_loss': train_loss, 
                   'test_loss': test_loss,
                   'norm_W': norm[0],
                   'norm_a': norm[1],
                   'accuracy': 1 - test_loss,
                   'r': r.item()})
        # print(f'epoch {epoch + 1}, loss {train_loss:.8f}, test loss {test_loss:.8f}')

100%|██████████| 10000/10000 [24:34<00:00,  6.78it/s]


In [8]:
wandb.finish()

0,1
accuracy,▁▅▇█████████████████████████████████████
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
norm_W,▁▃▅▆▇▇▇▇▇▇▇▇▇▇██████████████████████████
norm_a,▁▃▅▆▇▇▇▇▇▇▇▇▇███████████████████████████
r,▄▁▅▆▆▇▇▇▇▇▇██▇█▇▇███████████████████████
test_loss,█▄▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
accuracy,0.99241
epoch,10000.0
norm_W,27.54631
norm_a,28.53018
r,0.9992
test_loss,0.00759
train_loss,0.00095
