In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import wandb

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [None]:
import datetime

date = datetime.datetime.now().strftime("%m%d%H%M")
wandb.init(project='Numerical Method', name=f"PM_SAV_sum_Example_2_{date}")

In [3]:
X_train, Y_train, X_test, Y_test = torch.load('../data/dataset_3_20D.pt', weights_only=True)

# 将数据移动到适当的设备
X_train = X_train.to(device)
Y_train = Y_train.to(device)
X_test = X_test.to(device)
Y_test = Y_test.to(device)

X_train.shape, Y_train.shape, X_test.shape, Y_test.shape

(torch.Size([80000, 21]),
 torch.Size([80000]),
 torch.Size([20000, 21]),
 torch.Size([20000]))

In [4]:
# 使用 DataLoader 进行批处理
l = 256
train_dataset = TensorDataset(X_train, Y_train)
train_loader = DataLoader(train_dataset, batch_size=l, shuffle=True)

# 打印第一个批次的大小
for x, y in train_loader:
    print(x.shape, y.shape)
    break

torch.Size([256, 21]) torch.Size([256])


In [5]:
D = 20
m = 10000

In [6]:
class PM_Euler(nn.Module):
    def __init__(self, input, hidden_layer, output):
        super(PM_Euler, self).__init__()
        self.relu = nn.ReLU()
        self.hidden_dim = hidden_layer
        self.W = nn.Parameter(torch.rand(input, hidden_layer, device=device), requires_grad=True)
        # HE初始化
        nn.init.kaiming_normal_(self.W, mode='fan_in', nonlinearity='relu')
        self.a = nn.Parameter(torch.rand(hidden_layer, output, device=device), requires_grad=True)
        nn.init.kaiming_normal_(self.a, mode='fan_in', nonlinearity='relu')
        
    def forward(self, x):
        z1 = self.relu(torch.mm(x, self.W))
        z2 = torch.mm(z1, self.a) / self.hidden_dim
        return z2

    def loss(self, y_pred, y_true):
        return (y_pred - y_true.reshape(y_pred.shape)) ** 2

model = PM_Euler(D + 1, m, 1).to(device)

# 计算模型W和a的Norm
def get_norm(model):
    return torch.norm(model.W).item(), torch.norm(model.a).item()

get_norm(model)

(6.485905170440674, 142.09332275390625)

In [7]:
epochs = 10000
lr = 1
C = 100
_lambda = 4
r = [0]

train_losses = []
test_losses = []

In [8]:
with torch.no_grad():
    print(X_train.shape, Y_train.shape)
    print(X_train[0], Y_train[0])
    print(model(X_train))
    print(f'Init Loss : {model.loss(model(X_train), Y_train).mean().item()}')

torch.Size([80000, 21]) torch.Size([80000])
tensor([ 0.2917, -0.3087, -0.3028,  0.0076, -0.4275, -1.0602, -0.1303, -1.3861,
        -0.3681,  0.2687,  0.0687,  0.2104, -0.1937, -0.9206, -0.2700,  0.6400,
        -1.2866,  0.6359, -0.3955,  0.2940,  1.0000], device='cuda:0') tensor(5.5030, device='cuda:0')
tensor([[ 0.0006],
        [ 0.0003],
        [-0.0007],
        ...,
        [ 0.0001],
        [-0.0002],
        [-0.0005]], device='cuda:0')
Init Loss : 0.9661540985107422


In [None]:
wandb.config = {
    'learning_rate': lr,
    'batch_size': l,
    'epochs': epochs,
    'hidden_layer': m,
    'input': D + 1,
    'output': 1,
    'optimizer': 'SAV',
    'C': C,
    '_lambda': _lambda,
    'r': r
}

