In [None]:
import torch

import numpy as np
import pandas as pd
from glob import glob

from model import MLP
from data import scaling, inverse_scaling, preprocess_input_data

import warnings
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available()  else "cpu")

# sample_submission.csv
# final_train.csv
# final_test.csv

In [None]:
def inference(state_dict: dict, x_test: np.array, x_cols):
    model = MLP(in_features=len(x_cols), out_features=35)
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    logits = model(torch.FloatTensor(x_test).to(device))
    return logits

def mean_stack_logits(fold_num, history_dict, x_test: np.array, x_cols, y_scaler=None):
    stack_logits = torch.zeros(len(x_test), 35).to(device).cpu()
    for i in range(fold_num):
        logits = inference(history_dict[i], x_test, x_cols).cpu().detach().numpy()

        if y_scaler is not None:
            logits = y_scaler.inverse_transform(logits)

        stack_logits += logits
    return stack_logits / fold_num

In [None]:
y_cols = ['10','100','1000','101','1020','1040','1100','120','1200','121','140','150',
          '1510','160','200','201','251','2510','270','300','3000','301','351','352',
          '370','400','450','4510','500','550','5510','600','6000','650','652']

In [None]:
# ensemble
files = glob(f'./*week*.pt')

model_list = []
for f in files:
    load = torch.load(f, map_location=torch.device('cpu'))
    model = load['model']
    model_list.append(model)
week = load['week']
y_scaler = load['y_scaler']

s = pd.Series([9999] * 3480)
_train = pd.read_csv('./final_train.csv').set_index(s)
_train = scaling(_train, y_cols, y_scaler)

_test = pd.read_csv('./final_test.csv')
_test = pd.concat([_train, _test], axis=0)
_test, x_cols = preprocess_input_data(_test, y_cols, week=week)

in_test = _test.copy()
 
for v in range(168):
    test_pred = mean_stack_logits(len(model_list), model_list, in_test.loc[v][x_cols].values.astype('float64').reshape(1, -1), x_cols)
    in_test.loc[v, y_cols] = test_pred.squeeze().detach().numpy()

out_test = in_test.copy()
out_test = inverse_scaling(out_test, y_cols, y_scaler)
result = out_test[y_cols].loc[0:167]

In [None]:
_result = result.copy()

submission = pd.read_csv('./sample_submission.csv')

for i in range(168):
    submission.loc[i, y_cols] = _result.loc[i].values
    
submission

Unnamed: 0,timestamp,10,100,101,120,121,140,150,160,200,201,251,270,300,301,351,352,370,400,450,500,550,600,650,652,1000,1020,1040,1100,1200,1510,2510,3000,4510,5510,6000
0,20200525_0,83316.773251,15094.093880,1322.879036,3677.769521,981.522676,919.242996,31792.457803,1066.936107,2711.004347,1543.045437,12194.972754,3115.772702,6641.365730,2943.435541,6064.341987,11429.987957,2850.331256,5697.447146,19661.752165,19829.164654,10537.966755,3032.611216,2092.052049,1228.385553,29764.023032,697.602763,2775.701975,5098.376359,4476.565974,1307.429893,3546.229032,277.810655,2236.955043,1581.602986,2503.082792
1,20200525_1,49501.148656,9223.088758,850.342093,2408.432688,687.439567,635.351968,20617.091072,637.364835,1811.189549,1140.139533,7814.304493,2358.507704,4104.765331,1899.284837,4153.390012,6692.677550,1684.817925,3824.797035,11821.944607,12959.296645,6741.921010,1924.904452,1364.727186,737.202321,18248.306952,396.252303,1726.312517,3487.056346,2762.866699,1043.829581,2243.189458,196.390595,1391.299278,1026.441028,1631.187564
2,20200525_2,37608.957973,7978.528278,712.487820,2155.880320,492.625784,581.698263,15878.256836,482.932148,1529.465225,1182.135987,6542.775677,2257.170043,3470.298794,1538.377829,3746.726200,5231.972977,1202.407218,3577.681701,9632.312910,10414.344206,5434.304829,1378.885552,1065.603106,624.643393,12888.653927,304.329442,1392.978889,2736.749583,2049.761772,910.795429,1935.470067,167.711456,1208.784145,875.275948,1421.385609
3,20200525_3,44023.649312,10457.539469,860.625198,2701.215164,588.804912,720.075882,19104.211493,556.428005,1735.128986,1598.246988,7630.028996,2618.471720,3994.762169,1780.503792,4397.374282,6663.462912,1592.683212,4726.974702,10430.718323,13150.257770,6539.023796,1538.412148,1232.926338,903.049506,14893.854704,353.989574,1644.513322,3591.906713,2052.313422,1017.299973,2384.334235,268.120432,1496.929417,1024.064081,1872.128516
4,20200525_4,81026.609391,20147.867811,1467.308665,5160.368952,1054.166421,1185.624938,37254.242718,972.665912,2623.935366,2222.956062,12086.119108,3902.451330,6736.429017,3098.373372,6775.570534,13900.180569,3660.834021,9660.117319,16871.069031,26556.271085,11903.358808,2263.004877,1903.112998,1657.241330,34618.117596,708.641403,2959.363283,7188.051232,4102.926353,1569.874132,4315.465078,492.338736,3072.878693,1926.094226,3727.259421
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
163,20200531_19,311633.134471,92049.647440,9622.440205,24633.282849,8163.533746,6937.252029,145939.593648,6430.782972,20193.136934,8158.467892,69631.028831,15794.238627,29904.068762,20406.898576,39298.086010,65513.825095,13609.862836,34678.302918,71708.071023,106371.536137,73938.612957,22356.899299,13759.171807,9530.235605,138840.835664,7124.292876,14824.922150,25950.546749,14836.922552,8357.002542,23802.897647,2844.535224,14213.986266,10467.071212,19901.755895
164,20200531_20,298497.618132,83672.861265,8431.342139,20926.700477,7380.794043,6005.869903,142862.388795,5774.539337,16802.320681,7594.513572,64516.105526,14455.591318,26691.553570,18740.149394,34829.186800,60755.089665,13324.197878,31146.545931,65953.891296,101909.849304,66720.997455,20942.045258,12294.587248,8057.324104,145747.651878,5320.812949,13853.745537,27019.117836,15095.364132,7384.884536,20690.652291,2455.653367,12723.863728,9575.013481,17670.696522
165,20200531_21,290345.959792,68958.715621,6293.361972,16306.795621,5251.303461,4327.063030,133165.706264,4683.815947,12188.513798,6472.563148,53808.811845,11992.644714,23401.167436,15687.462509,28258.513659,54525.742860,11829.243104,25662.192099,59506.558498,91353.125765,53840.878194,17259.085152,8994.386844,5978.433392,140147.932235,3938.702264,11962.568827,25833.087676,14362.686717,6100.125564,17408.499115,1872.421021,9947.765398,7737.055953,13751.934982
166,20200531_22,241468.751894,47914.719240,3954.834105,11340.116207,3474.089891,2817.133161,103298.649225,3381.114404,8420.065278,4701.654723,38579.628855,8959.882710,18599.574164,10854.381384,19440.243424,41011.287347,9114.444719,19074.984434,47425.171608,67355.417552,35988.382297,11616.287991,6117.543316,4172.552019,106452.683729,2479.886607,8613.143431,19534.377699,11402.499345,4467.819014,12490.520783,1295.447135,7039.959037,5382.581272,9188.997116


In [None]:
submission.to_csv('_submission.csv', index=False)