In [None]:
import gc
import os
import sys
import random
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import train_test_split

# deep learning related
import torch
from torch.utils.data import DataLoader

sys.path.append("./")
from leap_feature import LeapData, Feature
from leap_dataset import LeepDataset
from leap_network import LeapNetwork
from leap_graph import graph_plot
from config import ExpConfig

folder_name = ExpConfig().data_splitfolder_name


# seed related
def set_seed(seed=42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


set_seed(42)

In [None]:
# 設定クラスの定義
class Config:
    def __init__(self):
        self.batch_size = 1024
        self.num_epochs = 100
        self.n_dataset = 150
        self.learning_rate = 0.0005
        self.weight_decay = 0.01
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # self.model_path = "./model.pth"


cfg = Config()


weight = pickle.load(open("./weight.pkl", "rb"))

In [None]:
# データの読み込み

train_filenames = pickle.load(
    open(
        os.path.join(
            "../../standarised_preprocess/", folder_name, "train_filenames.pkl"
        ),
        "rb",
    )
)
data_path = ["../../inputs/float32_numpy/{}".format(f) for f in train_filenames]

load_datas = []

for i, path in tqdm(enumerate(data_path)):
    if i == 0:
        load_datas.append(np.load(path))
    else:
        load_datas.append(np.load(path))

load_data = np.concatenate(load_datas, axis=0)
del load_datas
gc.collect()

train_index, valid_index = train_test_split(
    range(len(load_data)), test_size=0.2, random_state=42
)
tmp_valid_load_data = load_data[valid_index, :]
scaler = pickle.load(open("../../standarised_preprocess/scaler.pkl", "rb"))
train_load_data = LeapData(load_data[train_index, :], train=False, scaler=scaler)
valid_load_data = LeapData(load_data[valid_index, :], train=False, scaler=scaler)

train_load_data = LeepDataset(train_load_data)
valid_load_data = LeepDataset(valid_load_data)

train_loader = DataLoader(
    train_load_data,
    batch_size=cfg.batch_size,
    shuffle=True,
    num_workers=14,
    pin_memory=True,
)
valid_loader = DataLoader(
    valid_load_data,
    batch_size=cfg.batch_size,
    shuffle=False,
    num_workers=14,
    pin_memory=True,
)

In [None]:
# モデル定義
model = LeapNetwork()
model.to(cfg.device)

means = scaler.mean_
stds = scaler.scale_
target_means = means[556:]
target_stds = np.sqrt(scaler.var_[556:])

In [None]:
# 最適化定義
from sklearn.metrics import r2_score as r2_metrics
from transformers import get_cosine_schedule_with_warmup


def flatten_(t):
    y0_y5 = t[:, :6, :].reshape(-1, 6 * 60)
    y6 = t[:, 6:, 0].reshape(-1, 8)
    output = np.concatenate((y0_y5, y6), axis=1)
    return output


def custom_r2(y_pred, y_true):
    y_true = y_true * stds[556:].reshape(1, -1) + means[556:].reshape(1, -1)
    y_pred = y_pred * target_stds.reshape(1, -1) + target_means.reshape(1, -1)
    y_true = y_true * weight
    y_pred = y_pred * weight
    return 1 - np.sum((y_true - y_pred) ** 2) / np.sum((y_true - np.mean(y_true)) ** 2)


def one_column_r2(y_pred, y_true, i, weight):
    y_true = y_true * stds[556 + i] + means[556 + i]
    y_pred = y_pred * target_stds[i] + target_means[i]

    y_true = y_true * weight
    y_pred = y_pred * weight
    return r2_metrics(y_true, y_pred)


def ignore_one_column_r2(y_pred, y_true, i, weight, ptend_0002=False):
    if ptend_0002:
        y_pred = -y_pred * weight / 1200.0
    else:
        y_pred = y_pred * target_stds[i] + target_means[i]
        y_pred = y_pred * weight

    y_true = y_true * stds[556 + i] + means[556 + i]
    y_true = y_true * weight
    return r2_metrics(y_true, y_pred)


class R2Score(torch.nn.Module):
    def __init__(self):
        pass

    def forward(self, y_pred, y_true):

        return custom_r2(y_pred, y_true)

    def __call__(self, y_pred, y_true):

        return custom_r2(y_pred, y_true)


class CustomeWeightedLoss(torch.nn.Module):
    def __init__(self, weight) -> None:
        super().__init__()
        self.criterion = torch.nn.MSELoss()
        self.mask = torch.ones(60 * 14).cuda()
        self.mask = self.mask.reshape(1, 14, 60)
        self.mask[:, 2, :28] = 0.0
        self.weight = torch.tensor(weight[:, :-8].astype(np.float32))
        self.weight = self.weight.reshape(1, 6, 60)
        self.weight = torch.cat([self.weight, torch.ones(1, 8, 60)], dim=1)
        self.weight = torch.where(
            self.weight == 0, torch.tensor(0.0), torch.tensor(1.0)
        ).cuda()
        self.weight = self.weight.reshape(1, 14, 60)

    def forward(self, y_pred, y_true):
        y_pred = y_pred * self.mask
        # y_pred = y_pred * self.weight

        return self.criterion(y_pred, y_true)


criterion = CustomeWeightedLoss(weight=weight)
eval_criterion = R2Score()
optimizer = torch.optim.AdamW(
    model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay
)
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=10,
    num_training_steps=cfg.num_epochs * len(train_loader),
)

