# COMP34212 Summative Lab Task2


In [1]:
import torch
import torch.nn as nn
from torch.utils.data.sampler import SubsetRandomSampler
import torch.optim as optim

from tqdm import tqdm
import numpy as np
from torchvision import datasets
import torchvision.transforms as transforms

import os
device = 'cuda' if torch.cuda.is_available() else 'cpu' # Check if GPU is available
print('Using {} device'.format(device))

folder_name = 'checkpoint'
if not os.path.exists(folder_name):
        os.makedirs(folder_name)

Using cuda device


In [2]:
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)

# TODO: change variable names
def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

In [3]:
class BasicBlock(nn.Module):
    """
    A basic block for ResNet18 and ResNet34

    Args:
    inplanes: int, number of input channels
    planes: int, number of output channels
    stride: int, stride of the first convolutional layer
    downsample: nn.Module, downsample layer
    groups: int, number of groups for the convolutional layers
    base_width: int, base width of the convolutional layers
    dilation: int, dilation rate of the convolutional layers
    norm_layer: nn.Module, normalization layer

    Returns:
    out: tensor, output of the basic block
    """
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.relu(self.bn1(self.conv1(x)))

        out = self.bn2(self.conv2(out))

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

In [4]:
class Bottleneck(nn.Module):
    """
    For deeper models, we use the "bottleneck" building block to reduce the number of parameters.
    It is based on the following principle:
    1. A 1x1 convolution reduces the dimensionality of the input to a bottleneck representation.
    2. A 3x3 convolution is applied to the bottleneck representation.
    3. A 1x1 convolution increases the dimensionality of the representation back to the original.
    4. The output is added to the original input.
    5. The result is passed through a ReLU activation function.

    The bottleneck block is used in the ResNet-50, ResNet-101, and ResNet-152 architectures.
    """
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


In [5]:
class ResNet(nn.Module):
    """
    The ResNet architecture is based on the following principles:
    1. The input is passed through a 7x7 convolutional layer with stride 2.
    2. The output is passed through a 3x3 max pooling layer with stride 2.
    3. The output is passed through a series of residual blocks.
    4. The output is passed through a global average pooling layer.
    5. The output is passed through a fully connected layer with softmax activation.

    The ResNet architecture is defined by the number of layers and the type of residual block used.
    The ResNet-18 and ResNet-34 architectures use the "basic block" residual block.
    The ResNet-50, ResNet-101, and ResNet-152 architectures use the "bottleneck" residual block.
    """
    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        #self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)    # delete maxpool layer
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        """
        Create a layer of residual blocks

        Args:
        block: nn.Module, type of residual block: BasicBlock or Bottleneck
        planes: int, number of output channels
        blocks: int, number of blocks
        """
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        #x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    def forward(self, x):
        return self._forward_impl(x)


def _resnet(block, layers, **kwargs):
    model = ResNet(block, layers, **kwargs)
    return model


def ResNet18(**kwargs):
    return _resnet(BasicBlock, [2, 2, 2, 2],**kwargs)

def ResNet34(**kwargs):
    return _resnet(BasicBlock, [3, 4, 6, 3],**kwargs)


def ResNet50(**kwargs):
    return _resnet(Bottleneck, [3, 4, 6, 3],**kwargs)


def ResNet101(**kwargs):
    return _resnet(Bottleneck, [3, 4, 23, 3],**kwargs)


def ResNet152(**kwargs):
    return _resnet(Bottleneck, [3, 8, 36, 3],**kwargs)

