In [89]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import os
from tqdm import tqdm
import cv2

In [90]:
label_names = {
    0:"apple healthy（苹果健康）",
    1:"apple_Scab general（苹果黑星病一般）",
    2:"apple_Scab serious（苹果黑星病严重）",
    3:"apple Frogeye Spot（苹果灰斑病）",
    4:"Cedar Apple Rust  general（苹果雪松锈病一般）",
    5:"Cedar Apple Rust serious（苹果雪松锈病严重）",
    6:"Cherry healthy（樱桃健康）",
    7:"Cherry_Powdery Mildew  general（樱桃白粉病一般）",
    8:"Cherry_Powdery Mildew  serious（樱桃白粉病严重）",
    9:"Corn healthy（玉米健康）",
    10:"Cercospora zeaemaydis Tehon and Daniels general（玉米灰斑病一般）",
    11:"Cercospora zeaemaydis Tehon and Daniels  serious（玉米灰斑病严重）",
    12:"Puccinia polysora  general（玉米锈病一般）",
    13:"Puccinia polysora serious（玉米锈病严重）",
    14:"Corn Curvularia leaf spot fungus general（玉米叶斑病一般）",
    15:"Corn Curvularia leaf spot fungus  serious（玉米叶斑病严重）",
    16:"Maize dwarf mosaic virus（玉米花叶病毒病）",
    17:"Grape heathy（葡萄健康）",
    18:"Grape Black Rot Fungus general（葡萄黑腐病一般）",
    19:"Grape Black Rot Fungus serious（葡萄黑腐病严重）",
    20:"Grape Black Measles Fungus general（葡萄轮斑病一般）",
    21:"Grape Black Measles Fungus serious（葡萄轮斑病严重）",
    22:"Grape Leaf Blight Fungus general（葡萄褐斑病一般）",
    23:"Grape Leaf Blight Fungus  serious（葡萄褐斑病严重）",
    24:"Citrus healthy（柑桔健康）",
    25:"Citrus Greening June  general（柑桔黄龙病一般）",
    26:"Citrus Greening June  serious（柑桔黄龙病严重）",
    27:"Peach healthy（桃健康）",
    28:"Peach_Bacterial Spot general（桃疮痂病一般）",
    29:"Peach_Bacterial Spot  serious（桃疮痂病严重）",
    30:"Pepper healthy（辣椒健康）",
    31:"Pepper scab general（辣椒疮痂病一般）",
    32:"Pepper scab  serious（辣椒疮痂病严重）",
    33:"Potato healthy（马铃薯健康）",
    34:"Potato_Early Blight Fungus general（马铃薯早疫病一般）",
    35:"Potato_Early Blight Fungus serious（马铃薯早疫病严重）",
    36:"Potato_Late Blight Fungus general（马铃薯晚疫病一般）",
    37:"Potato_Late Blight Fungus  serious（马铃薯晚疫病严重）",
    38:"Strawberry healthy（草莓健康）",
    39:"Strawberry_Scorch general（草莓叶枯病一般）",
    40:"Strawberry_Scorch serious（草莓叶枯病严重）",
    41:"tomato healthy（番茄健康）",
    42:"tomato powdery mildew  general（番茄白粉病一般）",
    43:"tomato powdery mildew  serious（番茄白粉病严重）",
    44:"tomato Bacterial Spot Bacteria general（番茄疮痂病一般）",
    45:"tomato Bacterial Spot Bacteria  serious（番茄疮痂病严重）",
    46:"tomato_Early Blight Fungus general（番茄早疫病一般）",
    47:"tomato_Early Blight Fungus  serious（番茄早疫病严重）",
    48:"tomato_Late Blight Water Mold  general（番茄晚疫病菌一般）",
    49:"tomato_Late Blight Water Mold serious（番茄晚疫病菌严重）",
    50:"tomato_Leaf Mold Fungus general（番茄叶霉病一般）",
    51:"tomato_Leaf Mold Fungus serious（番茄叶霉病严重）",
    52:"tomato Target Spot Bacteria  general（番茄斑点病一般）",
    53:"tomato Target Spot Bacteria  serious（番茄斑点病严重）",
    54:"tomato_Septoria Leaf Spot Fungus  general（番茄斑枯病一般）",
    55:"tomato_Septoria Leaf Spot Fungus  serious（番茄斑枯病严重）",
    56:"tomato Spider Mite Damage general（番茄红蜘蛛损伤一般）",
    57:"tomato Spider Mite Damage serious（番茄红蜘蛛损伤严重）",
    58:"tomato YLCV Virus general（番茄黄化曲叶病毒病一般）",
    59:"tomato YLCV Virus  serious（番茄黄化曲叶病毒病严重）",
    60:"tomato Tomv（番茄花叶病毒病）"
}

