In [1]:
import os
os.environ["OMP_NUM_THREADS"] = "4"

import pandas as pd
import numpy as np
import torch
from functools import partial
import pytorch_lightning as pl
import warnings
import pickle
warnings.filterwarnings("ignore")

from torch.utils.data import DataLoader

from ptls.data_load.datasets import MemoryMapDataset
from ptls.data_load.iterable_processing.iterable_seq_len_limit import ISeqLenLimit
from ptls.data_load.iterable_processing.to_torch_tensor import ToTorch
from ptls.data_load.iterable_processing.feature_filter import FeatureFilter
from ptls.nn import TrxEncoder, RnnSeqEncoder
from ptls.frames.coles import CoLESModule
from ptls.data_load.iterable_processing import SeqLenFilter
from ptls.frames.coles import ColesIterableDataset
from ptls.frames.coles.split_strategy import SampleSlices
from ptls.frames import PtlsDataModule
from ptls.preprocessing import PandasDataPreprocessor
from ptls.data_load.utils import collate_feature_dict
from ptls.data_load.iterable_processing_dataset import IterableProcessingDataset

from tqdm.auto import tqdm
import lightgbm as ltb



# Part 1

In [6]:
geo_train = pd.read_parquet("geo_train.parquet")
geo_test = pd.read_parquet("geo_test.parquet")

In [7]:
preprocessor = PandasDataPreprocessor(
    col_id="client_id",
    col_event_time="event_time",
    event_time_transformation="dt_to_timestamp",
    cols_category=["geohash_4",
                   "geohash_5",
                   "geohash_6",],
    return_records=False,
)

In [None]:
preprocessor = preprocessor.fit(geo_train)

IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out


In [None]:
with open('geo_preprocessor.pkl', 'wb') as f:
    pickle.dump(preprocessor, f)

In [None]:
with open('geo_preprocessor.pkl', 'rb') as f:
    preprocessor = pickle.load(f)

In [None]:
processed_train = preprocessor.transform(geo_train)
processed_test = preprocessor.transform(geo_test)

IOStream.flush timed out
IOStream.flush timed out


In [None]:
processed_train.to_pickle('geo_processed_train.pkl')
processed_test.to_pickle('geo_processed_test.pkl')

# Part 2

In [None]:
processed_train = pd.read_pickle('geo_processed_train.pkl')

In [None]:
processed_test = pd.read_pickle('geo_processed_test.pkl')

In [None]:
with open('geo_preprocessor.pkl', 'rb') as f:
    preprocessor = pickle.load(f)

In [None]:
train = MemoryMapDataset(
    data=processed_train.to_dict("records"),
    i_filters=[
        FeatureFilter(drop_feature_names=['client_id', 'target_1', 'target_2', 'target_3', 'target_4']),
        SeqLenFilter(min_seq_len=64),
        ISeqLenLimit(max_seq_len=4096),
        ToTorch()
    ]
)

test = MemoryMapDataset(
    data=processed_test.to_dict("records"),
    i_filters=[
        FeatureFilter(drop_feature_names=['client_id', 'target_1', 'target_2', 'target_3', 'target_4']),
        SeqLenFilter(min_seq_len=64),
        ISeqLenLimit(max_seq_len=4096),
        ToTorch()
    ]
)

In [None]:
train_ds = ColesIterableDataset(
    data=train,
    splitter=SampleSlices(
        split_count=5,
        cnt_min=32,
        cnt_max=180
    )
)

valid_ds = ColesIterableDataset(
    data=test,
    splitter=SampleSlices(
        split_count=5,
        cnt_min=32,
        cnt_max=180
    )
)

In [None]:
train_dl = PtlsDataModule(
    train_data=train_ds,
    train_num_workers=8,
    train_batch_size=256,
    valid_data=valid_ds,
    valid_num_workers=8,
    valid_batch_size=256
)

In [None]:
trx_encoder_params = dict(
    embeddings_noise=0.003,
    embeddings={
        'geohash_4': {'in': preprocessor.get_category_dictionary_sizes()["geohash_4"], 'out': 24},
        'geohash_5': {'in': preprocessor.get_category_dictionary_sizes()["geohash_5"], 'out': 24},
        'geohash_6': {'in': preprocessor.get_category_dictionary_sizes()["geohash_6"], 'out': 24},
      }
)