In [6]:
class Cutout(object):
    """
    Randomly mask out one or more patches from an image.
    Args:
        n_holes (int): Number of patches to cut out of each image.
        length (int): The length (in pixels) of each square patch.
    """
    def __init__(self, n_holes, length):
        self.n_holes = n_holes
        self.length = length

    def __call__(self, img):
        h = img.size(1)
        w = img.size(2)

        mask = np.ones((h, w), np.float32)

        for n in range(self.n_holes):
        	# (x,y) makes the center of the hole
            y = np.random.randint(h)
            x = np.random.randint(w)

            y1 = np.clip(y - self.length // 2, 0, h)
            y2 = np.clip(y + self.length // 2, 0, h)
            x1 = np.clip(x - self.length // 2, 0, w)
            x2 = np.clip(x + self.length // 2, 0, w)

            mask[y1: y2, x1: x2] = 0.

        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img = img * mask

        return img

In [7]:
num_workers = 0 # number of subprocesses to use for data loading
batch_size = 16 # how many samples per batch to load
valid_size = 0.2    # percentage of training set to use as validation

def read_dataset(batch_size=16,valid_size=0.2,num_workers=0,pic_path='dataset'):
    """
    batch_size: Number of loaded drawings per batch
    valid_size: Percentage of training set to use as validation
    num_workers: Number of subprocesses to use for data loading
    pic_path: The path of the pictrues
    """
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),  # first, we perform a random crop of the image to 32x32 using padding of 4 pixels
        transforms.RandomHorizontalFlip(),  # randomly flip the image horizontally, probability of 0.5
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]), # normalize the image based on the mean and standard deviation from the ImageNet dataset
        Cutout(n_holes=1, length=16),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
    ])


    # Define the data loaders
    train_data = datasets.CIFAR10(pic_path, train=True,
                                download=True, transform=transform_train)
    valid_data = datasets.CIFAR10(pic_path, train=True,
                                download=True, transform=transform_test)
    test_data = datasets.CIFAR10(pic_path, train=False,
                                download=True, transform=transform_test)
        

    # obtain training indices that will be used for validation
    num_train = len(train_data)
    indices = list(range(num_train))
    # random indices
    np.random.shuffle(indices)
    # the ratio of split
    split = int(np.floor(valid_size * num_train))
    # divide data to radin_data and valid_data
    train_idx, valid_idx = indices[split:], indices[:split]

    # define samplers for obtaining training and validation batches
    # 无放回地按照给定的索引列表采样样本元素
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    # prepare data loaders (combine dataset and sampler)
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
        sampler=train_sampler, num_workers=num_workers)
    valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, 
        sampler=valid_sampler, num_workers=num_workers)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, 
        num_workers=num_workers)

    return train_loader,valid_loader,test_loader

# TRAINING

In [15]:
batch_size = 128
train_loader,valid_loader,test_loader = read_dataset(batch_size=batch_size,pic_path='dataset')
n_class = 10
model = ResNet50()


"""
ResNet18网络的7x7降采样卷积和池化操作容易丢失一部分信息,
所以在实验中我们将7x7的降采样层和最大池化层去掉,替换为一个3x3的降采样卷积,
同时减小该卷积层的步长和填充大小
"""
model.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
# model.fc = torch.nn.Linear(512, n_class) # modify the last layer
model.fc = torch.nn.Linear(2048, n_class) # modify the last layer

checkpoint_path = 'checkpoint/resnet50_cifar10.pt'
if os.path.exists(checkpoint_path):
    # 加载模型参数
    model_state_dict = torch.load(checkpoint_path)
    # 将参数加载到模型中
    model.load_state_dict(model_state_dict)
    print("Loaded model parameters from disk.")


model = model.to(device)
criterion = nn.CrossEntropyLoss().to(device)    # Use cross entropy loss function
n_epochs = 250
lr = 0.1

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Loaded model parameters from disk.


In [17]:
early_stop_threshold = 50
valid_loss_min = np.Inf # track change in validation loss
accuracy = []

counter = 0
for epoch in tqdm(range(50, n_epochs+1)):

    # keep track of training and validation loss
    train_loss = 0.0
    valid_loss = 0.0
    total_sample = 0
    right_sample = 0
    
    # Dynamic learning rate
    if counter/10 ==1:
        counter = 0
        lr = lr*0.5
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    ###################
    # 训练集的模型 #
    ###################
    model.train() #作用是启用batch normalization和drop out
    for data, target in train_loader:
        data = data.to(device)
        target = target.to(device)
        # clear the gradients of all optimized variables（清除梯度）
        optimizer.zero_grad()
        # forward pass: compute predicted outputs by passing inputs to the model
        output = model(data).to(device)  #（等价于output = model.forward(data).to(device) ）
        # calculate the batch loss
        loss = criterion(output, target)
        # backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()
        # perform a single optimization step (parameter update)
        optimizer.step()
        # update training loss（
        train_loss += loss.item()*data.size(0)
        
    ######################    
    # 验证集的模型#
    ######################

    model.eval()  # 验证模型
    for data, target in valid_loader:
        data = data.to(device)
        target = target.to(device)
        # forward pass: compute predicted outputs by passing inputs to the model
        output = model(data).to(device)
        # calculate the batch loss
        loss = criterion(output, target)
        # update average validation loss 
        valid_loss += loss.item()*data.size(0)
        # convert output probabilities to predicted class(将输出概率转换为预测类)
        _, pred = torch.max(output, 1)    
        # compare predictions to true label(将预测与真实标签进行比较)
        correct_tensor = pred.eq(target.data.view_as(pred))
        # correct = np.squeeze(correct_tensor.to(device).numpy())
        total_sample += batch_size
        for i in correct_tensor:
            if i:
                right_sample += 1
    print("Accuracy:",100*right_sample/total_sample,"%")
    accuracy.append(right_sample/total_sample)
 
    # 计算平均损失
    train_loss = train_loss/len(train_loader.sampler)
    valid_loss = valid_loss/len(valid_loader.sampler)
        
    # 显示训练集与验证集的损失函数 
    print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
        epoch, train_loss, valid_loss))
    
    # 如果验证集损失函数减少，就保存模型。
    if valid_loss <= valid_loss_min:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(valid_loss_min,valid_loss))
        torch.save(model.state_dict(), 'checkpoint/resnet50_cifar10.pt')
        valid_loss_min = valid_loss
        counter = 0
    else:
        counter += 1
        if counter >= early_stop_threshold:
            print("Early stopping")
            break

  0%|          | 1/201 [00:54<3:00:02, 54.01s/it]

