In [25]:
import streamlit as st
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle

In [29]:
# Load the dictionary from the file
with open('action_to_idx.pkl', 'rb') as f:
    action_to_idx = pickle.load(f)

In [30]:
class LSTM(nn.Module):
    def __init__(self, feature_sizes, embedding_dim=64, hidden_size=128, dropout=0.5):
        super().__init__()
        # Create an embedding for each feature
        self.embeddings = nn.ModuleList([
            nn.Embedding(num_embeddings=size, embedding_dim=embedding_dim)
            for size in feature_sizes
        ])
        
        # Dropout module for regularization
        self.dropout = nn.Dropout(dropout)
        
        # First LSTM layer: input is concatenated embeddings
        self.lstm1 = nn.LSTM(
            input_size=embedding_dim * len(feature_sizes),
            hidden_size=hidden_size,
            batch_first=True
        )
        # A linear projection to match the dimensions for the first residual connection
        self.residual_proj1 = nn.Linear(embedding_dim * len(feature_sizes), hidden_size)
        
        # Second LSTM layer: input and output are both hidden_size
        self.lstm2 = nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size,
            batch_first=True
        )
        
        # Activation function
        self.activation = nn.ReLU()
        
        # Final fully-connected layer to produce logits over actions
        self.fc = nn.Linear(hidden_size, len(action_to_idx))
        
    def forward(self, x):
        # x shape: (batch_size, time_steps, num_features)
        batch_size, seq_len, num_feats = x.size()
        
        # Process each feature through its embedding
        embedded = []
        for i in range(num_feats):
            emb = self.embeddings[i](x[:, :, i])  # (batch_size, seq_len, embedding_dim)
            embedded.append(emb)
        x_emb = torch.cat(embedded, dim=-1)  # (batch_size, seq_len, embedding_dim*num_feats)
        x_emb = self.dropout(x_emb)
        
        # First LSTM layer
        out1, _ = self.lstm1(x_emb)  # (batch_size, seq_len, hidden_size)
        # Residual: project input embeddings and add to LSTM output
        res1 = self.residual_proj1(x_emb)  # (batch_size, seq_len, hidden_size)
        out1 = self.activation(out1 + res1)
        out1 = self.dropout(out1)
        
        # Second LSTM layer
        out2, _ = self.lstm2(out1)  # (batch_size, seq_len, hidden_size)
        # Residual: add the output of the first LSTM layer (out1) to the output of the second
        out2 = self.activation(out2 + out1)
        out2 = self.dropout(out2)
        
        # Use the output of the final time step for prediction
        logits = self.fc(out2[:, -1, :])  # (batch_size, num_actions)
        return logits

In [31]:
@st.cache_resource
def load_model():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    checkpoint = torch.load('model.pth', map_location=device)

    # Load mappings and parameters first
    feature_to_idx = checkpoint['feature_to_idx']
    action_to_idx = checkpoint['action_to_idx']
    idx_to_action = {v: k for k, v in action_to_idx.items()}
    features_order = checkpoint['features_order']
    feature_sizes = checkpoint['feature_sizes']
    
    # Recreate model architecture
    model = LSTM(feature_sizes=feature_sizes)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()

    return model, feature_to_idx, action_to_idx, idx_to_action, features_order

# Load from the checkpoint
model, feature_to_idx, action_to_idx, idx_to_action, features_order = load_model()

  checkpoint = torch.load('model.pth', map_location=device)


In [32]:
# Initialize session state
if 'history' not in st.session_state:
    st.session_state.history = []
if 'current_features' not in st.session_state:
    st.session_state.current_features = {feat: 0 for feat in features_order}


2025-02-23 23:25:39.537 Session state does not function when running a script without `streamlit run`


In [33]:
# List of features with dropdowns
DROPDOWN_FEATURES = [
    "city", "country", "device_family", "device_type", 
    "language", "os_name", "user_properties_roles"
]

