In [1]:
import warnings  

import os
import time
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import Subset
from torch.nn.functional import softmax

from efficientnet_pytorch import EfficientNet

from typing import Callable, List, Tuple, Dict
from pathlib import Path

from transformers import AdamW
from transformers import get_cosine_schedule_with_warmup

from sklearn.metrics import roc_auc_score, accuracy_score
from sklearn.model_selection import train_test_split
from collections import defaultdict, OrderedDict
from tqdm.notebook import tqdm
from torchsummary import summary

import matplotlib
matplotlib.rcParams.update({'figure.figsize': (16, 12), 'font.size': 14})
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import clear_output

In [2]:
warnings.simplefilter("ignore", UserWarning)
warnings.simplefilter("ignore", DeprecationWarning)
warnings.filterwarnings('ignore')
os.environ["PYTHONWARNINGS"] = "ignore"

In [3]:
EXPERIMENT_NAME = "01_efficientnet-b5"

class ConfigExperiment:
    logdir = f"./logs/{EXPERIMENT_NAME}"
    save_dirname = EXPERIMENT_NAME
    submission_file = f"{EXPERIMENT_NAME}.csv"
    seed = 42
    batch_size = 4
    model_name = 'efficientnet-b5'
    size = 512
    num_workers = 20
    root_images = "../../../data/raw/plant-pathology-2020-fgvc7/images/"
    root = "../../../data/raw/plant-pathology-2020-fgvc7/"
    num_classes = 4
    patience = 10
    early_stopping_delta = 1e-4
    num_epochs = 200
    lr = 0.003
    class_names = ["healthy", "multiple_diseases", "rust", "scab"]
    is_fp16_used = False
    
    
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    
config = ConfigExperiment()
set_seed(config.seed)
config.size = EfficientNet.get_image_size(config.model_name)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = "cuda" if torch.cuda.is_available() else "cpu"

try:
    # Create target Directory
    os.mkdir(config.save_dirname)
    print("Directory " , config.save_dirname ,  " Created ") 
except FileExistsError:
    print("Directory " , config.save_dirname ,  " already exists")

Directory  01_efficientnet-b5  already exists


In [4]:
class PlantDataset(Dataset):
    
    def __init__(self, df, config, transforms=None):
    
        self.df = df
        self.images_dir = config.root_images
        self.class_names = config.class_names
        self.transforms=transforms
        
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, idx):
        image_src = self.images_dir + self.df.iloc[idx]['image_id'] + '.jpg'
        image = cv2.imread(image_src, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        labels = self.df.iloc[idx][self.class_names].values.astype(np.int8)
        label = torch.argmax(torch.from_numpy(labels))
        
        if self.transforms:
            transformed = self.transforms(image=image)
            image = transformed['image']

        return image, label

In [5]:
def pre_transforms(image_size=224):
    # Convert the image to a square of size image_size x image_size
    # (keeping aspect ratio)
    result = [
        A.LongestMaxSize(max_size=image_size),
        A.PadIfNeeded(image_size, image_size, border_mode=0)
    ]
    
    return result

def hard_transforms():
    result = [
        # Random shifts, stretches and turns with a 50% probability
#         A.RandomResizedCrop(height=config.size, width=config.size, p=1.0),
        A.OneOf([
            A.Rotate(limit=90, p=1),
            A.HorizontalFlip(p=1),
            A.VerticalFlip(p=1),
            A.OpticalDistortion(p=1),
            A.ShiftScaleRotate(rotate_limit=90, shift_limit=0.2, scale_limit=0.1, p=1)
        ], p=0.8),
        
        # Pixels
        A.OneOf([
            A.IAAEmboss(p=1.0),
            A.IAASharpen(p=1.0),
            A.Blur(p=1.0),
        ], p=0.5),

        # Affine
        A.OneOf([
            A.RandomContrast(limit=0.2, p=1),
            A.ElasticTransform(p=1),
            A.IAAPiecewiseAffine(p=1)
        ], p=0.5),
    ]
    
    return result

def post_transforms():
    # we use ImageNet image normalization
    # and convert it to torch.Tensor
    return [A.Normalize(p=1.0), ToTensorV2(p=1.0),]

def compose(transforms_to_compose):
    # combine all augmentations into one single pipeline
    result = A.Compose([item for sublist in transforms_to_compose for item in sublist])
    return result

In [6]:
train_df = pd.read_csv(config.root + 'train.csv')
train, valid = train_test_split(train_df, test_size=0.33, random_state=config.seed, shuffle=True, stratify=train_df[config.class_names])

train_transforms = compose([
    pre_transforms(config.size),
    hard_transforms(), 
    post_transforms()
])
valid_transforms = compose([
    pre_transforms(config.size), 
    post_transforms()
])

show_transforms = compose([
    pre_transforms(config.size),
    hard_transforms()
])

train_dataset = PlantDataset(train, config, train_transforms)
valid_dataset = PlantDataset(valid, config, valid_transforms)

train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)
valid_dataloader = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers)

In [7]:
def get_model(model_name: str, num_classes: int, pretrained: str = "imagenet") -> EfficientNet:
    model = EfficientNet.from_pretrained(model_name)
    for param in model.parameters():
        param.requires_grad = False
    num_ftrs = model._fc.in_features
    model._fc = nn.Sequential(nn.Linear(num_ftrs, num_classes, bias = True))
    return model

model = get_model(config.model_name, config.num_classes)

Loaded pretrained weights for efficientnet-b5


In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True, mode="min", factor=0.3)

