In [1]:
import torch
import torchvision
import cv2
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim

In [2]:
my_alexnet = torchvision.models.alexnet(pretrained=True)

Downloading: "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth" to /home/ma-user/.torch/models/alexnet-owt-4df8aa71.pth
100%|██████████| 244418560/244418560 [01:12<00:00, 3361645.44it/s]


In [3]:
import torchvision.transforms as transforms

my_tf = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])])
## 同学们可以针对训练时的transform做一些增广，比如随机旋转等

In [5]:
train_dataset = torchvision.datasets.CIFAR10(root='./cifar10',train=True,transform=my_tf,download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./cifar10',train=False,transform=my_tf,download=True)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
train_dataloader = DataLoader(dataset=train_dataset,batch_size=32,shuffle=True)
test_dataloader = DataLoader(dataset=test_dataset)

In [7]:
for param in my_alexnet.parameters():
    param.requires_grad=False

In [8]:
my_alexnet

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Dropout(p=0.5)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace)
    (3): Dropout(p=0.5)
    (4): Linear(in_features=4096, out_feature

In [9]:
in_f = my_alexnet.classifier[6].in_features
my_alexnet.classifier[6] = nn.Linear(in_f,10)

In [10]:
learn_rate = 0.001
num_epoches = 1
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(my_alexnet.classifier[6].parameters(),lr=learn_rate,momentum=0.9)

In [11]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [12]:
#train
my_alexnet.to(device)
my_alexnet.train()
for epoch in range(num_epoches):
    print(f"epoch: {epoch+1}")
    for idx,(img,label)in enumerate(train_dataloader):
        images = img.to(device)
        labels = label.to(device)
        output = my_alexnet(images)
        loss = criterion(output,labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if idx%100==0:
            print(f"current loss = {loss.item()}")

epoch: 1
current loss = 2.541965961456299
current loss = 0.8652909398078918
current loss = 1.0931724309921265
current loss = 0.8530042171478271
current loss = 0.7725613117218018
current loss = 0.8463196754455566
current loss = 0.9321938157081604
current loss = 0.6227675676345825
current loss = 0.5539371967315674
current loss = 0.42793968319892883
current loss = 1.0118420124053955
current loss = 0.5574460029602051
current loss = 0.9021467566490173
current loss = 0.8032833933830261
current loss = 0.5987125635147095


KeyboardInterrupt: 

In [22]:
#eval
my_alexnet.to(device)
my_alexnet.eval()
total,correct = 0 , 0
for img,label in test_dataloader:
    images = img.to(device)
    labels = label.to(device)
    print('label:{}'.format(label))
    output = my_alexnet(images)
#     print(torch.max(output))
    _,idx = torch.max(output.data,1)
    print("idx:{}".format(idx))
    total += labels.size(0)
    correct +=(idx==labels).cpu().sum()

print(f"accuracy:{correct.item()/total}")
      

label:tensor([3])
idx:tensor([3], device='cuda:0')
label:tensor([8])
idx:tensor([8], device='cuda:0')
label:tensor([8])
idx:tensor([8], device='cuda:0')
label:tensor([0])
idx:tensor([8], device='cuda:0')
label:tensor([6])
idx:tensor([6], device='cuda:0')
label:tensor([6])
idx:tensor([6], device='cuda:0')
label:tensor([1])
idx:tensor([1], device='cuda:0')
label:tensor([6])
idx:tensor([6], device='cuda:0')
label:tensor([3])
idx:tensor([3], device='cuda:0')
label:tensor([1])
idx:tensor([1], device='cuda:0')
label:tensor([0])
idx:tensor([0], device='cuda:0')
label:tensor([9])
idx:tensor([9], device='cuda:0')
label:tensor([5])
idx:tensor([5], device='cuda:0')
label:tensor([7])
idx:tensor([7], device='cuda:0')
label:tensor([9])
idx:tensor([9], device='cuda:0')
label:tensor([8])
idx:tensor([8], device='cuda:0')
label:tensor([5])
idx:tensor([5], device='cuda:0')
label:tensor([7])
idx:tensor([7], device='cuda:0')
label:tensor([8])
idx:tensor([8], device='cuda:0')
label:tensor([6])
idx:tensor([6