In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import wandb

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
X_train, Y_train, X_test, Y_test = torch.load('../data/dataset_3_20D.pt')

# 将数据移动到适当的设备
X_train = X_train.to(device)
Y_train = Y_train.to(device)
X_test = X_test.to(device)
Y_test = Y_test.to(device)

X_train.shape, Y_train.shape, X_test.shape, Y_test.shape

  X_train, Y_train, X_test, Y_test = torch.load('../data/dataset_3_20D.pt')


(torch.Size([80000, 21]),
 torch.Size([80000]),
 torch.Size([20000, 21]),
 torch.Size([20000]))

In [4]:
# 使用 DataLoader 进行批处理
l = 256
train_dataset = TensorDataset(X_train, Y_train)
train_loader = DataLoader(train_dataset, batch_size=l, shuffle=True)

# 打印第一个批次的大小
for x, y in train_loader:
    print(x.shape, y.shape)
    break

torch.Size([256, 21]) torch.Size([256])


In [5]:
D = 20
m = 10000

In [6]:
class PM_Euler(nn.Module):
    def __init__(self, input, hidden_layer, output):
        super(PM_Euler, self).__init__()
        self.relu = nn.ReLU()
        self.hidden_dim = hidden_layer
        self.W = nn.Parameter(torch.rand(input, hidden_layer, device=device), requires_grad=True)
        # HE初始化
        nn.init.kaiming_normal_(self.W, mode='fan_in', nonlinearity='relu')
        self.a = nn.Parameter(torch.rand(hidden_layer, output, device=device), requires_grad=True)
        nn.init.kaiming_normal_(self.a, mode='fan_in', nonlinearity='relu')
        
    def forward(self, x):
        # print(x.shape)
        z1 = self.relu(torch.mm(x, self.W))
        # print(z1.shape)
        z2 = torch.mm(z1, self.a) / self.hidden_dim
        return z2

    def loss(self, y_pred, y_true):
        return (y_pred - y_true.reshape(y_pred.shape)) ** 2
model = PM_Euler(D + 1, m, 1).to(device)

# 计算模型W和a的Norm
def get_norm(model):
    return torch.norm(model.W).item(), torch.norm(model.a).item()

get_norm(model)

(6.485015392303467, 141.00892639160156)

In [7]:
epochs = 10000
lr = 1
C = 100
_lambda = 4
# Define the relax parameters 
r_wave = 0
r_hat = 0
r = 0
a = 0
b = 0
c = 0
ellipsis_0 = 0
ratio_n = 0.99

train_losses = []
test_losses = []

In [8]:
config = {
    'learning_rate': lr,
    'batch_size': l,
    'epochs': epochs,
    'hidden_layer': m,
    'input': D + 1,
    'output': 1,
    'optimizer': 'RelaxedSAV',
    'C': C,
    '_lambda': _lambda,
    'ratio_n': 0.99,
    'dataset': 'dataset_3_20D'
}

In [9]:
import datetime

