# About

`InChI` string can be split by `/` into some parts (max number of parts in training data is 11). The first part is the format which is uniform string(`InChI=1S`) in this competition, and the second part is **chemical formula** which represents **the number of atoms** in each molecular.

Let me take `InChI=1S/C13H20OS/c1-9(2)8-15-13-6-5-10(3)7-12(13)11(4)14/h5-7,9,11,14H,8H2,1-4H3`(`image_id`: `000011a64c74`) as an example.
The second part is `C13H20OS`, which means that the molecular have 13`Carbon`s, 20`Hydrogen`s, a `Oxygen` and a `Sulfur`.

All the parts of `InChI` have **variable length** except for the second one (chemical formula) because the kind of atoms in training data is **limited** to 12(`B`, `Br`, `C`, `Cl`, `F`, `H`, `I`, `N`, `O`, `P`, `S`, and `Si`). Therefore, we can represents a chemical formula by a **fixed length** vector and treat chemical formula prediction task as **multi-output regression task**.


In this notebook, I try to train a model which predicts chemical furmula by multi-variate regression. I think precisse prediction of chemical furmula is helpful to predict other parts of `InChI`.

There are some notebooks which forcus on atoms or chemical formulas.

* [Bristol-Myers Squibb_count_atom](https://www.kaggle.com/kalfirst/bristol-myers-squibb-count-atom)
* [Bristol-Myers Squibb: Counting Elements](https://www.kaggle.com/stainsby/bristol-myers-squibb-counting-elements)
* [Step by Step 2: LS dist < 1 chemical formula](https://www.kaggle.com/wineplanetary/step-by-step-2-ls-dist-1-chemical-formula)

The last one already did what I want to do (Thank you for sharing @wineplanetary). I noticed that when I was almost done with this notebook😅

To make the difference, I will show you training process in this notebook, and try solving the competition task by utilizing predicted chemical formula in inference notebook.

# Prepare

## import

In [None]:
import os
import re
import gc
import sys
import yaml
import copy
import random
import shutil
import typing as tp
from pathlib import Path

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from sklearn.metrics import mean_squared_error

import cv2
import albumentations
from albumentations.core.transforms_interface import ImageOnlyTransform, DualTransform
from albumentations.pytorch import ToTensorV2

import torch
from torch import nn
from torch.utils import data

sys.path.append("../input/pytorch-pfn-extras/pytorch-pfn-extras-0.3.2")
import pytorch_pfn_extras as ppe
from pytorch_pfn_extras.training import extensions as ppe_exts

sys.path.append("../input/iterative-stratification/iterative-stratification-master")
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold

sys.path.append("../input/timm-pytorch-image-models/pytorch-image-models-master")
import timm

In [None]:
ROOT = Path.cwd().parent
INPUT = ROOT / "input"
DATA = INPUT / "bms-molecular-translation"
TRAIN = DATA / "train"
TRAIN_224 = INPUT / "bms-molecular-224px-jpg-padded" / "train"
TEST = DATA / "test"
TEST_224 = INPUT / "bms-molecular-224px-jpg-padded" / "test"

TMP = ROOT / "tmp"
TMP.mkdir(exist_ok=True)

RANDAM_SEED = 1086

FOLDS = [0, 1, 2, 3, 4]
N_FOLD = len(FOLDS)
# FOLDS = [0,]
# N_FOLD = 5

TARGETS = [
    'B', 'Br', 'C', 'Cl',
    'F', 'H', 'I', 'N',
    'O', 'P', 'S', 'Si']

N_TARGETS = len(TARGETS)

## read label data

In [None]:
train = pd.read_csv(DATA / "train_labels.csv")
train.head()

## preprocess

### extract chemical formula from InChI

In [None]:
# # extract chemical formula
train["formula"] = train.InChI.str.extract("InChI=1S/([^/]+)/.+")
train.head()

### counting number of atoms for each example

In [None]:
elem_regex = re.compile(r"[A-Z][a-z]?[0-9]*")
atom_regex = re.compile(r"[A-Z][a-z]?")
dgts_regex = re.compile(r"[0-9]*")

formula_examples = [
    "C23H19ClIN3O",
    "C33H49B2N3O4",
    "C13H12BrF3N4OS",
    "C5H18O2P2Si2"]

for i, fml in enumerate(formula_examples):
    print(f"[example{i + 1}: {fml}]")
    print("\tatom with digits:", elem_regex.findall(fml))
    print("\tatom:", atom_regex.findall(fml))
    print("\tdigits", dgts_regex.findall(fml))

In [None]:
# # example for counting method
for fml in formula_examples:
    atom_dict = dict()
    print(f"[formula: {fml}]")
    for elem in elem_regex.findall(fml):
        atom = dgts_regex.sub("", elem)
        dgts = atom_regex.sub("", elem)
        atom_cnt = int(dgts) if len(dgts) > 0 else 1
        atom_dict[atom] = atom_cnt
        print(f"\t[elem:\t{elem}]\tatom: {atom},\tdgts: {dgts},\tatom_cnt: {atom_cnt}")
    print(f"\tresult: {atom_dict}")

In [None]:
# # apply to all train data
atom_dict_list = []
for fml in tqdm(train["formula"].values):
    atom_dict = dict()
    for elem in elem_regex.findall(fml):
        atom = dgts_regex.sub("", elem)
        dgts = atom_regex.sub("", elem)
        atom_cnt = int(dgts) if len(dgts) > 0 else 1
        atom_dict[atom] = atom_cnt
    atom_dict_list.append(atom_dict)
    
atom_df = pd.DataFrame(
    atom_dict_list).fillna(0).astype(int)
atom_df = atom_df.sort_index(axis="columns")

In [None]:
# # merge
for atom in TARGETS:
    train[atom] = atom_df[atom]
train.head()

In [None]:
del atom_df
del atom_dict
del atom_dict_list
gc.collect()

### check atom distribution

In [None]:
# # total number of each atoms
display(train[TARGETS].sum(axis=0))
_ = train[TARGETS].sum(axis=0).T.plot(kind="bar")

In [None]:
# # distribution of n_atoms for each example
train["n_atoms"] = train[TARGETS].sum(axis=1)
_ = train["n_atoms"].hist(bins=100)

# Training Multi-Output Regression

## summary

* base model: resnet18d
* CV Strategy: Multi-Label Stratified KFold(K=5) by atoms in each chemical formula
* image size: 224x224 (simply resize)
* batch size: 64
* epoch: 15
* optimizer: AdamW
* schedular: OneCycleLR
* augmentation: HorizontalFlip, VerticalFlip, ShiftScaleRotate, RandomResizedCrop

**NOTE**: I use **only 4%** of training data for faster trainng.

## settings

In [None]:
settings = yaml.safe_load("""
globals:
  seed: 1086
  device: cuda
  max_epoch: 15
  patience: -1
  use_amp: True
  reduce_data: True
  reduce_div_factor: 25

dataset:
  name: LabeledImageDataset
  train:
    transform_list:
      # - [Resize, {always_apply: True, height: 224, width: 224}]
      - [HorizontalFlip, {p: 0.5}]
      - [VerticalFlip, {p: 0.5}]
      - [ShiftScaleRotate, {
          p: 0.5, shift_limit: 0.2, scale_limit: 0.2,
          rotate_limit: 20, border_mode: 0, value: 0, mask_value: 0}]
      - [RandomResizedCrop, {height: 224, width: 224, scale: [0.9, 1.0]}]
      - [Normalize, {always_apply: True}]
      - [ToTensorV2, {always_apply: True}]
  val:
    transform_list:
      # - [Resize, {always_apply: True, height: 224, width: 224}]
      - [Normalize, {always_apply: True}]
      - [ToTensorV2, {always_apply: True}]

loader:
  train:
    batch_size: 64
    shuffle: True
    num_workers: 4
    pin_memory: True
    drop_last: True
  val:
    batch_size: 128
    shuffle: False
    num_workers: 4
    pin_memory: True
    drop_last: False

model:
  name: BasicImageModel
  params:
    base_name: resnet18d
    dims_head: [null, 12]
    pretrained: True

loss:
  name: MSELoss
  params: {}

eval:
  - {name: MyMSELoss, report_name: loss, params: {}}

optimizer:
  name: AdamW
  params:
    lr: 1.0e-06
    weight_decay: 1.0e-02

scheduler:
  name: OneCycleLR
  params:
    epochs: 15
    max_lr: 1.0e-3
    pct_start: 0.2
    anneal_strategy: cos
    div_factor: 1.0e+3
    final_div_factor: 1.0e+3
""")

## definition

### model

In [None]:
class BasicImageModel(nn.Module):
    
    def __init__(
        self, base_name, dims_head: tp, pretrained=False
    ):
        """Initialize"""
        self.base_name = base_name
        super(BasicImageModel, self).__init__()
        
        # # prepare backbone
        if hasattr(timm.models, base_name):
            # # # load base model
            base_model = timm.create_model(base_name, pretrained=pretrained)
            in_features = base_model.num_features
            # # remove head classifier
            base_model.reset_classifier(0)
        else:
            raise NotImplementedError

        self.backbone = base_model
        
        # # prepare head clasifier
        if dims_head[0] is None:
            dims_head[0] = in_features

        layers_list = []
        for i in range(len(dims_head) - 2):
            in_dim, out_dim = dims_head[i: i + 2]
            layers_list.extend([
                nn.Linear(in_dim, out_dim),
                nn.ReLU(), nn.Dropout(0.5),])
        layers_list.append(
            nn.Linear(dims_head[-2], dims_head[-1]))
        self.head = nn.Sequential(*layers_list)

    def forward(self, x):
        """Forward"""
        h = self.backbone(x)
        h = self.head(h)
        return h

### image dataset

In [None]:
class ImageTransformBase:
    """
    Base Image Transform class.

    Args:
        data_augmentations: List of tuple(method: str, params :dict), each elems pass to albumentations
    """

    def __init__(self, data_augmentations: tp.List[tp.Tuple[str, tp.Dict]]):
        """Initialize."""
        augmentations_list = [
            self._get_augmentation(aug_name)(**params)
            for aug_name, params in data_augmentations]
        self.data_aug = albumentations.Compose(augmentations_list)

    def __call__(self, pair: tp.Tuple[np.ndarray]) -> tp.Tuple[np.ndarray]:
        """You have to implement this by task"""
        raise NotImplementedError

    def _get_augmentation(self, aug_name: str) -> tp.Tuple[ImageOnlyTransform, DualTransform]:
        """Get augmentations from albumentations"""
        if hasattr(albumentations, aug_name):
            return getattr(albumentations, aug_name)
        else:
            return eval(aug_name)


class ImageTransformForCls(ImageTransformBase):
    """Data Augmentor for Classification Task."""

    def __init__(self, data_augmentations: tp.List[tp.Tuple[str, tp.Dict]]):
        """Initialize."""
        super(ImageTransformForCls, self).__init__(data_augmentations)

    def __call__(self, in_arrs: tp.Tuple[np.ndarray]) -> tp.Tuple[np.ndarray]:
        """Apply Transform."""
        img, label = in_arrs
        augmented = self.data_aug(image=img)
        img = augmented["image"]

        return img, label

In [None]:
class LabeledImageDataset(data.Dataset):
    """Dataset class for (image, label) pairs"""

    def __init__(
        self,
        file_list: tp.List[
            tp.Tuple[tp.Union[str, Path], tp.Union[int, float, np.ndarray]]],
        transform_list: tp.List[tp.Dict],
    ):
        """Initialize"""
        self.file_list = file_list
        self.transform = ImageTransformForCls(transform_list)

    def __len__(self):
        """Return Num of Images."""
        return len(self.file_list)

    def __getitem__(self, index):
        """Return transformed image and mask for given index."""
        img_path, label = self.file_list[index]
        img = self._read_image_as_array(img_path)

        img, label = self.transform((img, label))
        return img, label

    def _read_image_as_array(self, path: str):
        """Read image file and convert into numpy.ndarray"""
        img_arr = cv2.imread(str(path))
        img_arr = cv2.cvtColor(img_arr, cv2.COLOR_BGR2RGB)
        return img_arr

### metric for evaluation

In [None]:
class EvalFuncManager(nn.Module):
    """Manager Class for evaluation at the end of epoch"""

    def __init__(
        self,
        evalfunc_dict: tp.Dict[str, nn.Module],
        iters_per_epoch: int,
        prefix: str = "val"
    ) -> None:
        """Initialize"""
        self.tmp_iter = 0
        self.iters_per_epoch = iters_per_epoch
        self.prefix = prefix
        self.metric_names = []
        super(EvalFuncManager, self).__init__()
        for k, v in evalfunc_dict.items():
            setattr(self, k, v)
            self.metric_names.append(k)
        self.reset()

    def reset(self) -> None:
        """Reset State."""
        self.tmp_iter = 0
        for name in self.metric_names:
            getattr(self, name).reset()

    def __call__(self, y: torch.Tensor, t: torch.Tensor) -> None:
        """Forward."""
        for name in self.metric_names:
            getattr(self, name).update(y, t)
        self.tmp_iter += 1

        if self.tmp_iter == self.iters_per_epoch:
            ppe.reporting.report({
                "{}/{}".format(self.prefix, name): getattr(self, name).compute()
                for name in self.metric_names
            })
            self.reset()
            
            
class MeanLoss(nn.Module):
    
    def __init__(self):
        super(MeanLoss, self).__init__()
        self.loss_sum = 0
        self.n_examples = 0
        
    def forward(self, y: torch.Tensor, t: torch.Tensor):
        """Compute metric at once"""
        return self.loss_func(y, t)

    def reset(self):
        """Reset state"""
        self.loss_sum = 0
        self.n_examples = 0
    
    def update(self, y: torch.Tensor, t: torch.Tensor):
        """Update metric by mini batch"""
        self.loss_sum += self(y, t).item() * y.shape[0]
        self.n_examples += y.shape[0]

    def compute(self):
        """Compute metric for dataset"""
        return self.loss_sum / self.n_examples
    
    
class MyMSELoss(MeanLoss):

    def __init__(self, **params):
        super(MyMSELoss, self).__init__()
        self.loss_func = nn.MSELoss(**params)

### get XXX

In [None]:
def get_file_list(stgs, train_all):
    """Get file path and target info."""
    use_fold = stgs["globals"]["val_fold"]
    
    train_df = train_all[train_all["fold"] != use_fold]
    val_df = train_all[train_all["fold"] == use_fold]
    
#     train_data_dir = TRAIN
    train_data_dir = TRAIN_224
    train_file_list = list(zip([
#         train_data_dir / f"{img_id[0]}/{img_id[1]}/{img_id[2]}/{img_id}.png"
        train_data_dir / f"{img_id}.jpg"
        for img_id in train_df["image_id"].values
    ], train_df[TARGETS].values.astype("f")))

    val_file_list = list(zip([
#         train_data_dir / f"{img_id[0]}/{img_id[1]}/{img_id[2]}/{img_id}.png"
        train_data_dir / f"{img_id}.jpg"
        for img_id in val_df["image_id"].values
    ], val_df[TARGETS].values.astype("f")))

    if stgs["globals"]["reduce_data"]:
        div = stgs["globals"]["reduce_div_factor"]
        trn_smpl_idx = np.random.choice(range(len(train_df)), len(train_df) // div)
        val_smpl_idx = np.random.choice(range(len(val_df)), len(val_df) // div)
        train_file_list = [train_file_list[idx] for idx in trn_smpl_idx]
        val_file_list = [val_file_list[idx] for idx in val_smpl_idx]
        
    return train_file_list, val_file_list


def get_dataloaders(
    stgs: tp.Dict,
    train_file_list: tp.List[tp.List],
    val_file_list: tp.List[tp.List],
    dataset_class: data.Dataset
):
    """Create DataLoader"""
    train_loader = val_loader = None
    if train_file_list is not None:
        train_dataset = dataset_class(
            train_file_list, **stgs["dataset"]["train"])
        train_loader = data.DataLoader(
            train_dataset, **stgs["loader"]["train"])

    if val_file_list is not None:
        val_dataset = dataset_class(
            val_file_list, **stgs["dataset"]["val"])
        val_loader = data.DataLoader(
            val_dataset, **stgs["loader"]["val"])

    return train_loader, val_loader

In [None]:
def get_model(args):
    """"""
    return eval(args["name"])(**args["params"])


def get_optimizer(args, model):
    """"""
    if hasattr(torch.optim, args["name"]):
        opt_class = getattr(torch.optim, args["name"])
    else:
        opt_class = eval(args["name"])

    return opt_class(model.parameters(), **args["params"])


def get_scheduler(args, optimizer, max_epoch, steps_per_epoch):
    """"""
    if args["name"] == "OneCycleLR":
        args["params"]["epochs"] = max_epoch
        args["params"]["steps_per_epoch"] = steps_per_epoch

    if hasattr(torch.optim.lr_scheduler, args["name"]):
        scdr_class = getattr(torch.optim.lr_scheduler, args["name"])
    else:
        scdr_class = eval(args["name"])

    return scdr_class(optimizer, **args["params"])


def get_loss_function(args):
    """"""
    if hasattr(nn, args["name"]):
        loss_class = getattr(nn, args["name"])
    else:
        loss_class = eval(args["name"])

    return loss_class(**args["params"])


def get_stepper(manager, stgs, scheduler):
    """"""
    if stgs["scheduler"]["name"] == "CosineAnnealingWarmRestarts":
        def step_scheduler_by_epoch():
            pass

        def step_scheduler_by_iter():
            scheduler.step(manager.epoch_detail)

    elif stgs["scheduler"]["name"] == "OneCycleLR":
        def step_scheduler_by_epoch():
            pass

        def step_scheduler_by_iter():
            scheduler.step()

    else:
        def step_scheduler_by_epoch():
            scheduler.step()

        def step_scheduler_by_iter():
            pass

    return step_scheduler_by_epoch, step_scheduler_by_iter

In [None]:
def get_manager(
    stgs, model, device, train_loader, val_loader, optimizer,
    eval_manager, output_path, print_progress: bool = False,
):
    """"""
    # # initialize manager
    if stgs["globals"]["patience"] > 0:
        trigger = ppe.training.triggers.EarlyStoppingTrigger(
            check_trigger=(1, 'epoch'), monitor='val/loss', mode="min",
            patience=stgs["globals"]["patience"], verbose=True,
            max_trigger=(stgs["globals"]["max_epoch"], 'epoch'))
    else:
        trigger = None    
    manager = ppe.training.ExtensionsManager(
        model, optimizer, stgs["globals"]["max_epoch"],
        iters_per_epoch=len(train_loader), stop_trigger=trigger, out_dir=output_path)

    # # for logging
    eval_names = ["val/{}".format(name) for name in eval_manager.metric_names]    
    log_extentions = [
        ppe_exts.observe_lr(optimizer=optimizer),
        ppe_exts.LogReport(),
        ppe_exts.PlotReport(["train/loss", "val/loss"], 'epoch', filename='loss.png'),
        # ppe_exts.PlotReport(["val/metric"], 'epoch', filename='metric.png'),
        ppe_exts.PlotReport(["lr"], 'epoch', filename='lr.png'),
        ppe_exts.PrintReport([
            "epoch", "iteration", "lr", "train/loss", *eval_names, "elapsed_time"])
    ]
    if print_progress:
        log_extentions.append(ppe_exts.ProgressBar(update_interval=20))

    for ext in log_extentions:
        manager.extend(ext)
        
    # # for evaluation
    def eval_func(*batch):
        return run_eval(stgs, model, device, batch, eval_manager)
    manager.extend(
        ppe_exts.Evaluator(val_loader, model, eval_func=eval_func),
        trigger=(1, "epoch"))
    
    # # for saving snapshot
    manager.extend(
        ppe_exts.snapshot(target=model, filename="snapshot_by_loss_epoch_{.epoch}.pth"),
        trigger=ppe.training.triggers.MinValueTrigger(key="val/loss", trigger=(1, 'epoch')))

    return manager

### training utils

In [None]:
def set_random_seed(seed: int = 42, deterministic: bool = False):
    """Set seeds"""
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = deterministic  # type: ignore

In [None]:
def run_train_loop(
    manager, stgs, model, device, train_loader, optimizer, scheduler, loss_func
):
    """Run minibatch training loop"""
    step_scheduler_by_epoch, step_scheduler_by_iter = get_stepper(manager, stgs, scheduler)
 
    use_amp = stgs["globals"]["use_amp"]
    while not manager.stop_trigger:
        model.train()
        scaler = torch.cuda.amp.GradScaler() if use_amp else None
        for batch in train_loader:
            with manager.run_iteration():
                x = batch[0].to(device)
                t = batch[-1].to(device)
                optimizer.zero_grad()
                if use_amp:
                    with torch.cuda.amp.autocast():
                        y = model(x)
                        loss = loss_func(y, t)    
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    y = model(x)
                    loss = loss_func(y, t)
                    loss.backward()
                    optimizer.step()

                ppe.reporting.report({'train/loss': loss.item()})    
                step_scheduler_by_iter()
        step_scheduler_by_epoch()


def run_eval(stgs, model, device, batch, eval_manager):
    """Run evaliation for val or test. this function is applied to each batch."""
    model.eval()
    x = batch[0].to(device)
    t = batch[-1].to(device)
    if stgs["globals"]["use_amp"]:
        with torch.cuda.amp.autocast(): 
            y = model(x)
            eval_manager(y, t)
    else:
        y = model(x)
        eval_manager(y, t)

In [None]:
def train_one_fold(stgs, train_all, output_path, print_progress=False):
    """train one fold"""
    torch.backends.cudnn.benchmark = True
    set_random_seed(stgs["globals"]["seed"])

    # # prepare train, valid paths
    train_file_list, val_file_list = get_file_list(stgs, train_all)
    print("train: {}, val: {}".format(len(train_file_list), len(val_file_list)))

    device = torch.device(stgs["globals"]["device"])
    # # get data_loader
    train_loader, val_loader = get_dataloaders(
        stgs, train_file_list, val_file_list, LabeledImageDataset)

    # # get model
    model = BasicImageModel(**stgs["model"]["params"])
    model.to(device)

    # # get optimizer
    optimizer = getattr(
        torch.optim, stgs["optimizer"]["name"]
    )(model.parameters(), **stgs["optimizer"]["params"])

    # # get scheduler
    if stgs["scheduler"]["name"] == "OneCycleLR":
        stgs["scheduler"]["params"]["epochs"] = stgs["globals"]["max_epoch"]
        stgs["scheduler"]["params"]["steps_per_epoch"] = len(train_loader)
    scheduler = getattr(
        torch.optim.lr_scheduler, stgs["scheduler"]["name"]
    )(optimizer, **stgs["scheduler"]["params"])

    # # get loss
    if hasattr(nn, stgs["loss"]["name"]):
        loss_func = getattr(nn, stgs["loss"]["name"])(**stgs["loss"]["params"])
    else:
        loss_func = eval(stgs["loss"]["name"])(**stgs["loss"]["params"])
    loss_func.to(device)

    eval_manager = EvalFuncManager(
        {
            metric["report_name"]: eval(metric["name"])(**metric["params"])
            for metric in stgs["eval"]
        }, len(val_loader))
    eval_manager.to(device)

    # # get manager
    manager = get_manager(
        stgs, model, device, train_loader, val_loader,
        optimizer, eval_manager, output_path, print_progress)

    run_train_loop(
        manager, stgs, model, device, train_loader,
        optimizer, scheduler, loss_func)

## train

### split fold

In [None]:
mskf = MultilabelStratifiedKFold(n_splits=5, shuffle=True, random_state=RANDAM_SEED)
train["fold"] = -1
for fold_id, (trn_idx, val_idx) in enumerate(
    mskf.split(train.InChI, (train[TARGETS] > 0).astype(int))
):
    train.loc[val_idx, "fold"] = fold_id

In [None]:
train.groupby("fold")[TARGETS].sum()

In [None]:
train.groupby("fold")["n_atoms"].agg(
    ["sum", "mean", "median", "std", "max", "min"])

In [None]:
(train.InChI.str.len()).groupby(train["fold"]).agg(
    ["sum", "mean", "median", "std", "max", "min"])

### run training

In [None]:
torch.cuda.empty_cache()
gc.collect()

In [None]:
stgs_list = []
for fold_id in FOLDS:
    tmp_stgs = copy.deepcopy(settings)
    tmp_stgs["globals"]["val_fold"] = fold_id
    stgs_list.append(tmp_stgs)

In [None]:
for fold_id, tmp_stgs in zip(FOLDS, stgs_list):
    train_one_fold(tmp_stgs, train, TMP / f"fold{fold_id}", False)
    torch.cuda.empty_cache()
    gc.collect()

## Inference OOF

### copy best model

In [None]:
best_log_list = []
for fold_id, tmp_stgs in enumerate(stgs_list):
    exp_dir_path = TMP / f"fold{fold_id}"
    log = pd.read_json(exp_dir_path / "log")
    best_log = log.iloc[[log["val/loss"].idxmin()],]
    best_epoch = best_log.epoch.values[0]
    best_log_list.append(best_log)
    
    best_model_path = exp_dir_path / f"snapshot_by_loss_epoch_{best_epoch}.pth"
    copy_to = f"./best_loss_model_fold{fold_id}.pth"
    shutil.copy(best_model_path, copy_to)
    
    for p in exp_dir_path.glob("*.pth"):
        p.unlink()
    
    shutil.copytree(exp_dir_path, f"./fold{fold_id}")
    
    with open(f"./fold{fold_id}/settings.yml", "w") as fw:
        yaml.dump(tmp_stgs, fw)
    
pd.concat(best_log_list, axis=0, ignore_index=True)

### inference

In [None]:
def run_inference_loop(stgs, model, loader, device):
    model.to(device)
    model.eval()
    pred_list = []
    with torch.no_grad():
        for x, t in tqdm(loader):
            if stgs["globals"]["use_amp"]:
                with torch.cuda.amp.autocast():
                    y = model(x.to(device))
            else:
                y = model(x.to(device))
            pred_list.append(y.detach().cpu().numpy())
        
    pred_arr = np.concatenate(pred_list)
    del pred_list
    return pred_arr

In [None]:
oof_pred_arr = np.zeros((len(train), N_TARGETS))
label_arr = train[TARGETS].values
score_list = []

for fold_id in range(N_FOLD):
    print(f"[fold {fold_id}]")
    tmp_dir = Path(f"./fold{fold_id}")
    with open(tmp_dir / "settings.yml", "r") as fr:
        tmp_stgs = yaml.safe_load(fr)
    device = torch.device(tmp_stgs["globals"]["device"])
    val_idx = train.query("fold == @fold_id").index.values
    
    # # get data_loader
    tmp_stgs["globals"]["reduce_data"] = False
    _, val_file_list = get_file_list(tmp_stgs, train)
    _, val_loader = get_dataloaders(
        tmp_stgs, None, val_file_list, LabeledImageDataset)
#     val_idx = val_idx[:len(val_file_list)]
    
    # # get and load model
    model_path =f"./best_loss_model_fold{fold_id}.pth"
    tmp_stgs["model"]["params"]["pretrained"] = False
    model = BasicImageModel(**tmp_stgs["model"]["params"])
    model.load_state_dict(torch.load(model_path, map_location=device))

    val_pred = run_inference_loop(tmp_stgs, model, val_loader, device)
    val_loss = mean_squared_error(label_arr[val_idx], val_pred)
    print(f"[fold {fold_id}] val loss: {val_loss:.5f}")
    oof_pred_arr[val_idx] = val_pred
    score_list.append([fold_id, val_loss])
    
    del model
    del val_pred
    torch.cuda.empty_cache()
    gc.collect()

In [None]:
oof_loss = mean_squared_error(label_arr, oof_pred_arr)
score_list.append(["oof", oof_loss])

In [None]:
pd.DataFrame(score_list, columns=["fold", "mse"])

In [None]:
oof_pred_arr.shape

In [None]:
oof_df = train.copy()
oof_df[TARGETS] = oof_pred_arr
oof_df.to_csv("./oof_prediction.csv", index=False)

In [None]:
train.to_pickle("./train_formula_mlskf_5fold.pkl")