In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
import wandb

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
X_train, Y_train, X_test, Y_test = torch.load('./data/dataset_2_40D.pt')

# 将数据移动到适当的设备
X_train = X_train.to(device)
Y_train = Y_train.to(device)
X_test = X_test.to(device)
Y_test = Y_test.to(device)

X_train.shape, Y_train.shape, X_test.shape, Y_test.shape

(torch.Size([8000, 41]),
 torch.Size([8000]),
 torch.Size([2000, 41]),
 torch.Size([2000]))

In [4]:
# 使用 DataLoader 进行批处理
l = 64
train_dataset = TensorDataset(X_train, Y_train)
train_loader = DataLoader(train_dataset, batch_size=l, shuffle=True)

# 打印第一个批次的大小
for x, y in train_loader:
    print(x.shape, y.shape)
    break

torch.Size([64, 41]) torch.Size([64])


In [5]:
D = 40
m = 100

In [6]:
class Model(nn.Module):
    def __init__(self, input, hidden_layer, output):
        super(Model, self).__init__()
        self.relu = nn.ReLU()
        self.hidden_dim = hidden_layer
        self.W = nn.Parameter(torch.rand(input, hidden_layer, device=device), requires_grad=True)
        # HE初始化
        nn.init.kaiming_normal_(self.W, mode='fan_in', nonlinearity='relu')
        self.a = nn.Parameter(torch.rand(hidden_layer, output, device=device), requires_grad=True)
        nn.init.kaiming_normal_(self.a, mode='fan_in', nonlinearity='relu')
        
    def forward(self, x):
        # print(x.shape)
        z1 = self.relu(torch.mm(x, self.W))
        # print(z1.shape)
        z2 = torch.mm(z1, self.a) / self.hidden_dim
        return z2

    def loss(self, y_pred, y_true):
        return (y_pred - y_true.reshape(y_pred.shape)) ** 2

model = Model(D + 1, m, 1).to(device)

# 计算模型W和a的Norm
def get_norm(model):
    return torch.norm(model.W).item(), torch.norm(model.a).item()

get_norm(model)

(9.150127410888672, 12.962539672851562)

In [7]:
epochs = 10000
lr = 1
C = 1
_lambda = 4
r = 0
epsilon = 1e-8
beta_1 = 0.9
beta_2 = 0.999
m_a, m_w = 0, 0
v_a, v_w = 0, 0

train_losses = []
test_losses = []

In [8]:
config = {
    'learning_rate': lr,
    'batch_size': l,
    'epochs': epochs,
    'hidden_layer': m,
    'input': D + 1,
    'output': 1,
    'optimizer': 'Adam_SAV',
    'Approx Method': 'PM',
    'C': C,
    '_lambda': _lambda,
    'r': r,
    'epsilon': epsilon,
    
}

In [9]:
import datetime

date = datetime.datetime.now().strftime("%m%d%H%M")
wandb.init(project='Numerical Method', name=f"PM_A_SAV_Example_2_{date}", config=config)

[34m[1mwandb[0m: Currently logged in as: [33mpheonizard[0m ([33mpheonizard-university-of-nottingham[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [10]:
for epoch in range(epochs):
    cnt = 0
    m_a, m_w, v_a, v_w = 0, 0, 0, 0
    for X, Y in train_loader:
        loss = model.loss(model(X), Y).sum()
        if cnt == 0:
            r = torch.sqrt(torch.tensor(loss + C, device=device))
            cnt = 1
        loss.backward()
        with torch.no_grad():
            #=========Nonlinear Term==========
            N_a_init = model.a.grad
            N_w_init = model.W.grad
            m_a = beta_1 * m_a + (1 - beta_1) * N_a_init
            m_w = beta_1 * m_w + (1 - beta_1) * N_w_init
            v_a = beta_2 * v_a + (1 - beta_2) * torch.norm(N_a_init) ** 2
            v_w = beta_2 * v_w + (1 - beta_2) * torch.norm(N_w_init) ** 2
            m_a_hat = m_a / (1 - beta_1 ** (epoch + 1))
            m_w_hat = m_w / (1 - beta_1 ** (epoch + 1))
            v_a_hat = v_a / (1 - beta_2 ** (epoch + 1))
            v_w_hat = v_w / (1 - beta_2 ** (epoch + 1))
            N_a = m_a_hat
            N_w = m_w_hat
            #=========Time Step Update========
            adaptive_lr = lr / (torch.sqrt(v_a_hat) + epsilon)
            #=========SAV Update========== 
            theta_a_2 = adaptive_lr * N_a / (torch.sqrt(loss + C) * (1 + adaptive_lr * _lambda))
            theta_w_2 = adaptive_lr * N_w / (torch.sqrt(loss + C) * (1 + adaptive_lr * _lambda))
            r = r / (1 + adaptive_lr * (torch.sum(N_a * (N_a / (1 + adaptive_lr * _lambda))) + torch.sum(N_w * (N_w) / (1 + adaptive_lr * _lambda))) / (2 * (loss + C)))
            model.a -= r.item() * theta_a_2
            model.W -= r.item() * theta_w_2
            model.a.grad.zero_()
            model.W.grad.zero_()
    with torch.no_grad():
        train_loss = model.loss(model(X_train), Y_train).mean()
        test_loss = model.loss(model(X_test), Y_test).mean()
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        norm = get_norm(model)
        wandb.log({'epoch': epoch + 1,
                   'train_loss': train_loss, 
                   'test_loss': test_loss,
                   'norm_W': norm[0],
                   'norm_a': norm[1],
                   'accuracy': 1 - test_loss,
                   'r': r.item(),
                   'adaptive_lr': adaptive_lr.item()})
        print(f'epoch {epoch + 1}, loss {train_loss:.7f}, test loss {test_loss:.7f}')

  r = torch.sqrt(torch.tensor(loss + C, device=device))


epoch 1, loss 0.5940118, test loss 0.6061187
epoch 2, loss 0.2240370, test loss 0.2315919
epoch 3, loss 0.0757441, test loss 0.0805644
epoch 4, loss 0.0632788, test loss 0.0667678
epoch 5, loss 0.0582640, test loss 0.0641414
epoch 6, loss 0.0531468, test loss 0.0600450
epoch 7, loss 0.0479347, test loss 0.0545462
epoch 8, loss 0.0419256, test loss 0.0488774
epoch 9, loss 0.0366453, test loss 0.0437580
epoch 10, loss 0.0321184, test loss 0.0390667
epoch 11, loss 0.0277753, test loss 0.0345587
epoch 12, loss 0.0239672, test loss 0.0308546
epoch 13, loss 0.0206889, test loss 0.0275357
epoch 14, loss 0.0180221, test loss 0.0246675
epoch 15, loss 0.0155224, test loss 0.0220297
epoch 16, loss 0.0133276, test loss 0.0196039
epoch 17, loss 0.0117458, test loss 0.0177950
epoch 18, loss 0.0103854, test loss 0.0162446
epoch 19, loss 0.0092366, test loss 0.0147726
epoch 20, loss 0.0082876, test loss 0.0136398
epoch 21, loss 0.0076040, test loss 0.0129708
epoch 22, loss 0.0067772, test loss 0.01210

In [12]:
wandb.finish()