In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

# 定义神经网络
class PolicyNet(nn.Module):
    def __init__(self):
        super(PolicyNet, self).__init__()
        self.fc1 = nn.Linear(2, 32)
        self.fc2 = nn.Linear(32, 3)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 创建数据集
def create_dataset(num_samples):
    inputs = torch.randn(num_samples, 2)  # 随机生成输入数据
    labels = torch.zeros(num_samples, dtype=torch.long)
    for i in range(num_samples):
        if inputs[i, 1] > 0:
            labels[i] = 2
        elif inputs[i, 1] < 0:
            labels[i] = 0
        else:
            labels[i] = 1
    dataset = TensorDataset(inputs, labels)
    return dataset

# 训练函数
def train(model, num_epochs, batch_size):
    dataset = create_dataset(100000)  # 创建数据集
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters())

    for epoch in tqdm(range(num_epochs)):
        for inputs, labels in data_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        if (epoch + 1) % 10 == 0:
            print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

# 效果判断函数
def evaluate(model):
    test_cases = [
        torch.tensor([1.0, 1.0]),   # 预期输出: 2
        torch.tensor([2.0, 0.0]),   # 预期输出: 1
        torch.tensor([-1.0, -2.0])  # 预期输出: 0
    ]

    for case in test_cases:
        output = model(case.unsqueeze(0))
        pred = output.argmax(dim=1).item()
        if case[1] > 0:
            expected = 2
        elif case[1] < 0:
            expected = 0
        else:
            expected = 1
        print(f"Input: {case}, Prediction: {pred}, Expected: {expected}")

# 创建模型并训练
model = PolicyNet()
train(model, num_epochs=100, batch_size=64)

# 评估效果
evaluate(model)


 10%|█         | 10/100 [00:34<05:15,  3.50s/it]

Epoch [10/100], Loss: 0.0000


 20%|██        | 20/100 [01:10<04:39,  3.49s/it]

Epoch [20/100], Loss: 0.0000


 30%|███       | 30/100 [01:44<04:03,  3.47s/it]

Epoch [30/100], Loss: 0.0000


 40%|████      | 40/100 [02:20<03:32,  3.55s/it]

Epoch [40/100], Loss: 0.0000


 50%|█████     | 50/100 [02:55<02:55,  3.50s/it]

Epoch [50/100], Loss: 0.0000


 60%|██████    | 60/100 [03:30<02:22,  3.56s/it]

Epoch [60/100], Loss: 0.0000


 70%|███████   | 70/100 [04:05<01:44,  3.49s/it]

Epoch [70/100], Loss: 0.0000


 80%|████████  | 80/100 [04:40<01:09,  3.47s/it]

Epoch [80/100], Loss: 0.0000


 90%|█████████ | 90/100 [05:15<00:35,  3.52s/it]

Epoch [90/100], Loss: 0.0000


100%|██████████| 100/100 [05:50<00:00,  3.50s/it]

Epoch [100/100], Loss: 0.0000
Input: tensor([1., 1.]), Prediction: 2, Expected: 2
Input: tensor([2., 0.]), Prediction: 2, Expected: 1
Input: tensor([-1., -2.]), Prediction: 0, Expected: 0





In [3]:
# 保存模型状态字典
torch.save(model.state_dict(), f'models/model_trybest.pth')