In [91]:
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

#1. set random.seed
import random 
seed = seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

In [92]:
use_gpu = True
num_classes = 61

model = models.densenet201()



In [93]:
for para in list(model.parameters()):
    para.requires_grad=False
for para in list(model.features.denseblock3.parameters()):
    para.requires_grad=True
for para in list(model.features.transition3.parameters()):
    para.requires_grad=True
for para in list(model.features.denseblock4.parameters()):
    para.requires_grad=True
for para in list(model.features.norm5.parameters()):
    para.requires_grad=True

In [94]:
model.classifier = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(1920, 61),
)

In [95]:
device_ids = [0,1]

if use_gpu:
    model = model.cuda(device_ids[0])
    model = nn.DataParallel(model, device_ids=device_ids)
model.load_state_dict(torch.load('tuned-densenet.pth'))

In [97]:
path = './val/val/images/'

trans_train = transforms.Compose([transforms.RandomResizedCrop(size=224),
                                  transforms.RandomHorizontalFlip(),
                                  transforms.RandomRotation(30),
                                  transforms.ToTensor(),
                                  transforms.Normalize(mean=[0.47954108864506007, 0.5295650244021952, 0.39169756009537665],
                                                       std=[0.21481591229053462, 0.20095268035289796, 0.24845895286079178])])

trans_valid = transforms.Compose([transforms.Resize(size=224),
                                  transforms.CenterCrop(size=224),
                                  transforms.ToTensor(),
                                  transforms.Normalize(mean=[0.47954108864506007, 0.5295650244021952, 0.39169756009537665],
                                                       std=[0.21481591229053462, 0.20095268035289796, 0.24845895286079178])])



In [98]:
def get_prediction(model, loader, valid=False):
    prediction = np.array([])
    model.module.eval()
    for _, data in enumerate(loader):
        if valid:
            inputs,_ = data
        else:
            inputs = data
        print('.', end='')
        if use_gpu:
            inputs = inputs.cuda()
        outputs = model(inputs)
        pred = torch.argmax(outputs.data, dim=1)
        prediction = np.append(prediction, pred.cpu().numpy())
    return prediction

In [103]:
class TestDataset(Dataset):
    def __init__(self, data_dir = './', transform=None):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transform
        self.image_names = os.listdir(data_dir)
        
    def __len__(self):
        return len(self.image_names)
    
    def __getitem__(self, index):
        img_name = self.image_names[index]
        img_path = os.path.join(self.data_dir, img_name)
        image = cv2.resize(cv2.imread(img_path),(224,224))
        if self.transform is not None:
            image = self.transform(image)
        return image

dataset_test = TestDataset(data_dir=path, transform=trans_valid)
loader_test = DataLoader(dataset = dataset_test, batch_size=64, shuffle=False, num_workers=0)

test_prediction = get_prediction(model, loader_test)

sub = pd.DataFrame(list(zip(dataset_test.image_names,test_prediction.astype(int))),
                   columns=['image_id', 'disease_class'])

.

  segments_z = grid_z[slices]
  segments_y = grid_y[slices]
  segments_x = grid_x[slices]


......................................................................

In [58]:
seed = 250

In [57]:
sublist.append(sub)

In [104]:
sub['disease_class'] = sub['disease_class'].map(str)+' '
for i in sublist:
    sub['disease_class'] += (i['disease_class'].map(str)+' ')

In [110]:
for i in tqdm(range(len(sub['disease_class']))):
    b = sub['disease_class'][i].split(' ')[:-1]
    a = set(b)
    nummax = 0
    for j in a:
        num = b.count(j)
        if num > nummax:
            c = j
            nummax = num
    sub['disease_class'][i] = c
sub['disease_class'] = sub['disease_class'].map(int)

100%|██████████| 4540/4540 [00:00<00:00, 6399.82it/s]


In [113]:
sub.to_json('val5.json',orient='records',force_ascii=False)