<a href="https://colab.research.google.com/github/God-Orcale/AI_Test/blob/main/optimizer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
# 从前面的数据集和数据加载器部分加载代码，并构建模型。
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root = "data",
    train = False,
    download = True,
    transform = ToTensor()
)

train_dataloader = DataLoader(training_data, batch_size = 64)
test_dataloader = DataLoader(test_data, batch_size =64)

class NeuralNetwork(nn.Module):
  def __init__(self):
    super().__init__()
    self.flatten = nn.Flatten()
    self.linear_relu_stack = nn.Sequential(
        nn.Linear(28*28, 512),
        nn.ReLU(),
        nn.Linear(512, 512),
        nn.ReLU(),
        nn.Linear(512, 10)
    )

  def forward(self,x):
      x = self.flatten(x)
      logits = self.linear_relu_stack(x)
      return logits
model = NeuralNetwork()

Number of Epochs - 迭代数据集的次数

Batch Size - 在更新参数之前通过网络传播的数据样本数

学习率 - 在每个批次/epoch 更新模型参数的量。较小的值会导致学习速度变慢，而较大的值可能会导致训练期间出现不可预知的行为。

In [12]:
learning_rate = 1e-3
batch_size = 64
epochs= 5

# 训练循环 - 迭代训练数据集并尝试收敛到最佳参数。
# 验证/测试循环 - 迭代测试数据集以检查模型性能是否正在提高。

In [13]:
# 常见的损失函数包括nn.MSELoss（均方误差）用于回归任务，nn.用于分类的NLLLoss。
# nn.CrossEntropyLoss结合了nn.LogSoftmax和nn.NLLLoss。
# 我们将模型的输出logits传递给nn.CrossEntropyLoss，后者将对logit进行归一化并计算预测误差。
loss_fn = nn.CrossEntropyLoss()

In [14]:
# 我们通过注册需要训练的模型参数并传入学习率超参数来初始化优化器。
optimizer = torch.optim.SGD(model.parameters(),lr = learning_rate)
# 调用optimizer.zero_grad()重置模型参数的梯度。
# 默认情况下，梯度累加;为了防止重复计数，我们在每次迭代时都将它们显式归零。
# 通过调用loss.backward().PyTorch会根据每个参数来存储损失的梯度。
# 一旦我们有了梯度，我们就会调用optimizer.step()以通过在backward pass中收集的梯度来调整参数。

In [17]:
def train_loop(dataloader,model,loss_fn,optimizer):
  size = len(dataloader.dataset)
  model.train()
  for batch,(X,y) in enumerate(dataloader):
    pred = model(X)
    loss = loss_fn(pred,y)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    if batch%100 ==0:
      loss,current = loss.item(),batch*batch_size +len(X)
      print(f"loss:{loss:>7f} [{current:>5d}/{size:>5d}]")
def test_loop(dataloader,model,loss_fn):
  model.eval()
  size = len(dataloader.dataset)
  num_batches= len(dataloader)
  test_loss,correct = 0,0
  with torch.no_grad():
    for X,y in dataloader:
      pred = model(X)
      test_loss +=loss_fn(pred,y).item()
      correct += (pred.argmax(1)==y).type(torch.float).sum().item()
  test_loss/=num_batches
  correct /= size
  print(f"Test Error: \n  Accuracy:{(100*correct):>0.1f}%,Avg loss:{test_loss:>8f} \n")


In [18]:
# 我们初始化损失函数和优化器，并将其传递给train_loop和test_loop。
# 随意增加epoch的数量来跟踪模型的改进性能。
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

epochs = 10
for t in range(epochs):
  print(f"Epoch {t+1}\n")
  train_loop(train_dataloader,model,loss_fn,optimizer)
  test_loop(test_dataloader,model,loss_fn)
print("DONE!")

Epoch 1

loss:0.787227 [   64/60000]
loss:0.877057 [ 6464/60000]
loss:0.632036 [12864/60000]
loss:0.837792 [19264/60000]
loss:0.739338 [25664/60000]
loss:0.733042 [32064/60000]
loss:0.811461 [38464/60000]
loss:0.787427 [44864/60000]
loss:0.793446 [51264/60000]
loss:0.745739 [57664/60000]
Test Error: 
  Accuracy:72.4%,Avg loss:0.756086 

Epoch 2

loss:0.748837 [   64/60000]
loss:0.845274 [ 6464/60000]
loss:0.598966 [12864/60000]
loss:0.812405 [19264/60000]
loss:0.717305 [25664/60000]
loss:0.706948 [32064/60000]
loss:0.785652 [38464/60000]
loss:0.769672 [44864/60000]
loss:0.769983 [51264/60000]
loss:0.723540 [57664/60000]
Test Error: 
  Accuracy:73.4%,Avg loss:0.732871 

Epoch 3

loss:0.715676 [   64/60000]
loss:0.816643 [ 6464/60000]
loss:0.570996 [12864/60000]
loss:0.791216 [19264/60000]
loss:0.698269 [25664/60000]
loss:0.685380 [32064/60000]
loss:0.762036 [38464/60000]
loss:0.754089 [44864/60000]
loss:0.750105 [51264/60000]
loss:0.703859 [57664/60000]
Test Error: 
  Accuracy:74.2%,Avg