In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
import wandb

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
X_train, Y_train, X_test, Y_test = torch.load('./data/dataset_2_40D.pt')

# 将数据移动到适当的设备
X_train = X_train.to(device)
Y_train = Y_train.to(device)
X_test = X_test.to(device)
Y_test = Y_test.to(device)

X_train.shape, Y_train.shape, X_test.shape, Y_test.shape

(torch.Size([8000, 41]),
 torch.Size([8000]),
 torch.Size([2000, 41]),
 torch.Size([2000]))

In [4]:
# 使用 DataLoader 进行批处理
l = 64
train_dataset = TensorDataset(X_train, Y_train)
train_loader = DataLoader(train_dataset, batch_size=l, shuffle=True)

# 打印第一个批次的大小
for x, y in train_loader:
    print(x.shape, y.shape)
    break

torch.Size([64, 41]) torch.Size([64])


In [5]:
D = 40
m = 100

In [6]:
class Model(nn.Module):
    def __init__(self, input, hidden_layer, output):
        super(Model, self).__init__()
        self.relu = nn.ReLU()
        self.hidden_dim = hidden_layer
        self.W = nn.Parameter(torch.rand(input, hidden_layer, device=device), requires_grad=True)
        # HE初始化
        nn.init.kaiming_normal_(self.W, mode='fan_in', nonlinearity='relu')
        self.a = nn.Parameter(torch.rand(hidden_layer, output, device=device), requires_grad=True)
        nn.init.kaiming_normal_(self.a, mode='fan_in', nonlinearity='relu')
        
    def forward(self, x):
        # print(x.shape)
        z1 = self.relu(torch.mm(x, self.W))
        # print(z1.shape)
        z2 = torch.mm(z1, self.a) / self.hidden_dim
        return z2

    def loss(self, y_pred, y_true):
        return (y_pred - y_true.reshape(y_pred.shape)) ** 2

model = Model(D + 1, m, 1).to(device)

# 计算模型W和a的Norm
def get_norm(model):
    return torch.norm(model.W).item(), torch.norm(model.a).item()

get_norm(model)

(9.051899909973145, 13.263046264648438)

In [7]:
epochs = 10000
lr = 1
C = 1
_lambda = 4
r = 0
epsilon = 1e-8
beta_1 = 0.9
beta_2 = 0.999
J = 10
h = 0.0001
# Define the relax parameters 
r_wave = 0
r_hat = 0
r = 0
a = 0
b = 0
c = 0
ellipsis_0 = 0
ratio_n = 0.99


train_losses = []
test_losses = []

In [8]:
config = {
    'learning_rate': lr,
    'batch_size': l,
    'epochs': epochs,
    'hidden_layer': m,
    'input': D + 1,
    'output': 1,
    'J': J,
    'h': h,
    'optimizer': 'Adam_Relax_SAV',
    'Approx Method': 'SPM',
    'C': C,
    '_lambda': _lambda,
    'r': r,
    'epsilon': epsilon,
    'ratio_n': 0.99
}

In [9]:
import datetime