# List of event_type actions (truncated for display)
event_actions = [
    'session_end', 'application-window-opened', 'session_start', 'agency-dashboard::layout:render', 'agency-dashboard:::view', 'agency-dashboard::widget:render', 'agency-dashboard::configurable-table:render', '::nav-header:user-signed-out', 
    'dashboard:my-book:configurable-table:render', 'dashboard:my-book:widget:render', 'triaged-submission-list:my-book:configurable-table:render', 'triaged-submission-list:my-book::view', 'dashboard:my-book:layout:render', 'dashboard:my-book::view',
    '::nav-header:action-center-click', 'action-center:::view', 'account:::view', 'account-lines:::view', 'account-lines::layout:render', 'account-lines::widget:render', 'account-lines::configurable-table:render', ':all-accounts:configurable-table:render', 
    ':all-accounts:widget:render', ':all-accounts:layout:render', ':all-accounts::view', 'submissions:policy-definition::submit-click', 'submissions:all-policy:configurable-table:render', 'submissions:all-policy::view', 'submissions:triaged_submissions-definition::view',
    'triaged-submission:triaged_submissions-definition:layout:render', 'triaged-submission:triaged_submissions-definition::view', 'triaged-submission:triaged_submissions-definition:widget:render', 'triaged-submission-list:triaged_submissions-definition:configurable-table:render', 
    'triaged-submission-list:triaged_submissions-definition::view', 'submissions:policy-definition::view', 'submissions:policy-definition:configurable-table:render', 'submissions:policy-create::view', 'submissions:policy-create::submit-click', 'account-lines:::change-rating-click', 
    'account-property-rating:perils:configurable-table:render', 'account-property-rating:perils::view', 'action-center:::submit-click', 'action-center:action-details::view', 'action-center:::close-click', 'dashboard:my-book::action-click', 'action-center:action-details:response-form:submit-click', 
    'account-lines::templeton-docs:create-document-click', 'account-property-rating:perils:perils-table:add-click', 'account-property-rating:perils:perils-table:edit-click', 'account-property-rating:perils:perils-table:delete-click', 'dashboard:portfolio-insights:layout:render', 'dashboard:portfolio-insights::view', 
    'dashboard:portfolio-insights:widget:render', 'dashboard:my-book:recent-actions-table:action-click', 'account-auto-rating:::view', 'account-auto-rating::configurable-table:render', 'account-property-rating:perils:layers:add-click', 'account-property-rating:perils:model-request-details:save-click', 'submissions:exposures-create::submit-click', 
    'submissions:all-exposures:configurable-table:render', 'submissions:all-exposures::view', 'submissions:exposures-create::view', 'dashboard:my-book:recent-actions-table:account-click', '::configurable-table:render', '::layout:render', '::widget:render', 'EMPTY', 'submissions:all-account::view', 'submissions:all-account:configurable-table:render', 
    'submissions:account-create::view', 'account-broker-view::layout:render', 'account-broker-view:::view', 'account-broker-view::widget:render', 'agency-account::layout:render', 'agency-account:::view', 'agency-account::widget:render', 'agency-account::configurable-table:render', 
    'account-broker-view::configurable-table:render', 'submissions:all-ingest_policy_through_pd:configurable-table:render', 'submissions:all-ingest_policy_through_pd::view', 'submissions:ingest_policy_through_pd-create::view', '::nav-header:help-menu-opened', 
    'account-lines::duplicate-policy-modal:duplicate-rating', 'account-property-rating::duplicate-policy-modal:duplicate-rating', 
    'account-lines::construction-excess-rater:save-new-quote-click', 'account-lines::construction-excess-rater:create-document-click',
    '::duplicate-policy-modal:duplicate-rating', 'all-accounts:renewals:layout:render', 'all-accounts:renewals::view', 'all-accounts:renewals:configurable-table:render', 
    'all-accounts:renewals:widget:render', 'submissions:all-financial_lines::view', 'dashboard:team-insights:layout:render', 
    'dashboard:team-insights::view', 'dashboard:team-insights:widget:render', 'account-property-rating:pricing-detail:configurable-table:render',
    'account-property-rating:pricing-detail::view', 'account-property-rating:pricing-detail::open-ra-file-click',
    'account-property-rating:building-details:configurable-table:render', 'account-property-rating:building-details::view', 
    'submissions:exposures-definition::view', 'submissions:all-renewal::view', 'submissions:renewal-definition::view', 
    'all-accounts:new-business::view', 'all-accounts:new-business:layout:render', 'submissions:policy-create:configurable-table:render', 
    'submissions:renewal-create::view', 'submissions:renewal-definition::submit-click', 'submissions:renewal-create::submit-click',
    'submissions:all-renewal:configurable-table:render', 'all-accounts:new-business:accounts-table:account-click', 
    'account-lines::construction-excess-rater:modify-existing-quote-click', 'linked-email-thread-attachments:triaged_submissions-definition::document-download-click', 
    'submissions:all-auto::view', 'submissions:all-auto:configurable-table:render', 'account-workers-comp-rating:::view', 
    'account-workers-comp-rating:::change-rating-click', ':all-accounts::advanced-filters-opened', ':all-accounts:accounts-table:account-click',
    'account-broker-readonly-view::layout:render', 'account-broker-readonly-view:::view', 'account-broker-readonly-view::widget:render', 
    'triaged-submission:triaged_submissions-definition::winnability-click', 'triaged-submission:triaged_submissions-definition::appetite-click', 'assigned-email-thread:::email-thread-expansion', 
    'assigned-email-thread:::document-download-click', 'submissions:all-exposure_demo::view', 'submissions:all-exposure_demo:configurable-table:render', 'submissions:all-sashco_submission:configurable-table:render', 
    'submissions:all-sashco_submission::view', 'goals-and-rules:goals:configurable-table:render', 'goals-and-rules:goals::view', 'goals-and-rules:goal-definition::view', 
    'account-broker-readonly-view::configurable-table:render', 'submissions:all-terrorism::view', 'submissions:terrorism-create::view', 'submissions:all-terrorism:configurable-table:render', 'submissions:financial_lines-create::view', 
    'submissions:all-financial_lines:configurable-table:render', 'all-accounts:new-business:configurable-table:render', 'contacts::configurable-table:render', 'brokerage::configurable-table:render', 'brokerage::layout:render', 
    'brokerage:::view', 'brokerage::widget:render', 'complex-rules::configurable-table:render', 'classification-rules::configurable-table:render', 'rule:::view', 'rule::configurable-table:render', 
    'account-lines:::action-center-click', 'account-auto-rating:::change-rating-click', ':::account-click', 'account-property-rating::configurable-table:render', 'carriers::configurable-table:render', 
    'submissions:policy-definition::save-click', 'account-property-rating:perils:layers:delete-click', 'account-auto-rating::duplicate-policy-modal:duplicate-rating',
    'classification-rule:::view', 'classification-rule::configurable-table:render', 'submissions:policy-create::save-click', 'account-property-rating:::change-rating-click', 'goals-and-rules:rules:configurable-table:render', 
    'goals-and-rules:rules::view', 'goals-and-rules:new-rule::view', 'goals-and-rules:new-rule::close-click', 'reinsurance-binders::configurable-table:render', 'reinsurers-on-binders::configurable-table:render', 'reinsurers-on-binders:::view'
]

