In [None]:
from lib.naming_corrections import TABLES_V5_2_V4_RENAME_LEGACY,\
                                   TABLES_COLUMNS_DEFAULT_LEGACY,\
                                   FEATURES_NAMES_FROM_NEW_CACHE,\
                                   FEATURES_NAMES_FROM_PRELOADED_CACHE, FEATURE_COLUMNS_OTHERS

# Define the feature names and stuff
FEATURE_COLUMNS = FEATURES_NAMES_FROM_PRELOADED_CACHE+FEATURE_COLUMNS_OTHERS
TABLES_V5_2_V4_RENAME = TABLES_V5_2_V4_RENAME_LEGACY
TABLES_COLUMNS_DEFAULT = TABLES_COLUMNS_DEFAULT_LEGACY

In [None]:
def _load_data(
    data_dir: str,
    device,
    feature_column_names: List[str] = FEATURE_COLUMNS,
    debug=False,
    semi_supervised=True,
    semi_supervised_resample_negs=None, # randomize negs
    semi_supervised_resample_factor=None, # randomize negs factor
    splits: List[str] = [
        DATA_LABEL_TRAIN,
        DATA_LABEL_VAL,
        DATA_LABEL_TEST,
    ], 
    scaler: Optional[StandardScaler] = None,
    rng = np.random.default_rng(seed=1),
    features_dir: Optional[str]=None, # extra path that we use for duckdb queries
    refresh_cache:bool = False,
):
    labels_dir = "labels"
    partitions_dir = "partitions.parquet"
    cached_features_dir = "cache/features"
    cached_edges_dir = "cache/edges"

    assert data_dir!=None, ("data path does not exist")
    assert features_dir!=None, ("duckdb path does not exist")

    datastore = load_local_data_store(data_dir)
    features_datastore = load_local_data_store(features_dir)

    assert datastore.exists(partitions_dir), "partitions.parquet missing"
    assert datastore.exists(labels_dir) or is_empty("labels", datastore)==False, "labels missing"

    print ("Building dataset")

    # If not using cache, clear cache folder
    global FEATURE_COLUMNS, TABLES_V5_2_V4_RENAME, TABLES_COLUMNS_DEFAULT
    if refresh_cache:
        reset_cache(cached_features_dir, cached_edges_dir, datastore)
        FEATURE_COLUMNS = FEATURES_NAMES_FROM_NEW_CACHE+FEATURE_COLUMNS_OTHERS
        from lib.naming_corrections import TABLES_V5_2_V4_RENAME, TABLES_COLUMNS_DEFAULT
        TABLES_V5_2_V4_RENAME = TABLES_V5_2_V4_RENAME
        TABLES_COLUMNS_DEFAULT = TABLES_COLUMNS_DEFAULT
        print ("Using DuckDB to generate features to cache")
    else:
        print ("Using existing cache. Verifying...")

    # read the top level partition parquet
    df_p = pd.read_parquet(
        datastore.open_file(partitions_dir)
    ).reset_index(drop=True).reset_index()

    if debug:
        df_p = df_p.groupby('split').first()

    # we make a copy of the y in its native int
    # - because we need to int version for the metrics (since we want to compute in GPU)
    # - unfortunately the loss function we are using reqiures the y values
    # to be in float
    def build_data(X, y, edges):
        return Data(
            x=torch.tensor(
                X.astype(np.float32), 
                device=device
            ), 
            edge_index=torch.tensor(edges.T, device=device), 
            y=torch.tensor(y, device=device).float(),
            y_i=torch.tensor(y, device=device)
        )

    from collections import Counter
    counters = {}
    graph_data = {}
    labelled = {}
    others = {}
    for sp, A in df_p.groupby('split'):
        graph_data[sp] = {
            x: None
            for x in sorted(A['index'])
        }
        labelled[sp] = {
            x: None
            for x in sorted(A['index'])
        }
        counters[sp] = Counter()

    # Retrieve features and edges dataframes
    for sp in splits:
        for p in tqdm(graph_data[sp]):
            labels_partition_filepath = os.path.join(labels_dir, f"labels_{p}.parquet")
            features_partition_filepath = os.path.join(cached_features_dir, f"features_{p}.parquet")
            edges_partition_filepath = os.path.join(cached_edges_dir, f"edges_{p}.parquet")

            # Regenerate from duckdb if conditions exist
            if (
                refresh_cache==True
                or datastore.exists(features_partition_filepath)==False 
                or datastore.exists(edges_partition_filepath)==False
            ):                
                # it should hit only one entry
                _partition = df_p.query(f'index == {p}').iloc[0]

                # df_l :
                # - txid (in same order as df_f)
                # - label
                # - node (may not really be needed, not sure why another ordering )
                df_l = pd.read_parquet(
                    datastore.open_file(labels_partition_filepath)
                )

                df_f, df_e = generate_from_queries(df_l, _partition, features_datastore)

                # Stores generated queries to cache
                datastore.to_features_pandas(df_f, features_partition_filepath)
                datastore.to_features_pandas(df_e, edges_partition_filepath)

    # use the default names if the cache is regenerated else we keep to the legacy names from the old cache

    # If it's a train loop and there is no pre-fitted scaler
    if (
        (DATA_LABEL_TRAIN in splits)
        and 
        (scaler is None)
    ):
        scaler = fit_scaler(graph_data[DATA_LABEL_TRAIN], cached_features_dir=cached_features_dir, datastore=datastore)            

    # Convert cache into a geometric dataset
    print("Loading Cached Data for Training...")
    for sp in splits:
        for p in tqdm(graph_data[sp]):
            labels_partition_filepath = os.path.join(labels_dir, f"labels_{p}.parquet")
            features_partition_filepath = os.path.join(cached_features_dir, f"features_{p}.parquet")
            edges_partition_filepath = os.path.join(cached_edges_dir, f"edges_{p}.parquet")

            # df_l :
            # - txid (in same order as df_f)
            # - label
            # - node (may not really be needed, not sure why another ordering )
            df_l = pd.read_parquet(
                datastore.open_file(labels_partition_filepath)
            )

            assert datastore.exists(features_partition_filepath), f"Missing Features Dataframe, '{features_partition_filepath}'"                
            assert datastore.exists(edges_partition_filepath), f"Missing Edges Dataframe, '{edges_partition_filepath}'"                
            df_f, df_e = read_from_cache(features_partition_filepath, edges_partition_filepath, datastore)

            X = df_f[
                feature_column_names
            ].fillna(value=0.).values

            if scaler:    
                X = scaler.transform(X)

            # need to ensure ordering is same
            y = df_f[['txid']].merge(
                df_l[['txid', 'label']],
            )['label'].values

            # uses a negative sampling strategy
            y=augment_labels(
                y, 
                rng, 
                semi_supervised=semi_supervised, 
                semi_supervised_resample_negs=semi_supervised_resample_negs, 
                semi_supervised_resample_factor=semi_supervised_resample_factor
                )

            # if semi sup is disabled, then this will essentially
            # be all of them
            labelled[sp][p], = np.where(
                y != 2
            )

            counters[sp].update(y)
            graph_data[sp][p] = build_data(
                X=X,
                y=y,
                edges=df_e[['from', 'to']].values
            )

    for sp in splits:
        print (sp, counters[sp])

    return graph_data, labelled, scaler, feature_column_names