# Workone:实现FashionMnist识别准确率达90%以上

In [1]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
from tqdm import tqdm

This project is to classify the FashionMnist Datasets.

I had built three models:
   
   - LeNetv5
   - SimpleNet: base on LeNetv5, I simplify the channels and drop Maxpooling.
   - CNNmodel: base on LeNetv5, I use padding when convolution for keep the origin size,use xavier to initialize the conv weight

Performance:
   - LeNetv5: 90.82%
   - SimpleNet: 88.91%
   - CNNmodel: 90.56%

## Load FashionMnist Dataset

In [3]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize([0.2860,],[0.3526,]),
     # transforms.RandomHorizontalFlip()
    ]
)

train_data = datasets.FashionMNIST(
    root='mnist',
    train=True,
    download=False,
    transform=transform
)

test_data = datasets.FashionMNIST(
    root='mnist',
    train=False,
    download=False,
    transform=transform
)


In [None]:
batch_size = 32
train_loader = DataLoader(train_data,batch_size = batch_size,shuffle=False)
test_loader = DataLoader(test_data,batch_size = batch_size,shuffle=False)


In [41]:
def ComputeMeanAndStd(data):
    '''
        Compute mean and std for transforms.Normalize()
    '''
    std = 0.0
    mean = 0.0
    for X,y in data:
        x = X.view(-1,28*28)
        std += x.std().item()
        mean += x.mean().item()

    std /= len(data)
    mean /= len(data)
    return mean,std

In [42]:
# mean,std = ComputeMeanAndStd(train_loader)
# print(mean)
# print(std)

## Build model

In [43]:
class LeNetv5(nn.Module):
    def __init__(self):
        super(LeNetv5, self).__init__()
        self.conv1 = nn.Conv2d(1,64,5)# 28-5+1 = 24
        # max_pool size = 2 # 12
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64,128,5) # 12-5+1 = 8
        self.bn2 = nn.BatchNorm2d(128)
        # max_pool size = 2 # 4
        self.flatten = nn.Flatten() # 16*4*4 = 16
        self.f1 = nn.Linear(128*4*4,120)
        self.f2 = nn.Linear(120,84)
        self.f3 = nn.Linear(84,10)
        self.maxPool = nn.MaxPool2d(2)
        self.relu = nn.ReLU()

    def forward(self,x):
        x = self.maxPool(self.relu(self.conv1(x)))
        x = self.bn1(x)
        x = self.maxPool(self.relu(self.conv2(x)))
        x = self.bn2(x)
        x = self.flatten(x)
        x = self.relu(self.f1(x))
        x = self.relu(self.f2(x))
        x = self.f3(x)
        return x


In [44]:
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(1,20,kernel_size=3)
        self.bn1 = nn.BatchNorm2d(20)
        self.conv2 = nn.Conv2d(20,20,kernel_size=3)
        self.bn2 = nn.BatchNorm2d(20)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(20*24*24,128)
        self.fc2 = nn.Linear(128,10)
        self.relu = nn.ReLU()
    def forward(self,x):
        x = self.relu(self.conv1(x))
        x = self.bn1(x)
        x = self.relu(self.conv2(x))
        x = self.bn2(x)
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [45]:
class CNNmodel(nn.Module):
    def __init__(self):
        super(CNNmodel, self).__init__()
        self.conv1 = nn.Conv2d(1,32,5,1,2)
        self.relu1 = nn.ReLU()
        self.norm1 = nn.BatchNorm2d(32)
        nn.init.xavier_uniform_(self.conv1.weight)

        self.maxpool1 = nn.MaxPool2d(2,2)

        self.conv2 = nn.Conv2d(32,64,3,1,2)
        self.relu2 = nn.ReLU()
        self.norm2 = nn.BatchNorm2d(64)
        nn.init.xavier_uniform_(self.conv2.weight)

        self.maxpool2 = nn.MaxPool2d(2,2)

        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64*8*8,4096)
        self.fcrelu = nn.ReLU()
        self.fc2 = nn.Linear(4096,10)

    def forward(self,x):
        out = self.maxpool1(self.norm1(self.relu1(self.conv1(x))))
        out = self.maxpool2(self.norm2(self.relu2(self.conv2(out))))
        out = self.flatten(out)
        out = self.fcrelu(self.fc1(out))
        out = self.fc2(out)

        return out


