In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os
import pandas as pd
import numpy as np
from fastai.vision import *
from fastai.metrics import error_rate
from sklearn.model_selection import KFold

torch.manual_seed(121)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(121)

In [3]:
df = pd.read_csv('../input/plant-pathology-2020-fgvc7/train.csv')
labels = np.array(df.iloc[:, 1:])
labels = np.where(labels==1)[1]

df['label'] = labels
df = df[['image_id', 'label']]
df.to_csv('./train_data.csv', index=False)

In [4]:
BS = 16
IMG_SZ = 512
# GPU_IDS = [0]
# torch.cuda.set_device(GPU_IDS[0])
# CLASS_WT = torch.FloatTensor([1.0, 1.0, 1.0, 1.0]).to('cuda:'+str(GPU_IDS[0]))

label_dict = {0: 'healthy',
              1: 'multiple_diseases',
              2: 'rust',
              3: 'scab',
              'micro': 'micro',
              'macro': 'macro',
             }

In [5]:
data_tfms = get_transforms(do_flip=True, flip_vert=True, max_rotate=0, max_zoom=1, max_lighting=None,
                          max_warp=0, p_affine=0, p_lighting=None)

def create_databunch(valid_idx):
    sub_csv = pd.read_csv('../input/plant-pathology-2020-fgvc7/sample_submission.csv')
    test = ImageList.from_df(sub_csv, path='../input/plant-pathology-2020-fgvc7', folder='images', suffix='.jpg')
    data = (ImageList.from_df(df, path='../input/plant-pathology-2020-fgvc7', folder='images', suffix='.jpg')
            .split_by_idx(valid_idx)
            .label_from_df()
            .add_test(test)
            .transform(data_tfms, size=IMG_SZ)
            .databunch(path='.', bs=BS)
            .normalize(imagenet_stats)
           )
    return data


In [6]:
from sklearn.metrics import roc_curve, auc, roc_auc_score, f1_score
class AUC(Callback):

    def __init__(self, num_cl, pick='micro'):
        self.id_to_class = label_dict
        self.name = str(self.id_to_class[pick])+'-AUC'
        self.pick = pick
        self.num_cl = num_cl
        
    
    def on_epoch_begin(self, **kwargs):
        self.outputs, self.targets = [], []
    
    def on_batch_end(self, last_output, last_target, **kwargs):
        self.outputs.append(last_output)
        self.targets.append(last_target)
    
    def on_epoch_end(self, last_metrics, **kwargs):
        self.outputs = F.softmax(torch.cat(self.outputs), dim=1).cpu().detach().numpy()
        self.targets = torch.cat(self.targets).cpu().numpy().reshape(-1)
        self.targets = np.eye(self.num_cl)[self.targets]
        fpr, tpr, roc_auc = {}, {}, {'macro':0}
        for i in range(self.num_cl):
            roc_auc[self.id_to_class[i]] = roc_auc_score(self.targets[:, i], self.outputs[:, i])
            roc_auc['macro'] += roc_auc[self.id_to_class[i]]
        roc_auc['micro'] = roc_auc_score(self.targets.ravel(), self.outputs.ravel())
        roc_auc['macro'] = roc_auc['macro']/self.num_cl
        return add_metrics(last_metrics, roc_auc[self.id_to_class[self.pick]])

class F1score(Callback):        
    
    def on_epoch_begin(self, **kwargs):
        self.outputs, self.targets = [], []
    
    def on_batch_end(self, last_output, last_target, **kwargs):
        self.outputs.append(last_output)
        self.targets.append(last_target)
    
    def on_epoch_end(self, last_metrics, **kwargs):
        self.outputs = F.softmax(torch.cat(self.outputs), dim=1).cpu().detach().numpy()
        self.outputs = np.argmax(self.outputs, axis=1)
        self.targets = torch.cat(self.targets).cpu().numpy().reshape(-1)
        
        return add_metrics(last_metrics, f1_score(self.targets, self.outputs, average='macro'))
    