Accuracy: 80.14240506329114 %
Epoch: 50 	Training Loss: 0.632336 	Validation Loss: 0.570307
Validation loss decreased (inf --> 0.570307).  Saving model ...


  1%|          | 2/201 [01:46<2:55:36, 52.95s/it]

Accuracy: 73.22982594936708 %
Epoch: 51 	Training Loss: 0.634021 	Validation Loss: 0.870101


  1%|▏         | 3/201 [02:40<2:56:21, 53.44s/it]

Accuracy: 79.74683544303798 %
Epoch: 52 	Training Loss: 0.622338 	Validation Loss: 0.568731
Validation loss decreased (0.570307 --> 0.568731).  Saving model ...


  2%|▏         | 4/201 [03:32<2:53:46, 52.93s/it]

Accuracy: 80.65664556962025 %
Epoch: 53 	Training Loss: 0.626238 	Validation Loss: 0.525927
Validation loss decreased (0.568731 --> 0.525927).  Saving model ...


  2%|▏         | 5/201 [04:24<2:52:24, 52.78s/it]

Accuracy: 75.75158227848101 %
Epoch: 54 	Training Loss: 0.613401 	Validation Loss: 0.698334


  3%|▎         | 6/201 [05:17<2:51:07, 52.65s/it]

Accuracy: 77.87776898734177 %
Epoch: 55 	Training Loss: 0.614722 	Validation Loss: 0.657110


  3%|▎         | 7/201 [06:09<2:50:02, 52.59s/it]

Accuracy: 76.02848101265823 %
Epoch: 56 	Training Loss: 0.611066 	Validation Loss: 0.689456


  4%|▍         | 8/201 [07:02<2:48:58, 52.53s/it]

Accuracy: 74.95055379746836 %
Epoch: 57 	Training Loss: 0.608751 	Validation Loss: 0.746983


  4%|▍         | 9/201 [07:54<2:47:36, 52.38s/it]

Accuracy: 78.30300632911393 %
Epoch: 58 	Training Loss: 0.601011 	Validation Loss: 0.656484


  5%|▍         | 10/201 [08:46<2:46:35, 52.33s/it]

Accuracy: 77.4129746835443 %
Epoch: 59 	Training Loss: 0.602077 	Validation Loss: 0.677794


  5%|▌         | 11/201 [09:40<2:47:01, 52.74s/it]

Accuracy: 71.50909810126582 %
Epoch: 60 	Training Loss: 0.591525 	Validation Loss: 0.883486


  6%|▌         | 12/201 [10:33<2:46:47, 52.95s/it]

Accuracy: 80.49841772151899 %
Epoch: 61 	Training Loss: 0.590865 	Validation Loss: 0.547771


  6%|▋         | 13/201 [11:25<2:45:22, 52.78s/it]

Accuracy: 68.90822784810126 %
Epoch: 62 	Training Loss: 0.589075 	Validation Loss: 1.063761


  7%|▋         | 14/201 [12:18<2:44:10, 52.68s/it]

Accuracy: 76.13726265822785 %
Epoch: 63 	Training Loss: 0.582849 	Validation Loss: 0.752781


  7%|▋         | 15/201 [13:10<2:43:00, 52.59s/it]

