# Libraries

## Python imports

In [4]:
import os
import polars as pl
import time
from pathlib import Path
import re
from concurrent.futures import ProcessPoolExecutor
from tqdm import tqdm
import json
from collections import defaultdict
from tabulate import tabulate
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.preprocessing import RobustScaler, MinMaxScaler
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import AdamW
from tensorflow.keras.losses import Huber
import tensorflow as tf
import time
from datetime import datetime, timedelta



from itables import init_notebook_mode, show

# Initialize itables for interactive tables in Jupyter Notebook
init_notebook_mode(all_interactive=True)  # Initialize


## Variables


In [5]:
SOL_ADDRESS = "So11111111111111111111111111111111111111112"

In [6]:
columns_to_view = ["pairAddress", "baseToken.address", "averageLiquidity", "cfv_future", "cumulative_fee_volume", "percentageProfit", "pairAge", "baseToken.name"]
columns_to_describe = ["cfv_future", "cumulative_fee_volume", "percentageProfit"]

# columns to keep
numeric_fields = [
    "fees24h.m5",           # Float64
    "volume.m5",            # Float64
    "liquidity.meteora",    # Float64
    "volume24h.h24",        # Float64
    "volume.h1",            # Float64
    "apr",                  # Float64
    "txns.m5.sells",        # Int64
    "priceChange.h1",       # Float64
    "priceChange.m5",       # Float64
    "txns.h1.buys",         # Int64
    "txns.h6.sells",        # Int64
    "txns.h1.sells",        # Int64
    "reserve_x_amount",     # Float64
    "volume24h.h1",         # Float64
    "pairCreatedAt",        # Int64
    "farm_apy",             # Int64
    "feeToTvl.max",         # Float64
    "fees_24h",             # Float64
    "feeToTvl.h24",         # Float64
    "txns.h24.sells",       # Int64
    "fees24h.max",          # Float64
    "liquidity.usd",        # Float64
    "fdv",                  # Int64
    "feeToTvl.h6",          # Float64
    "txns.m5.buys",         # Int64
    "cumulative_fee_volume",# Float64
    "marketCap",            # Int64
    "fees24h.h1",           # Float64
    "liquidity.quote",      # Float64
    "fees24h.h6",           # Float64
    "volume24h.h6",         # Float64
    "fees24h.h24",          # Float64
    "volume24h.min",        # Float64
    "volume.h24",           # Float64
    "txns.h6.buys",         # Int64
    "volume.h6",            # Float64
    "base_fee",             # Float64
    "feeToTvl.h1",          # Float64
    "farm_apr",             # Int64
    "priceChange.h6",       # Float64
    "txns.h24.buys",        # Int64
    "apy",                  # Float64
    "volume24h.max",        # Float64
    "feeToTvl.m5",          # Float64
    "bin_step",             # Int64
    "feeToTvl.min",         # Float64
    "fees24h.min",          # Float64
    "liquidity.base",       # Float64
    "trade_volume_24h",     # Float64
    "priceChange.h24",      # Float64
    "today_fees",           # Float64
    "volume24h.m5",         # Float64
    "pairAge",              # Int64
    "percentageProfit"      # Float64
]

string_fields = [
    "cumulative_trade_volume",  # String
    "priceUsd",                 # String
    "liquidity",                # String
    "base_fee_percentage",      # String
    "protocol_fee_percentage",  # String
    "priceNative",              # String
    # "trend",                    # String
    "max_fee_percentage",       # String
    "reserve_y_amount",         # String
]

bool_fields = [
    # "strict",                   # Boolean
    # "hide",                     # Boolean
    "isPumpToken"               # Boolean
]

columns_to_keep = string_fields + numeric_fields + bool_fields

## Preprocessing util functions

In [7]:
def explode_nested_values(json_data):
    """
    Explodes nested fields in JSON data, flattening nested dictionaries and lists.
    :param json_data: Parsed JSON data (list of dictionaries).
    :return: Flattened JSON data.
    """
    def flatten(record, parent_key='', sep='.'):
        """
        Flattens a single dictionary by recursively expanding nested dictionaries and lists.
        :param record: The dictionary to flatten.
        :param parent_key: Base key for recursion.
        :param sep: Separator for nested keys.
        :return: A flattened dictionary.
        """
        items = []
        for key, value in record.items():
            new_key = f"{parent_key}{sep}{key}" if parent_key else key

            # temporary measure to avoid certain nested fields
            keys_to_skip = set(["boosts", "info"])
            if key in keys_to_skip:
                continue
            
            if isinstance(value, dict):
                items.extend(flatten(value, new_key, sep=sep).items())
            elif isinstance(value, list):
                # If the value is a list, index each element with its position
                for i, item in enumerate(value):
                    items.extend(flatten({f"{new_key}[{i}]": item}, sep=sep).items())
            else:
                items.append((new_key, value))
        return dict(items)

    return [flatten(record) for record in json_data]

def find_and_update_large_int_fields(json_data):
    """
    Find fields with large integers and update the JSON data to convert them to floats.
    """
    max_supported_int = 9_223_372_036_854_775_807  # Maximum value for 64-bit integers
    large_int_fields = set()

    # Find fields with large integers
    for record in json_data:
        for key, value in record.items():
            if isinstance(value, int) and value > max_supported_int:
                large_int_fields.add(key)

    # Update the JSON data: convert values in those fields to strings
    for record in json_data:
        for key in large_int_fields:
            if key in record and isinstance(record[key], int):
                record[key] = float(record[key])

    return json_data, list(large_int_fields)

