In [7]:
import torch 
from torch import nn
from d2l import torch as d2l
# 参数分别对应：卷积层数量  输入通道数量 输出通道数量
def vgg_block(num_convs,in_channels,out_channels):
    layers = []
    for _ in range(num_convs):
        # 添加一个卷积层
        layers.append(nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,padding=1))
        
        # 添加激活函数
        layers.append(nn.ReLU())
        # 因为添加多个卷积层 上一层的输出通道 对应下一层的数通道
        in_channels = out_channels

    # 添加一个最大池化层
    layers.append(nn.MaxPool2d(kernel_size=2,stride=2))
    return nn.Sequential(*layers)

In [8]:
conv_arch = ((1,64),(1,128),(2,256),(2,512),(2,512))
print(type(conv_arch))

<class 'tuple'>


In [11]:
def vgg(conv_arch):
    conv_blks = []
    in_channels = 1  # 初始化输入通道数 在一个卷积层添加之后 需要更新输入通道数

    # 卷积层部分
    for(num_convs,out_channels) in conv_arch:
        # 遍历元组
        # 初始化VGG 块  然后添加进去
        conv_blks.append(vgg_block(num_convs=num_convs,in_channels=in_channels,out_channels=out_channels))


        # 更新输入通道数
        in_channels = out_channels

    # 全连接层  经过一系列卷积层之后 需要将4D向量 转换为2D向量
    return nn.Sequential(
        *conv_blks,nn.Flatten(),
        # 全练级曾部分
        nn.Linear(out_channels * 7 * 7,4096),nn.ReLU(),nn.Dropout(0.5),
        nn.Linear(4096,4096),nn.ReLU(),nn.Dropout(0.5),
        nn.Linear(4096,10)
    )



In [14]:
net = vgg(conv_arch=conv_arch)

X = torch.randn(size=(1,1,224,224))

for blk in net:
    X = blk(X)
    print(blk.__class__.__name__,'output shape:\t',X.shape)

Sequential output shape:	 torch.Size([1, 64, 112, 112])
Sequential output shape:	 torch.Size([1, 128, 56, 56])
Sequential output shape:	 torch.Size([1, 256, 28, 28])
Sequential output shape:	 torch.Size([1, 512, 14, 14])
Sequential output shape:	 torch.Size([1, 512, 7, 7])
Flatten output shape:	 torch.Size([1, 25088])
Linear output shape:	 torch.Size([1, 4096])
ReLU output shape:	 torch.Size([1, 4096])
Dropout output shape:	 torch.Size([1, 4096])
Linear output shape:	 torch.Size([1, 4096])
ReLU output shape:	 torch.Size([1, 4096])
Dropout output shape:	 torch.Size([1, 4096])
Linear output shape:	 torch.Size([1, 10])


In [15]:
ratio = 4
small_conv_arch = [(pair[0],pair[1] // ratio) for pair in conv_arch]
net = vgg(small_conv_arch)

In [18]:
lr , num_epochs,batch_size = 0.05,10,128
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size,resize=224)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ../data\FashionMNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [02:06<00:00, 208574.80it/s] 


Extracting ../data\FashionMNIST\raw\train-images-idx3-ubyte.gz to ../data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ../data\FashionMNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:04<00:00, 6079.70it/s]


Extracting ../data\FashionMNIST\raw\train-labels-idx1-ubyte.gz to ../data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:02<00:00, 1587149.09it/s]


Extracting ../data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to ../data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 5016792.98it/s]


Extracting ../data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to ../data\FashionMNIST\raw



In [20]:
d2l.train_ch6(net,train_iter,test_iter,num_epochs=num_epochs,lr=lr,d2l.try_gpu())

SyntaxError: positional argument follows keyword argument (1518756369.py, line 1)