Accuracy: 85.82871835443038 %
Epoch: 64 	Training Loss: 0.455419 	Validation Loss: 0.403172
Validation loss decreased (0.525927 --> 0.403172).  Saving model ...


  8%|▊         | 16/201 [14:03<2:42:00, 52.54s/it]

Accuracy: 83.1190664556962 %
Epoch: 65 	Training Loss: 0.431469 	Validation Loss: 0.481843


  8%|▊         | 17/201 [14:55<2:41:17, 52.60s/it]

Accuracy: 85.07713607594937 %
Epoch: 66 	Training Loss: 0.435992 	Validation Loss: 0.431960


  9%|▉         | 18/201 [15:48<2:40:06, 52.50s/it]

Accuracy: 81.4873417721519 %
Epoch: 67 	Training Loss: 0.440487 	Validation Loss: 0.578799


  9%|▉         | 19/201 [16:40<2:38:44, 52.33s/it]

Accuracy: 87.05498417721519 %
Epoch: 68 	Training Loss: 0.442674 	Validation Loss: 0.356488
Validation loss decreased (0.403172 --> 0.356488).  Saving model ...


 10%|▉         | 20/201 [17:31<2:37:14, 52.13s/it]

Accuracy: 85.7001582278481 %
Epoch: 69 	Training Loss: 0.443254 	Validation Loss: 0.409506


 10%|█         | 21/201 [18:24<2:36:30, 52.17s/it]

Accuracy: 81.52689873417721 %
Epoch: 70 	Training Loss: 0.438804 	Validation Loss: 0.532061


 11%|█         | 22/201 [19:16<2:35:43, 52.20s/it]

Accuracy: 81.92246835443038 %
Epoch: 71 	Training Loss: 0.439964 	Validation Loss: 0.563821


 11%|█▏        | 23/201 [20:08<2:35:01, 52.25s/it]

Accuracy: 83.03006329113924 %
Epoch: 72 	Training Loss: 0.441117 	Validation Loss: 0.486080


 12%|█▏        | 24/201 [21:00<2:33:50, 52.15s/it]

Accuracy: 82.80261075949367 %
Epoch: 73 	Training Loss: 0.436825 	Validation Loss: 0.495567


 12%|█▏        | 25/201 [21:52<2:32:53, 52.12s/it]

Accuracy: 83.72231012658227 %
Epoch: 74 	Training Loss: 0.433041 	Validation Loss: 0.468130


 13%|█▎        | 26/201 [22:44<2:32:03, 52.14s/it]

Accuracy: 79.70727848101266 %
Epoch: 75 	Training Loss: 0.435366 	Validation Loss: 0.648990


 13%|█▎        | 27/201 [23:36<2:31:05, 52.10s/it]

Accuracy: 82.90150316455696 %
Epoch: 76 	Training Loss: 0.435011 	Validation Loss: 0.482146


 14%|█▍        | 28/201 [24:28<2:30:06, 52.06s/it]

Accuracy: 84.6815664556962 %
Epoch: 77 	Training Loss: 0.428483 	Validation Loss: 0.453471


 14%|█▍        | 29/201 [25:20<2:29:12, 52.05s/it]

Accuracy: 81.06210443037975 %
Epoch: 78 	Training Loss: 0.427465 	Validation Loss: 0.573208


 15%|█▍        | 30/201 [26:12<2:28:22, 52.06s/it]

Accuracy: 89.84375 %
Epoch: 79 	Training Loss: 0.327670 	Validation Loss: 0.273553
Validation loss decreased (0.356488 --> 0.273553).  Saving model ...


 15%|█▌        | 31/201 [27:05<2:27:37, 52.10s/it]

Accuracy: 89.9426424050633 %
Epoch: 80 	Training Loss: 0.302730 	Validation Loss: 0.268266
Validation loss decreased (0.273553 --> 0.268266).  Saving model ...


 16%|█▌        | 32/201 [27:57<2:26:46, 52.11s/it]

Accuracy: 89.2998417721519 %
Epoch: 81 	Training Loss: 0.303498 	Validation Loss: 0.293514


 16%|█▋        | 33/201 [28:49<2:26:04, 52.17s/it]

Accuracy: 88.33069620253164 %
Epoch: 82 	Training Loss: 0.302829 	Validation Loss: 0.318960


 17%|█▋        | 34/201 [29:41<2:25:05, 52.13s/it]