def find_mixed_type_fields(json_data):
    """
    Identify fields with mixed data types in the JSON data.
    Store types and example values for manual inspection.
    :param json_data: Parsed JSON data (list of dictionaries).
    :return: A dictionary of fields with their types and example values.
    """
    type_map = defaultdict(set)  # Store unique data types for each field
    value_map = defaultdict(list)  # Store example values for each field

    # Analyze each record
    for record in json_data:
        for key, value in record.items():
            value_type = type(value).__name__
            
            if value_type not in type_map[key] or len(value_map[key]) < 5:  # Limit stored examples for readability
                type_map[key].add(value_type)
                value_map[key].append((value_type, value))

    # Find fields with mixed types
    mixed_type_fields = {
        key: {"types": list(types), "examples": value_map[key]}  # Convert set to list
        for key, types in type_map.items()
        if len(types) > 1
    }
    return mixed_type_fields

def normalize_mixed_types(json_data, mixed_type_fields):
    """
    Normalize mixed types in the JSON data. Currently, we convert every
    data type that has a float in its type to floats.
    """
    field_types = defaultdict(set)
    for field in mixed_type_fields:
        field_types[field] = mixed_type_fields[field]["types"]

    for record in json_data:
        for field in mixed_type_fields:
            if field in record:
                value = record[field] if record[field] is not None else float(0)
                if not isinstance(value, float) and "float" in field_types[field]:
                    record[field] = float(value)  # Convert integers to floats for consistency
                elif isinstance(value, float):
                    pass
                else:
                    raise ValueError(f"Did not handle field {field} with value {value} and types {field_types[field]}")
    return json_data

def miscellaneous_cleaning(json_data):
    for record in json_data:
        force_conversions = [[float, ["reserve_x_amount", "cumulative_fee_volume"]]]
        for conversion, fields in force_conversions:
            for field in fields:
                if field in record:
                    try:
                        record[field] = conversion(record[field])
                    except ValueError:
                        print(f"Invalid {conversion} conversion: {record[field]}")
            
    return json_data

## Main loading functions

In [8]:
def read_file_and_process(file_path, batch_size=1000):
    try:
        with open(file_path, 'r') as file:
            file_str = file.read()
            json_data = json.loads(file_str)

        
        # Explode nested fields
        json_data = explode_nested_values(json_data)

        # Find and update fields with large integers
        json_data, _ = find_and_update_large_int_fields(json_data)

        # Find fields with mixed data types
        mixed_type_fields = find_mixed_type_fields(json_data)
        if mixed_type_fields and "timestamp" in mixed_type_fields:
            print("mixed_type_fields: ", mixed_type_fields["timestamp"])
        if mixed_type_fields:
            json_data = normalize_mixed_types(json_data, mixed_type_fields)

        json_data = miscellaneous_cleaning(json_data)

        # Process the JSON data in batches due to weird error 
        # that occurs when loading all data into a single dataframe
        # in polars
        all_batches = []
        for i in range(0, len(json_data), batch_size):
            batch = json_data[i:i + batch_size]
            batch_df = pl.DataFrame(batch)
            all_batches.append(batch_df)

        # Concatenate all batches into a single Polars DataFrame
        df = pl.concat(all_batches, how="vertical")
        return df, json_data

    except Exception as e:
        print(f"Error reading {file_path}: {e}")
        return None

def process_file(file, batch_size=100_000):
    """
    Process a single JSON file: load, parse, and return metadata.
    :param file: Path to the file to be processed.
    :param batch_size: Batch size for processing large JSON data.
    :return: Tuple (key_suffix, timestamp, subdir, DataFrame).
    """
    file_pattern = re.compile(r"(\d+)_historical_data_(.+)\.json")
    match = file_pattern.match(file.name)

    if not match:
        return None

    try:
        timestamp = int(match.group(1))  # Extract timestamp
        key_suffix = match.group(2)  # Extract key suffix

        # Read JSON into a Polars DataFrame
        df, json_data = read_file_and_process(file, batch_size=batch_size)

        # Return processed data
        return key_suffix, timestamp, file.parent.name, df, json_data
    except Exception as e:
        print(f"Error reading {file}: {e}")
        return None

def load_and_group_json_data_by_suffix_parallel(base_dir, max_files=None):
    """
    Load and group JSON data by key suffix using parallel processing.
    :param base_dir: Base directory containing subdirectories with JSON files.
    :return: Sorted dictionary of grouped data.
    """
    # Directories to process
    subdirs = ["dex_screener_pairs", "enriched_data", "meteora_pairs"]
    grouped_data = {}

    # Find all JSON files to process
    files_to_process = []
    # file_pattern = re.compile(r"(\d+)_historical_data_(.+)\.json")

    for subdir in subdirs:
        full_path = Path(base_dir) / subdir
        if not full_path.is_dir():
            print(f"Skipping non-existent directory: {full_path}")
            continue
        if max_files:
            files_to_process.extend(list(full_path.glob("*.json"))[:max_files])
        else:
            files_to_process.extend(full_path.glob("*.json"))

    # Process files in parallel with a progress bar
    with ProcessPoolExecutor() as executor:
        results = list(
            tqdm(
                executor.map(process_file, files_to_process),
                total=len(files_to_process),
                desc="Processing files",
                unit="file"
            )
        )

    # Group results by key suffix
    for result in results:
        if result is None:
            continue

        key_suffix, timestamp, subdir, df, _ = result

        if key_suffix not in grouped_data:
            grouped_data[key_suffix] = {
                "files": [],  # List of file metadata (timestamp, directory, data)
                "group_timestamp": None,  # Earliest timestamp in the group
            }

        grouped_data[key_suffix]["files"].append({
            "timestamp": timestamp,
            "directory": subdir,
            "data": df,
        })

    # Calculate group timestamp (earliest timestamp) and sort groups
    for key, group in grouped_data.items():
        group["group_timestamp"] = min(file["timestamp"] for file in group["files"])

    print("grouped_data time: ", [gd["group_timestamp"] for gd in [x for x in grouped_data.values()][0:5]])

    # Sort the groups by group timestamp
    sorted_grouped_data = dict(
        sorted(grouped_data.items(), key=lambda item: item[1]["group_timestamp"])
    )

    return sorted_grouped_data


