In [1]:
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

# PyTorch
In this notebook you will gain some hands-on experience with [PyTorch](https://pytorch.org/), one of the major frameworks for deep learning. To install PyTorch. follow [the official installation instructions](https://pytorch.org/get-started/locally/). Make sure that you select the correct OS & select the version with CUDA if your computer supports it.
If you do not have an Nvidia GPU, you can install the CPU version by setting `CUDA` to `None`.
However, in this case we recommend using [Google Colab](https://colab.research.google.com/).
Make sure that you enable GPU acceleration in `Runtime > Change runtime type`.

You will start by re-implementing some common features of deep neural networks (dropout and batch normalization) and then implement a very popular modern architecture for image classification (ResNet) and improve its training loop.

# 1. Dropout
Dropout is a form of regularization for neural networks. It works by randomly setting activations (values) to 0, each one with equal probability `p`. The values are then scaled by a factor $\frac{1}{1-p}$ to conserve their mean.

Dropout effectively trains a pseudo-ensemble of models with stochastic gradient descent. During evaluation we want to use the full ensemble and therefore have to turn off dropout. Use `self.training` to check if the model is in training or evaluation mode.

Do not use any dropout implementation from PyTorch for this!

In [2]:
class Dropout(nn.Module):
    """
    Dropout, as discussed in the lecture and described here:
    https://pytorch.org/docs/stable/nn.html#torch.nn.Dropout
    
    Args:
        p: float, dropout probability
    """
    def __init__(self, p):
        super().__init__()
        self.p = p
        
    def forward(self, input):
        """
        The module's forward pass.
        This has to be implemented for every PyTorch module.
        PyTorch then automatically generates the backward pass
        by dynamically generating the computational graph during
        execution.
        
        Args:
            input: PyTorch tensor, arbitrary shape

        Returns:
            PyTorch tensor, same shape as input
        """
        
        # TODO: Set values randomly to 0.
        input_array = input.numpy()
        input_flatten = input_array.flatten()
        random_num = np.random.choice(a = input_flatten.shape[0] ,size = int(input_flatten.shape[0] * self.p) ,replace = False)
        for num in random_num:
            input_flatten[num] = 0
        input_flatten /= (1-self.p)
        input_flatten.reshape(input_array.shape)
        
        return  torch.from_numpy(input_flatten)  

In [3]:
# Test dropout
test = torch.ones(10000)
dropout = Dropout(0.5)
test_dropped = dropout(test)
#print(test_dropped)

# These assertions can in principle fail due to bad luck, but
# if implemented correctly they should almost always succeed.
assert np.isclose(test_dropped.sum().item(), 10_000, atol=400)
assert np.isclose((test_dropped > 0).sum().item(), 5_000, atol=200)

# 2. Batch normalization
Batch normalization is a trick use to smoothen the loss landscape and improve training. It is defined as the function
$$y = \frac{x - \mu_x}{\sigma_x + \epsilon} \cdot \gamma + \beta$$,
where $\gamma$ and $\beta$ and learnable parameters and $\epsilon$ is a some small number to avoid dividing by zero. The Statistics $\mu_x$ and $\sigma_x$ are taken separately for each feature. In a CNN this means averaging over the batch and all pixels.

Do not use any batch normalization implementation from PyTorch for this!

In [4]:
class BatchNorm(nn.Module):
    """
    Batch normalization, as discussed in the lecture and similar to
    https://pytorch.org/docs/stable/nn.html#torch.nn.BatchNorm1d
    
    Only uses batch statistics (no running mean for evaluation).
    Batch statistics are calculated for a single dimension.
    Gamma is initialized as 1, beta as 0.
    
    Args:
        num_features: Number of features to calculate batch statistics for.
    """
    def __init__(self, num_features):
        super().__init__()
        
        # TODO: Initialize the required parameters
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        
    def forward(self, input):
        """
        Batch normalization over the dimension C of (N, C, L).
        
        Args:
            input: PyTorch tensor, shape [N, C, L]
            
        Return:
            PyTorch tensor, same shape as input
        """
        eps = 1e-5
        feature_mean,feature_var =  [],[]
        
        for i in range(input.shape[1]):
            feature_mean.append(input[:,i,:].mean().item())
            feature_var.append(input[:,i,:].var().item())
        
        feature_mean_ts = torch.from_numpy(np.array(feature_mean)).float()
        feature_var_ts = torch.from_numpy(np.array(feature_var)).float()
       
        output = ((input -  feature_mean_ts[None,:,None]) / torch.sqrt(feature_var_ts[None,:,None] + eps)) * self.gamma[None,:,None] + self.beta [None,:,None] 
        return output
        
        # TODO: Implement the required transformation

In [5]:
# Tests the batch normalization implementation
torch.random.manual_seed(42)
test = torch.randn(8, 2, 4)

b1 = BatchNorm(2)
test_b1 = b1(test)

b2 = nn.BatchNorm1d(2, affine=False, track_running_stats=False)
test_b2 = b2(test)


assert torch.allclose(test_b1, test_b2, rtol=0.02)

# 3. ResNet
ResNet is the models that first introduced residual connections (a form of skip connections). It is a rather simple, but successful and very popular architecture. In this part of the exercise we will re-implement it step by step.

Note that there is also an [improved version of ResNet](https://arxiv.org/abs/1603.05027) with optimized residual blocks. Here we will implement the [original version](https://arxiv.org/abs/1512.03385) for CIFAR-10. Your dropout and batchnorm implementations won't help you here. Just use PyTorch's own layers.

This is just a convenience function to make e.g. `nn.Sequential` more flexible. It is e.g. useful in combination with `x.squeeze()`.

In [6]:
class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)

We begin by implementing the residual blocks. The block is illustrated by this sketch:

![Residual connection](img/residual_connection.png)

Note that we use 'SAME' padding, no bias, and batch normalization after each convolution. You do not need `nn.Sequential` here. The skip connection is already implemented as `self.skip`. It can handle different strides and increases in the number of channels.

In [24]:
class ResidualBlock(nn.Module):
    """
    The residual block used by ResNet.
    
    Args:
        in_channels: The number of channels (feature maps) of the incoming embedding
        out_channels: The number of channels after the first convolution
        stride: Stride size of the first convolution, used for downsampling
    """
    
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()        
        if stride > 1 or in_channels != out_channels:
            # Add strides in the skip connection and zeros for the new channels.
            self.skip = Lambda(lambda x: F.pad(x[:, :, ::stride, ::stride],
                                               (0, 0, 0, 0, 0, out_channels - in_channels),
                                               mode="constant", value=0))
        else:
            self.skip = nn.Sequential()
            
        # TODO: Initialize the required layers
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1,bias=False)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1,bias=False)
        self.bn = nn.BatchNorm2d(out_channels)#为什么是out_channels?
        self.relu = nn.ReLU()
        
    def forward(self, input):
        # TODO: Execute the required layers and functions
        residual = input
        x = input
        x = self.relu(self.bn(self.conv1(x)))
        x = self.bn(self.conv2(x))
        x += self.skip(residual)
        output = self.relu(x)
        return output 

