In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import CropModels
import utils
from augmentation import five_crops, HorizontalFlip, make_transforms
from CropDataset import MyDataSet, preprocess, preprocess_hflip, normalize_05, normalize_torch
import random
import numpy as np
import os
import pandas as pd
import torch.nn.functional as F
from scipy.stats.mstats import gmean
NB_CLASS=59
SEED=888
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
os.environ["CUDA_VISIBLE_DEVICES"] = '3'
torch.backends.cudnn.benchmark = True

In [2]:

BATCH_SIZE=32
IMAGE_TRAIN_PRE='../data/AgriculturalDisease_trainingset/images/'
ANNOTATION_TRAIN='../data/AgriculturalDisease_trainingset/AgriculturalDisease_train_annotations_deleteNoise.json' #是否需要剔除两类异常类
IMAGE_VAL_PRE='../data/AgriculturalDisease_validationset/images/'
ANNOTATION_VAL='../data/AgriculturalDisease_validationset/AgriculturalDisease_validation_annotations_deleteNoise.json' #是否需要剔除两类异常类
IMAGE_TEST_PRE='../data/AgriculturalDisease_testA/images/'
ANNOTATION_TEST='../data/AgriculturalDisease_testA/AgriculturalDisease_test_annotations.json'
with open(ANNOTATION_TRAIN) as datafile1:
    trainDataFram=pd.read_json(datafile1,orient='records')
with open(ANNOTATION_VAL) as datafile2: #first check if it's a valid json file or not
    validateDataFram =pd.read_json(datafile2,orient='records')
def get_model(model_class):
    print('[+] loading model... ', end='', flush=True)
    model = model_class(NB_CLASS)
    model.cuda()
    print('done')
    return model

In [3]:
def predictAll(model_name, model_class, weight_pth, image_size, normalize):
    print(f'[+] predict {model_name}')
    model = get_model(model_class)
    model.load_state_dict(torch.load(weight_pth)['state_dict'])
    model.eval()
    print('load state dict done')

    tta_preprocess = [preprocess(normalize, image_size), preprocess_hflip(normalize, image_size)]
    tta_preprocess += make_transforms([transforms.Resize((image_size + 20, image_size + 20))],
                                      [transforms.ToTensor(), normalize],
                                      five_crops(image_size))
    tta_preprocess += make_transforms([transforms.Resize((image_size + 20, image_size + 20))],
                                      [HorizontalFlip(), transforms.ToTensor(), normalize],
                                      five_crops(image_size))
    print(f'[+] tta size: {len(tta_preprocess)}')


    data_loaders = []
    for transform in tta_preprocess:

        test_dataset = MyDataSet(json_Description=ANNOTATION_VAL,transform=transform,path_pre=IMAGE_VAL_PRE)
        data_loader = DataLoader(dataset=test_dataset, num_workers=16,
                                 batch_size=BATCH_SIZE,
                                 shuffle=False)
        data_loaders.append(data_loader)
        print('add transforms')

    lx, px = utils.predict_tta(model, data_loaders)
    data = {
        'lx': lx.cpu(),
        'px': px.cpu(),
    }
    if not os.path.exists('../feature/'+model_name):
        os.makedirs('../feature/'+model_name)
    torch.save(data, '../feature/'+model_name+'/val_all_prediction.pth')
    print('Done')


In [4]:
def predictFlip(model_name, model_class, weight_pth, image_size, normalize):
    print(f'[+] predict {model_name}')
    model = get_model(model_class)
    model.load_state_dict(torch.load(weight_pth)['state_dict'])
    model.eval()
    print('load state dict done')

    tta_preprocess = [preprocess(normalize, image_size), preprocess_hflip(normalize, image_size)]

#     tta_preprocess += make_transforms([transforms.Resize((image_size + 20, image_size + 20))],
#                                       [transforms.ToTensor(), normalize],
#                                       five_crops(image_size))
#     tta_preprocess += make_transforms([transforms.Resize((image_size + 20, image_size + 20))],
#                                       [HorizontalFlip(), transforms.ToTensor(), normalize],
#                                       five_crops(image_size))
    print(f'[+] tta size: {len(tta_preprocess)}')


    data_loaders = []
    for transform in tta_preprocess:

        test_dataset = MyDataSet(json_Description=ANNOTATION_VAL,transform=transform,path_pre=IMAGE_VAL_PRE)
        data_loader = DataLoader(dataset=test_dataset, num_workers=16,
                                 batch_size=BATCH_SIZE,
                                 shuffle=False)
        data_loaders.append(data_loader)
        print('add transforms')

    lx, px = utils.predict_tta(model, data_loaders)
    data = {
        'lx': lx.cpu(),
        'px': px.cpu(),
    }
    if not os.path.exists('../feature/'+model_name):
        os.makedirs('../feature/'+model_name)
    torch.save(data, '../feature/'+model_name+'/val_flip_prediction.pth')
    print('Done')

