In [None]:
import time
import json
import ast
import os
import datetime
import io
from collections import defaultdict

import imageio
import boto3
import pandas as pd
import numpy as np
import imageio
import matplotlib.pyplot as plt
import seaborn as sns

from brtdevkit.core.db.athena import AthenaClient
from brtdevkit.data import Dataset
from timezonefinder import TimezoneFinderL
import pytz

from aletheia_dataset_creator.dataset_tools.aletheia_dataset_helpers import imageids_to_dataset, imageids_to_dataset_fast
from aletheia_dataset_creator.config.dataset_config import LEFT_CAMERAS, ALL_CAMERA_PAIRS_LIST
%matplotlib inline

In [None]:
pd.set_option('display.max_rows', 500)

In [None]:
athena = AthenaClient()
s3 = boto3.resource('s3')
tf = TimezoneFinderL()
home = os.path.expanduser(path='~')
data_path = '/data/jupiter/alex.li/datasets/'

# Selecting data

In [None]:
HALO_LEFT_CAMERAS = ['T01', 'T02', 'T05', 'T06', 'T09', 'T10', 'T13', 'T14', 'I01', 'I02']
allpath = data_path + "/halo_all.parquet"
if os.path.exists(allpath):
    df_all = pd.read_parquet(path=allpath)
else:
    query = f"""
        SELECT collected_on, id, robot_name, geohash, camera_location, operation_time, latitude,
            longitude, gps_can_data__json
        FROM image_jupiter
        WHERE sensor_type = 'VD6763'
        AND camera_location IN {tuple(HALO_LEFT_CAMERAS)}
        AND geohash IS NOT NULL
        AND geohash NOT LIKE '7zzzz%'
        AND gps_can_data__json IS NOT NULL
        ORDER BY RAND()
        LIMIT 1000000
    """
    df_all = athena.get_df(query)
    df_all.to_parquet(allpath)
orangepath = data_path + "/halo_orange_implement.parquet"
if os.path.exists(orangepath):
    df_orange = pd.read_parquet(orangepath)
else:
    print('cache failed')
    query1 = f"""
    SELECT id, robot_name, collected_on, operation_time,
        camera_location, gps_can_data__json, group_id, geohash
    FROM image_jupiter
    WHERE sensor_type = 'VD6763'
    AND camera_location IN {tuple(HALO_LEFT_CAMERAS)}
    AND geohash IS NOT NULL
    AND geohash NOT LIKE '7zzzz%'
    AND gps_can_data__json IS NOT NULL
    AND image_jupiter.robot_name IN ('halohitchhiker_182')
    ORDER BY RAND()
    LIMIT 30000
    """
    df_orange = athena.get_df(query1)
    df_orange.to_parquet(orangepath)

In [None]:
puddlepath = data_path + "/halo_puddle.parquet"
if os.path.exists(puddlepath):
    df_puddle = pd.read_parquet(puddlepath)
else:
    df_puddle = Dataset.retrieve(name='labelbox_import_puddle_slice').to_dataframe()
    df_puddle.to_parquet(puddlepath)
dustpath = data_path + "/halo_dust.parquet"
if os.path.exists(dustpath):
    df_dust = pd.read_parquet(dustpath)
else:
    df_dust = Dataset.retrieve(name='labelbox_import_dust_slice').to_dataframe()
    df_dust.to_parquet(dustpath)

