In [12]:
from dask.distributed import Client, LocalCluster
import dask
dask.config.set({'temporary_directory': '/var/cache/spark'})
cluster = LocalCluster(n_workers=8, threads_per_worker=2)
client = Client(cluster)

Perhaps you already have a cluster running?
Hosting the HTTP server on port 45735 instead


In [2]:
import numpy as np
import pyarrow
import pyarrow.parquet as pq
from pathlib import Path
from functools import partial
import tqdm
import shutil
import pandas as pd
import dask
import dask.array as da
import dask.dataframe as dd
import math
pd.set_option('max_colwidth', 0)
classes = ['retweet', 'reply', 'like', 'retweet_with_comment']
label_cols = ['has_' + c for c in classes]
feature_cols = [
        "author_follower_count",
        "author_following_count",
        "user_follower_count",
        "user_following_count",
        "num_hashtags",
        "num_media",
        "num_links",
        "num_domains",
        "author_is_verified",
        "user_is_verified",
        "follows"]
emb_cols = ['e{:03d}'.format(i) for i in range(768)]
user_emb_cols = ['e{:03d}_u'.format(i) for i in range(768)]


In [3]:
data_dir = Path('~/recsys2020')
ds_name = 'user_sampled'

train_file = data_dir / "train.parquet/"
val_file = data_dir / "val.parquet/"
test_file = data_dir / "test.parquet/"

input_file_embeddings = data_dir / f"{ds_name}_embeddings.parquet/"

df = dd.read_parquet(str(input_file_embeddings), columns=['user_id', 'tweet_id', 'tweet_timestamp'] + feature_cols + emb_cols + label_cols).set_index('user_id')

In [5]:
meta = [
    ('user_id', object),
    ('ds_type', np.uint8)
] + [(n, d) for n,d in zip(df.columns, df.dtypes)]
meta = dd.utils.make_meta(meta)

In [35]:
def apply_group(grp: pd.DataFrame):
    num_labeled = grp[['has_retweet', 'has_like','has_reply','has_retweet_with_comment']].any(axis=1).sum(axis=0)
    if num_labeled < 3:
        return
    grp = grp.sort_values('tweet_timestamp')
    # we have at least 3 samples, so this *should* be okay to do
    test_val_len = max(1, math.floor(grp.shape[0] * 0.05))
    grp.insert(0, 'ds_type', value=pd.Series(0, index=grp.index, dtype=np.uint8))
    col_idx = grp.columns.get_loc('ds_type')
    grp['ds_type'] = 0
    grp.iloc[(-2 * test_val_len):(-1 * test_val_len), col_idx] = 1
    grp.iloc[(-1 * test_val_len):, col_idx] = 2
#     print(grp['ds_type'])
#     grp = grp.reset_index(drop=False)
    grp.insert(0, column='user_id', value=grp.index.values)
    return grp
    
grouped = df.groupby('user_id') \
            .apply(apply_group, meta=meta) \
            .reset_index(drop=True) \
            .repartition(partition_size='100MB')
grouped.to_parquet(str(data_dir/'grouped.parquet'))

In [36]:
grouped = dd.read_parquet(str(data_dir/'grouped.parquet'))

In [37]:
grouped[grouped['ds_type'] == 0].loc[:, grouped.columns != 'ds_type'].to_parquet(str(train_file))

In [38]:
grouped[grouped['ds_type'] == 1].loc[:, grouped.columns != 'ds_type'].to_parquet(str(val_file))

In [39]:
grouped[grouped['ds_type'] == 2].loc[:, grouped.columns != 'ds_type'].to_parquet(str(test_file))

In [40]:
print("Test samples:", dd.read_parquet(str(test_file)).count().compute())

Test samples: user_id                     72595
tweet_id                    72595
tweet_timestamp             72595
author_follower_count       72595
author_following_count      72595
                            ...  
e767                        72595
has_retweet                 72595
has_reply                   72595
has_like                    72595
has_retweet_with_comment    72595
Length: 786, dtype: int64


In [41]:
print("Train samples:", dd.read_parquet(str(train_file)).count().compute())

Train samples: user_id                     560195
tweet_id                    560195
tweet_timestamp             560195
author_follower_count       560195
author_following_count      560195
                             ...  
e767                        560195
has_retweet                 560195
has_reply                   560195
has_like                    560195
has_retweet_with_comment    560195
Length: 786, dtype: int64


In [42]:
print("Val samples:", dd.read_parquet(str(val_file)).count().compute())

Val samples: user_id                     72595
tweet_id                    72595
tweet_timestamp             72595
author_follower_count       72595
author_following_count      72595
                            ...  
e767                        72595
has_retweet                 72595
has_reply                   72595
has_like                    72595
has_retweet_with_comment    72595
Length: 786, dtype: int64
