In [2]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [3]:
import pandas as pd
import numpy as np
import utils
import os

import matplotlib.pyplot as plt

from torch.utils.data import Dataset
import lightgbm as lgb

In [4]:
DATA_DIR = "../working_data/"
SRC_DATA_DIR = "../datasets/processed_data/"

In [11]:
class DataBuilder:
    # TODO: функции доступа к date и id, тк сейчас они то в колоноках,
    # то в индексах, так что вылетают рандомные ошибки
    def __init__(self, hydro, meteo, s2m_dict):
        self.hydro = self.prepare_df(hydro)
        self.meteo = self.prepare_df(meteo)
        self.s2m_dict = s2m_dict
        
    def prepare_df(self, df):
        df["date"] = pd.to_datetime(df["date"])
        first2cols = list(df.columns[:2])
        df.set_index(first2cols, inplace=True)
        
        return df
    
    def build(self):
        self.fill_missing_dates()
        self.merged = self.merge_parts()
        self.extract_merged_x_y()
        
        return self.features, self.target
    
    def fill_missing_dates(self):
        min_date, max_date = self.min_max_data_date()
        
        new_hydro_idx = self.create_all_dates_index(self.hydro, min_date, max_date)
        new_meteo_idx = self.create_all_dates_index(self.meteo, min_date, max_date)
        
        fill_val = np.nan
        self.hydro = self.hydro.reindex(new_hydro_idx, fill_value=fill_val)
        self.meteo = self.meteo.reindex(new_meteo_idx, fill_value=fill_val)
    
    def min_max_data_date(self):
        dates_hydro = self.hydro.index.get_level_values("date")
        dates_meteo = self.meteo.index.get_level_values("date")
        
        min_date = min(dates_hydro.min(), dates_meteo.min())
        max_date = max(dates_hydro.max(), dates_meteo.max())
        
        return min_date, max_date
    
    def create_all_dates_index(self, df, min_date, max_date):
        id_idxs = df.index.get_level_values(0).unique()
        new_date_index = pd.date_range(min_date, max_date, name="date")
        
        all_dates_index = pd.MultiIndex.from_product([id_idxs, new_date_index])
        
        return all_dates_index
    
    def merge_parts(self):
        nearest_meteo_id = self.hydro_to_meteo_map_col()
        
        hydro = self.hydro.reset_index()
        meteo = self.meteo

        merged = hydro.merge(meteo, left_on=[nearest_meteo_id, "date"], right_on=["stationNumber", "date"], how="left")
        merged.set_index(["id", "date"], inplace=True)
        
        return merged
    
    def hydro_to_meteo_map_col(self):
        hydro_id = self.hydro.index.get_level_values("id")
        hydro_nearest_meteo = hydro_id.map(self.s2m_dict)
        
        return hydro_nearest_meteo
    
    def extract_merged_x_y(self):
        feature_cols = list(self.merged.columns)
        feature_cols.remove("target")

        self.features = self.merged[feature_cols]
        self.target = self.merged["target"]

In [12]:
hydro = pd.read_csv(DATA_DIR + "hydro_features.csv")
meteo = pd.read_csv(DATA_DIR + "meteo_features.csv")

hydro = utils.reduce_memory_usage(hydro)
meteo = utils.reduce_memory_usage(meteo)

s2m = pd.read_csv(DATA_DIR + "handmade_s2m.csv", index_col=0)
s2m_dict = s2m.to_dict()["meteo_id"]

In [13]:
builder = DataBuilder(hydro, meteo, s2m_dict)
features, target = builder.build()

In [14]:
del hydro, meteo, builder

In [27]:
# drop objects where target is none
target_nan_mask = target.notna()
features, target = features[target_nan_mask], target[target_nan_mask]

In [28]:
# class DatasetRetriever(Dataset):
#     def __init__(self, features, target):
#         super().__init__()
#         self.features = features
#         self.target = target
        
#     def __getitem__(self, index: int):
#         obj_features = self.features.iloc[index]
#         obj_target = self.target.iloc[index]
        
