### Train your dataset

Here is a detailed example of model training for `mobilenet_v2`, `test_data` of ants and bees and default imbalanced tool (actually, here is a content of `model_trainin.py`). All training constants are placed into `config.py`. 

In more comfortable and faster way the training can also be launched by running `model_trainin.py` with desired configuration.

In [1]:
import os

os.chdir(os.path.join('..', 'src'))

import cv2
import torch
import torchvision

import hashlib
import numpy as np

from PIL import Image
from torch.utils.data import DataLoader

from pandas.io.json._normalize import nested_to_record

from utils import nn_utils, image_utils, augmentation
from utils.model_training_utils import get_computing_device, load_pytorch_model, get_optimizer, \
                                       PytorchDataset, SaveBestModelCallback, model_train

import resnet_selector
import mobilenet_selector

import neptune_logger
import config

In [2]:
# Dict for imbalanced params setting
imbalanced_dict = {'train_sampler': {'default': False,
                                     'weighted_loss': False,
                                     'train_sampler': True
                                     },
                   'weighted_loss': {'default': False,
                                     'weighted_loss': True,
                                     'train_sampler': False
                                     },
                   'tag': {'default': 'default',
                           'weighted_loss': 'weighted_loss',
                           'train_sampler': 'custom_batch_sampler'
                           }
                   }
imbalanced_tool = 'default'

In [3]:
CONFIG = {'model': {'model_name': 'mobilenet_v2',           
                        'optimizer': config.MODEL_OPTIMIZER,
                        'pretrained': config.PRETRAINED,
                        'freeze_conv': config.FREEZE_CONV,
                        'epochs': config.EPOCHS,
                        'batch_size': config.BATCH_SIZE,
                        },

              'data': {'root_dir': os.getcwd().replace('src', 'datasets'),
                       'dataset_name': 'test_data',
                       'resize_img': config.RESIZE_IMG,
                       'num_workers': config.NUM_WORKERS,
                       'pytorch_aug': config.PYTORCH_AUG,
                       'save_aspect_ratio': config.SAVE_ASPECT_RATIO,
                       'imgaug_aug': config.IMGAUG_AUG,
                       'img_normalize': config.IMG_NORMALIZE,
                       },

              'imbalanced_tools': {'train_sampler': imbalanced_dict['train_sampler'][imbalanced_tool],
                                   'weighted_loss': imbalanced_dict['weighted_loss'][imbalanced_tool],
                                   }
              }

In [4]:
NEPTUNE_EXPERIMENT_TAG_LIST = [CONFIG['data']['dataset_name'],
                               CONFIG['model']['model_name'],
                               'imbalanced_tools',
                               imbalanced_dict['tag'][imbalanced_tool],
                               ]

DEVICE = get_computing_device()
CONFIG.update({'seed': config.SEED})
np.random.seed(config.SEED)
torch.manual_seed(config.SEED)

FULL_DATASET_DIR = os.path.join(CONFIG['data']['root_dir'],
                                    CONFIG['data']['dataset_name'])
CLASS_DICT = image_utils.make_class_dict(os.path.join(FULL_DATASET_DIR, 'df_img_meta.csv'))

print('Workers:', CONFIG['data']['num_workers'])
if CONFIG['data']['num_workers'] > 0:
    cv2.setNumThreads(0)

CUDA available: True
Workers: 8


In [5]:
# Setting image transformations
_pytorch_aug = CONFIG['data']['pytorch_aug']
_img_normalize = CONFIG['data']['img_normalize']
_save_aspect_ratio = CONFIG['data']['save_aspect_ratio']

default_transforms_list = [augmentation.ImgAugTransform(with_aug=False),
                           image_utils.CustomResize((CONFIG['data']['resize_img'],
                                         CONFIG['data']['resize_img']),
                                        _save_aspect_ratio
                                        ),
                           torchvision.transforms.ToTensor()
                           ]

if _pytorch_aug:
    # Pytorch default & augmentation transform
    transform_list = [augmentation.ImgAugTransform(with_aug=CONFIG['data']['imgaug_aug']),
                      image_utils.CustomResize((CONFIG['data']['resize_img'],
                                    CONFIG['data']['resize_img']),
                                   _save_aspect_ratio
                                   ),
                      torchvision.transforms.ColorJitter(hue=.05, saturation=.05),
                      torchvision.transforms.RandomHorizontalFlip(),
                      torchvision.transforms.RandomRotation(20, resample=Image.BILINEAR),
                      torchvision.transforms.ToTensor()
                      ]
