In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from ResNet_hc import resnet101
import time
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
%matplotlib inline

In [2]:
# 利用torchvision对图像数据预处理
train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop(224),
    transforms.RandomAffine(degrees=15,scale=(0.8,1.5)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

trainset = torchvision.datasets.ImageFolder(root='../data/train/', transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)

valset = torchvision.datasets.ImageFolder(root='../data/val/', transform=val_transform)
valloader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=False, num_workers=4)

In [3]:
print(len(trainloader), len(valloader))

165 32


In [4]:
# 加载预训练模型
model = resnet101(2)
model.load_state_dict(torch.load('pretrained/resnet101-5d3b4d8f.pth'), strict=False)

In [5]:
# CPU 或者 GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 查看GPU可用情况
if torch.cuda.device_count()>1:
    print('We are using',torch.cuda.device_count(),'GPUs!')
    model = nn.DataParallel(model)
model.to(device)

# 定义loss function和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)

# 保存每个epoch后的Accuracy Loss Val_Accuracy
Accuracy = []
Loss = []
Val_Accuracy = []
BEST_VAL_ACC = 0.
# 训练
since = time.time()
for epoch in range(20):
    train_loss = 0.
    train_accuracy = 0.
    run_accuracy = 0.
    run_loss =0.
    total = 0.
    model.train()
    for i,data in enumerate(trainloader,0):
        images, labels = data
        images = images.to(device)
        labels = labels.to(device)  
        # 经典四步
        optimizer.zero_grad()
        outs = model(images)
        loss = criterion(outs, labels)
        loss.backward()
        optimizer.step()
        # 输出状态
        total += labels.size(0)
        run_loss += loss.item()
        _,prediction = torch.max(outs,1)
        run_accuracy += (prediction == labels).sum().item()
        if i % 20 == 19:
            print('epoch {},iter {},train accuracy: {:.4f}%   loss:  {:.4f}'.format(epoch, i+1, 100*run_accuracy/(labels.size(0)*20), run_loss/20))
            train_accuracy += run_accuracy
            train_loss += run_loss
            run_accuracy, run_loss = 0., 0.
    Loss.append(train_loss/total)
    Accuracy.append(100*train_accuracy/total)
    # 可视化训练过程
    fig1, ax1 = plt.subplots(figsize=(11, 8))
    ax1.plot(range(0, epoch+1, 1), Accuracy)
    ax1.set_title("Average trainset accuracy vs epochs")
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Avg. train. accuracy")
    plt.savefig('Train_accuracy_vs_epochs.png')
    plt.clf()
    plt.close()
    
    fig2, ax2 = plt.subplots(figsize=(11, 8))
    ax2.plot(range(epoch+1), Loss)
    ax2.set_title("Average trainset loss vs epochs")
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("Current loss")
    plt.savefig('loss_vs_epochs.png')

    plt.clf()
    plt.close()
    # 验证
    acc = 0.
    model.eval()
    print('waitting for Val...')
    with torch.no_grad():
        accuracy = 0.
        total =0
        for data in valloader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            out = model(images)
            _, prediction = torch.max(out, 1)
            total += labels.size(0)
            accuracy += (prediction == labels).sum().item()
            acc = 100.*accuracy/total
    print('epoch {}  The ValSet accuracy is {:.4f}% \n'.format(epoch, acc))
    Val_Accuracy.append(acc)
    if acc > BEST_VAL_ACC:
        print('Find Better Model and Saving it...')
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(model.state_dict(), './checkpoint/ResNet101_Cats_Dogs_hc.pth')
        BEST_VAL_ACC = acc
        print('Saved!')
    
    fig3, ax3 = plt.subplots(figsize=(11, 8))

    ax3.plot(range(epoch+1),Val_Accuracy )
    ax3.set_title("Average Val accuracy vs epochs")
    ax3.set_xlabel("Epoch")
    ax3.set_ylabel("Current Val accuracy")

    plt.savefig('val_accuracy_vs_epoch.png')
    plt.close()
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed%60))
    print('Now the best val Acc is {:.4f}%'.format(BEST_VAL_ACC))

We are using 2 GPUs!
epoch 0,iter 20,train accuracy: 42.1484%   loss:  0.7372
epoch 0,iter 40,train accuracy: 69.4141%   loss:  0.6270
epoch 0,iter 60,train accuracy: 81.4844%   loss:  0.5404
epoch 0,iter 80,train accuracy: 87.0703%   loss:  0.4702
epoch 0,iter 100,train accuracy: 88.3984%   loss:  0.4224
epoch 0,iter 120,train accuracy: 88.7109%   loss:  0.3820
epoch 0,iter 140,train accuracy: 89.2969%   loss:  0.3575
epoch 0,iter 160,train accuracy: 89.7656%   loss:  0.3369
waitting for Val...
epoch 0  The ValSet accuracy is 97.3750% 

Find Better Model and Saving it...
Saved!
Training complete in 2m 30s
Now the best val Acc is 97.3750%
epoch 1,iter 20,train accuracy: 91.0547%   loss:  0.2939
epoch 1,iter 40,train accuracy: 91.3672%   loss:  0.2861
epoch 1,iter 60,train accuracy: 90.7031%   loss:  0.2738
epoch 1,iter 80,train accuracy: 90.7812%   loss:  0.2635
epoch 1,iter 100,train accuracy: 91.2891%   loss:  0.2533
epoch 1,iter 120,train accuracy: 93.0469%   loss:  0.2327
epoch 1,i

epoch 13,iter 80,train accuracy: 95.3516%   loss:  0.1122
epoch 13,iter 100,train accuracy: 95.3125%   loss:  0.1139
epoch 13,iter 120,train accuracy: 94.8438%   loss:  0.1235
epoch 13,iter 140,train accuracy: 95.0000%   loss:  0.1124
epoch 13,iter 160,train accuracy: 95.1172%   loss:  0.1112
waitting for Val...
epoch 13  The ValSet accuracy is 98.9750% 

Training complete in 36m 47s
Now the best val Acc is 99.0000%
epoch 14,iter 20,train accuracy: 94.9219%   loss:  0.1152
epoch 14,iter 40,train accuracy: 94.8828%   loss:  0.1179
epoch 14,iter 60,train accuracy: 95.4297%   loss:  0.1136
epoch 14,iter 80,train accuracy: 95.5859%   loss:  0.1086
epoch 14,iter 100,train accuracy: 94.9609%   loss:  0.1201
epoch 14,iter 120,train accuracy: 94.6094%   loss:  0.1230
epoch 14,iter 140,train accuracy: 94.6094%   loss:  0.1161
epoch 14,iter 160,train accuracy: 95.6250%   loss:  0.0998
waitting for Val...
epoch 14  The ValSet accuracy is 98.9000% 

Training complete in 39m 18s
Now the best val Ac