In [None]:
import pandas as pd
import numpy as np
import time, sys, datetime, os

In [None]:
sys.path.insert(0,"../../python/")
import rg17.text_cleaning as tc

In [None]:
import sys
from datawand.parametrization import ParamHelper
ph = ParamHelper("../../", "TrendApproximation", sys.argv)

In [None]:
stemmed_tweet_text_file = ph.get("stemmed_tweet_file_path")
word_corpus_file = ph.get("word_corpus")
model_root_folder = ph.get("w2v_root_folder")
w2v_model_dim = ph.get("w2v_model_dim")

In [None]:
if not os.path.exists(model_root_folder):
    os.makedirs(model_root_folder)


# 1. Load stemmed data

In [None]:
stemmed_tweets = pd.read_csv(stemmed_tweet_text_file, sep="|")

In [None]:
stemmed_tweets = stemmed_tweets[["time","text"]]

## Convert epoch to GMT time (Dávid uses this format for now)

In [None]:
stemmed_tweets["gmt_time"] = stemmed_tweets["time"].apply(lambda x: time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(x)))

In [None]:
snapshot_hours = [1,4,7,10,13,16,19,22]

def round_hour_value(hour_val):
    if hour_val % 3 == 2:
        return hour_val - 1
    elif hour_val % 3 == 0:
        return hour_val - 2
    else:
        return hour_val

def get_snapshot_id(gmt_time):
    date_rec, time_rec = gmt_time.split()
    hour_value = int(time_rec.split(":")[0])
    rounded_hour = round_hour_value(hour_value)
    if rounded_hour == -2:
        rounded_hour = 22
        dt = datetime.datetime.strptime(date_rec, "%Y-%m-%d")
        dt_new = dt - datetime.timedelta(days=1)
        date_rec = "%.4i-%.2i-%.2i" % (dt_new.year, dt_new.month, dt_new.day)
    return "%sT%.2i:00" % (date_rec, rounded_hour)
    
print(get_snapshot_id("2017-06-11 21:00:59"))
print(get_snapshot_id("2017-06-11 00:00:59"))

In [None]:
stemmed_tweets["snapshot_id"] = stemmed_tweets["gmt_time"].apply(get_snapshot_id)

In [None]:
stemmed_tweets["date"] = stemmed_tweets["snapshot_id"].apply(lambda x: x.split("T")[0])

## Filter for evaluation timeframes (June 6 - June 11) - for faster preprocessing before W2V

In [None]:
print(len(stemmed_tweets))
stemmed_tweets = stemmed_tweets[(stemmed_tweets["date"] >= "2017-05-28") & (stemmed_tweets["date"] <= "2017-06-11")]
print(len(stemmed_tweets))

# 2. Load selected word corpus

In [None]:
selected_words = []
with open(word_corpus_file) as f:
    for line in f:
        selected_words.append(line.rstrip())
selected_words = set(selected_words)

In [None]:
len(selected_words)

# 3. Filter tweet texts for relevant words

## a.) Cleaning based on regexp

In [None]:
stemmed_tweets["cleaned_text"] = stemmed_tweets["text"].apply(tc.clean_text)

## b.) Cleaning based on word length

In [None]:
stemmed_tweets["word_list"] = stemmed_tweets["cleaned_text"].apply(lambda x: tc.get_words_above_size_limit(x,2))

## c.) Leaving only the selected words

In [None]:
stemmed_tweets["selected_word_list"] = stemmed_tweets["word_list"].apply(lambda x: list(set(x).intersection(selected_words)))

# 4. Training Word2Vec models

In [None]:
import gensim, logging

In [None]:
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)

In [None]:
def train_w2v_model(snapshot_id, dump_model=False):
    partial_df = stemmed_tweets[stemmed_tweets["snapshot_id"] == snapshot_id]
    # dropping tweets with only one word - no use for w2v
    partial_df = partial_df[partial_df["selected_word_list"].apply(lambda x: True if len(x) > 2 else False)]
    print("Number of tweets for training: %i" % len(partial_df))
    w2v_data = list(partial_df["selected_word_list"])
    model = gensim.models.Word2Vec(w2v_data, min_count=5, size=w2v_model_dim, batch_words=100)
    if dump_model:
        output_dir = "%s/dim_%i/" % (model_root_folder, w2v_model_dim)
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        model.save(fname_or_handle="%s/%s.w2v" % (output_dir, snapshot_id))
    return (snapshot_id, model)

## Training + Exporting models for all snapshots

In [None]:
snapshot_ids = list(stemmed_tweets["snapshot_id"].unique())
for snapshot_id in sorted(snapshot_ids):
    snapshot_id, _ = train_w2v_model(snapshot_id, dump_model=True)
    print(snapshot_id)