## Train

In [46]:
def train(model,train_loader,loss,optimizer,epochs):
    for epoch in range(epochs):
        print('-'*100)
        print(f"epoch:{epoch}")
        print('-'*100)
        time.sleep(0.25)
        model.train()
        train_size = len(train_loader.dataset)

        for batch, (X, y) in enumerate(tqdm(train_loader)):
            # print(X.requires_grad)
            X, y = X.to(device), y.to(device)
            # print(X.requires_grad)
            pred = model(X)
            loss = loss_fn(pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch % 100 == 0:
                loss = loss.item()
                correct = (pred.argmax(1) == y).type(torch.float).sum().item() / batch_size
                sample = batch * batch_size
                print(f"Loss:{loss:>6f} , Accuary:{correct * 100:>6f}%")


In [47]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Use device: {device}")

Use device: cuda


In [101]:
net = LeNetv5().to(device)#accuary:0.9141
print(net)

LeNetv5(
  (conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1))
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1))
  (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (f1): Linear(in_features=2048, out_features=120, bias=True)
  (f2): Linear(in_features=120, out_features=84, bias=True)
  (f3): Linear(in_features=84, out_features=10, bias=True)
  (maxPool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (relu): ReLU()
)


In [86]:
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(net.parameters(),lr=0.015,momentum=0.9,nesterov=True)

In [87]:
train(net,train_loader,loss_fn,optimizer,epochs=5)

----------------------------------------------------------------------------------------------------
epoch:0
----------------------------------------------------------------------------------------------------


  0%|▎                                                                                | 7/1875 [00:00<00:30, 61.41it/s]

Loss:2.349850 , Accuary:18.750000%


  6%|████▋                                                                          | 112/1875 [00:01<00:26, 65.88it/s]

Loss:0.630314 , Accuary:75.000000%


 11%|████████▊                                                                      | 210/1875 [00:03<00:25, 66.13it/s]

Loss:0.253223 , Accuary:93.750000%


 16%|█████████████                                                                  | 309/1875 [00:04<00:23, 66.00it/s]

Loss:0.582471 , Accuary:81.250000%


 22%|█████████████████▍                                                             | 414/1875 [00:06<00:22, 65.69it/s]

Loss:0.335582 , Accuary:90.625000%


 27%|█████████████████████▌                                                         | 512/1875 [00:07<00:20, 66.17it/s]

Loss:0.774256 , Accuary:81.250000%


 33%|█████████████████████████▋                                                     | 610/1875 [00:09<00:19, 65.91it/s]

Loss:0.392376 , Accuary:84.375000%


 38%|█████████████████████████████▊                                                 | 709/1875 [00:10<00:17, 66.65it/s]

Loss:0.288655 , Accuary:90.625000%


 43%|██████████████████████████████████                                             | 807/1875 [00:12<00:16, 65.37it/s]

Loss:0.572831 , Accuary:78.125000%


 49%|██████████████████████████████████████▍                                        | 912/1875 [00:13<00:14, 65.90it/s]

Loss:0.419828 , Accuary:81.250000%


 54%|██████████████████████████████████████████                                    | 1010/1875 [00:15<00:13, 66.45it/s]

Loss:0.458288 , Accuary:78.125000%


 59%|██████████████████████████████████████████████▏                               | 1109/1875 [00:16<00:11, 66.47it/s]

Loss:0.488784 , Accuary:84.375000%


 65%|██████████████████████████████████████████████████▌                           | 1214/1875 [00:18<00:10, 65.26it/s]

Loss:0.290647 , Accuary:87.500000%


 70%|██████████████████████████████████████████████████████▌                       | 1312/1875 [00:19<00:08, 64.60it/s]

Loss:0.412397 , Accuary:78.125000%


 75%|██████████████████████████████████████████████████████████▋                   | 1410/1875 [00:21<00:07, 64.03it/s]

Loss:0.440848 , Accuary:87.500000%


 81%|███████████████████████████████████████████████████████████████               | 1515/1875 [00:23<00:05, 64.01it/s]

Loss:0.400076 , Accuary:87.500000%


 86%|███████████████████████████████████████████████████████████████████           | 1613/1875 [00:24<00:04, 63.91it/s]

Loss:0.207205 , Accuary:93.750000%


 91%|███████████████████████████████████████████████████████████████████████▏      | 1711/1875 [00:26<00:02, 63.61it/s]

Loss:0.557047 , Accuary:81.250000%


 97%|███████████████████████████████████████████████████████████████████████████▍  | 1814/1875 [00:27<00:00, 62.47it/s]

Loss:0.249628 , Accuary:90.625000%


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:28<00:00, 65.16it/s]


