In [1]:
import gc
import os
import cv2
import sys
import json
import time
import timm
import torch
import random
import sklearn.metrics

from PIL import Image
from pathlib import Path
from functools import partial
from contextlib import contextmanager

import numpy as np
import scipy as sp
import pandas as pd
import torch.nn as nn

from torch.optim import Adam, SGD, AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset
from albumentations import Compose, Normalize, Resize
from albumentations.pytorch import ToTensorV2

os.environ["CUDA_VISIBLE_DEVICES"]="2"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [2]:
!nvidia-smi

Wed Mar 10 11:57:07 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.39       Driver Version: 460.39       CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 207...  Off  | 00000000:04:00.0 Off |                  N/A |
| 20%   31C    P8     8W / 215W |     13MiB /  7982MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 3080    Off  | 00000000:09:00.0 Off |                  N/A |
|100%   72C    P2   301W / 340W |   6946MiB / 10014MiB |     90%      Default |
|       

In [3]:
train_metadata = pd.read_csv("/Datasets/DF20/metadata/DanishFungi2020_train_metadata_DEV.csv")
print(len(train_metadata))

test_metadata = pd.read_csv("/Datasets/DF20/metadata/DanishFungi2020_test_metadata_DEV.csv")
print(len(test_metadata))

266344
29594


In [4]:
train_metadata.head()

Unnamed: 0,gbifID,eventDate,year,month,day,countryCode,locality,identifiedBy,taxonID,scientificName,...,image_url,Substrate,rightsHolder,Latitude,Longitude,CoorUncert,Habitat,image_path,class_id,genus_id
0,2238546328,2018-04-16T00:00:00,2018.0,4.0,16.0,DK,Ulvedal Plantage,Ulfva Melchior Hvidegaard,30872.0,Ramalina farinacea (L.) Ach.,...,https://api.gbif.org/v1/image/unsafe/https://s...,bark of living trees,Ulfva Melchior Hvidegaard,56.299706,9.25811,50.0,Mixed woodland (with coniferous and deciduous ...,/Datasets/SvampeAtlas-14.12.2020/2238546328-30...,1273,453
1,2558871973,2020-01-03T00:00:00,2020.0,1.0,3.0,DK,Slotshegn,Thomas Læssøe,15256.0,Hysterium acuminatum Fr.,...,https://api.gbif.org/v1/image/unsafe/https://s...,dead wood (including bark),Ole Martin,55.861899,11.975973,50.0,Deciduous woodland,/Datasets/SvampeAtlas-14.12.2020/2558871973-53...,708,246
2,2238503501,2017-08-22T00:00:00,2017.0,8.0,22.0,DK,Petersborg Strandenge,Per Taudal Poulsen,61200.0,Gliophorus perplexus (A.H.Sm. & Hesler) Kovalenko,...,https://api.gbif.org/v1/image/unsafe/https://s...,soil,Per Taudal Poulsen,56.975158,9.285525,75.0,natural grassland,/Datasets/SvampeAtlas-14.12.2020/2238503501-24...,535,186
3,2446759075,2019-10-26T00:00:00,2019.0,10.0,26.0,DK,Klintebjerg,Susanne Rabenborg,30530.0,Lecidella scabra (Taylor) Hertel & Leuckert,...,https://api.gbif.org/v1/image/unsafe/https://s...,stone,Susanne Rabenborg,55.960242,11.583103,15.0,gravel or clay pit,/Datasets/SvampeAtlas-14.12.2020/2446759075-19...,832,276
4,2238472345,2016-08-21T00:00:00,2016.0,8.0,21.0,DK,Blåbjerg,Tom Smidth,63728.0,"Russula fragilis Fr., 1838",...,https://api.gbif.org/v1/image/unsafe/https://s...,soil,Tom Smidth,55.742985,8.250188,50.0,Mixed woodland (with coniferous and deciduous ...,/Datasets/SvampeAtlas-14.12.2020/2238472345-16...,1338,476


In [5]:
@contextmanager
def timer(name):
    t0 = time.time()
    LOGGER.info(f'[{name}] start')
    yield
    LOGGER.info(f'[{name}] done in {time.time() - t0:.0f} s.')

    
def init_logger(log_file='train.log'):
    from logging import getLogger, DEBUG, FileHandler,  Formatter,  StreamHandler
    
    log_format = '%(asctime)s %(levelname)s %(message)s'
    
    stream_handler = StreamHandler()
    stream_handler.setLevel(DEBUG)
    stream_handler.setFormatter(Formatter(log_format))
    
    file_handler = FileHandler(log_file)
    file_handler.setFormatter(Formatter(log_format))
    
    logger = getLogger('Herbarium')
    logger.setLevel(DEBUG)
    logger.addHandler(stream_handler)
    logger.addHandler(file_handler)
    
    return logger

LOG_FILE = '../../logs/DF20/-EfficientNet-B0-224.log'
LOGGER = init_logger(LOG_FILE)


def seed_torch(seed=777):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

SEED = 777
seed_torch(SEED)

In [6]:
class TrainDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_path = self.df['image_path'].values[idx]
        label = self.df['class_id'].values[idx]
        image = cv2.imread(file_path)
        
        try:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        except:
            print(file_path)

        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        
        return image, label

In [7]:
HEIGHT = 224
WIDTH = 224


from albumentations import RandomCrop, HorizontalFlip, VerticalFlip, RandomBrightnessContrast, CenterCrop, PadIfNeeded, RandomResizedCrop

