In [26]:
import torch.nn as nn
import torch

In [34]:
def vgg_block(num_convs, in_channels, out_channels):
    """ define VGG block with parameter num_convs """
    layers = []
    for i in range(num_convs):
        layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        layers.append(nn.ReLU())
        in_channels = out_channels
    layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
    return nn.Sequential(*layers)
    

In [36]:
def vgg(conv_arch, input_size):
    """ define VGG network with parameter conv_arch """
    conv_blks = []
    in_channels = 1 #输入黑白照片单通道，假设的话
    for (num_convs, out_channels) in conv_arch:  #卷积层
        conv_blks.append(vgg_block(num_convs, in_channels, out_channels))
        in_channels = out_channels
    
    test_input = torch.randn(1, 1, input_size, input_size)
    flatten_len = len(nn.Sequential(*conv_blks)(test_input).view(-1)) #创建测试输入，动态计算输入全连接层的长度
   
    return nn.Sequential(
        *conv_blks, nn.Flatten(), #全连接层
        nn.Linear(flatten_len, 4096), nn.ReLU(), nn.Dropout(0.5),
        nn.Linear(4096, 4096), nn.ReLU(), nn.Dropout(0.5),
        nn.Linear(4096, 10))
# 这个函数返回的是一个类，所以也可以进行实例化和输入数据，nn.Sequential是nn.Module的子类，自动继承父类所有的函数，所以可以输入数据并前向传播


In [38]:
conv_arch = ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512)) #某vgg网络结构，可能经典？
net = vgg(conv_arch, input_size=224) #实例化

## model train

In [41]:
# 载入数据
import torch
import torchvision
from torch.utils import data
from torchvision import transforms

trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../torchvision_data", train=True, transform=trans, download = True)
mnist_test = torchvision.datasets.FashionMNIST(root="../torchvision_data", train=False, transform=trans, download = True)

In [42]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader((mnist_train), batch_size=256, shuffle=True)
test_dataloader = DataLoader((mnist_test), batch_size=256, shuffle=False)
len(test_dataloader)

40