----------------------------------------------------------------------------------------------------
epoch:1
----------------------------------------------------------------------------------------------------


  1%|▌                                                                               | 14/1875 [00:00<00:27, 67.92it/s]

Loss:0.184659 , Accuary:90.625000%


  6%|████▌                                                                          | 108/1875 [00:01<00:32, 54.56it/s]

Loss:0.274589 , Accuary:93.750000%


 11%|████████▉                                                                      | 212/1875 [00:03<00:24, 66.61it/s]

Loss:0.143091 , Accuary:96.875000%


 17%|█████████████                                                                  | 310/1875 [00:04<00:23, 65.37it/s]

Loss:0.298136 , Accuary:87.500000%


 22%|█████████████████▏                                                             | 409/1875 [00:06<00:21, 67.67it/s]

Loss:0.239978 , Accuary:87.500000%


 27%|█████████████████████▍                                                         | 508/1875 [00:08<00:23, 57.04it/s]

Loss:0.328653 , Accuary:90.625000%


 33%|█████████████████████████▊                                                     | 613/1875 [00:09<00:20, 62.53it/s]

Loss:0.171054 , Accuary:90.625000%


 38%|█████████████████████████████▉                                                 | 710/1875 [00:11<00:20, 58.13it/s]

Loss:0.325151 , Accuary:93.750000%


 44%|██████████████████████████████████▍                                            | 816/1875 [00:13<00:15, 67.11it/s]

Loss:0.326845 , Accuary:87.500000%


 48%|██████████████████████████████████████▏                                        | 907/1875 [00:14<00:14, 66.16it/s]

Loss:0.203499 , Accuary:87.500000%


 54%|██████████████████████████████████████████                                    | 1012/1875 [00:16<00:12, 67.08it/s]

Loss:0.425458 , Accuary:84.375000%


 59%|██████████████████████████████████████████████▏                               | 1110/1875 [00:17<00:11, 67.14it/s]

Loss:0.204265 , Accuary:87.500000%


 64%|██████████████████████████████████████████████████▎                           | 1209/1875 [00:18<00:09, 67.17it/s]

Loss:0.271540 , Accuary:90.625000%


 70%|██████████████████████████████████████████████████████▎                       | 1307/1875 [00:20<00:08, 66.44it/s]

Loss:0.326742 , Accuary:84.375000%


 75%|██████████████████████████████████████████████████████████▋                   | 1412/1875 [00:22<00:06, 66.73it/s]

Loss:0.276404 , Accuary:87.500000%


 80%|██████████████████████████████████████████████████████████████▊               | 1509/1875 [00:23<00:06, 57.85it/s]

Loss:0.303901 , Accuary:84.375000%


 86%|███████████████████████████████████████████████████████████████████           | 1612/1875 [00:25<00:04, 63.47it/s]

Loss:0.170152 , Accuary:87.500000%


 91%|███████████████████████████████████████████████████████████████████████▏      | 1710/1875 [00:26<00:02, 65.59it/s]

Loss:0.292147 , Accuary:87.500000%


 96%|███████████████████████████████████████████████████████████████████████████▏  | 1808/1875 [00:28<00:01, 54.93it/s]

Loss:0.160229 , Accuary:96.875000%


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:29<00:00, 63.51it/s]


----------------------------------------------------------------------------------------------------
epoch:2
----------------------------------------------------------------------------------------------------


  1%|▌                                                                               | 14/1875 [00:00<00:28, 65.91it/s]

Loss:0.126611 , Accuary:93.750000%


  6%|████▋                                                                          | 112/1875 [00:01<00:26, 66.10it/s]

