In [71]:
import torch
import torch.nn as nn
import torchvision
import numpy

In [72]:
alexnet = torchvision.models.alexnet(pretrained=True)
print(alexnet)

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [73]:
for params in alexnet.parameters():
  params.requires_grad = False

In [74]:
alexnet.features[0] = nn.Conv2d(1, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
alexnet.classifier[6] = nn.Linear(4096,10)
alexnet.classifier.add_module('7', nn.Softmax(dim=1))

In [75]:
print(alexnet)

AlexNet(
  (features): Sequential(
    (0): Conv2d(1, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [None]:
device = ('cuda' if torch.cuda.is_available else 'cpu')
print(device)

In [77]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((227,227)),
    torchvision.transforms.CenterCrop((224,224)),
    torchvision.transforms.RandomHorizontalFlip(p=0.5),
    torchvision.transforms.RandomVerticalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485],
                                     std=[0.229])
])

In [78]:
batch_size = 32
num_epoch = 20

In [79]:
train_data = torchvision.datasets.MNIST(root='/content/sample_data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size= batch_size, shuffle=True, num_workers=4)

test_data = torchvision.datasets.MNIST(root='/content/sample_data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_data, batch_size= batch_size, shuffle=True, num_workers=4)

In [80]:
alexnet.to(device)
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(alexnet.parameters(), lr=0.001, momentum=0.9)

In [None]:
torch.autograd.set_detect_anomaly(True)
for epoch in range(num_epoch):
  for i, (images, labels) in enumerate(train_loader):
    images = images.to(device)
    labels = labels.to(device)

    predicted = alexnet(images)
    loss = loss_func(predicted, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (i+1)%1000==0:
      print(f"[{epoch, i+1}], loss : {loss}")

In [None]:
total = correct = 0
with torch.no_grad():
  for data in test_loader:
    images, labels = data
    images = images.to(device)
    labels = labels.to(device)
    output = alexnet(images)
    _, predicted = torch.max(output, 1)
    total += labels.size(0)
    correct += (predicted==labels).sum().item()

  print(f"Accuracy is {100*(correct/total)}%")