## Aggregation functions

In [9]:
def process_new_files(groups):
    """
    Process groups into a new_files dictionary, adding group_timestamp to each DataFrame entry
    and merging meteora_pairs into enriched_data based on the matching columns.
    """
    new_files = {"dex_screener_pairs": [], "meteora_pairs": [], "enriched_data": []}

    for key, group in groups.items():
        group_timestamp = group["group_timestamp"]
        meteora_pairs = None

        # Extract meteora_pairs in advance
        for file in group["files"]:
            if file["directory"] == "meteora_pairs":
                meteora_pairs = file["data"].clone().with_columns(pl.lit(group_timestamp).alias("group_timestamp"))
                break

        for file in group["files"]:
            directory = file["directory"]
            df = file["data"].clone()

            # Add group_timestamp as a new column to the DataFrame
            df = df.with_columns(pl.lit(group_timestamp).alias("group_timestamp"))

            # If directory is enriched_data, merge meteora_pairs
            if directory == "enriched_data" and meteora_pairs is not None:
                # Select relevant fields from meteora_pairs
                meteora_pairs_selected = meteora_pairs.select([
                    "address",
                    "apr",
                    "max_fee_percentage",
                    "cumulative_trade_volume",
                    "cumulative_fee_volume",
                    "reserve_x_amount",
                    "apy",
                    "base_fee_percentage",
                    "liquidity",
                    "hide",
                    "trade_volume_24h",
                    "farm_apy",
                    "reserve_y_amount",
                    "protocol_fee_percentage",
                    "today_fees",
                    "farm_apr",
                    "fees_24h"
                ])

                # Perform the join on "pairAddress" and "address"
                df = df.join(
                    meteora_pairs_selected,
                    left_on="pairAddress",
                    right_on="address",
                    how="left"  # Keep all rows from enriched_data
                )
                
            # Append the processed DataFrame to the corresponding directory list
            if directory in new_files:
                new_files[directory].append(df)

    return new_files


def aggregate_new_files_with_logging(new_files):
    """
    Aggregate new_files data into one massive Polars DataFrame per directory.
    Logs mismatched columns for debugging purposes.
    """
    data_book = {}

    for directory, dfs in new_files.items():
        if dfs:  # Ensure there are DataFrames to concatenate
            # Collect all column names from the DataFrames
            all_columns = [set(df.columns) for df in dfs]
            common_columns = set.intersection(*all_columns)
            mismatched_columns = [columns - common_columns for columns in all_columns]

            # Log mismatched columns
            # print(f"\nDirectory: {directory}")
            for i, mismatch in enumerate(mismatched_columns):
                if mismatch:
                    print(f"  DataFrame {i} has additional columns: {mismatch}")

            # Align column names to ensure they match for concatenation
            dfs = [df.select(list(common_columns)) for df in dfs]

            # Concatenate aligned DataFrames
            aggregated_df = pl.concat(dfs)
            data_book[f"aggregated_{directory}"] = aggregated_df

    return data_book



## Data diagnostics


In [10]:

def compare_dataframe_fields(data_book):
    """
    Compare fields between dataframes by identifying columns that exist in dex_screener_pairs and meteora_pairs
    but are missing from the enriched dataset. For each missing field, returns the field name, data type,
    and source dataframe in a tabulated format.
    example:
    
    usage:
        compare_dataframe_fields(data_book)

    output:
        | Missing Field | Type | Source Dataframe |
        | ------------- | ---- | --------------- |
        | info.websites[3].url | str | aggregated_dex_screener_pairs |
    """
    # Access the dataframes
    aggregated_enriched_data = data_book.get("aggregated_enriched_data")
    aggregated_dex_screener_pairs = data_book.get("aggregated_dex_screener_pairs")
    aggregated_meteora_pairs = data_book.get("aggregated_meteora_pairs")

    # Get columns for each dataframe
    enriched_columns = set(aggregated_enriched_data.columns)
    dex_screener_columns = set(aggregated_dex_screener_pairs.columns)
    meteora_columns = set(aggregated_meteora_pairs.columns)

    # Find missing fields
    missing_from_enriched_in_dex = dex_screener_columns - enriched_columns
    missing_from_enriched_in_meteora = meteora_columns - enriched_columns

    # Prepare a list for tabulated output
    result = []
    
    # Add missing fields from dex_screener_pairs
    for field in missing_from_enriched_in_dex:
        field_type = aggregated_dex_screener_pairs[field].dtype
        result.append([field, field_type, "aggregated_dex_screener_pairs"])

    # Add missing fields from meteora_pairs
    for field in missing_from_enriched_in_meteora:
        field_type = aggregated_meteora_pairs[field].dtype
        result.append([field, field_type, "aggregated_meteora_pairs"])

    # Print the results in tabulated format
    headers = ["Missing Field", "Type", "Source Dataframe"]
    print(tabulate(result, headers=headers, tablefmt="grid"))