In [None]:
geohash_df = pd.read_csv(filepath_or_buffer='/data/jupiter/alex.li/20231213_geohash_table_v6.csv', index_col="Unnamed: 0")
geohash_train_df = geohash_df[geohash_df['bucket'] == 'train']
new_geohashes = set()
def filter_df(df_orig):
    global new_geohashes
    df_orig["geohash_short"] = df_orig["geohash"].apply(lambda x: x[:6])
    if 'speed' not in df_orig.columns:
        if 'gps_can_data__json' in df_orig.columns:
            df_orig["speed"] = df_orig["gps_can_data__json"].apply(lambda x: json.loads(x).get('speed', np.nan))
        elif 'gps_can_data' in df_orig.columns:
            df_orig["speed"] = df_orig["gps_can_data"].apply(lambda x: x.get('speed', np.nan))
    df_atspeed = df_orig[(1 < df_orig["speed"]) & (df_orig["speed"] < 30)]

    new_geohashes = new_geohashes.union([geohash for geohash in set(df_atspeed["geohash_short"]) if geohash not in geohash_df.index])
    df_train = df_atspeed[df_atspeed['geohash_short'].isin(geohash_train_df.index)]
    return df_train

In [None]:
df_filt_all = filter_df(df_all)[['id', 'camera_location', 'robot_name','collected_on', 'speed', 'geohash_short']]
df_filt_orange = filter_df(df_orange)[['id', 'camera_location', 'robot_name','collected_on', 'speed', 'geohash_short']]
df_filt_puddle = filter_df(df_puddle)[['id', 'camera_location', 'robot_name','collected_on', 'speed', 'geohash_short']]
df_filt_dust = filter_df(df_dust)[['id', 'camera_location', 'robot_name','collected_on', 'speed', 'geohash_short']]
print(len(new_geohashes))   

In [None]:
print(len(df_filt_all))
df_filt_all = df_filt_all.sample(30000, replace=False)
print(len(df_filt_orange))
df_filt_all = df_filt_orange.sample(5000, replace=False)
print(len(df_filt_puddle))
df_filt_puddle = df_filt_puddle.sample(10000, replace=False)
print(len(df_filt_dust))
df_filt_dust = df_filt_dust.sample(10000, replace=False)

In [None]:
df = pd.concat([df_filt_all, df_filt_orange, df_filt_puddle, df_filt_dust])
df = df[df['camera_location'].isin(HALO_LEFT_CAMERAS)]
df['collected_on']  = pd.to_datetime(df['collected_on'])
print(len(df))

In [None]:
df.groupby('camera_location').count()

In [None]:
# df = pd.read_csv('/data/jupiter/alex.li/wrong_label.csv')
# Dataset.create(name='halo_v61_to_relabel', description='images with incorrect label from v61 train set', kind=Dataset.KIND_IMAGE, image_ids=list(df['id']))

In [None]:
def make_dataset_slow(from_df, name, description) -> None:
    imids = list(set(from_df[)'id'])
    desc = f"{description} ({len(from_df['id'])} images)"
    print(len(imids))
    from_df.to_parquet(data_path + f'/{name}.parquet', index=False)
    imageids_to_dataset(imids, name, dataset_kind='image',
                            dataset_description=desc)
# make_dataset_slow(df, 'halo_images_for_train_implement_dust_puddle_small', 'training images for halo, choosen based on recent fps. Needs to be filtered further...')

In [None]:
model_positive_df = pd.read_csv('/mnt/sandbox1/alex.li/model_positives/halo_images_for_train_implement_dust_puddle_small_repro_bug/image_similarity_reduced_1.csv')

In [None]:
print(len(model_positive_df))
model_positive_df = model_positive_df.drop_duplicates(['cluster_id'])
print(len(model_positive_df))

In [None]:
model_positive_df = model_positive_df.drop_duplicates(['id'])
imids = model_positive_df['id']
print(sum(imids.isin(df_filt_all['id'])))
print(sum(imids.isin(df_filt_orange['id'])))
print(sum(imids.isin(df_filt_puddle['id'])))
print(sum(imids.isin(df_filt_dust['id'])))
print(len(imids))

In [None]:
Dataset.create(name='model_positives_labelbox_search', 
            description="""Images to label. Model positives on images from a few sources.
            129: randomly sampled from athena
            4443: sampled from athena on rear camera of halohitchhiker_182
            121: sampled from labelbox, have puddles and tire tracks
            28: sampled from labelbox, dusty images""",
            kind=Dataset.KIND_IMAGE,
            image_ids=imids,
)