# About this notebook ...

## Library

In [1]:
import glob
import json
import math
import os
import random
import time
import warnings
from collections import defaultdict
from contextlib import contextmanager

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold
from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts, ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm

In [2]:
warnings.filterwarnings("ignore")

## Config

In [3]:
config_defaults = {
    "seed": 440,
    "data_size": 8_000_000,
    "n_class": 4,
    "n_fold": 10,
    "geese_net_layers": 12,
    "geese_net_filters": 48,
    "gradient_accumulation_steps": 1,
    "max_grad_norm": 1000,
    "num_workers": 4,
    "batch_size": 3200,
    "epochs": 20,
    "scheduler": "CosineAnnealingWarmRestarts",
    "criterion": "CrossEntropyLoss",
    "lr": 1e-3,
    "min_lr": 1e-4,
    "weight_decay": 1e-5,
    "model_name": "geese_net_alpha",
}

In [4]:
if config_defaults["scheduler"] == "CosineAnnealingWarmRestarts":
    config_defaults["T_0"] = config_defaults["epochs"]

elif config_defaults["scheduler"] == "CosineAnnealingLR":
    config_defaults["T_max"] = config_defaults["epochs"]

elif config_defaults["scheduler"] == "ReduceLROnPlateau":
    config_defaults["factor"] = 0.2
    config_defaults["patience"] = 4
    config_defaults["eps"] = 1e-6

In [5]:
class Config:
    pre_train_file = ""
    print_freq = 100
    train = True
    debug = False
    apex = False

In [6]:
if Config.debug:
    wandb.init(project="hungry-geese", config=config_defaults, mode="disabled")
