## loss function

In [None]:
import torch 
import torch.nn as nn
import torchvision
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

In [None]:
inputs = torch.tensor([1, 2, 3], dtype=torch.float32).reshape((1, 1, 1, 3))
targets = torch.tensor([1, 2, 5], dtype=torch.float32).reshape((1, 1, 1, 3))

loss = nn.L1Loss(reduction='sum')
result = loss(inputs, targets)
print(result)

loss_mse = nn.MSELoss()
result_mse = loss_mse(inputs, targets)
print(result_mse)

x = torch.tensor([0.1, 0.2, 0.3]).reshape((1, -1))
y = torch.tensor([1])

loss_cross = nn .CrossEntropyLoss()
result_cross = loss_cross(x, y)
print(result_cross)

## example for CIFAR 10 Model 

In [None]:
class Foo(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.model1 = nn.Sequential(
            nn.Conv2d(3, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4, 64),
            nn.Linear(64, 10),
        )

    def forward(self, x):
        x = self.model1(x)
        return x 

dataset = torchvision.datasets.CIFAR10("./_data", train=False, transform=ToTensor())
dataloader = DataLoader(dataset, batch_size=4)

imgs, labels = next(iter(dataloader))
foo = Foo()
print(foo(imgs).shape)

idx = 0 
for data in dataloader:
    idx += 1
    if idx > 10:
        break
    imgs, labels = data
    loss = nn.CrossEntropyLoss()
    preds = foo(imgs)
    result_loss = loss(preds, labels)
    result_loss.backward()

# 可以看出已经累计了梯度 
# print(next(foo.named_parameters('model1'))[1].grad)

## 优化器

In [None]:
foo = Foo()

loss = nn.CrossEntropyLoss()
optim = torch.optim.SGD(foo.parameters(), lr=0.01)
for epoch in range(20):
    running_loss = 0
    for data in dataloader:
        imgs, labels = data 
        outputs = foo(imgs)
        result_loss = loss(outputs, labels)
        optim.zero_grad() # 将每个张量上的梯度清零
        result_loss.backward()
        optim.step()
        running_loss += result_loss
    print(f"running loss is {running_loss}")
