In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.tensorboard import SummaryWriter

from LeNet import LeNet1

In [2]:
# LeNet = nn.Sequential(
#     nn.Conv2d(1, 6, kernel_size=(5, 5), padding=2),
#     nn.ReLU(),
#     # nn.Sigmoid(),
#     # nn.AvgPool2d(kernel_size=(2, 2), stride=2),
#     nn.MaxPool2d(kernel_size=2,stride=2),
#
#     nn.Conv2d(6, 16, kernel_size=(5, 5)),
#     # nn.Sigmoid(),
#     nn.ReLU(),
#     # nn.AvgPool2d(kernel_size=(2, 2), stride=2),
#     nn.MaxPool2d(kernel_size=2,stride=2),
#
#     nn.Flatten(),
#
#     nn.Linear(16 * 5 * 5, 120),
#     # nn.Sigmoid(),
#     nn.ReLU(),
#
#     nn.Linear(120, 84),
#     # nn.Sigmoid(),
#     nn.ReLU(),
#
#     nn.Linear(84, 10)
# )
writer=SummaryWriter("logs")
torch.cuda.empty_cache ()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

mnist_training = datasets.MNIST(
    root="./data",
    train=True,
    transform=ToTensor(),
    download=False
)

mnist_test = datasets.MNIST(
    root="./data",
    train=False,
    transform=ToTensor(),
    download=False
)

img,label=mnist_test[0]
print(img.shape,label)
img.show()

torch.Size([1, 28, 28]) 7


In [3]:
BATCH_SIZE = 256
lr = 0.1
epochs = 20

train_dataloader = DataLoader(mnist_training, batch_size=BATCH_SIZE, shuffle=True)

test_dataloader = DataLoader(mnist_test, batch_size=BATCH_SIZE, shuffle=False)

In [4]:
net = LeNet1().to(device)

def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        nn.init.xavier_uniform_(m.weight)
net.apply(init_weights)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9)

for epoch in range(epochs):
    print(
        f"epoch {epoch} \n---------------------"
    )

    for batch, (inputs, labels) in enumerate(train_dataloader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        if batch % 10 == 0:
            loss, current = loss.item(), batch * len(inputs)
            print(f"loss:{loss:>7f} [{current:>5d}/ 60000]")

    with torch.no_grad():
        acc = 0
        total = 0
        for (image, label) in test_dataloader:
            image, label = image.to(device), label.to(device)
            output = net(image)
            _, pred = torch.max(output.data, 1)
            total += label.size(0)
            acc += (pred == label).sum()

        print(f"test: acc {100 * acc / total}")
        
        writer.add_scalar("acc",acc/total,epoch)
        
writer.close()


epoch 0 
---------------------
loss:2.586294 [    0/ 60000]
loss:2.332217 [ 2560/ 60000]
loss:2.328431 [ 5120/ 60000]
loss:2.346096 [ 7680/ 60000]
loss:2.320854 [10240/ 60000]
loss:2.303083 [12800/ 60000]
loss:2.307433 [15360/ 60000]
loss:2.310144 [17920/ 60000]
loss:2.314940 [20480/ 60000]
loss:2.311887 [23040/ 60000]
loss:2.296582 [25600/ 60000]
loss:2.317631 [28160/ 60000]
loss:2.314351 [30720/ 60000]
loss:2.313234 [33280/ 60000]
loss:2.330926 [35840/ 60000]
loss:2.305202 [38400/ 60000]
loss:2.294136 [40960/ 60000]
loss:2.298667 [43520/ 60000]
loss:2.318080 [46080/ 60000]
loss:2.314773 [48640/ 60000]
loss:2.298622 [51200/ 60000]
loss:2.303575 [53760/ 60000]
loss:2.301598 [56320/ 60000]
loss:2.315878 [58880/ 60000]
test: acc 20.829999923706055
epoch 1 
---------------------
loss:2.305477 [    0/ 60000]
loss:2.317616 [ 2560/ 60000]
loss:2.299921 [ 5120/ 60000]
loss:2.310353 [ 7680/ 60000]
loss:2.294155 [10240/ 60000]
loss:2.307668 [12800/ 60000]
loss:2.300088 [15360/ 60000]
loss:2.302