Next we implement a stack of residual blocks for convenience. The first layer in the block is the one changing the number of channels and downsampling. You can use `nn.ModuleList` to use a list of child modules.

In [25]:
class ResidualStack(nn.Module):
    """
    A stack of residual blocks.
    
    Args:
        in_channels: The number of channels (feature maps) of the incoming embedding
        out_channels: The number of channels after the first layer
        stride: Stride size of the first layer, used for downsampling
        num_blocks: Number of residual blocks
    """
    
    def __init__(self, in_channels, out_channels, stride, num_blocks):
        super().__init__()
        
        # TODO: Initialize the required layers (blocks)
        if num_blocks == 1 :
            self.modulelist = nn.ModuleList([ResidualBlock(in_channels,out_channels,stride)])
        else :
            self.modulelist = nn.ModuleList([ResidualBlock(in_channels,out_channels,stride)])
            for i in range (num_blocks - 1):
                self.modulelist.append(ResidualBlock(out_channels,out_channels))
        
    def forward(self, input):
        # TODO: Execute the layers (blocks)
        for mod in self.modulelist:
            input = mod(input)
        return input

Now we are finally ready to implement the full model! To do this, use the `nn.Sequential` API and carefully read the following paragraph from the paper (Fig. 3 is not important):

![ResNet CIFAR10 description](img/resnet_cifar10_description.png)

