In [1]:
!ls ../input/plant-pathology-2020-baseline-train/baseline-512/model.best

../input/plant-pathology-2020-baseline-train/baseline-512/model.best


In [2]:
import os
import torch

class Hparams():
    def __init__(self):

        self.cuda = True if torch.cuda.is_available() else False

        """
        Data Parameters
        """

        # os.makedirs('../input', exist_ok=True)
        os.makedirs('../model', exist_ok=True)
        os.makedirs('../results/', exist_ok=True)

        self.train_csv = '../input/plant-pathology-2020-fgvc7/train.csv'
        self.test_csv = '../input/plant-pathology-2020-fgvc7/test.csv'
        self.valid_csv = '../input/plant-pathology-2020-fgvc7/valid.csv'
        self.valid1_csv = '../input/plant-pathology-2020-fgvc7/valid1.csv'

        self.train_dir = '../input/plant-pathology-2020-fgvc7/images/'
        self.test_dir = '../input/plant-pathology-2020-fgvc7/images/'
        self.valid_dir = '../input/plant-pathology-2020-fgvc7/images/'

        """
        Model Parameters
        """

        os.makedirs('../model/', exist_ok=True)

        self.image_shape = (512, 512)
        self.num_channel = 3
        self.num_classes = 4

        self.id_to_class = {
            0: 'healthy',
            1: 'multiple_diseases',
            2: 'rust',
            3: 'scab',
        }

        self.weights = [1.0, 4.0, 1.0, 1.0]

        """
        Training parameters
        """

        self.gpu_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.device_ids = [0]

        self.pretrained = True

        self.thresh = 0.5
        self.repeat_infer = 1

        self.num_epochs = 10
        self.batch_size = 2

        self.learning_rate = 0.0001

        self.momentum1 = 0.9
        self.momentum2 = 0.999

        self.avg_mode = 'macro'

        self.print_interval = 1000

        ################################################################################################################################################
        self.exp_name = 'baseline-512/'
        ################################################################################################################################################

        self.result_dir = '../results/'+self.exp_name
        os.makedirs(self.result_dir, exist_ok=True)

        self.model_dir = '../model/' + self.exp_name
        os.makedirs(self.model_dir, exist_ok=True)

        self.model = self.model_dir + 'model'


hparams = Hparams()


In [3]:
from __future__ import print_function, division
import os
import json
import csv
import torch
import random
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image
from PIL import ImageFilter

import code

class ChestData(Dataset):

  def __init__(self, data_csv, data_dir, transform=None, image_shape=hparams.image_shape, pre_process=None, ds_type=''):
        'Initialization'
        self.data_csv = data_csv
        self.data_dir = data_dir
        self.image_shape = hparams.image_shape
        self.ds_type = ds_type
        self.transform = transform
        self.pre_process = pre_process
        self.data_frame = pd.read_csv(data_csv)

  def __len__(self):
        'Denotes the total number of samples'
        return len(self.data_frame)

  def __getitem__(self, index):
        'Generates one sample of data'

        img_name = os.path.join(self.data_dir,
                                self.data_frame.iloc[index, 0])+'.jpg'

        image = Image.open(img_name)
        
        if image.size[0] > 2000 and image.size[1] > 1000:
          image = transforms.CenterCrop((3*image.size[1]//4, 2*image.size[0]//3))(image)

        if self.transform:
            image = self.transform(image)

        return (image, self.data_frame.iloc[index, 0])


In [4]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from collections import OrderedDict
import torchvision

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = models.densenet121(pretrained=False)
        num_ftrs = self.model.classifier.in_features
        self.model.classifier = nn.Sequential(
                                    nn.Linear(num_ftrs, hparams.num_classes),
                                    nn.Softmax(dim=1))

    def forward(self, x):
        x = self.model(x)
        return x


In [5]:
import time
import code
import os, torch, sys
import torch
import csv
from tqdm import tqdm
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, utils
from torchvision.utils import save_image
from torch.autograd import Variable
from torch import optim
from skimage.util import random_noise
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd

import sklearn.metrics
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
from sklearn.metrics import roc_auc_score

epsilon = 0.0000000001

plt.switch_backend('agg')

def submit(model_path, data=(hparams.test_csv, hparams.test_dir)):

    test_dataset = ChestData(data_csv=data[0], data_dir=data[1],
                        transform=transforms.Compose([
                            transforms.Resize(hparams.image_shape),
                            transforms.ToTensor(),
                            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                        ]))

    test_loader = DataLoader(test_dataset, batch_size=hparams.batch_size,
                            shuffle=False, num_workers=4)


    discriminator = Discriminator().to(hparams.gpu_device)
    if hparams.cuda:
        discriminator = nn.DataParallel(discriminator, device_ids=hparams.device_ids)
    checkpoint = torch.load(model_path, map_location=hparams.gpu_device)
    discriminator.load_state_dict(checkpoint['discriminator_state_dict'])

    discriminator = discriminator.eval()
    print('Model loaded')

    Tensor = torch.cuda.FloatTensor if hparams.cuda else torch.FloatTensor

    print('Testing model on {0} examples. '.format(len(test_dataset)))

    with torch.no_grad():
        pred_logits_list = []
        img_names_list = []
#         for _ in range(hparams.repeat_infer):
        for (img, img_names) in tqdm(test_loader):
            img = Variable(img.float(), requires_grad=False)
            img = img.to(hparams.gpu_device)
            pred_logits = discriminator(img)

            pred_logits_list.append(pred_logits)
            img_names_list += img_names

        pred_logits = torch.cat(pred_logits_list, dim=0)
    if hparams.cuda:
        pred_logits = pred_logits.cpu()
    
    df = pd.DataFrame({**{'image_id': img_names_list}, **{hparams.id_to_class[idx]: pred_logits[:, idx] for idx in range(4)}})
    df.to_csv('submission.csv', index=False)

In [6]:
submit('../input/plant-pathology-2020-cropped-center/center-crop/model.best')

  0%|          | 0/911 [00:00<?, ?it/s]

Model loaded
Testing model on 1821 examples. 


100%|██████████| 911/911 [01:35<00:00,  9.55it/s]


In [7]:
!cat submission.csv

image_id,healthy,multiple_diseases,rust,scab
Test_0,0.00016570599,0.055574305,0.9439996,0.00026036688
Test_1,8.844735e-07,0.0032374412,0.9964341,0.00032764857
Test_2,0.0007177882,0.00019906924,2.5448346e-05,0.9990577
Test_3,0.9969302,1.1741677e-06,0.0030076026,6.115764e-05
Test_4,9.293739e-10,1.2507239e-05,0.99998736,8.848574e-08
Test_5,0.9970565,0.0009404536,0.00017306085,0.0018299985
Test_6,0.99731064,5.889628e-05,0.00073140464,0.0018990996
Test_7,8.018397e-10,1.5756881e-06,2.9691396e-08,0.99999845
Test_8,9.420543e-07,0.7741122,0.00025661523,0.22563027
Test_9,0.00014481905,0.028478263,0.97084326,0.0005336532
Test_10,0.0016229881,0.013851241,0.9840688,0.0004569931
Test_11,0.9908987,0.00071853265,0.0015620596,0.0068206876
Test_12,4.715609e-05,0.0015939514,5.9048745e-07,0.9983583
Test_13,0.9963541,0.0018854673,0.000612389,0.0011480171
Test_14,0.004921194,0.04341824,0.9510176,0.0006429873
Test_15,0.019119682,0.015506332,0.9634154,0.0019586077
Test_16,0.98634386,0.0011456