In [1]:
import CropModels
from CropDataset import MyDataSet,normalize_torch,normalize_05,normalize_dataset,preprocess,preprocess_hflip,preprocess_with_augmentation
import pandas as pd
from torch.utils.data import DataLoader
import torch.nn as nn
from tensorboardX import SummaryWriter
import datetime
import os
import torch
from torch.autograd import Variable
from utils import RunningMean
import utils
from PIL import Image
import numpy as np
import random
NB_CLASS=59
SEED=888
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
os.environ["CUDA_VISIBLE_DEVICES"] = '3'
torch.backends.cudnn.benchmark = True

In [2]:
BATCH_SIZE=32
IMAGE_SIZE=224    # 不同模型修改不同的Size
IMAGE_TRAIN_PRE='../data/AgriculturalDisease_trainingset/images/'
ANNOTATION_TRAIN='../data/AgriculturalDisease_trainingset/AgriculturalDisease_train_annotations_deleteNoise.json' #是否需要剔除两类异常类
IMAGE_VAL_PRE='../data/AgriculturalDisease_validationset/images/'
ANNOTATION_VAL='../data/AgriculturalDisease_validationset/AgriculturalDisease_validation_annotations_deleteNoise.json' #是否需要剔除两类异常类
date=str(datetime.date.today())
with open(ANNOTATION_TRAIN) as datafile1:
    trainDataFram=pd.read_json(datafile1,orient='records')
with open(ANNOTATION_VAL) as datafile2: #first check if it's a valid json file or not
    validateDataFram =pd.read_json(datafile2,orient='records')    
def getmodel():
    print('[+] loading model... ', end='', flush=True)
    model=CropModels.nasnetmobile(NB_CLASS)
    model.cuda()
    print('Done')
    return model

In [3]:
def train(epochNum):
    writer=SummaryWriter('log/'+date+'/NasnetMobile/') # 创建 /log/日期/InceptionResnet的组织形式  不同模型需要修改不同名称
    train_dataset=MyDataSet(json_Description=ANNOTATION_TRAIN,transform=preprocess_with_augmentation(normalize_05,IMAGE_SIZE),path_pre=IMAGE_TRAIN_PRE)
    val_dataset=MyDataSet(json_Description=ANNOTATION_VAL,transform=preprocess(normalize_05,IMAGE_SIZE),path_pre=IMAGE_VAL_PRE)
    train_dataLoader=DataLoader(dataset=train_dataset,batch_size=BATCH_SIZE,num_workers=16,shuffle=True)
    val_dataLoader=DataLoader(dataset=val_dataset,batch_size=BATCH_SIZE,num_workers=1,shuffle=False)
    model=getmodel()
    weight=torch.Tensor([1,3,3,3,3,4,2,3,3,3,3,3,3,3,3,3,2,3,3,3,2,3,4,2,3,1,1,3,2,2,1,3,3,1,3,2,3,3,3,3,2,1,3,2,3,3,3,1,3,3,4,4,3,2,2,3,1,1,3]).cuda()
    criterion=nn.CrossEntropyLoss(weight=weight).cuda()
#     lx, px = utils.predict(model,val_dataLoader)
#     min_loss = criterion(Variable(px), Variable(lx)).item()
    min_loss=4.1
    print('min_loss is :%f'%(min_loss))
    min_acc=0.80
    patience=0
    lr=0.0
    momentum=0.0
    for epoch in range(epochNum):
        print('Epoch {}/{}'.format(epoch, epochNum - 1))
        print('-' * 10)
        if epoch==3 or epoch==4 or epoch==5:
            lr=0.00006
            momentum=0.95
            print('set lr=:%f,momentum=%f'%(lr,momentum))
        if epoch==6:
            lr=1e-4
            momentum=0.9
            print('set lr=:%f,momentum=%f'%(lr,momentum))
        if patience==2:
            patience=0
            model.load_state_dict(torch.load('../model/NasnetMobile/'+date+'_loss_best.pth')['state_dict'])
            lr=lr/10
            print('loss has increased lr divide 10 lr now is :%f'%(lr))
        if epoch==0 or epoch==1 or epoch==2: #第一轮首先训练全连接层
            lr=1e-3
#             optimizer=torch.optim.SGD(params=model.fresh_params(),lr=lr,momentum=0.9)
            optimizer = torch.optim.Adam(model.fresh_params(),lr = lr,amsgrad=True,weight_decay=1e-4)
        else:
            optimizer = torch.optim.Adam(model.parameters(),lr = lr,amsgrad=True,weight_decay=1e-4)