In [None]:
seq_encoder = RnnSeqEncoder(
    trx_encoder=TrxEncoder(**trx_encoder_params),
    hidden_size=64,
    type='gru',
)

In [None]:
model = CoLESModule(
    seq_encoder=seq_encoder,
    optimizer_partial=partial(torch.optim.Adam, lr=0.001),
    lr_scheduler_partial=partial(torch.optim.lr_scheduler.StepLR, step_size=3, gamma=0.9025)
)

In [None]:
trainer = pl.Trainer(
    max_epochs=30,
    limit_val_batches=5000,
    # gpus=[0],
    enable_progress_bar=True,
    gradient_clip_val=0.5,
    logger=pl.loggers.TensorBoardLogger(
        save_dir='./logdir',
        name='geo_result'
    ),
    callbacks=[
        pl.callbacks.LearningRateMonitor(logging_interval='step'),
        pl.callbacks.ModelCheckpoint(every_n_train_steps=5000, save_top_k=-1),
        pl.callbacks.EarlyStopping(monitor="valid/recall_top_k", mode="max", patience=5),
        pl.callbacks.EarlyStopping(monitor="loss", mode="min", patience=3),
    ]
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, train_dl)

You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
2024-06-16 04:46:32.414118: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name               | Type            | Params | Mode 
---------------------------------------------------------------
0 | _loss              | ContrastiveLoss | 0      | train
1 | _seq_encoder       | RnnSeqEncoder   | 47.7 M | train
2 | _validation_metric | BatchRecallTopK | 0 