else:
    transform_list = [augmentation.ImgAugTransform(with_aug=CONFIG['data']['imgaug_aug']),
                      image_utils.CustomResize((CONFIG['data']['resize_img'],
                                    CONFIG['data']['resize_img']),
                                   _save_aspect_ratio
                                   ),
                      torchvision.transforms.ToTensor()
                      ]

if _img_normalize:
    # Parameters were taken from Pytorch example for Imagenet
    # https://github.com/pytorch/examples/blob/master/imagenet/main.py#L197-L198
    transform_list.append(torchvision.transforms.Normalize(mean=config.IMG_NORMALIZE_MEAN,
                                                           std=config.IMG_NORMALIZE_STD)
                          )
    default_transforms_list.append(torchvision.transforms.Normalize(mean=config.IMG_NORMALIZE_MEAN,
                                                                    std=config.IMG_NORMALIZE_STD)
                                  )

In [6]:
train_transforms = torchvision.transforms.Compose(transform_list)
val_trainsforms = torchvision.transforms.Compose(default_transforms_list)

In [7]:
# Set dictionaries for training, validation and testing
TRAIN_DATASET_KWARGS = {'class_dict': CLASS_DICT,
                        'dataset_rootdir': FULL_DATASET_DIR,
                        'val_size': config.TESTVAL_SIZE,
                        'test_size': config.TEST_SIZE_FROM_TESTVAL,   # Доля от val_size
                        'seed': config.SEED,
                        'torch_transform': train_transforms,
                        }

TESTVAL_DATASET_KWARGS = {'class_dict': CLASS_DICT,
                          'dataset_rootdir': FULL_DATASET_DIR,
                          'val_size': config.TESTVAL_SIZE,
                          'test_size': config.TEST_SIZE_FROM_TESTVAL,  # Доля от val_size
                          'seed': config.SEED,
                          'torch_transform': val_trainsforms,
                          }

In [8]:
train_dataset = PytorchDataset('train', TRAIN_DATASET_KWARGS)
val_dataset = PytorchDataset('val', TESTVAL_DATASET_KWARGS)

class_dict: {'ants': 0, 'bees': 1}
dataset_rootdir: /home/alexander/Documents/py_projects/bitbucket/rsf_insects_recognition/datasets/test_data
val_size: 0.3
test_size: 0.5
seed: 42
torch_transform: Compose(
    <utils.augmentation.ImgAugTransform object at 0x7f84e8bbe790>
    <utils.image_utils.CustomResize object at 0x7f84e8bbe990>
    ColorJitter(brightness=None, contrast=None, saturation=[0.95, 1.05], hue=[-0.05, 0.05])
    RandomHorizontalFlip(p=0.5)
    RandomRotation(degrees=(-20, 20), resample=2, expand=False)
    ToTensor()
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)

Unique classes in /home/alexander/Documents/py_projects/bitbucket/rsf_insects_recognition/datasets/test_data:
['bees', 'ants']
Mapping class dictionary: {'ants': 0, 'bees': 1}
Number of observations for TRAIN: 170

TRAIN classes:  2
0    86
1    84
Name: class, dtype: int64
VAL classes:  2
1    18
0    18
Name: class, dtype: int64
TEST classes:  2
1    19
0    19
Name: class, dtype: int6

In [9]:
# What sampler for batching to choose
if CONFIG['imbalanced_tools']['train_sampler']:

    _train_shuffle = False
    _train_sampler = nn_utils.ImbalancedDatasetSampler(dataset=train_dataset,
                                                       class_dict=CLASS_DICT,
                                                       random_state=config.SEED)
else:
    _train_shuffle = True
    _train_sampler = None

In [10]:
train_data_loader = DataLoader(train_dataset,
                               batch_size=CONFIG['model']['batch_size'],
                               shuffle=_train_shuffle, 
                               sampler=_train_sampler,
                               num_workers=CONFIG['data']['num_workers'])
val_data_loader = DataLoader(val_dataset, 
                             batch_size=CONFIG['model']['batch_size'],
                             shuffle=True, 
                             num_workers=CONFIG['data']['num_workers'])