#             optimizer=torch.optim.SGD(params=model.parameters(),lr=lr,momentum=momentum)
        running_loss = RunningMean()
        running_corrects = RunningMean()
        for batch_idx, (inputs, labels) in enumerate(train_dataLoader):
            model.train(True)
            n_batchsize=inputs.size(0)
            inputs = Variable(inputs).cuda()
            labels = Variable(labels).cuda()
            optimizer.zero_grad()
            outputs = model(inputs)
            _, preds = torch.max(outputs.data, 1)
            if isinstance(outputs,tuple):
                loss=sum((criterion(o,labels)) for o in outputs)
            else:
                loss = criterion(outputs, labels)
            running_loss.update(loss.item(),1)
            running_corrects.update(torch.sum(preds == labels.data).data,n_batchsize)
            loss.backward()
            optimizer.step()
            if batch_idx%30==29:
                print('[epoch:%d,batch:%d]:acc: %f,loss:%f'%(epoch,batch_idx,running_corrects.value,running_loss.value))
                if batch_idx%300==299: 
                    niter = epoch * len(train_dataset)/BATCH_SIZE + batch_idx
                    writer.add_scalar('Train/Acc',running_corrects.value,niter)
                    writer.add_scalar('Train/Loss',running_loss.value,niter)
                    lx,px=utils.predict(model,val_dataLoader)
                    log_loss = criterion(Variable(px), Variable(lx))
                    log_loss = log_loss.item()
                    _, preds = torch.max(px, dim=1)
                    accuracy = torch.mean((preds == lx).float())
                    writer.add_scalar('Val/Acc',accuracy,niter)
                    writer.add_scalar('Val/Loss',log_loss,niter)
                    print('[epoch:%d,batch:%d]: val_loss:%f,val_acc:%f,val_total:%d'%(epoch,batch_idx,log_loss,accuracy,len(val_dataset)))
        print('[epoch:%d] :acc: %f,loss:%f,lr:%f,patience:%d'%(epoch,running_corrects.value,running_loss.value,lr,patience))       
        lx,px=utils.predict(model,val_dataLoader)
        log_loss = criterion(Variable(px), Variable(lx))
        log_loss = log_loss.item()
        _, preds = torch.max(px, dim=1)
        accuracy = torch.mean((preds == lx).float())
        writer.add_scalar('Val/Acc',accuracy,(epoch+1) * len(train_dataset)/BATCH_SIZE)
        writer.add_scalar('Val/Loss',log_loss,(epoch+1) * len(train_dataset)/BATCH_SIZE)
        print('[epoch:%d]: val_loss:%f,val_acc:%f,'%(epoch,log_loss,accuracy))
        if  log_loss < min_loss:
            utils.snapshot('../model/', 'NasnetMobile', {
                   'epoch': epoch + 1,
                   'state_dict': model.state_dict(),
                   'optimizer': optimizer.state_dict(),
                   'val_loss': log_loss,
                   'val_correct':accuracy })          
            patience = 0
            min_loss=log_loss
            print('save new model loss,now loss is ',min_loss)
        else:
            patience += 1
        if accuracy>min_acc:
            utils.snapshot('../model/', 'NasnetMobile', {
                   'epoch': epoch + 1,
                   'state_dict': model.state_dict(),
                   'optimizer': optimizer.state_dict(),
                   'val_loss': log_loss,
                   'val_correct':accuracy },key='acc') 
            min_acc=accuracy
            print('save new model acc,now acc is ',min_acc)

In [None]:
train(60)

[+] loading model... Done
min_loss is :4.100000
Epoch 0/59
----------
[epoch:0,batch:29]:acc: 0.134375,loss:3.926591
[epoch:0,batch:59]:acc: 0.206771,loss:3.678492
[epoch:0,batch:89]:acc: 0.263542,loss:3.434125
[epoch:0,batch:119]:acc: 0.303906,loss:3.256706
[epoch:0,batch:149]:acc: 0.340833,loss:3.095744
[epoch:0,batch:179]:acc: 0.365625,loss:2.965924
[epoch:0,batch:209]:acc: 0.383036,loss:2.855124
[epoch:0,batch:239]:acc: 0.400000,loss:2.763887
[epoch:0,batch:269]:acc: 0.416204,loss:2.673652
[epoch:0,batch:299]:acc: 0.429896,loss:2.590477
[epoch:0,batch:299]: val_loss:1.777935,val_acc:0.599912,val_total:4539
[epoch:0,batch:329]:acc: 0.441477,loss:2.518498
[epoch:0,batch:359]:acc: 0.447049,loss:2.464793
[epoch:0,batch:389]:acc: 0.456250,loss:2.411397
[epoch:0,batch:419]:acc: 0.465253,loss:2.358708
[epoch:0,batch:449]:acc: 0.471875,loss:2.315168
[epoch:0,batch:479]:acc: 0.478125,loss:2.270495
[epoch:0,batch:509]:acc: 0.484069,loss:2.229955
[epoch:0,batch:539]:acc: 0.488831,loss:2.19539

[epoch:4,batch:89]:acc: 0.794444,loss:0.588114
[epoch:4,batch:119]:acc: 0.792969,loss:0.580286
[epoch:4,batch:149]:acc: 0.795000,loss:0.568641
[epoch:4,batch:179]:acc: 0.797222,loss:0.562046
[epoch:4,batch:209]:acc: 0.798810,loss:0.558643
[epoch:4,batch:239]:acc: 0.801172,loss:0.558706
[epoch:4,batch:269]:acc: 0.800926,loss:0.550815
[epoch:4,batch:299]:acc: 0.799375,loss:0.547913
[epoch:4,batch:299]: val_loss:0.485960,val_acc:0.813175,val_total:4539
[epoch:4,batch:329]:acc: 0.800189,loss:0.543621
[epoch:4,batch:359]:acc: 0.801736,loss:0.537692
[epoch:4,batch:389]:acc: 0.801603,loss:0.533546
[epoch:4,batch:419]:acc: 0.801637,loss:0.530913
[epoch:4,batch:449]:acc: 0.801736,loss:0.533578
[epoch:4,batch:479]:acc: 0.803125,loss:0.527971
[epoch:4,batch:509]:acc: 0.802696,loss:0.527473
[epoch:4,batch:539]:acc: 0.804167,loss:0.525547
[epoch:4,batch:569]:acc: 0.804715,loss:0.526363
[epoch:4,batch:599]:acc: 0.805000,loss:0.526511
[epoch:4,batch:599]: val_loss:0.453628,val_acc:0.829037,val_total:

