In [2]:
import torch
from torch import Tensor
from torch import nn
from torch import optim
from torch.utils.data import Dataset, DataLoader

In [None]:
def seed_everything(seed: int):
    torch.seed(seed)
    torch.cuda.seed_all(seed)
seed_everything(42)

# 01分类

小于等于0为0,否则为1

# model

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [4]:
model = nn.Sequential(
    nn.Linear(1, 3),
    nn.ReLU(),
    nn.Linear(3, 2),
).to(device)
model

Sequential(
  (0): Linear(in_features=1, out_features=3, bias=True)
  (1): ReLU()
  (2): Linear(in_features=3, out_features=2, bias=True)
)

In [5]:
optimizer = optim.SGD(model.parameters(), lr=0.1)
optimizer

SGD (
Parameter Group 0
    dampening: 0
    differentiable: False
    foreach: None
    lr: 0.1
    maximize: False
    momentum: 0
    nesterov: False
    weight_decay: 0
)

In [6]:
loss_fn = nn.CrossEntropyLoss()
loss_fn

CrossEntropyLoss()

# data

In [7]:
train_x = torch.randn(10000, 1)
train_x

tensor([[ 0.2403],
        [-0.0928],
        [ 0.0116],
        ...,
        [ 0.3054],
        [ 2.5811],
        [-2.4365]])

In [8]:
train_y = (train_x > 0).type(torch.long).flatten()
train_y

tensor([1, 0, 1,  ..., 1, 1, 0])

In [9]:
val_x = torch.rand([10000, 1]) * 2 - 1
val_x

tensor([[ 0.2275],
        [-0.2959],
        [-0.9763],
        ...,
        [-0.4862],
        [-0.1284],
        [-0.1013]])

In [10]:
val_y = (val_x > 0).type(torch.long).flatten()
val_y

tensor([1, 0, 0,  ..., 0, 0, 0])

