In [1]:
import torch
from torch import nn
from d2l import torch as d2l

In [2]:
def vgg_block(num_convs, input_channels, output_channels):
    layers = []
    for _ in range(num_convs):
        layers.append(
            nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1)
        )
        layers.append(nn.ReLU())
        input_channels = output_channels
    layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
    return nn.Sequential(*layers)

In [3]:
conv_arch = (
    (1, 64),        # (num_convs, output_channels)
    (1, 128),
    (2, 256),
    (2, 512),
    (2, 512)
)
def vgg(conv_arch):
    conv_blks = []
    input_channel = 1
    for (num_conv, output_channel) in conv_arch:
        conv_blks.append(
            vgg_block(num_conv, input_channel, output_channel)
        )
        input_channel = output_channel
    vgg_net = nn.Sequential(
        *conv_blks,
        nn.Flatten(),
        nn.Linear(output_channel * 7 * 7, 4096),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(4096, 4096),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(4096, 10)
    )
    return vgg_net
net = vgg(conv_arch)

In [4]:
x = torch.randn(size=(1, 1, 224, 224))
for blk in net:
    x = blk(x)
    print(blk.__class__.__name__, 'Output shape: \t', x.shape)

Sequential Output shape: 	 torch.Size([1, 64, 112, 112])
Sequential Output shape: 	 torch.Size([1, 128, 56, 56])
Sequential Output shape: 	 torch.Size([1, 256, 28, 28])
Sequential Output shape: 	 torch.Size([1, 512, 14, 14])
Sequential Output shape: 	 torch.Size([1, 512, 7, 7])
Flatten Output shape: 	 torch.Size([1, 25088])
Linear Output shape: 	 torch.Size([1, 4096])
ReLU Output shape: 	 torch.Size([1, 4096])
Dropout Output shape: 	 torch.Size([1, 4096])
Linear Output shape: 	 torch.Size([1, 4096])
ReLU Output shape: 	 torch.Size([1, 4096])
Dropout Output shape: 	 torch.Size([1, 4096])
Linear Output shape: 	 torch.Size([1, 10])


In [5]:
ratio = 4
small_conv_arch = [(ca[0], ca[1] // ratio) for ca in conv_arch]
net = vgg(small_conv_arch)

In [6]:
x = torch.randn(size=(1, 1, 224, 224))
for blk in net:
    x = blk(x)
    print(blk.__class__.__name__, 'Output shape: \t', x.shape)

Sequential Output shape: 	 torch.Size([1, 16, 112, 112])
Sequential Output shape: 	 torch.Size([1, 32, 56, 56])
Sequential Output shape: 	 torch.Size([1, 64, 28, 28])
Sequential Output shape: 	 torch.Size([1, 128, 14, 14])
Sequential Output shape: 	 torch.Size([1, 128, 7, 7])
Flatten Output shape: 	 torch.Size([1, 6272])
Linear Output shape: 	 torch.Size([1, 4096])
ReLU Output shape: 	 torch.Size([1, 4096])
Dropout Output shape: 	 torch.Size([1, 4096])
Linear Output shape: 	 torch.Size([1, 4096])
ReLU Output shape: 	 torch.Size([1, 4096])
Dropout Output shape: 	 torch.Size([1, 4096])
Linear Output shape: 	 torch.Size([1, 10])


In [7]:
lr, num_epochs, batch_size = 0.05, 10, 128
# train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)
# d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ../data\FashionMNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:52<00:00, 499667.99it/s] 


Extracting ../data\FashionMNIST\raw\train-images-idx3-ubyte.gz to ../data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ../data\FashionMNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 69146.79it/s]


Extracting ../data\FashionMNIST\raw\train-labels-idx1-ubyte.gz to ../data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:09<00:00, 458530.61it/s]


Extracting ../data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to ../data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 2551675.37it/s]

Extracting ../data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to ../data\FashionMNIST\raw