def investigate_schema_mismatches_with_examples(new_files):
    """
    Investigate schema mismatches across DataFrames in each directory of new_files.
    Logs mismatched types and provides a summary with example values for each type.
    example:
        usage:
            investigate_schema_mismatches_with_examples(new_files)

        output:
            Mismatched column types with examples:
            Column 'info.websites[3].url' has multiple types:
                Type: str, Example: https://www.google.com
    """
    for directory, dfs in new_files.items():
        if directory != "meteora_pairs":
            continue
            
        if not dfs:
            print(f"\nDirectory: {directory} has no DataFrames.")
            continue

        print(f"\nDirectory: {directory}")
        schema_summary = {}

        # Collect column types and example values for each DataFrame
        for i, df in enumerate(dfs):
            # print(f"  DataFrame {i} schema:")
            schema = df.schema  # Returns a dictionary of column names and types
            for col, dtype in schema.items():
                # print(f"    {col}: {dtype}")

                # Track the types seen for each column and example values
                if col not in schema_summary:
                    schema_summary[col] = {}

                if dtype not in schema_summary[col]:
                    # Get an example value of the specific type from the column
                    try:
                        example_value = df.select(col).filter(pl.col(col).is_not_null()).head(1).to_numpy()[0, 0]
                    except Exception:
                        example_value = None  # Handle cases where getting an example fails
                    schema_summary[col][dtype] = example_value

        # Identify mismatched types
        print("\n  Mismatched column types with examples:")
        for col, types_with_examples in schema_summary.items():
            if len(types_with_examples) > 1:  # More than one type found for the column
                print(f"    Column '{col}' has multiple types:")
                for dtype, example in types_with_examples.items():
                    print(f"      Type: {dtype}, Example: {example}")



## Visualization

In [11]:

def print_side_by_side_tables_from_grouped_data(grouped_data, key, file_index_1, file_index_2, row_index=0):
    # Extract the data for the specified rows
    data_1 = grouped_data[key]["files"][file_index_1]["data"][row_index]
    data_2 = grouped_data[key]["files"][file_index_2]["data"][row_index]

    return print_side_by_side_tables(data_1, data_2)
    
def print_side_by_side_tables(raw_data_1, raw_data_2):
    """
    Prints two tables side by side for easy comparison.

    This function takes two data inputs, converts them to plain Python types,
    and prints them in a tabular format side by side using the `tabulate` library.

    Parameters:
    - raw_data_1: A data object with a `to_dict()` method that returns a dictionary
      where keys are column names and values are data values or objects with a `to_list()` method.
    - raw_data_2: A data object similar to `raw_data_1`.

    The function assumes that the input data can be converted to a dictionary and that
    the values in the dictionary can be converted to lists if they have a `to_list()` method.

    Example:
    --------
    Suppose `raw_data_1` and `raw_data_2` are dataframes with the following content:

    raw_data_1:
    | Column | Value |
    |--------|-------|
    | A      | [1, 2]|
    | B      | [3, 4]|

    raw_data_2:
    | Column | Value |
    |--------|-------|
    | A      | [5, 6]|
    | C      | [7, 8]|

    The output will be:
    +---------+---------+---+---------+---------+
    | Column  | Value   | | | Column  | Value   |
    +=========+=========+===+=========+=========+
    | A       | [1, 2]  | | | A       | [5, 6]  |
    +---------+---------+---+---------+---------+
    | B       | [3, 4]  | | | C       | [7, 8]  |
    +---------+---------+---+---------+---------+

    """

    data_1 = raw_data_1.to_dict()
    data_2 = raw_data_2.to_dict()
    
    # Convert all values to plain Python types
    def convert_to_plain_types(data):
        return {k: (v.to_list() if hasattr(v, "to_list") else v) for k, v in data.items()}

    data_1 = convert_to_plain_types(data_1)
    data_2 = convert_to_plain_types(data_2)

    # Get the keys (columns) and values for each table
    table_1 = [[key, value] for key, value in data_1.items()]
    table_2 = [[key, value] for key, value in data_2.items()]

    # Ensure both tables have the same number of rows by padding with empty strings
    max_rows = max(len(table_1), len(table_2))
    while len(table_1) < max_rows:
        table_1.append(["", ""])
    while len(table_2) < max_rows:
        table_2.append(["", ""])

    # Combine the tables side by side
    combined_table = []
    for row_1, row_2 in zip(table_1, table_2):
        combined_table.append(row_1 + ["|"] + row_2)

    # Define headers for the tables
    headers = ["Column", "Value", "|", "Column", "Value"]

    # Print the combined table using tabulate
    print(tabulate(combined_table, headers=headers, tablefmt="grid"))

def pretty_print_schema(df):
    schema = df.schema
    
    # Convert schema to a list of tuples for tabulate
    schema_list = [(col, dtype) for col, dtype in schema.items()]
    
    # Pretty-print using tabulate
    print(tabulate(schema_list, headers=["Column", "Data Type"], tablefmt="pretty"))

