In [12]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim

In [2]:
my_alexnet = torchvision.models.alexnet(pretrained=True)

In [4]:
my_fault = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.4914,0.4882,0.4465],[0.247,0.243,0.261])
])

In [5]:
train_dataset = torchvision.datasets.CIFAR10(root='./cifar10',train=True,transform=my_fault,download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./cifar10',train=False,transform=my_fault,download=True)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
train_dataloader = DataLoader(dataset=train_dataset,batch_size=32,shuffle=True)
test_dataloader = DataLoader(dataset=test_dataset)

In [8]:
for param in my_alexnet.parameters():
    param.requires_grad = False

In [9]:
my_alexnet

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Dropout(p=0.5)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace)
    (3): Dropout(p=0.5)
    (4): Linear(in_features=4096, out_feature

In [11]:
in_f = my_alexnet.classifier[6].in_features
my_alexnet.classifier[6] = nn.Linear(in_f,10)

In [16]:
learn_rate = 0.001
num_epoches = 1
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(my_alexnet.classifier[6].parameters(),lr=learn_rate,momentum=0.9)

In [17]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [18]:
device

device(type='cuda')

In [19]:
# train
my_alexnet.to(device)
my_alexnet.train()
for epoch in range(num_epoches):
    print("epoch:{}".format(epoch+1))
    for idx,(img,label) in enumerate(train_dataloader):
        images = img.to(device)
        labels = label.to(device)
        output = my_alexnet(images)
        loss = criterion(output,labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if idx%50 == 0:
            print("current loss:{}".format(loss.item()))

epoch:1
current loss:2.4996755123138428
current loss:1.0151495933532715
current loss:0.8049212098121643
current loss:1.09517502784729
current loss:0.5423555374145508
current loss:0.4911864101886749
current loss:0.5631389617919922
current loss:0.818423867225647
current loss:0.6078160405158997
current loss:0.6915767788887024
current loss:0.7340267896652222
current loss:1.1257890462875366


KeyboardInterrupt: 

In [20]:
#test
my_alexnet.to(device)
my_alexnet.eval()
correct,total = 0,0
for img,label in test_dataloader:
    images = img.to(device)
    labels = label.to(device)
    output = my_alexnet(images)
    _,idx = torch.max(output.data,1)
    correct += (labels==idx).cpu().sum()
    total += labels.size(0)
    
print("accuracy:{}".format(correct.item()/total))

accuracy:0.7678