Accuracy: 85.94738924050633 %
Epoch: 83 	Training Loss: 0.299301 	Validation Loss: 0.407845


 17%|█▋        | 35/201 [30:33<2:23:58, 52.04s/it]

Accuracy: 88.15268987341773 %
Epoch: 84 	Training Loss: 0.300875 	Validation Loss: 0.327559


 18%|█▊        | 36/201 [31:25<2:23:05, 52.03s/it]

Accuracy: 88.14280063291139 %
Epoch: 85 	Training Loss: 0.300089 	Validation Loss: 0.332321


 18%|█▊        | 37/201 [32:17<2:22:08, 52.00s/it]

Accuracy: 89.31962025316456 %
Epoch: 86 	Training Loss: 0.302136 	Validation Loss: 0.300809


 19%|█▉        | 38/201 [33:09<2:21:19, 52.02s/it]

Accuracy: 88.81526898734177 %
Epoch: 87 	Training Loss: 0.303365 	Validation Loss: 0.328184


 19%|█▉        | 39/201 [34:01<2:20:52, 52.18s/it]

Accuracy: 88.39003164556962 %
Epoch: 88 	Training Loss: 0.301836 	Validation Loss: 0.316858


 20%|█▉        | 40/201 [34:54<2:20:05, 52.21s/it]

Accuracy: 87.25276898734177 %
Epoch: 89 	Training Loss: 0.310333 	Validation Loss: 0.351705


 20%|██        | 41/201 [35:46<2:19:11, 52.20s/it]

Accuracy: 89.01305379746836 %
Epoch: 90 	Training Loss: 0.296114 	Validation Loss: 0.312075


 21%|██        | 42/201 [36:38<2:18:17, 52.18s/it]

Accuracy: 91.80181962025317 %
Epoch: 91 	Training Loss: 0.230466 	Validation Loss: 0.214685
Validation loss decreased (0.268266 --> 0.214685).  Saving model ...


 21%|██▏       | 43/201 [37:30<2:17:26, 52.19s/it]

Accuracy: 91.80181962025317 %
Epoch: 92 	Training Loss: 0.208800 	Validation Loss: 0.213123
Validation loss decreased (0.214685 --> 0.213123).  Saving model ...


 22%|██▏       | 44/201 [38:23<2:16:36, 52.21s/it]

Accuracy: 91.43591772151899 %
Epoch: 93 	Training Loss: 0.202236 	Validation Loss: 0.240314


 22%|██▏       | 45/201 [39:15<2:15:48, 52.23s/it]

Accuracy: 91.13924050632912 %
Epoch: 94 	Training Loss: 0.208529 	Validation Loss: 0.248212


 23%|██▎       | 46/201 [40:07<2:14:51, 52.21s/it]

Accuracy: 91.70292721518987 %
Epoch: 95 	Training Loss: 0.196702 	Validation Loss: 0.233871


 23%|██▎       | 47/201 [40:59<2:14:02, 52.23s/it]

Accuracy: 91.40625 %
Epoch: 96 	Training Loss: 0.200884 	Validation Loss: 0.231084


 24%|██▍       | 48/201 [41:51<2:13:08, 52.21s/it]

Accuracy: 91.48536392405063 %
Epoch: 97 	Training Loss: 0.196145 	Validation Loss: 0.235855


 24%|██▍       | 49/201 [42:44<2:12:21, 52.25s/it]

Accuracy: 91.28757911392405 %
Epoch: 98 	Training Loss: 0.200598 	Validation Loss: 0.241200


 25%|██▍       | 50/201 [43:36<2:11:46, 52.36s/it]

Accuracy: 91.0304588607595 %
Epoch: 99 	Training Loss: 0.194930 	Validation Loss: 0.244980


 25%|██▌       | 51/201 [44:30<2:11:58, 52.79s/it]

Accuracy: 91.35680379746836 %
Epoch: 100 	Training Loss: 0.192746 	Validation Loss: 0.238175


 26%|██▌       | 52/201 [45:24<2:11:42, 53.03s/it]

Accuracy: 91.5051424050633 %
Epoch: 101 	Training Loss: 0.196765 	Validation Loss: 0.236556


 26%|██▋       | 53/201 [46:16<2:10:27, 52.89s/it]

Accuracy: 91.58425632911393 %
Epoch: 102 	Training Loss: 0.193287 	Validation Loss: 0.242441


 27%|██▋       | 54/201 [47:09<2:09:29, 52.86s/it]