else:
    wandb.init(project="hungry-geese", config=config_defaults)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mimokuri[0m (use `wandb login --relogin` to force relogin)


In [7]:
config = wandb.config

In [8]:
if Config.debug:
    config.update({"epochs": 1, "data_size": 10_000}, allow_val_change=True)

In [9]:
if Config.apex:
    from apex import amp

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Load Data

In [11]:
BASE_DIR = "../input/hungrygeeseepisode/hungry-geese-episode/"
OUTPUT_DIR = "pre-models/"

if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

In [12]:
paths = [path for path in glob.glob(BASE_DIR + "*.json") if "info" not in path]
print(len(paths))

29048


## Utils

In [13]:
@contextmanager
def timer(name):
    t0 = time.time()
    LOGGER.info(f"[{name}] start")
    yield
    LOGGER.info(f"[{name}] done in {time.time() - t0:.0f} s.")


def init_logger(log_file=OUTPUT_DIR + "train.log"):
    from logging import INFO, FileHandler, Formatter, StreamHandler, getLogger

    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger


LOGGER = init_logger()


def seed_torch(seed=42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


seed_torch(seed=config.seed)

In [14]:
def ident(y):
    return y


def reverse_ns(y):
    if y == 0:
        return 1
    if y == 1:
        return 0
    return y


def reverse_we(y):
    if y == 2:
        return 3
    if y == 3:
        return 2
    return y


def reverse_nswe(y):
    return reverse_ns(reverse_we(y))

In [15]:
def no_flip(image):
    return image


def h_flip(image):
    return image[:, :, ::-1]


def v_flip(image):
    return image[:, ::-1, :]


def hv_flip(image):
    return image[:, ::-1, ::-1]

## Observation

In [16]:
next_position_map = {}
for pos in range(77):
    position = []
    position.append((11 * (1 + pos // 11) + pos % 11) % 77)
    position.append((11 * (-1 + pos // 11) + pos % 11) % 77)
    position.append((11 * (pos // 11) + (pos + 1) % 11) % 77)
    position.append((11 * (pos // 11) + (pos - 1) % 11) % 77)
    next_position_map[pos] = set(position)

In [17]:
def make_input(obses):
    b = np.zeros((17, 7 * 11), dtype=np.float32)
    obs = obses[-1]

    for p, pos_list in enumerate(obs["geese"]):
        pid = (p - obs["index"]) % 4

        # head position
        for pos in pos_list[:1]:
            b[0 + pid, pos] = 1
        # tip position
        for pos in pos_list[-1:]:
            b[4 + pid, pos] = 1
        # whole position
        for pos in pos_list:
            b[8 + pid, pos] = 1

    # previous head position
    if len(obses) > 1:
        obs_prev = obses[-2]
        for p, pos_list in enumerate(obs_prev["geese"]):
            for pos in pos_list[:1]:
                b[12 + (p - obs["index"]) % 4, pos] = 1

    # food
    for pos in obs["food"]:
        b[16, pos] = 1

    return b.reshape(-1, 7, 11)

In [18]:
def get_reverse_cube(obses):
    """
    尻尾から順番に 1, 0.9, 0.8, ... という並び
    """
    b = np.zeros((4, 7 * 11), dtype=np.float32)
    obs = obses[-1]

    for p, geese in enumerate(obs["geese"]):
        # whole position reverse
        for num_reverse, pos in enumerate(geese[::-1]):
            b[(p - obs["index"]) % 4, pos] = 1 - num_reverse * 0.1

    return b.reshape(-1, 7, 11)

In [19]:
def get_next_disappear_cube(obses):
    """
    次になくなる場所: 1
    次になくなる可能性のある場所: 0.5
    """
    b = np.zeros((4, 7 * 11), dtype=np.float32)
    obs = obses[-1]
    step = obs["step"]

    # foodを食べる可能性があるか。
    eat_food_possibility = defaultdict(int)
    for p, geese in enumerate(obs["geese"]):
        for pos in geese[:1]:
            if not next_position_map[pos].isdisjoint(obs["food"]):
                eat_food_possibility[p] = 1

    if (step % 40) == 39:  # 1つ短くなる
        for p, geese in enumerate(obs["geese"]):
            if eat_food_possibility[p]:  # 尻尾が1、尻尾の１つ前0.5
                for pos in geese[-1:]:
                    b[(p - obs["index"]) % 4, pos] = 1
                for pos in geese[-2:-1]:
                    b[(p - obs["index"]) % 4, pos] = 0.5
            else:  # 食べる可能性なし -> 尻尾が1, 尻尾の1つ前1
                for pos in geese[-2:]:
                    b[(p - obs["index"]) % 4, pos] = 1
    else:  # 1つ短くならない
        for p, geese in enumerate(obs["geese"]):
            if eat_food_possibility[p]:  # 食べる可能性があり -> 尻尾を0.5
                for pos in geese[-1:]:
                    b[(p - obs["index"]) % 4, pos] = 0.5
            else:  # 食べる可能性なし # 尻尾を1
                for pos in geese[-1:]:
                    b[(p - obs["index"]) % 4, pos] = 1

    return b.reshape(-1, 7, 11)

In [20]:
def get_step_cube_v2(obses):
    """
    step0: 0, step199: 1
    step0: 0, step39 + 40n: 1
    """
    b = np.zeros((1, 7, 11), dtype=np.float32)
    obs = obses[-1]
    step = obs["step"]

    b[:, :, :5] = (step % 200) / 199
    b[:, :, 5:] = (step % 40) / 39

    return b

In [21]:
def get_length_cube(obses):
    b = np.zeros((2, 7, 11), dtype=np.float32)
    obs = obses[-1]

    my_length = len(obs["geese"][obs["index"]])
    opposite1_length = len(obs["geese"][(obs["index"] + 1) % 4])
    opposite2_length = len(obs["geese"][(obs["index"] + 2) % 4])
    opposite3_length = len(obs["geese"][(obs["index"] + 3) % 4])

    b[0] = my_length / 10
    max_opposite_length = max(opposite1_length, opposite2_length, opposite3_length)
    b[1, :, 0:2] = (my_length - max_opposite_length) / 10
    b[1, :, 2:5] = (my_length - opposite1_length) / 10
    b[1, :, 5:8] = (my_length - opposite2_length) / 10
    b[1, :, 8:11] = (my_length - opposite3_length) / 10

    return b

In [22]:
def get_features(obses):
    b = np.zeros((7 * 11), dtype=np.float32)
    obs = obses[-1]
    step = obs["step"]

    my_goose = obs["geese"][obs["index"]]
    my_length = len(my_goose)

    # num step
    b[0] = (step - 194) if step >= 195 else 0
    b[1] = (step % 40 - 35) if step % 40 > 35 else 0

    """
    2-4: difference between my_length and opponent length (-3 to 3)
    """
    for p, pos_list in enumerate(obs["geese"]):
        pid = (p - obs["index"]) % 4
        p_length = len(pos_list)

        if pid == 0:
            continue

        b[1 + pid] = max(min(my_length - p_length, 3), -3) + 3

    """
    5-7: difference between my head position and opponent one
    """
    if my_length != 0:

        for p, pos_list in enumerate(obs["geese"]):
            pid = (p - obs["index"]) % 4

            if pid == 0 or len(pos_list) == 0:
                continue

            diff = abs(my_goose[0] - pos_list[0])
            x_ = diff % 11
            x = min(x_, 11 - x_)
            y_ = diff // 11
            y = min(y_, 7 - y_)
            b[4 + pid] = x + y

    return b.reshape(1, 7, 11)

## Data

In [23]:
X_train = np.zeros((config.data_size, 26, 7, 11), dtype=np.float32)
y_train = np.zeros((config.data_size,), dtype=np.uint8)

X_count = 0
y_count = 0

In [24]:
def create_dataset_from_json(filepath, json_object=None, standing=0):
    global X_train
    global y_train
    global X_count
    global y_count

    if json_object is None:
        json_open = open(path, "r")
        json_load = json.load(json_open)
    else:
        json_load = json_object

    try:
        winner_index = np.argmax(np.argsort(json_load["rewards"]) == 3 - standing)

        obses = []
        actions = {"NORTH": 0, "SOUTH": 1, "WEST": 2, "EAST": 3}

        for i in range(len(json_load["steps"]) - 1):
            if json_load["steps"][i][winner_index]["status"] == "ACTIVE":
                y_ = json_load["steps"][i + 1][winner_index]["action"]
                if y_ is not None:
                    step = json_load["steps"][i]
                    step[winner_index]["observation"]["geese"] = step[0]["observation"]["geese"]
                    step[winner_index]["observation"]["food"] = step[0]["observation"]["food"]
                    step[winner_index]["observation"]["step"] = step[0]["observation"]["step"]
                    obses.append(step[winner_index]["observation"])

                    for func in [ident, reverse_ns, reverse_we, reverse_nswe]:
                        if y_count >= config.data_size:
                            break

                        y_train[y_count] = func(actions[y_])
                        y_count += 1

                    if y_count >= config.data_size:
                        break

        for j in range(len(obses)):
            # X_ = make_input(obses[: j + 1])

            # 反転可能な特徴量
            X_ = []
            X_.append(make_input(obses[: j + 1]))
            X_.append(get_reverse_cube(obses[: j + 1]))
            X_.append(get_next_disappear_cube(obses[: j + 1]))

            # 反転不可能な特徴量
            X_i = []
            # X_i.append(get_step_cube_v2(obses[: j + 1]))
            # X_i.append(get_length_cube(obses[: j + 1]))
            X_i.append(get_features(obses[: j + 1]))

            X_ = np.concatenate(X_)
            X_i = np.concatenate(X_i)

            for func in [no_flip, v_flip, h_flip, hv_flip]:
                if X_count >= config.data_size:
                    break

                X_train[X_count] = np.concatenate([func(X_), X_i])
                X_count += 1

            if X_count >= config.data_size:
                break

        return
    except Exception as e:
        if Config.debug:
            raise Exception from e
        return

In [25]:
for path in tqdm(paths[::-1]):
    create_dataset_from_json(path, standing=0)  # use only winners' moves
    if X_count >= config.data_size:
        break

print(f"Num episode: {len(X_train):,}")

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=29048.0), HTML(value='')))


Num episode: 8,000,000


In [26]:
if Config.debug:
    X_train = X_train[:1000]
    y_train = y_train[:1000]

In [27]:
y_df = pd.DataFrame(y_train, dtype=np.uint8)
y_df.columns = ["action"]
y_df

Unnamed: 0,action
0,2
1,2
2,3
3,3
4,0
...,...
7999995,0
7999996,1
7999997,0
7999998,1


## CV Split

In [28]:
folds = y_df.copy()
Fold = StratifiedKFold(n_splits=config.n_fold, shuffle=True, random_state=config.seed)
for n, (train_index, val_index) in enumerate(Fold.split(folds, folds["action"])):
    folds.loc[val_index, "fold"] = int(n)
folds["fold"] = folds["fold"].astype(np.uint8)
print(folds.groupby(["fold", "action"]).size())

fold  action
0     0         193103
      1         193103
      2         206897
      3         206897
1     0         193103
      1         193103
      2         206897
      3         206897
2     0         193104
      1         193103
      2         206897
      3         206896
3     0         193104
      1         193103
      2         206897
      3         206896
4     0         193104
      1         193103
      2         206897
      3         206896
5     0         193104
      1         193103
      2         206897
      3         206896
6     0         193103
      1         193104
      2         206896
      3         206897
7     0         193103
      1         193104
      2         206896
      3         206897
8     0         193103
      1         193104
      2         206896
      3         206897
9     0         193103
      1         193104
      2         206896
      3         206897
dtype: int64


## Dataset

In [29]:
class TrainDataset(Dataset):
    def __init__(self, array, label):
        self.array = array
        self.label = label

    def __len__(self):
        return self.array.shape[0]

    def __getitem__(self, idx):
        return self.array[idx], torch.tensor(self.label[idx]).long()


class TestDataset(Dataset):
    def __init__(self, array):
        self.array = array

    def __len__(self):
        return self.array.shape[0]

    def __getitem__(self, idx):
        return self.array[idx]

In [30]:
# Test

if Config.debug or False:
    train_ds = TrainDataset(X_train, y_train)

    for i in range(1):
        obs, action = train_ds[i]
        print(obs.shape, action)

## Model

In [31]:
class TorusConv2d(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, do=False, bn=True):
        super().__init__()
        self.edge_size = (kernel_size[0] // 2, kernel_size[1] // 2)
        self.conv = nn.Conv2d(input_dim, output_dim, kernel_size=kernel_size)
        self.do = nn.Dropout2d(p=0.1) if do else None
        self.bn = nn.BatchNorm2d(output_dim) if bn else None

    def forward(self, x):
        h = torch.cat([x[:, :, :, -self.edge_size[1] :], x, x[:, :, :, : self.edge_size[1]]], dim=3)
        h = torch.cat([h[:, :, -self.edge_size[0] :], h, h[:, :, : self.edge_size[0]]], dim=2)
        h = self.conv(h)
        h = self.do(h) if self.do is not None else h
        h = self.bn(h) if self.bn is not None else h
        return h

In [32]:
class GeeseNetAlpha(nn.Module):
    def __init__(self):
        super().__init__()

        layers = config.geese_net_layers
        filters = config.geese_net_filters
        dim = filters * 5 + 30

        self.embed_step = nn.Embedding(5, 3)
        self.embed_hunger = nn.Embedding(5, 3)
        self.embed_diff_len = nn.Embedding(7, 4)
        self.embed_diff_head = nn.Embedding(9, 4)

        self.conv0 = TorusConv2d(25, filters, (3, 3))
        self.blocks = nn.ModuleList([TorusConv2d(filters, filters, (3, 3)) for _ in range(layers)])
        self.conv1 = TorusConv2d(filters, filters, (5, 5))

        # self.attention = nn.MultiheadAttention(dim, 1)

        self.head_p1 = nn.Linear(dim, dim // 2, bias=True)
        self.head_p2 = nn.Linear(dim // 2, 4, bias=False)
        self.head_v1 = nn.Linear(dim, dim // 2, bias=True)
        self.head_v2 = nn.Linear(dim // 2, 1, bias=False)

        self.bn_p1 = nn.BatchNorm1d(dim // 2)
        self.bn_v1 = nn.BatchNorm1d(dim // 2)

    def forward(self, x, _=None):
        x_feats = x[:, -1].view(x.size(0), -1).long()

        # Embedding for features
        e_step = self.embed_step(x_feats[:, 0])
        e_hung = self.embed_hunger(x_feats[:, 1])
        e_diff_l = self.embed_diff_len(x_feats[:, 2:5]).view(x.size(0), -1)
        e_diff_h = self.embed_diff_head(x_feats[:, 5:8]).view(x.size(0), -1)

        x = x[:, :-1].float()

        # CNN for observation
        h = F.relu_(self.conv0(x))

        for block in self.blocks:
            h = F.relu_(h + block(h))

        h = F.relu_(h + self.conv1(h))

        # Extract head position
        h_head = (h * x[:, :1]).view(h.size(0), h.size(1), -1).sum(-1)
        h_head2 = (h * x[:, 1:2]).view(h.size(0), h.size(1), -1).sum(-1)
        h_head3 = (h * x[:, 2:3]).view(h.size(0), h.size(1), -1).sum(-1)
        h_head4 = (h * x[:, 3:4]).view(h.size(0), h.size(1), -1).sum(-1)
        h_avg = h.view(h.size(0), h.size(1), -1).mean(-1)

        # Merge features
        h = torch.cat(
            [
                h_head,
                h_head2,
                h_head3,
                h_head4,
                h_avg,
                e_step,
                e_hung,
                e_diff_l,
                e_diff_h,
            ],
            1,
        ).view(1, h.size(0), -1)

        # h, _ = self.attention(h, h, h)

        h_p = F.relu_(self.bn_p1(self.head_p1(h.view(x.size(0), -1))))
        p = self.head_p2(h_p)

        h_v = F.relu_(self.bn_v1(self.head_v1(h.view(x.size(0), -1))))
        v = torch.tanh(self.head_v2(h_v))

        return {"policy": p, "value": v}

In [33]:
# Test

if Config.debug or False:
    model = GeeseNetAlpha()
    # print(model)

    params = sum(p.numel() for p in model.parameters())
    print(f"params: {params:,}")

    train_ds = TrainDataset(X_train, y_train)
    train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

    for obs, action in train_loader:
        print(f"input shape: {obs.shape}")
        output = model(obs)
        print(output)
        print(f"{torch.argmax(output['policy'], dim=1)}")
        break

## Loss

## Scoring

In [34]:
def get_score(y_true, y_pred):
    return accuracy_score(y_true, y_pred)

In [35]:
def get_result(result_df):
    preds = result_df["preds"].values
    labels = result_df["action"].values
    score = get_score(labels, preds)
    LOGGER.info(f"Score: {score:<.5f}")
    return score

## Helper functions

In [36]:
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return "%dm %ds" % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return "%s (remain %s)" % (asMinutes(s), asMinutes(rs))

In [37]:
def train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, device):
    losses = AverageMeter()

    # switch to train mode
    model.train()
    start = time.time()

    for step, (obs, action) in enumerate(train_loader):
        obs = obs.to(device)
        action = action.to(device)
        batch_size = action.size(0)

        y_preds = model(obs)["policy"]

        loss = criterion(y_preds, action)
        losses.update(loss.item(), batch_size)
        if config.gradient_accumulation_steps > 1:
            loss = loss / config.gradient_accumulation_steps
        if Config.apex:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)

        if (step + 1) % config.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        if step % Config.print_freq == 0 or step == (len(train_loader) - 1):
            print(
                f"Epoch: [{epoch + 1}][{step}/{len(train_loader)}] "
                f"Elapsed {timeSince(start, float(step + 1) / len(train_loader)):s} "
                f"Loss avg.: {losses.avg:.4f} "
                f"Grad: {grad_norm:.4f} "
                f"LR: {scheduler.get_last_lr()[0]:.5f}  "
            )

    return losses.avg

In [38]:
def valid_fn(valid_loader, model, criterion, device):
    losses = AverageMeter()

    # switch to evaluation mode
    model.eval()
    preds = []
    start = time.time()

    for step, (obs, action) in enumerate(valid_loader):
        obs = obs.to(device)
        action = action.to(device)
        batch_size = action.size(0)

        # compute loss
        with torch.no_grad():
            y_preds = model(obs)["policy"]

        loss = criterion(y_preds, action)
        losses.update(loss.item(), batch_size)

        # record accuracy
        preds.append(y_preds.softmax(1).to("cpu").numpy())
        if config.gradient_accumulation_steps > 1:
            loss = loss / config.gradient_accumulation_steps

        if step % Config.print_freq == 0 or step == (len(valid_loader) - 1):
            print(
                f"Eval: [{step}/{len(valid_loader)}] "
                f"Elapsed {timeSince(start, float(step + 1) / len(valid_loader)):s} "
                f"Loss avg.: {losses.avg:.4f} "
            )
    predictions = np.concatenate(preds)
    return losses.avg, predictions

## Train loop

In [39]:
def train_loop(folds, fold):

    LOGGER.info(f"========== fold: {fold} training ==========")

    # ====================================================
    # Data Loader
    # ====================================================
    # X_train_folds = X_train[folds["fold"] != fold]
    # X_valid_folds = X_train[folds["fold"] == fold]

    # y_train_folds = y_train[folds["fold"] != fold]
    y_valid_folds = y_train[folds["fold"] == fold]

    # y_df_train_folds = y_df[folds["fold"] != fold]
    y_df_valid_folds = y_df[folds["fold"] == fold]

    # train_dataset = TrainDataset(X_train[folds["fold"] != fold], y_train[folds["fold"] != fold])
    # valid_dataset = TrainDataset(X_train[folds["fold"] == fold], y_valid_folds)

    train_loader = DataLoader(
        TrainDataset(X_train[folds["fold"] != fold], y_train[folds["fold"] != fold]),
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True,
        drop_last=True,
    )
    valid_loader = DataLoader(
        TrainDataset(X_train[folds["fold"] == fold], y_valid_folds),
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True,
        drop_last=False,
    )

    # ====================================================
    # Scheduler
    # ====================================================
    def get_scheduler(optimizer):
        if config.scheduler == "ReduceLROnPlateau":
            scheduler = ReduceLROnPlateau(
                optimizer, mode="min", factor=config.factor, patience=config.patience, verbose=True, eps=config.eps
            )
        elif config.scheduler == "CosineAnnealingLR":
            scheduler = CosineAnnealingLR(optimizer, T_max=config.T_max, eta_min=config.min_lr, last_epoch=-1)
        elif config.scheduler == "CosineAnnealingWarmRestarts":
            scheduler = CosineAnnealingWarmRestarts(
                optimizer, T_0=config.T_0, T_mult=1, eta_min=config.min_lr, last_epoch=-1
            )
        return scheduler

    # ====================================================
    # model & optimizer
    # ====================================================
    model = GeeseNetAlpha()
    # try:
    #     model.load_state_dict(torch.load(os.path.join(OUTPUT_DIR, Config.pre_train_file)))
    # except:
    #     print(f"Failed to load pre-train weight.")

    # Disable training for value network
    # for param in model.head_v1.parameters():
    #     param.requires_grad = False
    # for param in model.head_v2.parameters():
    #     param.requires_grad = False

    model.to(device)

    # Use multi GPU
    if device == torch.device("cuda") and not Config.apex:
        model = torch.nn.DataParallel(model)  # make parallel

    optimizer = Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay, amsgrad=False)
    scheduler = get_scheduler(optimizer)

    # ====================================================
    # apex
    # ====================================================
    if Config.apex:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=0)

    # ====================================================
    # Criterion
    # ====================================================
    def get_criterion():
        if config.criterion == "CrossEntropyLoss":
            criterion = nn.CrossEntropyLoss()
        return criterion

    criterion = get_criterion()

    # ====================================================
    # loop
    # ====================================================
    best_score = 0.0
    best_loss = np.inf
    best_preds = None

    # wandb.watch(model, log_freq=Config.print_freq)

    for epoch in range(config.epochs):

        start_time = time.time()

        # train
        avg_loss = train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, device)

        # eval
        avg_val_loss, preds = valid_fn(valid_loader, model, criterion, device)

        if isinstance(scheduler, ReduceLROnPlateau):
            scheduler.step(avg_val_loss)
        elif isinstance(scheduler, CosineAnnealingLR):
            scheduler.step()
        elif isinstance(scheduler, CosineAnnealingWarmRestarts):
            scheduler.step()

        # scoring
        score = get_score(y_valid_folds, preds.argmax(1))

        elapsed = time.time() - start_time

        LOGGER.info(
            f"Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s"
        )
        LOGGER.info(f"Epoch {epoch+1} - Accuracy: {score}")

        wandb.log({"epoch": epoch + 1, "train_loss": avg_loss, "val_loss": avg_val_loss, "accuracy": score})

        if score > best_score:
            best_score = score
            LOGGER.info(f"Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model")
            torch.save(model.module.state_dict(), OUTPUT_DIR + f"{config.model_name}_fold{fold}_best.pth")
            best_preds = preds

        if epoch == config.epochs - 1:
            LOGGER.info(f"Epoch {epoch+1} - Save final model")
            torch.save(model.module.state_dict(), OUTPUT_DIR + f"{config.model_name}_fold{fold}_final.pth")

    y_df_valid_folds[[str(c) for c in range(config.n_class)]] = best_preds
    y_df_valid_folds["preds"] = best_preds.argmax(1)

    return y_df_valid_folds

## Main


In [40]:
def main():
    if Config.train:
        # train
        oof_df = pd.DataFrame()
        for fold in range(config.n_fold):
            _oof_df = train_loop(folds, fold)
            oof_df = pd.concat([oof_df, _oof_df])
            LOGGER.info(f"========== fold: {fold} result ==========")
            get_result(_oof_df)
            break  # fold 1つだけ
        # CV result
        # LOGGER.info(f"========== CV ==========")
        # get_result(oof_df)
        # save result
        oof_df.to_csv(OUTPUT_DIR + "oof_df.csv", index=False)

In [41]:
if __name__ == "__main__":
    main()



Epoch: [1][0/2250] Elapsed 0m 5s (remain 221m 50s) Loss avg.: 1.4240 Grad: 0.9038 LR: 0.00100  
Epoch: [1][100/2250] Elapsed 1m 6s (remain 23m 30s) Loss avg.: 0.6920 Grad: 0.5552 LR: 0.00100  
Epoch: [1][200/2250] Elapsed 2m 6s (remain 21m 32s) Loss avg.: 0.6289 Grad: 0.4413 LR: 0.00100  
Epoch: [1][300/2250] Elapsed 3m 7s (remain 20m 12s) Loss avg.: 0.5979 Grad: 0.3612 LR: 0.00100  
Epoch: [1][400/2250] Elapsed 4m 8s (remain 19m 3s) Loss avg.: 0.5771 Grad: 0.3872 LR: 0.00100  
Epoch: [1][500/2250] Elapsed 5m 8s (remain 17m 56s) Loss avg.: 0.5631 Grad: 0.4723 LR: 0.00100  
Epoch: [1][600/2250] Elapsed 6m 8s (remain 16m 51s) Loss avg.: 0.5522 Grad: 0.3364 LR: 0.00100  
Epoch: [1][700/2250] Elapsed 7m 9s (remain 15m 48s) Loss avg.: 0.5440 Grad: 0.3549 LR: 0.00100  
Epoch: [1][800/2250] Elapsed 8m 9s (remain 14m 45s) Loss avg.: 0.5370 Grad: 0.3054 LR: 0.00100  
Epoch: [1][900/2250] Elapsed 9m 9s (remain 13m 43s) Loss avg.: 0.5312 Grad: 0.3589 LR: 0.00100  
Epoch: [1][1000/2250] Elapsed 10

Epoch 1 - avg_train_loss: 0.4954  avg_val_loss: 0.4658  time: 1412s
Epoch 1 - Accuracy: 0.7981775
Epoch 1 - Save Best Score: 0.7982 Model


Epoch: [2][0/2250] Elapsed 0m 3s (remain 141m 54s) Loss avg.: 0.4493 Grad: 0.2547 LR: 0.00099  
Epoch: [2][100/2250] Elapsed 1m 4s (remain 22m 51s) Loss avg.: 0.4621 Grad: 0.2467 LR: 0.00099  
Epoch: [2][200/2250] Elapsed 2m 5s (remain 21m 15s) Loss avg.: 0.4605 Grad: 0.2824 LR: 0.00099  
Epoch: [2][300/2250] Elapsed 3m 5s (remain 20m 2s) Loss avg.: 0.4606 Grad: 0.2625 LR: 0.00099  
Epoch: [2][400/2250] Elapsed 4m 6s (remain 18m 56s) Loss avg.: 0.4598 Grad: 0.2411 LR: 0.00099  
Epoch: [2][500/2250] Elapsed 5m 7s (remain 17m 51s) Loss avg.: 0.4596 Grad: 0.2651 LR: 0.00099  
Epoch: [2][600/2250] Elapsed 6m 7s (remain 16m 48s) Loss avg.: 0.4592 Grad: 0.2486 LR: 0.00099  
Epoch: [2][700/2250] Elapsed 7m 8s (remain 15m 46s) Loss avg.: 0.4589 Grad: 0.2560 LR: 0.00099  
Epoch: [2][800/2250] Elapsed 8m 8s (remain 14m 44s) Loss avg.: 0.4592 Grad: 0.2649 LR: 0.00099  
Epoch: [2][900/2250] Elapsed 9m 9s (remain 13m 42s) Loss avg.: 0.4590 Grad: 0.2120 LR: 0.00099  
Epoch: [2][1000/2250] Elapsed 10

Epoch 2 - avg_train_loss: 0.4561  avg_val_loss: 0.4537  time: 1410s
Epoch 2 - Accuracy: 0.80415
Epoch 2 - Save Best Score: 0.8042 Model


Epoch: [3][0/2250] Elapsed 0m 3s (remain 134m 27s) Loss avg.: 0.4435 Grad: 0.1923 LR: 0.00098  
Epoch: [3][100/2250] Elapsed 1m 4s (remain 22m 47s) Loss avg.: 0.4495 Grad: 0.2073 LR: 0.00098  
Epoch: [3][200/2250] Elapsed 2m 4s (remain 21m 11s) Loss avg.: 0.4498 Grad: 0.2100 LR: 0.00098  
Epoch: [3][300/2250] Elapsed 3m 5s (remain 20m 0s) Loss avg.: 0.4493 Grad: 0.1912 LR: 0.00098  
Epoch: [3][400/2250] Elapsed 4m 5s (remain 18m 53s) Loss avg.: 0.4488 Grad: 0.2054 LR: 0.00098  
Epoch: [3][500/2250] Elapsed 5m 6s (remain 17m 49s) Loss avg.: 0.4489 Grad: 0.2218 LR: 0.00098  
Epoch: [3][600/2250] Elapsed 6m 6s (remain 16m 46s) Loss avg.: 0.4486 Grad: 0.1934 LR: 0.00098  
Epoch: [3][700/2250] Elapsed 7m 7s (remain 15m 44s) Loss avg.: 0.4487 Grad: 0.2038 LR: 0.00098  
Epoch: [3][800/2250] Elapsed 8m 7s (remain 14m 42s) Loss avg.: 0.4489 Grad: 0.2074 LR: 0.00098  
Epoch: [3][900/2250] Elapsed 9m 8s (remain 13m 40s) Loss avg.: 0.4488 Grad: 0.2097 LR: 0.00098  
Epoch: [3][1000/2250] Elapsed 10

Epoch 3 - avg_train_loss: 0.4477  avg_val_loss: 0.4527  time: 1410s
Epoch 3 - Accuracy: 0.8046075
Epoch 3 - Save Best Score: 0.8046 Model


Epoch: [4][0/2250] Elapsed 0m 3s (remain 126m 14s) Loss avg.: 0.4615 Grad: 0.2314 LR: 0.00095  
Epoch: [4][100/2250] Elapsed 1m 3s (remain 22m 34s) Loss avg.: 0.4427 Grad: 0.1717 LR: 0.00095  
Epoch: [4][200/2250] Elapsed 2m 4s (remain 21m 4s) Loss avg.: 0.4429 Grad: 0.1976 LR: 0.00095  
Epoch: [4][300/2250] Elapsed 3m 4s (remain 19m 54s) Loss avg.: 0.4431 Grad: 0.1905 LR: 0.00095  
Epoch: [4][400/2250] Elapsed 4m 4s (remain 18m 49s) Loss avg.: 0.4429 Grad: 0.1670 LR: 0.00095  
Epoch: [4][500/2250] Elapsed 5m 5s (remain 17m 46s) Loss avg.: 0.4433 Grad: 0.1904 LR: 0.00095  
Epoch: [4][600/2250] Elapsed 6m 5s (remain 16m 43s) Loss avg.: 0.4431 Grad: 0.2006 LR: 0.00095  
Epoch: [4][700/2250] Elapsed 7m 6s (remain 15m 42s) Loss avg.: 0.4428 Grad: 0.1965 LR: 0.00095  
Epoch: [4][800/2250] Elapsed 8m 7s (remain 14m 41s) Loss avg.: 0.4427 Grad: 0.1840 LR: 0.00095  
Epoch: [4][900/2250] Elapsed 9m 7s (remain 13m 39s) Loss avg.: 0.4426 Grad: 0.1657 LR: 0.00095  
Epoch: [4][1000/2250] Elapsed 10

Epoch 4 - avg_train_loss: 0.4427  avg_val_loss: 0.4461  time: 1410s
Epoch 4 - Accuracy: 0.8078125
Epoch 4 - Save Best Score: 0.8078 Model


Epoch: [5][0/2250] Elapsed 0m 3s (remain 140m 25s) Loss avg.: 0.4408 Grad: 0.1721 LR: 0.00091  
Epoch: [5][100/2250] Elapsed 1m 4s (remain 22m 48s) Loss avg.: 0.4374 Grad: 0.1821 LR: 0.00091  
Epoch: [5][200/2250] Elapsed 2m 4s (remain 21m 12s) Loss avg.: 0.4380 Grad: 0.1721 LR: 0.00091  
Epoch: [5][300/2250] Elapsed 3m 5s (remain 20m 0s) Loss avg.: 0.4383 Grad: 0.1757 LR: 0.00091  
Epoch: [5][400/2250] Elapsed 4m 6s (remain 18m 54s) Loss avg.: 0.4386 Grad: 0.1692 LR: 0.00091  
Epoch: [5][500/2250] Elapsed 5m 6s (remain 17m 49s) Loss avg.: 0.4387 Grad: 0.2064 LR: 0.00091  
Epoch: [5][600/2250] Elapsed 6m 6s (remain 16m 46s) Loss avg.: 0.4390 Grad: 0.1519 LR: 0.00091  
Epoch: [5][700/2250] Elapsed 7m 7s (remain 15m 43s) Loss avg.: 0.4390 Grad: 0.1703 LR: 0.00091  
Epoch: [5][800/2250] Elapsed 8m 7s (remain 14m 42s) Loss avg.: 0.4395 Grad: 0.1715 LR: 0.00091  
Epoch: [5][900/2250] Elapsed 9m 7s (remain 13m 40s) Loss avg.: 0.4395 Grad: 0.1565 LR: 0.00091  
Epoch: [5][1000/2250] Elapsed 10

Epoch 5 - avg_train_loss: 0.4393  avg_val_loss: 0.4446  time: 1409s
Epoch 5 - Accuracy: 0.8086375
Epoch 5 - Save Best Score: 0.8086 Model


Epoch: [6][0/2250] Elapsed 0m 3s (remain 141m 17s) Loss avg.: 0.4408 Grad: 0.1602 LR: 0.00087  
Epoch: [6][100/2250] Elapsed 1m 4s (remain 22m 53s) Loss avg.: 0.4352 Grad: 0.1616 LR: 0.00087  
Epoch: [6][200/2250] Elapsed 2m 5s (remain 21m 14s) Loss avg.: 0.4359 Grad: 0.1503 LR: 0.00087  
Epoch: [6][300/2250] Elapsed 3m 5s (remain 20m 0s) Loss avg.: 0.4358 Grad: 0.1924 LR: 0.00087  
Epoch: [6][400/2250] Elapsed 4m 5s (remain 18m 54s) Loss avg.: 0.4358 Grad: 0.1736 LR: 0.00087  
Epoch: [6][500/2250] Elapsed 5m 6s (remain 17m 49s) Loss avg.: 0.4359 Grad: 0.1614 LR: 0.00087  
Epoch: [6][600/2250] Elapsed 6m 6s (remain 16m 46s) Loss avg.: 0.4360 Grad: 0.1606 LR: 0.00087  
Epoch: [6][700/2250] Elapsed 7m 7s (remain 15m 44s) Loss avg.: 0.4363 Grad: 0.1488 LR: 0.00087  
Epoch: [6][800/2250] Elapsed 8m 8s (remain 14m 42s) Loss avg.: 0.4363 Grad: 0.1572 LR: 0.00087  
Epoch: [6][900/2250] Elapsed 9m 8s (remain 13m 41s) Loss avg.: 0.4365 Grad: 0.1567 LR: 0.00087  
Epoch: [6][1000/2250] Elapsed 10

Epoch 6 - avg_train_loss: 0.4364  avg_val_loss: 0.4457  time: 1411s
Epoch 6 - Accuracy: 0.8078325


Epoch: [7][0/2250] Elapsed 0m 3s (remain 141m 22s) Loss avg.: 0.4317 Grad: 0.1617 LR: 0.00081  
Epoch: [7][100/2250] Elapsed 1m 4s (remain 22m 46s) Loss avg.: 0.4315 Grad: 0.1487 LR: 0.00081  
Epoch: [7][200/2250] Elapsed 2m 4s (remain 21m 9s) Loss avg.: 0.4320 Grad: 0.1683 LR: 0.00081  
Epoch: [7][300/2250] Elapsed 3m 4s (remain 19m 57s) Loss avg.: 0.4326 Grad: 0.1551 LR: 0.00081  
Epoch: [7][400/2250] Elapsed 4m 5s (remain 18m 50s) Loss avg.: 0.4329 Grad: 0.1594 LR: 0.00081  
Epoch: [7][500/2250] Elapsed 5m 5s (remain 17m 47s) Loss avg.: 0.4330 Grad: 0.1436 LR: 0.00081  
Epoch: [7][600/2250] Elapsed 6m 6s (remain 16m 44s) Loss avg.: 0.4333 Grad: 0.1672 LR: 0.00081  
Epoch: [7][700/2250] Elapsed 7m 6s (remain 15m 42s) Loss avg.: 0.4333 Grad: 0.1473 LR: 0.00081  
Epoch: [7][800/2250] Elapsed 8m 6s (remain 14m 40s) Loss avg.: 0.4333 Grad: 0.1509 LR: 0.00081  
Epoch: [7][900/2250] Elapsed 9m 7s (remain 13m 39s) Loss avg.: 0.4333 Grad: 0.1424 LR: 0.00081  
Epoch: [7][1000/2250] Elapsed 10

Epoch 7 - avg_train_loss: 0.4341  avg_val_loss: 0.4416  time: 1407s
Epoch 7 - Accuracy: 0.80953125
Epoch 7 - Save Best Score: 0.8095 Model


Epoch: [8][0/2250] Elapsed 0m 3s (remain 132m 15s) Loss avg.: 0.4399 Grad: 0.1614 LR: 0.00075  
Epoch: [8][100/2250] Elapsed 1m 4s (remain 22m 45s) Loss avg.: 0.4315 Grad: 0.1740 LR: 0.00075  
Epoch: [8][200/2250] Elapsed 2m 4s (remain 21m 12s) Loss avg.: 0.4314 Grad: 0.1620 LR: 0.00075  
Epoch: [8][300/2250] Elapsed 3m 5s (remain 19m 59s) Loss avg.: 0.4314 Grad: 0.1585 LR: 0.00075  
Epoch: [8][400/2250] Elapsed 4m 5s (remain 18m 53s) Loss avg.: 0.4308 Grad: 0.1557 LR: 0.00075  
Epoch: [8][500/2250] Elapsed 5m 6s (remain 17m 48s) Loss avg.: 0.4308 Grad: 0.1580 LR: 0.00075  
Epoch: [8][600/2250] Elapsed 6m 6s (remain 16m 45s) Loss avg.: 0.4307 Grad: 0.1555 LR: 0.00075  
Epoch: [8][700/2250] Elapsed 7m 7s (remain 15m 43s) Loss avg.: 0.4311 Grad: 0.1391 LR: 0.00075  
Epoch: [8][800/2250] Elapsed 8m 7s (remain 14m 41s) Loss avg.: 0.4311 Grad: 0.1644 LR: 0.00075  
Epoch: [8][900/2250] Elapsed 9m 8s (remain 13m 40s) Loss avg.: 0.4310 Grad: 0.1518 LR: 0.00075  
Epoch: [8][1000/2250] Elapsed 1

Epoch 8 - avg_train_loss: 0.4320  avg_val_loss: 0.4417  time: 1410s
Epoch 8 - Accuracy: 0.80963125
Epoch 8 - Save Best Score: 0.8096 Model


Epoch: [9][0/2250] Elapsed 0m 3s (remain 140m 37s) Loss avg.: 0.4363 Grad: 0.1802 LR: 0.00069  
Epoch: [9][100/2250] Elapsed 1m 4s (remain 22m 47s) Loss avg.: 0.4250 Grad: 0.1561 LR: 0.00069  
Epoch: [9][200/2250] Elapsed 2m 4s (remain 21m 12s) Loss avg.: 0.4279 Grad: 0.1457 LR: 0.00069  
Epoch: [9][300/2250] Elapsed 3m 5s (remain 20m 0s) Loss avg.: 0.4280 Grad: 0.1523 LR: 0.00069  
Epoch: [9][400/2250] Elapsed 4m 5s (remain 18m 53s) Loss avg.: 0.4280 Grad: 0.1748 LR: 0.00069  
Epoch: [9][500/2250] Elapsed 5m 6s (remain 17m 49s) Loss avg.: 0.4284 Grad: 0.1475 LR: 0.00069  
Epoch: [9][600/2250] Elapsed 6m 7s (remain 16m 47s) Loss avg.: 0.4286 Grad: 0.1777 LR: 0.00069  
Epoch: [9][700/2250] Elapsed 7m 7s (remain 15m 44s) Loss avg.: 0.4286 Grad: 0.1616 LR: 0.00069  
Epoch: [9][800/2250] Elapsed 8m 7s (remain 14m 42s) Loss avg.: 0.4286 Grad: 0.1653 LR: 0.00069  
Epoch: [9][900/2250] Elapsed 9m 8s (remain 13m 40s) Loss avg.: 0.4288 Grad: 0.1566 LR: 0.00069  
Epoch: [9][1000/2250] Elapsed 10

Epoch 9 - avg_train_loss: 0.4299  avg_val_loss: 0.4412  time: 1410s
Epoch 9 - Accuracy: 0.8096275


Epoch: [10][0/2250] Elapsed 0m 3s (remain 135m 13s) Loss avg.: 0.4543 Grad: 0.1587 LR: 0.00062  
Epoch: [10][100/2250] Elapsed 1m 4s (remain 22m 45s) Loss avg.: 0.4267 Grad: 0.1494 LR: 0.00062  
Epoch: [10][200/2250] Elapsed 2m 4s (remain 21m 10s) Loss avg.: 0.4262 Grad: 0.1460 LR: 0.00062  
Epoch: [10][300/2250] Elapsed 3m 5s (remain 20m 0s) Loss avg.: 0.4258 Grad: 0.1660 LR: 0.00062  
Epoch: [10][400/2250] Elapsed 4m 5s (remain 18m 53s) Loss avg.: 0.4258 Grad: 0.1594 LR: 0.00062  
Epoch: [10][500/2250] Elapsed 5m 6s (remain 17m 49s) Loss avg.: 0.4261 Grad: 0.1465 LR: 0.00062  
Epoch: [10][600/2250] Elapsed 6m 6s (remain 16m 46s) Loss avg.: 0.4264 Grad: 0.1535 LR: 0.00062  
Epoch: [10][700/2250] Elapsed 7m 7s (remain 15m 44s) Loss avg.: 0.4266 Grad: 0.1593 LR: 0.00062  
Epoch: [10][800/2250] Elapsed 8m 7s (remain 14m 42s) Loss avg.: 0.4267 Grad: 0.1575 LR: 0.00062  
Epoch: [10][900/2250] Elapsed 9m 8s (remain 13m 41s) Loss avg.: 0.4270 Grad: 0.1674 LR: 0.00062  
Epoch: [10][1000/2250]

Epoch 10 - avg_train_loss: 0.4278  avg_val_loss: 0.4393  time: 1410s
Epoch 10 - Accuracy: 0.810645
Epoch 10 - Save Best Score: 0.8106 Model


Epoch: [11][0/2250] Elapsed 0m 3s (remain 142m 15s) Loss avg.: 0.4573 Grad: 0.1619 LR: 0.00055  
Epoch: [11][100/2250] Elapsed 1m 4s (remain 22m 46s) Loss avg.: 0.4240 Grad: 0.1554 LR: 0.00055  
Epoch: [11][200/2250] Elapsed 2m 4s (remain 21m 11s) Loss avg.: 0.4243 Grad: 0.1543 LR: 0.00055  
Epoch: [11][300/2250] Elapsed 3m 5s (remain 19m 59s) Loss avg.: 0.4247 Grad: 0.1650 LR: 0.00055  
Epoch: [11][400/2250] Elapsed 4m 5s (remain 18m 53s) Loss avg.: 0.4250 Grad: 0.1691 LR: 0.00055  
Epoch: [11][500/2250] Elapsed 5m 6s (remain 17m 49s) Loss avg.: 0.4249 Grad: 0.1957 LR: 0.00055  
Epoch: [11][600/2250] Elapsed 6m 6s (remain 16m 46s) Loss avg.: 0.4251 Grad: 0.1690 LR: 0.00055  
Epoch: [11][700/2250] Elapsed 7m 7s (remain 15m 44s) Loss avg.: 0.4253 Grad: 0.1719 LR: 0.00055  
Epoch: [11][800/2250] Elapsed 8m 7s (remain 14m 42s) Loss avg.: 0.4254 Grad: 0.1623 LR: 0.00055  
Epoch: [11][900/2250] Elapsed 9m 8s (remain 13m 40s) Loss avg.: 0.4254 Grad: 0.1625 LR: 0.00055  
Epoch: [11][1000/2250

Epoch 11 - avg_train_loss: 0.4258  avg_val_loss: 0.4390  time: 1409s
Epoch 11 - Accuracy: 0.81070125
Epoch 11 - Save Best Score: 0.8107 Model


Epoch: [12][0/2250] Elapsed 0m 3s (remain 135m 57s) Loss avg.: 0.4294 Grad: 0.1625 LR: 0.00048  
Epoch: [12][100/2250] Elapsed 1m 4s (remain 22m 42s) Loss avg.: 0.4213 Grad: 0.1572 LR: 0.00048  
Epoch: [12][200/2250] Elapsed 2m 4s (remain 21m 8s) Loss avg.: 0.4215 Grad: 0.1787 LR: 0.00048  
Epoch: [12][300/2250] Elapsed 3m 4s (remain 19m 57s) Loss avg.: 0.4220 Grad: 0.1671 LR: 0.00048  
Epoch: [12][400/2250] Elapsed 4m 5s (remain 18m 52s) Loss avg.: 0.4221 Grad: 0.1823 LR: 0.00048  
Epoch: [12][500/2250] Elapsed 5m 5s (remain 17m 48s) Loss avg.: 0.4220 Grad: 0.1581 LR: 0.00048  
Epoch: [12][600/2250] Elapsed 6m 6s (remain 16m 45s) Loss avg.: 0.4221 Grad: 0.1668 LR: 0.00048  
Epoch: [12][700/2250] Elapsed 7m 6s (remain 15m 43s) Loss avg.: 0.4224 Grad: 0.1614 LR: 0.00048  
Epoch: [12][800/2250] Elapsed 8m 7s (remain 14m 41s) Loss avg.: 0.4225 Grad: 0.1697 LR: 0.00048  
Epoch: [12][900/2250] Elapsed 9m 7s (remain 13m 40s) Loss avg.: 0.4225 Grad: 0.1722 LR: 0.00048  
Epoch: [12][1000/2250]

Epoch 12 - avg_train_loss: 0.4237  avg_val_loss: 0.4382  time: 1410s
Epoch 12 - Accuracy: 0.8111075
Epoch 12 - Save Best Score: 0.8111 Model


Epoch: [13][0/2250] Elapsed 0m 3s (remain 142m 30s) Loss avg.: 0.4400 Grad: 0.1736 LR: 0.00041  
Epoch: [13][100/2250] Elapsed 1m 4s (remain 22m 48s) Loss avg.: 0.4174 Grad: 0.1667 LR: 0.00041  
Epoch: [13][200/2250] Elapsed 2m 4s (remain 21m 12s) Loss avg.: 0.4185 Grad: 0.1681 LR: 0.00041  
Epoch: [13][300/2250] Elapsed 3m 5s (remain 19m 59s) Loss avg.: 0.4197 Grad: 0.1767 LR: 0.00041  
Epoch: [13][400/2250] Elapsed 4m 5s (remain 18m 53s) Loss avg.: 0.4197 Grad: 0.1793 LR: 0.00041  
Epoch: [13][500/2250] Elapsed 5m 6s (remain 17m 49s) Loss avg.: 0.4199 Grad: 0.1799 LR: 0.00041  
Epoch: [13][600/2250] Elapsed 6m 6s (remain 16m 46s) Loss avg.: 0.4201 Grad: 0.1741 LR: 0.00041  
Epoch: [13][700/2250] Elapsed 7m 7s (remain 15m 44s) Loss avg.: 0.4202 Grad: 0.1672 LR: 0.00041  
Epoch: [13][800/2250] Elapsed 8m 7s (remain 14m 42s) Loss avg.: 0.4201 Grad: 0.1712 LR: 0.00041  
Epoch: [13][900/2250] Elapsed 9m 8s (remain 13m 40s) Loss avg.: 0.4204 Grad: 0.1795 LR: 0.00041  
Epoch: [13][1000/2250

Epoch 13 - avg_train_loss: 0.4216  avg_val_loss: 0.4380  time: 1409s
Epoch 13 - Accuracy: 0.81130625
Epoch 13 - Save Best Score: 0.8113 Model


Epoch: [14][0/2250] Elapsed 0m 3s (remain 134m 39s) Loss avg.: 0.4220 Grad: 0.1607 LR: 0.00035  
Epoch: [14][100/2250] Elapsed 1m 4s (remain 22m 42s) Loss avg.: 0.4173 Grad: 0.1694 LR: 0.00035  
Epoch: [14][200/2250] Elapsed 2m 4s (remain 21m 8s) Loss avg.: 0.4174 Grad: 0.1780 LR: 0.00035  
Epoch: [14][300/2250] Elapsed 3m 4s (remain 19m 57s) Loss avg.: 0.4173 Grad: 0.1807 LR: 0.00035  
Epoch: [14][400/2250] Elapsed 4m 5s (remain 18m 52s) Loss avg.: 0.4179 Grad: 0.1686 LR: 0.00035  
Epoch: [14][500/2250] Elapsed 5m 6s (remain 17m 49s) Loss avg.: 0.4179 Grad: 0.1700 LR: 0.00035  
Epoch: [14][600/2250] Elapsed 6m 6s (remain 16m 46s) Loss avg.: 0.4183 Grad: 0.1770 LR: 0.00035  
Epoch: [14][700/2250] Elapsed 7m 7s (remain 15m 44s) Loss avg.: 0.4185 Grad: 0.1771 LR: 0.00035  
Epoch: [14][800/2250] Elapsed 8m 7s (remain 14m 42s) Loss avg.: 0.4185 Grad: 0.1762 LR: 0.00035  
Epoch: [14][900/2250] Elapsed 9m 8s (remain 13m 40s) Loss avg.: 0.4187 Grad: 0.1668 LR: 0.00035  
Epoch: [14][1000/2250]

Epoch 14 - avg_train_loss: 0.4194  avg_val_loss: 0.4374  time: 1411s
Epoch 14 - Accuracy: 0.8117925
Epoch 14 - Save Best Score: 0.8118 Model


Epoch: [15][0/2250] Elapsed 0m 3s (remain 145m 16s) Loss avg.: 0.4403 Grad: 0.1854 LR: 0.00029  
Epoch: [15][100/2250] Elapsed 1m 4s (remain 22m 56s) Loss avg.: 0.4160 Grad: 0.1824 LR: 0.00029  
Epoch: [15][200/2250] Elapsed 2m 4s (remain 21m 14s) Loss avg.: 0.4152 Grad: 0.1994 LR: 0.00029  
Epoch: [15][300/2250] Elapsed 3m 5s (remain 19m 59s) Loss avg.: 0.4157 Grad: 0.1893 LR: 0.00029  
Epoch: [15][400/2250] Elapsed 4m 5s (remain 18m 52s) Loss avg.: 0.4155 Grad: 0.1712 LR: 0.00029  
Epoch: [15][500/2250] Elapsed 5m 5s (remain 17m 47s) Loss avg.: 0.4154 Grad: 0.1889 LR: 0.00029  
Epoch: [15][600/2250] Elapsed 6m 6s (remain 16m 44s) Loss avg.: 0.4156 Grad: 0.2044 LR: 0.00029  
Epoch: [15][700/2250] Elapsed 7m 6s (remain 15m 42s) Loss avg.: 0.4156 Grad: 0.1835 LR: 0.00029  
Epoch: [15][800/2250] Elapsed 8m 6s (remain 14m 40s) Loss avg.: 0.4159 Grad: 0.1901 LR: 0.00029  
Epoch: [15][900/2250] Elapsed 9m 7s (remain 13m 39s) Loss avg.: 0.4161 Grad: 0.1768 LR: 0.00029  
Epoch: [15][1000/2250

Epoch 15 - avg_train_loss: 0.4172  avg_val_loss: 0.4374  time: 1406s
Epoch 15 - Accuracy: 0.81164625


Epoch: [16][0/2250] Elapsed 0m 3s (remain 125m 44s) Loss avg.: 0.4204 Grad: 0.1802 LR: 0.00023  
Epoch: [16][100/2250] Elapsed 1m 3s (remain 22m 38s) Loss avg.: 0.4126 Grad: 0.1855 LR: 0.00023  
Epoch: [16][200/2250] Elapsed 2m 4s (remain 21m 6s) Loss avg.: 0.4132 Grad: 0.1893 LR: 0.00023  
Epoch: [16][300/2250] Elapsed 3m 4s (remain 19m 55s) Loss avg.: 0.4132 Grad: 0.1902 LR: 0.00023  
Epoch: [16][400/2250] Elapsed 4m 4s (remain 18m 49s) Loss avg.: 0.4138 Grad: 0.1866 LR: 0.00023  
Epoch: [16][500/2250] Elapsed 5m 5s (remain 17m 46s) Loss avg.: 0.4137 Grad: 0.1924 LR: 0.00023  
Epoch: [16][600/2250] Elapsed 6m 6s (remain 16m 44s) Loss avg.: 0.4140 Grad: 0.1952 LR: 0.00023  
Epoch: [16][700/2250] Elapsed 7m 6s (remain 15m 43s) Loss avg.: 0.4140 Grad: 0.1857 LR: 0.00023  
Epoch: [16][800/2250] Elapsed 8m 7s (remain 14m 41s) Loss avg.: 0.4143 Grad: 0.1867 LR: 0.00023  
Epoch: [16][900/2250] Elapsed 9m 7s (remain 13m 40s) Loss avg.: 0.4142 Grad: 0.1990 LR: 0.00023  
Epoch: [16][1000/2250]

Epoch 16 - avg_train_loss: 0.4152  avg_val_loss: 0.4375  time: 1410s
Epoch 16 - Accuracy: 0.81193875
Epoch 16 - Save Best Score: 0.8119 Model


Epoch: [17][0/2250] Elapsed 0m 3s (remain 141m 38s) Loss avg.: 0.4126 Grad: 0.1877 LR: 0.00019  
Epoch: [17][100/2250] Elapsed 1m 4s (remain 22m 49s) Loss avg.: 0.4119 Grad: 0.1986 LR: 0.00019  
Epoch: [17][200/2250] Elapsed 2m 4s (remain 21m 11s) Loss avg.: 0.4117 Grad: 0.1940 LR: 0.00019  
Epoch: [17][300/2250] Elapsed 3m 5s (remain 19m 59s) Loss avg.: 0.4122 Grad: 0.2035 LR: 0.00019  
Epoch: [17][400/2250] Elapsed 4m 5s (remain 18m 52s) Loss avg.: 0.4124 Grad: 0.2065 LR: 0.00019  
Epoch: [17][500/2250] Elapsed 5m 5s (remain 17m 47s) Loss avg.: 0.4124 Grad: 0.2161 LR: 0.00019  
Epoch: [17][600/2250] Elapsed 6m 6s (remain 16m 44s) Loss avg.: 0.4121 Grad: 0.1916 LR: 0.00019  
Epoch: [17][700/2250] Elapsed 7m 6s (remain 15m 42s) Loss avg.: 0.4124 Grad: 0.2108 LR: 0.00019  
Epoch: [17][800/2250] Elapsed 8m 6s (remain 14m 40s) Loss avg.: 0.4122 Grad: 0.2012 LR: 0.00019  
Epoch: [17][900/2250] Elapsed 9m 7s (remain 13m 39s) Loss avg.: 0.4122 Grad: 0.2097 LR: 0.00019  
Epoch: [17][1000/2250

Epoch 17 - avg_train_loss: 0.4131  avg_val_loss: 0.4375  time: 1409s
Epoch 17 - Accuracy: 0.81197125
Epoch 17 - Save Best Score: 0.8120 Model


Epoch: [18][0/2250] Elapsed 0m 3s (remain 141m 55s) Loss avg.: 0.4237 Grad: 0.2053 LR: 0.00015  
Epoch: [18][100/2250] Elapsed 1m 4s (remain 22m 45s) Loss avg.: 0.4104 Grad: 0.2027 LR: 0.00015  
Epoch: [18][200/2250] Elapsed 2m 4s (remain 21m 11s) Loss avg.: 0.4103 Grad: 0.2164 LR: 0.00015  
Epoch: [18][300/2250] Elapsed 3m 5s (remain 19m 58s) Loss avg.: 0.4110 Grad: 0.2098 LR: 0.00015  
Epoch: [18][400/2250] Elapsed 4m 5s (remain 18m 51s) Loss avg.: 0.4118 Grad: 0.2072 LR: 0.00015  
Epoch: [18][500/2250] Elapsed 5m 5s (remain 17m 47s) Loss avg.: 0.4115 Grad: 0.2193 LR: 0.00015  
Epoch: [18][600/2250] Elapsed 6m 6s (remain 16m 44s) Loss avg.: 0.4114 Grad: 0.2124 LR: 0.00015  
Epoch: [18][700/2250] Elapsed 7m 6s (remain 15m 42s) Loss avg.: 0.4114 Grad: 0.2026 LR: 0.00015  
Epoch: [18][800/2250] Elapsed 8m 6s (remain 14m 40s) Loss avg.: 0.4111 Grad: 0.2061 LR: 0.00015  
Epoch: [18][900/2250] Elapsed 9m 7s (remain 13m 39s) Loss avg.: 0.4111 Grad: 0.2137 LR: 0.00015  
Epoch: [18][1000/2250

Epoch 18 - avg_train_loss: 0.4113  avg_val_loss: 0.4379  time: 1409s
Epoch 18 - Accuracy: 0.8118425


Epoch: [19][0/2250] Elapsed 0m 3s (remain 134m 3s) Loss avg.: 0.3888 Grad: 0.2091 LR: 0.00012  
Epoch: [19][100/2250] Elapsed 1m 4s (remain 22m 47s) Loss avg.: 0.4071 Grad: 0.2123 LR: 0.00012  
Epoch: [19][200/2250] Elapsed 2m 4s (remain 21m 12s) Loss avg.: 0.4090 Grad: 0.2131 LR: 0.00012  
Epoch: [19][300/2250] Elapsed 3m 5s (remain 20m 0s) Loss avg.: 0.4085 Grad: 0.2022 LR: 0.00012  
Epoch: [19][400/2250] Elapsed 4m 6s (remain 18m 54s) Loss avg.: 0.4084 Grad: 0.2123 LR: 0.00012  
Epoch: [19][500/2250] Elapsed 5m 6s (remain 17m 50s) Loss avg.: 0.4084 Grad: 0.2149 LR: 0.00012  
Epoch: [19][600/2250] Elapsed 6m 7s (remain 16m 47s) Loss avg.: 0.4084 Grad: 0.2196 LR: 0.00012  
Epoch: [19][700/2250] Elapsed 7m 7s (remain 15m 44s) Loss avg.: 0.4084 Grad: 0.2186 LR: 0.00012  
Epoch: [19][800/2250] Elapsed 8m 8s (remain 14m 43s) Loss avg.: 0.4085 Grad: 0.2144 LR: 0.00012  
Epoch: [19][900/2250] Elapsed 9m 8s (remain 13m 41s) Loss avg.: 0.4087 Grad: 0.2129 LR: 0.00012  
Epoch: [19][1000/2250] 

Epoch 19 - avg_train_loss: 0.4098  avg_val_loss: 0.4380  time: 1411s
Epoch 19 - Accuracy: 0.811815


Epoch: [20][0/2250] Elapsed 0m 3s (remain 143m 47s) Loss avg.: 0.4043 Grad: 0.2078 LR: 0.00011  
Epoch: [20][100/2250] Elapsed 1m 4s (remain 22m 55s) Loss avg.: 0.4083 Grad: 0.2299 LR: 0.00011  
Epoch: [20][200/2250] Elapsed 2m 5s (remain 21m 14s) Loss avg.: 0.4078 Grad: 0.2153 LR: 0.00011  
Epoch: [20][300/2250] Elapsed 3m 5s (remain 20m 0s) Loss avg.: 0.4076 Grad: 0.2206 LR: 0.00011  
Epoch: [20][400/2250] Elapsed 4m 5s (remain 18m 53s) Loss avg.: 0.4079 Grad: 0.2227 LR: 0.00011  
Epoch: [20][500/2250] Elapsed 5m 6s (remain 17m 48s) Loss avg.: 0.4080 Grad: 0.2220 LR: 0.00011  
Epoch: [20][600/2250] Elapsed 6m 6s (remain 16m 45s) Loss avg.: 0.4083 Grad: 0.2166 LR: 0.00011  
Epoch: [20][700/2250] Elapsed 7m 7s (remain 15m 43s) Loss avg.: 0.4081 Grad: 0.2197 LR: 0.00011  
Epoch: [20][800/2250] Elapsed 8m 7s (remain 14m 42s) Loss avg.: 0.4081 Grad: 0.2214 LR: 0.00011  
Epoch: [20][900/2250] Elapsed 9m 7s (remain 13m 40s) Loss avg.: 0.4079 Grad: 0.2233 LR: 0.00011  
Epoch: [20][1000/2250]

Epoch 20 - avg_train_loss: 0.4087  avg_val_loss: 0.4388  time: 1409s
Epoch 20 - Accuracy: 0.8116875
Epoch 20 - Save final model
Score: 0.81197
