In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import wandb

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
X_train, Y_train, X_test, Y_test = torch.load('../data/dataset_3_20D.pt')

# 将数据移动到适当的设备
X_train = X_train.to(device)
Y_train = Y_train.to(device)
X_test = X_test.to(device)
Y_test = Y_test.to(device)

X_train.shape, Y_train.shape, X_test.shape, Y_test.shape

  X_train, Y_train, X_test, Y_test = torch.load('../data/dataset_3_20D.pt')


(torch.Size([80000, 21]),
 torch.Size([80000]),
 torch.Size([20000, 21]),
 torch.Size([20000]))

In [4]:
# 使用 DataLoader 进行批处理
l = 256
train_dataset = TensorDataset(X_train, Y_train)
train_loader = DataLoader(train_dataset, batch_size=l, shuffle=True)

# 打印第一个批次的大小
for x, y in train_loader:
    print(x.shape, y.shape)
    break

torch.Size([256, 21]) torch.Size([256])


In [5]:
D = 20
m = 10000

In [6]:
class Model(nn.Module):
    def __init__(self, input, hidden_layer, output):
        super(Model, self).__init__()
        self.relu = nn.ReLU()
        self.hidden_dim = hidden_layer
        self.W = nn.Parameter(torch.rand(input, hidden_layer, device=device), requires_grad=True)
        # HE初始化
        nn.init.kaiming_normal_(self.W, mode='fan_in', nonlinearity='relu')
        self.a = nn.Parameter(torch.rand(hidden_layer, output, device=device), requires_grad=True)
        nn.init.kaiming_normal_(self.a, mode='fan_in', nonlinearity='relu')
        
    def forward(self, x):
        # print(x.shape)
        z1 = self.relu(torch.mm(x, self.W))
        # print(z1.shape)
        z2 = torch.mm(z1, self.a) / self.hidden_dim
        return z2

    def loss(self, y_pred, y_true):
        return (y_pred - y_true.reshape(y_pred.shape)) ** 2

model = Model(D + 1, m, 1).to(device)

# 计算模型W和a的Norm
def get_norm(model):
    return torch.norm(model.W).item(), torch.norm(model.a).item()

get_norm(model)

(6.495305061340332, 141.2490997314453)

In [7]:
epochs = 10000
lr = 1
C = 100
_lambda = 4
r = 0
epsilon = 1e-8
beta_1 = 0.9
beta_2 = 0.999
# Define the relax parameters 
r_wave = 0
r_hat = 0
r = 0
a = 0
b = 0
c = 0
ellipsis_0 = 0
ratio_n = 0.99


train_losses = []
test_losses = []

In [8]:
config = {
    'learning_rate': lr,
    'batch_size': l,
    'epochs': epochs,
    'hidden_layer': m,
    'input': D + 1,
    'output': 1,
    'optimizer': 'Adam_Relax_SAV',
    'Approx Method': 'PM',
    'C': C,
    '_lambda': _lambda,
    'r': r,
    'epsilon': epsilon,
    'ratio_n': 0.99
}

In [9]:
import datetime

