In [None]:
import torch
import torch.nn as nn 
import sys
sys.path.append('..')
import myd2l

In [2]:
def conv_block(in_channels, n_channels):
    return nn.Sequential(
        nn.BatchNorm2d(in_channels), nn.ReLU(),
        nn.Conv2d(in_channels, n_channels, kernel_size=3, padding=1)
    )

In [3]:
# 稠密块

class DenseBlock(nn.Module):
    def __init__(self, n_convs, in_channels, n_channels):
        super().__init__()

        layer = []
        for i in range(n_convs):
            layer.append(conv_block(n_channels * i + in_channels, n_channels))

        self.net = nn.Sequential(*layer)

    def forward(self, X):
        for blk in self.net:
            Y = blk(X)
            X = torch.cat((X, Y), dim=1)

        return X 

In [4]:
blk = DenseBlock(2, 3, 10)
X = torch.randn((4, 3, 8, 8))
Y = blk(X)
print(Y.shape)

torch.Size([4, 23, 8, 8])


In [5]:
# 过渡层

def transition_block(in_channels, n_channels):
    return nn.Sequential(
        nn.BatchNorm2d(in_channels), nn.ReLU(),
        nn.Conv2d(in_channels, n_channels, kernel_size=1),
        nn.AvgPool2d(kernel_size=2, stride=2)
    )

In [6]:
blk = transition_block(23, 10)
print(blk(Y).shape)

torch.Size([4, 10, 4, 4])


In [8]:
b1 = nn.Sequential(
    nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
    nn.BatchNorm2d(64), nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)

n_channels, growth_rate = 64, 32
n_convs_in_dense_blocks = [4, 4, 4, 4]
blks = []

for i, n_convs in enumerate(n_convs_in_dense_blocks):
    blks.append(DenseBlock(n_convs, n_channels, growth_rate))
    n_channels += n_convs * growth_rate

    if i != len(n_convs_in_dense_blocks) - 1:
        blks.append(transition_block(n_channels, n_channels // 2))
        n_channels = n_channels // 2

net = nn.Sequential(
    b1, *blks,
    nn.BatchNorm2d(n_channels), nn.ReLU(),
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten(),
    nn.Linear(n_channels, 10)
)

In [None]:
batch_size, resize = 256, 96
train_iter, test_iter = myd2l.load_data_fashion_mnist(batch_size, resize)

n_epochs, lr = 10, 0.1
device = torch.device('cuda')
myd2l.train_clf(net, n_epochs, lr, train_iter, test_iter, device)