# Cross Validation

In [1]:
import sys

sys.path.append("..")
sys.path.append("../../inputs")

In [2]:
import logging

logging.basicConfig(
    # filename=__file__.replace('.py', '.log'),
    stream=sys.stdout,
    level=logging.getLevelName("INFO"),
    format="%(asctime)s [%(levelname)s] [%(module)s] %(message)s",
)

log = logging.getLogger(__name__)

In [3]:
from omegaconf import OmegaConf

c = OmegaConf.load("../config/main.yaml")

c.settings.debug = False
c.wandb.enabled = False
c.wandb.group = "LB"
c.wandb.dir = "../../cache/"
c.settings.dirs.working = ".."
c.settings.dirs.input = "../../inputs/"

pretraind_dir = "../../datasets/trainings"

In [4]:
pretrained = f"""


- dir: {pretraind_dir}/2022-02-04_23-43-25/fold0/
  model: ump_1
- dir: {pretraind_dir}/2022-02-04_23-43-27/fold1/
  model: ump_1
- dir: {pretraind_dir}/2022-02-04_23-43-29/fold2/
  model: ump_1
- dir: {pretraind_dir}/2022-02-04_23-43-31/fold3/
  model: ump_1
- dir: {pretraind_dir}/2022-02-04_23-43-33/fold4/
  model: ump_1
- dir: {pretraind_dir}/2022-02-05_08-05-00/fold5/
  model: ump_1
- dir: {pretraind_dir}/2022-02-05_08-05-02/fold6/
  model: ump_1
- dir: {pretraind_dir}/2022-02-05_08-05-04/fold7/
  model: ump_1
- dir: {pretraind_dir}/2022-02-05_08-05-06/fold8/
  model: ump_1
- dir: {pretraind_dir}/2022-02-05_08-05-08/fold9/
  model: ump_1
- dir: {pretraind_dir}/2022-02-05_20-09-55/fold10/
  model: ump_1
- dir: {pretraind_dir}/2022-02-05_20-09-57/fold11/
  model: ump_1
- dir: {pretraind_dir}/2022-02-05_20-09-59/fold12/
  model: ump_1
- dir: {pretraind_dir}/2022-02-05_20-10-01/fold13/
  model: ump_1
- dir: {pretraind_dir}/2022-02-05_20-10-04/fold14/
  model: ump_1
"""

_pretrained = f"""
- dir: {pretraind_dir}/2022-02-08_16-16-05/fold0/
  model: lightgbm

"""

c.params.pretrained = OmegaConf.create(pretrained)

In [5]:
log.info(OmegaConf.to_yaml(c))

2022-02-09 16:37:21,830 [INFO] [3244290467] defaults:
- _self_
hydra:
  run:
    dir: ../outputs/${now:%Y-%m-%d_%H-%M-%S}
  job_logging:
    formatters:
      simple:
        format: '%(asctime)s [%(levelname)s][%(module)s] %(message)s'
wandb:
  enabled: false
  entity: imokuri
  project: ump
  dir: ../../cache/
  group: LB
settings:
  print_freq: 100
  gpus: 6,7
  dirs:
    working: ..
    input: ../../inputs/
    feature: ${settings.dirs.input}features/
    preprocess: ${settings.dirs.input}preprocess/
  inputs:
  - train.csv
  - example_test.csv
  - example_sample_submission.csv
  debug: false
  n_debug_data: 100000
  amp: true
  multi_gpu: true
  training_method: nn
params:
  seed: 440
  n_class: 1
  preprocess: false
  n_fold: 5
  skip_training: false
  epoch: 20
  es_patience: 0
  batch_size: 640
  gradient_acc_step: 1
  max_grad_norm: 1000
  fold: simple_cpcv
  group_name: investment_id
  time_name: time_id
  label_name: target
  use_feature: true
  feature_set:
  - f000
  datas

In [6]:
import os

import pandas as pd
import src.utils as utils
from src.get_score import record_result
from tqdm.notebook import tqdm