In [5]:
def predictCrop(model_name, model_class, weight_pth, image_size, normalize):
    print(f'[+] predict {model_name}')
    model = get_model(model_class)
    model.load_state_dict(torch.load(weight_pth)['state_dict'])
    model.eval()
    print('load state dict done')

    
    tta_preprocess=[preprocess(normalize, image_size)]
    tta_preprocess += make_transforms([transforms.Resize((image_size + 20, image_size + 20))],
                                      [transforms.ToTensor(), normalize],
                                      five_crops(image_size))
#     tta_preprocess += make_transforms([transforms.Resize((image_size + 20, image_size + 20))],
#                                       [HorizontalFlip(), transforms.ToTensor(), normalize],
#                                       five_crops(image_size))
    print(f'[+] tta size: {len(tta_preprocess)}')


    data_loaders = []
    for transform in tta_preprocess:

        test_dataset = MyDataSet(json_Description=ANNOTATION_VAL,transform=transform,path_pre=IMAGE_VAL_PRE)
        data_loader = DataLoader(dataset=test_dataset, num_workers=16,
                                 batch_size=BATCH_SIZE,
                                 shuffle=False)
        data_loaders.append(data_loader)
        print('add transforms')

    lx, px = utils.predict_tta(model, data_loaders)
    data = {
        'lx': lx.cpu(),
        'px': px.cpu(),
    }
    if not os.path.exists('../feature/'+model_name):
        os.makedirs('../feature/'+model_name)
    torch.save(data, '../feature/'+model_name+'/val_crop_prediction.pth')
    print('Done')

In [6]:
def predict_raw(model_name, model_class, weight_pth, image_size, normalize):
    print(f'[+] predict {model_name}')
    model = get_model(model_class)
    model.load_state_dict(torch.load(weight_pth)['state_dict'])
    model.eval()
    print('load state dict done')
    test_dataset = MyDataSet(json_Description=ANNOTATION_VAL,transform=preprocess(normalize, image_size),path_pre=IMAGE_VAL_PRE)
    data_loader = DataLoader(dataset=test_dataset, num_workers=16,
                                 batch_size=BATCH_SIZE,
                                 shuffle=False)
    lx,px=utils.predict(model,data_loader)
    data = {
        'lx': lx.cpu(),
        'px': px.cpu(),
    }
    if not os.path.exists('../feature/'+model_name):
        os.makedirs('../feature/'+model_name)
    torch.save(data, '../feature/'+model_name+'/val_raw_prediction.pth')

    

In [7]:
def predict_all():
    
    
    
#     predict_raw('Resnet50',CropModels.resnet50_finetune,'../model/ResNet50/2018-11-01_acc_best.pth',224,normalize_torch)
#     predictAll('Resnet50',CropModels.resnet50_finetune,'../model/ResNet50/2018-11-01_loss_best.pth',224,normalize_torch)
#     predictFlip('Resnet50',CropModels.resnet50_finetune,'../model/ResNet50/2018-11-01_loss_best.pth',224,normalize_torch)
#     predictCrop('Resnet50',CropModels.resnet50_finetune,'../model/ResNet50/2018-11-01_loss_best.pth',224,normalize_torch)
#     print('ResNet50 done')

#     predict_raw('Resnet101',CropModels.resnet101_finetune,'../model/ResNet101/2018-11-04_acc_best.pth',224,normalize_torch)
#     predictAll('Resnet101',CropModels.resnet101_finetune,'../model/ResNet101/2018-11-04_acc_best.pth',224,normalize_torch)
#     predictFlip('Resnet101',CropModels.resnet101_finetune,'../model/ResNet101/2018-11-04_acc_best.pth',224,normalize_torch)
#     predictCrop('Resnet101',CropModels.resnet101_finetune,'../model/ResNet101/2018-11-04_acc_best.pth',224,normalize_torch)
#     print('ResNet101 done')

    predict_raw('DesNet121',CropModels.densenet121_finetune,'../model/DesNet121/2018-11-04_acc_best.pth',224,normalize_torch)
    predictAll('DesNet121',CropModels.densenet121_finetune,'../model/DesNet121/2018-11-04_acc_best.pth',224,normalize_torch)
    predictFlip('DesNet121',CropModels.densenet121_finetune,'../model/DesNet121/2018-11-04_acc_best.pth',224,normalize_torch)
    predictCrop('DesNet121',CropModels.densenet121_finetune,'../model/DesNet121/2018-11-04_acc_best.pth',224,normalize_torch)
    print('DesNet121 done')


