In [4]:
import sys
import os
import warnings
import pandas as pd
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "..")))

In [12]:
import pytest
import pyarrow.parquet as pq
import pyarrow.fs as fs
import pyarrow.compute as pc
import pyarrow.dataset as ds
import s3fs as fs3
import pyarrow as pa

In [13]:
import daphmeIO
import constants

In [7]:
import pytest
import warnings
import pandas as pd

# Test formats and folder structures

## Generate test dataset

## Import partitioned dataset (data 3)

In [None]:
column_names = pd.read_csv('../../data/sample2/sample2.csv', nrows=0).columns
column_names

In [None]:
dataset = ds.dataset(path, format="parquet", partitioning="hive")
df = dataset.to_table().to_pandas()
df 

## Functions that check integrity

kind of based on scikit-mobility

In [36]:
def _update_schema(original, new_labels):
    updated_schema = dict(original)
    for label in new_labels:
        if label in constants.DEFAULT_SCHEMA:
            updated_schema[label] = new_labels[label]
    return updated_schema

def _has_traj_cols(df, traj_cols):
    
    # Check for sufficient spatial columns
    spatial_exists = (
        ('latitude' in traj_cols and 'longitude' in traj_cols and 
         traj_cols['latitude'] in df and traj_cols['longitude'] in df) or
        ('x' in traj_cols and 'y' in traj_cols and 
         traj_cols['x'] in df and traj_cols['y'] in df) or
        ('geohash' in traj_cols and traj_cols['geohash'] in df)
    )
    
    # Check for sufficient temporal columns
    temporal_exists = (
        ('datetime' in traj_cols and traj_cols['datetime'] in df) or
        ('timestamp' in traj_cols and traj_cols['timestamp'] in df)
    )
    
    if not spatial_exists:
        raise ValueError(
            "Could not find required spatial columns in {}. The dataframe columns must contain or map to at least one of the following sets: "
            "('latitude', 'longitude'), ('x', 'y'), or 'geohash'.".format(df.columns.tolist())
        )
        
    if not temporal_exists:
        raise ValueError(
            "Could not find required temporal columns in {}. The dataframe columns must contain or map to either 'datetime' or 'timestamp'.".format(df.columns.tolist())
        )
    
    return spatial_exists and temporal_exists


def _cast_traj_cols(df, traj_cols):
    if 'datetime' in traj_cols and traj_cols['datetime'] in df:
        if not pd.core.dtypes.common.is_datetime64_any_dtype(df[traj_cols['datetime']].dtype):
            df[traj_cols['datetime']] = pd.to_datetime(df[traj_cols['datetime']])
    if 'timestamp' in traj_cols and traj_cols['timestamp'] in df:
        # Coerce to integer if it's not already
        if not pd.core.dtypes.common.is_integer_dtype(df[traj_cols['timestamp']].dtype):
            df[traj_cols['timestamp']] = df[traj_cols['timestamp']].astype(int)

    float_cols = ['latitude', 'longitude', 'x', 'y']
    for col in float_cols:
        if col in traj_cols and traj_cols[col] in df:
            if not pd.core.dtypes.common.is_float_dtype(df[traj_cols[col]].dtype):
                df[traj_cols[col]] = df[traj_cols[col]].astype("float")

    string_cols = ['user_id', 'geohash']
    for col in string_cols:
        if col in traj_cols and traj_cols[col] in df:
            if not pd.core.dtypes.common.is_string_dtype(df[traj_cols[col]].dtype):
                df[traj_cols[col]] = df[traj_cols[col]].astype("str")

    return df

def _is_traj_df(df, traj_cols = None, **kwargs):
    
    if not (isinstance(df, pd.DataFrame) or isinstance(df, gpd.GeoDataFrame)):
        return False
    
    if not traj_cols:
        traj_cols = {}
        traj_cols = _update_schema(traj_cols, kwargs) #kwargs ignored if traj_cols
        
    traj_cols = _update_schema(constants.DEFAULT_SCHEMA, traj_cols)
    
    if not _has_traj_cols(df, traj_cols):
        return False
    
    if 'datetime' in traj_cols and traj_cols['datetime'] in df:
        if not pd.core.dtypes.common.is_datetime64_any_dtype(df[traj_cols['datetime']].dtype):
            return False
    elif 'timestamp' in traj_cols and traj_cols['timestamp'] in df:
        if not pd.core.dtypes.common.is_integer_dtype(df[traj_cols['timestamp']].dtype):
            return False

    float_cols = ['latitude', 'longitude', 'x', 'y']
    for col in float_cols:
        if col in traj_cols and traj_cols[col] in df:
            if not pd.core.dtypes.common.is_float_dtype(df[traj_cols[col]].dtype):
                return False

    string_cols = ['user_id', 'geohash']
    for col in string_cols:
        if col in traj_cols and traj_cols[col] in df:
            if not pd.core.dtypes.common.is_string_dtype(df[traj_cols[col]].dtype):
                return False

    return True


def from_object(df, traj_cols = None, spark_enabled=False, **kwargs):

    if not (isinstance(df, pd.DataFrame) or isinstance(df, gpd.GeoDataFrame)):
        raise TypeError(
            "Expected the data argument to be either a pandas DataFrame or a GeoPandas GeoDataFrame."
        )
    
    # valid trajectory column names passed to **kwargs collected
    if not traj_cols:
        traj_cols = {}
        traj_cols = _update_schema(traj_cols, kwargs) #kwargs ignored if traj_cols
            
    for key, value in traj_cols.items():
        if value not in df:
            warnings.warn(f"Trajectory column '{value}' specified for '{key}' not found in df.")
            
    # include defaults when missing
    traj_cols = _update_schema(constants.DEFAULT_SCHEMA, traj_cols)
    
    if _has_traj_cols(df, traj_cols):
        return _cast_traj_cols(df, traj_cols)