#         return obj_features, obj_target
        
#     def __len__(self):
#         return len(self.features)
    
# data_retr = DatasetRetriever(features, target)

In [31]:
features

Unnamed: 0_level_0,Unnamed: 1_level_0,stationNumber,max_level_lag_1,max_level_lag_2,max_level_lag_3,max_level_lag_4,max_level_lag_5,max_level_lag_6,max_level_lag_7,max_level_nanmean_1_7,max_level_nanmean_1_30,...,windAngleY,cloudCoverTotal_diff,windSpeed_diff,totalAccumulatedPrecipitation_diff,soilTemperature_diff,airTemperature_diff,relativeHumidity_diff,pressureReducedToMeanSeaLevel_diff,windAngleX_diff,windAngleY_diff
id,date,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
5001,1984-01-01,31707,,,,,,,,258.000000,258.000000,...,0.881546,1.000,-1.333333,0.00,1.000,1.666667,-4.666667,-0.299988,-0.121696,-0.077529
5001,1984-01-02,31707,258.0,,,,,,,256.500000,256.500000,...,0.176777,-0.250,-0.125000,0.00,-0.250,-0.675000,1.625000,1.075005,0.088388,-0.088388
5001,1984-01-03,31707,255.0,258.0,,,,,,255.000000,255.000000,...,0.353553,0.125,0.375000,0.00,0.000,0.412500,-1.000000,-0.275009,-0.088388,0.088388
5001,1984-01-04,31707,252.0,255.0,258.0,,,,,253.250000,253.250000,...,0.673982,-0.500,0.000000,0.00,0.125,0.187500,-0.250000,-0.287491,-0.029073,-0.045636
5001,1984-01-05,31707,248.0,252.0,255.0,258.0,,,,251.399994,251.399994,...,0.411700,0.000,-0.375000,0.00,0.000,-0.050000,-0.250000,-0.250000,0.117462,-0.042753
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6574,2018-12-27,31594,21.0,21.0,23.0,23.0,27.0,26.0,26.0,23.142857,27.366667,...,0.422554,0.000,0.125000,-0.05,0.475,0.237500,-3.125000,0.099998,0.001564,0.004073
6574,2018-12-28,31594,21.0,21.0,21.0,23.0,23.0,27.0,26.0,22.428572,27.066668,...,0.530488,-1.250,-0.250000,0.00,-1.100,-0.637500,2.250000,0.537506,0.115898,0.078174
6574,2018-12-29,31594,21.0,21.0,21.0,21.0,23.0,23.0,27.0,21.571428,26.666666,...,0.487316,1.250,0.250000,0.00,1.200,0.825000,-1.000000,0.937500,-0.107146,-0.060620
6574,2018-12-30,31594,21.0,21.0,21.0,21.0,21.0,23.0,23.0,21.285715,26.466667,...,0.373578,0.000,-0.250000,0.00,-0.450,-0.412500,0.625000,1.074997,0.026797,0.031376


### training model for single post

In [18]:
class TimeSeriesValFold:
    def __init__(self, features, labels, nfolds=12, val_width=30):
        self.features = features
        self.labels = labels
        
        self.dates = self.features.index.get_level_values("date")
        self.uniq_dates = sorted(self.dates.unique())
        self.unique_dates_num = len(self.uniq_dates)
        
        self.nfolds = nfolds
        self.val_width = val_width
        
        self.set_folds_periods()
        
        
    def set_folds_periods(self):
        self.train_masks = []
        self.val_masks = []
        
        train_start = 0
        last_idx = self.unique_dates_num - 1
        
        for fold_idx in range(self.nfolds):
            folds_till_end = self.nfolds - fold_idx + 1
            train_end = last_idx - folds_till_end * self.val_width
            
            val_start = train_end
            val_end = val_start + self.val_width
            
            train_dates = self.uniq_dates[train_start: train_end]
            val_dates = self.uniq_dates[val_start: val_end]
            
            train_date_mask = self.dates.isin(train_dates)
            val_date_mask = self.dates.isin(val_dates)
            
            self.train_masks.append(train_date_mask)
            self.val_masks.append(val_date_mask)
        