In [7]:
kf = KFold(n_splits=5, random_state=379)
epochs = 6
lr = 1e-3
preds = []
fold_id = 1
for train_idx, valid_idx in kf.split(df):
    data = create_databunch(valid_idx)
    learn = cnn_learner(data, models.densenet121, metrics=[accuracy, F1score(), AUC(num_cl=4, pick=0), AUC(num_cl=4, pick=1), AUC(num_cl=4, pick=2), AUC(num_cl=4, pick=3), AUC(num_cl=4, pick='macro')])
    learn.fit_one_cycle(epochs, slice(lr))
    learn.unfreeze()
    learn.fit_one_cycle(epochs, slice(lr/100, lr/10))
    learn.fit_one_cycle(epochs, slice(lr/1000, lr/100))
    learn.save('densenet-{}'.format(fold_id))
    fold_id += 1
    preds.append(learn.get_preds(ds_type=DatasetType.Test))

Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/checkpoints/densenet121-a639ec97.pth


HBox(children=(FloatProgress(value=0.0, max=32342954.0), HTML(value='')))




epoch,train_loss,valid_loss,accuracy,f1score,healthy-AUC,multiple_diseases-AUC,rust-AUC,scab-AUC,macro-AUC,time
0,1.391829,0.555811,0.805479,0.6759,0.957007,0.706667,0.967318,0.950165,0.895289,05:08
1,0.800948,0.357651,0.882192,0.770056,0.988062,0.847536,0.989614,0.975051,0.950066,05:00
2,0.635037,0.378418,0.884932,0.739569,0.986708,0.805797,0.987502,0.979103,0.939778,04:40
3,0.459415,0.306767,0.893151,0.775276,0.985724,0.871884,0.992072,0.982871,0.958138,04:53
4,0.368197,0.265596,0.89863,0.795981,0.991098,0.905942,0.994114,0.986132,0.969322,04:51
5,0.363224,0.262541,0.906849,0.807141,0.990934,0.899565,0.994114,0.985974,0.967647,04:38


epoch,train_loss,valid_loss,accuracy,f1score,healthy-AUC,multiple_diseases-AUC,rust-AUC,scab-AUC,macro-AUC,time
0,0.328951,0.250756,0.920548,0.820586,0.992534,0.900435,0.994911,0.987114,0.968748,04:57
1,0.314619,0.250806,0.917808,0.821735,0.990893,0.914928,0.996088,0.987905,0.972453,04:36
2,0.279638,0.242213,0.928767,0.820981,0.995487,0.901594,0.996296,0.990217,0.970898,04:43
3,0.222883,0.227538,0.928767,0.825585,0.994667,0.922174,0.997023,0.991135,0.97625,05:11
4,0.202406,0.225829,0.928767,0.815532,0.995692,0.914493,0.996607,0.991546,0.974585,04:52
5,0.197196,0.217362,0.931507,0.829777,0.995569,0.924203,0.996607,0.992623,0.977251,05:16


epoch,train_loss,valid_loss,accuracy,f1score,healthy-AUC,multiple_diseases-AUC,rust-AUC,scab-AUC,macro-AUC,time
0,0.172728,0.22993,0.931507,0.829839,0.995159,0.918696,0.996815,0.99218,0.975712,05:17
1,0.168879,0.220152,0.926027,0.811572,0.996267,0.915942,0.997057,0.991863,0.975282,05:35
2,0.162365,0.24604,0.920548,0.806722,0.995775,0.916957,0.996434,0.991388,0.975138,05:30
3,0.167041,0.226239,0.926027,0.811382,0.996185,0.91942,0.996849,0.992021,0.976119,05:43
4,0.189747,0.228359,0.931507,0.829989,0.995651,0.923333,0.997196,0.991926,0.977027,05:33
5,0.170844,0.225572,0.934247,0.820441,0.995857,0.918551,0.997057,0.992496,0.97599,05:23