Accuracy: 92.51384493670886 %
Epoch: 103 	Training Loss: 0.151584 	Validation Loss: 0.199042
Validation loss decreased (0.213123 --> 0.199042).  Saving model ...


 27%|██▋       | 55/201 [48:02<2:08:20, 52.74s/it]

Accuracy: 92.93908227848101 %
Epoch: 104 	Training Loss: 0.136558 	Validation Loss: 0.193475
Validation loss decreased (0.199042 --> 0.193475).  Saving model ...


 28%|██▊       | 56/201 [48:54<2:07:16, 52.67s/it]

Accuracy: 92.79074367088607 %
Epoch: 105 	Training Loss: 0.130709 	Validation Loss: 0.204655


 28%|██▊       | 57/201 [49:46<2:06:12, 52.58s/it]

Accuracy: 92.74129746835443 %
Epoch: 106 	Training Loss: 0.130568 	Validation Loss: 0.207581


 29%|██▉       | 58/201 [50:39<2:05:15, 52.56s/it]

Accuracy: 93.0676424050633 %
Epoch: 107 	Training Loss: 0.124107 	Validation Loss: 0.198809


 29%|██▉       | 59/201 [51:32<2:04:25, 52.57s/it]

Accuracy: 92.75118670886076 %
Epoch: 108 	Training Loss: 0.127247 	Validation Loss: 0.204052


 30%|██▉       | 60/201 [52:24<2:03:33, 52.58s/it]

Accuracy: 92.50395569620254 %
Epoch: 109 	Training Loss: 0.122646 	Validation Loss: 0.216213


 30%|███       | 61/201 [53:17<2:02:43, 52.60s/it]

Accuracy: 92.84018987341773 %
Epoch: 110 	Training Loss: 0.122787 	Validation Loss: 0.214490


 31%|███       | 62/201 [54:09<2:01:41, 52.53s/it]

Accuracy: 92.6621835443038 %
Epoch: 111 	Training Loss: 0.123992 	Validation Loss: 0.207953


 31%|███▏      | 63/201 [55:02<2:00:45, 52.51s/it]

Accuracy: 92.97863924050633 %
Epoch: 112 	Training Loss: 0.120921 	Validation Loss: 0.200244


 32%|███▏      | 64/201 [55:54<1:59:54, 52.51s/it]

Accuracy: 92.8006329113924 %
Epoch: 113 	Training Loss: 0.124189 	Validation Loss: 0.194914


 32%|███▏      | 65/201 [56:47<1:59:08, 52.56s/it]

Accuracy: 92.5632911392405 %
Epoch: 114 	Training Loss: 0.123601 	Validation Loss: 0.224019


 33%|███▎      | 66/201 [57:40<1:58:21, 52.60s/it]

Accuracy: 93.32476265822785 %
Epoch: 115 	Training Loss: 0.099143 	Validation Loss: 0.197105


 33%|███▎      | 67/201 [58:32<1:57:28, 52.60s/it]

Accuracy: 93.41376582278481 %
Epoch: 116 	Training Loss: 0.085808 	Validation Loss: 0.194309


 34%|███▍      | 68/201 [59:25<1:56:37, 52.61s/it]

Accuracy: 93.57199367088607 %
Epoch: 117 	Training Loss: 0.087482 	Validation Loss: 0.184658
Validation loss decreased (0.193475 --> 0.184658).  Saving model ...


 34%|███▍      | 69/201 [1:00:17<1:55:47, 52.63s/it]

Accuracy: 93.49287974683544 %
Epoch: 118 	Training Loss: 0.083790 	Validation Loss: 0.183780
Validation loss decreased (0.184658 --> 0.183780).  Saving model ...


 35%|███▍      | 70/201 [1:01:10<1:54:47, 52.57s/it]

Accuracy: 93.66099683544304 %
Epoch: 119 	Training Loss: 0.082551 	Validation Loss: 0.183531
Validation loss decreased (0.183780 --> 0.183531).  Saving model ...


 35%|███▌      | 71/201 [1:02:02<1:53:54, 52.57s/it]

Accuracy: 93.53243670886076 %
Epoch: 120 	Training Loss: 0.081689 	Validation Loss: 0.188392


 36%|███▌      | 72/201 [1:02:55<1:53:00, 52.56s/it]

Accuracy: 93.69066455696202 %
Epoch: 121 	Training Loss: 0.079248 	Validation Loss: 0.191811


 36%|███▋      | 73/201 [1:03:48<1:52:22, 52.68s/it]

