In [None]:
import sys
import os
import warnings
import pandas as pd
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "..")))

In [None]:
import pytest
import warnings
import pandas as pd
from .. import daphmeIO

In [None]:
loader = daphmeIO.DataLoader()

In [None]:
loader.load_spark("s3://phl-pings/gravy/")

In [None]:
df = loader.df

In [None]:
df.select('uid').distinct().show()

In [None]:
df = df.selectExpr("grid as uid")
df.show()

# Test formats and folder structures

## Generate test dataset

In [None]:
pd.read_parquet("s3://synthetic-raw-data/100-agents/sparse_trajectories.parquet")

In [None]:
import pyarrow.parquet as pq
import pyarrow.fs as fs
import pyarrow.dataset as ds
import s3fs as fs3
import pyarrow as pa
import constants


In [None]:

#using pyarrow.fs.S3FileSystem seems to break this
s3 = fs3.S3FileSystem()

path = "s3://synthetic-raw-data/100-agents/sparse_trajectories.parquet"
dataset = ds.dataset(path[5:], format="parquet", filesystem=s3)
df = dataset.to_table().to_pandas()

In [None]:
df['date'] = pd.to_datetime(df['timestamp'], unit='s').dt.date

In [None]:
table = pa.Table.from_pandas(df)

In [None]:
path = "../data/sample3/"
dataset = ds.dataset(path, format="parquet", partitioning='hive')
df = dataset.to_table().to_pandas()

## Import partitioned dataset (data 3)

In [None]:
column_names = pd.read_csv('../../data/sample2/sample2.csv', nrows=0).columns
column_names

In [None]:
#these can be passed as kwargs?
traj_cols = {'user_id':'uid',
             'latitude':'latitude',
             'longitude':'longitude',
             'time':'timestamp'}

path = '../../data/sample3'
single_user = False
file_format = "csv"
partitioning= "hive"
filesystem = s3

# we use pandas, pyarrow, or pyspark for reading a df, then we check if it has traj vars?
# maybe we first test on a single row

In [None]:
constants.DEFAULT_SCHEMA

In [None]:
dataset = ds.dataset(path, format="parquet", partitioning="hive")
df = dataset.to_table().to_pandas()
df 

## Functions that check integrity

kind of based on scikit-mobility

In [None]:
def _update_schema(original, new_labels):
    updated_schema = dict(original)
    for label in new_labels:
        if label in constants.SCHEMA_NAMES:
            updated_schema[label] = new_labels[label]
    return updated_schema

def _has_traj_cols(df, traj_cols):
    
    # Check for sufficient spatial columns
    spatial_exists = (
        ('latitude' in traj_cols and 'longitude' in traj_cols and 
         traj_cols['latitude'] in df and traj_cols['longitude'] in df) or
        ('x' in traj_cols and 'y' in traj_cols and 
         traj_cols['x'] in df and traj_cols['y'] in df) or
        ('geohash' in traj_cols and traj_cols['geohash'] in df)
    )
    
    # Check for sufficient temporal columns
    temporal_exists = (
        ('datetime' in traj_cols and traj_cols['datetime'] in df) or
        ('timestamp' in traj_cols and traj_cols['timestamp'] in df)
    )
    
    if not spatial_exists:
        raise ValueError(
            "Missing required spatial columns. The dataframe must contain at least one of the following sets: "
            "('latitude', 'longitude'), ('x', 'y'), or 'geohash'."
        )
        
    if not temporal_exists:
        raise ValueError(
            "Missing required temporal column. The dataframe must contain either 'datetime' or 'timestamp'."
        )
    
    return spatial_exists and temporal_exists