[epoch:8,batch:149]:acc: 0.861250,loss:0.346700
[epoch:8,batch:179]:acc: 0.860243,loss:0.349774
[epoch:8,batch:209]:acc: 0.861161,loss:0.346463
[epoch:8,batch:239]:acc: 0.861328,loss:0.346013
[epoch:8,batch:269]:acc: 0.859259,loss:0.347743
[epoch:8,batch:299]:acc: 0.859479,loss:0.345000
[epoch:8,batch:299]: val_loss:0.436724,val_acc:0.845340,val_total:4539
[epoch:8,batch:329]:acc: 0.858239,loss:0.348898
[epoch:8,batch:359]:acc: 0.858247,loss:0.348623
[epoch:8,batch:389]:acc: 0.858894,loss:0.350182
[epoch:8,batch:419]:acc: 0.858631,loss:0.350581
[epoch:8,batch:449]:acc: 0.857847,loss:0.351519
[epoch:8,batch:479]:acc: 0.857552,loss:0.352957
[epoch:8,batch:509]:acc: 0.857782,loss:0.351632
[epoch:8,batch:539]:acc: 0.856366,loss:0.356330
[epoch:8,batch:569]:acc: 0.855702,loss:0.358551
[epoch:8,batch:599]:acc: 0.855573,loss:0.359303
[epoch:8,batch:599]: val_loss:0.403691,val_acc:0.844459,val_total:4539
[epoch:8,batch:629]:acc: 0.856151,loss:0.359774
[epoch:8,batch:659]:acc: 0.856013,loss:0.3

[epoch:12,batch:239]:acc: 0.862109,loss:0.314709
[epoch:12,batch:269]:acc: 0.865046,loss:0.310415
[epoch:12,batch:299]:acc: 0.865729,loss:0.312620
[epoch:12,batch:299]: val_loss:0.403073,val_acc:0.848645,val_total:4539
[epoch:12,batch:329]:acc: 0.866288,loss:0.313496
[epoch:12,batch:359]:acc: 0.865191,loss:0.314718
[epoch:12,batch:389]:acc: 0.864824,loss:0.317274
[epoch:12,batch:419]:acc: 0.865923,loss:0.316628
[epoch:12,batch:449]:acc: 0.866042,loss:0.314818
[epoch:12,batch:479]:acc: 0.866471,loss:0.314110
[epoch:12,batch:509]:acc: 0.866973,loss:0.312220
[epoch:12,batch:539]:acc: 0.867477,loss:0.310318
[epoch:12,batch:569]:acc: 0.866338,loss:0.314283
[epoch:12,batch:599]:acc: 0.866719,loss:0.313692
[epoch:12,batch:599]: val_loss:0.398726,val_acc:0.851950,val_total:4539
[epoch:12,batch:629]:acc: 0.867361,loss:0.312367
[epoch:12,batch:659]:acc: 0.867898,loss:0.312291
[epoch:12,batch:689]:acc: 0.867980,loss:0.311529
[epoch:12,batch:719]:acc: 0.868186,loss:0.311129
[epoch:12,batch:749]:ac

[epoch:16,batch:299]: val_loss:0.393488,val_acc:0.853933,val_total:4539
[epoch:16,batch:329]:acc: 0.869129,loss:0.307170
[epoch:16,batch:359]:acc: 0.869358,loss:0.306453
[epoch:16,batch:389]:acc: 0.869151,loss:0.305719
[epoch:16,batch:419]:acc: 0.870238,loss:0.305391
[epoch:16,batch:449]:acc: 0.870278,loss:0.305622
[epoch:16,batch:479]:acc: 0.870443,loss:0.305089
[epoch:16,batch:509]:acc: 0.870772,loss:0.304500
[epoch:16,batch:539]:acc: 0.870313,loss:0.306419
[epoch:16,batch:569]:acc: 0.870504,loss:0.305496
[epoch:16,batch:599]:acc: 0.870781,loss:0.303067
[epoch:16,batch:599]: val_loss:0.394981,val_acc:0.852390,val_total:4539
[epoch:16,batch:629]:acc: 0.871181,loss:0.302404
[epoch:16,batch:659]:acc: 0.872064,loss:0.301988
[epoch:16,batch:689]:acc: 0.871558,loss:0.302911
[epoch:16,batch:719]:acc: 0.871094,loss:0.303921
[epoch:16,batch:749]:acc: 0.870792,loss:0.305364
[epoch:16,batch:779]:acc: 0.870713,loss:0.303909
[epoch:16,batch:809]:acc: 0.871026,loss:0.303249
[epoch:16,batch:839]:ac

[epoch:20,batch:359]:acc: 0.875608,loss:0.300039
[epoch:20,batch:389]:acc: 0.874119,loss:0.302881
[epoch:20,batch:419]:acc: 0.875149,loss:0.300281
[epoch:20,batch:449]:acc: 0.874375,loss:0.301546
[epoch:20,batch:479]:acc: 0.875195,loss:0.300841
[epoch:20,batch:509]:acc: 0.875735,loss:0.298938
[epoch:20,batch:539]:acc: 0.876157,loss:0.299359
[epoch:20,batch:569]:acc: 0.876151,loss:0.299832
[epoch:20,batch:599]:acc: 0.876406,loss:0.300149
[epoch:20,batch:599]: val_loss:0.396479,val_acc:0.856356,val_total:4539
[epoch:20,batch:629]:acc: 0.876587,loss:0.299436
[epoch:20,batch:659]:acc: 0.876894,loss:0.299132
[epoch:20,batch:689]:acc: 0.877400,loss:0.298977
[epoch:20,batch:719]:acc: 0.876519,loss:0.301486
[epoch:20,batch:749]:acc: 0.876708,loss:0.301160
[epoch:20,batch:779]:acc: 0.877163,loss:0.300169
[epoch:20,batch:809]:acc: 0.877083,loss:0.300136
[epoch:20,batch:839]:acc: 0.876823,loss:0.299767
[epoch:20,batch:869]:acc: 0.876545,loss:0.300162
[epoch:20,batch:899]:acc: 0.876667,loss:0.2990