Accuracy: 93.69066455696202 %
Epoch: 122 	Training Loss: 0.079680 	Validation Loss: 0.186388


 37%|███▋      | 74/201 [1:04:40<1:51:08, 52.51s/it]

Accuracy: 93.43354430379746 %
Epoch: 123 	Training Loss: 0.079437 	Validation Loss: 0.196307


 37%|███▋      | 75/201 [1:05:33<1:50:17, 52.52s/it]

Accuracy: 93.45332278481013 %
Epoch: 124 	Training Loss: 0.074719 	Validation Loss: 0.186454


 38%|███▊      | 76/201 [1:06:25<1:49:30, 52.57s/it]

Accuracy: 93.45332278481013 %
Epoch: 125 	Training Loss: 0.073976 	Validation Loss: 0.192169


 38%|███▊      | 77/201 [1:07:17<1:48:14, 52.38s/it]

Accuracy: 93.55221518987342 %
Epoch: 126 	Training Loss: 0.075614 	Validation Loss: 0.194275


 39%|███▉      | 78/201 [1:08:09<1:47:04, 52.23s/it]

Accuracy: 93.62143987341773 %
Epoch: 127 	Training Loss: 0.076406 	Validation Loss: 0.193529


 39%|███▉      | 79/201 [1:09:01<1:46:12, 52.23s/it]

Accuracy: 93.50276898734177 %
Epoch: 128 	Training Loss: 0.072240 	Validation Loss: 0.195588


 40%|███▉      | 80/201 [1:09:54<1:45:22, 52.25s/it]

Accuracy: 93.38409810126582 %
Epoch: 129 	Training Loss: 0.071518 	Validation Loss: 0.212627


 40%|████      | 81/201 [1:10:46<1:44:27, 52.23s/it]

Accuracy: 93.76977848101266 %
Epoch: 130 	Training Loss: 0.066039 	Validation Loss: 0.186389


 41%|████      | 82/201 [1:11:38<1:43:31, 52.20s/it]

Accuracy: 93.87856012658227 %
Epoch: 131 	Training Loss: 0.061045 	Validation Loss: 0.186369


 41%|████▏     | 83/201 [1:12:30<1:42:27, 52.10s/it]

Accuracy: 93.96756329113924 %
Epoch: 132 	Training Loss: 0.056554 	Validation Loss: 0.186337


 42%|████▏     | 84/201 [1:13:22<1:41:31, 52.06s/it]

Accuracy: 93.88844936708861 %
Epoch: 133 	Training Loss: 0.056791 	Validation Loss: 0.185710


 42%|████▏     | 85/201 [1:14:14<1:40:36, 52.04s/it]

Accuracy: 93.82911392405063 %
Epoch: 134 	Training Loss: 0.054879 	Validation Loss: 0.195017


 43%|████▎     | 86/201 [1:15:06<1:39:53, 52.12s/it]

Accuracy: 93.87856012658227 %
Epoch: 135 	Training Loss: 0.054938 	Validation Loss: 0.187765


 43%|████▎     | 87/201 [1:15:58<1:38:56, 52.08s/it]

Accuracy: 93.73022151898734 %
Epoch: 136 	Training Loss: 0.054468 	Validation Loss: 0.190775


 44%|████▍     | 88/201 [1:16:50<1:38:07, 52.10s/it]

Accuracy: 93.8192246835443 %
Epoch: 137 	Training Loss: 0.049170 	Validation Loss: 0.186482


 44%|████▍     | 89/201 [1:17:42<1:37:10, 52.06s/it]

Accuracy: 93.79944620253164 %
Epoch: 138 	Training Loss: 0.051817 	Validation Loss: 0.186555


 45%|████▍     | 90/201 [1:18:34<1:36:17, 52.05s/it]

Accuracy: 94.04667721518987 %
Epoch: 139 	Training Loss: 0.050190 	Validation Loss: 0.184117


 45%|████▌     | 91/201 [1:19:26<1:35:28, 52.08s/it]

Accuracy: 94.06645569620254 %
Epoch: 140 	Training Loss: 0.046881 	Validation Loss: 0.182184
Validation loss decreased (0.183531 --> 0.182184).  Saving model ...


 46%|████▌     | 92/201 [1:20:18<1:34:37, 52.08s/it]