Note that a convolution layer is always convolution + batch norm + activation (ReLU), that each ResidualBlock contains 2 layers, and that you might have to `squeeze` the embedding before the dense (fully-connected) layer.

In [70]:
n = 5
num_classes = 10

# TODO: Implement ResNet via nn.Sequential
resnet = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, padding=1,bias=False),
    nn.BatchNorm2d(16),
    nn.ReLU(),
    ResidualStack(in_channels=16,out_channels=16,stride=1,num_blocks=n), 
    ResidualStack(in_channels=16,out_channels=32,stride=2,num_blocks=n), 
    ResidualStack(in_channels=32,out_channels=64,stride=2,num_blocks=n),
    nn.AdaptiveAvgPool2d(1),
    Lambda(lambda x: torch.squeeze(x)),
    nn.Linear(64, num_classes),
         
)

Next we need to initialize the weights of our model.

In [71]:
def initialize_weight(module):
    if isinstance(module, (nn.Linear, nn.Conv2d)):
        nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
    elif isinstance(module, nn.BatchNorm2d):
        nn.init.constant_(module.weight, 1)
        nn.init.constant_(module.bias, 0)
        
resnet.apply(initialize_weight);

# 4. Training
So now we have a shiny new model, but that doesn't really help when we can't train it. So that's what we do next.

First we need to load the data. Note that we split the official training data into train and validation sets, because you must not look at the test set until you are completely done developing your model and report the final results. Some people don't do this properly, but you should not copy other people's bad habits.

In [72]:
class CIFAR10Subset(torchvision.datasets.CIFAR10):
    """
    Get a subset of the CIFAR10 dataset, according to the passed indices.
    """
    def __init__(self, *args, idx=None, **kwargs):
        super().__init__(*args, **kwargs)
        
        if idx is None:
            return
        
        self.data = self.data[idx]
        targets_np = np.array(self.targets)
        self.targets = targets_np[idx].tolist()

We next define transformations that change the images into PyTorch tensors, standardize the values according to the precomputed mean and standard deviation, and provide data augmentation for the training set.

In [73]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, 4),
    transforms.ToTensor(),
    normalize,
])
transform_eval = transforms.Compose([
    transforms.ToTensor(),
    normalize
])

In [74]:
ntrain = 45_000
train_set = CIFAR10Subset(root='./data', train=True, idx=range(ntrain),
                          download=True, transform=transform_train)
val_set = CIFAR10Subset(root='./data', train=True, idx=range(ntrain,50_000),
                        download=True, transform=transform_eval)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform_eval)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [75]:
dataloaders = {}
dataloaders['train'] = torch.utils.data.DataLoader(train_set, batch_size=128,
                                                   shuffle=True, num_workers=0,#每个epoch执行之前重新划分minibatch
                                                   pin_memory=False)
dataloaders['val'] = torch.utils.data.DataLoader(val_set, batch_size=128,
                                                 shuffle=False, num_workers=0,
                                                 pin_memory=False)
dataloaders['test'] = torch.utils.data.DataLoader(test_set, batch_size=128,
                                                  shuffle=False, num_workers=0,
                                                  pin_memory=False)

Next we push the model to our GPU (if there is one).

In [76]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
resnet.to(device);
print(torch.cuda.is_available())

False


Next we define a helper method that does one epoch of training or evaluation. We have only defined training here, so you need to implement the necessary changes for evaluation!