In [9]:
for epoch in range(epochs):
    cnt = 0
    for X, Y in train_loader:
        loss = model.loss(model(X), Y).sum()
        if cnt == 0:
            r = torch.sqrt(torch.tensor(loss + C, device=device))
            cnt = 1
        loss.backward()
        with torch.no_grad():
            N_a = model.a.grad.clone()
            N_w = model.W.grad.clone()
            theta_a_2 = lr * N_a / (torch.sqrt(loss + C) * (1 + lr * _lambda))
            theta_w_2 = lr * N_w / (torch.sqrt(loss + C) * (1 + lr * _lambda))
            r = r / (1 + lr * (torch.sum(N_a * (N_a / (1 + lr * _lambda))) + torch.sum(N_w * (N_w) / (1 + lr * _lambda))) / (2 * (loss + C)))
            model.a -= r.item() * theta_a_2
            model.W -= r.item() * theta_w_2
            model.a.grad.zero_()
            model.W.grad.zero_()
    with torch.no_grad():
        train_loss = model.loss(model(X_train), Y_train).mean()
        test_loss = model.loss(model(X_test), Y_test).mean()
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        norm = get_norm(model)
        # wandb.log({'epoch': epoch + 1,
        #            'train_loss': train_loss, 
        #            'test_loss': test_loss,
        #            'norm_W': norm[0],
        #            'norm_a': norm[1],
        #            'accuracy': 1 - test_loss,
        #            'r': r.item()})
        print(f'epoch {epoch + 1}, loss {train_loss:.4f}, test loss {test_loss:.4f}')

  r = torch.sqrt(torch.tensor(loss + C, device=device))


epoch 1, loss 0.9549, test loss 1.1249
epoch 2, loss 0.9469, test loss 1.1169
epoch 3, loss 0.9381, test loss 1.1080
epoch 4, loss 0.9292, test loss 1.0989
epoch 5, loss 0.9199, test loss 1.0895
epoch 6, loss 0.9110, test loss 1.0806
epoch 7, loss 0.9007, test loss 1.0704
epoch 8, loss 0.8907, test loss 1.0602
epoch 9, loss 0.8804, test loss 1.0499
epoch 10, loss 0.8680, test loss 1.0372
epoch 11, loss 0.8536, test loss 1.0226
epoch 12, loss 0.8390, test loss 1.0077
epoch 13, loss 0.8231, test loss 0.9914
epoch 14, loss 0.8074, test loss 0.9754
epoch 15, loss 0.7827, test loss 0.9494
epoch 16, loss 0.7633, test loss 0.9293
epoch 17, loss 0.7481, test loss 0.9134
epoch 18, loss 0.7273, test loss 0.8913
epoch 19, loss 0.7012, test loss 0.8640
epoch 20, loss 0.6854, test loss 0.8471
epoch 21, loss 0.6730, test loss 0.8338
epoch 22, loss 0.6596, test loss 0.8191
epoch 23, loss 0.6431, test loss 0.8012
epoch 24, loss 0.6282, test loss 0.7851
epoch 25, loss 0.6179, test loss 0.7737
epoch 26,

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7efc290cba40>>
Traceback (most recent call last):
  File "/root/miniconda3/envs/Numerical_Method/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 790, in _clean_thread_parent_frames
    active_threads = {thread.ident for thread in threading.enumerate()}
                      ^^^^^^^^^^^^
  File "/root/miniconda3/envs/Numerical_Method/lib/python3.12/threading.py", line 1196, in ident
    @property

KeyboardInterrupt: 


epoch 27, loss 0.6002, test loss 0.7539
epoch 28, loss 0.5943, test loss 0.7475
epoch 29, loss 0.5873, test loss 0.7395
epoch 30, loss 0.5826, test loss 0.7344
epoch 31, loss 0.5786, test loss 0.7297
epoch 32, loss 0.5755, test loss 0.7263
epoch 33, loss 0.5726, test loss 0.7227
epoch 34, loss 0.5696, test loss 0.7191
epoch 35, loss 0.5676, test loss 0.7169
epoch 36, loss 0.5653, test loss 0.7144
epoch 37, loss 0.5641, test loss 0.7131
epoch 38, loss 0.5624, test loss 0.7109
epoch 39, loss 0.5615, test loss 0.7098
epoch 40, loss 0.5605, test loss 0.7090
epoch 41, loss 0.5597, test loss 0.7079
epoch 42, loss 0.5591, test loss 0.7075
epoch 43, loss 0.5585, test loss 0.7069
epoch 44, loss 0.5579, test loss 0.7063


In [None]:
wandb.finish()