In [11]:
# Optimizers and their parameters
# https://pytorch.org/docs/stable/optim.html
OPTIM_KWARGS = {'optimizer': {'adadelta': {'lr': 1.0, 'rho': 0.9, 'eps': 1e-06, 'weight_decay': 0},

                              'adagrad': {'lr': 0.01, 'lr_decay': 0, 'weight_decay': 0,
                                          'initial_accumulator_value': 0, 'eps': 1e-10},

                              'adam': {'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08,
                                       'weight_decay': 0, 'amsgrad': False},

                              'adamw': {'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08,

                                        'weight_decay': 0.01, 'amsgrad': False},

                              'adamax': {'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08,
                                         'weight_decay': 0},

                              'rms_prop': {'lr': 0.01, 'alpha': 0.99, 'eps': 1e-08, 'weight_decay': 0,
                                           'momentum': 0, 'centered': False}
                              },

                'lr_scheduler': {'mode': 'min', 'factor': 0.95, 'patience': 10, 'verbose': False,
                                 'threshold': 0.0001, 'threshold_mode': 'rel', 'cooldown': 0, 'min_lr': 1e-9,
                                 'eps': 1e-08}
                }

In [12]:
# Select and configure model to be trained
# Number of classes
n_classes = len(TRAIN_DATASET_KWARGS['class_dict'].keys())

if 'mobilenet' in CONFIG['model']['model_name']:
    nnet = mobilenet_selector.get_mobilenet(version=CONFIG['model']['model_name'],
                                            class_number=n_classes,
                                            pretrained=CONFIG['model']['pretrained'],
                                            freeze_conv=CONFIG['model']['freeze_conv'])
else:
    nnet = resnet_selector.get_resnet(version=CONFIG['model']['model_name'],
                                      class_number=n_classes,
                                      pretrained=CONFIG['model']['pretrained'],
                                      freeze_conv=CONFIG['model']['freeze_conv'])

Loaded: mobilenet_v2


In [13]:
# Whether to use weighted loss or not
if CONFIG['imbalanced_tools']['weighted_loss']:
    _label_to_count = train_dataset.df_metadata['class'].value_counts().to_dict()
    _unnorm_weights = {k: 1 / v for k, v in _label_to_count.items()}
    _normed_weights = {k: v / sum(_unnorm_weights.values()) for k, v in _unnorm_weights.items()}

    _loss_weight = [v for k, v in sorted(_normed_weights.items(), key=lambda x: x[0])]
    _loss_weight = torch.FloatTensor(_loss_weight).to(DEVICE)

else:
    _loss_weight = None

# Set training keyword argumetns
criterion = torch.nn.CrossEntropyLoss(weight=_loss_weight)
optimizer = get_optimizer(CONFIG['model']['optimizer'],
                          nnet,
                          OPTIM_KWARGS['optimizer'])

In [14]:
# Train dataset img hashes to dataset hash. Universal for every machine.
train_data_hash = hashlib.sha1(
    '_'.join(sorted(train_dataset.df_metadata['sha1'].values.tolist())).encode()).hexdigest()

training_kwargs = {'device': DEVICE,

                   'criterion': criterion,
                   'optimizer': optimizer,
                   'lr_scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                                              **OPTIM_KWARGS['lr_scheduler']),

                   'data_loaders': {'train_loader': train_data_loader,
                                    'val_loader': val_data_loader},
                   'total_epochs': CONFIG['model']['epochs'],
                   'batch_size': CONFIG['model']['batch_size'],
                   'target_names': list(TRAIN_DATASET_KWARGS['class_dict'].keys()),

                   'model_name': f"""{CONFIG['model']['model_name']}_{train_data_hash}""",
                   'class_dict': CLASS_DICT,
                   'save_path': os.path.join(os.getcwd(), 'weights'),
                   'general_config': CONFIG
                   }

weights_folder = os.path.join(os.getcwd(), 'weights',
                                  f"""pytorch_{CONFIG['model']['model_name']}_{train_data_hash}""")

In [15]:
if config.USE_NEPTUNE:
    print('Neptune logger is ON.')
    # Сбор параметров для логгирования
    TOTAL_PARAMS = {'config': CONFIG,
                    'training_kwargs': training_kwargs,
                    'train_dataset_kwargs': TRAIN_DATASET_KWARGS,
                    'testval_dataset_kwargs': TESTVAL_DATASET_KWARGS,
                    'weights_folder': weights_folder,
                    'training_data_sha1': train_data_hash,
                    }
    neptune_kwargs = {'params': nested_to_record(TOTAL_PARAMS, sep='_'),
                      'artifact_path': os.path.join(os.getcwd().replace('src', 'log_artifacts'), 'artifacts'),
                      'image_path': os.path.join(os.getcwd().replace('src', 'log_artifacts'), 'images'),
                      'tag_list': NEPTUNE_EXPERIMENT_TAG_LIST,
                      'training_data_sha1': train_data_hash
                      }
    neptune_class = neptune_logger.NeptuneLogger(neptune_kwargs)