In [34]:
# Sidebar for feature selection
with st.sidebar:
    st.header("Feature Selection")
    current_selections = {}
    
    for feat in DROPDOWN_FEATURES:
        options = list(feature_to_idx[feat].keys())[1:]  # exclude default
        selection = st.selectbox(
            f"{feat.replace('_', ' ').title()}",
            options,
            key=feat
        )
        current_selections[feat] = feature_to_idx[feat].get(selection, 0)

# Main interface
st.title("Action Predictor")



DeltaGenerator()

In [35]:
# Create buttons for event types in columns
cols = st.columns(4)
button_idx = 0
for action in event_actions:
    with cols[button_idx % 4]:
        if st.button(action):
            # Record current state + action
            record = {
                **current_selections,
                "event_type": action_to_idx.get(action, 0)
            }
            st.session_state.history.append(record)
            
            # Keep only last 8 steps
            if len(st.session_state.history) > 8:
                st.session_state.history = st.session_state.history[-8:]
    button_idx += 1



In [36]:
# Display recent history
st.subheader("Recent Actions")
history_display = [f"Step {i+1}: {h['event_type']}" for i, h in enumerate(st.session_state.history[-8:])]
st.write("\n".join(history_display))

# Prediction logic
if len(st.session_state.history) >= 8:
    # Prepare input tensor
    input_data = []
    for step in st.session_state.history[-8:]:
        step_features = [step.get(feat, 0) for feat in features_order]
        input_data.append(step_features)
    
    input_tensor = torch.LongTensor([input_data])  # (1, 8, 34)
    
    # Predict
    with torch.no_grad():
        logits = model(input_tensor)
        pred_idx = torch.argmax(logits).item()
    
    predicted_action = idx_to_action.get(pred_idx, "unknown")
    st.subheader(f"Predicted Next Action: {predicted_action}")

# Reset button
if st.button("Reset Session"):
    st.session_state.history = []
    st.experimental_rerun()

