In [76]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.autonotebook import tqdm

In [77]:
class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel, stride=1):
        super(ResidualBlock, self).__init__()
        #ResNet18所使用的基本殘差單元，每個單元由兩個3x3卷積層組成，中間有一個BN層和一個ReLU激活函數
        self.left = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(outchannel),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(outchannel)
        )
        self.shortcut = nn.Sequential()
        if stride != 1 or inchannel != outchannel:
            self.shortcut = nn.Sequential(
                nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(outchannel)
            )

    def forward(self, x):
        out = self.left(x)
        out = out + self.shortcut(x)
        out = F.relu(out)

        return out

# 自訂一個神經網絡，使用nn.model，透過__init__初始化每一層神經網路。
# 使用forward連接數據
class ResNet(nn.Module):
    def __init__(self, ResidualBlock, num_classes=100):
        super(ResNet, self).__init__()
        self.inchannel = 64
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.layer1 = self.make_layer(ResidualBlock, 64, 2, stride=1)
        self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2)
        self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2)
        self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2)
        self.fc = nn.Linear(512, num_classes)

    # make_layers函數重複殘差塊，以及shortcut部分
    def make_layer(self, block, channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.inchannel, channels, stride))
            self.inchannel = channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

In [78]:
def ResNet18():
    return ResNet(ResidualBlock)

In [79]:
#Use the ResNet18 on Cifar-100
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import argparse
import os

#check gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#設置超參數
EPOCH = 1         #設置訓練次數
pre_epoch = 0
BATCH_SIZE = 128
LR = 0.01

#prepare dataset and preprocessing
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

trainset = torchvision.datasets.CIFAR100(root='../data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR100(root='../data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

#labels in CIFAR100
classes = ('beaver', 'dolphin', 'otter', 'seal', 'whale',
      'aquarium fish', 'flatfish', 'ray', 'shark', 'trout',
      'orchids', 'poppies', 'roses', 'sunflowers', 'tulips',
      'bottles', 'bowls', 'cans', 'cups', 'plates',
      'apples', 'mushrooms', 'oranges', 'pears', 'sweet peppers',
      'clock', 'computer keyboard', 'lamp', 'telephone', 'television',
      'bed', 'chair', 'couch', 'table', 'wardrobe',
      'bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach',
      'bear', 'leopard', 'lion', 'tiger', 'wolf',
      'bridge', 'castle', 'house', 'road', 'skyscraper',
      'cloud', 'forest', 'mountain', 'plain', 'sea',
      'camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo',
      'fox', 'porcupine', 'possum', 'raccoon', 'skunk',
      'crab', 'lobster', 'snail', 'spider', 'worm',
      'baby', 'boy', 'girl', 'man', 'woman',
      'crocodile', 'dinosaur', 'lizard', 'snake', 'turtle',
      'hamster', 'mouse', 'rabbit', 'shrew', 'squirrel',
      'maple', 'oak', 'palm', 'pine', 'willow',
      'bicycle', 'bus', 'motorcycle', 'pickup' 'truck', 'train',
      'lawn-mower', 'rocket', 'streetcar', 'tank', 'tractor')

#Make model，使用cpu
#define ResNet18
net = ResNet18().to(device)

#define loss funtion & optimizer
criterion = nn.CrossEntropyLoss()         #交叉熵損失函數
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)       #優化器隨機梯度下降

Files already downloaded and verified
Files already downloaded and verified


In [80]:
def acc_():
  data_loader = tqdm(testloader, desc=f"test ", leave=True)

  with torch.no_grad():    #在該模組下，所有計算得出的tensor的requires_grad都自動設置為False
    correct=0
    total=0
    for data in data_loader:
      net.eval()
      images, labels = data[0].to(device), data[1].to(device)
      outputs = net(images)
      _, predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum()

  return (100 * correct / total)      #計算ac

In [81]:
#train
for epoch in range(pre_epoch, EPOCH):
    print('\n==========================Epoch %d==========================' % (epoch + 1))
    net.train()       #啟用batch normalization 和 dropout
    sum_loss = 0.0
    correct = 0.0
    total = 0.0
    data_loader = tqdm(trainloader, desc=f"Epoch {epoch+1}", leave=True)

    for i, data in enumerate(data_loader, 0):
      #prepare dataset
      length = len(trainloader)
      inputs, labels = data
      inputs, labels = inputs.to(device), labels.to(device)
      optimizer.zero_grad()

      #forward & backward
      outputs = net(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()


      #print ac & loss in each batch
      sum_loss += loss.item()
      data_loader.set_postfix({'Loss': sum_loss / (i + 1)})

      _, predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += predicted.eq(labels.data).cpu().sum()
      print('[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% '
            % (epoch + 1, (i + 1 + epoch * length), sum_loss / (i + 1), 100. * correct / total))
    print('==========================================================================')

print("============ Training END ================")





Epoch 1:   0%|          | 0/391 [00:00<?, ?it/s]

[epoch:1, iter:1] Loss: 4.709 | Acc: 0.781% 
[epoch:1, iter:2] Loss: 4.699 | Acc: 0.781% 
[epoch:1, iter:3] Loss: 4.681 | Acc: 1.302% 
[epoch:1, iter:4] Loss: 4.688 | Acc: 1.367% 
[epoch:1, iter:5] Loss: 4.673 | Acc: 1.406% 
[epoch:1, iter:6] Loss: 4.661 | Acc: 1.432% 
[epoch:1, iter:7] Loss: 4.663 | Acc: 1.228% 
[epoch:1, iter:8] Loss: 4.655 | Acc: 1.465% 
[epoch:1, iter:9] Loss: 4.649 | Acc: 1.302% 
[epoch:1, iter:10] Loss: 4.643 | Acc: 1.250% 
[epoch:1, iter:11] Loss: 4.629 | Acc: 1.420% 
[epoch:1, iter:12] Loss: 4.624 | Acc: 1.367% 
[epoch:1, iter:13] Loss: 4.622 | Acc: 1.322% 
[epoch:1, iter:14] Loss: 4.621 | Acc: 1.339% 
[epoch:1, iter:15] Loss: 4.612 | Acc: 1.510% 
[epoch:1, iter:16] Loss: 4.605 | Acc: 1.611% 
[epoch:1, iter:17] Loss: 4.599 | Acc: 1.838% 
[epoch:1, iter:18] Loss: 4.593 | Acc: 1.953% 
[epoch:1, iter:19] Loss: 4.590 | Acc: 2.097% 
[epoch:1, iter:20] Loss: 4.581 | Acc: 2.148% 
[epoch:1, iter:21] Loss: 4.584 | Acc: 2.232% 
[epoch:1, iter:22] Loss: 4.573 | Acc: 2.273

In [82]:
# test network on test set
accuracy_= acc_()

print('\nTest\'s Accuracy: %d %%' % accuracy_)
print("====== Test END ==========")

test :   0%|          | 0/100 [00:00<?, ?it/s]


Test's Accuracy: 15 %


In [83]:
torch.save(net, './model.pt')