date = datetime.datetime.now().strftime("%m%d%H%M")
wandb.init(project='Numerical Method', name=f"PM_RelaxedSAV_Example_3_{date}", config=config, notes="first try for example 3")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [10]:
for epoch in range(epochs):
    cnt = 0
    for X, Y in train_loader:
        y_pred = model(X)
        loss = model.loss(y_pred, Y).mean()
        if cnt == 0:
            with torch.no_grad():
                r = torch.sqrt(loss + C)
                cnt = 1
        loss.backward()
        with torch.no_grad():
            #===============Update the parameters in SAV================
            theta_a_1 = model.a.clone()
            theta_w_1 = model.W.clone()
            N_a = model.a.grad.clone()
            N_w = model.W.grad.clone()
            theta_a_2 = -lr * N_a / (torch.sqrt(loss + C) * (1 + lr * _lambda))
            theta_w_2 = -lr * N_w / (torch.sqrt(loss + C) * (1 + lr * _lambda))
            r_wave = r / (1 + lr * (torch.sum(N_a * (N_a / (1 + lr * _lambda))) + torch.sum(N_w * (N_w) / (1 + lr * _lambda))) / (2 * (loss + C)))
            model.a += r_wave.item() * theta_a_2
            model.W += r_wave.item() * theta_w_2
            model.a.grad.zero_()
            model.W.grad.zero_()
            # ===============Update r in SAV================
            tmp_loss = model.loss(model(X), Y).mean()
            r_hat = torch.sqrt(tmp_loss + C)
            a = (r_wave - r_hat) ** 2
            b = 2 * r_hat * (r_wave - r_hat)
            c = r_hat ** 2 - r_wave ** 2 -  ratio_n * (torch.norm(model.a - theta_a_1) ** 2 + torch.norm(model.W - theta_w_1) ** 2) / lr
            if a == 0:
                # 为什么会出现a=0的情况
                print('a == 0')
                ellipsis_0 = 0
            elif (b ** 2 - 4 * a * c) < 0:
                ellipsis_0 = 0
                print('b^2 - 4ac < 0')
            else: 
                ellipsis_0 = max((-b - torch.sqrt(b ** 2 - 4 * a * c)) / (2 * a), 0)
            # print(r, r_wave, r_hat, ellipsis_0, a, b, c, (-b - torch.sqrt(b ** 2 - 4 * a * c)) / (2 * a), r - r_hat)
            r = ellipsis_0 * r_wave + (1 - ellipsis_0) * r_hat
            # 创建log记录变量
            # with open('log.txt', 'a') as f:
            #     f.write(f'epoch {epoch + 1}, loss {loss.item():.6f}, r {r.item():.6f}, r_wave {r_wave.item():.6f}, r_hat {r_hat.item():.6f}, ellipsis_0 {ellipsis_0:.6f},a {a},b {b}, c {c},b^2 - 4ac {b ** 2 - 4 * a * c}\n')
            # 检测是否有nan
            if torch.isnan(r_wave) or torch.isnan(r_hat) or torch.isnan(a) or torch.isnan(b) or torch.isnan(c):
                raise ValueError('nan')
    with torch.no_grad():
        train_loss = model.loss(model(X_train), Y_train).mean()
        test_loss = model.loss(model(X_test), Y_test).mean()
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        norm = get_norm(model)
        wandb.log({'epoch': epoch + 1,
                   'train_loss': train_loss, 
                   'test_loss': test_loss,
                   'norm_W': norm[0],
                   'norm_a': norm[1],
                   'accuracy': 1 - test_loss,
                   'r': r.item(),
                   'r_wave': r_wave.item(),
                   'r_hat': r_hat.item(),
                   'ellipsis': ellipsis_0})
        print(f'epoch {epoch + 1}, loss {train_loss:.6f}, test loss {test_loss:.6f}')

a == 0
epoch 1, loss 1.009364, test loss 0.961881
a == 0
epoch 2, loss 1.009269, test loss 0.961788
epoch 3, loss 1.009175, test loss 0.961695
a == 0
epoch 4, loss 1.009081, test loss 0.961603
a == 0
epoch 5, loss 1.008987, test loss 0.961510
a == 0
epoch 6, loss 1.008894, test loss 0.961418
a == 0
epoch 7, loss 1.008801, test loss 0.961326
a == 0
epoch 8, loss 1.008708, test loss 0.961234
a == 0
epoch 9, loss 1.008615, test loss 0.961142
a == 0
epoch 10, loss 1.008522, test loss 0.961051
a == 0
epoch 11, loss 1.008430, test loss 0.960959
epoch 12, loss 1.008337, test loss 0.960868
a == 0
epoch 13, loss 1.008245, test loss 0.960777
a == 0
epoch 14, loss 1.008152, test loss 0.960685
a == 0
epoch 15, loss 1.008060, test loss 0.960594
epoch 16, loss 1.007969, test loss 0.960504
a == 0
epoch 17, loss 1.007877, test loss 0.960413
a == 0
epoch 18, loss 1.007786, test loss 0.960323
a == 0
epoch 19, loss 1.007697, test loss 0.960234
a == 0
epoch 20, loss 1.007607, test loss 0.960145
a == 0
epo

KeyboardInterrupt: 

In [12]:
wandb.finish()