In [None]:
import pandas as pd
import numpy as np

In [None]:
%cd baseline
train_df = pd.read_csv("train_tiger.csv")
train_df.head()

# Обучение модели

In [None]:
from pathlib import Path
from pprint import pprint
from typing import Tuple

import hydra
import pytorch_lightning as pl
from omegaconf import DictConfig
from oml.const import TCfg
from oml.datasets.images import get_retrieval_images_datasets
from oml.lightning.callbacks.metric import MetricValCallback
from oml.lightning.modules.extractor import ExtractorModule, ExtractorModuleDDP
from oml.lightning.pipelines.parser import (
    check_is_config_for_ddp,
    parse_logger_from_config,
    parse_ckpt_callback_from_config,
    parse_engine_params_from_config,
    parse_sampler_from_config,
    parse_scheduler_from_config,
)
from oml.metrics.embeddings import EmbeddingMetrics
from oml.registry.losses import get_criterion_by_cfg
from oml.registry.models import get_extractor_by_cfg
from oml.registry.optimizers import get_optimizer_by_cfg
from oml.registry.transforms import TRANSFORMS_REGISTRY, get_transforms_by_cfg
from oml.utils.misc import dictconfig_to_dict, set_global_seed
from torch.utils.data import DataLoader

import torch

import albumentations as albu
import cv2
from albumentations.pytorch import ToTensorV2
from oml.const import MEAN, PAD_COLOR, STD, TNormParam

In [None]:
from datetime import datetime

postfix = "metric_learning"

current_dateTime = datetime.now()
y = current_dateTime.year
month = current_dateTime.month
d = current_dateTime.day
hour = current_dateTime.hour
minute = current_dateTime.minute
s = current_dateTime.second
ms = current_dateTime.microsecond

cfg: TCfg = {
    "postfix": postfix,
    "seed": 42,
    "image_size": 640,
    "accelerator": "gpu",
    "devices": 1, 
    "dataframe_name": "train_tiger.csv",
    "dataset_root": "./",
    "logs_root": "logs/",
    "logs_folder": f"{y}-{month}-{d}-{hour}-{minute}-{s}-{ms}_{postfix}",
    "num_workers": 4,
    "cache_size": 0,
    "sampler": None,
    "bs_train": 32,
    "bs_val": 64,  
    "max_epochs": 5,  # number of epochs to train
    "valid_period": 2, 
    "save_dir": ".",

    "metric_args": {
        "metrics_to_exclude_from_visualization": ["cmc"],
        "map_top_k": [1, 3, 5], 
        "return_only_overall_category": False,
        "visualize_only_overall_category": False
    },

    "log_images": True,
    "metric_for_checkpointing": "OVERALL/map/5",


    "extractor":{
        "name": "resnet",
        "args":{
            "arch": "resnet50",
            "gem_p": 1.0,
            "remove_fc": True,
            "normalise_features": False,
            "weights": None,
        },
    },

    "criterion": {
        "name": "triplet",
        "args":{
          "margin": "null",
          "reduction": "mean"
        }
    },

    "optimizer":{
        "name": "adam",
        "args":{
            "lr": 1e-5,
        },
    },

    "scheduling": None,
    "logger":{
        "name": "tensorboard",  
        "args":{
            "save_dir": "."
        }
    }
}


In [None]:
def get_transforms(im_size: int, mean: TNormParam = MEAN, std: TNormParam = STD) -> albu.Compose:
    """
    Use default oml albu augs, but without HorizontalFlip.
    :param im_size:
    :param mean:
    :param std:
    :return:
    """
    return albu.Compose(
        [
            albu.LongestMaxSize(max_size=im_size),
            albu.PadIfNeeded(
                min_height=im_size,
                min_width=im_size,
                border_mode=cv2.BORDER_CONSTANT,
                value=PAD_COLOR,
            ),
            albu.Normalize(mean=mean, std=std),
            ToTensorV2(),
        ],
    )

In [None]:
def get_retrieval_loaders(cfg: TCfg) -> Tuple[DataLoader, DataLoader]:
    train_dataset, valid_dataset = get_retrieval_images_datasets(
        dataset_root=Path(cfg['dataset_root']),
        transforms_train=get_transforms(cfg['image_size']),
        transforms_val=get_transforms(cfg['image_size']),
        dataframe_name=cfg['dataframe_name'],
        cache_size=cfg['cache_size'],
        verbose=cfg.get('show_dataset_warnings', True),
    )    

    loader_train = DataLoader(
        dataset=train_dataset,
        num_workers=cfg['num_workers'],
        batch_size=cfg['bs_train'],
        drop_last=True,
        shuffle=True,
    )

    loader_val = DataLoader(dataset=valid_dataset, batch_size=cfg['bs_val'], num_workers=cfg['num_workers'])

    return loader_train, loader_val


In [None]:
def extractor_training_pipeline(cfg: TCfg) -> None:
    set_global_seed(cfg['seed'])

    cfg = dictconfig_to_dict(cfg)
    print(cfg)
    
    logger = parse_logger_from_config(cfg)
    logger.log_pipeline_info(cfg)

    loader_train, loaders_val = get_retrieval_loaders(cfg)
    extractor = get_extractor_by_cfg(cfg['extractor'])
    criterion = get_criterion_by_cfg(cfg['criterion'], **{'label2category': loader_train.dataset.get_label2category()})
    optimizable_parameters = [
        {'lr': cfg['optimizer']['args']['lr'], 'params': extractor.parameters()},
        {'lr': cfg['optimizer']['args']['lr'], 'params': criterion.parameters()},
    ]
    optimizer = get_optimizer_by_cfg(cfg['optimizer'], **{'params': optimizable_parameters})  # type: ignore

    module_kwargs = {}
    module_kwargs.update(parse_scheduler_from_config(cfg, optimizer=optimizer))
    module_constructor = ExtractorModule  # type: ignore

    pl_module = module_constructor(
        extractor=extractor,
        criterion=criterion,
        optimizer=optimizer,
        input_tensors_key=loader_train.dataset.input_tensors_key,
        labels_key=loader_train.dataset.labels_key,
        freeze_n_epochs=cfg.get('freeze_n_epochs', 0),
        **module_kwargs,
    )

    metrics_constructor = EmbeddingMetrics
    metrics_calc = metrics_constructor(
        dataset = loaders_val.dataset,
        **cfg.get('metric_args', {}),
    )


    metrics_clb_constructor = MetricValCallback
    metrics_clb = metrics_clb_constructor(
        metric=metrics_calc,
        log_images=cfg.get('log_images', False),
    )

    trainer = pl.Trainer(
        max_epochs=cfg['max_epochs'],
        num_sanity_val_steps=0,
        check_val_every_n_epoch=cfg['valid_period'],
        default_root_dir=str(Path.cwd()),
        enable_checkpointing=True,
        enable_progress_bar=True,
        enable_model_summary=True,
        callbacks=[metrics_clb, parse_ckpt_callback_from_config(cfg)],
        logger=logger,
        precision=16,
        # **trainer_engine_params,
        **cfg.get('lightning_trainer_extra_args', {}),
    )

    trainer.fit(model=pl_module, train_dataloaders=loader_train, val_dataloaders=loaders_val)


In [None]:
extractor_training_pipeline(cfg)