#     predict_raw('Resnet152',CropModels.resnet152_finetune,'../model/ResNet/2018-10-31_acc_best.pth',224,normalize_torch)
#     predictAll('Resnet152',CropModels.resnet152_finetune,'../model/ResNet/2018-10-31_acc_best.pth',224,normalize_torch)
#     predictFlip('Resnet152',CropModels.resnet152_finetune,'../model/ResNet/2018-10-31_acc_best.pth',224,normalize_torch)
#     predictCrop('Resnet152',CropModels.resnet152_finetune,'../model/ResNet/2018-10-31_acc_best.pth',224,normalize_torch)
#     print('ResNet152 done')


#     predict_raw('DesNet201',CropModels.densenet201_finetune,'../model/DesNet201/2018-11-01_acc_best.pth',224,normalize_torch) 
#     predictCrop('DesNet201',CropModels.densenet201_finetune,'../model/DesNet201/2018-11-01_acc_best.pth',224,normalize_torch) 
#     predictFlip('DesNet201',CropModels.densenet201_finetune,'../model/DesNet201/2018-11-01_acc_best.pth',224,normalize_torch) 
#     predictAll('DesNet201',CropModels.densenet201_finetune,'../model/DesNet201/2018-11-01_acc_best.pth',224,normalize_torch)
#     print('DesNet201 done')

    #DesNet161
#     predict_raw('DesNet161',CropModels.densenet161_finetune,'../model/DesNet161/2018-11-02_acc_best.pth',224,normalize_torch) 
#     predictCrop('DesNet161',CropModels.densenet161_finetune,'../model/DesNet161/2018-11-02_acc_best.pth',224,normalize_torch) 
#     predictFlip('DesNet161',CropModels.densenet161_finetune,'../model/DesNet161/2018-11-02_acc_best.pth',224,normalize_torch) 
#     predictAll('DesNet161',CropModels.densenet161_finetune,'../model/DesNet161/2018-11-02_acc_best.pth',224,normalize_torch)
#     print('Desnet161 Done')
    
    #Nasnetmobile   
#     predict_raw('Nasnetmobile',CropModels.nasnetmobile,'../model/NasnetMobile/2018-11-02_acc_best.pth',224,normalize_05) 
#     predictCrop('Nasnetmobile',CropModels.nasnetmobile,'../model/NasnetMobile/2018-11-02_acc_best.pth',224,normalize_05) 
#     predictFlip('Nasnetmobile',CropModels.nasnetmobile,'../model/NasnetMobile/2018-11-02_acc_best.pth',224,normalize_05) 
#     predictAll('Nasnetmobile',CropModels.nasnetmobile,'../model/NasnetMobile/2018-11-02_acc_best.pth',224,normalize_05)
#     print('nasnetmobile Done')
    
    # InceptionV4
#     predict_raw('InceptionV4',CropModels.inceptionv4_finetune,'../model/InceptionV4/2018-11-01_acc_best.pth',299,normalize_05) 
#     predictCrop('InceptionV4',CropModels.inceptionv4_finetune,'../model/InceptionV4/2018-11-01_acc_best.pth',299,normalize_05) 
#     predictFlip('InceptionV4',CropModels.inceptionv4_finetune,'../model/InceptionV4/2018-11-01_acc_best.pth',299,normalize_05) 
#     predictAll('InceptionV4',CropModels.inceptionv4_finetune,'../model/InceptionV4/2018-11-01_acc_best.pth',299,normalize_05)
#     print('InceptionV4 done')
    #InceptionV3