In [None]:
target_columns = Feature.target

best_r2 = -np.inf
best_model = None

criterion_weight = torch.ones_like(
    torch.from_numpy(weight[0, :-8].astype(np.float32)).to(torch.float32)
).reshape(1, -1)
criterion_weight = criterion_weight.to(cfg.device).to(torch.float32)

train_r2_scores = {}
valid_r2_scores = {}

for epoch in range(cfg.num_epochs):
    model.train()

    ignore_index = []
    train_preds = []
    train_labels = []
    valid_preds = []
    valid_labels = []

    tq = tqdm(total=len(train_loader))
    tq.set_description(f"Epoch {epoch}")

    current_loss = 0.0
    r2_score = 0.0

    for i, data in enumerate(train_loader):
        g0 = data["g0"].to(cfg.device).to(torch.float32)
        g1 = data["g1"].to(cfg.device).to(torch.float32)
        g2 = data["g2"].to(cfg.device).to(torch.float32)
        g3 = data["g3"].to(cfg.device).to(torch.float32)
        g4 = data["g4"].to(cfg.device).to(torch.float32)
        g5 = data["g5"].to(cfg.device).to(torch.float32)
        g6 = data["g6"].to(cfg.device).to(torch.float32)
        g7 = data["g7"].to(cfg.device).to(torch.float32)
        g8 = data["g8"].to(cfg.device).to(torch.float32)
        g_else = data["g_else"].to(cfg.device).to(torch.float32)

        y0 = data["y0"].to(cfg.device).unsqueeze(1)
        y1 = data["y1"].to(cfg.device).unsqueeze(1)
        y2 = data["y2"].to(cfg.device).unsqueeze(1)
        y3 = data["y3"].to(cfg.device).unsqueeze(1)
        y4 = data["y4"].to(cfg.device).unsqueeze(1)
        y5 = data["y5"].to(cfg.device).unsqueeze(1)
        y6 = data["y6"].to(cfg.device).unsqueeze(-1).repeat(1, 1, 60)
        y = torch.cat([y0, y1, y2, y3, y4, y5, y6], dim=1).to(torch.float32)

        optimizer.zero_grad()

        unet_output = model(g0, g1, g2, g3, g4, g5, g6, g7, g8, g_else)

        loss = criterion(unet_output, y)
        loss.backward()
        optimizer.step()
        scheduler.step()

        current_loss += loss.item()

        unet_output = flatten_(unet_output.detach().cpu().numpy())
        y = flatten_(y.detach().cpu().numpy())
        r2_score += eval_criterion(unet_output, y).item()

        tq.set_description(
            f"Epoch {epoch} loss: {current_loss / (i + 1):.4f}, r2: {r2_score / (i + 1):.4f}"
        )

        tq.update(1)
    tq.close()

    model.eval()

    current_loss = 0.0
    r2_score = 0.0

    tq = tqdm(total=len(valid_loader))

    group2_stack = []

    with torch.no_grad():
        for i, data in enumerate(valid_loader):
            g0 = data["g0"].to(cfg.device).to(torch.float32)
            g1 = data["g1"].to(cfg.device).to(torch.float32)
            g2 = data["g2"].to(cfg.device).to(torch.float32)
            g3 = data["g3"].to(cfg.device).to(torch.float32)
            g4 = data["g4"].to(cfg.device).to(torch.float32)
            g5 = data["g5"].to(cfg.device).to(torch.float32)
            g6 = data["g6"].to(cfg.device).to(torch.float32)
            g7 = data["g7"].to(cfg.device).to(torch.float32)
            g8 = data["g8"].to(cfg.device).to(torch.float32)
            g_else = data["g_else"].to(cfg.device).to(torch.float32)

            y0 = data["y0"].to(cfg.device).unsqueeze(1)
            y1 = data["y1"].to(cfg.device).unsqueeze(1)
            y2 = data["y2"].to(cfg.device).unsqueeze(1)
            y3 = data["y3"].to(cfg.device).unsqueeze(1)
            y4 = data["y4"].to(cfg.device).unsqueeze(1)
            y5 = data["y5"].to(cfg.device).unsqueeze(1)
            y6 = data["y6"].to(cfg.device).unsqueeze(-1).repeat(1, 1, 60)
            y = torch.cat([y0, y1, y2, y3, y4, y5, y6], dim=1).to(torch.float32)

            unet_output = model(g0, g1, g2, g3, g4, g5, g6, g7, g8, g_else)
            loss = criterion(unet_output, y)

            pred = flatten_(unet_output.detach().cpu().numpy())
            label = flatten_(y.detach().cpu().numpy())

            valid_preds.append(pred)
            valid_labels.append(label)
            group2_stack.append(g1.detach().cpu().numpy())

            current_loss += loss.item()
            unet_output = flatten_(unet_output.detach().cpu().numpy())
            y = flatten_(y.detach().cpu().numpy())
            r2_score += eval_criterion(unet_output, y).item()

            tq.set_description(
                f"Epoch {epoch} loss: {current_loss / (i + 1):.4f}, r2: {r2_score / (i + 1):.4f}"
            )

            tq.update(1)
    tq.close()

    valid_preds = np.concatenate(valid_preds, axis=0)
    valid_labels = np.concatenate(valid_labels, axis=0)
    group2_stack = np.concatenate(group2_stack, axis=0)

    for i, column in enumerate(target_columns):
        valid_r2_scores[column] = one_column_r2(
            valid_preds[:, i],
            valid_labels[:, i],
            i,
            weight[0, i],
        )

    # ひとまず予測 & ignore_indexを取得
    mean_score = np.mean(list(valid_r2_scores.values()))
    for i, (key, value) in enumerate(valid_r2_scores.items()):
        if value < 0:
            print(f"{i}, {key}, {value}")
            ignore_index.append(i)

    # q0002のみルールベースで無
    q0002_cols = []
    for i in range(28):
        q0002_cols.append(f"ptend_q0002_{i}")
    q0002_count = 0

    q0003_cols = []
    for i in range(28):
        q0003_cols.append(f"ptend_q0003_{i}")
    q0003_count = 0

    for i, column in enumerate(target_columns):
        if column in q0002_cols:
            valid_preds[:, i] = tmp_valid_load_data[
                : len(valid_preds), 120 + q0002_count
            ]
            q0002_count += 1

            valid_r2_scores[column] = ignore_one_column_r2(
                valid_preds[:, i],
                valid_labels[:, i],
                i,
                weight[0, i],
                ptend_0002=True,
            )
        else:
            valid_r2_scores[column] = ignore_one_column_r2(
                valid_preds[:, i],
                valid_labels[:, i],
                i,
                weight[0, i],
                ptend_0002=False,
            )

    ignore_mean_score = np.mean(list(valid_r2_scores.values()))
    print(f"mean_score: {mean_score}, ignore_mean_score: {ignore_mean_score}")

    if ignore_mean_score > best_r2:
        best_r2 = ignore_mean_score
        best_model = model.state_dict()
        torch.save(best_model, f"./best_model.pth")
        pickle.dump(ignore_index, open("ignore_index.pkl", "wb"))

    graph_plot(epoch, valid_r2_scores, weight[0])

    print("####################\n")