date = datetime.datetime.now().strftime("%m%d%H%M")
wandb.init(project='Numerical Method', 
           name=f"SPM_A_RelaxSAV_Example_2_{date}", 
           config=config,
           notes="修正了大多问题，希望能够正常运行")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpheonizard[0m ([33mpheonizard-university-of-nottingham[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112421192228794, max=1.0…

In [None]:
for epoch in range(epochs):
    cnt = 0
    m_a, m_w, v = 0, 0, 0
    for X, Y in train_loader:
        loss = 0
        for j in range(J):
            original_params = [model.W.clone(), model.a.clone()]
            for param in model.parameters():
                param.data += h * torch.randn_like(param)
            loss += model.loss(model(X), Y).mean()
            model.W.data, model.a.data = original_params
        loss /= J 
        if cnt == 0:
            r = torch.sqrt(torch.tensor(loss + C, device=device))
        loss.backward()
        with torch.no_grad():
            #=========Nonlinear Term==========
            N_a_init = model.a.grad
            N_w_init = model.W.grad
            m_a = beta_1 * m_a + (1 - beta_1) * N_a_init
            m_w = beta_1 * m_w + (1 - beta_1) * N_w_init
            v = beta_2 * v + (1 - beta_2) * (torch.norm(N_a_init) ** 2 + torch.norm(N_w_init) ** 2)
            m_a_hat = m_a / (1 - beta_1 ** (cnt + 1))
            m_w_hat = m_w / (1 - beta_1 ** (cnt + 1))
            v_hat = v / (1 - beta_2 ** (cnt + 1))
            N_a = m_a_hat
            N_w = m_w_hat
            #=========Time Step Update========
            adaptive_lr = lr / (torch.sqrt(v_hat) + epsilon)
            #=========SAV Update========== 
            theta_a_1 = model.a.clone()
            theta_w_1 = model.W.clone()
            theta_a_2 = - adaptive_lr * N_a / (torch.sqrt(loss + C) * (1 + adaptive_lr * _lambda))
            theta_w_2 = - adaptive_lr * N_w / (torch.sqrt(loss + C) * (1 + adaptive_lr * _lambda))
            r_wave = r / (1 + adaptive_lr * (torch.sum(N_a * (N_a / (1 + adaptive_lr * _lambda))) + torch.sum(N_w * (N_w) / (1 + adaptive_lr * _lambda))) / (2 * (loss + C)))
            model.a += r_wave.item() * theta_a_2
            model.W += r_wave.item() * theta_w_2
            model.a.grad.zero_()
            model.W.grad.zero_()
            #=========Relax Update==========
            tmp_loss = model.loss(model(X), Y).mean()
            r_hat = torch.sqrt(tmp_loss + C)
            a = (r_wave - r_hat) ** 2
            b = 2 * r_hat * (r_wave - r_hat)
            c = r_hat ** 2 - r_wave ** 2 - ratio_n * (torch.norm(model.a - theta_a_1) ** 2 + torch.norm(model.W - theta_w_1) ** 2) / adaptive_lr
            if a == 0:
                # 为什么会出现a=0的情况
                print('a == 0')
                ellipsis_0 = 1
            elif (b ** 2 - 4 * a * c) < 0:
                ellipsis_0 = 1
                print('b^2 - 4ac < 0')
            else: 
                ellipsis_0 = max((-b - torch.sqrt(b ** 2 - 4 * a * c)) / (2 * a), 0)
            r = ellipsis_0 * r_wave + (1 - ellipsis_0) * r_hat
            if torch.isnan(r_wave) or torch.isnan(r_hat) or torch.isnan(a) or torch.isnan(b) or torch.isnan(c):
                raise ValueError('nan')
            cnt += 1
            

    with torch.no_grad():
        train_loss = model.loss(model(X_train), Y_train).mean()
        test_loss = model.loss(model(X_test), Y_test).mean()
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        norm = get_norm(model)
        wandb.log({'epoch': epoch + 1,
                   'train_loss': train_loss, 
                   'test_loss': test_loss,
                   'norm_W': norm[0],
                   'norm_a': norm[1],
                   'accuracy': 1 - test_loss,
                   'r': r.item(),
                   'adaptive_lr': adaptive_lr.item(),
                   'r_wave': r_wave.item(),
                   'r_hat': r_hat.item(),
                   'a': a.item(),
                   'b': b.item(),
                   'c': c.item(),
                   'ellipsis_0': ellipsis_0})
        print(f'epoch {epoch + 1}, loss {train_loss:.6f}, test loss {test_loss:.6f}')

  r = torch.sqrt(torch.tensor(loss + C, device=device))


epoch 1, loss 0.570264, test loss 0.587395
epoch 2, loss 0.266627, test loss 0.278476
epoch 3, loss 0.121670, test loss 0.130181
epoch 4, loss 0.078118, test loss 0.085130
epoch 5, loss 0.067929, test loss 0.074281
epoch 6, loss 0.065329, test loss 0.071407
epoch 7, loss 0.064236, test loss 0.070134
epoch 8, loss 0.063487, test loss 0.069319
epoch 9, loss 0.062878, test loss 0.068669
epoch 10, loss 0.062307, test loss 0.068108
epoch 11, loss 0.061783, test loss 0.067522
epoch 12, loss 0.061287, test loss 0.067057
epoch 13, loss 0.060817, test loss 0.066604
epoch 14, loss 0.060368, test loss 0.066274
epoch 15, loss 0.059923, test loss 0.065862
epoch 16, loss 0.059490, test loss 0.065445
epoch 17, loss 0.059061, test loss 0.065073
epoch 18, loss 0.058637, test loss 0.064705
epoch 19, loss 0.058218, test loss 0.064308
epoch 20, loss 0.057806, test loss 0.063939
epoch 21, loss 0.057394, test loss 0.063567
epoch 22, loss 0.056996, test loss 0.063150
epoch 23, loss 0.056590, test loss 0.0628

KeyboardInterrupt: 

In [None]:
wandb.finish()

VBox(children=(Label(value='0.012 MB of 0.022 MB uploaded\r'), FloatProgress(value=0.5339011925042589, max=1.0…

0,1
a,█▄▃▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
accuracy,▁▅▇█████████████████████████████████████
adaptive_lr,▁▂▃▅▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█████████████████
b,▁▄▅▆▆▇▇▆▇▇█▇▇▇█▇███▇████▇█▇▇▇▇█▇████████
c,█▅▄▃▃▂▂▃▂▂▁▂▂▂▁▂▁▁▁▂▁▁▁▁▂▁▂▂▂▂▁▂▁▁▁▁▁▁▁▁
ellipsis_0,▁▁▆▇██████████▇█▇███▇██▇████████████▇███
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████
norm_W,▁▃▄▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█████████
norm_a,▁▃▄▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█████████
r,█▃▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
a,4e-05
accuracy,0.94474
adaptive_lr,20.85179
b,-0.01259
c,0.01255
ellipsis_0,0.99998
epoch,46.0
norm_W,14.23516
norm_a,17.85828
r,1.01368