def visual_analyze(df):
    # Filter out rows where the column is None
    df_filtered = df.filter(
        pl.col("time_diff_minutes").is_not_null()  # Replace "your_column" with your actual column name
    )
    # Convert Polars Series to a list for plotting
    time_diff = df_filtered["time_diff_minutes"].to_list()
    
    # # Create a boxplot
    # plt.boxplot(time_diff, vert=False, patch_artist=True, boxprops=dict(facecolor="lightblue"))

    # # Create a histogram with density
    # sns.histplot(time_diff, kde=True, color="skyblue", bins=10, edgecolor="black")
    # plt.title("Histogram with Density of Time Differences")
    # plt.ylabel("Density")

    # plt.title("Boxplot of Time Differences")

    # plt.scatter(range(len(time_diff)), time_diff, color="red", alpha=0.7)
    # plt.title("Scatter Plot of Time Differences")
    # plt.xlabel("Index")
    # plt.ylabel("Time Difference")
    # plt.show()
    # Example data (replace with your actual data)
    # time_diff = [10, 20, 30, 40, 50, 1000, 2000, 10, 20, 30, 40, 50, 1000, 2000, 10, 20, 30, 40, 50]
    
    # # Create a histogram
    # plt.figure(figsize=(8, 6))
    # sns.histplot(time_diff, bins=10, color="skyblue", edgecolor="black", kde=False)
    # plt.title("Histogram of Time Differences")
    # plt.xlabel("Time Difference")
    # plt.ylabel("Frequency (Number of Values)")
    # plt.show()

    # Define ranges and count values in each range
    ranges = ["0-40", "40-70", "70-1000","1000-2000", "2000-3000", "3000+"]
    counts = [
        sum(0 <= x < 40 for x in time_diff),
        sum(40 <= x < 70 for x in time_diff),
        sum(70 <= x < 1000 for x in time_diff),
        sum(1000 <= x < 2000 for x in time_diff),
        sum(2000 <= x < 3000 for x in time_diff),
        sum(x >= 3000 for x in time_diff),
    ]
    
    # Create a bar plot
    plt.figure(figsize=(8, 6))
    plt.bar(ranges, counts, color="orange", edgecolor="black")
    plt.title("Bar Plot of Time Differences")
    plt.xlabel("Time Difference Range")
    plt.ylabel("Frequency (Number of Values)")
    plt.show()

# def print_side_by_side_tables(raw_data_1, raw_data_2):
#     data_1 = raw_data_1.to_dict()
#     data_2 = raw_data_2.to_dict()
    
#     # Convert all values to plain Python types
#     def convert_to_plain_types(data):
#         return {k: (v.to_list() if hasattr(v, "to_list") else v) for k, v in data.items()}

#     data_1 = convert_to_plain_types(data_1)
#     data_2 = convert_to_plain_types(data_2)

#     # Get the keys (columns) and values for each table
#     table_1 = [[key, value] for key, value in data_1.items()]
#     table_2 = [[key, value] for key, value in data_2.items()]

#     # Ensure both tables have the same number of rows by padding with empty strings
#     max_rows = max(len(table_1), len(table_2))
#     while len(table_1) < max_rows:
#         table_1.append(["", ""])
#     while len(table_2) < max_rows:
#         table_2.append(["", ""])

#     # Combine the tables side by side
#     combined_table = []
#     for row_1, row_2 in zip(table_1, table_2):
#         combined_table.append(row_1 + ["|"] + row_2)

#     # Define headers for the tables
#     headers = ["Column", "Value", "|", "Column", "Value"]

#     # Print the combined table using tabulate
#     print(tabulate(combined_table, headers=headers, tablefmt="grid"))

def pretty_print_schema(df):
    schema = df.schema
    
    # Convert schema to a list of tuples for tabulate
    schema_list = [(col, dtype) for col, dtype in schema.items()]
    
    # Pretty-print using tabulate
    print(tabulate(schema_list, headers=["Column", "Data Type"], tablefmt="pretty"))

def visual_analyze(df):
    # Filter out rows where the column is None
    df_filtered = df.filter(
        pl.col("time_diff_minutes").is_not_null()  # Replace "your_column" with your actual column name
    )
    # Convert Polars Series to a list for plotting
    time_diff = df_filtered["time_diff_minutes"].to_list()
    
    # # Create a boxplot
    # plt.boxplot(time_diff, vert=False, patch_artist=True, boxprops=dict(facecolor="lightblue"))

    # # Create a histogram with density
    # sns.histplot(time_diff, kde=True, color="skyblue", bins=10, edgecolor="black")
    # plt.title("Histogram with Density of Time Differences")
    # plt.ylabel("Density")

    # plt.title("Boxplot of Time Differences")

    # plt.scatter(range(len(time_diff)), time_diff, color="red", alpha=0.7)
    # plt.title("Scatter Plot of Time Differences")
    # plt.xlabel("Index")
    # plt.ylabel("Time Difference")
    # plt.show()
    # Example data (replace with your actual data)
    # time_diff = [10, 20, 30, 40, 50, 1000, 2000, 10, 20, 30, 40, 50, 1000, 2000, 10, 20, 30, 40, 50]
    
    # # Create a histogram
    # plt.figure(figsize=(8, 6))
    # sns.histplot(time_diff, bins=10, color="skyblue", edgecolor="black", kde=False)
    # plt.title("Histogram of Time Differences")
    # plt.xlabel("Time Difference")
    # plt.ylabel("Frequency (Number of Values)")
    # plt.show()

    # Define ranges and count values in each range
    ranges = ["0-40", "40-70", "70-1000","1000-2000", "2000-3000", "3000+"]
    counts = [
        sum(0 <= x < 40 for x in time_diff),
        sum(40 <= x < 70 for x in time_diff),
        sum(70 <= x < 1000 for x in time_diff),
        sum(1000 <= x < 2000 for x in time_diff),
        sum(2000 <= x < 3000 for x in time_diff),
        sum(x >= 3000 for x in time_diff),
    ]
    
    # Create a bar plot
    plt.figure(figsize=(8, 6))
    plt.bar(ranges, counts, color="orange", edgecolor="black")
    plt.title("Bar Plot of Time Differences")
    plt.xlabel("Time Difference Range")
    plt.ylabel("Frequency (Number of Values)")
    plt.show()