#     predict_raw('InceptionV3',CropModels.InceptionV3Finetune,'../model/InceptionV3/2018-11-03_acc_best.pth',299,normalize_05) 
#     predictAll('InceptionV3',CropModels.InceptionV3Finetune,'../model/InceptionV3/2018-11-03_acc_best.pth',299,normalize_05) 
#     predictCrop('InceptionV3',CropModels.InceptionV3Finetune,'../model/InceptionV3/2018-11-03_acc_best.pth',299,normalize_05) 
#     predictFlip('InceptionV3',CropModels.InceptionV3Finetune,'../model/InceptionV3/2018-11-03_acc_best.pth',299,normalize_05) 
    #Xception
#     predict_raw('Xception',CropModels.xception_finetune,'../model/Xception/2018-11-03_acc_best.pth',299,normalize_torch) 
#     predictCrop('Xception',CropModels.xception_finetune,'../model/Xception/2018-11-03_acc_best.pth',299,normalize_torch) 
#     predictFlip('Xception',CropModels.xception_finetune,'../model/Xception/2018-11-03_acc_best.pth',299,normalize_torch) 
#     predictAll('Xception',CropModels.xception_finetune,'../model/Xception/2018-11-03_acc_best.pth',299,normalize_torch)
#     print('Xception done')
    #Inception-Resnetv2
#     predict_raw('InceptionResnet',CropModels.inceptionresnetv2_finetune,'../model/Inception_Resnet/2018-11-02_acc_best.pth',299,normalize_05) 
#     predictAll('InceptionResnet',CropModels.inceptionresnetv2_finetune,'../model/Inception_Resnet/2018-11-02_acc_best.pth',299,normalize_05)
#     predictCrop('InceptionResnet',CropModels.inceptionresnetv2_finetune,'../model/Inception_Resnet/2018-11-02_acc_best.pth',299,normalize_05)
#     predictFlip('InceptionResnet',CropModels.inceptionresnetv2_finetune,'../model/Inception_Resnet/2018-11-02_acc_best.pth',299,normalize_05)

In [13]:

predict=torch.load('../feature/DesNet121/val_raw_prediction.pth')
val_Predict=predict['px']
predictSoftMax=F.softmax(val_Predict,dim=1)
_,px=torch.max(predictSoftMax,dim=1)
print(torch.mean((predict['lx']==px).float()))
scoreAfter=utils.calibrate_probs(trainDataFram,validateDataFram,predictSoftMax,59)
val_preAfter=np.argmax(scoreAfter,axis=1)
accurarysAfterCal=np.mean(val_preAfter == predict['lx'].numpy())
print('accurarysAfterCal is',accurarysAfterCal)



tensor(0.8733)
accurarysAfterCal is 0.8730998017184401


In [16]:
import torch.nn.functional as F
from scipy.stats.mstats import gmean
predict=torch.load('../feature/DesNet121/val_flip_prediction.pth')
predictScore=predict['px']
predictSoftMax=F.softmax(predictScore,dim=1)
val_prob =gmean(predictSoftMax, axis=2) #np.mean(score,axis=2) #
#val_prob=np.mean(predictSoftMax.numpy(),axis=2)
val_pred = np.argmax(val_prob, axis=1)
print(np.mean(val_pred == predict['lx'].numpy()))

scoreAfter=utils.calibrate_probs2(trainDataFram,validateDataFram,val_prob,59)
val_preAfter=np.argmax(scoreAfter,axis=1)
accurarysAfterCal=np.mean(val_preAfter == predict['lx'].numpy())
print('accurarysAfterCal is',accurarysAfterCal)

0.8735404274069178
accurarysAfterCal is 0.8735404274069178


In [92]:
'''
测试结果

'''

label=validateDataFram['disease_class'].values
val_predResnet50=torch.load('../feature/Resnet50/val_raw_prediction.pth')['px']
val_predResnet50=torch.unsqueeze(val_predResnet50,2)

val_predResnet101=torch.load('../feature/Resnet101/val_raw_prediction.pth')['px']
val_predResnet101=torch.unsqueeze(val_predResnet101,2)

val_predResnet152=torch.load('../feature/Resnet152/val_raw_prediction.pth')['px']
val_predResnet152=torch.unsqueeze(val_predResnet152,2)

val_predDesnet201=torch.load('../feature/DesNet201/val_crop_prediction.pth')['px']
#val_predDesnet201=torch.unsqueeze(val_predDesnet201,2)

val_predDesnet121=torch.load('../feature/DesNet121/val_raw_prediction.pth')['px']
val_predDesnet121=torch.unsqueeze(val_predDesnet121,2)

