In [11]:
import torch 
from torch import nn
from d2l import torch as d2l

In [15]:
# vgg块
def vgg_block(num_convs, input_channels, out_channels):
    layers = []
    for _ in range(num_convs):
        layers.append(nn.Conv2d(input_channels, out_channels, 
                                kernel_size=3,padding=1))
        layers.append(nn.ReLU())
        input_channels = out_channels
    layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
    return nn.Sequential(*layers)


In [17]:
# vgg 网络| 层数-输出通道
conv_arch = ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))


def vgg(conv_arch):
    conv_blocks = []
    in_channels = 1
    for num_convs, out_channles in conv_arch:
        conv_blocks.append(vgg_block(num_convs, in_channels, out_channles))
        in_channels = out_channles
        
    return nn.Sequential(*conv_blocks, nn.Flatten(),
                         nn.Linear(out_channles*7*7, 4096),nn.ReLU(),nn.Dropout(0.5),
                         nn.Linear(4096,4096),nn.ReLU(), nn.Dropout(0.5),
                         nn.Linear(4096, 10))

net= vgg(conv_arch)

In [24]:
X = torch.rand(size = (1 , 1, 224, 224))
for blk in net:
    X= blk(X)
    print(blk.__class__.__name__, 'output shape:    \t', X.shape)

Sequential output shape:    	 torch.Size([1, 64, 112, 112])
Sequential output shape:    	 torch.Size([1, 128, 56, 56])
Sequential output shape:    	 torch.Size([1, 256, 28, 28])
Sequential output shape:    	 torch.Size([1, 512, 14, 14])
Sequential output shape:    	 torch.Size([1, 512, 7, 7])
Flatten output shape:    	 torch.Size([1, 25088])
Linear output shape:    	 torch.Size([1, 4096])
ReLU output shape:    	 torch.Size([1, 4096])
Dropout output shape:    	 torch.Size([1, 4096])
Linear output shape:    	 torch.Size([1, 4096])
ReLU output shape:    	 torch.Size([1, 4096])
Dropout output shape:    	 torch.Size([1, 4096])
Linear output shape:    	 torch.Size([1, 10])


In [25]:
radio = 4
small_conv_arch = [(pair[0], pair[1] //4) for pair in conv_arch]
net = vgg(small_conv_arch)

In [None]:
lr, num_epochs ,batch_sizes= 0.05, 10  , 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_sizes, resize=224)

d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

![20250114142958](https://raw.githubusercontent.com/Rainbow452/image/main/img/20250114142958.png)