## Misc

In [12]:
def statistical_analyze(df):
    """
    Statistical analysis of the time difference between group_timestamp and group_timestamp_future.
    """

    # Calculate mean and standard deviation
    mean = df["time_diff_minutes"].mean()
    std = df["time_diff_minutes"].std()
    
    # Compute Z-Score
    df = df.with_columns(
        ((pl.col("time_diff_minutes") - mean) / std).alias("z_score")
    )
    
    # Filter outliers (|Z| > 3)
    outliers = df.filter(
        pl.col("z_score").abs() > 3
    )

    return df
    
def compute_profit_columns(df):
    # 1) Compute pairAge and isPumpToken
    df = df.with_columns([
        (pl.col("group_timestamp") - pl.col("pairCreatedAt")).alias("pairAge"),
        pl.col("baseToken.address").str.ends_with("pump").alias("isPumpToken"),
    ])

    # 2) Create a 'future' DataFrame for self-join
    df_future = df.select([
        "pairAddress",
        pl.col("group_timestamp").alias("group_timestamp_future"),
        pl.col("liquidity.meteora").alias("liquidity_future"),
        pl.col("cumulative_fee_volume").alias("cfv_future"),
    ])
    

    # 3) Sort for asof join
    df = df.sort(["pairAddress", "group_timestamp"])
    df_future = df_future.sort(["pairAddress", "group_timestamp_future"])

    # FIX HERE: Remove [] from left_on/right_on
    joined = df.join_asof(
        df_future,
        left_on="group_timestamp",  # <- string, not list
        right_on="group_timestamp_future",  # <- string, not list
        by="pairAddress",
        strategy="forward",
        allow_exact_matches=False
    )

    # Calculate time difference in minutes between timestamps
    joined = joined.with_columns(
        ((pl.col("group_timestamp_future") - pl.col("group_timestamp")) / 1000).alias("time_diff_seconds"),
        ((pl.col("group_timestamp_future") - pl.col("group_timestamp")) / 1000 /60).alias("time_diff_minutes")
    )

    # 4) Filter time window
    joined = joined.filter(
        (pl.col("time_diff_minutes") >= 40) & (pl.col("time_diff_minutes") < 70)
    )


    # Convert cumulative_fee_volume to float before calculations
    joined = joined.with_columns([
        pl.col("cumulative_fee_volume").cast(pl.Float64).alias("cumulative_fee_volume"),
        pl.col("cfv_future").cast(pl.Float64).alias("cfv_future"),
    ])

    # First create averageLiquidity
    joined = joined.with_columns(
        ((pl.col("liquidity.meteora") + pl.col("liquidity_future")) / 2).alias("averageLiquidity")
    )
    
    # Then calculate percentageProfit using the newly created column
    joined = joined.filter(
        pl.col("averageLiquidity") > 0
    )
    joined = joined.with_columns(
        ((pl.col("cfv_future") - pl.col("cumulative_fee_volume")) / pl.col("averageLiquidity") * 100).alias("percentageProfit")
    )

    return joined


## Model training

In [13]:
"""
LSTM Model for Liquidity Pool Profitability Prediction
Steps: Data Loading → Preprocessing → Feature Engineering → Sequence Creation → Model Training → Evaluation
"""