epoch,train_loss,valid_loss,accuracy,f1score,healthy-AUC,multiple_diseases-AUC,rust-AUC,scab-AUC,macro-AUC,time
0,1.369613,0.487169,0.826923,0.692643,0.97093,0.748876,0.966415,0.969125,0.913837,04:53
1,0.83984,0.327134,0.909341,0.802289,0.985739,0.861272,0.986261,0.985249,0.95463,04:57
2,0.585535,0.313933,0.906593,0.805834,0.989506,0.876846,0.985975,0.988641,0.960242,04:54
3,0.500683,0.265537,0.909341,0.798173,0.994113,0.910726,0.988964,0.989785,0.970897,04:53
4,0.400864,0.264781,0.917582,0.7972,0.994442,0.921323,0.990364,0.989899,0.974007,04:50
5,0.375768,0.250106,0.923077,0.818655,0.994625,0.917148,0.9903,0.990318,0.973098,04:49


epoch,train_loss,valid_loss,accuracy,f1score,healthy-AUC,multiple_diseases-AUC,rust-AUC,scab-AUC,macro-AUC,time
0,0.366415,0.25005,0.92033,0.805459,0.995064,0.915864,0.990618,0.99169,0.973309,04:59
1,0.326874,0.226896,0.925824,0.814461,0.996343,0.918112,0.991826,0.992415,0.974674,04:59
2,0.310625,0.205562,0.936813,0.851274,0.997696,0.920199,0.992749,0.993596,0.97606,05:00
3,0.263966,0.205648,0.942308,0.859581,0.997404,0.918272,0.992971,0.994016,0.975666,05:01
4,0.237191,0.195472,0.950549,0.866319,0.998537,0.915222,0.993766,0.993711,0.975309,05:01
5,0.184303,0.216487,0.936813,0.855037,0.997879,0.911207,0.993417,0.993711,0.974053,04:59


epoch,train_loss,valid_loss,accuracy,f1score,healthy-AUC,multiple_diseases-AUC,rust-AUC,scab-AUC,macro-AUC,time
0,0.163856,0.198766,0.950549,0.862619,0.998281,0.912171,0.993766,0.993863,0.97452,04:58
1,0.209233,0.208795,0.942308,0.856086,0.998208,0.915703,0.993512,0.99413,0.975388,04:59
2,0.204964,0.204544,0.936813,0.839735,0.997769,0.911047,0.993576,0.994473,0.974216,05:01
3,0.174223,0.212282,0.93956,0.857428,0.997916,0.924213,0.993719,0.994244,0.977523,04:59
4,0.192834,0.205352,0.936813,0.830013,0.998062,0.918272,0.993766,0.994435,0.976134,05:04
5,0.195528,0.202161,0.93956,0.864466,0.998281,0.927425,0.993448,0.994397,0.978388,04:59


epoch,train_loss,valid_loss,accuracy,f1score,healthy-AUC,multiple_diseases-AUC,rust-AUC,scab-AUC,macro-AUC,time
0,1.322257,0.478953,0.82967,0.716659,0.946687,0.840408,0.97077,0.967307,0.931293,04:54
1,0.723328,0.422313,0.85989,0.769591,0.965437,0.892245,0.98519,0.973371,0.954061,04:55
2,0.578925,0.336195,0.870879,0.75314,0.977732,0.837551,0.990009,0.980502,0.946448,04:51
3,0.492553,0.311051,0.881868,0.74616,0.980055,0.844694,0.99426,0.981087,0.950024,04:54
4,0.402564,0.289989,0.881868,0.726419,0.980635,0.867755,0.994685,0.984015,0.956773,04:54
5,0.332504,0.28575,0.879121,0.741959,0.98306,0.870204,0.994296,0.985118,0.958169,04:55


epoch,train_loss,valid_loss,accuracy,f1score,healthy-AUC,multiple_diseases-AUC,rust-AUC,scab-AUC,macro-AUC,time
0,0.35608,0.251588,0.903846,0.766279,0.985417,0.878163,0.998335,0.985566,0.96187,05:01
1,0.288964,0.236977,0.909341,0.827809,0.989993,0.893061,0.999291,0.984429,0.966694,05:03
2,0.264577,0.209075,0.923077,0.837951,0.99153,0.916122,0.998689,0.988184,0.973631,05:01
3,0.236373,0.192795,0.928571,0.833133,0.993135,0.920816,0.999433,0.990389,0.975943,04:57
4,0.236436,0.195819,0.936813,0.827991,0.993887,0.911633,0.999646,0.991353,0.97413,04:58
5,0.226403,0.191016,0.928571,0.828874,0.993135,0.903265,0.999646,0.990802,0.971712,04:59