Accuracy: 94.00712025316456 %
Epoch: 141 	Training Loss: 0.045628 	Validation Loss: 0.179460
Validation loss decreased (0.182184 --> 0.179460).  Saving model ...


 46%|████▋     | 93/201 [1:21:11<1:33:45, 52.09s/it]

Accuracy: 94.08623417721519 %
Epoch: 142 	Training Loss: 0.045308 	Validation Loss: 0.180949


 47%|████▋     | 94/201 [1:22:03<1:33:00, 52.15s/it]

Accuracy: 93.99723101265823 %
Epoch: 143 	Training Loss: 0.043415 	Validation Loss: 0.184480


 47%|████▋     | 95/201 [1:22:55<1:32:07, 52.15s/it]

Accuracy: 94.10601265822785 %
Epoch: 144 	Training Loss: 0.044149 	Validation Loss: 0.188592


 48%|████▊     | 96/201 [1:23:47<1:31:19, 52.18s/it]

Accuracy: 94.09612341772151 %
Epoch: 145 	Training Loss: 0.045149 	Validation Loss: 0.186523


 48%|████▊     | 97/201 [1:24:39<1:30:21, 52.13s/it]

Accuracy: 94.07634493670886 %
Epoch: 146 	Training Loss: 0.042344 	Validation Loss: 0.183628


 49%|████▉     | 98/201 [1:25:31<1:29:21, 52.05s/it]

Accuracy: 94.11590189873418 %
Epoch: 147 	Training Loss: 0.041451 	Validation Loss: 0.185923


 49%|████▉     | 99/201 [1:26:24<1:28:44, 52.20s/it]

Accuracy: 94.08623417721519 %
Epoch: 148 	Training Loss: 0.040387 	Validation Loss: 0.185570


 50%|████▉     | 100/201 [1:27:17<1:28:20, 52.48s/it]

Accuracy: 94.09612341772151 %
Epoch: 149 	Training Loss: 0.039521 	Validation Loss: 0.184303


 50%|█████     | 101/201 [1:28:09<1:27:23, 52.44s/it]

Accuracy: 94.01700949367088 %
Epoch: 150 	Training Loss: 0.039039 	Validation Loss: 0.187996


 51%|█████     | 102/201 [1:29:02<1:26:31, 52.44s/it]

Accuracy: 94.10601265822785 %
Epoch: 151 	Training Loss: 0.043351 	Validation Loss: 0.182211


 51%|█████     | 103/201 [1:29:54<1:25:26, 52.31s/it]

Accuracy: 94.17523734177215 %
Epoch: 152 	Training Loss: 0.038334 	Validation Loss: 0.182334


 52%|█████▏    | 104/201 [1:30:46<1:24:24, 52.21s/it]

Accuracy: 94.08623417721519 %
Epoch: 153 	Training Loss: 0.040360 	Validation Loss: 0.185629


 52%|█████▏    | 105/201 [1:31:38<1:23:24, 52.13s/it]

Accuracy: 94.19501582278481 %
Epoch: 154 	Training Loss: 0.039648 	Validation Loss: 0.185323


 52%|█████▏    | 105/201 [1:31:56<1:24:03, 52.53s/it]


KeyboardInterrupt: 

# TEST


In [18]:
n_class = 10
batch_size = 100
train_loader,valid_loader,test_loader = read_dataset(batch_size=batch_size,pic_path='dataset')
model = ResNet50() # 得到预训练模型
model.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
# model.fc = torch.nn.Linear(512, n_class) # 将最后的全连接层修改
model.fc = torch.nn.Linear(2048, n_class) # modify the last layer

# 载入权重
model.load_state_dict(torch.load('checkpoint/resnet50_cifar10.pt'))
model = model.to(device)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [20]:

total_sample = 0
right_sample = 0
model.eval()  # 验证模型
for data, target in test_loader:
    data = data.to(device)
    target = target.to(device)
    # forward pass: compute predicted outputs by passing inputs to the model
    output = model(data).to(device)
    # convert output probabilities to predicted class(将输出概率转换为预测类)
    _, pred = torch.max(output, 1)    
    # compare predictions to true label(将预测与真实标签进行比较)
    correct_tensor = pred.eq(target.data.view_as(pred))
    # correct = np.squeeze(correct_tensor.to(device).numpy())
    total_sample += batch_size
    for i in correct_tensor:
        if i:
            right_sample += 1
print("Accuracy:",100*right_sample/total_sample,"%")

Accuracy: 94.91 %