In [9]:
class Trainer:
    def __init__(self, model, train_dataloader: DataLoader, valid_dataloader: DataLoader, criterion, optimizer, scheduler, device, config: ConfigExperiment):
        self.model = model
        self.train_dataloader = train_dataloader
        self.valid_dataloader = valid_dataloader
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.config = config
        self.train_metrics = {
            'avg_loss': [],
            'auc/_mean': [],
            'auc/healthy': [],
            'auc/multiple_diseases': [],
            'auc/rust': [],
            'auc/scab': [],
        }
        self.valid_metrics = {
            'avg_loss': [],
            'auc/_mean': [],
            'auc/healthy': [],
            'auc/multiple_diseases': [],
            'auc/rust': [],
            'auc/scab': [],
        }
        self.counter = 0
        self.delta = config.early_stopping_delta
      
    def run(self):
        self.model.to(device)
        best_valid_loss = float('inf')
        best_valid_auc_mean = 0

        try:
            for i_epoch in tqdm(range(self.config.num_epochs), desc='Epochs', total=config.num_epochs, position=1, leave=True):
                start_time = time.time()

                train_loss, train_outputs, train_targets = self._train()
                valid_loss, valid_outputs, valid_targets = self._evaluate()
                    
                self.train_metrics["avg_loss"].append(train_loss)
                self.train_metrics["auc/_mean"].append(self.comp_metric(train_outputs, train_targets))
                self.train_metrics["auc/healthy"].append(self.healthy_roc_auc(train_outputs, train_targets))
                self.train_metrics["auc/multiple_diseases"].append(self.multiple_diseases_roc_auc(train_outputs, train_targets))
                self.train_metrics["auc/rust"].append(self.rust_roc_auc(train_outputs, train_targets))
                self.train_metrics["auc/scab"].append(self.scab_roc_auc(train_outputs, train_targets))
                
                self.valid_metrics["avg_loss"].append(valid_loss)
                self.valid_metrics["auc/_mean"].append(self.comp_metric(valid_outputs, valid_targets))
                self.valid_metrics["auc/healthy"].append(self.healthy_roc_auc(valid_outputs, valid_targets))
                self.valid_metrics["auc/multiple_diseases"].append(self.multiple_diseases_roc_auc(valid_outputs, valid_targets))
                self.valid_metrics["auc/rust"].append(self.rust_roc_auc(valid_outputs, valid_targets))
                self.valid_metrics["auc/scab"].append(self.scab_roc_auc(valid_outputs, valid_targets))
                
                end_time = time.time()
                epoch_mins, epoch_secs = self._epoch_time(start_time, end_time)
                self.print_progress(i_epoch, epoch_mins, epoch_secs)
                
                self.scheduler.step(self.valid_metrics["auc/_mean"][-1])
                
                if valid_loss < best_valid_loss:
                    self.counter = 0
                    best_valid_loss = valid_loss
                    torch.save(model.state_dict(), f"{config.save_dirname}/best_model_epoch={i_epoch+1}_loss={best_valid_loss}.pth")
                else:
                    self.counter += 1
                    
                if self.valid_metrics["auc/_mean"][-1] > best_valid_auc_mean:
                    best_valid_auc_mean = self.valid_metrics["auc/_mean"][-1]
                    torch.save(model.state_dict(), f"{config.save_dirname}/best_model_epoch={i_epoch+1}_auc_mean={best_valid_auc_mean}.pth")
                    
                if self.counter > self.config.patience:
                    print("EarlyStopping")
                    break
        except KeyboardInterrupt: 
            pass
        
        return self.train_metrics, self.valid_metrics
        
    def _train(self):
        model.train()
        epoch_loss = 0
        epoch_output = None
        epoch_target = None
        for i, (images, labels) in tqdm(enumerate(self.train_dataloader), desc='Train', total=len(self.train_dataloader), position=2, leave=True):
            loss_iten, outputs = self._train_process(images, labels)
            epoch_loss += loss_iten              

            if epoch_output is None:
                epoch_output = outputs.cpu().data
            else:
                epoch_output = torch.cat((epoch_output, outputs.cpu().data))

            if epoch_target is None:
                epoch_target = labels.cpu().data
            else:
                epoch_target = torch.cat((epoch_target, labels.cpu().data))
            
        return epoch_loss / len(self.train_dataloader), epoch_output, epoch_target
    
    def _train_process(self, images, labels):
        images = images.to(self.device)
        labels = labels.to(self.device)
        self.optimizer.zero_grad()
        outputs = self.model(images)
        loss = self.criterion(outputs, labels)
        loss.backward()
        self.optimizer.step()
        return loss.item(), outputs
            
    def _evaluate(self):
        model.eval()
        epoch_loss = 0
        epoch_output = None
        epoch_target = None
        with torch.no_grad():
            for i, (images, labels) in enumerate(self.valid_dataloader):
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                epoch_loss += loss.item()
                
                if epoch_output is None:
                    epoch_output = outputs.cpu().data
                else:
                    epoch_output = torch.cat((epoch_output, outputs.cpu().data))

                if epoch_target is None:
                    epoch_target = labels.cpu().data
                else:
                    epoch_target = torch.cat((epoch_target, labels.cpu().data))

        return epoch_loss / len(self.valid_dataloader), epoch_output, epoch_target
    
#     def _evaluate_tta(self):
#         n_samples = len(dataset)
#         y_probas_tta = np.zeros((n_samples, config.num_classes, tta_count), dtype=np.float32)

#         model.eval()
#         model.to(device)
#         with torch.no_grad():
#             for i_epoch in tqdm(range(tta_count), desc='TTA', total=tta_count, position=1, leave=True):
#                 for i, (images, idx) in enumerate(dataloader):
#                     images = images.to(device)
#                     outputs = model(images)
#                     y_pred = F.softmax(outputs, dim=1)

#                     tta_index = i_epoch - 1
#                     start_index = (i % len(dataloader)) * config.batch_size
#                     end_index = min(start_index + config.batch_size, n_samples)
#                     batch_y_probas = y_pred.detach().cpu().numpy()                
#                     y_probas_tta[start_index:end_index, :, tta_index] = batch_y_probas

#         return y_probas_tta

#     def show_history(self):
#         fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 8))
#         clear_output(True)
#         ax[0].plot(history, label='train loss')
#         ax[0].set_xlabel('Batch')
#         ax[0].set_title('Train loss')
#         if train_history is not None:
#             ax[1].plot(train_history, label='general train history')
#             ax[1].set_xlabel('Epoch')
#         if valid_history is not None:
#             ax[1].plot(valid_history, label='general valid history')
#         plt.legend()
#         plt.show()

    
    def _epoch_time(self, start_time, end_time):
        elapsed_time = end_time - start_time
        elapsed_mins = int(elapsed_time / 60)
        elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
        return elapsed_mins, elapsed_secs

    def print_progress(self, i_epoch, epoch_mins, epoch_secs):
        i_epoch = i_epoch + 1
        print(f"Epoch: {i_epoch:02} | Time: {epoch_mins}m {epoch_secs}s")
        print("Training Results - Average Loss: {:.4f} | auc/_mean: {:.4f} | auc/healthy: {:.4f} | auc/multiple_diseases: {:.4f}, auc/rust: {:.4f}, auc/scab: {:.4f}"
            .format(
                self.train_metrics['avg_loss'][-1], 
                self.train_metrics['auc/_mean'][-1],
                self.train_metrics['auc/healthy'][-1],
                self.train_metrics['auc/multiple_diseases'][-1],
                self.train_metrics['auc/rust'][-1],
                self.train_metrics['auc/scab'][-1],
            ))
        print("Evaluating Results - Average Loss: {:.4f} | auc/_mean: {:.4f} | auc/healthy: {:.4f} | auc/multiple_diseases: {:.4f}, auc/rust: {:.4f}, auc/scab: {:.4f}"
            .format( 
                self.valid_metrics['avg_loss'][-1], 
                self.valid_metrics['auc/_mean'][-1],
                self.valid_metrics['auc/healthy'][-1],
                self.valid_metrics['auc/multiple_diseases'][-1],
                self.valid_metrics['auc/rust'][-1],
                self.valid_metrics['auc/scab'][-1],
            ))
        print()
        
    def comp_metric(self, preds, targs, labels=range(4)):
        preds = torch.sigmoid(preds)
        targs = torch.eye(4)[targs]
        return np.mean([roc_auc_score(targs[:,i], preds[:,i]) for i in labels])

    def healthy_roc_auc(self, *args):
        return self.comp_metric(*args, labels=[0])

    def multiple_diseases_roc_auc(self, *args):
        return self.comp_metric(*args, labels=[1])

    def rust_roc_auc(self, *args):
        return self.comp_metric(*args, labels=[2])

    def scab_roc_auc(self, *args):
        return self.comp_metric(*args, labels=[3])