def get_transforms(*, data):
    assert data in ('train', 'valid')

    if data == 'train':
        return Compose([
            RandomResizedCrop(WIDTH, HEIGHT, scale=(0.8, 1.0)),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            RandomBrightnessContrast(p=0.2),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])

    elif data == 'valid':
        return Compose([
            Resize(WIDTH, HEIGHT),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])

In [8]:
N_CLASSES = len(train_metadata['class_id'].unique())

train_dataset = TrainDataset(train_metadata, transform=get_transforms(data='train'))
valid_dataset = TrainDataset(test_metadata, transform=get_transforms(data='valid'))

In [11]:
# Adjust BATCH_SIZE and ACCUMULATION_STEPS to values that if multiplied results in 64 !!!!!1
BATCH_SIZE = 32
ACCUMULATION_STEPS = 2
EPOCHS = 100
WORKERS = 8

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=WORKERS)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=WORKERS)

In [12]:
from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_pretrained('efficientnet-b0')

model._fc = nn.Linear(model._fc.in_features, N_CLASSES)

Loaded pretrained weights for efficientnet-b0


In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import f1_score, accuracy_score, top_k_accuracy_score
import tqdm


with timer('Train model'):
    accumulation_steps = ACCUMULATION_STEPS
    n_epochs = EPOCHS
    lr = 0.01
    
    model.to(device)
    
    optimizer = SGD(model.parameters(), lr=lr, momentum=0.9)
    scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.9, patience=1, verbose=True, eps=1e-6)
    
    criterion = nn.CrossEntropyLoss()
    best_score = 0.
    best_loss = np.inf
    
    for epoch in range(n_epochs):
        
        start_time = time.time()

        model.train()
        avg_loss = 0.

        optimizer.zero_grad()

        for i, (images, labels) in tqdm.tqdm(enumerate(train_loader)):

            images = images.to(device)
            labels = labels.to(device)

            y_preds = model(images)
            loss = criterion(y_preds, labels)

            # Scale the loss to the mean of the accumulated batch size
            loss = loss / accumulation_steps
            loss.backward()
            if (i - 1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

                avg_loss += loss.item() / len(train_loader)

        model.eval()
        avg_val_loss = 0.
        preds = np.zeros((len(valid_dataset)))
        preds_raw = []

        for i, (images, labels) in enumerate(valid_loader):
            
            images = images.to(device)
            labels = labels.to(device)
            
            with torch.no_grad():
                y_preds = model(images)
            
            preds[i * BATCH_SIZE: (i+1) * BATCH_SIZE] = y_preds.argmax(1).to('cpu').numpy()
            preds_raw.extend(y_preds.to('cpu').numpy())

            loss = criterion(y_preds, labels)
            avg_val_loss += loss.item() / len(valid_loader)
        
        scheduler.step(avg_val_loss)
            
        score = f1_score(test_metadata['class_id'], preds, average='macro')
        accuracy = accuracy_score(test_metadata['class_id'], preds)
        recall_3 = top_k_accuracy_score(test_metadata['class_id'], preds_raw, k=3)

        elapsed = time.time() - start_time

        LOGGER.debug(f'  Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f} F1: {score:.6f}  Accuracy: {accuracy:.6f} Recall@3: {recall_3:.6f} time: {elapsed:.0f}s')

        if accuracy>best_score:
            best_score = accuracy
            LOGGER.debug(f'  Epoch {epoch+1} - Save Best Accuracy: {best_score:.6f} Model')
            torch.save(model.state_dict(), f'DF20-EfficientNet-B0-224_best_accuracy.pth')

        if avg_val_loss<best_loss:
            best_loss = avg_val_loss
            LOGGER.debug(f'  Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model')
            torch.save(model.state_dict(), f'DF20-EfficientNet-B0-224_best_loss.pth')

2021-03-10 11:57:51,871 INFO [Train model] start
8324it [42:28,  3.27it/s]
2021-03-10 12:45:10,237 DEBUG   Epoch 1 - avg_train_loss: 0.9758  avg_val_loss: 2.6483 F1: 0.215129  Accuracy: 0.413090 Recall@3: 0.600662 time: 2837s
2021-03-10 12:45:10,238 DEBUG   Epoch 1 - Save Best Accuracy: 0.413090 Model
2021-03-10 12:45:10,316 DEBUG   Epoch 1 - Save Best Loss: 2.6483 Model
8324it [42:36,  3.26it/s]
2021-03-10 13:32:24,438 DEBUG   Epoch 2 - avg_train_loss: 0.5952  avg_val_loss: 2.0367 F1: 0.352075  Accuracy: 0.522403 Recall@3: 0.709840 time: 2834s
2021-03-10 13:32:24,439 DEBUG   Epoch 2 - Save Best Accuracy: 0.522403 Model
2021-03-10 13:32:24,500 DEBUG   Epoch 2 - Save Best Loss: 2.0367 Model
8324it [42:33,  3.26it/s]
2021-03-10 14:19:38,109 DEBUG   Epoch 3 - avg_train_loss: 0.4927  avg_val_loss: 1.8585 F1: 0.410297  Accuracy: 0.558593 Recall@3: 0.744577 time: 2834s
2021-03-10 14:19:38,110 DEBUG   Epoch 3 - Save Best Accuracy: 0.558593 Model
2021-03-10 14:19:38,169 DEBUG   Epoch 3 - Save 

In [None]:
torch.save(model.state_dict(), f'DF20-EfficientNet-B0-224-100E.pth')