In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset

# 產生資料
X = torch.cat((torch.normal(2*torch.ones(100, 2), 1), torch.normal(-2*torch.ones(100, 2), 1)), 0)
Y = torch.cat((torch.zeros(100), torch.ones(100)), 0).long()

# 分割資料
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2)
train_loader = DataLoader(TensorDataset(X_train, Y_train), batch_size=16, shuffle=True)
test_loader = DataLoader(TensorDataset(X_test, Y_test), batch_size=16)

# 定義模型
class FNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 10),
            nn.ReLU(),
            nn.Linear(10, 2)
        )
    def forward(self, x):
        return self.net(x)

model = FNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

# 訓練
for epoch in range(100):
    for xb, yb in train_loader:
        out = model(xb)
        loss = loss_fn(out, yb)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if (epoch + 1) % 10 == 0:
            print(f'epoch: {epoch+1}, loss = {loss.item(): .4f}')

# 測試
correct = 0
total = 0
with torch.no_grad():
    for xb, yb in test_loader:
        pred = model(xb).argmax(1)
        correct += (pred == yb).sum().item()
        total += yb.size(0)
print(f"Test Accuracy: {correct / total:.2f}")

ModuleNotFoundError: No module named 'sklearn'

In [3]:
import numpy as np
import matplotlib.pyplot as plt

# 建立網格點 (meshgrid) 做為輸入測試資料
h = 0.02  # 網格密度
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                     np.arange(y_min, y_max, h))

grid = torch.tensor(np.c_[xx.ravel(), yy.ravel()], dtype=torch.float32)
with torch.no_grad():
    Z = model(grid)
    Z = Z.argmax(1).numpy().reshape(xx.shape)

# 畫出分類區域
plt.figure(figsize=(8, 6))
plt.contourf(xx, yy, Z, cmap='RdYlBu', alpha=0.6)

# 畫出原始資料點
plt.scatter(X[:, 0], X[:, 1], c=Y, edgecolor='k', cmap='RdYlBu', s=60)
plt.title("FNN", fontsize=16)
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.grid(True)
plt.savefig("Classfication2.png", dpi = 300)
plt.show()


NameError: name 'X' is not defined