date = datetime.datetime.now().strftime("%m%d%H%M")
wandb.init(project='Numerical Method', name=f"PM_A_RelaxSAV_Example_3_{date}", config=config)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpheonizard[0m ([33mpheonizard-university-of-nottingham[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [10]:
for epoch in range(epochs):
    cnt = 0
    m_a, m_w, v = 0, 0, 0
    for X, Y in train_loader:
        loss = model.loss(model(X), Y).mean()
        if cnt == 0:
            r = torch.sqrt(torch.tensor(loss + C, device=device))
        loss.backward()
        with torch.no_grad():
            #=========Nonlinear Term==========
            N_a_init = model.a.grad
            N_w_init = model.W.grad
            m_a = beta_1 * m_a + (1 - beta_1) * N_a_init
            m_w = beta_1 * m_w + (1 - beta_1) * N_w_init
            v = beta_2 * v + (1 - beta_2) * (torch.norm(N_a_init) ** 2 + torch.norm(N_w_init) ** 2)
            m_a_hat = m_a / (1 - beta_1 ** (cnt + 1))
            m_w_hat = m_w / (1 - beta_1 ** (cnt + 1))
            v_hat = v / (1 - beta_2 ** (cnt + 1))
            N_a = m_a_hat
            N_w = m_w_hat
            #=========Time Step Update========
            adaptive_lr = lr / (torch.sqrt(v_hat) + epsilon)
            #=========SAV Update========== 
            theta_a_1 = model.a.clone()
            theta_w_1 = model.W.clone()
            theta_a_2 = - adaptive_lr * N_a / (torch.sqrt(loss + C) * (1 + adaptive_lr * _lambda))
            theta_w_2 = - adaptive_lr * N_w / (torch.sqrt(loss + C) * (1 + adaptive_lr * _lambda))
            r_wave = r / (1 + adaptive_lr * (torch.sum(N_a * (N_a / (1 + adaptive_lr * _lambda))) + torch.sum(N_w * (N_w) / (1 + adaptive_lr * _lambda))) / (2 * (loss + C)))
            model.a += r_wave.item() * theta_a_2
            model.W += r_wave.item() * theta_w_2
            model.a.grad.zero_()
            model.W.grad.zero_()
            #=========Relax Update==========
            tmp_loss = model.loss(model(X), Y).mean()
            r_hat = torch.sqrt(tmp_loss + C)
            a = (r_wave - r_hat) ** 2
            b = 2 * r_hat * (r_wave - r_hat)
            c = r_hat ** 2 - r_wave ** 2 - ratio_n * (torch.norm(model.a - theta_a_1) ** 2 + torch.norm(model.W - theta_w_1) ** 2) / adaptive_lr
            if a == 0:
                # 为什么会出现a=0的情况
                ellipsis_0 = 0
            elif (b ** 2 - 4 * a * c) < 0:
                ellipsis_0 = 0
                print('b^2 - 4ac < 0')
            else: 
                ellipsis_0 = max((-b - torch.sqrt(b ** 2 - 4 * a * c)) / (2 * a), 0)
            r = ellipsis_0 * r_wave + (1 - ellipsis_0) * r_hat
            if torch.isnan(r_wave) or torch.isnan(r_hat) or torch.isnan(a) or torch.isnan(b) or torch.isnan(c):
                raise ValueError('nan')
            cnt += 1

    with torch.no_grad():
        train_loss = model.loss(model(X_train), Y_train).mean()
        test_loss = model.loss(model(X_test), Y_test).mean()
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        norm = get_norm(model)
        wandb.log({'epoch': epoch + 1,
                   'train_loss': train_loss, 
                   'test_loss': test_loss,
                   'norm_W': norm[0],
                   'norm_a': norm[1],
                   'accuracy': 1 - test_loss,
                   'r': r.item(),
                   'r_wave': r_wave.item(),
                   'r_hat': r_hat.item(),
                   'ellipsis': ellipsis_0,
                   'adaptive_lr': adaptive_lr.item()})
        print(f'epoch {epoch + 1}, loss {train_loss:.6f}, test loss {test_loss:.6f}')

  r = torch.sqrt(torch.tensor(loss + C, device=device))


epoch 1, loss 1.009310, test loss 0.961819
epoch 2, loss 1.009192, test loss 0.961703
epoch 3, loss 1.009075, test loss 0.961586
epoch 4, loss 1.008958, test loss 0.961471
epoch 5, loss 1.008841, test loss 0.961355
epoch 6, loss 1.008724, test loss 0.961240
epoch 7, loss 1.008608, test loss 0.961126
epoch 8, loss 1.008492, test loss 0.961011
epoch 9, loss 1.008376, test loss 0.960897
epoch 10, loss 1.008260, test loss 0.960782
epoch 11, loss 1.008145, test loss 0.960668
epoch 12, loss 1.008029, test loss 0.960553
epoch 13, loss 1.007915, test loss 0.960440
epoch 14, loss 1.007800, test loss 0.960326
epoch 15, loss 1.007687, test loss 0.960214
epoch 16, loss 1.007575, test loss 0.960103
epoch 17, loss 1.007465, test loss 0.959993
epoch 18, loss 1.007356, test loss 0.959885
epoch 19, loss 1.007248, test loss 0.959778
epoch 20, loss 1.007143, test loss 0.959673
epoch 21, loss 1.007039, test loss 0.959571
epoch 22, loss 1.006938, test loss 0.959470
epoch 23, loss 1.006838, test loss 0.9593

KeyboardInterrupt: 

In [None]:
wandb.finish()