Loss:0.150013 , Accuary:96.875000%


 11%|████████▊                                                                      | 210/1875 [00:03<00:25, 64.77it/s]

Loss:0.136523 , Accuary:96.875000%


 17%|█████████████▎                                                                 | 316/1875 [00:04<00:23, 67.31it/s]

Loss:0.187342 , Accuary:93.750000%


 22%|█████████████████▍                                                             | 414/1875 [00:06<00:21, 66.83it/s]

Loss:0.157891 , Accuary:90.625000%


 27%|█████████████████████▍                                                         | 510/1875 [00:07<00:24, 54.84it/s]

Loss:0.198272 , Accuary:93.750000%


 33%|█████████████████████████▊                                                     | 613/1875 [00:09<00:21, 59.16it/s]

Loss:0.159379 , Accuary:90.625000%


 38%|█████████████████████████████▉                                                 | 711/1875 [00:11<00:17, 67.01it/s]

Loss:0.262361 , Accuary:96.875000%


 43%|██████████████████████████████████                                             | 809/1875 [00:12<00:16, 66.31it/s]

Loss:0.244801 , Accuary:87.500000%


 49%|██████████████████████████████████████▍                                        | 913/1875 [00:14<00:16, 59.37it/s]

Loss:0.106163 , Accuary:100.000000%


 54%|██████████████████████████████████████████                                    | 1011/1875 [00:15<00:14, 59.91it/s]

Loss:0.481616 , Accuary:78.125000%


 59%|██████████████████████████████████████████████▏                               | 1111/1875 [00:17<00:11, 64.97it/s]

Loss:0.152394 , Accuary:93.750000%


 64%|██████████████████████████████████████████████████▎                           | 1209/1875 [00:18<00:10, 65.17it/s]

Loss:0.221345 , Accuary:90.625000%


 70%|██████████████████████████████████████████████████████▋                       | 1314/1875 [00:20<00:08, 66.77it/s]

Loss:0.240465 , Accuary:93.750000%


 75%|██████████████████████████████████████████████████████████▋                   | 1412/1875 [00:21<00:07, 65.82it/s]

Loss:0.179394 , Accuary:93.750000%


 80%|██████████████████████████████████████████████████████████████▊               | 1509/1875 [00:23<00:07, 48.18it/s]

Loss:0.245004 , Accuary:93.750000%


 86%|██████████████████████████████████████████████████████████████████▊           | 1607/1875 [00:25<00:04, 65.12it/s]

Loss:0.120700 , Accuary:96.875000%


 91%|███████████████████████████████████████████████████████████████████████       | 1708/1875 [00:26<00:02, 59.38it/s]

Loss:0.222029 , Accuary:93.750000%


 97%|███████████████████████████████████████████████████████████████████████████▍  | 1813/1875 [00:28<00:01, 59.08it/s]

Loss:0.153431 , Accuary:96.875000%


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:29<00:00, 63.13it/s]


----------------------------------------------------------------------------------------------------
epoch:3
----------------------------------------------------------------------------------------------------


  0%|▎                                                                                | 7/1875 [00:00<00:30, 60.84it/s]

Loss:0.044676 , Accuary:96.875000%


  6%|████▌                                                                          | 108/1875 [00:01<00:29, 59.72it/s]

Loss:0.097861 , Accuary:96.875000%


 11%|████████▉                                                                      | 213/1875 [00:03<00:31, 53.47it/s]

Loss:0.081668 , Accuary:96.875000%


 16%|█████████████                                                                  | 309/1875 [00:05<00:24, 62.94it/s]

Loss:0.149975 , Accuary:93.750000%


 22%|█████████████████▏                                                             | 407/1875 [00:06<00:22, 65.58it/s]

Loss:0.153716 , Accuary:90.625000%


 27%|█████████████████████▍                                                         | 509/1875 [00:08<00:22, 60.54it/s]

Loss:0.229368 , Accuary:93.750000%


 32%|█████████████████████████▌                                                     | 607/1875 [00:10<00:21, 59.84it/s]

Loss:0.121540 , Accuary:93.750000%


 38%|█████████████████████████████▉                                                 | 712/1875 [00:11<00:18, 64.40it/s]

Loss:0.200933 , Accuary:96.875000%


 43%|██████████████████████████████████▏                                            | 810/1875 [00:13<00:16, 65.44it/s]