In [77]:
def run_epoch(model, optimizer, dataloader, train):
    """
    Run one epoch of training or evaluation.
    
    Args:
        model: The model used for prediction
        optimizer: Optimization algorithm for the model
        dataloader: Dataloader providing the data to run our model on
        train: Whether this epoch is used for training or evaluation
        
    Returns:
        Loss and accuracy in this epoch.
    """
    # TODO: Change the necessary parts to work correctly during evaluation (train=False)
    #print("runepoch running")
    device = next(model.parameters()).device
    epoch_loss = 0.0
    epoch_acc = 0.0
   
    #print("going to enter train mode")
    if train :
        model.train() # Set model to training mode (for e.g. batch normalization, dropout)
        #print("runepoch model.train done")
        i = 0
        for xb, yb in dataloader:#一个xb是一个minibatch
            #print(xb.shape)
            #print(yb.shaoe)
            #print("going to transfer data to cpu")
            xb, yb = xb.to(device), yb.to(device)
            #print("transfer done")
            #with torch.set_grad_enabled(True):
            pred = model(xb)
            loss = F.cross_entropy(pred, yb)
            top1 = torch.argmax(pred, dim=1)
            ncorrect = torch.sum(top1 == yb)
            #print("current epoch forward done")
            optimizer.zero_grad()
            loss.backward()
            optimizer.step() 
            #print("current epoch backward done")
                
            epoch_loss += loss.item()
            epoch_acc += ncorrect.item()
            print("loss and accuracy at minibatch {}!".format(i))
            print(epoch_loss)
            print(epoch_acc)
            i += 1
                
            
    else :
        model.eval()
        for xb, yb in dataloader:
            xb, yb = xb.to(device), yb.to(device)
            pred = model(xb)
            loss = F.cross_entropy(pred, yb)
            top1 = torch.argmax(pred, dim=1)
            ncorrect = torch.sum(top1 == yb)
            
            epoch_loss += loss.item()
            epoch_acc += ncorrect.item() 
     
    
    epoch_loss /= len(dataloader.dataset)
    epoch_acc /= len(dataloader.dataset)
    
    return epoch_loss, epoch_acc     
     

Next we implement a method for fitting (training) our model. For many models early stopping can save a lot of training time. Your task is to add early stopping to the loop (based on validation accuracy). Early stopping usually means exiting the training loop if the validation accuracy hasn't improved for `patience` number of steps. Don't forget to save the best model parameters according to validation accuracy. You will need `copy.deepcopy` and the `state_dict` for this.

In [78]:
def fit(model, optimizer, lr_scheduler, dataloaders, max_epochs, patience):
    """
    Fit the given model on the dataset.
    
    Args:
        model: The model used for prediction
        optimizer: Optimization algorithm for the model
        lr_scheduler: Learning rate scheduler that improves training
                      in late epochs with learning rate decay
        dataloaders: Dataloaders for training and validation
        max_epochs: Maximum number of epochs for training
        patience: Number of epochs to wait with early stopping the
                  training if validation loss has decreased
                  
    Returns:
        Loss and accuracy in this epoch.
    """
    
    best_acc = 0
    curr_patience = 0
    #print("fit running")
    
    for epoch in range(max_epochs):
        #print("going to run epoch")
        train_loss, train_acc = run_epoch(model, optimizer, dataloaders['train'], train=True)
        lr_scheduler.step()
        print(f"Epoch {epoch + 1: >3}/{max_epochs}, train loss: {train_loss:.2e}, accuracy: {train_acc * 100:.2f}%")
        
        val_loss, val_acc = run_epoch(model, None, dataloaders['val'], train=False)
        print(f"Epoch {epoch + 1: >3}/{max_epochs}, val loss: {val_loss:.2e}, accuracy: {val_acc * 100:.2f}%")
        
        # TODO: Add early stopping and save the best weights (in best_model_weights)
        
        if best_acc == 0 or val_acc > best_acc :
            val_best_acc = val_acc
            curr_patience = 0
            best_model_weights = copy.deepcopy(model.state_dict)
        else :
            current_patience += 1
            if current_patience >= patience :
                print("Stopping early at epoch {}!".format(epoch+1))
                break
    
    model.load_state_dict(best_model_weights)

In most cases you should just use the Adam optimizer for training, because it works well out of the box. However, a well-tuned SGD (with momentum) will in most cases outperform Adam. And since the original paper gives us a well-tuned SGD we will just use that.

In [79]:
optimizer = torch.optim.SGD(resnet.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)

# Fit model
print("start")
fit(resnet, optimizer, lr_scheduler, dataloaders, max_epochs=200, patience=50)