In [11]:
class Data(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __getitem__(self, index):
        return self.X[index], self.y[index]

    def __len__(self):
        return len(self.X)

In [12]:
train_dataloader = DataLoader(Data(train_x, train_y), batch_size=1000, shuffle=True)
train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x28ebb237490>

In [13]:
val_dataloader = DataLoader(Data(val_x, val_y), batch_size=1000, shuffle=False)
val_dataloader

<torch.utils.data.dataloader.DataLoader at 0x28ebe443890>

# train

In [None]:
with torch.inference_mode():
    x = torch.tensor([-0.01, -0.001, 0.001, 0.01]).reshape(-1, 1).to(device)
    print(model(x).cpu().softmax(dim=1))

tensor([[0.5583, 0.4417],
        [0.5161, 0.4839],
        [0.5067, 0.4933],
        [0.4642, 0.5358]])


In [14]:
epochs = 100

In [15]:
for epoch in range(1, epochs+1):
    model.train()
    y_trues = []
    y_preds = []
    losses = []
    for x, y in train_dataloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred: Tensor = model(x)
        loss: Tensor = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()

        y_trues.append(y)
        y_preds.append(y_pred.argmax(dim=1))
        losses.append(loss.item())

    acc = (torch.cat(y_trues) == torch.cat(y_preds)).type(torch.float).mean().item()
    loss_mean = torch.tensor(losses).mean().item()
    print(f"epoch: {epoch}, train, acc: {acc:.6f}, loss: {loss_mean:.6f}")

    model.eval()
    y_trues = []
    y_preds = []
    losses = []
    for x, y in val_dataloader:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            y_pred: Tensor = model(x)
        loss: Tensor = loss_fn(y_pred, y)
        y_trues.append(y)
        y_preds.append(y_pred.argmax(dim=1))
        losses.append(loss.item())

    acc = (torch.cat(y_trues) == torch.cat(y_preds)).type(torch.float).mean().item()
    loss_mean = torch.tensor(losses).mean().item()
    print(f"epoch: {epoch}, val, acc: {acc:.6f}, loss: {loss_mean:.6f}")

epoch: 1, train, acc: 0.502000, loss: 0.689267
epoch: 1, val, acc: 0.500100, loss: 0.653449
epoch: 2, train, acc: 0.679600, loss: 0.604887
epoch: 2, val, acc: 0.834900, loss: 0.595620
epoch: 3, train, acc: 0.909800, loss: 0.522756
epoch: 3, val, acc: 0.929900, loss: 0.527265
epoch: 4, train, acc: 0.956900, loss: 0.435517
epoch: 4, val, acc: 0.959900, loss: 0.452944
epoch: 5, train, acc: 0.975000, loss: 0.356489
epoch: 5, val, acc: 0.973700, loss: 0.384562
epoch: 6, train, acc: 0.981500, loss: 0.293868
epoch: 6, val, acc: 0.981800, loss: 0.328405
epoch: 7, train, acc: 0.985800, loss: 0.247217
epoch: 7, val, acc: 0.985900, loss: 0.284430
epoch: 8, train, acc: 0.989200, loss: 0.212807
epoch: 8, val, acc: 0.987300, loss: 0.250203
epoch: 9, train, acc: 0.990300, loss: 0.186901
epoch: 9, val, acc: 0.988200, loss: 0.223130
epoch: 10, train, acc: 0.991600, loss: 0.166760
epoch: 10, val, acc: 0.989400, loss: 0.201136
epoch: 11, train, acc: 0.992500, loss: 0.150551
epoch: 11, val, acc: 0.991800,

# test

In [16]:
test_x = torch.logspace(-4, 0, 100)
test_x

tensor([1.0000e-04, 1.0975e-04, 1.2045e-04, 1.3219e-04, 1.4508e-04, 1.5923e-04,
        1.7475e-04, 1.9179e-04, 2.1049e-04, 2.3101e-04, 2.5354e-04, 2.7826e-04,
        3.0539e-04, 3.3516e-04, 3.6784e-04, 4.0370e-04, 4.4306e-04, 4.8626e-04,
        5.3367e-04, 5.8570e-04, 6.4281e-04, 7.0548e-04, 7.7426e-04, 8.4975e-04,
        9.3260e-04, 1.0235e-03, 1.1233e-03, 1.2328e-03, 1.3530e-03, 1.4850e-03,
        1.6298e-03, 1.7886e-03, 1.9630e-03, 2.1544e-03, 2.3645e-03, 2.5950e-03,
        2.8480e-03, 3.1257e-03, 3.4305e-03, 3.7649e-03, 4.1320e-03, 4.5349e-03,
        4.9770e-03, 5.4623e-03, 5.9948e-03, 6.5793e-03, 7.2208e-03, 7.9248e-03,
        8.6975e-03, 9.5455e-03, 1.0476e-02, 1.1498e-02, 1.2619e-02, 1.3849e-02,
        1.5199e-02, 1.6681e-02, 1.8307e-02, 2.0092e-02, 2.2051e-02, 2.4201e-02,
        2.6561e-02, 2.9151e-02, 3.1993e-02, 3.5112e-02, 3.8535e-02, 4.2292e-02,
        4.6416e-02, 5.0941e-02, 5.5908e-02, 6.1359e-02, 6.7342e-02, 7.3907e-02,
        8.1113e-02, 8.9022e-02, 9.7701e-

In [17]:
test_x = torch.cat([test_x, -test_x]).reshape(-1, 1)
test_x.shape

torch.Size([200, 1])

In [18]:
test_y = (test_x > 0).type(torch.long).flatten()
test_y

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])

In [19]:
model.eval()
with torch.inference_mode():
    y_pred = model(test_x.to(device)).cpu().argmax(dim=-1)
y_pred

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])

In [20]:
# 准确率
(test_y == y_pred).type(torch.float).mean().item()

0.824999988079071

In [21]:
# 获取错误的 x index
error_index = (test_y != y_pred)
error_index

tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, 

In [22]:
# 获取错误的 x
error_x = test_x[error_index].flatten()
error_x

tensor([1.0000e-04, 1.0975e-04, 1.2045e-04, 1.3219e-04, 1.4508e-04, 1.5923e-04,
        1.7475e-04, 1.9179e-04, 2.1049e-04, 2.3101e-04, 2.5354e-04, 2.7826e-04,
        3.0539e-04, 3.3516e-04, 3.6784e-04, 4.0370e-04, 4.4306e-04, 4.8626e-04,
        5.3367e-04, 5.8570e-04, 6.4281e-04, 7.0548e-04, 7.7426e-04, 8.4975e-04,
        9.3260e-04, 1.0235e-03, 1.1233e-03, 1.2328e-03, 1.3530e-03, 1.4850e-03,
        1.6298e-03, 1.7886e-03, 1.9630e-03, 2.1544e-03, 2.3645e-03])