Loss:0.263462 , Accuary:84.375000%


 49%|██████████████████████████████████████▍                                        | 912/1875 [00:14<00:16, 59.84it/s]

Loss:0.022553 , Accuary:100.000000%


 54%|██████████████████████████████████████████▏                                   | 1013/1875 [00:16<00:13, 64.41it/s]

Loss:0.561362 , Accuary:84.375000%


 59%|██████████████████████████████████████████████▏                               | 1111/1875 [00:18<00:11, 66.57it/s]

Loss:0.142340 , Accuary:90.625000%


 64%|██████████████████████████████████████████████████▎                           | 1209/1875 [00:19<00:09, 66.94it/s]

Loss:0.112528 , Accuary:96.875000%


 70%|██████████████████████████████████████████████████████▋                       | 1314/1875 [00:21<00:08, 67.00it/s]

Loss:0.179275 , Accuary:93.750000%


 75%|██████████████████████████████████████████████████████████▋                   | 1412/1875 [00:22<00:06, 66.28it/s]

Loss:0.173558 , Accuary:93.750000%


 81%|██████████████████████████████████████████████████████████████▊               | 1510/1875 [00:24<00:05, 66.54it/s]

Loss:0.211320 , Accuary:93.750000%


 86%|██████████████████████████████████████████████████████████████████▉           | 1609/1875 [00:25<00:03, 67.08it/s]

Loss:0.072419 , Accuary:96.875000%


 91%|███████████████████████████████████████████████████████████████████████▎      | 1714/1875 [00:27<00:02, 66.46it/s]

Loss:0.090194 , Accuary:96.875000%


 97%|███████████████████████████████████████████████████████████████████████████▍  | 1812/1875 [00:28<00:00, 66.14it/s]

Loss:0.078481 , Accuary:100.000000%


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:29<00:00, 63.62it/s]


----------------------------------------------------------------------------------------------------
epoch:4
----------------------------------------------------------------------------------------------------


  1%|▌                                                                               | 14/1875 [00:00<00:27, 68.16it/s]

Loss:0.019383 , Accuary:100.000000%


  6%|████▋                                                                          | 112/1875 [00:01<00:26, 67.01it/s]

Loss:0.094487 , Accuary:93.750000%


 11%|████████▊                                                                      | 210/1875 [00:03<00:24, 66.82it/s]

Loss:0.022963 , Accuary:100.000000%


 16%|████████████▊                                                                  | 305/1875 [00:04<00:36, 42.67it/s]

Loss:0.103384 , Accuary:93.750000%


 22%|█████████████████                                                              | 406/1875 [00:07<00:42, 34.92it/s]

Loss:0.026213 , Accuary:100.000000%


 27%|█████████████████████▍                                                         | 508/1875 [00:09<00:28, 48.50it/s]

Loss:0.141880 , Accuary:93.750000%


 33%|█████████████████████████▊                                                     | 612/1875 [00:11<00:20, 63.01it/s]

Loss:0.181456 , Accuary:90.625000%


 38%|█████████████████████████████▉                                                 | 710/1875 [00:12<00:18, 63.63it/s]

Loss:0.117695 , Accuary:96.875000%


 43%|██████████████████████████████████                                             | 809/1875 [00:14<00:15, 66.96it/s]

Loss:0.117766 , Accuary:93.750000%


 48%|██████████████████████████████████████▏                                        | 907/1875 [00:15<00:15, 62.91it/s]

Loss:0.024771 , Accuary:100.000000%


 54%|██████████████████████████████████████████                                    | 1011/1875 [00:17<00:14, 59.51it/s]

Loss:0.256163 , Accuary:90.625000%


 59%|██████████████████████████████████████████████▎                               | 1114/1875 [00:19<00:13, 58.06it/s]

Loss:0.056604 , Accuary:100.000000%


 65%|██████████████████████████████████████████████████▍                           | 1212/1875 [00:20<00:10, 61.51it/s]

Loss:0.180839 , Accuary:90.625000%


 70%|██████████████████████████████████████████████████████▍                       | 1310/1875 [00:22<00:08, 64.46it/s]