epoch,train_loss,valid_loss,accuracy,f1score,healthy-AUC,multiple_diseases-AUC,rust-AUC,scab-AUC,macro-AUC,time
0,0.181027,0.177699,0.942308,0.836499,0.993272,0.917959,0.999575,0.991801,0.975652,04:59
1,0.21302,0.171203,0.93956,0.830075,0.994092,0.914286,0.999752,0.992146,0.975069,04:59
2,0.192544,0.182383,0.928571,0.782688,0.993408,0.920408,0.999504,0.991801,0.97628,05:00
3,0.197552,0.180492,0.934066,0.805242,0.993887,0.917755,0.999433,0.991939,0.975753,05:15
4,0.194397,0.174544,0.93956,0.830075,0.993921,0.917347,0.99961,0.99187,0.975687,05:12
5,0.1962,0.179526,0.931319,0.784726,0.993682,0.920816,0.999469,0.992077,0.976511,05:05


epoch,train_loss,valid_loss,accuracy,f1score,healthy-AUC,multiple_diseases-AUC,rust-AUC,scab-AUC,macro-AUC,time
0,1.396424,0.582975,0.791209,0.645093,0.944094,0.659281,0.956733,0.953908,0.878504,04:53
1,0.796987,0.545657,0.835165,0.697611,0.962266,0.815029,0.969561,0.967866,0.928681,04:51
2,0.566087,0.385154,0.881868,0.786349,0.978891,0.850193,0.986394,0.979782,0.948815,04:49
3,0.44076,0.3446,0.898352,0.816111,0.983534,0.872351,0.991426,0.979715,0.956756,04:49
4,0.40593,0.31114,0.909341,0.84485,0.986748,0.867052,0.993015,0.982259,0.957269,04:49
5,0.311984,0.307215,0.909341,0.850374,0.987581,0.864804,0.993015,0.981958,0.95684,04:55


epoch,train_loss,valid_loss,accuracy,f1score,healthy-AUC,multiple_diseases-AUC,rust-AUC,scab-AUC,macro-AUC,time
0,0.348989,0.323957,0.903846,0.813127,0.98893,0.872351,0.994306,0.981556,0.959286,04:50
1,0.300539,0.250449,0.931319,0.861955,0.992303,0.899326,0.995928,0.988318,0.968969,04:53
2,0.280708,0.218604,0.934066,0.8637,0.995278,0.893706,0.997352,0.989423,0.96894,04:52
3,0.218682,0.228565,0.931319,0.850559,0.996151,0.892742,0.996756,0.989891,0.968885,04:52
4,0.195584,0.211272,0.945055,0.883191,0.996588,0.895311,0.997352,0.99036,0.969903,04:55
5,0.225481,0.202296,0.950549,0.877572,0.996945,0.895633,0.997219,0.991096,0.970223,04:53


epoch,train_loss,valid_loss,accuracy,f1score,healthy-AUC,multiple_diseases-AUC,rust-AUC,scab-AUC,macro-AUC,time
0,0.174404,0.211766,0.945055,0.883564,0.996667,0.887604,0.997153,0.990427,0.967963,04:53
1,0.188941,0.196009,0.93956,0.857303,0.996945,0.896596,0.997385,0.991029,0.970489,04:56
2,0.166783,0.216863,0.93956,0.857239,0.996627,0.895472,0.997219,0.990393,0.969928,04:56
3,0.157075,0.198525,0.945055,0.873252,0.996945,0.888568,0.997617,0.99036,0.968372,04:55
4,0.20248,0.197307,0.945055,0.861986,0.996984,0.89483,0.997848,0.990594,0.970064,04:57
5,0.186973,0.215726,0.93956,0.857192,0.996588,0.887925,0.997319,0.990326,0.96804,04:56