def from_file(filepath, format="csv", traj_cols=None, **kwargs):
    assert format in ["csv", "parquet"]
    
    if format == 'parquet':
        dataset = ds.dataset(filepath, format="parquet", partitioning="hive")
        df = dataset.to_table().to_pandas()
        return from_object(df, traj_cols=traj_cols, **kwargs)

    elif format == 'csv':
        if os.path.isdir(filepath) or isinstance(filepath, list):
            dataset = ds.dataset(filepath, format="csv", partitioning="hive")
            df = dataset.to_table().to_pandas()
            return from_object(df, traj_cols=traj_cols)
        else:
            df = pd.read_csv(filepath)
            return from_object(df, traj_cols=traj_cols)
    
    return None

def sample_users(filepath, format='csv', frac_users=1.0, traj_cols=None, **kwargs):
    
    assert format in ['csv', 'parquet']

    if not traj_cols:
        traj_cols = {}
        traj_cols = _update_schema(traj_cols, kwargs)
        
    uid_col = traj_cols['user_id']

    if format == 'parquet':
        dataset = ds.dataset(filepath, format="parquet", partitioning="hive")
        if uid_col not in dataset.schema.names:
            raise ValueError(
                "Could not find required user ID column in {}. The columns must contain or map to '{}'.".format(
                    dataset.schema.names, uid_col)
            )
        user_ids = pc.unique(dataset.to_table(columns=[uid_col])[uid_col]).to_pandas()

    else:
        if os.path.isdir(filepath):
            dataset = ds.dataset(filepath, format="csv", partitioning="hive")
            if uid_col not in dataset.schema.names:
                raise ValueError(
                    "Could not find required user ID column in {}. The columns must contain or map to '{}'.".format(
                        dataset.schema.names, uid_col)
                )
            user_ids = pc.unique(dataset.to_table(columns=[uid_col])[uid_col]).to_pandas()
        else:
            df = pd.read_csv(filepath, usecols=[uid_col])
            if uid_col not in df.columns:
                raise ValueError(
                    "Could not find required user ID column in {}. The columns must contain or map to '{}'.".format(
                        df.columns.tolist(), uid_col)
                )
            user_ids = df[uid_col].unique()

    return user_ids.sample(frac=frac_users) if frac_users < 1.0 else user_ids



def sample_from_file(filepath, users, format="csv", traj_cols=None, **kwargs):
    assert format in ["csv", "parquet"]
    
    if not traj_cols:
        traj_cols = {}
        traj_cols = _update_schema(traj_cols, kwargs)
        
    uid_col = traj_cols['user_id']

    if format == 'parquet':
        dataset = ds.dataset(filepath, format="parquet", partitioning="hive")
        if uid_col not in dataset.schema.names:
            raise ValueError(
                "Could not find required user ID column in {}. The columns must contain or map to '{}'.".format(
                    dataset.schema.names, uid_col)
            )
        df = dataset.to_table(
            filter=ds.field(uid_col).isin(list(users))
        ).to_pandas()
    
    elif format == 'csv':
        if os.path.isdir(filepath) or isinstance(filepath, list):
            dataset = ds.dataset(filepath, format="csv", partitioning="hive")
            if uid_col not in dataset.schema.names:
                raise ValueError(
                    "Could not find required user ID column in {}. The columns must contain or map to '{}'.".format(
                        dataset.schema.names, uid_col)
                )
            df = dataset.to_table(
                filter=ds.field(uid_col).isin(list(users))
            ).to_pandas()
        else:
            df = pd.read_csv(filepath)
            if uid_col not in df.columns:
                raise ValueError(
                    "Could not find required user ID column in {}. The columns must contain or map to '{}'.".format(
                        df.columns.tolist(), uid_col)
                )
            df = df[df[uid_col].isin(users)]
    
    return from_object(df, traj_cols=traj_cols, **kwargs)

In [32]:
path = '../../data/sample1/'
df = from_file(path, format='parquet')

In [38]:
u_sample = sample_users(path, format='parquet', frac_users=0.1, user_id='uid')

In [39]:
sample_from_file(path, users = u_sample, format='parquet', user_id='uid')

Unnamed: 0,uid,timestamp,latitude,longitude
0,wonderful_swirles,1704121560,38.321017,-36.667869
1,wonderful_swirles,1704178440,38.320851,-36.667484
2,wonderful_swirles,1704178800,38.320852,-36.667470
3,wonderful_swirles,1704179820,38.320832,-36.667579
4,wonderful_swirles,1704180960,38.320834,-36.667461
...,...,...,...,...
176,happy_feynman,1705183920,38.320609,-36.666785
177,happy_feynman,1705249440,38.321699,-36.667564
178,happy_feynman,1705249500,38.321702,-36.667532
179,happy_feynman,1705250880,38.321722,-36.667483


In [41]:
data = pd.DataFrame([[1, 39.984094, 116.319236, '2008-10-23 13:53:05'],
 [1, 39.984198, 116.319322, '2008-10-23 13:53:06'],
 [1, 39.984224, 116.319402, '2008-10-23 13:53:11'],
 [1, 39.984211, 116.319389, '2008-10-23 13:53:16']], columns = ['uid', 'latitude', 'longitude', 'time'])

In [42]:
traj_cols = {'user_id':'uid',
         'latitude':'latitude',
         'longitude':'longitude',
            'datetime':'time'}
df = from_object(data, traj_cols)

In [None]:
_is_traj_df(df, traj_cols)