In [1]:
import os
import sys
import random
import pickle
import numpy as np
import polars as pl
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import train_test_split

# deep learning related
import torch
from torch.utils.data import DataLoader

sys.path.append("./")
from leap_feature import LeapData, Feature
from leap_dataset import LeepDataset
from leap_network import LeapNetwork
from leap_graph import graph_plot


# seed related
def set_seed(seed=42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


set_seed(42)

In [2]:
class Config:
    def __init__(self):
        self.batch_size = 100
        self.num_epochs = 100
        self.n_dataset = 80
        self.learning_rate = 0.0005
        self.weight_decay = 0.01
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # self.model_path = "./model.pth"


cfg = Config()

data_dir = "../../inputs/float32_numpy/"

data_path = [
    os.path.join(data_dir, "data_{}.npy".format(i)) for i in range(cfg.n_dataset)
]
data_path = random.sample(data_path, cfg.n_dataset)

weight = pickle.load(open("./weight.pkl", "rb"))

"""
for i, path in tqdm(enumerate(data_path)):
    if i == 0:
        load_data = np.load(path)
    else:
        load_data = np.concatenate([load_data, np.load(path)], axis=0)
"""

# 最大値を事前に取得
max_value = pickle.load(open("../../standarised_preprocess/max_value.pkl", "rb"))
min_value = pickle.load(open("../../standarised_preprocess/min_value.pkl", "rb"))

In [3]:
debug = False

if debug:
    chunk_size = 1000
    sample_submission = pd.read_csv(
        "../../inputs/sample_submission.csv", nrows=chunk_size
    )
    weight_sample = pd.read_csv("../../inputs/sample_submission.csv", nrows=chunk_size)
    weight_sample = weight_sample.iloc[0]
    test = pd.read_csv("../../inputs/test.csv", nrows=chunk_size)
else:
    sample_submission = pl.read_csv("../../inputs/sample_submission.csv")
    sample_submission = sample_submission.to_pandas()
    weight_sample = pd.read_csv("../../inputs/sample_submission.csv", nrows=10)
    weight_sample = weight_sample.iloc[0]
    test = pl.read_csv("../../inputs/test.csv")
    test = test.to_pandas()

In [4]:
scaler = pickle.load(open("../../standarised_preprocess/scaler.pkl", "rb"))

x = test.drop(columns=["sample_id"])
pad = np.zeros((x.shape[0], 368))
x = np.concatenate([x, pad], axis=1)

test_load_data = LeapData(x, scaler=scaler)
test_load_data = LeepDataset(test_load_data)
test_dataloader = DataLoader(
    test_load_data,
    batch_size=cfg.batch_size,
    shuffle=False,
    num_workers=4,
    drop_last=False,
)

model = LeapNetwork()
model.load_state_dict(torch.load("./best_model.pth"))
model.eval()
model.to(cfg.device)
print("ok")

ok


In [5]:
from tqdm import tqdm

tq = tqdm(total=len(test_dataloader))
results = np.zeros(shape=(x.shape[0], 368))
current_row = 0
ids = []


def flatten_(t):
    y0_y5 = t[:, :6, :].reshape(-1, 6 * 60)
    y6 = t[:, 6:, 0].reshape(-1, 8)
    output = np.concatenate((y0_y5, y6), axis=1)
    return output


with torch.no_grad():
    for data in test_dataloader:
        g0 = data["g0"].to(cfg.device).to(torch.float32)
        g1 = data["g1"].to(cfg.device).to(torch.float32)
        g2 = data["g2"].to(cfg.device).to(torch.float32)
        g3 = data["g3"].to(cfg.device).to(torch.float32)
        g4 = data["g4"].to(cfg.device).to(torch.float32)
        g5 = data["g5"].to(cfg.device).to(torch.float32)
        g6 = data["g6"].to(cfg.device).to(torch.float32)
        g7 = data["g7"].to(cfg.device).to(torch.float32)
        g8 = data["g8"].to(cfg.device).to(torch.float32)
        g_else = data["g_else"].to(cfg.device).to(torch.float32)

        unet_output = model(g0, g1, g2, g3, g4, g5, g6, g7, g8, g_else)
        unet_output = unet_output.detach().cpu().numpy()
        output = flatten_(unet_output)

        pred_length = output.shape[0]
        results[current_row : current_row + pred_length, :] = output

        current_row += pred_length

        tq.update()
    tq.close()

100%|██████████| 6250/6250 [02:27<00:00, 42.51it/s]


In [6]:
target_means = scaler.mean_[556:]
target_stds = np.sqrt(scaler.var_[556:])
ignore_index = pickle.load(open("./ignore_index.pkl", "rb"))

In [7]:
results = results * target_stds.reshape(1, -1) + target_means.reshape(1, -1)

# 最大値と最小値でクリップ
results = np.clip(results, min_value[556:], max_value[556:])

In [8]:
sample_submission.iloc[:, 1:] *= results
sample_submission

Unnamed: 0,sample_id,ptend_t_0,ptend_t_1,ptend_t_2,ptend_t_3,ptend_t_4,ptend_t_5,ptend_t_6,ptend_t_7,ptend_t_8,...,ptend_v_58,ptend_v_59,cam_out_NETSW,cam_out_FLWDS,cam_out_PRECSC,cam_out_PRECC,cam_out_SOLS,cam_out_SOLL,cam_out_SOLSD,cam_out_SOLLD
0,test_169651,-0.419316,-0.576058,-0.265361,-0.630269,-0.897386,-1.002050,-1.027672,-1.089551,-1.088531,...,-0.191093,0.161974,0.000000,5.344158,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
1,test_524862,-0.284003,-0.476794,-0.342407,-0.669730,-0.947179,-1.039662,-1.040192,-1.075055,-1.058014,...,-0.476176,0.760597,0.000000,5.208979,0.000000,0.059297,0.000000,0.000000,0.000000,0.001331
2,test_634129,-0.319397,-1.052044,-0.640844,-0.605541,-0.768845,-0.917986,-0.975336,-0.979691,-0.888787,...,-0.010123,-0.080900,0.000000,5.719446,0.000000,0.000000,0.000000,0.000000,0.000000,0.005377
3,test_403572,-0.468336,-0.903143,-0.520188,-0.583597,-0.816809,-1.001381,-1.065786,-1.052867,-0.956422,...,-0.135924,-0.212694,0.000000,5.719747,0.000000,0.252216,0.000000,0.000000,0.000000,0.007719
4,test_484578,-0.255026,-0.242345,-0.342255,-0.760192,-1.018077,-1.079235,-1.079577,-1.071151,-1.045887,...,0.949473,-0.039441,0.000000,5.131378,0.000000,0.059952,0.000000,0.000000,0.000000,0.002336
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
624995,test_578220,1.076137,0.652727,1.089709,1.051947,1.334756,1.416589,1.345372,1.255149,1.197816,...,0.146148,0.120034,2.096562,5.477231,0.000000,0.000000,1.950296,1.948255,1.578414,1.012330
624996,test_395695,1.261486,0.543370,0.829838,1.189030,1.177692,1.325159,1.216146,0.945109,1.170962,...,-0.151021,-0.072445,1.544295,5.851239,0.000135,0.070587,1.049664,1.018175,2.310576,2.127312
624997,test_88942,1.312129,0.137116,0.656292,1.182820,1.035582,0.972158,0.822844,0.626005,0.798432,...,-0.163876,-0.119152,1.120017,5.887046,0.000000,0.003514,0.871828,0.880567,1.471584,1.148277
624998,test_79382,1.335894,0.223849,0.883944,0.984375,1.238477,1.144399,0.927326,0.855531,0.896217,...,0.091642,-0.107075,1.553271,5.424863,0.000000,0.000000,1.464764,1.536430,1.222430,0.639848


In [9]:
use_cols = []
for i in range(28):
    use_cols.append(f"ptend_q0002_{i}")

for col in use_cols:
    sample_submission[col] = (
        -test[col.replace("ptend", "state")] * weight_sample[col] / 1200.0
    )

In [10]:
# sample_submissionをpolarsに変
sample_submission = pl.from_pandas(sample_submission)

sample_submission.write_csv("submission.csv")

In [11]:
sample_submission[use_cols]

ptend_q0002_0,ptend_q0002_1,ptend_q0002_2,ptend_q0002_3,ptend_q0002_4,ptend_q0002_5,ptend_q0002_6,ptend_q0002_7,ptend_q0002_8,ptend_q0002_9,ptend_q0002_10,ptend_q0002_11,ptend_q0002_12,ptend_q0002_13,ptend_q0002_14,ptend_q0002_15,ptend_q0002_16,ptend_q0002_17,ptend_q0002_18,ptend_q0002_19,ptend_q0002_20,ptend_q0002_21,ptend_q0002_22,ptend_q0002_23,ptend_q0002_24,ptend_q0002_25,ptend_q0002_26,ptend_q0002_27
f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-1.0207e-37,-0.0,-0.0,-0.0,-0.0,-0.0,-1.9760e-22
-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-3.7435e-28,-2.8370e-23,-1.0433e-18,-0.0,-1.9760e-22
-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-8.6025e-41,-6.3145e-36,-1.3878e-30,-2.4033e-25,-1.3172e-20,-4.6590e-16,-3.6985e-15,-2.4972e-11
-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-2.2201e-41,-1.8663e-36,-5.7886e-32,-0.0,-0.0,-0.0,-2.4746e-16,-1.9241e-12
-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-6.4843e-28,-1.0308e-23,-1.0400e-20
-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-5.5032e-44,-4.7731e-39,-2.8176e-34,-1.1656e-29,-3.6349e-25,-7.8666e-21,-1.0308e-23,-1.0400e-20
-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-1.4334e-41,-1.2442e-36,-7.3528e-32,-3.0805e-27,-1.2008e-22,-4.3715e-18,-4.3138e-18,-8.1202e-15
-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-6.4843e-28,-1.6378e-17,-9.1858e-14
-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-3.7953e-45,-9.1068e-42,-6.2250e-37,-4.1499e-32,-0.0,-2.2765e-25,-4.6957e-23,-5.3730e-20
-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-2.4342e-36,-9.3591e-32,-5.3684e-27,-4.1423e-22,-5.5680e-20,-2.5991e-16