val_predDesnet161=torch.load('../feature/DesNet161/val_raw_prediction.pth')['px']
val_predDesnet161=torch.unsqueeze(val_predDesnet161,2)


val_predNasnetMobile=torch.load('../feature/Nasnetmobile/val_raw_prediction.pth')['px']
val_predNasnetMobile=torch.unsqueeze(val_predNasnetMobile,2)
# val_predNasnetMobile=torch.load('../feature/Nasnetmobile/val_all_prediction.pth')['px']
val_preInceptionv4=torch.load('../feature/InceptionV4/val_raw_prediction.pth')['px']
val_preInceptionv4=torch.unsqueeze(val_preInceptionv4,2)

val_preInceptionv3=torch.load('../feature/InceptionV3/val_raw_prediction.pth')['px']
val_preInceptionv3=torch.unsqueeze(val_preInceptionv3,2)

val_preInceptionResnet=torch.load('../feature/InceptionResnet/val_raw_prediction.pth')['px']
val_preInceptionResnet=torch.unsqueeze(val_preInceptionResnet,2)

#val_preInceptionResnet=torch.load('../feature/InceptionResnet/val_all_prediction.pth')['px']
val_preXception=torch.load('../feature/Xception/val_raw_prediction.pth')['px']
val_preXception=torch.unsqueeze(val_preXception,2)

val_prob=F.softmax(torch.cat((val_predResnet50,3*val_predResnet101,val_predResnet152,val_predDesnet121,3*val_predDesnet161,val_predDesnet201,val_predNasnetMobile,val_preInceptionv3,val_preInceptionv4),dim=2),dim=1).numpy()
val_prob=np.mean(val_prob,axis=2)
val_predict=np.argmax(val_prob,axis=1)
print(np.mean(val_predict == label ))
val_prob_after=utils.calibrate_probs2(trainDataFram,validateDataFram,val_prob,59)
val_predict_after=np.argmax(val_prob_after,axis=1)
print(np.mean(val_predict_after == label))

0.8830138797091871
0.8821326283322317


In [20]:
import json
def getPredictJson():
    img_list=os.listdir('../data/AgriculturalDisease_testA/images/')
    result=[]
    for img_name in img_list:
        a={"image_id":img_name,"disease_class":-1}
        result.append(a.copy())
    with open("../data/AgriculturalDisease_testA/AgriculturalDisease_test_annotations.json",'w') as f:
        json.dump(result,f,ensure_ascii=False)

In [None]:
'''
生成提交结果
'''
import json
label=validateDataFram['disease_class'].values
test_predResnet50=torch.load('../feature/Resnet50/test_all_prediction.pth')['px']
test_predResnet152=torch.load('../feature/Resnet152/test_all_prediction.pth')['px']
test_predDesnet201=torch.load('../feature/DesNet201/test_crop_prediction.pth')['px']
test_predDesnet161=torch.load('../feature/DesNet161/test_all_prediction.pth')['px']
# val_predNasnetMobile=torch.load('../feature/Nasnetmobile/val_all_prediction.pth')['px']
test_preInceptionv3=torch.load('../feature/InceptionV3/test_raw_prediction.pth')['px']
test_preInceptionv3=torch.unsqueeze(test_preInceptionv3,2)
test_preInceptionResnet=torch.load('../feature/InceptionResnet/test_raw_prediction.pth')['px']
test_preInceptionResnet=torch.unsqueeze(test_preInceptionResnet,2)
#val_preInceptionResnet=torch.load('../feature/InceptionResnet/val_all_prediction.pth')['px']
test_preXception=torch.load('../feature/Xception/test_raw_prediction.pth')['px']
test_preXception=torch.unsqueeze(test_preXception,2)
test_prob=F.softmax(torch.cat((test_predResnet50,test_predResnet152,test_predDesnet201,test_predDesnet161,test_preInceptionResnet,test_preInceptionv3,test_preXception),dim=2),dim=1).numpy()
test_prob=gmean(test_prob,axis=2)
test_predict=np.argmax(test_prob,axis=1)
print(len(test_predict))
img_list=os.listdir('../data/AgriculturalDisease_testA/images/')
result=[]
for index,img_name in enumerate(img_list):
    if test_predict[index]>=44:
        test_predict[index]+=2
    a={"image_id":img_name,"disease_class":int(test_predict[index]}
    result.append(a.copy())
with open("result.json",'w') as f:
    json.dump(result,f,ensure_ascii=False)