[epoch:24,batch:509]:acc: 0.878002,loss:0.297491
[epoch:24,batch:539]:acc: 0.877546,loss:0.296728
[epoch:24,batch:569]:acc: 0.878289,loss:0.296444
[epoch:24,batch:599]:acc: 0.878542,loss:0.296121
[epoch:24,batch:599]: val_loss:0.393390,val_acc:0.855475,val_total:4539
[epoch:24,batch:629]:acc: 0.877927,loss:0.295959
[epoch:24,batch:659]:acc: 0.877178,loss:0.296124
[epoch:24,batch:689]:acc: 0.876585,loss:0.296722
[epoch:24,batch:719]:acc: 0.875694,loss:0.299095
[epoch:24,batch:749]:acc: 0.876375,loss:0.297487
[epoch:24,batch:779]:acc: 0.876482,loss:0.297372
[epoch:24,batch:809]:acc: 0.876505,loss:0.297667
[epoch:24,batch:839]:acc: 0.876079,loss:0.297696
[epoch:24,batch:869]:acc: 0.876185,loss:0.297456
[epoch:24,batch:899]:acc: 0.876146,loss:0.297403
[epoch:24,batch:899]: val_loss:0.395079,val_acc:0.854153,val_total:4539
[epoch:24,batch:929]:acc: 0.876008,loss:0.296833
[epoch:24,batch:959]:acc: 0.875814,loss:0.296337
[epoch:24,batch:989]:acc: 0.875600,loss:0.296273
[epoch:24] :acc: 0.8756