else:
    print('Neptune logger is OFF.')
    neptune_class = None

Neptune logger is OFF.


In [16]:
# Training
model_train(nnet, training_kwargs, neptune_class=neptune_class)

100%|██████████| 11/11 [00:01<00:00,  5.69it/s]
100%|██████████| 3/3 [00:00<00:00,  9.66it/s]
  0%|          | 0/11 [00:00<?, ?it/s]

[SaveBestModelCallback] val_loss was improved: inf -> 0.48677828907966614. Model was saved.
[2.276771 sec.][Epoch 1] train_loss: 6.610351487994194, val_loss: 0.48677828907966614, learning_rate: 0.001.
              precision    recall  f1-score   support

        ants       1.00      0.83      0.91        18
        bees       0.86      1.00      0.92        18

    accuracy                           0.92        36
   macro avg       0.93      0.92      0.92        36
weighted avg       0.93      0.92      0.92        36



100%|██████████| 11/11 [00:01<00:00,  5.61it/s]
100%|██████████| 3/3 [00:00<00:00, 10.22it/s]
  0%|          | 0/11 [00:00<?, ?it/s]

[2.270029 sec.][Epoch 2] train_loss: 4.160225257277489, val_loss: 0.6025669574737549, learning_rate: 0.001.
              precision    recall  f1-score   support

        ants       0.94      0.94      0.94        18
        bees       0.94      0.94      0.94        18

    accuracy                           0.94        36
   macro avg       0.94      0.94      0.94        36
weighted avg       0.94      0.94      0.94        36



100%|██████████| 11/11 [00:02<00:00,  4.59it/s]
100%|██████████| 3/3 [00:00<00:00, 10.06it/s]
  0%|          | 0/11 [00:00<?, ?it/s]

[2.709707 sec.][Epoch 3] train_loss: 3.536779396235943, val_loss: 0.5929201245307922, learning_rate: 0.001.
              precision    recall  f1-score   support

        ants       0.93      0.78      0.85        18
        bees       0.81      0.94      0.87        18

    accuracy                           0.86        36
   macro avg       0.87      0.86      0.86        36
weighted avg       0.87      0.86      0.86        36



100%|██████████| 11/11 [00:02<00:00,  5.03it/s]
100%|██████████| 3/3 [00:00<00:00,  9.80it/s]
  0%|          | 0/11 [00:00<?, ?it/s]

[2.508828 sec.][Epoch 4] train_loss: 3.308390859514475, val_loss: 1.2561622858047485, learning_rate: 0.001.
              precision    recall  f1-score   support

        ants       0.88      0.78      0.82        18
        bees       0.80      0.89      0.84        18

    accuracy                           0.83        36
   macro avg       0.84      0.83      0.83        36
weighted avg       0.84      0.83      0.83        36



100%|██████████| 11/11 [00:02<00:00,  5.24it/s]
100%|██████████| 3/3 [00:00<00:00,  9.79it/s]
  0%|          | 0/11 [00:00<?, ?it/s]

[2.419902 sec.][Epoch 5] train_loss: 3.0779288709163666, val_loss: 2.321338653564453, learning_rate: 0.001.
              precision    recall  f1-score   support

        ants       0.93      0.78      0.85        18
        bees       0.81      0.94      0.87        18

    accuracy                           0.86        36
   macro avg       0.87      0.86      0.86        36
weighted avg       0.87      0.86      0.86        36



100%|██████████| 11/11 [00:01<00:00,  5.77it/s]
100%|██████████| 3/3 [00:00<00:00, 10.06it/s]
  0%|          | 0/11 [00:00<?, ?it/s]

[2.22081 sec.][Epoch 6] train_loss: 2.3602486550807953, val_loss: 1.289846420288086, learning_rate: 0.001.
              precision    recall  f1-score   support

        ants       0.93      0.72      0.81        18
        bees       0.77      0.94      0.85        18

    accuracy                           0.83        36
   macro avg       0.85      0.83      0.83        36
weighted avg       0.85      0.83      0.83        36