start
loss and accuracy at minibatch 0!
3.866533041000366
12.0
loss and accuracy at minibatch 1!
6.891152381896973
28.0
loss and accuracy at minibatch 2!
10.597957372665405
51.0
loss and accuracy at minibatch 3!
14.207668781280518
63.0
loss and accuracy at minibatch 4!
18.11723303794861
79.0
loss and accuracy at minibatch 5!
21.551703453063965
104.0
loss and accuracy at minibatch 6!
25.885797023773193
123.0
loss and accuracy at minibatch 7!
29.56934380531311
137.0
loss and accuracy at minibatch 8!
32.759204149246216
153.0
loss and accuracy at minibatch 9!
35.691073417663574
177.0
loss and accuracy at minibatch 10!
39.27406287193298
206.0
loss and accuracy at minibatch 11!
42.279764890670776
227.0
loss and accuracy at minibatch 12!
45.0328094959259
253.0
loss and accuracy at minibatch 13!
47.97999691963196
278.0
loss and accuracy at minibatch 14!
50.452752351760864
304.0
loss and accuracy at minibatch 15!
53.3086142539978
333.0
loss and accuracy at minibatch 16!
55.899131536483765
360.0

loss and accuracy at minibatch 136!
299.342728972435
4143.0
loss and accuracy at minibatch 137!
301.07266116142273
4183.0
loss and accuracy at minibatch 138!
302.90768480300903
4222.0
loss and accuracy at minibatch 139!
304.74943137168884
4261.0
loss and accuracy at minibatch 140!
306.5044445991516
4298.0
loss and accuracy at minibatch 141!
308.1657314300537
4343.0
loss and accuracy at minibatch 142!
310.0850336551666
4375.0
loss and accuracy at minibatch 143!
311.8767387866974
4417.0
loss and accuracy at minibatch 144!
313.5237271785736
4461.0
loss and accuracy at minibatch 145!
315.2228432893753
4502.0
loss and accuracy at minibatch 146!
316.95178496837616
4544.0
loss and accuracy at minibatch 147!
318.6015702486038
4587.0
loss and accuracy at minibatch 148!
320.35326731204987
4633.0
loss and accuracy at minibatch 149!
322.3445975780487
4669.0
loss and accuracy at minibatch 150!
323.97682082653046
4714.0
loss and accuracy at minibatch 151!
325.70947432518005
4751.0
loss and accuracy 

loss and accuracy at minibatch 270!
528.2391542196274
10276.0
loss and accuracy at minibatch 271!
529.7082674503326
10334.0
loss and accuracy at minibatch 272!
531.1488308906555
10396.0
loss and accuracy at minibatch 273!
532.9072924852371
10440.0
loss and accuracy at minibatch 274!
534.4933171272278
10484.0
loss and accuracy at minibatch 275!
535.9869232177734
10540.0
loss and accuracy at minibatch 276!
537.6736345291138
10586.0
loss and accuracy at minibatch 277!
539.2855883836746
10641.0
loss and accuracy at minibatch 278!
540.7895117998123
10695.0
loss and accuracy at minibatch 279!
542.3762695789337
10745.0
loss and accuracy at minibatch 280!
543.8896112442017
10799.0
loss and accuracy at minibatch 281!
545.3737144470215
10849.0
loss and accuracy at minibatch 282!
546.9414956569672
10907.0
loss and accuracy at minibatch 283!
548.5580812692642
10959.0
loss and accuracy at minibatch 284!
549.999298453331
11021.0
loss and accuracy at minibatch 285!
551.587813615799
11074.0
loss and a

loss and accuracy at minibatch 51!
74.53070759773254
3110.0
loss and accuracy at minibatch 52!
75.91436445713043
3171.0
loss and accuracy at minibatch 53!
77.31065583229065
3234.0
loss and accuracy at minibatch 54!
78.72625815868378
3295.0
loss and accuracy at minibatch 55!
80.33274972438812
3356.0
loss and accuracy at minibatch 56!
81.8355530500412
3411.0
loss and accuracy at minibatch 57!
83.24436700344086
3474.0
loss and accuracy at minibatch 58!
84.70971131324768
3539.0
loss and accuracy at minibatch 59!
86.0656772851944
3608.0
loss and accuracy at minibatch 60!
87.46995210647583
3676.0
loss and accuracy at minibatch 61!
88.75389468669891
3745.0
loss and accuracy at minibatch 62!
90.14055454730988
3803.0
loss and accuracy at minibatch 63!
91.53064000606537
3864.0
loss and accuracy at minibatch 64!
92.89462554454803
3927.0
loss and accuracy at minibatch 65!
94.17910873889923
3996.0
loss and accuracy at minibatch 66!
95.47874331474304
4068.0
loss and accuracy at minibatch 67!
97.0045

