In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import wandb

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
X_train, Y_train, X_test, Y_test = torch.load('../data/dataset_2_40D.pt', weights_only=True)

# 将数据移动到适当的设备
X_train = X_train.to(device)
Y_train = Y_train.to(device)
X_test = X_test.to(device)
Y_test = Y_test.to(device)

# 使用 DataLoader 进行批处理
l = 64
train_dataset = TensorDataset(X_train, Y_train)
train_loader = DataLoader(train_dataset, batch_size=l, shuffle=True)

D = 40
m = 100

class PM_Euler(nn.Module):
    def __init__(self, input, hidden_layer, output):
        super(PM_Euler, self).__init__()
        self.relu = nn.ReLU()
        self.hidden_dim = hidden_layer
        self.W = nn.Parameter(torch.rand(input, hidden_layer, device=device), requires_grad=True)
        # HE初始化
        nn.init.kaiming_normal_(self.W, mode='fan_in', nonlinearity='relu')
        self.a = nn.Parameter(torch.rand(hidden_layer, output, device=device), requires_grad=True)
        nn.init.kaiming_normal_(self.a, mode='fan_in', nonlinearity='relu')
        
    def forward(self, x):
        # print(x.shape)
        z1 = self.relu(torch.mm(x, self.W))
        # print(z1.shape)
        z2 = torch.mm(z1, self.a) / self.hidden_dim
        return z2

    def loss(self, y_pred, y_true):
        return (y_pred - y_true.reshape(y_pred.shape)) ** 2

model = PM_Euler(D + 1, m, 1).to(device)

# 计算模型W和a的Norm
def get_norm(model):
    return torch.norm(model.W).item(), torch.norm(model.a).item()

get_norm(model)

(9.027321815490723, 14.600614547729492)

In [4]:
epochs = 1000
lr = 1
_lambda = 4

train_losses = []
test_losses = []

import datetime

config = {
    'learning_rate': lr,
    'batch_size': l,
    'epochs': epochs,
    'hidden_layer': m,
    'input': D + 1,
    '_lambda': _lambda,
    'output': 1,
    'optimizer': 'ESAV'
}

date = datetime.datetime.now().strftime("%m%d%H%M")
# wandb.init(project='Numerical Method', name=f"PM_ESAV_Example_2_{date}", config=config)

In [5]:
for epoch in range(epochs):
    cnt = 0
    for X, Y in train_loader:
        loss = model.loss(model(X), Y).mean()
        if cnt == 0:
            r = torch.exp(loss)
            cnt = 1
        loss.backward()
        with torch.no_grad():
            N_a = model.a.grad
            N_w = model.W.grad
            linear_N_a = N_a / (1 + _lambda * lr)
            linear_N_w = N_w / (1 + _lambda * lr)
            theta_a_2 = - linear_N_a * lr / (torch.exp(loss))
            theta_w_2 = - linear_N_w * lr / (torch.exp(loss))
            #=========Update SAV R================
            r = r / (1 + lr * (torch.sum(N_a * linear_N_a) + torch.sum(N_w * linear_N_w)))
            #=========Update Params================
            model.a += r.item() * theta_a_2
            model.W += r.item() * theta_w_2
            model.a.grad.zero_()
            model.W.grad.zero_()
    with torch.no_grad():
        train_loss = model.loss(model(X_train), Y_train).mean()
        test_loss = model.loss(model(X_test), Y_test).mean()
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        norm = get_norm(model)
        # wandb.log({'epoch': epoch + 1,
        #            'train_loss': train_loss, 
        #            'test_loss': test_loss,
        #            'norm_W': norm[0],
        #            'norm_a': norm[1],
        #            'accuracy': 1 - test_loss,
        #            'r': r.item()})
        print(f'epoch {epoch + 1}, loss {train_loss:.8f}, test loss {test_loss:.8f}')

epoch 1, loss 0.73483843, test loss 0.74773395
epoch 2, loss 0.46026531, test loss 0.47123617
epoch 3, loss 0.27537027, test loss 0.28367016
epoch 4, loss 0.14782394, test loss 0.15416951
epoch 5, loss 0.09553069, test loss 0.10093677
epoch 6, loss 0.07622597, test loss 0.08116935
epoch 7, loss 0.06980070, test loss 0.07458044
epoch 8, loss 0.06748733, test loss 0.07220708
epoch 9, loss 0.06639946, test loss 0.07109726
epoch 10, loss 0.06568919, test loss 0.07034674
epoch 11, loss 0.06512337, test loss 0.06979825
epoch 12, loss 0.06460775, test loss 0.06929831
epoch 13, loss 0.06412086, test loss 0.06881406
epoch 14, loss 0.06365983, test loss 0.06840075
epoch 15, loss 0.06322233, test loss 0.06800325
epoch 16, loss 0.06280925, test loss 0.06764218
epoch 17, loss 0.06240448, test loss 0.06725618
epoch 18, loss 0.06200529, test loss 0.06691082
epoch 19, loss 0.06161251, test loss 0.06652095
epoch 20, loss 0.06122556, test loss 0.06620709
epoch 21, loss 0.06084026, test loss 0.06588604
e

In [6]:
# wandb.finish()