In [1]:
!pip install torch python-dotenv boto3 --quiet

[0m

In [2]:
%load_ext dotenv
%dotenv env

In [3]:
import torch
import os
import argparse
import logging
import boto3
from pprint import pformat
from torch import nn
import io
import tqdm
import numpy as np
import pandas as pd

In [4]:
import os


LABEL = [
    "continue_work_session_30_minutes"
]

METADATA = [
    "user_id",
    "session_30_raw",
    "cum_platform_event_raw",
    "cum_platform_time_raw",
    "cum_session_time_raw",
    "glob_session_time_raw", 
    "year",
    "month",
    "day",
    "hour",
    "minute",
    "second"
]

OUT_FEATURE_COLUMNS = [
    "user_count",
    "project_count",
    "country_count", 
    "date_hour_sin", 
    "date_hour_cos",
    "date_minute_sin",
    "date_minute_cos",
    
    "session_30_count",
    "session_5_count",
    "cum_session_event_count",
    "delta_last_event",
    "cum_session_time",
    
    "expanding_click_average",
    "cum_platform_time",
    "cum_platform_events",
    "cum_projects",
    "average_event_time",
    
    "rolling_session_time",
    "rolling_session_events",
    "rolling_session_gap",
    "previous_session_time",
    "previous_session_events",
]

GROUPBY_COLS = ['user_id']

LOAD_COLS = LABEL + METADATA + OUT_FEATURE_COLUMNS

S3_BUCKET = 'dissertation-data-dmiller'
BASE_CHECK_PATH = 'lstm_experiments/checkpoints'

LSTM_CHECKPOINTS = {
    'seq_40': 'lstm_experiments/checkpoints/data_v1/n_files_30/ordinal/sequence_length_40/data_partition_None/2023_06_12_13_05/clickstream-epoch=51-loss_valid=0.59.ckpt'
}

LSTM_CHECKPOINT_EMBEDDING = {
    'embedding_30': 'lstm_experiments/checkpoints/data_v1/n_files_30/embedded_ordinal_heuristic/sequence_length_30/data_partition_None/2023_05_10_14_18/clickstream-epoch=05-loss_valid=0.62.ckpt',
    'embedding_30_heuristic': 'lstm_experiments/checkpoints/data_v1/n_files_30/embedded_ordinal_heuristic/sequence_length_30/data_partition_None/2023_05_10_14_18/clickstream-epoch=05-loss_valid=0.62.ckpt'
}

In [5]:
# %load data_module
import pdb

import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import IterableDataset

LABEL_INDEX = 1
TOTAL_EVENTS_INDEX = 2
BATCHES = 1000000

    
class ClickstreamDataset(Dataset):
    def __init__(self, dataset_pointer_list) -> None:
        """
        Yield data in batches of BATCHES
        """
        self.events = dataset_pointer_list
        self.size = self.events[0].shape[0]


    def __getitem__(self, idx):
        events = [np.array([event[idx]]) for event in self.events]
        return np.concatenate(events, axis=1)
       

    def __len__(self):
        return self.size


In [6]:
# %load torch_model_bases
import torch 
from torch import nn
N_FEATURES = 22
class LSTMOrdinal(nn.Module):
    def __init__(self,  hidden_size=32, dropout=0.2) -> None:
        super(LSTMOrdinal, self).__init__()
        
        self.lstm = nn.LSTM(
            input_size=N_FEATURES,
            hidden_size=hidden_size,
            num_layers=2,
            batch_first=True,
            dropout=dropout
        )
        self.output = nn.Linear(
            hidden_size,
            1
        )

    def forward(self, x):


        x, _ = self.lstm(x)
        x = x[:, -1]
        return self.output(x)

In [7]:
# %load npz_extractor
import logging
import os
import zipfile

import boto3
import numpy as np
import torch
import logging

class NPZExtractor:
    logger = logging.getLogger(__name__)
    def __init__(self, input_path, n_files, n_sequences, s3_client, data_partition) -> None:
        self.input_path = input_path
        self.n_files = n_files
        self.n_sequences = n_sequences
        self.s3_client = s3_client
        self.data_partition = data_partition


    def get_dataset_pointer(self):

        read_path = os.path.join(self.input_path, f'files_used_{self.n_files}')
        if not os.path.exists(read_path):
            self.logger.info(f'Creating directory: {read_path}')
            os.makedirs(read_path)


        for _ in range(0, self.n_sequences +1, 10):
            key_zip, key_npy = (
                os.path.join(read_path, f'sequence_index_{_}.npz'),
                os.path.join(read_path, f'sequence_index_{_}')
            )
            
            self.logger.info(f'Loading pointer to dataset: {key_npy}: derived from {key_zip}')


            if not os.path.exists(key_zip):
                self.logger.info(f'Zip file to extract: {key_zip}: npy file to load: {key_npy}')
                self.s3_client.download_file(
                    'dissertation-data-dmiller',
                    key_zip,
                    key_zip
                )
            if not os.path.exists(key_npy):
                self.logger.info(f'Zip file downloaded: {key_zip}: npy file to load: {key_npy}')

                self._zip_extract(key_zip, key_npy)

        lz_concatenated_results = self._lazy_concatenate()

        if self.data_partition:
            return [p[:self.data_partition] for p in lz_concatenated_results]
        else:
            return lz_concatenated_results


    def _zip_extract(self, key_zip, key_npy):
        self.logger.info(f'Extracting file: {key_zip} -> {key_npy}')

        with zipfile.ZipFile(key_zip, 'r') as zip_ref:
            zip_ref.extractall(path=key_npy, members=['arr_0.npy'])

        self.logger.info(f'Zip file exracted: {key_zip} -> {key_npy}/arr_0.npy')

    def _lazy_concatenate(self):
        lz_concat = []
        for _ in range(0, self.n_sequences +1, 10):
            path_to_load = os.path.join(self.input_path, f'files_used_{self.n_files}', f'sequence_index_{_}', f'arr_0.npy')
            self.logger.info(f'Loading: {path_to_load}')
            lz_concat.append(np.load(path_to_load, mmap_mode='r'))
        return lz_concat

In [8]:
# %load likelihood_engagement_cpu
import torch
import os
import argparse
from torch.utils.data import DataLoader, Dataset
import logging
import boto3
from pprint import pformat
from torch import nn
import io
import tqdm
import numpy as np
import pandas as pd



pd.set_option('display.width', 1000)
pd.set_option('display.max_columns', 500)
pd.set_option('display.max_rows', 500)    
pd.set_option('mode.use_inf_as_na', True)
torch.set_printoptions(sci_mode=False, linewidth=400, precision=2)
np.set_printoptions(suppress=True, precision=4, linewidth=200)
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')

CHECKPOINT_DIR='lstm_experiments/checkpoints/data_v1/n_files_30/ordinal/sequence_length_10/data_partition_None/2023_03_30_07_54/clickstream-epoch=51-loss_valid=0.59.ckpt'
METADATA_INDEX = 13

logger = logging.getLogger('likelihood_engagement')

def parse_args():
    parser = argparse.ArgumentParser()
    
    parser.add_argument('--n_files', type=int, default=2)
    parser.add_argument('--n_sequences', type=int, default=20)
    parser.add_argument('--file_path', type=str, default='datasets/torch_ready_data')
    parser.add_argument('--checkpoint_dir', type=str, default=CHECKPOINT_DIR)
    parser.add_argument('--write_path', type=str, default='datasets/rl_ready_data')
    parser.add_argument('--model_type', type=str, default='ordinal')
    args = parser.parse_args()
    return args

def _extract_features(tensor, n_sequences, n_features):
    
    features_dict = {}
    tensor = tensor.squeeze()
    metadata = tensor[:, :METADATA_INDEX]
    features = tensor[:, METADATA_INDEX:]
                
    features = torch.flip(
        torch.reshape(features, (features.shape[0], 41, 22)),
        dims=[1]
    )
    
    features_dict['features_40'] = features
    features_dict['last_sequence'] = features[:, -1, :]

    return metadata, features_dict



def get_models(checkpoints: dict, s3_client, device):
    """_summary_
    Downloads models from s3 and loads them into memory.
    """
    models = {}
    for name, checkpoint in checkpoints.items():
        logger.info(f'Downloading model: {name}')
        response = s3_client.get_object(
            Bucket=S3_BUCKET,
            Key=checkpoint
        )
        buffer = io.BytesIO(response['Body'].read())
        state = torch.load(buffer, map_location=torch.device(device), )
        model = LSTMOrdinal()
        model.load_state_dict(state['state_dict'])
        model.to(device)
        models[name] = model
    return models

@torch.no_grad()
def generate_static_predictions(args):
    
    user_metadata_container = []
    client = boto3.client('s3')
    
    logger.info('Generating static prediction likelihoods for experiment')
    npz_extractor = NPZExtractor(
        args.file_path,
        args.n_files,
        args.n_sequences,
        client,
        None
           
    )
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'   
    logger.info(f'Setting device to {device}')
    
    logger.info('generating dataset pointer')
    dataset = npz_extractor.get_dataset_pointer()
    
    logger.info('Downloading model checkpoint')
    
    write_path = os.path.join(args.write_path, f'files_used_{args.n_files}')
    if not os.path.exists(write_path):
        logger.info(f'Creating directory: {write_path}')
        os.makedirs(write_path)
    
    client = boto3.client('s3')
    
    logger.info(f'Downloading models from checkpoints {LSTM_CHECKPOINTS.keys()}')
    
    models = get_models(LSTM_CHECKPOINTS, client, device)
    
    dataset = ClickstreamDataset(dataset)
    loader = DataLoader(dataset, batch_size=2048*8, shuffle=False, num_workers=8, pin_memory=True)
    activation = nn.Sigmoid()
    

    p_bar = tqdm.tqdm(loader, total=len(loader))
    
    for indx, data in enumerate(p_bar):
        p_bar.set_description(f'Processing batch: {indx}')
        data = data.to(device)
        
        metadata, features_dict = _extract_features(data, args.n_sequences + 1, 20)
        # logger.info(f'running inference on batch: {indx}')
        try:
            # preds_1 = activation(models['seq_1'](features_dict['last_sequence']))
            # preds_10 = activation(models['seq_10'](features_dict['features_10']))
            # preds_20 = activation(models['seq_20'](features_dict['features_20']))
            # preds_30 = activation(models['seq_30'](features_dict['features_30']))
            preds_40 = activation(models['seq_40'](features_dict['features_40']))
        except:
            raise Exception(f'Error processing batch: {indx}')
        
        # heuristic_data = featuRes_dict['features_30'].clone()
        # preds_30_heuristic = activation(models['seq_30_heuristic'](heuristic_data))
        # heuristic_scalar = torch.where(metadata[:, 4] < 25, torch.tensor(0.0), torch.tensor(1.0))
        # preds_30_heuristic = preds_30_heuristic * heuristic_scalar.unsqueeze(1)

        
        # logger.info(f'Concatenating metadata and predictions')
        user_metadata = torch.cat([metadata, features_dict['last_sequence'].squeeze(), preds_40], dim=1)
        user_metadata_container.append(user_metadata)

   
    predicted_data = torch.cat(user_metadata_container, dim=0)
    logger.info(f'Predicted data shape: {predicted_data.shape}: generating df')
    predicted_data = pd.DataFrame(predicted_data.cpu().numpy(), columns=LABEL + METADATA + OUT_FEATURE_COLUMNS + ['seq_40'])
   
    logger.info(f'Decoding date time data: merging to date time: {predicted_data.shape}')
    predicted_data['date_time'] = pd.to_datetime(predicted_data[['year', 'month', 'day', 'hour', 'minute', 'second']], errors='coerce').dropna()
    logger.info(f'Date time decoded and errors coerced: {predicted_data.shape}')
    write_path = os.path.join(write_path, 'predicted_data.parquet')
    logger.info(f'Writing to parquet: {write_path}')
    predicted_data.to_parquet(write_path, index=False)
    return predicted_data
    

In [9]:
class Arguments:
    n_files = 30
    n_sequences = 40
    file_path = 'torch_ready_data'
    write_path = 'rl_ready_data'

In [10]:
generate_static_predictions(Arguments)

2023-06-13 09:03:43,079 Found credentials in environment variables.
2023-06-13 09:03:43,169 Generating static prediction likelihoods for experiment
2023-06-13 09:03:43,176 Setting device to cuda
2023-06-13 09:03:43,177 generating dataset pointer
2023-06-13 09:03:43,178 Loading pointer to dataset: torch_ready_data/files_used_30/sequence_index_0: derived from torch_ready_data/files_used_30/sequence_index_0.npz
2023-06-13 09:03:43,178 Loading pointer to dataset: torch_ready_data/files_used_30/sequence_index_10: derived from torch_ready_data/files_used_30/sequence_index_10.npz
2023-06-13 09:03:43,179 Loading pointer to dataset: torch_ready_data/files_used_30/sequence_index_20: derived from torch_ready_data/files_used_30/sequence_index_20.npz
2023-06-13 09:03:43,179 Loading pointer to dataset: torch_ready_data/files_used_30/sequence_index_30: derived from torch_ready_data/files_used_30/sequence_index_30.npz
2023-06-13 09:03:43,180 Loading pointer to dataset: torch_ready_data/files_used_30/s

Unnamed: 0,continue_work_session_30_minutes,user_id,session_30_raw,cum_platform_event_raw,cum_platform_time_raw,cum_session_time_raw,glob_session_time_raw,year,month,day,hour,minute,second,user_count,project_count,country_count,date_hour_sin,date_hour_cos,date_minute_sin,date_minute_cos,session_30_count,session_5_count,cum_session_event_count,delta_last_event,cum_session_time,expanding_click_average,cum_platform_time,cum_platform_events,cum_projects,average_event_time,rolling_session_time,rolling_session_events,rolling_session_gap,previous_session_time,previous_session_events,seq_40,date_time
0,0.0,0.0,1.0,1.0,0.000000,0.000000,126762.929688,2021.0,10.0,19.0,8.0,40.0,37.0,-0.987115,-0.818056,1.0,0.866025,-0.500000,-0.866025,-0.500000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,0.328273,2021-10-20 08:40:37
1,0.0,1.0,1.0,1.0,0.000000,0.000000,154.899994,2021.0,10.0,19.0,8.0,40.0,38.0,-0.999728,1.000000,1.0,0.866025,-0.500000,-0.866025,-0.500000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,0.349654,2021-10-20 08:40:38
2,0.0,0.0,1.0,2.0,0.033333,0.033333,126762.929688,2021.0,10.0,19.0,8.0,40.0,39.0,-0.987115,-0.818056,1.0,0.866025,-0.500000,-0.866025,-0.500000,-1.000000,-1.000000,-0.999839,-0.997777,-0.999940,-0.998480,-1.000000,-0.999994,-1.000000,-0.998322,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,0.342207,2021-10-20 08:40:39
3,1.0,2.0,1.0,1.0,0.000000,0.000000,687603.625000,2021.0,10.0,19.0,8.0,40.0,39.0,-0.952089,-0.036256,1.0,0.866025,-0.500000,-0.866025,-0.500000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,0.308344,2021-10-20 08:40:39
4,0.0,0.0,1.0,3.0,0.100000,0.066667,126762.929688,2021.0,10.0,19.0,8.0,40.0,41.0,-0.987115,-0.818056,1.0,0.866025,-0.500000,-0.866025,-0.500000,-1.000000,-1.000000,-0.999679,-0.997777,-0.999880,-0.997974,-1.000000,-0.999987,-1.000000,-0.997762,-1.000000,-1.000000,-1.000000,-1.000000,-1.000000,0.294768,2021-10-20 08:40:41
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
38500985,0.0,34793668.0,18.0,844.0,16293.116211,41.200001,16293.116211,2022.0,8.0,14.0,5.0,13.0,25.0,-0.996299,-0.839513,1.0,0.965926,0.258819,0.978148,0.207912,-0.966270,-0.988875,-0.971268,-0.987771,-0.925596,-0.986172,-0.999088,-0.994518,-0.999016,-0.957931,-0.940610,-0.977270,-0.968728,-0.958735,-0.986678,0.331056,2022-08-14 05:13:25
38500986,0.0,231560.0,89.0,7063.0,108009.000000,22.366667,108031.382812,2022.0,8.0,14.0,5.0,13.0,26.0,-0.968992,0.944018,1.0,0.965926,0.258819,0.978148,0.207912,-0.825397,-0.934260,-0.955859,-0.998888,-0.959608,-0.996961,-0.993955,-0.954080,-0.999754,-0.987982,-0.947780,-0.965470,-0.950459,-0.997050,-0.997593,0.184052,2022-08-14 05:13:26
38500987,0.0,38358912.0,4.0,578.0,16940.632812,54.049999,16940.632812,2022.0,8.0,14.0,5.0,13.0,26.0,-0.997467,0.195245,1.0,0.965926,0.258819,0.978148,0.207912,-0.994048,-0.997472,-0.950080,-0.986659,-0.902390,-0.979638,-0.999052,-0.996248,-0.999016,-0.965205,-0.831934,-0.959448,-0.996146,-0.869462,-0.976567,0.370998,2022-08-14 05:13:26
38500988,0.0,231560.0,89.0,7064.0,108031.382812,22.383333,108031.382812,2022.0,8.0,14.0,5.0,13.0,27.0,-0.968992,0.944018,1.0,0.965926,0.258819,0.978148,0.207912,-0.825397,-0.934260,-0.955698,-0.998888,-0.959577,-0.997569,-0.993953,-0.954074,-0.999754,-0.988007,-0.947780,-0.965470,-0.950459,-0.997050,-0.997593,0.215153,2022-08-14 05:13:27


: 