def _cast_traj_cols(df, traj_cols):
    if 'datetime' in traj_cols and traj_cols['datetime'] in df:
        if not pd.core.dtypes.common.is_datetime64_any_dtype(df[traj_cols['datetime']].dtype):
            df[traj_cols['datetime']] = pd.to_datetime(df[traj_cols['datetime']])
    if 'timestamp' in traj_cols and traj_cols['timestamp'] in df:
        # Coerce to integer if it's not already
        if not pd.core.dtypes.common.is_integer_dtype(df[traj_cols['timestamp']].dtype):
            df[traj_cols['timestamp']] = df[traj_cols['timestamp']].astype(int)

    float_cols = ['latitude', 'longitude', 'x', 'y']
    for col in float_cols:
        if col in traj_cols and traj_cols[col] in df:
            if not pd.core.dtypes.common.is_float_dtype(df[traj_cols[col]].dtype):
                df[traj_cols[col]] = df[traj_cols[col]].astype("float")

    string_cols = ['user_id', 'geohash']
    for col in string_cols:
        if col in traj_cols and traj_cols[col] in df:
            if not pd.core.dtypes.common.is_string_dtype(df[traj_cols[col]].dtype):
                df[traj_cols[col]] = df[traj_cols[col]].astype("str")

    return df

def _is_traj_df(df, traj_cols = None, **kwargs):
    
    if not (isinstance(df, pd.DataFrame) or isinstance(df, gpd.GeoDataFrame)):
        return False
    
    if not traj_cols:
        traj_cols = {}
        traj_cols = _update_schema(traj_cols, kwargs) #kwargs ignored if traj_cols
        
    traj_cols = _update_schema(constants.DEFAULT_SCHEMA, traj_cols)
    
    if not _has_traj_cols(df, traj_cols):
        return False
    
    if 'datetime' in traj_cols and traj_cols['datetime'] in df:
        if not pd.core.dtypes.common.is_datetime64_any_dtype(df[traj_cols['datetime']].dtype):
            return False
    elif 'timestamp' in traj_cols and traj_cols['timestamp'] in df:
        if not pd.core.dtypes.common.is_integer_dtype(df[traj_cols['timestamp']].dtype):
            return False

    float_cols = ['latitude', 'longitude', 'x', 'y']
    for col in float_cols:
        if col in traj_cols and traj_cols[col] in df:
            if not pd.core.dtypes.common.is_float_dtype(df[traj_cols[col]].dtype):
                return False

    string_cols = ['user_id', 'geohash']
    for col in string_cols:
        if col in traj_cols and traj_cols[col] in df:
            if not pd.core.dtypes.common.is_string_dtype(df[traj_cols[col]].dtype):
                return False

    return True


def from_object(df, traj_cols = None, spark_enabled=False, **kwargs):

    if not (isinstance(df, pd.DataFrame) or isinstance(df, gpd.GeoDataFrame)):
        raise TypeError(
            "Expected the data argument to be either a pandas DataFrame or a GeoPandas GeoDataFrame."
        )
    
    # valid trajectory column names passed to **kwargs collected
    if not traj_cols:
        traj_cols = {}
        traj_cols = _update_schema(traj_cols, kwargs) #kwargs ignored if traj_cols
            
    for key, value in traj_cols.items():
        if value not in df:
            warnings.warn(f"Trajectory column '{value}' specified for '{key}' not found in df.")
            
    # include defaults when missing
    traj_cols = _update_schema(constants.DEFAULT_SCHEMA, traj_cols)
    
    if _has_traj_cols(df, traj_cols):
        return _cast_traj_cols(df, traj_cols)
    

In [None]:
data = pd.DataFrame([[1, 39.984094, 116.319236, '2008-10-23 13:53:05'],
 [1, 39.984198, 116.319322, '2008-10-23 13:53:06'],
 [1, 39.984224, 116.319402, '2008-10-23 13:53:11'],
 [1, 39.984211, 116.319389, '2008-10-23 13:53:16']], columns = ['uid', 'latitude', 'longitude', 'time'])

In [None]:
traj_cols = {'user_id':'uid',
         'latitude':'latitude',
         'longitude':'longitude',
          'datetime':'time'}
df = from_object(data, traj_cols)

In [None]:
_is_traj_df(df, traj_cols)