[epoch:28,batch:629]:acc: 0.873859,loss:0.302438
[epoch:28,batch:659]:acc: 0.873532,loss:0.301850
[epoch:28,batch:689]:acc: 0.873687,loss:0.301464
[epoch:28,batch:719]:acc: 0.873177,loss:0.302801
[epoch:28,batch:749]:acc: 0.872958,loss:0.302610
[epoch:28,batch:779]:acc: 0.873518,loss:0.302228
[epoch:28,batch:809]:acc: 0.873727,loss:0.301122
[epoch:28,batch:839]:acc: 0.873437,loss:0.301860
[epoch:28,batch:869]:acc: 0.874425,loss:0.299664
[epoch:28,batch:899]:acc: 0.874340,loss:0.299193
[epoch:28,batch:899]: val_loss:0.392853,val_acc:0.853051,val_total:4539
[epoch:28,batch:929]:acc: 0.873992,loss:0.299724
[epoch:28,batch:959]:acc: 0.874251,loss:0.299140
[epoch:28,batch:989]:acc: 0.874337,loss:0.299108
[epoch:28] :acc: 0.874263,loss:0.299027,lr:0.000000,patience:1
[epoch:28]: val_loss:0.394043,val_acc:0.856576,
Epoch 29/59
----------
loss has increased lr divide 10 lr now is :0.000000
[epoch:29,batch:29]:acc: 0.886458,loss:0.251116
[epoch:29,batch:59]:acc: 0.881250,loss:0.271795
[epoch:29

[epoch:32,batch:779]:acc: 0.874679,loss:0.301308
[epoch:32,batch:809]:acc: 0.874769,loss:0.301808
[epoch:32,batch:839]:acc: 0.874888,loss:0.300712
[epoch:32,batch:869]:acc: 0.874856,loss:0.299855
[epoch:32,batch:899]:acc: 0.874931,loss:0.300211
[epoch:32,batch:899]: val_loss:0.392517,val_acc:0.854594,val_total:4539
[epoch:32,batch:929]:acc: 0.875134,loss:0.299426
[epoch:32,batch:959]:acc: 0.875000,loss:0.299929
[epoch:32,batch:989]:acc: 0.874684,loss:0.300497
[epoch:32] :acc: 0.874736,loss:0.300230,lr:0.000000,patience:1
[epoch:32]: val_loss:0.398267,val_acc:0.856356,
Epoch 33/59
----------
loss has increased lr divide 10 lr now is :0.000000
[epoch:33,batch:29]:acc: 0.869792,loss:0.283873
[epoch:33,batch:59]:acc: 0.866146,loss:0.289259
[epoch:33,batch:89]:acc: 0.875694,loss:0.288436
[epoch:33,batch:119]:acc: 0.876563,loss:0.289902
[epoch:33,batch:149]:acc: 0.876875,loss:0.288348
[epoch:33,batch:179]:acc: 0.873958,loss:0.291939
[epoch:33,batch:209]:acc: 0.876488,loss:0.286429
[epoch:33,

[epoch:36,batch:899]: val_loss:0.393467,val_acc:0.854814,val_total:4539
[epoch:36,batch:929]:acc: 0.874933,loss:0.297351
[epoch:36,batch:959]:acc: 0.874740,loss:0.297938
[epoch:36,batch:989]:acc: 0.874400,loss:0.297969
[epoch:36] :acc: 0.874452,loss:0.297923,lr:0.000000,patience:1
[epoch:36]: val_loss:0.398257,val_acc:0.854153,
Epoch 37/59
----------
loss has increased lr divide 10 lr now is :0.000000
[epoch:37,batch:29]:acc: 0.864583,loss:0.304333
[epoch:37,batch:59]:acc: 0.873437,loss:0.300373
[epoch:37,batch:89]:acc: 0.875000,loss:0.302600
[epoch:37,batch:119]:acc: 0.877344,loss:0.293932
[epoch:37,batch:149]:acc: 0.878958,loss:0.286497
[epoch:37,batch:179]:acc: 0.879514,loss:0.286098
[epoch:37,batch:209]:acc: 0.874702,loss:0.294263
[epoch:37,batch:239]:acc: 0.875260,loss:0.296539
[epoch:37,batch:269]:acc: 0.876157,loss:0.295788
[epoch:37,batch:299]:acc: 0.875625,loss:0.297429
[epoch:37,batch:299]: val_loss:0.392455,val_acc:0.853051,val_total:4539
[epoch:37,batch:329]:acc: 0.875095,l

[epoch:40]: val_loss:0.403068,val_acc:0.851068,
Epoch 41/59
----------
loss has increased lr divide 10 lr now is :0.000000
[epoch:41,batch:29]:acc: 0.882292,loss:0.289592
[epoch:41,batch:59]:acc: 0.876042,loss:0.288929
[epoch:41,batch:89]:acc: 0.879514,loss:0.285246
[epoch:41,batch:119]:acc: 0.878125,loss:0.292522
[epoch:41,batch:149]:acc: 0.878750,loss:0.285880
[epoch:41,batch:179]:acc: 0.875000,loss:0.293945
[epoch:41,batch:209]:acc: 0.874107,loss:0.298841
[epoch:41,batch:239]:acc: 0.872917,loss:0.299276
[epoch:41,batch:269]:acc: 0.874769,loss:0.293739
[epoch:41,batch:299]:acc: 0.874479,loss:0.298060
[epoch:41,batch:299]: val_loss:0.398792,val_acc:0.852831,val_total:4539
[epoch:41,batch:329]:acc: 0.874811,loss:0.295539
[epoch:41,batch:359]:acc: 0.874392,loss:0.297707
[epoch:41,batch:389]:acc: 0.873558,loss:0.297534
[epoch:41,batch:419]:acc: 0.875000,loss:0.294194
[epoch:41,batch:449]:acc: 0.874931,loss:0.296600
[epoch:41,batch:479]:acc: 0.874089,loss:0.298039
[epoch:41,batch:509]:acc

[epoch:45,batch:119]:acc: 0.877344,loss:0.292713
[epoch:45,batch:149]:acc: 0.877500,loss:0.290506
[epoch:45,batch:179]:acc: 0.875000,loss:0.291757
[epoch:45,batch:209]:acc: 0.872917,loss:0.296476
[epoch:45,batch:239]:acc: 0.871484,loss:0.298545
[epoch:45,batch:269]:acc: 0.872454,loss:0.298477
[epoch:45,batch:299]:acc: 0.873333,loss:0.295479
[epoch:45,batch:299]: val_loss:0.392838,val_acc:0.855915,val_total:4539
[epoch:45,batch:329]:acc: 0.873106,loss:0.294339
[epoch:45,batch:359]:acc: 0.872830,loss:0.296468
[epoch:45,batch:389]:acc: 0.873237,loss:0.295244
[epoch:45,batch:419]:acc: 0.874107,loss:0.293255
[epoch:45,batch:449]:acc: 0.874722,loss:0.292793
[epoch:45,batch:479]:acc: 0.874154,loss:0.294420
[epoch:45,batch:509]:acc: 0.874081,loss:0.293899
[epoch:45,batch:539]:acc: 0.873900,loss:0.294273
[epoch:45,batch:569]:acc: 0.874452,loss:0.293418
[epoch:45,batch:599]:acc: 0.873646,loss:0.294485
[epoch:45,batch:599]: val_loss:0.396038,val_acc:0.854594,val_total:4539
[epoch:45,batch:629]:ac

[epoch:49,batch:269]:acc: 0.870486,loss:0.302523
[epoch:49,batch:299]:acc: 0.871146,loss:0.298996
[epoch:49,batch:299]: val_loss:0.393752,val_acc:0.856356,val_total:4539
[epoch:49,batch:329]:acc: 0.871970,loss:0.299936
[epoch:49,batch:359]:acc: 0.873090,loss:0.298314
[epoch:49,batch:389]:acc: 0.872596,loss:0.300399
[epoch:49,batch:419]:acc: 0.874033,loss:0.297539
[epoch:49,batch:449]:acc: 0.875139,loss:0.295629
[epoch:49,batch:479]:acc: 0.875977,loss:0.294378
[epoch:49,batch:509]:acc: 0.875368,loss:0.297276
[epoch:49,batch:539]:acc: 0.874826,loss:0.298328
[epoch:49,batch:569]:acc: 0.873958,loss:0.297869
[epoch:49,batch:599]:acc: 0.873958,loss:0.297282
[epoch:49,batch:599]: val_loss:0.396131,val_acc:0.854153,val_total:4539
[epoch:49,batch:629]:acc: 0.874157,loss:0.297462
[epoch:49,batch:659]:acc: 0.874479,loss:0.295932
[epoch:49,batch:689]:acc: 0.874321,loss:0.296342
[epoch:49,batch:719]:acc: 0.874653,loss:0.296263
[epoch:49,batch:749]:acc: 0.874167,loss:0.296586
[epoch:49,batch:779]:ac

[epoch:53,batch:389]:acc: 0.872035,loss:0.300698
[epoch:53,batch:419]:acc: 0.871280,loss:0.300635
[epoch:53,batch:449]:acc: 0.872500,loss:0.298464
[epoch:53,batch:479]:acc: 0.872461,loss:0.298583
[epoch:53,batch:509]:acc: 0.871630,loss:0.300871
[epoch:53,batch:539]:acc: 0.871644,loss:0.302019
[epoch:53,batch:569]:acc: 0.870833,loss:0.304022
[epoch:53,batch:599]:acc: 0.870365,loss:0.304438
[epoch:53,batch:599]: val_loss:0.396760,val_acc:0.851509,val_total:4539
[epoch:53,batch:629]:acc: 0.871032,loss:0.303842
[epoch:53,batch:659]:acc: 0.871591,loss:0.303812
[epoch:53,batch:689]:acc: 0.871105,loss:0.304322
[epoch:53,batch:719]:acc: 0.871571,loss:0.303794
[epoch:53,batch:749]:acc: 0.871583,loss:0.302981
[epoch:53,batch:779]:acc: 0.871354,loss:0.302054
[epoch:53,batch:809]:acc: 0.871258,loss:0.301678
[epoch:53,batch:839]:acc: 0.872173,loss:0.300188
[epoch:53,batch:869]:acc: 0.872306,loss:0.300039
[epoch:53,batch:899]:acc: 0.872153,loss:0.300075
[epoch:53,batch:899]: val_loss:0.397002,val_ac

[epoch:57,batch:539]:acc: 0.872917,loss:0.298441
[epoch:57,batch:569]:acc: 0.871820,loss:0.299315
[epoch:57,batch:599]:acc: 0.871563,loss:0.299192
[epoch:57,batch:599]: val_loss:0.392487,val_acc:0.855475,val_total:4539
[epoch:57,batch:629]:acc: 0.871825,loss:0.299508
[epoch:57,batch:659]:acc: 0.872917,loss:0.299224
[epoch:57,batch:689]:acc: 0.873777,loss:0.298790
[epoch:57,batch:719]:acc: 0.873568,loss:0.298488
[epoch:57,batch:749]:acc: 0.873500,loss:0.298565
[epoch:57,batch:779]:acc: 0.873958,loss:0.297621
[epoch:57,batch:809]:acc: 0.873727,loss:0.296862
[epoch:57,batch:839]:acc: 0.874107,loss:0.296311
[epoch:57,batch:869]:acc: 0.874569,loss:0.295949
[epoch:57,batch:899]:acc: 0.874479,loss:0.296634
[epoch:57,batch:899]: val_loss:0.397065,val_acc:0.851509,val_total:4539
[epoch:57,batch:929]:acc: 0.874798,loss:0.296551
[epoch:57,batch:959]:acc: 0.874382,loss:0.296968
[epoch:57,batch:989]:acc: 0.874463,loss:0.296957
[epoch:57] :acc: 0.874421,loss:0.297086,lr:0.000000,patience:1
[epoch:57

In [4]:
def TrainWithRawData(path,epochNum):
    train_dataset=MyDataSet(json_Description=ANNOTATION_TRAIN,transform=preprocess(normalize_05,IMAGE_SIZE),path_pre=IMAGE_TRAIN_PRE)
    val_dataset=MyDataSet(json_Description=ANNOTATION_VAL,transform=preprocess(normalize_05,IMAGE_SIZE),path_pre=IMAGE_VAL_PRE)
    train_dataLoader=DataLoader(dataset=train_dataset,batch_size=BATCH_SIZE,num_workers=16,shuffle=True)
    val_dataLoader=DataLoader(dataset=val_dataset,batch_size=BATCH_SIZE,num_workers=1,shuffle=False)
    model=getmodel()
    criterion=nn.CrossEntropyLoss().cuda()
    modelParams=torch.load(path)
    model.load_state_dict(modelParams['state_dict'])
    min_loss=modelParams['val_loss']
    print('min_loss is :%f'%(min_loss))
    print('val_correct is %f'%(modelParams['val_correct']))
    min_acc=max(modelParams['val_correct'],0.81)
    optinizerSave=modelParams['optimizer']
    patience=0
    lr=1e-4
    momentum=0.9
    beginepoch=modelParams['epoch']
    for epoch in range(beginepoch,epochNum):
        print('Epoch {}/{}'.format(epoch, epochNum - 1))
        print('-' * 10)
        if patience==3:
            patience=0
            model.load_state_dict(torch.load('../model/NasnetMobile/'+date+'_loss_best.pth')['state_dict'])
            lr=lr/5
            print('lr desencd')
        if epoch==beginepoch:
            optimizer=torch.optim.SGD(params=model.parameters(),lr=lr,momentum=momentum)
#             optimizer.load_state_dict(optinizerSave)
#             lr=optimizer['lr']
#             momentum=optimizer['momentum']
            print('begin lr is ',lr)
            
        else:
            optimizer=torch.optim.SGD(params=model.parameters(),lr=lr,momentum=momentum)
                   
        running_loss = RunningMean()
        running_corrects = RunningMean()
        for batch_idx, (inputs, labels) in enumerate(train_dataLoader):
            model.train(True)
            n_batchsize=inputs.size(0)
            inputs = Variable(inputs).cuda()
            labels = Variable(labels).cuda()
            optimizer.zero_grad()
            outputs = model(inputs)
            _, preds = torch.max(outputs.data, 1)
            if isinstance(outputs,tuple):
                loss=sum((criterion(o,labels)) for o in outputs)
            else:
                loss = criterion(outputs, labels)
            running_loss.update(loss.item(),1)
            running_corrects.update(torch.sum(preds == labels.data).data,n_batchsize)
            loss.backward()
            optimizer.step()
            if batch_idx%30==29:
                print('[epoch:%d,batch:%d]:acc: %f,loss:%f'%(epoch,batch_idx,running_corrects.value,running_loss.value))
                if batch_idx%300==299: 
                    niter = epoch * len(train_dataset)/BATCH_SIZE + batch_idx
                    lx,px=utils.predict(model,val_dataLoader)
                    log_loss = criterion(Variable(px), Variable(lx))
                    log_loss = log_loss.item()
                    _, preds = torch.max(px, dim=1)
                    accuracy = torch.mean((preds == lx).float())
                    print('[epoch:%d,batch:%d]: val_loss:%f,val_acc:%f,val_total:%d'%(epoch,batch_idx,log_loss,accuracy,len(val_dataset)))
                    if  log_loss < min_loss:
                        utils.snapshot('../model/', 'NasnetMobile', {
                               'epoch': epoch + 1,
                               'state_dict': model.state_dict(),
                               'optimizer': optimizer.state_dict(),
                               'val_loss': log_loss,
                               'val_correct':accuracy })          

                        min_loss=log_loss
                        patience=0
                        print('save new model loss,now loss is ',min_loss)

                    if accuracy>min_acc:
                        utils.snapshot('../model/', 'NasnetMobile', {
                               'epoch': epoch + 1,
                               'state_dict': model.state_dict(),
                               'optimizer': optimizer.state_dict(),
                               'val_loss': log_loss,
                               'val_correct':accuracy },key='acc') 
                        min_acc=accuracy
                        print('save new model acc,now acc is ',min_acc)
        print('[epoch:%d] :acc: %f,loss:%f,lr:%f,patience:%d'%(epoch,running_corrects.value,running_loss.value,lr,patience))         
        lx,px=utils.predict(model,val_dataLoader)
        log_loss = criterion(Variable(px), Variable(lx))
        log_loss = log_loss.item()
        _, preds = torch.max(px, dim=1)
        accuracy = torch.mean((preds == lx).float())
        print('[epoch:%d]: val_loss:%f,val_acc:%f,'%(epoch,log_loss,accuracy))
        if  log_loss < min_loss:
            utils.snapshot('../model/', 'NasnetMobile', {
                   'epoch': epoch + 1,
                   'state_dict': model.state_dict(),
                   'optimizer': optimizer.state_dict(),
                   'val_loss': log_loss,
                   'val_correct':accuracy })          
            patience = 0
            min_loss=log_loss
            print('save new model loss,now loss is ',min_loss)
        else:
            patience += 1
        if accuracy>min_acc:
            utils.snapshot('../model/', 'NasnetMobile', {
                   'epoch': epoch + 1,
                   'state_dict': model.state_dict(),
                   'optimizer': optimizer.state_dict(),
                   'val_loss': log_loss,
                   'val_correct':accuracy },key='acc') 
            min_acc=accuracy
            print('save new model acc,now acc is ',min_acc)

In [None]:
TrainWithRawData('../model/NasnetMobile/2018-11-01_acc_best.pth',60)

[+] loading model... Done
min_loss is :0.395555
val_correct is 0.857678
Epoch 19/59
----------
begin lr is  0.0001
[epoch:19,batch:29]:acc: 0.865625,loss:0.311833
[epoch:19,batch:59]:acc: 0.877604,loss:0.296242
[epoch:19,batch:89]:acc: 0.878819,loss:0.290369
[epoch:19,batch:119]:acc: 0.876302,loss:0.290033
[epoch:19,batch:149]:acc: 0.876042,loss:0.288834
[epoch:19,batch:179]:acc: 0.877431,loss:0.288586
[epoch:19,batch:209]:acc: 0.879762,loss:0.282727
[epoch:19,batch:239]:acc: 0.878255,loss:0.283751
[epoch:19,batch:269]:acc: 0.877083,loss:0.284995
[epoch:19,batch:299]:acc: 0.875625,loss:0.286338
[epoch:19,batch:299]: val_loss:0.370539,val_acc:0.855695,val_total:4539
save new model loss,now loss is  0.3705390393733978
[epoch:19,batch:329]:acc: 0.876042,loss:0.285321
[epoch:19,batch:359]:acc: 0.876389,loss:0.284622
[epoch:19,batch:389]:acc: 0.877404,loss:0.283591
[epoch:19,batch:419]:acc: 0.877902,loss:0.281815
[epoch:19,batch:449]:acc: 0.877986,loss:0.281527
[epoch:19,batch:479]:acc: 0.8

[epoch:22,batch:899]:acc: 0.895382,loss:0.247747
[epoch:22,batch:899]: val_loss:0.360367,val_acc:0.863186,val_total:4539
save new model acc,now acc is  tensor(0.8632, device='cuda:0')
[epoch:22,batch:929]:acc: 0.895262,loss:0.248014
[epoch:22,batch:959]:acc: 0.895052,loss:0.247616
[epoch:22,batch:989]:acc: 0.894729,loss:0.248430
[epoch:22] :acc: 0.894725,loss:0.249884,lr:0.000100,patience:0
[epoch:22]: val_loss:0.363611,val_acc:0.862965,
Epoch 23/59
----------
[epoch:23,batch:29]:acc: 0.909375,loss:0.232563
[epoch:23,batch:59]:acc: 0.902604,loss:0.237394
[epoch:23,batch:89]:acc: 0.899653,loss:0.246408
[epoch:23,batch:119]:acc: 0.898177,loss:0.241960
[epoch:23,batch:149]:acc: 0.897917,loss:0.242914
[epoch:23,batch:179]:acc: 0.898438,loss:0.244486
[epoch:23,batch:209]:acc: 0.898661,loss:0.243714
[epoch:23,batch:239]:acc: 0.899089,loss:0.244187
[epoch:23,batch:269]:acc: 0.899306,loss:0.241106
[epoch:23,batch:299]:acc: 0.898125,loss:0.241046
[epoch:23,batch:299]: val_loss:0.362028,val_acc:

[epoch:26,batch:959]:acc: 0.890755,loss:0.253670
[epoch:26,batch:989]:acc: 0.890751,loss:0.253879
[epoch:26] :acc: 0.890753,loss:0.255464,lr:0.000020,patience:0
[epoch:26]: val_loss:0.369189,val_acc:0.860101,
Epoch 27/59
----------
[epoch:27,batch:29]:acc: 0.895833,loss:0.234334
[epoch:27,batch:59]:acc: 0.896875,loss:0.235567
[epoch:27,batch:89]:acc: 0.891667,loss:0.246647
[epoch:27,batch:119]:acc: 0.891667,loss:0.249157
[epoch:27,batch:149]:acc: 0.892292,loss:0.249597
[epoch:27,batch:179]:acc: 0.892535,loss:0.250566
[epoch:27,batch:209]:acc: 0.891667,loss:0.250767
[epoch:27,batch:239]:acc: 0.890755,loss:0.252361
[epoch:27,batch:269]:acc: 0.890162,loss:0.252783
[epoch:27,batch:299]:acc: 0.890312,loss:0.252026
[epoch:27,batch:299]: val_loss:0.361916,val_acc:0.862304,val_total:4539
[epoch:27,batch:329]:acc: 0.890057,loss:0.253942
[epoch:27,batch:359]:acc: 0.891406,loss:0.251570
[epoch:27,batch:389]:acc: 0.891827,loss:0.251546
[epoch:27,batch:419]:acc: 0.891443,loss:0.252218
[epoch:27,bat

[epoch:31,batch:59]:acc: 0.886458,loss:0.258226
[epoch:31,batch:89]:acc: 0.883681,loss:0.266835
[epoch:31,batch:119]:acc: 0.889583,loss:0.259451
[epoch:31,batch:149]:acc: 0.892500,loss:0.253351
[epoch:31,batch:179]:acc: 0.891146,loss:0.254871
[epoch:31,batch:209]:acc: 0.890179,loss:0.254385
[epoch:31,batch:239]:acc: 0.891797,loss:0.251262
[epoch:31,batch:269]:acc: 0.893056,loss:0.248413
[epoch:31,batch:299]:acc: 0.891458,loss:0.251650
[epoch:31,batch:299]: val_loss:0.361629,val_acc:0.864948,val_total:4539
[epoch:31,batch:329]:acc: 0.892235,loss:0.250076
[epoch:31,batch:359]:acc: 0.891233,loss:0.251308
[epoch:31,batch:389]:acc: 0.889984,loss:0.253836
[epoch:31,batch:419]:acc: 0.890923,loss:0.252119
[epoch:31,batch:449]:acc: 0.891528,loss:0.250367
[epoch:31,batch:479]:acc: 0.891536,loss:0.249474
[epoch:31,batch:509]:acc: 0.893076,loss:0.246739
[epoch:31,batch:539]:acc: 0.893519,loss:0.246331
[epoch:31,batch:569]:acc: 0.892050,loss:0.249392
[epoch:31,batch:599]:acc: 0.892344,loss:0.249423

[epoch:35,batch:239]:acc: 0.890885,loss:0.254772
[epoch:35,batch:269]:acc: 0.890278,loss:0.256499
[epoch:35,batch:299]:acc: 0.890521,loss:0.254855
[epoch:35,batch:299]: val_loss:0.361265,val_acc:0.862965,val_total:4539
[epoch:35,batch:329]:acc: 0.890625,loss:0.254049
[epoch:35,batch:359]:acc: 0.891406,loss:0.252584
[epoch:35,batch:389]:acc: 0.891587,loss:0.252822
[epoch:35,batch:419]:acc: 0.891815,loss:0.252713
[epoch:35,batch:449]:acc: 0.891667,loss:0.252892
[epoch:35,batch:479]:acc: 0.890495,loss:0.254079
[epoch:35,batch:509]:acc: 0.890564,loss:0.253017
[epoch:35,batch:539]:acc: 0.889525,loss:0.254952
[epoch:35,batch:569]:acc: 0.889090,loss:0.255538
[epoch:35,batch:599]:acc: 0.889010,loss:0.254361
[epoch:35,batch:599]: val_loss:0.360334,val_acc:0.864067,val_total:4539
[epoch:35,batch:629]:acc: 0.888542,loss:0.254251
[epoch:35,batch:659]:acc: 0.888400,loss:0.254292
[epoch:35,batch:689]:acc: 0.888089,loss:0.255411
[epoch:35,batch:719]:acc: 0.888238,loss:0.255583
[epoch:35,batch:749]:ac