In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchvision.transforms import Resize

from AlexNet import AlexNet

In [2]:
# net=nn.Sequential(
#     nn.Conv2d(1,96,kernel_size=11,stride=4,padding=1),
#     nn.ReLU(),
#     nn.MaxPool2d(kernel_size=3,stride=2),
#
#     nn.Conv2d(96, 256, kernel_size=5,  padding=2),
#     nn.ReLU(),
#     nn.MaxPool2d(kernel_size=3, stride=2),
#
#     nn.Conv2d(256, 384, kernel_size=3, padding=1),
#     nn.ReLU(),
#
#     nn.Conv2d(384, 384, kernel_size=3, padding=1),
#     nn.ReLU(),
#
#     nn.Conv2d(384, 256, kernel_size=3, padding=1),
#     nn.ReLU(),
#     nn.MaxPool2d(kernel_size=3,stride=2),
#
#     nn.Flatten(),
#
#     nn.Linear(256*5*5,4096),
#     nn.ReLU(),
#     nn.Dropout(p=0.5),
#
#     nn.Linear(4096,4096),
#     nn.ReLU(),
#     nn.Dropout(p=0.5),
#
#     nn.Linear(4096,10)
# )

In [3]:
# X=torch.randn(1,1,224,224)
# for layer in net:
#     X=layer(X)
#     print(layer.__class__.__name__,"output shape:\t",X.shape)

In [4]:
transform=transforms.Compose([
    Resize([224,224]),
    ToTensor()
])
mnist_training = datasets.MNIST(
    root="../data",
    train=True,
    transform=transform,
    download=False
)

mnist_test = datasets.MNIST(
    root="../data",
    train=False,
    transform=transform,
    download=False
)

In [5]:
print(mnist_test,mnist_training)

Dataset MNIST
    Number of datapoints: 10000
    Root location: ../data
    Split: Test
    StandardTransform
Transform: Compose(
               Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=warn)
               ToTensor()
           ) Dataset MNIST
    Number of datapoints: 60000
    Root location: ../data
    Split: Train
    StandardTransform
Transform: Compose(
               Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=warn)
               ToTensor()
           )


In [6]:
torch.cuda.empty_cache ()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
BATCH_SIZE = 256
lr = 0.01
epochs = 20

train_dataloader = DataLoader(mnist_training, batch_size=BATCH_SIZE, shuffle=True)

test_dataloader = DataLoader(mnist_test, batch_size=BATCH_SIZE, shuffle=False)

In [8]:
net=AlexNet().to(device)

def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        nn.init.xavier_uniform_(m.weight)
net.apply(init_weights)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9)

for epoch in range(epochs):
    print(
        f"epoch {epoch} \n---------------------"
    )

    for batch, (inputs, labels) in enumerate(train_dataloader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        if batch % 10 == 0:
            loss, current = loss.item(), batch * len(inputs)
            print(f"loss:{loss:>7f} [{current:>5d}/ 60000]")

    with torch.no_grad():
        acc = 0
        total = 0
        for (image, label) in test_dataloader:
            image, label = image.to(device), label.to(device)
            output = net(image)
            _, pred = torch.max(output.data, 1)
            total += label.size(0)
            acc += (pred == label).sum()

        print(f"test: acc {100 * acc / total}")


epoch 0 
---------------------
loss:2.309132 [    0/ 60000]
loss:2.296461 [ 2560/ 60000]
loss:2.275908 [ 5120/ 60000]
loss:2.206079 [ 7680/ 60000]
loss:1.562011 [10240/ 60000]
loss:1.082462 [12800/ 60000]
loss:0.446683 [15360/ 60000]
loss:0.363827 [17920/ 60000]
loss:0.233036 [20480/ 60000]
loss:0.254791 [23040/ 60000]
loss:0.201992 [25600/ 60000]
loss:0.255451 [28160/ 60000]
loss:0.180451 [30720/ 60000]
loss:0.102079 [33280/ 60000]
loss:0.107142 [35840/ 60000]
loss:0.125804 [38400/ 60000]
loss:0.104566 [40960/ 60000]
loss:0.123582 [43520/ 60000]
loss:0.084459 [46080/ 60000]
loss:0.113914 [48640/ 60000]
loss:0.151553 [51200/ 60000]
loss:0.101548 [53760/ 60000]
loss:0.102417 [56320/ 60000]
loss:0.086216 [58880/ 60000]
test: acc 97.63999938964844
epoch 1 
---------------------
loss:0.067266 [    0/ 60000]
loss:0.113348 [ 2560/ 60000]
loss:0.123024 [ 5120/ 60000]
loss:0.104231 [ 7680/ 60000]
loss:0.094467 [10240/ 60000]
loss:0.055352 [12800/ 60000]
loss:0.101548 [15360/ 60000]
loss:0.0518


KeyboardInterrupt

