In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchvision.transforms import Resize

from VGG import VGG

In [2]:
conv_arch=(
    (1,64),
    (1,128),
    (2,256),
    (2,512),
    (2,512)
)

In [3]:
ratio = 4
small_conv_arch = [(pair[0], pair[1] // ratio) for pair in conv_arch]
small_conv_arch

[(1, 16), (1, 32), (2, 64), (2, 128), (2, 128)]

In [4]:
# X=torch.randn(size=(1,1,224,224))
# for blk in net:
#     X=blk(X)
#     print(blk.__class__.__name__,'output shape:\t',X.shape)

In [5]:
transform=transforms.Compose([
    Resize([224,224]),
    ToTensor()
])
mnist_training = datasets.MNIST(
    root="../data",
    train=True,
    transform=transform,
    download=False
)

mnist_test = datasets.MNIST(
    root="../data",
    train=False,
    transform=transform,
    download=False
)

In [6]:
torch.cuda.empty_cache ()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
BATCH_SIZE = 128
lr = 0.01
epochs = 20

train_dataloader = DataLoader(mnist_training, batch_size=BATCH_SIZE, shuffle=True)

test_dataloader = DataLoader(mnist_test, batch_size=BATCH_SIZE, shuffle=False)

In [8]:
net=VGG(small_conv_arch)
net

Sequential(
  (0): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (1): Sequential(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (3): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (4): Seque

In [9]:
net.to(device)

def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        nn.init.xavier_uniform_(m.weight)
net.apply(init_weights)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9)

for epoch in range(epochs):
    print(
        f"epoch {epoch} \n---------------------"
    )

    for batch, (inputs, labels) in enumerate(train_dataloader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        if batch % 10 == 0:
            loss, current = loss.item(), batch * len(inputs)
            print(f"loss:{loss:>7f} [{current:>5d}/ 60000]")

    with torch.no_grad():
        acc = 0
        total = 0
        for (image, label) in test_dataloader:
            image, label = image.to(device), label.to(device)
            output = net(image)
            _, pred = torch.max(output.data, 1)
            total += label.size(0)
            acc += (pred == label).sum()

        print(f"test: acc {100 * acc / total}")

epoch 0 
---------------------
loss:2.301156 [    0/ 60000]
loss:2.305832 [ 1280/ 60000]
loss:2.296857 [ 2560/ 60000]
loss:2.292382 [ 3840/ 60000]
loss:2.287944 [ 5120/ 60000]
loss:2.267069 [ 6400/ 60000]
loss:2.118567 [ 7680/ 60000]
loss:1.240553 [ 8960/ 60000]
loss:0.611455 [10240/ 60000]
loss:0.658462 [11520/ 60000]
loss:0.460682 [12800/ 60000]
loss:0.242844 [14080/ 60000]
loss:0.177017 [15360/ 60000]
loss:0.355205 [16640/ 60000]
loss:0.213276 [17920/ 60000]
loss:0.168332 [19200/ 60000]
loss:0.160125 [20480/ 60000]
loss:0.121496 [21760/ 60000]
loss:0.278805 [23040/ 60000]
loss:0.179440 [24320/ 60000]
loss:0.137964 [25600/ 60000]
loss:0.183054 [26880/ 60000]
loss:0.099813 [28160/ 60000]
loss:0.080487 [29440/ 60000]
loss:0.098885 [30720/ 60000]
loss:0.245040 [32000/ 60000]
loss:0.109399 [33280/ 60000]
loss:0.112814 [34560/ 60000]
loss:0.053773 [35840/ 60000]
loss:0.126897 [37120/ 60000]
loss:0.098267 [38400/ 60000]
loss:0.057795 [39680/ 60000]
loss:0.085521 [40960/ 60000]
loss:0.03729

KeyboardInterrupt: 