loss and accuracy at minibatch 185!
254.3712249994278
11969.0
loss and accuracy at minibatch 186!
255.6444352865219
12036.0
loss and accuracy at minibatch 187!
256.895902633667
12107.0
loss and accuracy at minibatch 188!
258.2631764411926
12170.0
loss and accuracy at minibatch 189!
259.51147079467773
12244.0
loss and accuracy at minibatch 190!
260.70962953567505
12314.0
loss and accuracy at minibatch 191!
261.89272463321686
12392.0
loss and accuracy at minibatch 192!
263.18753921985626
12464.0
loss and accuracy at minibatch 193!
264.4552649259567
12536.0
loss and accuracy at minibatch 194!
265.7135854959488
12599.0
loss and accuracy at minibatch 195!
266.8944387435913
12674.0
loss and accuracy at minibatch 196!
268.33355939388275
12735.0
loss and accuracy at minibatch 197!
269.59553480148315
12807.0
loss and accuracy at minibatch 198!
270.7243592739105
12891.0
loss and accuracy at minibatch 199!
271.9173879623413
12968.0
loss and accuracy at minibatch 200!
273.1167846918106
13046.0
los

loss and accuracy at minibatch 317!
414.6287052631378
21544.0
loss and accuracy at minibatch 318!
415.93772530555725
21612.0
loss and accuracy at minibatch 319!
417.039888381958
21688.0
loss and accuracy at minibatch 320!
418.3107134103775
21759.0
loss and accuracy at minibatch 321!
419.37689876556396
21839.0
loss and accuracy at minibatch 322!
420.4750579595566
21918.0
loss and accuracy at minibatch 323!
421.4284802079201
21998.0
loss and accuracy at minibatch 324!
422.5476250052452
22083.0
loss and accuracy at minibatch 325!
423.87533539533615
22150.0
loss and accuracy at minibatch 326!
425.01105231046677
22219.0
loss and accuracy at minibatch 327!
426.16000813245773
22291.0
loss and accuracy at minibatch 328!
427.10552990436554
22376.0
loss and accuracy at minibatch 329!
428.33682239055634
22450.0
loss and accuracy at minibatch 330!
429.5208646059036
22531.0
loss and accuracy at minibatch 331!
430.5800074338913
22609.0
loss and accuracy at minibatch 332!
431.6504135131836
22684.0
lo

loss and accuracy at minibatch 99!
107.45623028278351
7884.0
loss and accuracy at minibatch 100!
108.4834463596344
7963.0
loss and accuracy at minibatch 101!
109.37286394834518
8059.0
loss and accuracy at minibatch 102!
110.54686731100082
8134.0


KeyboardInterrupt: 

Once the model is trained we run it on the test set to obtain our final accuracy.
Note that we can only look at the test set once, everything else would lead to overfitting. So you _must_ ignore the test set while developing your model!

In [19]:
test_loss, test_acc = run_epoch(resnet, None, dataloaders['test'], train=False)
print(f"Test loss: {test_loss:.1e}, accuracy: {test_acc * 100:.2f}%")

Test loss: 2.8e-03, accuracy: 92.07%


That's almost what was reported in the paper (92.49%) and we didn't even train on the full training set.

# Optional task: Squeeze out all the juice!

Can you do even better? Have a look at [A Recipe for Training Neural Networks](https://karpathy.github.io/2019/04/25/recipe/) and some state-of-the-art architectures such as [EfficientNet architecture](https://ai.googleblog.com/2019/05/efficientnet-improving-accuracy-and.html). Play around with the possibilities PyTorch offers you and see how close you can get to the [state of the art on CIFAR-10](https://paperswithcode.com/sota/image-classification-on-cifar-10).