class LiquidityPoolPredictor:
    """
    Static class for LSTM-based liquidity pool profitability prediction
    """
    @staticmethod
    def preprocess_data(df: pl.DataFrame) -> pl.DataFrame:
        """
        Applies critical preprocessing steps specific to liquidity pool data
        Justification: Prepares raw data for temporal modeling while preserving key financial relationships
        """
        # Feature selection (based on domain knowledge)
        keep_cols = [
            'volume.m5', 'volume.h1', 'volume.h6', 'volume.h24',
            'liquidity.meteora', 'liquidity.usd', 'apr', 'farm_apy',
            'priceChange.h1', 'priceChange.m5', 'priceChange.h24',
            'txns.m5.sells', 'txns.h1.buys', 'percentageProfit',
            'group_timestamp', 'pairCreatedAt'
        ]
        # keep_cols = [
        #     'volume.m5', 'volume.h1', 'volume.h6', 'volume.h24',
        #     'liquidity.meteora', 'liquidity.usd', 'apr', 'apy', 'farm_apy',
        #     'priceChange.h1', 'priceChange.m5', 'priceChange.h24',
        #     'txns.m5.sells', 'txns.h1.buys', 'time_diff_seconds',
        #     'time_diff_minutes', 'group_timestamp', 'percentageProfit',
        #     'pairCreatedAt'
        # ]
        df = df.select(keep_cols)
        
        # Handle missing data
        numerical_cols = [col for col in df.columns if df[col].dtype in [pl.Float64, pl.Int64]]
        df = df.with_columns(
            [pl.col(col).fill_null(pl.col(col).median()) for col in numerical_cols]
        )
        
        # Temporal alignment
        df = df.sort('group_timestamp')
        
        return df

    @staticmethod
    def engineer_features(df: pl.DataFrame) -> pl.DataFrame:
        """
        Creates time-aware features critical for LSTM performance
        Justification: Captures liquidity pool dynamics through lagged effects and rolling metrics
        """
        # Lag features (critical for autoregressive patterns)
        lag_config = {
            'volume.m5': [1, 3, 6],
            'priceChange.h1': [12, 24],
            'txns.m5.sells': [1, 2]
        }
        
        for col, lags in lag_config.items():
            for lag in lags:
                df = df.with_columns(
                    pl.col(col).shift(lag).alias(f"{col}_lag{lag}")
                )
        
        # Rolling statistics (captures market regime changes)
        window_sizes = {'liquidity.usd': 12, 'apr': 24}
        for col, window in window_sizes.items():
            df = df.with_columns(
                pl.col(col).rolling_mean(window).alias(f"{col}_rolling_mean_{window}"),
                pl.col(col).rolling_std(window).alias(f"{col}_rolling_std_{window}")
            )
        
        # Time-based features (accounts for temporal patterns)
        df = df.with_columns(
            (pl.col('group_timestamp') - pl.col('pairCreatedAt')).alias('pool_age'),
            pl.col('group_timestamp').cast(pl.Datetime).dt.hour().alias('hour_of_day'),
            pl.col('group_timestamp').cast(pl.Datetime).dt.weekday().alias('day_of_week')
        )
        
        return df

    @staticmethod
    def normalize_data(df: pl.DataFrame) -> np.ndarray:
        """
        Applies robust scaling to handle crypto market outliers
        Justification: Preserves temporal relationships while constraining values for LSTM stability
        """
        # Separate target variable
        target = df['percentageProfit'].to_numpy()
        features = df.drop('percentageProfit').to_numpy()
        
        # Robust scaling for volatile market data
        scaler = RobustScaler()
        scaled_features = scaler.fit_transform(features)
        
        # # MinMax scaling for percentage-based features
        # pct_cols = [df.columns.index(col) for col in []]
        # # pct_cols = [df.columns.index(col) for col in ['apr', 'apy', 'farm_apy']]
        # minmax_scaler = MinMaxScaler(feature_range=(0, 1))
        # scaled_features[:, pct_cols] = minmax_scaler.fit_transform(scaled_features[:, pct_cols])
        
        return scaled_features, target

    @staticmethod
    def create_sequences(features: np.ndarray, target: np.ndarray, seq_length: int = 12) -> tuple:
        """
        Transforms tabular data into LSTM-ready sequences
        Justification: Captures temporal dependencies critical for liquidity pool dynamics
        """
        X, y = [], []
        for i in range(len(features) - seq_length):
            X.append(features[i:i+seq_length])
            y.append(target[i+seq_length])
        
        return np.array(X), np.array(y)

    @staticmethod
    def build_lstm_model(input_shape: tuple) -> Sequential:
        """
        Constructs optimized LSTM architecture for financial time series
        Justification: Balances model capacity with regularization needs for volatile crypto data
        """
        model = Sequential([
            LSTM(64, 
                 input_shape=input_shape,
                 return_sequences=True,
                 dropout=0.2,
                 recurrent_dropout=0.1,
                 kernel_regularizer=l2(0.01)),
            LSTM(32,
                 dropout=0.2,
                 recurrent_dropout=0.1,
                 kernel_regularizer=l2(0.01)),
            Dense(16, activation='relu', kernel_regularizer=l2(0.01)),
            Dense(1)
        ])
        
        optimizer = AdamW(learning_rate=0.001, weight_decay=0.005)
        model.compile(optimizer=optimizer,
                      loss=Huber(delta=0.5),
                      metrics=['mae', 'mse'])
        
        return model

    @staticmethod
    def evaluate_model(model: Sequential, X_test: np.ndarray, y_test: np.ndarray):
        """
        Comprehensive model evaluation with financial metrics
        Justification: Provides actionable insights beyond standard ML metrics
        """
        # Standard metrics
        test_loss, test_mae, test_mse = model.evaluate(X_test, y_test, verbose=0)
        
        # Prediction visualization
        predictions = model.predict(X_test)
        
        plt.figure(figsize=(12, 6))
        plt.plot(y_test, label='Actual Returns')
        plt.plot(predictions, label='Predicted Returns', alpha=0.7)
        plt.title('Liquidity Pool Return Predictions vs Actual')
        plt.xlabel('Time Steps')
        plt.ylabel('Percentage Profit')
        plt.legend()
        plt.show()
        
        # Trading strategy simulation
        strategy_returns = np.sign(predictions.flatten()) * y_test
        cumulative_returns = np.cumprod(1 + strategy_returns)
        
        plt.figure(figsize=(12, 6))
        plt.plot(cumulative_returns, label='Strategy Returns')
        plt.plot(np.cumprod(1 + y_test), label='Buy & Hold')
        plt.title('Cumulative Returns Comparison')
        plt.xlabel('Time Steps')
        plt.ylabel('Cumulative Return')
        plt.legend()
        plt.show()
        
        return test_mae


## ..

In [14]:
# Initialize a variable to track the previous timestamp
previous_timestamp = None

In [15]:

# Store the current timestamp
current_timestamp = time.time()

# Format the current timestamp to a user-friendly format
formatted_current_time = datetime.fromtimestamp(current_timestamp).strftime('%Y-%m-%d %H:%M:%S')

# Print the formatted current timestamp
print("Current Timestamp:", formatted_current_time)
print("--------------------------------")