In [7]:
run = utils.setup_wandb(c)

In [8]:
train = pd.read_feather("../../inputs/train.f")
train = train.loc[:, ["row_id", "time_id", "target"]]
train.set_index("row_id", inplace=True)

In [9]:
preds_col = []

# 各 OOF の結果を読み取り
for n, training in tqdm(enumerate(c.params.pretrained), total=len(c.params.pretrained)):
    preds_col.append(f"preds{n}")
    oof_df = pd.read_feather(os.path.join(training.dir.rsplit("/", 2)[0], "oof_df.f")).set_index("row_id")

    if training.model == "lightgbm":
        train[f"preds{n}"] = oof_df["preds"].groupby("row_id").sum()
    else:
        train[f"preds{n}"] = oof_df["preds"]

# 各行の OOF の結果の数をカウントする
train["count_oof"] = len(c.params.pretrained) - train.isnull().sum(axis=1)  # + 4

# OOF の結果がない行を 0 埋め
train.fillna(0, inplace=True)

# OOF の結果をマージ
train["preds"] = 0
for col in preds_col:
    train["preds"] += train[col]

# 推論結果がないものは除外
train = train[train["preds"] != 0.0]

# 複数 OOF を加算している場合は、OOFの数で割る
train["preds"] = train["preds"] / train["count_oof"]

  0%|          | 0/15 [00:00<?, ?it/s]

In [10]:
train

Unnamed: 0_level_0,time_id,target,preds0,preds1,preds2,preds3,preds4,preds5,preds6,preds7,preds8,preds9,preds10,preds11,preds12,preds13,preds14,count_oof,preds
row_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
0_1,0,-0.300875,0.086914,0.148071,0.126465,0.085754,0.060791,0.0,0.0,0.0,0.000000,0.0,0.0,0.000000,0.0,0.000000,0.000000,5,0.101599
0_2,0,-0.231040,-0.026031,0.000782,-0.011414,-0.036316,-0.036072,0.0,0.0,0.0,0.000000,0.0,0.0,0.000000,0.0,0.000000,0.000000,5,-0.021810
0_6,0,0.568807,0.061188,0.118774,0.101624,0.077271,0.022461,0.0,0.0,0.0,0.000000,0.0,0.0,0.000000,0.0,0.000000,0.000000,5,0.076263
0_7,0,-1.064780,-0.040497,-0.025497,-0.046356,-0.087952,-0.143311,0.0,0.0,0.0,0.000000,0.0,0.0,0.000000,0.0,0.000000,0.000000,5,-0.068723
0_8,0,-0.531940,0.014252,0.027084,0.070129,0.102966,0.011040,0.0,0.0,0.0,0.000000,0.0,0.0,0.000000,0.0,0.000000,0.000000,5,0.045094
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1219_3768,1219,0.033600,0.000000,0.000000,0.000000,0.000000,-0.045837,0.0,0.0,0.0,-0.047180,0.0,0.0,-0.065125,0.0,-0.038483,-0.047699,5,-0.048865
1219_3769,1219,-0.223264,0.000000,0.000000,0.000000,0.000000,-0.047729,0.0,0.0,0.0,-0.058228,0.0,0.0,-0.071167,0.0,-0.052917,-0.026688,5,-0.051346
1219_3770,1219,-0.559415,0.000000,0.000000,0.000000,0.000000,0.068298,0.0,0.0,0.0,0.034485,0.0,0.0,0.073364,0.0,0.083984,0.035339,5,0.059094
1219_3772,1219,0.009599,0.000000,0.000000,0.000000,0.000000,-0.001007,0.0,0.0,0.0,0.001371,0.0,0.0,-0.003759,0.0,0.008049,-0.046967,5,-0.008463


In [11]:
train["count_oof"].value_counts()

5    3141390
Name: count_oof, dtype: int64

In [12]:
record_result(c, train, c.params.n_fold)

2022-02-09 16:37:47,352 [INFO] [get_score] Score: 0.12424


0.12424088145380552