Epoch 0: 100%|██████████| 2086/2086 [02:01<00:00, 17.12it/s, v_num=1, seq_len=104.0]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/555 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/555 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 1/555 [00:00<00:24, 22.59it/s][A
Validation DataLoader 0:   0%|          | 2/555 [00:00<00:26, 20.78it/s][A
Validation DataLoader 0:   1%|          | 3/555 [00:00<00:26, 20.92it/s][A
Validation DataLoader 0:   1%|          | 4/555 [00:00<00:26, 20.64it/s][A
Validation DataLoader 0:   1%|          | 5/555 [00:00<00:26, 20.93it/s][A
Validation DataLoader 0:   1%|          | 6/555 [00:00<00:26, 20.56it/s][A
Validation DataLoader 0:   1%|▏         | 7/555 [00:00<00:26, 20.69it/s][A
Validation DataLoader 0:   1%|▏         | 8/555 [00:00<00:26, 20.82it/s][A
Validation DataLoader 0:   2%|▏         | 9/555 [00:00<00:26, 20.96it/s][A
Validation DataLoader 0:   2%|▏         | 10/555 [00:00<00:

In [None]:
torch.save(model.state_dict(), './geo_emb64_model.pt')

# Part 3

In [None]:
processed_train = pd.read_pickle('geo_processed_train.pkl')

In [10]:
processed_test = pd.read_pickle('geo_processed_test.pkl')

In [None]:
# processed_target = pd.read_pickle('processed_target.pkl')
target_train = pd.read_parquet("train_target.parquet")
target_test = pd.read_parquet("test_target_b.parquet")

In [None]:
with open('geo_preprocessor.pkl', 'rb') as f:
    preprocessor = pickle.load(f)

In [None]:
import gc
gc.collect()

In [None]:
trx_encoder_params = dict(
    embeddings_noise=0.003,
    embeddings={
        'geohash_4': {'in': preprocessor.get_category_dictionary_sizes()["geohash_4"], 'out': 24},
        'geohash_5': {'in': preprocessor.get_category_dictionary_sizes()["geohash_5"], 'out': 24},
        'geohash_6': {'in': preprocessor.get_category_dictionary_sizes()["geohash_6"], 'out': 24},
      }
)
seq_encoder = RnnSeqEncoder(
    trx_encoder=TrxEncoder(**trx_encoder_params),
    hidden_size=64,
    type='gru',
)


model = model = CoLESModule(
    seq_encoder=seq_encoder,
    optimizer_partial=partial(torch.optim.Adam, lr=0.001),
    lr_scheduler_partial=partial(torch.optim.lr_scheduler.StepLR, step_size=3, gamma=0.9025)
)
model.load_state_dict(torch.load('./geo_emb64_model.pt'))
model.eval()

In [None]:
device  = torch.device("cuda")
model = model.to(device)

In [None]:
from pandas.tseries.offsets import MonthBegin


class GetSplit(IterableProcessingDataset):
    def __init__(
        self,
        months,
        col_id='client_id',
        col_time='event_time'
    ):
        super().__init__()
        self.months = months
        self._col_id = col_id
        self._col_time = col_time

    def __iter__(self):
        for rec in self._src:
            for i, month in enumerate(self.months):
                features = rec[0] if type(rec) is tuple else rec
                features = features.copy()

                month_event_time = int((pd.to_datetime(month, yearfirst=True, dayfirst=False) - MonthBegin(1)).to_datetime64()) / 1e9
                mask = features[self._col_time] < month_event_time

                for key, tensor in features.items():
                    if key.startswith('target'):
                        features[key] = tensor[i].tolist()
                    elif key != self._col_id:
                        features[key] = tensor[mask]

                features[self._col_id] += '__' + str(month)

                yield features
                

from datetime import datetime


def collate_feature_dict_with_target(batch, col_id='client_id', targets=False):
    batch_ids = []
    target_cols = []
    for sample in batch:
        batch_ids.append(sample[col_id])
        del sample[col_id]

        if targets:
            target_cols.append([sample[f'target_{i}'] for i in range(1, 5)])
            del sample['target_1']
            del sample['target_2']
            del sample['target_3']
            del sample['target_4']

    padded_batch = collate_feature_dict(batch)
    if targets:
        return padded_batch, batch_ids, target_cols
    return padded_batch, batch_ids

def to_pandas(x):
    with torch.no_grad():
        expand_cols = []
        scalar_features = {}
        for k, v in x.items():
            if type(v) is torch.Tensor:
                v = v.cpu().detach().numpy()
            if type(v) is list or len(v.shape) == 1:
                scalar_features[k] = v
            elif len(v.shape) == 2:
                expand_cols.append(k)
            else:
                scalar_features[k] = None
        dataframes = [pd.DataFrame(scalar_features)]
        for col in expand_cols:
            v = x[col].cpu().detach().numpy()
            dataframes.append(pd.DataFrame(v, columns=[f'{col}_{i:04d}' for i in range(v.shape[1])]))
        return pd.concat(dataframes, axis=1)

In [21]:
from tqdm import tqdm

def make_prediction(model, inference_dl):
    dfs = []
    for x in tqdm(inference_dl):
        x_len = len(x)
        if x_len == 3:
            x, batch_ids, target_cols = x
        else:
            x, batch_ids = x
        out = model(x.to(device))
        if x_len == 3:
            target_cols = torch.tensor(target_cols)
            x_out = {
                'client_id': batch_ids,
                'target_1': target_cols[:, 0],
                'target_2': target_cols[:, 1],
                'target_3': target_cols[:, 2],
                'target_4': target_cols[:, 3],
                'emb': out
            }
        else:
            x_out = {
                'client_id': batch_ids,
                'emb': out
            }
        torch.cuda.empty_cache()
        dfs.append(to_pandas(x_out))
    return pd.concat(dfs, axis='rows')


def get_train_dataset(processed_data, model, months):
    train = MemoryMapDataset(
        data=processed_data.to_dict("records"),
        i_filters=[
            ISeqLenLimit(max_seq_len=4096),
            FeatureFilter(keep_feature_names=['client_id', 'target_1', 'target_2', 'target_3', 'target_4']),
            GetSplit(months=months),
            ToTorch(),
        ]
    )

    inference_train_dl = DataLoader(
            dataset=train,
            collate_fn=collate_feature_dict_with_target,
            shuffle=False,
            num_workers=0,
            batch_size=256,
        )
    
    train_emb_df = make_prediction(model, inference_train_dl)
    train_emb_df[['client_id', 'month']] = train_emb_df['client_id'].str.split('__', n=1, expand=True)
    return train_emb_df


def get_val_dataset(processed_data, model, months):
    val = MemoryMapDataset(
        data=processed_data.to_dict("records"),
        i_filters=[
            ISeqLenLimit(max_seq_len=4096),
            FeatureFilter(keep_feature_names=['client_id', 'target_1', 'target_2', 'target_3', 'target_4']),
            GetSplit(months=months),
            ToTorch(),
        ]
    )
    inference_val_dl = DataLoader(
            dataset=val,
            collate_fn=collate_feature_dict_with_target,
            shuffle=False,
            num_workers=0,
            batch_size=256,
        )
    
    val_emb_df = make_prediction(model, inference_val_dl)
    val_emb_df[['client_id', 'month']] = val_emb_df['client_id'].str.split('__', n=1, expand=True)
    return val_emb_df


def get_test_dataset(processed_data, model):
    test = MemoryMapDataset(
        data=processed_data.to_dict("records"),
        i_filters=[
            ISeqLenLimit(max_seq_len=4096),
            FeatureFilter(keep_feature_names=['client_id', 'target_1', 'target_2', 'target_3', 'target_4']),
            ToTorch(),
        ]
    )

    inference_test_dl = DataLoader(
            dataset=test,
            collate_fn=collate_feature_dict_with_target,
            shuffle=False,
            num_workers=0,
            batch_size=256,
        )
    
    test_emb_df = make_prediction(model, inference_test_dl)
    return test_emb_df

In [43]:
del part3, train_emb_df_part3

In [44]:
import gc
gc.collect()

115

In [45]:
part4 = processed_train[450_000:]

In [None]:
train_emb_df_part4 = get_train_dataset(part4, model, sorted(target_train.mon.sort_values().unique()))

100%|██████████| 8077/8077 [17:11<00:00,  7.83it/s]


In [50]:
train_emb_df_part4

Unnamed: 0,client_id,emb_0000,emb_0001,emb_0002,emb_0003,emb_0004,emb_0005,emb_0006,emb_0007,emb_0008,emb_0009,emb_0010,emb_0011,emb_0012,emb_0013,emb_0014,emb_0015,emb_0016,emb_0017,emb_0018,emb_0019,emb_0020,emb_0021,emb_0022,emb_0023,emb_0024,emb_0025,emb_0026,emb_0027,emb_0028,emb_0029,emb_0030,emb_0031,emb_0032,emb_0033,emb_0034,emb_0035,emb_0036,emb_0037,emb_0038,emb_0039,emb_0040,emb_0041,emb_0042,emb_0043,emb_0044,emb_0045,emb_0046,emb_0047,emb_0048,emb_0049,emb_0050,emb_0051,emb_0052,emb_0053,emb_0054,emb_0055,emb_0056,emb_0057,emb_0058,emb_0059,emb_0060,emb_0061,emb_0062,emb_0063,month
0,b91d8218da3a9fe60e663ad7093e52eba853a6112a2f97...,0.530980,-0.984420,-0.971033,0.986405,0.960868,0.954608,-0.686292,-0.214617,-0.843814,0.923254,0.642353,0.858727,-0.996571,0.960354,0.987376,0.912106,0.764827,0.917465,0.656390,0.503387,0.997832,-0.998006,0.744212,-0.989362,-0.417862,0.966574,-0.863185,0.844486,0.273897,0.994828,-0.595478,-0.674601,0.896613,-0.933203,-0.930911,0.457315,-0.949447,0.776534,-0.997098,-0.894956,-0.417569,0.557773,0.927011,0.948354,-0.783789,-0.790093,-0.951561,0.882480,0.661425,-0.735958,0.943946,0.459107,0.895740,0.928397,0.939273,-0.926222,-0.522056,-0.848748,-0.997565,0.974599,0.373775,0.266752,0.995215,0.998462,2022-02-28
1,b91d8218da3a9fe60e663ad7093e52eba853a6112a2f97...,0.530980,-0.984420,-0.971033,0.986405,0.960868,0.954608,-0.686292,-0.214617,-0.843814,0.923254,0.642353,0.858727,-0.996571,0.960354,0.987376,0.912106,0.764827,0.917465,0.656390,0.503387,0.997832,-0.998006,0.744212,-0.989362,-0.417862,0.966574,-0.863185,0.844486,0.273897,0.994828,-0.595478,-0.674601,0.896613,-0.933203,-0.930911,0.457315,-0.949447,0.776534,-0.997098,-0.894956,-0.417569,0.557773,0.927011,0.948354,-0.783789,-0.790093,-0.951561,0.882480,0.661425,-0.735958,0.943946,0.459107,0.895740,0.928397,0.939273,-0.926222,-0.522056,-0.848748,-0.997565,0.974599,0.373775,0.266752,0.995215,0.998462,2022-03-31
2,b91d8218da3a9fe60e663ad7093e52eba853a6112a2f97...,0.530980,-0.984420,-0.971033,0.986405,0.960868,0.954608,-0.686292,-0.214617,-0.843814,0.923254,0.642353,0.858727,-0.996571,0.960354,0.987376,0.912106,0.764827,0.917465,0.656390,0.503387,0.997832,-0.998006,0.744212,-0.989362,-0.417862,0.966574,-0.863185,0.844486,0.273897,0.994828,-0.595478,-0.674601,0.896613,-0.933203,-0.930911,0.457315,-0.949447,0.776534,-0.997098,-0.894956,-0.417569,0.557773,0.927011,0.948354,-0.783789,-0.790093,-0.951561,0.882480,0.661425,-0.735958,0.943946,0.459107,0.895740,0.928397,0.939273,-0.926222,-0.522056,-0.848748,-0.997565,0.974599,0.373775,0.266752,0.995215,0.998462,2022-04-30
3,b91d8218da3a9fe60e663ad7093e52eba853a6112a2f97...,0.530980,-0.984420,-0.971033,0.986405,0.960868,0.954608,-0.686292,-0.214617,-0.843814,0.923254,0.642353,0.858727,-0.996571,0.960354,0.987376,0.912106,0.764827,0.917465,0.656390,0.503387,0.997832,-0.998006,0.744212,-0.989362,-0.417862,0.966574,-0.863185,0.844486,0.273897,0.994828,-0.595478,-0.674601,0.896613,-0.933203,-0.930911,0.457315,-0.949447,0.776534,-0.997098,-0.894956,-0.417569,0.557773,0.927011,0.948354,-0.783789,-0.790093,-0.951561,0.882480,0.661425,-0.735958,0.943946,0.459107,0.895740,0.928397,0.939273,-0.926222,-0.522056,-0.848748,-0.997565,0.974599,0.373775,0.266752,0.995215,0.998462,2022-05-31
4,b91d8218da3a9fe60e663ad7093e52eba853a6112a2f97...,0.530980,-0.984420,-0.971033,0.986405,0.960868,0.954608,-0.686292,-0.214617,-0.843814,0.923254,0.642353,0.858727,-0.996571,0.960354,0.987376,0.912106,0.764827,0.917465,0.656390,0.503387,0.997832,-0.998006,0.744212,-0.989362,-0.417862,0.966574,-0.863185,0.844486,0.273897,0.994828,-0.595478,-0.674601,0.896613,-0.933203,-0.930911,0.457315,-0.949447,0.776534,-0.997098,-0.894956,-0.417569,0.557773,0.927011,0.948354,-0.783789,-0.790093,-0.951561,0.882480,0.661425,-0.735958,0.943946,0.459107,0.895740,0.928397,0.939273,-0.926222,-0.522056,-0.848748,-0.997565,0.974599,0.373775,0.266752,0.995215,0.998462,2022-06-30
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
187,fffff598cd1a947b8ce0b86d56fd356729ec7bacb7053a...,-0.067350,-0.935076,-0.939799,0.944806,0.596933,0.985579,-0.688099,0.152316,-0.838738,0.874432,0.362869,0.909280,-0.991268,0.841990,0.763775,0.990757,0.589735,0.967755,0.750433,0.409629,0.992645,-0.995973,0.817319,-0.968961,0.636623,0.993538,-0.658791,0.385049,0.363789,0.975056,0.023072,-0.219486,0.954808,-0.952260,-0.801674,0.893073,-0.850678,0.893551,-0.991724,-0.882788,0.503352,0.801012,0.991226,0.980636,-0.641031,-0.017088,-0.977291,0.800889,0.822028,-0.676543,0.987370,0.682270,0.200354,0.759405,-0.048285,-0.784575,-0.847842,-0.802447,-0.996074,0.943942,-0.447762,0.304014,0.984917,0.993055,2022-09-30
188,fffff598cd1a947b8ce0b86d56fd356729ec7bacb7053a...,-0.064424,-0.938533,-0.939866,0.943526,0.602098,0.985547,-0.682055,0.153996,-0.844163,0.872999,0.365596,0.907011,-0.991683,0.845713,0.760345,0.991530,0.579726,0.970450,0.758239,0.411320,0.993222,-0.995970,0.814965,-0.968199,0.637410,0.993421,-0.660824,0.353980,0.365894,0.973281,0.022603,-0.214611,0.958182,-0.956069,-0.802016,0.891400,-0.853218,0.896189,-0.992052,-0.876859,0.504892,0.799586,0.991420,0.980332,-0.633755,-0.024806,-0.980781,0.805924,0.828922,-0.670500,0.988874,0.678223,0.208953,0.756120,-0.046721,-0.782844,-0.850287,-0.792060,-0.995699,0.929453,-0.447933,0.302650,0.985130,0.992595,2022-10-31
189,fffff598cd1a947b8ce0b86d56fd356729ec7bacb7053a...,-0.064424,-0.938533,-0.939866,0.943526,0.602098,0.985547,-0.682055,0.153996,-0.844163,0.872999,0.365596,0.907011,-0.991683,0.845713,0.760345,0.991530,0.579726,0.970450,0.758239,0.411320,0.993222,-0.995970,0.814965,-0.968199,0.637410,0.993421,-0.660824,0.353980,0.365894,0.973281,0.022603,-0.214611,0.958182,-0.956069,-0.802016,0.891400,-0.853218,0.896189,-0.992052,-0.876859,0.504892,0.799586,0.991420,0.980332,-0.633755,-0.024806,-0.980781,0.805924,0.828922,-0.670500,0.988874,0.678223,0.208953,0.756120,-0.046721,-0.782844,-0.850287,-0.792060,-0.995699,0.929453,-0.447933,0.302650,0.985130,0.992595,2022-11-30
190,fffff598cd1a947b8ce0b86d56fd356729ec7bacb7053a...,-0.064424,-0.938533,-0.939866,0.943526,0.602098,0.985547,-0.682055,0.153996,-0.844163,0.872999,0.365596,0.907011,-0.991683,0.845713,0.760345,0.991530,0.579726,0.970450,0.758239,0.411320,0.993222,-0.995970,0.814965,-0.968199,0.637410,0.993421,-0.660824,0.353980,0.365894,0.973281,0.022603,-0.214611,0.958182,-0.956069,-0.802016,0.891400,-0.853218,0.896189,-0.992052,-0.876859,0.504892,0.799586,0.991420,0.980332,-0.633755,-0.024806,-0.980781,0.805924,0.828922,-0.670500,0.988874,0.678223,0.208953,0.756120,-0.046721,-0.782844,-0.850287,-0.792060,-0.995699,0.929453,-0.447933,0.302650,0.985130,0.992595,2022-12-31


In [None]:
train_emb_df_part4.to_csv('train4_geo_emb_v2.csv', index=False)

In [49]:
1

1

In [None]:
train_emb_df_part3.to_csv('train3_geo_emb_v2.csv', index=False)

In [32]:
train_emb_df_part2.to_csv('train2_geo_emb_v2.csv', index=False)

In [26]:
train_emb_df_part1.to_csv('train1_geo_emb_v2.csv', index=False)

In [28]:
val_emb_df = get_val_dataset(processed_test, model, sorted(target_test.mon.sort_values().unique()))

100%|██████████| 7165/7165 [15:26<00:00,  7.74it/s]


In [31]:
val_emb_df.to_csv('val_geo_emb_v2.csv', index=False)

In [32]:
del val_emb_df

In [13]:
test_emb_df = get_test_dataset(processed_test, model)

100%|██████████| 652/652 [23:09<00:00,  2.13s/it]


In [15]:
test_emb_df.to_csv('test_geo_emb_v2.csv', index=False)

In [26]:
del test_emb_df

In [33]:
del processed_test