epoch,train_loss,valid_loss,accuracy,f1score,healthy-AUC,multiple_diseases-AUC,rust-AUC,scab-AUC,macro-AUC,time
0,1.39244,0.542978,0.81044,0.660921,0.96413,0.715674,0.972954,0.947976,0.900184,04:48
1,0.827351,0.447537,0.862637,0.703164,0.98417,0.83729,0.982962,0.976579,0.94525,04:46
2,0.63797,0.314547,0.884615,0.738731,0.994244,0.913647,0.989255,0.981292,0.969609,04:47
3,0.515729,0.27266,0.901099,0.776003,0.993524,0.950993,0.992033,0.984674,0.980306,04:44
4,0.39662,0.266174,0.909341,0.800111,0.994963,0.952659,0.991732,0.986292,0.981412,04:45
5,0.376759,0.249595,0.909341,0.800775,0.994747,0.953353,0.991833,0.986257,0.981547,04:44


epoch,train_loss,valid_loss,accuracy,f1score,healthy-AUC,multiple_diseases-AUC,rust-AUC,scab-AUC,macro-AUC,time
0,0.285206,0.225975,0.917582,0.794117,0.995755,0.96793,0.993238,0.988919,0.986461,04:51
1,0.24224,0.20693,0.928571,0.848984,0.996798,0.973206,0.994075,0.991653,0.988933,04:51
2,0.2303,0.186931,0.928571,0.852658,0.99687,0.982646,0.995849,0.99374,0.992276,04:51
3,0.234192,0.200228,0.928571,0.828059,0.996906,0.977926,0.996921,0.994172,0.991481,04:50
4,0.206629,0.182941,0.931319,0.844218,0.997697,0.980425,0.996653,0.994675,0.992363,04:57
5,0.210012,0.174198,0.942308,0.860389,0.997661,0.979037,0.996519,0.994999,0.992054,04:49


epoch,train_loss,valid_loss,accuracy,f1score,healthy-AUC,multiple_diseases-AUC,rust-AUC,scab-AUC,macro-AUC,time
0,0.186954,0.172837,0.931319,0.841052,0.997769,0.981813,0.996854,0.994927,0.992841,04:46
1,0.204784,0.186746,0.92033,0.806962,0.998273,0.979453,0.997121,0.995179,0.992507,04:48
2,0.196746,0.18206,0.928571,0.841978,0.998093,0.98334,0.996854,0.994603,0.993223,04:50
3,0.198505,0.183048,0.923077,0.811305,0.998273,0.983063,0.997356,0.994783,0.993369,04:51
4,0.185971,0.175876,0.92033,0.806606,0.998273,0.983063,0.997289,0.995611,0.993559,04:49
5,0.182004,0.188173,0.914835,0.776188,0.998273,0.982507,0.997255,0.994963,0.99325,04:50


In [8]:
predictions = preds[0][0]
for i in range(1, len(preds)):
    print(i)
    predictions += preds[i][0]
outputs = predictions.cpu().numpy()

sub = pd.read_csv('../input/plant-pathology-2020-fgvc7/sample_submission.csv')
for lbl in range(4): 
    sub[[label_dict[lbl]]] = outputs[:, lbl]
print(sub.head)
sub.to_csv('submission.csv', index=False)

1
2
3
4
<bound method NDFrame.head of        image_id       healthy  multiple_diseases      rust          scab
0        Test_0  5.349843e-04       2.017066e-03  4.997447  1.000129e-06
1        Test_1  5.787787e-05       1.124309e-01  4.884352  3.158933e-03
2        Test_2  2.224101e-02       6.141452e-03  0.000002  4.971615e+00
3        Test_3  4.999962e+00       8.470369e-07  0.000006  3.056160e-05
4        Test_4  1.036413e-08       9.008597e-04  4.999099  4.117399e-09
...         ...           ...                ...       ...           ...
1816  Test_1816  1.034877e-06       4.846213e-04  4.999514  9.354980e-08
1817  Test_1817  1.249826e-01       1.720678e+00  0.125231  3.029108e+00
1818  Test_1818  1.010568e-06       3.583019e-03  4.996414  1.808642e-06
1819  Test_1819  4.996661e+00       1.013101e-04  0.000132  3.105818e-03
1820  Test_1820  1.101254e-01       8.880148e-01  0.001729  4.000131e+00

[1821 rows x 5 columns]>