In [10]:
trainer = Trainer(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, device, config)
trainer.run();

HBox(children=(FloatProgress(value=0.0, description='Epochs', max=200.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Train', max=39.0, style=ProgressStyle(description_width='…


Epoch: 01 | Time: 1m 10s
Training Results - Average Loss: 0.9506 | auc/_mean: 0.7662 | auc/healthy: 0.8205 | auc/multiple_diseases: 0.5440, auc/rust: 0.8694, auc/scab: 0.8310
Evaluating Results - Average Loss: 1.0029 | auc/_mean: 0.8128 | auc/healthy: 0.9153 | auc/multiple_diseases: 0.5506, auc/rust: 0.9437, auc/scab: 0.8415



HBox(children=(FloatProgress(value=0.0, description='Train', max=39.0, style=ProgressStyle(description_width='…


Epoch: 02 | Time: 1m 7s
Training Results - Average Loss: 0.6917 | auc/_mean: 0.8454 | auc/healthy: 0.8920 | auc/multiple_diseases: 0.6473, auc/rust: 0.9251, auc/scab: 0.9172
Evaluating Results - Average Loss: 0.7047 | auc/_mean: 0.8430 | auc/healthy: 0.9448 | auc/multiple_diseases: 0.5588, auc/rust: 0.9588, auc/scab: 0.9095



HBox(children=(FloatProgress(value=0.0, description='Train', max=39.0, style=ProgressStyle(description_width='…


Epoch: 03 | Time: 1m 8s
Training Results - Average Loss: 0.6591 | auc/_mean: 0.8660 | auc/healthy: 0.8991 | auc/multiple_diseases: 0.7262, auc/rust: 0.9325, auc/scab: 0.9060
Evaluating Results - Average Loss: 0.5344 | auc/_mean: 0.8909 | auc/healthy: 0.9671 | auc/multiple_diseases: 0.6834, auc/rust: 0.9710, auc/scab: 0.9422



HBox(children=(FloatProgress(value=0.0, description='Train', max=39.0, style=ProgressStyle(description_width='…


Epoch: 04 | Time: 1m 8s
Training Results - Average Loss: 0.6198 | auc/_mean: 0.8792 | auc/healthy: 0.9112 | auc/multiple_diseases: 0.7417, auc/rust: 0.9426, auc/scab: 0.9212
Evaluating Results - Average Loss: 0.4177 | auc/_mean: 0.9094 | auc/healthy: 0.9705 | auc/multiple_diseases: 0.7268, auc/rust: 0.9769, auc/scab: 0.9635



HBox(children=(FloatProgress(value=0.0, description='Train', max=39.0, style=ProgressStyle(description_width='…


Epoch: 05 | Time: 1m 8s
Training Results - Average Loss: 0.5905 | auc/_mean: 0.8931 | auc/healthy: 0.9232 | auc/multiple_diseases: 0.7830, auc/rust: 0.9424, auc/scab: 0.9237
Evaluating Results - Average Loss: 0.3956 | auc/_mean: 0.9125 | auc/healthy: 0.9706 | auc/multiple_diseases: 0.7325, auc/rust: 0.9800, auc/scab: 0.9670

Epoch     5: reducing learning rate of group 0 to 9.0000e-04.


HBox(children=(FloatProgress(value=0.0, description='Train', max=39.0, style=ProgressStyle(description_width='…


Epoch: 06 | Time: 1m 8s
Training Results - Average Loss: 0.5703 | auc/_mean: 0.8940 | auc/healthy: 0.9217 | auc/multiple_diseases: 0.7840, auc/rust: 0.9410, auc/scab: 0.9293
Evaluating Results - Average Loss: 0.3927 | auc/_mean: 0.9116 | auc/healthy: 0.9694 | auc/multiple_diseases: 0.7301, auc/rust: 0.9791, auc/scab: 0.9677



HBox(children=(FloatProgress(value=0.0, description='Train', max=39.0, style=ProgressStyle(description_width='…


Epoch: 07 | Time: 1m 8s
Training Results - Average Loss: 0.5739 | auc/_mean: 0.8905 | auc/healthy: 0.9215 | auc/multiple_diseases: 0.7667, auc/rust: 0.9353, auc/scab: 0.9383
Evaluating Results - Average Loss: 0.3919 | auc/_mean: 0.9127 | auc/healthy: 0.9670 | auc/multiple_diseases: 0.7375, auc/rust: 0.9789, auc/scab: 0.9675



HBox(children=(FloatProgress(value=0.0, description='Train', max=39.0, style=ProgressStyle(description_width='…


Epoch: 08 | Time: 1m 8s
Training Results - Average Loss: 0.5599 | auc/_mean: 0.8966 | auc/healthy: 0.9234 | auc/multiple_diseases: 0.7746, auc/rust: 0.9489, auc/scab: 0.9394
Evaluating Results - Average Loss: 0.3909 | auc/_mean: 0.9160 | auc/healthy: 0.9666 | auc/multiple_diseases: 0.7518, auc/rust: 0.9781, auc/scab: 0.9676



HBox(children=(FloatProgress(value=0.0, description='Train', max=39.0, style=ProgressStyle(description_width='…


Epoch: 09 | Time: 1m 8s
Training Results - Average Loss: 0.5346 | auc/_mean: 0.9011 | auc/healthy: 0.9263 | auc/multiple_diseases: 0.8023, auc/rust: 0.9523, auc/scab: 0.9237
Evaluating Results - Average Loss: 0.3943 | auc/_mean: 0.9178 | auc/healthy: 0.9659 | auc/multiple_diseases: 0.7608, auc/rust: 0.9765, auc/scab: 0.9680

Epoch     9: reducing learning rate of group 0 to 2.7000e-04.


HBox(children=(FloatProgress(value=0.0, description='Train', max=39.0, style=ProgressStyle(description_width='…


Epoch: 10 | Time: 1m 10s
Training Results - Average Loss: 0.5308 | auc/_mean: 0.9035 | auc/healthy: 0.9283 | auc/multiple_diseases: 0.8033, auc/rust: 0.9502, auc/scab: 0.9323
Evaluating Results - Average Loss: 0.3962 | auc/_mean: 0.9183 | auc/healthy: 0.9653 | auc/multiple_diseases: 0.7641, auc/rust: 0.9762, auc/scab: 0.9678



HBox(children=(FloatProgress(value=0.0, description='Train', max=39.0, style=ProgressStyle(description_width='…


Epoch: 11 | Time: 1m 8s
Training Results - Average Loss: 0.5366 | auc/_mean: 0.9110 | auc/healthy: 0.9314 | auc/multiple_diseases: 0.8299, auc/rust: 0.9468, auc/scab: 0.9361
Evaluating Results - Average Loss: 0.3978 | auc/_mean: 0.9180 | auc/healthy: 0.9649 | auc/multiple_diseases: 0.7640, auc/rust: 0.9756, auc/scab: 0.9674



HBox(children=(FloatProgress(value=0.0, description='Train', max=39.0, style=ProgressStyle(description_width='…


Epoch: 12 | Time: 1m 9s
Training Results - Average Loss: 0.5530 | auc/_mean: 0.8972 | auc/healthy: 0.9229 | auc/multiple_diseases: 0.7851, auc/rust: 0.9489, auc/scab: 0.9319
Evaluating Results - Average Loss: 0.3980 | auc/_mean: 0.9176 | auc/healthy: 0.9646 | auc/multiple_diseases: 0.7623, auc/rust: 0.9757, auc/scab: 0.9677



HBox(children=(FloatProgress(value=0.0, description='Train', max=39.0, style=ProgressStyle(description_width='…


Epoch: 13 | Time: 1m 8s
Training Results - Average Loss: 0.5487 | auc/_mean: 0.9109 | auc/healthy: 0.9148 | auc/multiple_diseases: 0.8376, auc/rust: 0.9485, auc/scab: 0.9427
Evaluating Results - Average Loss: 0.3986 | auc/_mean: 0.9175 | auc/healthy: 0.9649 | auc/multiple_diseases: 0.7619, auc/rust: 0.9755, auc/scab: 0.9675

Epoch    13: reducing learning rate of group 0 to 8.1000e-05.


HBox(children=(FloatProgress(value=0.0, description='Train', max=39.0, style=ProgressStyle(description_width='…


Epoch: 14 | Time: 1m 9s
Training Results - Average Loss: 0.5572 | auc/_mean: 0.8987 | auc/healthy: 0.9182 | auc/multiple_diseases: 0.8087, auc/rust: 0.9380, auc/scab: 0.9299
Evaluating Results - Average Loss: 0.3984 | auc/_mean: 0.9177 | auc/healthy: 0.9648 | auc/multiple_diseases: 0.7631, auc/rust: 0.9757, auc/scab: 0.9672



HBox(children=(FloatProgress(value=0.0, description='Train', max=39.0, style=ProgressStyle(description_width='…


Epoch: 15 | Time: 1m 9s
Training Results - Average Loss: 0.5119 | auc/_mean: 0.9100 | auc/healthy: 0.9295 | auc/multiple_diseases: 0.8199, auc/rust: 0.9537, auc/scab: 0.9371
Evaluating Results - Average Loss: 0.4002 | auc/_mean: 0.9179 | auc/healthy: 0.9644 | auc/multiple_diseases: 0.7643, auc/rust: 0.9757, auc/scab: 0.9672



HBox(children=(FloatProgress(value=0.0, description='Train', max=39.0, style=ProgressStyle(description_width='…


Epoch: 16 | Time: 1m 10s
Training Results - Average Loss: 0.5216 | auc/_mean: 0.9112 | auc/healthy: 0.9298 | auc/multiple_diseases: 0.8190, auc/rust: 0.9547, auc/scab: 0.9411
Evaluating Results - Average Loss: 0.4002 | auc/_mean: 0.9178 | auc/healthy: 0.9645 | auc/multiple_diseases: 0.7643, auc/rust: 0.9756, auc/scab: 0.9670



HBox(children=(FloatProgress(value=0.0, description='Train', max=39.0, style=ProgressStyle(description_width='…


Epoch: 17 | Time: 1m 8s
Training Results - Average Loss: 0.5199 | auc/_mean: 0.9054 | auc/healthy: 0.9397 | auc/multiple_diseases: 0.7868, auc/rust: 0.9597, auc/scab: 0.9356
Evaluating Results - Average Loss: 0.3989 | auc/_mean: 0.9181 | auc/healthy: 0.9644 | auc/multiple_diseases: 0.7652, auc/rust: 0.9758, auc/scab: 0.9672

Epoch    17: reducing learning rate of group 0 to 2.4300e-05.


HBox(children=(FloatProgress(value=0.0, description='Train', max=39.0, style=ProgressStyle(description_width='…


Epoch: 18 | Time: 1m 9s
Training Results - Average Loss: 0.5232 | auc/_mean: 0.9061 | auc/healthy: 0.9283 | auc/multiple_diseases: 0.8056, auc/rust: 0.9507, auc/scab: 0.9400
Evaluating Results - Average Loss: 0.3995 | auc/_mean: 0.9182 | auc/healthy: 0.9645 | auc/multiple_diseases: 0.7656, auc/rust: 0.9753, auc/scab: 0.9675



HBox(children=(FloatProgress(value=0.0, description='Train', max=39.0, style=ProgressStyle(description_width='…


Epoch: 19 | Time: 1m 8s
Training Results - Average Loss: 0.5383 | auc/_mean: 0.9015 | auc/healthy: 0.9399 | auc/multiple_diseases: 0.7889, auc/rust: 0.9440, auc/scab: 0.9332
Evaluating Results - Average Loss: 0.4001 | auc/_mean: 0.9184 | auc/healthy: 0.9643 | auc/multiple_diseases: 0.7671, auc/rust: 0.9750, auc/scab: 0.9675

EarlyStopping


In [10]:
model = get_model(config.model_name, config.num_classes)
model.load_state_dict(torch.load(f"{config.save_dirname}/best_model_epoch=19_auc_mean=0.9184434924272492.pth"))
for param in model.parameters():
    param.requires_grad = True

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=config.patience, verbose=True, mode="min", factor=0.3)

Loaded pretrained weights for efficientnet-b5


In [11]:
trainer = Trainer(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, device, config)
trainer.run();

HBox(children=(FloatProgress(value=0.0, description='Epochs', max=200.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 01 | Time: 3m 6s
Training Results - Average Loss: 1.2575 | auc/_mean: 0.6917 | auc/healthy: 0.7145 | auc/multiple_diseases: 0.5344, auc/rust: 0.8093, auc/scab: 0.7085
Evaluating Results - Average Loss: 0.6857 | auc/_mean: 0.8446 | auc/healthy: 0.9020 | auc/multiple_diseases: 0.5550, auc/rust: 0.9581, auc/scab: 0.9631



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 02 | Time: 3m 13s
Training Results - Average Loss: 0.8562 | auc/_mean: 0.7834 | auc/healthy: 0.8187 | auc/multiple_diseases: 0.5656, auc/rust: 0.9034, auc/scab: 0.8461
Evaluating Results - Average Loss: 0.4532 | auc/_mean: 0.8533 | auc/healthy: 0.9706 | auc/multiple_diseases: 0.4970, auc/rust: 0.9895, auc/scab: 0.9561



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 03 | Time: 3m 12s
Training Results - Average Loss: 0.6776 | auc/_mean: 0.8406 | auc/healthy: 0.8902 | auc/multiple_diseases: 0.6236, auc/rust: 0.9270, auc/scab: 0.9216
Evaluating Results - Average Loss: 0.4472 | auc/_mean: 0.8814 | auc/healthy: 0.9745 | auc/multiple_diseases: 0.5761, auc/rust: 0.9899, auc/scab: 0.9851



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 04 | Time: 3m 9s
Training Results - Average Loss: 0.6283 | auc/_mean: 0.8570 | auc/healthy: 0.9169 | auc/multiple_diseases: 0.6426, auc/rust: 0.9346, auc/scab: 0.9339
Evaluating Results - Average Loss: 0.4567 | auc/_mean: 0.8890 | auc/healthy: 0.9864 | auc/multiple_diseases: 0.5879, auc/rust: 0.9944, auc/scab: 0.9874



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 05 | Time: 3m 5s
Training Results - Average Loss: 0.5535 | auc/_mean: 0.8709 | auc/healthy: 0.9314 | auc/multiple_diseases: 0.6518, auc/rust: 0.9636, auc/scab: 0.9367
Evaluating Results - Average Loss: 1.2239 | auc/_mean: 0.8437 | auc/healthy: 0.9318 | auc/multiple_diseases: 0.5135, auc/rust: 0.9452, auc/scab: 0.9844



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 06 | Time: 3m 9s
Training Results - Average Loss: 0.5203 | auc/_mean: 0.8616 | auc/healthy: 0.9517 | auc/multiple_diseases: 0.5897, auc/rust: 0.9673, auc/scab: 0.9378
Evaluating Results - Average Loss: 0.2823 | auc/_mean: 0.8933 | auc/healthy: 0.9904 | auc/multiple_diseases: 0.6008, auc/rust: 0.9922, auc/scab: 0.9898



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 07 | Time: 3m 15s
Training Results - Average Loss: 0.4317 | auc/_mean: 0.8843 | auc/healthy: 0.9624 | auc/multiple_diseases: 0.6312, auc/rust: 0.9785, auc/scab: 0.9651
Evaluating Results - Average Loss: 0.3000 | auc/_mean: 0.8695 | auc/healthy: 0.9919 | auc/multiple_diseases: 0.5065, auc/rust: 0.9926, auc/scab: 0.9869



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 08 | Time: 3m 12s
Training Results - Average Loss: 0.3907 | auc/_mean: 0.8875 | auc/healthy: 0.9719 | auc/multiple_diseases: 0.6310, auc/rust: 0.9794, auc/scab: 0.9678
Evaluating Results - Average Loss: 0.2665 | auc/_mean: 0.8819 | auc/healthy: 0.9884 | auc/multiple_diseases: 0.5574, auc/rust: 0.9925, auc/scab: 0.9892



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 09 | Time: 3m 6s
Training Results - Average Loss: 0.4331 | auc/_mean: 0.8740 | auc/healthy: 0.9678 | auc/multiple_diseases: 0.5839, auc/rust: 0.9772, auc/scab: 0.9672
Evaluating Results - Average Loss: 0.3062 | auc/_mean: 0.8836 | auc/healthy: 0.9904 | auc/multiple_diseases: 0.5619, auc/rust: 0.9942, auc/scab: 0.9878



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 10 | Time: 3m 6s
Training Results - Average Loss: 0.3656 | auc/_mean: 0.8899 | auc/healthy: 0.9807 | auc/multiple_diseases: 0.6218, auc/rust: 0.9773, auc/scab: 0.9798
Evaluating Results - Average Loss: 0.4788 | auc/_mean: 0.9285 | auc/healthy: 0.9907 | auc/multiple_diseases: 0.7431, auc/rust: 0.9948, auc/scab: 0.9856



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 11 | Time: 3m 7s
Training Results - Average Loss: 0.3752 | auc/_mean: 0.8962 | auc/healthy: 0.9760 | auc/multiple_diseases: 0.6533, auc/rust: 0.9807, auc/scab: 0.9748
Evaluating Results - Average Loss: 0.3759 | auc/_mean: 0.8861 | auc/healthy: 0.9952 | auc/multiple_diseases: 0.5615, auc/rust: 0.9945, auc/scab: 0.9931



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 12 | Time: 3m 7s
Training Results - Average Loss: 0.3570 | auc/_mean: 0.8790 | auc/healthy: 0.9782 | auc/multiple_diseases: 0.5733, auc/rust: 0.9834, auc/scab: 0.9812
Evaluating Results - Average Loss: 0.3771 | auc/_mean: 0.8686 | auc/healthy: 0.9867 | auc/multiple_diseases: 0.5060, auc/rust: 0.9919, auc/scab: 0.9899



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 13 | Time: 3m 5s
Training Results - Average Loss: 0.3618 | auc/_mean: 0.8936 | auc/healthy: 0.9802 | auc/multiple_diseases: 0.6337, auc/rust: 0.9811, auc/scab: 0.9793
Evaluating Results - Average Loss: 0.2967 | auc/_mean: 0.8873 | auc/healthy: 0.9912 | auc/multiple_diseases: 0.5798, auc/rust: 0.9917, auc/scab: 0.9866



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 14 | Time: 3m 5s
Training Results - Average Loss: 0.3783 | auc/_mean: 0.8875 | auc/healthy: 0.9751 | auc/multiple_diseases: 0.6193, auc/rust: 0.9778, auc/scab: 0.9777
Evaluating Results - Average Loss: 0.3487 | auc/_mean: 0.8525 | auc/healthy: 0.9881 | auc/multiple_diseases: 0.4423, auc/rust: 0.9898, auc/scab: 0.9899



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 15 | Time: 3m 4s
Training Results - Average Loss: 0.3062 | auc/_mean: 0.9010 | auc/healthy: 0.9856 | auc/multiple_diseases: 0.6489, auc/rust: 0.9854, auc/scab: 0.9841
Evaluating Results - Average Loss: 0.2721 | auc/_mean: 0.8830 | auc/healthy: 0.9951 | auc/multiple_diseases: 0.5489, auc/rust: 0.9964, auc/scab: 0.9917



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 16 | Time: 3m 4s
Training Results - Average Loss: 0.3336 | auc/_mean: 0.8957 | auc/healthy: 0.9807 | auc/multiple_diseases: 0.6353, auc/rust: 0.9838, auc/scab: 0.9830
Evaluating Results - Average Loss: 0.2113 | auc/_mean: 0.9050 | auc/healthy: 0.9972 | auc/multiple_diseases: 0.6353, auc/rust: 0.9945, auc/scab: 0.9930

Epoch    16: reducing learning rate of group 0 to 9.0000e-04.


HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 17 | Time: 3m 4s
Training Results - Average Loss: 0.2109 | auc/_mean: 0.8988 | auc/healthy: 0.9948 | auc/multiple_diseases: 0.6145, auc/rust: 0.9927, auc/scab: 0.9933
Evaluating Results - Average Loss: 0.1642 | auc/_mean: 0.8970 | auc/healthy: 0.9969 | auc/multiple_diseases: 0.5966, auc/rust: 0.9982, auc/scab: 0.9961



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 18 | Time: 3m 5s
Training Results - Average Loss: 0.1813 | auc/_mean: 0.9086 | auc/healthy: 0.9975 | auc/multiple_diseases: 0.6495, auc/rust: 0.9940, auc/scab: 0.9935
Evaluating Results - Average Loss: 0.1835 | auc/_mean: 0.8803 | auc/healthy: 0.9965 | auc/multiple_diseases: 0.5336, auc/rust: 0.9963, auc/scab: 0.9948



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 19 | Time: 3m 5s
Training Results - Average Loss: 0.1701 | auc/_mean: 0.9155 | auc/healthy: 0.9966 | auc/multiple_diseases: 0.6760, auc/rust: 0.9940, auc/scab: 0.9953
Evaluating Results - Average Loss: 0.1862 | auc/_mean: 0.9053 | auc/healthy: 0.9970 | auc/multiple_diseases: 0.6320, auc/rust: 0.9952, auc/scab: 0.9968



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 20 | Time: 3m 4s
Training Results - Average Loss: 0.1773 | auc/_mean: 0.9081 | auc/healthy: 0.9965 | auc/multiple_diseases: 0.6473, auc/rust: 0.9945, auc/scab: 0.9942
Evaluating Results - Average Loss: 0.1574 | auc/_mean: 0.8897 | auc/healthy: 0.9973 | auc/multiple_diseases: 0.5665, auc/rust: 0.9979, auc/scab: 0.9971



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 21 | Time: 3m 4s
Training Results - Average Loss: 0.1541 | auc/_mean: 0.9071 | auc/healthy: 0.9968 | auc/multiple_diseases: 0.6411, auc/rust: 0.9951, auc/scab: 0.9954
Evaluating Results - Average Loss: 0.1458 | auc/_mean: 0.8942 | auc/healthy: 0.9982 | auc/multiple_diseases: 0.5841, auc/rust: 0.9969, auc/scab: 0.9978



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 22 | Time: 3m 6s
Training Results - Average Loss: 0.1510 | auc/_mean: 0.9195 | auc/healthy: 0.9972 | auc/multiple_diseases: 0.6891, auc/rust: 0.9946, auc/scab: 0.9969
Evaluating Results - Average Loss: 0.2067 | auc/_mean: 0.9053 | auc/healthy: 0.9973 | auc/multiple_diseases: 0.6302, auc/rust: 0.9965, auc/scab: 0.9971



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 23 | Time: 3m 5s
Training Results - Average Loss: 0.1698 | auc/_mean: 0.9127 | auc/healthy: 0.9966 | auc/multiple_diseases: 0.6652, auc/rust: 0.9939, auc/scab: 0.9950
Evaluating Results - Average Loss: 0.1442 | auc/_mean: 0.9225 | auc/healthy: 0.9979 | auc/multiple_diseases: 0.6977, auc/rust: 0.9987, auc/scab: 0.9958



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 24 | Time: 3m 4s
Training Results - Average Loss: 0.1504 | auc/_mean: 0.9153 | auc/healthy: 0.9984 | auc/multiple_diseases: 0.6726, auc/rust: 0.9958, auc/scab: 0.9943
Evaluating Results - Average Loss: 0.1831 | auc/_mean: 0.8952 | auc/healthy: 0.9947 | auc/multiple_diseases: 0.5917, auc/rust: 0.9967, auc/scab: 0.9977



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 25 | Time: 3m 5s
Training Results - Average Loss: 0.1652 | auc/_mean: 0.9081 | auc/healthy: 0.9979 | auc/multiple_diseases: 0.6461, auc/rust: 0.9936, auc/scab: 0.9951
Evaluating Results - Average Loss: 0.1690 | auc/_mean: 0.8987 | auc/healthy: 0.9955 | auc/multiple_diseases: 0.6048, auc/rust: 0.9985, auc/scab: 0.9960



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 26 | Time: 3m 5s
Training Results - Average Loss: 0.1525 | auc/_mean: 0.9193 | auc/healthy: 0.9985 | auc/multiple_diseases: 0.6892, auc/rust: 0.9949, auc/scab: 0.9945
Evaluating Results - Average Loss: 0.1716 | auc/_mean: 0.9031 | auc/healthy: 0.9972 | auc/multiple_diseases: 0.6211, auc/rust: 0.9984, auc/scab: 0.9956



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 27 | Time: 3m 5s
Training Results - Average Loss: 0.1528 | auc/_mean: 0.9254 | auc/healthy: 0.9979 | auc/multiple_diseases: 0.7141, auc/rust: 0.9946, auc/scab: 0.9948
Evaluating Results - Average Loss: 0.1953 | auc/_mean: 0.9082 | auc/healthy: 0.9972 | auc/multiple_diseases: 0.6428, auc/rust: 0.9962, auc/scab: 0.9967

Epoch    27: reducing learning rate of group 0 to 2.7000e-04.


HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 28 | Time: 3m 4s
Training Results - Average Loss: 0.1312 | auc/_mean: 0.9276 | auc/healthy: 0.9988 | auc/multiple_diseases: 0.7188, auc/rust: 0.9956, auc/scab: 0.9972
Evaluating Results - Average Loss: 0.1689 | auc/_mean: 0.9161 | auc/healthy: 0.9974 | auc/multiple_diseases: 0.6720, auc/rust: 0.9975, auc/scab: 0.9974



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 29 | Time: 3m 5s
Training Results - Average Loss: 0.0971 | auc/_mean: 0.9392 | auc/healthy: 0.9990 | auc/multiple_diseases: 0.7645, auc/rust: 0.9971, auc/scab: 0.9962
Evaluating Results - Average Loss: 0.1692 | auc/_mean: 0.9227 | auc/healthy: 0.9978 | auc/multiple_diseases: 0.6975, auc/rust: 0.9981, auc/scab: 0.9972



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 30 | Time: 3m 5s
Training Results - Average Loss: 0.0983 | auc/_mean: 0.9384 | auc/healthy: 0.9993 | auc/multiple_diseases: 0.7603, auc/rust: 0.9960, auc/scab: 0.9980
Evaluating Results - Average Loss: 0.1941 | auc/_mean: 0.9273 | auc/healthy: 0.9979 | auc/multiple_diseases: 0.7164, auc/rust: 0.9976, auc/scab: 0.9972



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 31 | Time: 3m 4s
Training Results - Average Loss: 0.1142 | auc/_mean: 0.9407 | auc/healthy: 0.9989 | auc/multiple_diseases: 0.7699, auc/rust: 0.9972, auc/scab: 0.9966
Evaluating Results - Average Loss: 0.1758 | auc/_mean: 0.9112 | auc/healthy: 0.9974 | auc/multiple_diseases: 0.6521, auc/rust: 0.9983, auc/scab: 0.9972



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 32 | Time: 3m 4s
Training Results - Average Loss: 0.0846 | auc/_mean: 0.9456 | auc/healthy: 0.9994 | auc/multiple_diseases: 0.7864, auc/rust: 0.9978, auc/scab: 0.9985
Evaluating Results - Average Loss: 0.1683 | auc/_mean: 0.9198 | auc/healthy: 0.9981 | auc/multiple_diseases: 0.6856, auc/rust: 0.9984, auc/scab: 0.9970



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 33 | Time: 3m 4s
Training Results - Average Loss: 0.0898 | auc/_mean: 0.9434 | auc/healthy: 0.9994 | auc/multiple_diseases: 0.7779, auc/rust: 0.9977, auc/scab: 0.9985
Evaluating Results - Average Loss: 0.1545 | auc/_mean: 0.9207 | auc/healthy: 0.9981 | auc/multiple_diseases: 0.6887, auc/rust: 0.9984, auc/scab: 0.9977



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 34 | Time: 3m 5s
Training Results - Average Loss: 0.0961 | auc/_mean: 0.9389 | auc/healthy: 0.9984 | auc/multiple_diseases: 0.7606, auc/rust: 0.9984, auc/scab: 0.9982
Evaluating Results - Average Loss: 0.1567 | auc/_mean: 0.9293 | auc/healthy: 0.9966 | auc/multiple_diseases: 0.7245, auc/rust: 0.9985, auc/scab: 0.9977

EarlyStopping


In [22]:
config.lr

0.003

In [23]:
model = get_model(config.model_name, config.num_classes)
model.load_state_dict(torch.load(f"{config.save_dirname}/best_model_epoch=34_auc_mean=0.9293390594149122.pth"))
for param in model.parameters():
    param.requires_grad = True

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0004)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=config.patience, verbose=True, mode="min", factor=0.3)

Loaded pretrained weights for efficientnet-b5


In [24]:
trainer = Trainer(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, device, config)
trainer.run();

HBox(children=(FloatProgress(value=0.0, description='Epochs', max=200.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 01 | Time: 3m 5s
Training Results - Average Loss: 0.0915 | auc/_mean: 0.9456 | auc/healthy: 0.9992 | auc/multiple_diseases: 0.7870, auc/rust: 0.9983, auc/scab: 0.9978
Evaluating Results - Average Loss: 0.1565 | auc/_mean: 0.9173 | auc/healthy: 0.9981 | auc/multiple_diseases: 0.6757, auc/rust: 0.9985, auc/scab: 0.9968



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 02 | Time: 3m 5s
Training Results - Average Loss: 0.0975 | auc/_mean: 0.9414 | auc/healthy: 0.9994 | auc/multiple_diseases: 0.7712, auc/rust: 0.9974, auc/scab: 0.9976
Evaluating Results - Average Loss: 0.1457 | auc/_mean: 0.9182 | auc/healthy: 0.9980 | auc/multiple_diseases: 0.6781, auc/rust: 0.9990, auc/scab: 0.9979



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 03 | Time: 3m 5s
Training Results - Average Loss: 0.1036 | auc/_mean: 0.9362 | auc/healthy: 0.9992 | auc/multiple_diseases: 0.7506, auc/rust: 0.9974, auc/scab: 0.9975
Evaluating Results - Average Loss: 0.2115 | auc/_mean: 0.9195 | auc/healthy: 0.9972 | auc/multiple_diseases: 0.6843, auc/rust: 0.9989, auc/scab: 0.9977



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 04 | Time: 3m 5s
Training Results - Average Loss: 0.0848 | auc/_mean: 0.9441 | auc/healthy: 0.9995 | auc/multiple_diseases: 0.7813, auc/rust: 0.9989, auc/scab: 0.9967
Evaluating Results - Average Loss: 0.1601 | auc/_mean: 0.9230 | auc/healthy: 0.9978 | auc/multiple_diseases: 0.6984, auc/rust: 0.9982, auc/scab: 0.9977



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 05 | Time: 3m 5s
Training Results - Average Loss: 0.1025 | auc/_mean: 0.9434 | auc/healthy: 0.9992 | auc/multiple_diseases: 0.7785, auc/rust: 0.9973, auc/scab: 0.9987
Evaluating Results - Average Loss: 0.1769 | auc/_mean: 0.9172 | auc/healthy: 0.9979 | auc/multiple_diseases: 0.6747, auc/rust: 0.9986, auc/scab: 0.9974



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 06 | Time: 3m 10s
Training Results - Average Loss: 0.0865 | auc/_mean: 0.9415 | auc/healthy: 0.9990 | auc/multiple_diseases: 0.7704, auc/rust: 0.9984, auc/scab: 0.9983
Evaluating Results - Average Loss: 0.2078 | auc/_mean: 0.9187 | auc/healthy: 0.9973 | auc/multiple_diseases: 0.6825, auc/rust: 0.9982, auc/scab: 0.9967



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 07 | Time: 3m 6s
Training Results - Average Loss: 0.0699 | auc/_mean: 0.9458 | auc/healthy: 0.9996 | auc/multiple_diseases: 0.7857, auc/rust: 0.9993, auc/scab: 0.9988
Evaluating Results - Average Loss: 0.1769 | auc/_mean: 0.9110 | auc/healthy: 0.9968 | auc/multiple_diseases: 0.6518, auc/rust: 0.9991, auc/scab: 0.9962



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 08 | Time: 3m 6s
Training Results - Average Loss: 0.0793 | auc/_mean: 0.9432 | auc/healthy: 0.9992 | auc/multiple_diseases: 0.7761, auc/rust: 0.9990, auc/scab: 0.9984
Evaluating Results - Average Loss: 0.1635 | auc/_mean: 0.9189 | auc/healthy: 0.9974 | auc/multiple_diseases: 0.6814, auc/rust: 0.9988, auc/scab: 0.9981



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 09 | Time: 3m 4s
Training Results - Average Loss: 0.0906 | auc/_mean: 0.9426 | auc/healthy: 0.9987 | auc/multiple_diseases: 0.7755, auc/rust: 0.9977, auc/scab: 0.9987
Evaluating Results - Average Loss: 0.1796 | auc/_mean: 0.9143 | auc/healthy: 0.9968 | auc/multiple_diseases: 0.6644, auc/rust: 0.9990, auc/scab: 0.9970



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 10 | Time: 3m 4s
Training Results - Average Loss: 0.0808 | auc/_mean: 0.9473 | auc/healthy: 0.9996 | auc/multiple_diseases: 0.7931, auc/rust: 0.9986, auc/scab: 0.9978
Evaluating Results - Average Loss: 0.1886 | auc/_mean: 0.9232 | auc/healthy: 0.9967 | auc/multiple_diseases: 0.7009, auc/rust: 0.9981, auc/scab: 0.9970



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 11 | Time: 3m 4s
Training Results - Average Loss: 0.0764 | auc/_mean: 0.9554 | auc/healthy: 0.9995 | auc/multiple_diseases: 0.8270, auc/rust: 0.9978, auc/scab: 0.9975
Evaluating Results - Average Loss: 0.1781 | auc/_mean: 0.9309 | auc/healthy: 0.9974 | auc/multiple_diseases: 0.7304, auc/rust: 0.9982, auc/scab: 0.9977



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 12 | Time: 3m 5s
Training Results - Average Loss: 0.0780 | auc/_mean: 0.9479 | auc/healthy: 0.9990 | auc/multiple_diseases: 0.7945, auc/rust: 0.9989, auc/scab: 0.9991
Evaluating Results - Average Loss: 0.1755 | auc/_mean: 0.9337 | auc/healthy: 0.9980 | auc/multiple_diseases: 0.7406, auc/rust: 0.9987, auc/scab: 0.9975



HBox(children=(FloatProgress(value=0.0, description='Train', max=305.0, style=ProgressStyle(description_width=…


Epoch: 13 | Time: 3m 4s
Training Results - Average Loss: 0.0808 | auc/_mean: 0.9512 | auc/healthy: 0.9995 | auc/multiple_diseases: 0.8085, auc/rust: 0.9980, auc/scab: 0.9987
Evaluating Results - Average Loss: 0.1804 | auc/_mean: 0.9307 | auc/healthy: 0.9973 | auc/multiple_diseases: 0.7289, auc/rust: 0.9988, auc/scab: 0.9977

EarlyStopping


In [25]:

class PlantDatasetTest(Dataset):
    
    def __init__(self, df, config, transforms=None):
    
        self.df = df
        self.images_dir = config.root_images
        self.class_names = config.class_names
        self.transforms=transforms
        
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, idx):
        image_src = self.images_dir + self.df.iloc[idx]['image_id'] + '.jpg'
        image = cv2.imread(image_src, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transforms:
            transformed = self.transforms(image=image)
            image = transformed['image']

        return image, idx

test_df = pd.read_csv(config.root + 'test.csv')
test_dataset = PlantDatasetTest(test_df, config, train_transforms)
test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers)

In [26]:
def predict(model, dataloader: DataLoader, device, dataset: Dataset, tta_count: int, config: ConfigExperiment):
    n_samples = len(dataset)
    y_probas_tta = np.zeros((n_samples, config.num_classes, tta_count), dtype=np.float32)
    
    model.eval()
    model.to(device)
    with torch.no_grad():
        for i_epoch in tqdm(range(tta_count), desc='TTA', total=tta_count, position=1, leave=True):
            for i, (images, idx) in enumerate(dataloader):
                images = images.to(device)
                outputs = model(images)
                y_pred = F.softmax(outputs, dim=1)
                
                tta_index = i_epoch - 1
                start_index = (i % len(dataloader)) * config.batch_size
                end_index = min(start_index + config.batch_size, n_samples)
                batch_y_probas = y_pred.detach().cpu().numpy()                
                y_probas_tta[start_index:end_index, :, tta_index] = batch_y_probas

    return y_probas_tta

In [27]:
model = get_model(config.model_name, config.num_classes)
model.load_state_dict(torch.load(f"{config.save_dirname}/best_model_epoch=12_auc_mean=0.9337178937045071.pth"))

Loaded pretrained weights for efficientnet-b5


<All keys matched successfully>

In [28]:
result = predict(model, test_dataloader, device, test_dataset, tta_count=5, config=config)

HBox(children=(FloatProgress(value=0.0, description='TTA', max=5.0, style=ProgressStyle(description_width='ini…




In [30]:
y_probas = np.mean(result, axis=-1)

In [31]:
y_probas[:10]

array([[2.47132152e-06, 9.37707070e-03, 9.90620434e-01, 5.30164748e-08],
       [1.23882239e-06, 5.10935672e-04, 9.99487877e-01, 8.49835868e-10],
       [1.51702261e-05, 1.70200574e-03, 1.83298141e-08, 9.98282790e-01],
       [9.98886466e-01, 3.00522341e-04, 1.93380238e-05, 7.93727289e-04],
       [5.18367324e-06, 1.06555652e-02, 9.89339173e-01, 9.46044878e-08],
       [5.90220928e-01, 2.36532129e-02, 8.82808599e-05, 3.86037588e-01],
       [9.99259174e-01, 2.16139160e-04, 1.68442239e-05, 5.07874298e-04],
       [1.64555913e-05, 7.36436963e-01, 8.93729739e-04, 2.62652904e-01],
       [4.84844844e-04, 2.05751527e-02, 1.05288045e-05, 9.78929520e-01],
       [7.15310443e-06, 6.55038399e-04, 9.99337792e-01, 1.73163581e-08]],
      dtype=float32)

In [32]:
test_df = pd.read_csv(config.root + 'test.csv')
test_df["healthy"] = 0
test_df["multiple_diseases"] = 0
test_df["rust"] = 0
test_df["scab"] = 0
test_df[['healthy', 'multiple_diseases', 'rust', 'scab']] = y_probas
test_df.to_csv(config.submission_file, index=False)
test_df.head()

Unnamed: 0,image_id,healthy,multiple_diseases,rust,scab
0,Test_0,2e-06,0.009377,0.9906204,5.301647e-08
1,Test_1,1e-06,0.000511,0.9994879,8.498359e-10
2,Test_2,1.5e-05,0.001702,1.832981e-08,0.9982828
3,Test_3,0.998886,0.000301,1.933802e-05,0.0007937273
4,Test_4,5e-06,0.010656,0.9893392,9.460449e-08


In [33]:
config.submission_file

'01_efficientnet-b5.csv'