100%|██████████| 11/11 [00:01<00:00,  5.50it/s]
100%|██████████| 3/3 [00:00<00:00,  9.74it/s]
  0%|          | 0/11 [00:00<?, ?it/s]

[2.322329 sec.][Epoch 7] train_loss: 1.5529964994639158, val_loss: 0.6721146106719971, learning_rate: 0.001.
              precision    recall  f1-score   support

        ants       1.00      0.78      0.88        18
        bees       0.82      1.00      0.90        18

    accuracy                           0.89        36
   macro avg       0.91      0.89      0.89        36
weighted avg       0.91      0.89      0.89        36



100%|██████████| 11/11 [00:01<00:00,  5.53it/s]
100%|██████████| 3/3 [00:00<00:00, 10.09it/s]
  0%|          | 0/11 [00:00<?, ?it/s]

[2.300741 sec.][Epoch 8] train_loss: 3.0569884292781353, val_loss: 0.8855840563774109, learning_rate: 0.001.
              precision    recall  f1-score   support

        ants       0.93      0.78      0.85        18
        bees       0.81      0.94      0.87        18

    accuracy                           0.86        36
   macro avg       0.87      0.86      0.86        36
weighted avg       0.87      0.86      0.86        36



100%|██████████| 11/11 [00:01<00:00,  5.89it/s]
100%|██████████| 3/3 [00:00<00:00,  9.92it/s]
  0%|          | 0/11 [00:00<?, ?it/s]

[2.185996 sec.][Epoch 9] train_loss: 2.208288636058569, val_loss: 1.5977210998535156, learning_rate: 0.001.
              precision    recall  f1-score   support

        ants       0.89      0.89      0.89        18
        bees       0.89      0.89      0.89        18

    accuracy                           0.89        36
   macro avg       0.89      0.89      0.89        36
weighted avg       0.89      0.89      0.89        36



100%|██████████| 11/11 [00:01<00:00,  5.97it/s]
100%|██████████| 3/3 [00:00<00:00, 10.14it/s]
  0%|          | 0/11 [00:00<?, ?it/s]

[2.154213 sec.][Epoch 10] train_loss: 2.0382276698946953, val_loss: 0.9035669565200806, learning_rate: 0.001.
              precision    recall  f1-score   support

        ants       0.93      0.78      0.85        18
        bees       0.81      0.94      0.87        18

    accuracy                           0.86        36
   macro avg       0.87      0.86      0.86        36
weighted avg       0.87      0.86      0.86        36



100%|██████████| 11/11 [00:02<00:00,  5.41it/s]
100%|██████████| 3/3 [00:00<00:00,  9.87it/s]
  0%|          | 0/11 [00:00<?, ?it/s]

[2.353556 sec.][Epoch 11] train_loss: 1.3232023641467094, val_loss: 0.6963770985603333, learning_rate: 0.001.
              precision    recall  f1-score   support

        ants       0.93      0.78      0.85        18
        bees       0.81      0.94      0.87        18

    accuracy                           0.86        36
   macro avg       0.87      0.86      0.86        36
weighted avg       0.87      0.86      0.86        36



100%|██████████| 11/11 [00:02<00:00,  4.18it/s]
100%|██████████| 3/3 [00:00<00:00, 10.27it/s]
  0%|          | 0/11 [00:00<?, ?it/s]

[2.939624 sec.][Epoch 12] train_loss: 2.660275097936392, val_loss: 1.0027539730072021, learning_rate: 0.00095.
              precision    recall  f1-score   support

        ants       0.93      0.78      0.85        18
        bees       0.81      0.94      0.87        18

    accuracy                           0.86        36
   macro avg       0.87      0.86      0.86        36
weighted avg       0.87      0.86      0.86        36



100%|██████████| 11/11 [00:01<00:00,  5.64it/s]
100%|██████████| 3/3 [00:00<00:00,  9.66it/s]
  0%|          | 0/11 [00:00<?, ?it/s]

[2.275374 sec.][Epoch 13] train_loss: 2.6412623301148415, val_loss: 0.749087393283844, learning_rate: 0.00095.
              precision    recall  f1-score   support

        ants       0.94      0.83      0.88        18
        bees       0.85      0.94      0.89        18

    accuracy                           0.89        36
   macro avg       0.89      0.89      0.89        36
weighted avg       0.89      0.89      0.89        36



 55%|█████▍    | 6/11 [00:02<00:01,  2.87it/s]


KeyboardInterrupt: 