# About this notebook ...

## Library

In [1]:
import glob
import json
import math
import os
import random
import time
import warnings
from collections import defaultdict
from contextlib import contextmanager

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold
from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts, ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm

In [2]:
warnings.filterwarnings("ignore")

## Config

In [3]:
config_defaults = {
    "seed": 440,
    "data_size": 7_000_000,
    "n_class": 4,
    "n_fold": 10,
    "geese_net_layers": 12,
    "geese_net_filters": 48,
    "gradient_accumulation_steps": 1,
    "max_grad_norm": 1000,
    "num_workers": 4,
    "batch_size": 3200,
    "epochs": 10,
    "scheduler": "CosineAnnealingWarmRestarts",
    "criterion": "CrossEntropyLoss",
    "lr": 1e-3,
    "min_lr": 1e-4,
    "weight_decay": 1e-5,
    "model_name": "geese_net_alpha",
}

In [4]:
if config_defaults["scheduler"] == "CosineAnnealingWarmRestarts":
    config_defaults["T_0"] = config_defaults["epochs"]

elif config_defaults["scheduler"] == "CosineAnnealingLR":
    config_defaults["T_max"] = config_defaults["epochs"]

elif config_defaults["scheduler"] == "ReduceLROnPlateau":
    config_defaults["factor"] = 0.2
    config_defaults["patience"] = 4
    config_defaults["eps"] = 1e-6

In [5]:
class Config:
    pre_train_file = ""
    print_freq = 100
    train = True
    debug = False
    apex = False

In [6]:
if Config.debug:
    wandb.init(project="hungry-geese", config=config_defaults, mode="disabled")