# Print the difference between the current and previous timestamp if previous exists
if previous_timestamp is not None:
    time_difference = current_timestamp - previous_timestamp
    formatted_difference = str(timedelta(seconds=time_difference))
    print("Previous Timestamp:", datetime.fromtimestamp(previous_timestamp).strftime('%Y-%m-%d %H:%M:%S'))
    print("Difference:", formatted_difference)

# Track the previous timestamp
previous_timestamp = current_timestamp

# Execution

## Load data from FS

In [16]:
# Base directory containing the subdirectories
base_directory = "/home/dev/mined-data"
# Load and group data by key suffix
start_time = time.time()
grouped_data = load_and_group_json_data_by_suffix_parallel(base_directory, max_files=1)
end_time = time.time()


## Aggregate data across files

In [20]:
new_files = process_new_files(grouped_data) 
data_book = aggregate_new_files_with_logging(new_files) 

for key, df in data_book.items():
    print(f"{key}: {df.shape}")


## Filter enriched data

In [16]:
# aggregated_enriched_data = data_book.get("aggregated_enriched_data")
aggregated_enriched_data = data_book.get("aggregated_enriched_data")
SOL_TOKEN = "So11111111111111111111111111111111111111112"
aggregated_enriched_data = aggregated_enriched_data.filter(pl.col("quoteToken.address") == SOL_TOKEN)
# print(len(aggregat/ed_enriched_data))
aggregated_enriched_data = aggregated_enriched_data.filter(pl.col("volume.h6") > 0)  
print(len(aggregated_enriched_data))

# joined.sort(["percentageProfit"], descending=True)


# Action space

In [25]:
aggregated_enriched_data = compute_profit_columns(aggregated_enriched_data)

# Set seeds for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

# Process data through pipeline
df = aggregated_enriched_data
processed_df = LiquidityPoolPredictor.preprocess_data(df)
engineered_df = LiquidityPoolPredictor.engineer_features(processed_df)
scaled_features, target = LiquidityPoolPredictor.normalize_data(engineered_df)
X, y = LiquidityPoolPredictor.create_sequences(scaled_features, target)

# Train/test split
train_size = int(0.7 * len(X))
val_size = int(0.2 * len(X))

X_train, y_train = X[:train_size], y[:train_size]
X_val, y_val = X[train_size:train_size+val_size], y[train_size:train_size+val_size]
X_test, y_test = X[train_size+val_size:], y[train_size+val_size:]

# # Build and train model
# model = LiquidityPoolPredictor.build_lstm_model((X_train.shape[1], X_train.shape[2]))

# early_stopping = tf.keras.callbacks.EarlyStopping(
#     monitor='val_loss',
#     patience=10,
#     restore_best_weights=True
# )

# history = model.fit(
#     X_train, y_train,
#     epochs=100,
#     batch_size=64,
#     validation_data=(X_val, y_val),
#     callbacks=[early_stopping],
#     verbose=1
# )

# # Evaluate model
# test_mae = LiquidityPoolPredictor.evaluate_model(model, X_test, y_test)
# print(f"Test MAE: {test_mae:.4f}")

# # Save model
# model.save("lstm_liquidity_pool_predictor.keras")

## Scratchpad


In [36]:
import os
import json
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor

def process_file(enriched_file, enriched_data_dir, meteora_pairs_dir, output_dir, meteora_files, processed_files):
    if enriched_file in processed_files:
        return

    enriched_file_path = os.path.join(enriched_data_dir, enriched_file)
    with open(enriched_file_path, 'r') as ef:
        enriched_data = json.load(ef)

    key = enriched_file.split('_')[-1].split('.')[0]
    meteora_file = next((f for f in meteora_files if key in f), None)
    
    if meteora_file:
        meteora_file_path = os.path.join(meteora_pairs_dir, meteora_file)
        with open(meteora_file_path, 'r') as mf:
            meteora_data = json.load(mf)

        meteora_data_dict = {m['address']: m for m in meteora_data}
        for dex_screener_pair in enriched_data:
            meteora_pair = meteora_data_dict.get(dex_screener_pair['pairAddress'])
            if meteora_pair:
                dex_screener_pair['meteora'] = meteora_pair

        output_file_path = os.path.join(output_dir, enriched_file)
        with open(output_file_path, 'w') as ef:
            json.dump(enriched_data, ef, indent=2)

def align_files(enriched_data_dir, meteora_pairs_dir, output_dir):
    enriched_files = os.listdir(enriched_data_dir)
    meteora_files = os.listdir(meteora_pairs_dir)
    processed_files = set(os.listdir(output_dir))

    with ThreadPoolExecutor() as executor:
        list(tqdm(executor.map(lambda enriched_file: process_file(enriched_file, enriched_data_dir, meteora_pairs_dir, output_dir, meteora_files, processed_files), enriched_files), total=len(enriched_files), desc="Processing files"))

align_files('/home/dev/mined-data/enriched_data', '/home/dev/mined-data/meteora_pairs', '/home/dev/mined-data/enriched_dex_pairs')


In [74]:
!dokku ps:scale meteora-data-miner web=0

In [17]:
print("tf version: ", tf.__version__)
print("keras version: ", tf.keras.__version__)
# print("scipy version: ", scipy.__version__)
print("numpy version: ", np.__version__)
print("polars version: ", pl.__version__)


# Notes

- Big contribution to misalignment on dex_screener_pairs data is:
        boosts
            .active
        
        info
            .websites[]
              .label
              .url
  Will consider expanding on them in the future for additional features. Will remove for now.
- uv add scipy==1.12 (numpy.char module not found error)