#     def __next__(self):
        
        
        
    def __iter__(self):
        for fold_idx in range(self.nfolds):
            train_period = self.train_masks[fold_idx]
            val_period = self.val_masks[fold_idx]
            
            train_features, train_labels = self.features[train_period], self.labels[train_period]
            val_features, val_labels = self.features[val_period], self.labels[val_period]
            
            yield train_features, train_labels, val_features, val_labels

In [23]:
ts_fold = TimeSeriesValFold(features, target)

In [25]:
for i in ts_fold:
    t_x, t_y, v_x, v_y = i
#     print(t_x.reset_index()["date"].max())
#     print("v", v_x.reset_index()["date"].min())
#     print("v max", v_x.reset_index()["date"].max())

In [13]:
train_dataset = lgb.Dataset(features.values, target)

lgb_param = {
    "objective": "regression",
}

model = lgb.train(lgb_param, train_dataset)

In [37]:
class LgbModel:
    """Controlles process of training model on data from single station"""
    def __init__(self, model_config):
        self.lgb_param = model_config
    
    def cross_val_score(self, data_loader):
        val_scores = []
        
        for train_x, train_y, val_x, val_y in data_loader:
            model = self.fit(train_x, train_y)
            val_scores.append(self.validate(val_x, val_y))
        
        return np.mean(val_scores)
            
    def fit(self, x, y):
        dataset = lgb.Dataset(x, y)
        model = lgb.train(self.lgb_param, dataset)
        return model
        
    def validate(self, features, target):
        pass
    
    def predict(self, features):
        pass

In [38]:
class StationModelsManager:
    """creates StattionFitters for every station_id from __init__, 
    Cat train and evaluate these models, return final metric for all stations"""
    def __init__(self, station_ids, model_fitter_class, model_config, data_loader_class):
        self.station_ids = station_ids
        self.models = {}
        self.data_loader_class = data_loader_class
        
        for id_stat in self.station_ids:
#             station_data = self.get_station_data(dataset, id_stat)
            model = model_fitter(model_config)
            model_loss = self.models[id_stat] = model
        
    def cross_val_score(self, features: pd.DataFrame, target: pd.DataFrame):
        self.check_input_ids(features)
        
        scores = []
        
        for station_id, model in self.models.items():
            data_loader = self.get_station_data_loader(features, target, station_id)
            score = self.models[station_id].cross_val_score(data_loader)
            scores.append(score)
            
        return np.mean(score)
    
    def fit(self, features: pd.DataFrame, target: pd.DataFrame):
        self.check_input_ids(features)
        
        for station_id, model in self.models.items():
            station_features, station_target = self.get_station_data(features, target, station_id)
            model.fit(station_features, station_target)
            
    def validate(self, features: pd.DataFrame, target: pd.DataFrame):
        
            
    def get_station_data_loader(self, features, target, station_id):
        station_features, station_target = self.get_station_data(features, target, station_id)
        data_loader = self.data_loader_class(station_features, station_target)
        return data_loader
    
    def get_station_data(self, features, target, station_id):
        id_col = features.reset_index()["id"]
        station_mask = id_col == station_id
        station_features = features[station_mask]
        station_target = target[station_mask]
        
        return station_features, station_target
    
    
    def predict(self, features: pd.DataFrame):
        # TODO: test it
        self.check_input_ids(features)
        
        answers = np.arange(len(features))
        for station_id in predicted_ids.unique():
            curr_station_mask = predicted_ids == station_id
            station_features = features[curr_station_mask]
            
            station_model = self.models[station_id]
            preds = station_model.predict(station_features)
            
            answers[curr_station_mask] = preds
            
        return answers
            
    def check_input_ids(self, features):
        ids = features.reset_index()["id"]
        
        in_white_list = ids.isin(self.station_ids)
        has_wrong_ids = np.sum(in_white_list) > 0
        
        if has_wrong_ids:
            raise ValueError("input features contain data from unknown stations")