else:
    wandb.init(project="hungry-geese", config=config_defaults)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mimokuri[0m (use `wandb login --relogin` to force relogin)


In [7]:
config = wandb.config

In [8]:
if Config.debug:
    config.update({"epochs": 1, "data_size": 10_000}, allow_val_change=True)

In [9]:
if Config.apex:
    from apex import amp

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Load Data

In [11]:
BASE_DIR = "../input/hungrygeeseepisode/hungry-geese-episode/"
OUTPUT_DIR = "pre-models/"

if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

In [12]:
paths = [path for path in glob.glob(BASE_DIR + "*.json") if "info" not in path]
print(len(paths))

29048


## Utils

In [13]:
@contextmanager
def timer(name):
    t0 = time.time()
    LOGGER.info(f"[{name}] start")
    yield
    LOGGER.info(f"[{name}] done in {time.time() - t0:.0f} s.")


def init_logger(log_file=OUTPUT_DIR + "train.log"):
    from logging import INFO, FileHandler, Formatter, StreamHandler, getLogger

    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger


LOGGER = init_logger()


def seed_torch(seed=42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


seed_torch(seed=config.seed)

In [14]:
def ident(y):
    return y


def reverse_ns(y):
    if y == 0:
        return 1
    if y == 1:
        return 0
    return y


def reverse_we(y):
    if y == 2:
        return 3
    if y == 3:
        return 2
    return y


def reverse_nswe(y):
    return reverse_ns(reverse_we(y))

In [15]:
def no_flip(image):
    return image


def h_flip(image):
    return image[:, :, ::-1]


def v_flip(image):
    return image[:, ::-1, :]


def hv_flip(image):
    return image[:, ::-1, ::-1]

## Observation

In [16]:
next_position_map = {}
for pos in range(77):
    position = []
    position.append((11 * (1 + pos // 11) + pos % 11) % 77)
    position.append((11 * (-1 + pos // 11) + pos % 11) % 77)
    position.append((11 * (pos // 11) + (pos + 1) % 11) % 77)
    position.append((11 * (pos // 11) + (pos - 1) % 11) % 77)
    next_position_map[pos] = set(position)

In [17]:
def make_input(obses):
    b = np.zeros((17, 7 * 11), dtype=np.float32)
    obs = obses[-1]

    for p, pos_list in enumerate(obs["geese"]):
        pid = (p - obs["index"]) % 4

        # head position
        for pos in pos_list[:1]:
            b[0 + pid, pos] = 1
        # tip position
        for pos in pos_list[-1:]:
            b[4 + pid, pos] = 1
        # whole position
        for pos in pos_list:
            b[8 + pid, pos] = 1

    # previous head position
    if len(obses) > 1:
        obs_prev = obses[-2]
        for p, pos_list in enumerate(obs_prev["geese"]):
            for pos in pos_list[:1]:
                b[12 + (p - obs["index"]) % 4, pos] = 1

    # food
    for pos in obs["food"]:
        b[16, pos] = 1

    return b.reshape(-1, 7, 11)

In [18]:
def get_reverse_cube(obses):
    """
    尻尾から順番に 1, 0.9, 0.8, ... という並び
    """
    b = np.zeros((4, 7 * 11), dtype=np.float32)
    obs = obses[-1]

    for p, geese in enumerate(obs["geese"]):
        # whole position reverse
        for num_reverse, pos in enumerate(geese[::-1]):
            b[(p - obs["index"]) % 4, pos] = 1 - num_reverse * 0.1

    return b.reshape(-1, 7, 11)

In [19]:
def get_next_disappear_cube(obses):
    """
    次になくなる場所: 1
    次になくなる可能性のある場所: 0.5
    """
    b = np.zeros((4, 7 * 11), dtype=np.float32)
    obs = obses[-1]
    step = obs["step"]

    # foodを食べる可能性があるか。
    eat_food_possibility = defaultdict(int)
    for p, geese in enumerate(obs["geese"]):
        for pos in geese[:1]:
            if not next_position_map[pos].isdisjoint(obs["food"]):
                eat_food_possibility[p] = 1

    if (step % 40) == 39:  # 1つ短くなる
        for p, geese in enumerate(obs["geese"]):
            if eat_food_possibility[p]:  # 尻尾が1、尻尾の１つ前0.5
                for pos in geese[-1:]:
                    b[(p - obs["index"]) % 4, pos] = 1
                for pos in geese[-2:-1]:
                    b[(p - obs["index"]) % 4, pos] = 0.5
            else:  # 食べる可能性なし -> 尻尾が1, 尻尾の1つ前1
                for pos in geese[-2:]:
                    b[(p - obs["index"]) % 4, pos] = 1
    else:  # 1つ短くならない
        for p, geese in enumerate(obs["geese"]):
            if eat_food_possibility[p]:  # 食べる可能性があり -> 尻尾を0.5
                for pos in geese[-1:]:
                    b[(p - obs["index"]) % 4, pos] = 0.5
            else:  # 食べる可能性なし # 尻尾を1
                for pos in geese[-1:]:
                    b[(p - obs["index"]) % 4, pos] = 1

    return b.reshape(-1, 7, 11)

In [20]:
def get_step_cube_v3(obses):
    b = np.zeros((2, 7, 11), dtype=np.float32)
    obs = obses[-1]
    step = obs["step"]

    b[0, :, :] = (step - 188) / 10 if step > 188 else 0
    b[1, :, :] = (step % 40 - 29) / 10 if step % 40 > 29 else 0

    return b

In [21]:
def get_length_cube_v2(obses):
    b = np.zeros((3, 7, 11), dtype=np.float32)
    obs = obses[-1]

    my_length = len(obs["geese"][obs["index"]])
    o1_length = len(obs["geese"][(obs["index"] + 1) % 4])
    o2_length = len(obs["geese"][(obs["index"] + 2) % 4])
    o3_length = len(obs["geese"][(obs["index"] + 3) % 4])

    b[0, :, :] = max(min((my_length - o1_length) * 0.1 + 0.5, 1.0), -1.0)
    b[1, :, :] = max(min((my_length - o2_length) * 0.1 + 0.5, 1.0), -1.0)
    b[2, :, :] = max(min((my_length - o3_length) * 0.1 + 0.5, 1.0), -1.0)

    return b

## Data

In [22]:
X_train = np.zeros((config.data_size, 30, 7, 11), dtype=np.float32)
y_train = np.zeros((config.data_size,), dtype=np.uint8)

X_count = 0
y_count = 0

In [23]:
def create_dataset_from_json(filepath, json_object=None, standing=0):
    global X_train
    global y_train
    global X_count
    global y_count

    if json_object is None:
        json_open = open(path, "r")
        json_load = json.load(json_open)
    else:
        json_load = json_object

    try:
        winner_index = np.argmax(np.argsort(json_load["rewards"]) == 3 - standing)

        obses = []
        actions = {"NORTH": 0, "SOUTH": 1, "WEST": 2, "EAST": 3}

        for i in range(len(json_load["steps"]) - 1):
            if json_load["steps"][i][winner_index]["status"] == "ACTIVE":
                y_ = json_load["steps"][i + 1][winner_index]["action"]
                if y_ is not None:
                    step = json_load["steps"][i]
                    step[winner_index]["observation"]["geese"] = step[0]["observation"]["geese"]
                    step[winner_index]["observation"]["food"] = step[0]["observation"]["food"]
                    step[winner_index]["observation"]["step"] = step[0]["observation"]["step"]
                    obses.append(step[winner_index]["observation"])

                    for func in [ident, reverse_ns, reverse_we, reverse_nswe]:
                        if y_count >= config.data_size:
                            break

                        y_train[y_count] = func(actions[y_])
                        y_count += 1

                    if y_count >= config.data_size:
                        break

        for j in range(len(obses)):
            # X_ = make_input(obses[: j + 1])

            # 反転可能な特徴量
            X_ = []
            X_.append(make_input(obses[: j + 1]))
            X_.append(get_reverse_cube(obses[: j + 1]))
            X_.append(get_next_disappear_cube(obses[: j + 1]))

            # 反転不可能な特徴量
            X_i = []
            X_i.append(get_step_cube_v3(obses[: j + 1]))
            X_i.append(get_length_cube_v2(obses[: j + 1]))

            X_ = np.concatenate(X_)
            X_i = np.concatenate(X_i)

            for func in [no_flip, v_flip, h_flip, hv_flip]:
                if X_count >= config.data_size:
                    break

                X_train[X_count] = np.concatenate([func(X_), X_i])
                X_count += 1

            if X_count >= config.data_size:
                break

        return
    except Exception as e:
        if Config.debug:
            raise Exception from e
        return

In [24]:
for path in tqdm(paths[::-1]):
    create_dataset_from_json(path, standing=0)  # use only winners' moves
    if X_count >= config.data_size:
        break

print(f"Num episode: {len(X_train):,}")

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=29048.0), HTML(value='')))


Num episode: 7,000,000


In [25]:
if Config.debug:
    X_train = X_train[:1000]
    y_train = y_train[:1000]

In [26]:
y_df = pd.DataFrame(y_train, dtype=np.uint8)
y_df.columns = ["action"]
y_df

Unnamed: 0,action
0,2
1,2
2,3
3,3
4,0
...,...
6999995,1
6999996,0
6999997,1
6999998,0


## CV Split

In [27]:
folds = y_df.copy()
Fold = StratifiedKFold(n_splits=config.n_fold, shuffle=True, random_state=config.seed)
for n, (train_index, val_index) in enumerate(Fold.split(folds, folds["action"])):
    folds.loc[val_index, "fold"] = int(n)
folds["fold"] = folds["fold"].astype(np.uint8)
print(folds.groupby(["fold", "action"]).size())

fold  action
0     0         169034
      1         169034
      2         180966
      3         180966
1     0         169034
      1         169034
      2         180966
      3         180966
2     0         169035
      1         169034
      2         180966
      3         180965
3     0         169035
      1         169034
      2         180966
      3         180965
4     0         169035
      1         169034
      2         180966
      3         180965
5     0         169035
      1         169034
      2         180966
      3         180965
6     0         169034
      1         169035
      2         180965
      3         180966
7     0         169034
      1         169035
      2         180965
      3         180966
8     0         169034
      1         169035
      2         180965
      3         180966
9     0         169034
      1         169035
      2         180965
      3         180966
dtype: int64


## Dataset

In [28]:
class TrainDataset(Dataset):
    def __init__(self, array, label):
        self.array = array
        self.label = label

    def __len__(self):
        return self.array.shape[0]

    def __getitem__(self, idx):
        return self.array[idx], torch.tensor(self.label[idx]).long()


class TestDataset(Dataset):
    def __init__(self, array):
        self.array = array

    def __len__(self):
        return self.array.shape[0]

    def __getitem__(self, idx):
        return self.array[idx]

In [29]:
# Test

if Config.debug or False:
    train_ds = TrainDataset(X_train, y_train)

    for i in range(1):
        obs, action = train_ds[i]
        print(obs.shape, action)

## Model

In [30]:
class TorusConv2d(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, bn):
        super().__init__()
        self.edge_size = (kernel_size[0] // 2, kernel_size[1] // 2)
        self.conv = nn.Conv2d(input_dim, output_dim, kernel_size=kernel_size)
        self.bn = nn.BatchNorm2d(output_dim) if bn else None

    def forward(self, x):
        h = torch.cat([x[:, :, :, -self.edge_size[1] :], x, x[:, :, :, : self.edge_size[1]]], dim=3)
        h = torch.cat([h[:, :, -self.edge_size[0] :], h, h[:, :, : self.edge_size[0]]], dim=2)
        h = self.conv(h)
        h = self.bn(h) if self.bn is not None else h
        return h

In [32]:
class GeeseNetAlpha(nn.Module):
    def __init__(self):
        super().__init__()
        layers, filters = 12, 32
        hidden = filters // 4
        dim1 = filters * 2
        dim2 = dim1 // 2

        self.conv0 = TorusConv2d(30, filters, (3, 3), True)
        self.cnn_blocks = nn.ModuleList([TorusConv2d(filters, filters, (3, 3), True) for _ in range(layers)])

        self.conv_p = TorusConv2d(filters, filters, (3, 3), True)
        self.conv_v = TorusConv2d(filters, filters, (3, 3), True)

        self.head_p1 = nn.Linear(dim1, dim2, bias=False)
        self.head_p2 = nn.Linear(dim2, 4, bias=False)
        self.head_v1 = nn.Linear(dim1, dim2, bias=False)
        self.head_v2 = nn.Linear(dim2, 1, bias=False)

    def forward(self, x, _=None):
        h = F.relu_(self.conv0(x))
        for cnn in self.cnn_blocks:
            h = F.relu_(h + cnn(h))

        h_p = F.relu_(self.conv_p(h))
        h_head_p = (h_p * x[:, :1]).view(h_p.size(0), h_p.size(1), -1).sum(-1)
        h_avg_p = h_p.view(h_p.size(0), h_p.size(1), -1).mean(-1)

        h_p = F.relu_(self.head_p1(torch.cat([h_head_p, h_avg_p], 1)))
        p = self.head_p2(h_p)

        h_v = F.relu_(self.conv_v(h))
        h_head_v = (h_v * x[:, :1]).view(h_v.size(0), h_v.size(1), -1).sum(-1)
        h_avg_v = h_v.view(h_v.size(0), h_v.size(1), -1).mean(-1)

        h_v = F.relu_(self.head_v1(torch.cat([h_head_v, h_avg_v], 1)))
        v = torch.tanh(self.head_v2(h_v))

        return {"policy": p, "value": v}

In [33]:
# Test

if Config.debug or False:
    model = GeeseNetAlpha()
    # print(model)

    params = sum(p.numel() for p in model.parameters())
    print(f"params: {params:,}")

    train_ds = TrainDataset(X_train, y_train)
    train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

    for obs, action in train_loader:
        print(f"input shape: {obs.shape}")
        output = model(obs)
        print(output)
        print(f"{torch.argmax(output['policy'], dim=1)}")
        break

## Loss

## Scoring

In [34]:
def get_score(y_true, y_pred):
    return accuracy_score(y_true, y_pred)

In [35]:
def get_result(result_df):
    preds = result_df["preds"].values
    labels = result_df["action"].values
    score = get_score(labels, preds)
    LOGGER.info(f"Score: {score:<.5f}")
    return score

## Helper functions

In [36]:
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return "%dm %ds" % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return "%s (remain %s)" % (asMinutes(s), asMinutes(rs))

In [37]:
def train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, device):
    losses = AverageMeter()

    # switch to train mode
    model.train()
    start = time.time()

    for step, (obs, action) in enumerate(train_loader):
        obs = obs.to(device)
        action = action.to(device)
        batch_size = action.size(0)

        y_preds = model(obs)["policy"]

        loss = criterion(y_preds, action)
        losses.update(loss.item(), batch_size)
        if config.gradient_accumulation_steps > 1:
            loss = loss / config.gradient_accumulation_steps
        if Config.apex:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)

        if (step + 1) % config.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        if step % Config.print_freq == 0 or step == (len(train_loader) - 1):
            print(
                f"Epoch: [{epoch + 1}][{step}/{len(train_loader)}] "
                f"Elapsed {timeSince(start, float(step + 1) / len(train_loader)):s} "
                f"Loss avg.: {losses.avg:.4f} "
                f"Grad: {grad_norm:.4f} "
                f"LR: {scheduler.get_last_lr()[0]:.5f}  "
            )

    return losses.avg

In [38]:
def valid_fn(valid_loader, model, criterion, device):
    losses = AverageMeter()

    # switch to evaluation mode
    model.eval()
    preds = []
    start = time.time()

    for step, (obs, action) in enumerate(valid_loader):
        obs = obs.to(device)
        action = action.to(device)
        batch_size = action.size(0)

        # compute loss
        with torch.no_grad():
            y_preds = model(obs)["policy"]

        loss = criterion(y_preds, action)
        losses.update(loss.item(), batch_size)

        # record accuracy
        preds.append(y_preds.softmax(1).to("cpu").numpy())
        if config.gradient_accumulation_steps > 1:
            loss = loss / config.gradient_accumulation_steps

        if step % Config.print_freq == 0 or step == (len(valid_loader) - 1):
            print(
                f"Eval: [{step}/{len(valid_loader)}] "
                f"Elapsed {timeSince(start, float(step + 1) / len(valid_loader)):s} "
                f"Loss avg.: {losses.avg:.4f} "
            )
    predictions = np.concatenate(preds)
    return losses.avg, predictions

## Train loop

In [39]:
def train_loop(folds, fold):

    LOGGER.info(f"========== fold: {fold} training ==========")

    # ====================================================
    # Data Loader
    # ====================================================
    # X_train_folds = X_train[folds["fold"] != fold]
    # X_valid_folds = X_train[folds["fold"] == fold]

    # y_train_folds = y_train[folds["fold"] != fold]
    y_valid_folds = y_train[folds["fold"] == fold]

    # y_df_train_folds = y_df[folds["fold"] != fold]
    y_df_valid_folds = y_df[folds["fold"] == fold]

    # train_dataset = TrainDataset(X_train[folds["fold"] != fold], y_train[folds["fold"] != fold])
    # valid_dataset = TrainDataset(X_train[folds["fold"] == fold], y_valid_folds)

    train_loader = DataLoader(
        TrainDataset(X_train[folds["fold"] != fold], y_train[folds["fold"] != fold]),
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True,
        drop_last=True,
    )
    valid_loader = DataLoader(
        TrainDataset(X_train[folds["fold"] == fold], y_valid_folds),
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True,
        drop_last=False,
    )

    # ====================================================
    # Scheduler
    # ====================================================
    def get_scheduler(optimizer):
        if config.scheduler == "ReduceLROnPlateau":
            scheduler = ReduceLROnPlateau(
                optimizer, mode="min", factor=config.factor, patience=config.patience, verbose=True, eps=config.eps
            )
        elif config.scheduler == "CosineAnnealingLR":
            scheduler = CosineAnnealingLR(optimizer, T_max=config.T_max, eta_min=config.min_lr, last_epoch=-1)
        elif config.scheduler == "CosineAnnealingWarmRestarts":
            scheduler = CosineAnnealingWarmRestarts(
                optimizer, T_0=config.T_0, T_mult=1, eta_min=config.min_lr, last_epoch=-1
            )
        return scheduler

    # ====================================================
    # model & optimizer
    # ====================================================
    model = GeeseNetAlpha()
    # try:
    #     model.load_state_dict(torch.load(os.path.join(OUTPUT_DIR, Config.pre_train_file)))
    # except:
    #     print(f"Failed to load pre-train weight.")

    # Disable training for value network
    # for param in model.head_v1.parameters():
    #     param.requires_grad = False
    # for param in model.head_v2.parameters():
    #     param.requires_grad = False

    model.to(device)

    # Use multi GPU
    if device == torch.device("cuda") and not Config.apex:
        model = torch.nn.DataParallel(model)  # make parallel

    optimizer = Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay, amsgrad=False)
    scheduler = get_scheduler(optimizer)

    # ====================================================
    # apex
    # ====================================================
    if Config.apex:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=0)

    # ====================================================
    # Criterion
    # ====================================================
    def get_criterion():
        if config.criterion == "CrossEntropyLoss":
            criterion = nn.CrossEntropyLoss()
        return criterion

    criterion = get_criterion()

    # ====================================================
    # loop
    # ====================================================
    best_score = 0.0
    best_loss = np.inf
    best_preds = None

    # wandb.watch(model, log_freq=Config.print_freq)

    for epoch in range(config.epochs):

        start_time = time.time()

        # train
        avg_loss = train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, device)

        # eval
        avg_val_loss, preds = valid_fn(valid_loader, model, criterion, device)

        if isinstance(scheduler, ReduceLROnPlateau):
            scheduler.step(avg_val_loss)
        elif isinstance(scheduler, CosineAnnealingLR):
            scheduler.step()
        elif isinstance(scheduler, CosineAnnealingWarmRestarts):
            scheduler.step()

        # scoring
        score = get_score(y_valid_folds, preds.argmax(1))

        elapsed = time.time() - start_time

        LOGGER.info(
            f"Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s"
        )
        LOGGER.info(f"Epoch {epoch+1} - Accuracy: {score}")

        wandb.log(
            {
                "epoch": epoch + 1,
                f"loss/train_fold{fold}": avg_loss,
                f"loss/val_fold{fold}": avg_val_loss,
                f"accuracy/fold{fold}": score,
            }
        )

        if score > best_score:
            best_score = score
            LOGGER.info(f"Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model")
            torch.save(model.module.state_dict(), OUTPUT_DIR + f"{config.model_name}_fold{fold}_best.pth")
            best_preds = preds

        if epoch == config.epochs - 1:
            LOGGER.info(f"Epoch {epoch+1} - Save final model")
            torch.save(model.module.state_dict(), OUTPUT_DIR + f"{config.model_name}_fold{fold}_final.pth")

    y_df_valid_folds[[str(c) for c in range(config.n_class)]] = best_preds
    y_df_valid_folds["preds"] = best_preds.argmax(1)

    return y_df_valid_folds

## Main


In [40]:
def main():
    if Config.train:
        # train
        oof_df = pd.DataFrame()
        for fold in range(config.n_fold):
            _oof_df = train_loop(folds, fold)
            oof_df = pd.concat([oof_df, _oof_df])
            LOGGER.info(f"========== fold: {fold} result ==========")
            get_result(_oof_df)
            break  # fold 1つだけ
        # CV result
        # LOGGER.info(f"========== CV ==========")
        # get_result(oof_df)
        # save result
        oof_df.to_csv(OUTPUT_DIR + "oof_df.csv", index=False)

In [41]:
if __name__ == "__main__":
    main()



Epoch: [1][0/1968] Elapsed 0m 54s (remain 1790m 47s) Loss avg.: 1.3949 Grad: 0.2840 LR: 0.00100  
Epoch: [1][100/1968] Elapsed 1m 24s (remain 26m 5s) Loss avg.: 0.7426 Grad: 0.7254 LR: 0.00100  
Epoch: [1][200/1968] Elapsed 1m 53s (remain 16m 41s) Loss avg.: 0.6633 Grad: 1.6122 LR: 0.00100  
Epoch: [1][300/1968] Elapsed 2m 23s (remain 13m 13s) Loss avg.: 0.6292 Grad: 0.8556 LR: 0.00100  
Epoch: [1][400/1968] Elapsed 2m 52s (remain 11m 15s) Loss avg.: 0.6073 Grad: 0.7743 LR: 0.00100  
Epoch: [1][500/1968] Elapsed 3m 22s (remain 9m 51s) Loss avg.: 0.5918 Grad: 0.5208 LR: 0.00100  
Epoch: [1][600/1968] Elapsed 3m 51s (remain 8m 46s) Loss avg.: 0.5797 Grad: 0.6178 LR: 0.00100  
Epoch: [1][700/1968] Elapsed 4m 20s (remain 7m 51s) Loss avg.: 0.5697 Grad: 0.8074 LR: 0.00100  
Epoch: [1][800/1968] Elapsed 4m 49s (remain 7m 2s) Loss avg.: 0.5614 Grad: 0.6357 LR: 0.00100  
Epoch: [1][900/1968] Elapsed 5m 19s (remain 6m 18s) Loss avg.: 0.5545 Grad: 0.7044 LR: 0.00100  
Epoch: [1][1000/1968] Elaps

Epoch 1 - avg_train_loss: 0.5164  avg_val_loss: 0.4828  time: 655s
Epoch 1 - Accuracy: 0.7886271428571429
Epoch 1 - Save Best Score: 0.7886 Model


Epoch: [2][0/1968] Elapsed 0m 3s (remain 115m 47s) Loss avg.: 0.4771 Grad: 0.4759 LR: 0.00098  
Epoch: [2][100/1968] Elapsed 0m 32s (remain 10m 6s) Loss avg.: 0.4729 Grad: 0.7491 LR: 0.00098  
Epoch: [2][200/1968] Elapsed 1m 2s (remain 9m 6s) Loss avg.: 0.4731 Grad: 0.4477 LR: 0.00098  
Epoch: [2][300/1968] Elapsed 1m 31s (remain 8m 27s) Loss avg.: 0.4725 Grad: 0.4139 LR: 0.00098  
Epoch: [2][400/1968] Elapsed 2m 0s (remain 7m 52s) Loss avg.: 0.4730 Grad: 0.4912 LR: 0.00098  
Epoch: [2][500/1968] Elapsed 2m 30s (remain 7m 19s) Loss avg.: 0.4724 Grad: 0.4351 LR: 0.00098  
Epoch: [2][600/1968] Elapsed 2m 59s (remain 6m 48s) Loss avg.: 0.4716 Grad: 0.4767 LR: 0.00098  
Epoch: [2][700/1968] Elapsed 3m 28s (remain 6m 17s) Loss avg.: 0.4713 Grad: 0.7635 LR: 0.00098  
Epoch: [2][800/1968] Elapsed 3m 58s (remain 5m 46s) Loss avg.: 0.4706 Grad: 0.4517 LR: 0.00098  
Epoch: [2][900/1968] Elapsed 4m 27s (remain 5m 16s) Loss avg.: 0.4703 Grad: 0.7394 LR: 0.00098  
Epoch: [2][1000/1968] Elapsed 4m 5

Epoch 2 - avg_train_loss: 0.4668  avg_val_loss: 0.4621  time: 604s
Epoch 2 - Accuracy: 0.80009
Epoch 2 - Save Best Score: 0.8001 Model


Epoch: [3][0/1968] Elapsed 0m 3s (remain 117m 55s) Loss avg.: 0.4621 Grad: 0.4014 LR: 0.00091  
Epoch: [3][100/1968] Elapsed 0m 32s (remain 10m 8s) Loss avg.: 0.4579 Grad: 0.4388 LR: 0.00091  
Epoch: [3][200/1968] Elapsed 1m 2s (remain 9m 8s) Loss avg.: 0.4581 Grad: 0.3429 LR: 0.00091  
Epoch: [3][300/1968] Elapsed 1m 31s (remain 8m 27s) Loss avg.: 0.4589 Grad: 0.3277 LR: 0.00091  
Epoch: [3][400/1968] Elapsed 2m 0s (remain 7m 52s) Loss avg.: 0.4589 Grad: 0.5213 LR: 0.00091  
Epoch: [3][500/1968] Elapsed 2m 30s (remain 7m 19s) Loss avg.: 0.4589 Grad: 0.3505 LR: 0.00091  
Epoch: [3][600/1968] Elapsed 2m 59s (remain 6m 48s) Loss avg.: 0.4587 Grad: 0.3941 LR: 0.00091  
Epoch: [3][700/1968] Elapsed 3m 28s (remain 6m 17s) Loss avg.: 0.4586 Grad: 0.3499 LR: 0.00091  
Epoch: [3][800/1968] Elapsed 3m 58s (remain 5m 46s) Loss avg.: 0.4584 Grad: 0.3223 LR: 0.00091  
Epoch: [3][900/1968] Elapsed 4m 27s (remain 5m 16s) Loss avg.: 0.4582 Grad: 0.4561 LR: 0.00091  
Epoch: [3][1000/1968] Elapsed 4m 5

Epoch 3 - avg_train_loss: 0.4568  avg_val_loss: 0.4539  time: 605s
Epoch 3 - Accuracy: 0.8047471428571429
Epoch 3 - Save Best Score: 0.8047 Model


Epoch: [4][0/1968] Elapsed 0m 3s (remain 112m 17s) Loss avg.: 0.4634 Grad: 0.3080 LR: 0.00081  
Epoch: [4][100/1968] Elapsed 0m 32s (remain 10m 5s) Loss avg.: 0.4522 Grad: 0.2994 LR: 0.00081  
Epoch: [4][200/1968] Elapsed 1m 2s (remain 9m 7s) Loss avg.: 0.4523 Grad: 0.3357 LR: 0.00081  
Epoch: [4][300/1968] Elapsed 1m 31s (remain 8m 27s) Loss avg.: 0.4512 Grad: 0.3725 LR: 0.00081  
Epoch: [4][400/1968] Elapsed 2m 0s (remain 7m 52s) Loss avg.: 0.4512 Grad: 0.5433 LR: 0.00081  
Epoch: [4][500/1968] Elapsed 2m 30s (remain 7m 19s) Loss avg.: 0.4513 Grad: 0.3716 LR: 0.00081  
Epoch: [4][600/1968] Elapsed 2m 59s (remain 6m 48s) Loss avg.: 0.4514 Grad: 0.4530 LR: 0.00081  
Epoch: [4][700/1968] Elapsed 3m 28s (remain 6m 17s) Loss avg.: 0.4515 Grad: 0.2874 LR: 0.00081  
Epoch: [4][800/1968] Elapsed 3m 58s (remain 5m 46s) Loss avg.: 0.4514 Grad: 0.2707 LR: 0.00081  
Epoch: [4][900/1968] Elapsed 4m 27s (remain 5m 16s) Loss avg.: 0.4515 Grad: 0.2931 LR: 0.00081  
Epoch: [4][1000/1968] Elapsed 4m 5

Epoch 4 - avg_train_loss: 0.4508  avg_val_loss: 0.4510  time: 605s
Epoch 4 - Accuracy: 0.8060042857142857
Epoch 4 - Save Best Score: 0.8060 Model


Epoch: [5][0/1968] Elapsed 0m 3s (remain 130m 19s) Loss avg.: 0.4241 Grad: 0.3875 LR: 0.00069  
Epoch: [5][100/1968] Elapsed 0m 33s (remain 10m 16s) Loss avg.: 0.4463 Grad: 0.3918 LR: 0.00069  
Epoch: [5][200/1968] Elapsed 1m 2s (remain 9m 12s) Loss avg.: 0.4462 Grad: 0.3042 LR: 0.00069  
Epoch: [5][300/1968] Elapsed 1m 32s (remain 8m 30s) Loss avg.: 0.4459 Grad: 0.3228 LR: 0.00069  
Epoch: [5][400/1968] Elapsed 2m 1s (remain 7m 54s) Loss avg.: 0.4462 Grad: 0.3137 LR: 0.00069  
Epoch: [5][500/1968] Elapsed 2m 30s (remain 7m 21s) Loss avg.: 0.4462 Grad: 0.4773 LR: 0.00069  
Epoch: [5][600/1968] Elapsed 2m 59s (remain 6m 49s) Loss avg.: 0.4459 Grad: 0.3543 LR: 0.00069  
Epoch: [5][700/1968] Elapsed 3m 29s (remain 6m 18s) Loss avg.: 0.4463 Grad: 0.3202 LR: 0.00069  
Epoch: [5][800/1968] Elapsed 3m 58s (remain 5m 47s) Loss avg.: 0.4460 Grad: 0.2503 LR: 0.00069  
Epoch: [5][900/1968] Elapsed 4m 27s (remain 5m 17s) Loss avg.: 0.4460 Grad: 0.3003 LR: 0.00069  
Epoch: [5][1000/1968] Elapsed 4m

Epoch 5 - avg_train_loss: 0.4460  avg_val_loss: 0.4479  time: 603s
Epoch 5 - Accuracy: 0.8066814285714285
Epoch 5 - Save Best Score: 0.8067 Model


Epoch: [6][0/1968] Elapsed 0m 3s (remain 110m 25s) Loss avg.: 0.4517 Grad: 0.2852 LR: 0.00055  
Epoch: [6][100/1968] Elapsed 0m 32s (remain 10m 3s) Loss avg.: 0.4416 Grad: 0.2851 LR: 0.00055  
Epoch: [6][200/1968] Elapsed 1m 2s (remain 9m 6s) Loss avg.: 0.4426 Grad: 0.2434 LR: 0.00055  
Epoch: [6][300/1968] Elapsed 1m 31s (remain 8m 26s) Loss avg.: 0.4422 Grad: 0.3680 LR: 0.00055  
Epoch: [6][400/1968] Elapsed 2m 0s (remain 7m 51s) Loss avg.: 0.4424 Grad: 0.2878 LR: 0.00055  
Epoch: [6][500/1968] Elapsed 2m 30s (remain 7m 19s) Loss avg.: 0.4428 Grad: 0.3051 LR: 0.00055  
Epoch: [6][600/1968] Elapsed 2m 59s (remain 6m 47s) Loss avg.: 0.4427 Grad: 0.3776 LR: 0.00055  
Epoch: [6][700/1968] Elapsed 3m 28s (remain 6m 16s) Loss avg.: 0.4423 Grad: 0.2771 LR: 0.00055  
Epoch: [6][800/1968] Elapsed 3m 57s (remain 5m 46s) Loss avg.: 0.4420 Grad: 0.2834 LR: 0.00055  
Epoch: [6][900/1968] Elapsed 4m 27s (remain 5m 16s) Loss avg.: 0.4419 Grad: 0.2814 LR: 0.00055  
Epoch: [6][1000/1968] Elapsed 4m 5

Epoch 6 - avg_train_loss: 0.4420  avg_val_loss: 0.4439  time: 604s
Epoch 6 - Accuracy: 0.80921
Epoch 6 - Save Best Score: 0.8092 Model


Epoch: [7][0/1968] Elapsed 0m 3s (remain 108m 56s) Loss avg.: 0.4206 Grad: 0.2226 LR: 0.00041  
Epoch: [7][100/1968] Elapsed 0m 32s (remain 10m 4s) Loss avg.: 0.4381 Grad: 0.2883 LR: 0.00041  
Epoch: [7][200/1968] Elapsed 1m 2s (remain 9m 6s) Loss avg.: 0.4377 Grad: 0.3768 LR: 0.00041  
Epoch: [7][300/1968] Elapsed 1m 31s (remain 8m 26s) Loss avg.: 0.4381 Grad: 0.3899 LR: 0.00041  
Epoch: [7][400/1968] Elapsed 2m 0s (remain 7m 51s) Loss avg.: 0.4381 Grad: 0.2856 LR: 0.00041  
Epoch: [7][500/1968] Elapsed 2m 29s (remain 7m 19s) Loss avg.: 0.4381 Grad: 0.3066 LR: 0.00041  
Epoch: [7][600/1968] Elapsed 2m 59s (remain 6m 47s) Loss avg.: 0.4379 Grad: 0.2621 LR: 0.00041  
Epoch: [7][700/1968] Elapsed 3m 28s (remain 6m 16s) Loss avg.: 0.4382 Grad: 0.3011 LR: 0.00041  
Epoch: [7][800/1968] Elapsed 3m 57s (remain 5m 46s) Loss avg.: 0.4385 Grad: 0.3024 LR: 0.00041  
Epoch: [7][900/1968] Elapsed 4m 27s (remain 5m 16s) Loss avg.: 0.4386 Grad: 0.2453 LR: 0.00041  
Epoch: [7][1000/1968] Elapsed 4m 5

Epoch 7 - avg_train_loss: 0.4386  avg_val_loss: 0.4431  time: 604s
Epoch 7 - Accuracy: 0.8093828571428572
Epoch 7 - Save Best Score: 0.8094 Model


Epoch: [8][0/1968] Elapsed 0m 3s (remain 109m 54s) Loss avg.: 0.4385 Grad: 0.3875 LR: 0.00029  
Epoch: [8][100/1968] Elapsed 0m 32s (remain 10m 4s) Loss avg.: 0.4358 Grad: 0.2992 LR: 0.00029  
Epoch: [8][200/1968] Elapsed 1m 2s (remain 9m 6s) Loss avg.: 0.4353 Grad: 0.3822 LR: 0.00029  
Epoch: [8][300/1968] Elapsed 1m 31s (remain 8m 26s) Loss avg.: 0.4351 Grad: 0.2532 LR: 0.00029  
Epoch: [8][400/1968] Elapsed 2m 0s (remain 7m 51s) Loss avg.: 0.4352 Grad: 0.2536 LR: 0.00029  
Epoch: [8][500/1968] Elapsed 2m 29s (remain 7m 19s) Loss avg.: 0.4351 Grad: 0.2722 LR: 0.00029  
Epoch: [8][600/1968] Elapsed 2m 59s (remain 6m 47s) Loss avg.: 0.4350 Grad: 0.2613 LR: 0.00029  
Epoch: [8][700/1968] Elapsed 3m 28s (remain 6m 16s) Loss avg.: 0.4352 Grad: 0.2542 LR: 0.00029  
Epoch: [8][800/1968] Elapsed 3m 57s (remain 5m 46s) Loss avg.: 0.4352 Grad: 0.3421 LR: 0.00029  
Epoch: [8][900/1968] Elapsed 4m 27s (remain 5m 16s) Loss avg.: 0.4351 Grad: 0.3088 LR: 0.00029  
Epoch: [8][1000/1968] Elapsed 4m 5

Epoch 8 - avg_train_loss: 0.4355  avg_val_loss: 0.4402  time: 604s
Epoch 8 - Accuracy: 0.8108542857142857
Epoch 8 - Save Best Score: 0.8109 Model


Epoch: [9][0/1968] Elapsed 0m 3s (remain 110m 20s) Loss avg.: 0.4455 Grad: 0.2769 LR: 0.00019  
Epoch: [9][100/1968] Elapsed 0m 32s (remain 10m 6s) Loss avg.: 0.4325 Grad: 0.2671 LR: 0.00019  
Epoch: [9][200/1968] Elapsed 1m 2s (remain 9m 7s) Loss avg.: 0.4324 Grad: 0.2460 LR: 0.00019  
Epoch: [9][300/1968] Elapsed 1m 31s (remain 8m 27s) Loss avg.: 0.4325 Grad: 0.2733 LR: 0.00019  
Epoch: [9][400/1968] Elapsed 2m 0s (remain 7m 52s) Loss avg.: 0.4323 Grad: 0.3128 LR: 0.00019  
Epoch: [9][500/1968] Elapsed 2m 30s (remain 7m 20s) Loss avg.: 0.4324 Grad: 0.2499 LR: 0.00019  
Epoch: [9][600/1968] Elapsed 2m 59s (remain 6m 48s) Loss avg.: 0.4325 Grad: 0.2605 LR: 0.00019  
Epoch: [9][700/1968] Elapsed 3m 29s (remain 6m 17s) Loss avg.: 0.4327 Grad: 0.3019 LR: 0.00019  
Epoch: [9][800/1968] Elapsed 3m 58s (remain 5m 47s) Loss avg.: 0.4328 Grad: 0.2625 LR: 0.00019  
Epoch: [9][900/1968] Elapsed 4m 27s (remain 5m 17s) Loss avg.: 0.4329 Grad: 0.3372 LR: 0.00019  
Epoch: [9][1000/1968] Elapsed 4m 5

Epoch 9 - avg_train_loss: 0.4330  avg_val_loss: 0.4387  time: 605s
Epoch 9 - Accuracy: 0.8115785714285715
Epoch 9 - Save Best Score: 0.8116 Model


Epoch: [10][0/1968] Elapsed 0m 3s (remain 109m 57s) Loss avg.: 0.4164 Grad: 0.3117 LR: 0.00012  
Epoch: [10][100/1968] Elapsed 0m 32s (remain 10m 5s) Loss avg.: 0.4290 Grad: 0.2728 LR: 0.00012  
Epoch: [10][200/1968] Elapsed 1m 2s (remain 9m 7s) Loss avg.: 0.4305 Grad: 0.3094 LR: 0.00012  
Epoch: [10][300/1968] Elapsed 1m 31s (remain 8m 26s) Loss avg.: 0.4305 Grad: 0.2593 LR: 0.00012  
Epoch: [10][400/1968] Elapsed 2m 0s (remain 7m 51s) Loss avg.: 0.4306 Grad: 0.3155 LR: 0.00012  
Epoch: [10][500/1968] Elapsed 2m 30s (remain 7m 19s) Loss avg.: 0.4305 Grad: 0.2449 LR: 0.00012  
Epoch: [10][600/1968] Elapsed 2m 59s (remain 6m 47s) Loss avg.: 0.4307 Grad: 0.2464 LR: 0.00012  
Epoch: [10][700/1968] Elapsed 3m 28s (remain 6m 16s) Loss avg.: 0.4310 Grad: 0.2926 LR: 0.00012  
Epoch: [10][800/1968] Elapsed 3m 57s (remain 5m 46s) Loss avg.: 0.4309 Grad: 0.2898 LR: 0.00012  
Epoch: [10][900/1968] Elapsed 4m 27s (remain 5m 16s) Loss avg.: 0.4312 Grad: 0.3259 LR: 0.00012  
Epoch: [10][1000/1968] E

Epoch 10 - avg_train_loss: 0.4312  avg_val_loss: 0.4378  time: 602s
Epoch 10 - Accuracy: 0.812
Epoch 10 - Save Best Score: 0.8120 Model
Epoch 10 - Save final model
Score: 0.81200