Loss:0.099086 , Accuary:93.750000%


 75%|██████████████████████████████████████████████████████████▌                   | 1409/1875 [00:23<00:06, 67.14it/s]

Loss:0.070665 , Accuary:93.750000%


 80%|██████████████████████████████████████████████████████████████▋               | 1507/1875 [00:25<00:05, 66.43it/s]

Loss:0.222081 , Accuary:87.500000%


 86%|███████████████████████████████████████████████████████████████████           | 1612/1875 [00:26<00:04, 65.36it/s]

Loss:0.080569 , Accuary:93.750000%


 91%|███████████████████████████████████████████████████████████████████████▏      | 1710/1875 [00:28<00:02, 66.42it/s]

Loss:0.054739 , Accuary:100.000000%


 96%|███████████████████████████████████████████████████████████████████████████▎  | 1809/1875 [00:29<00:00, 66.02it/s]

Loss:0.086199 , Accuary:93.750000%


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:30<00:00, 61.03it/s]


## Test

In [93]:
def test(model,test_loader,loss_fn):
    test_size = len(test_loader.dataset)
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for batch,(X,y) in enumerate(test_loader):
            X,y = X.to(device),y.to(device)
            pred = net(X)
            loss = loss_fn(pred,y).item()
            total_loss += loss
            correct += (pred.argmax(1)==y).type(torch.float).sum().item()
            if batch % 20 == 0:
                print(f"Test -- process:[{batch*batch_size:>6d}/{test_size:<6d}] loss:{loss:>4f} , accuary:{(correct/((batch+1)*batch_size))*100:>7f}%")
        print(f"Total -- total_loss:{total_loss/batch_size} , accuary:{correct*100/test_size}%")


In [94]:
test(net,test_loader,loss_fn)

Test -- process:[     0/10000 ] loss:0.203414 , accuary:93.750000%
Test -- process:[   640/10000 ] loss:0.325253 , accuary:91.815476%
Test -- process:[  1280/10000 ] loss:0.514904 , accuary:90.777439%
Test -- process:[  1920/10000 ] loss:0.285554 , accuary:91.290984%
Test -- process:[  2560/10000 ] loss:0.301694 , accuary:91.010802%
Test -- process:[  3200/10000 ] loss:0.343973 , accuary:90.501238%
Test -- process:[  3840/10000 ] loss:0.534036 , accuary:90.134298%
Test -- process:[  4480/10000 ] loss:0.069310 , accuary:90.203901%
Test -- process:[  5120/10000 ] loss:0.176400 , accuary:90.159161%
Test -- process:[  5760/10000 ] loss:0.186525 , accuary:90.003453%
Test -- process:[  6400/10000 ] loss:0.301326 , accuary:90.158582%
Test -- process:[  7040/10000 ] loss:0.117950 , accuary:90.101810%
Test -- process:[  7680/10000 ] loss:0.035395 , accuary:90.326763%
Test -- process:[  8320/10000 ] loss:0.127389 , accuary:90.445402%
Test -- process:[  8960/10000 ] loss:0.019366 , accuary:90.491

## Save Model

In [95]:
def save_model(model,name):
    import os
    if not os.path.exists("./model"):
        os.makedirs("./model")
    if not os.path.exists("./model/"+name):
        torch.save(model.state_dict(),'model/'+name)
        print(f"Successfully save the model:{name}!")
    else:
        print(f"the model:[{name}] already exits, please delete first before you save new one!")
    

In [96]:
save_model(net,"cnnmodel.pth")

the model:[cnnmodel.pth] already exits, please delete first before you save new one!


## Load Model

In [103]:
def load_model(model,model_pth):
    model = model
    model.load_state_dict(torch.load(model_pth))
    dataloader = DataLoader(test_data,batch_size=batch_size,shuffle=True)
    model.to(device)
    model.eval()
    correct = 0.0
    with torch.no_grad():
        for X,y in dataloader:
            X,y = X.to(device),y.to(device)
            pred = model(X)
            correct += (pred.argmax(1)==y).type(torch.float).sum().item()
        size = len(dataloader.dataset)
        print(f"Accuary:{correct*100.0/size:>4f}%")



In [107]:
load_model(CNNmodel(),"model/cnnmodel.pth")

Accuary:90.560000%
