# To be able to train models, weather information needs to be assigned to tweets   
- Here, we determine location of tweet and assign corresponding weather data from records (according to location and time)
- In addition, we check for outliers and other sources of contamination

In [None]:
# allows update of external libraries without need to reload package
%load_ext autoreload
%autoreload 2

In [None]:
import os
import pathlib
import logging

import warnings

logging.basicConfig(level=logging.INFO)
import json
import os
import glob
import functools
import multiprocessing
import datetime
import gc

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors
import tqdm
import xarray

import tweepy
import pyproj
import shapely.geometry

import a2.twitter.locations
import a2.dataset.load_dataset
import a2.dataset.utils_dataset
import a2.utils
import a2.plotting
import a2.preprocess.normalize_text

In [None]:
all_files = ["tweets_no_keywords_2020-02-13T00:00:00.000Z_2020-02-14T00:00:00.json"]
all_files = a2.utils.file_handling.get_all_files("tweets_[0-9]*.json")
figure_path = pathlib.Path("../../figures/data/rain/")

filename_base = "tweets_no_keywords_2014-2016_locations"

In [None]:
all_files

In [None]:
ds = a2.dataset.load_dataset.load_tweets_dataframe_from_jsons(all_files).to_xarray()
ds["created_at"] = (["index"], pd.to_datetime(ds.created_at).values)

In [None]:
ds

## Analysis Tweets with tagged location

Many tweets share the same location!!

In [None]:
places = (
    ds.where(
        ~a2.dataset.utils_dataset.is_nan(ds, "geo.coordinates.coordinates"),
        drop=True,
    )
    .groupby("geo.place_id")
    .count()
    .sortby("id", ascending=False)
)

In [None]:
times_visited = places["id"]
y, x, _ = plt.hist(times_visited, bins=100, log=True);

In [None]:
fig = plt.figure()
ax = plt.gca()
ax.plot(np.cumsum(times_visited))
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xlabel("number of locations")
ax.axvline(100)
ax.set_ylabel("cumulative number of tweets covered")
ax.set_xlim([1, None]);

In [None]:
print(
    f'tweets with tracking activated {ds.where(~a2.dataset.utils_dataset.is_nan(ds, "geo.coordinates.coordinates"), drop=True)["index"].shape[0]} and {ds.where(a2.dataset.utils_dataset.is_nan(ds, "geo.coordinates.coordinates"), drop=True)["index"].shape[0]} are tagged. Places with location and place_id {ds.where(~a2.dataset.utils_dataset.is_nan(ds, "geo.coordinates.coordinates") & ~a2.dataset.utils_dataset.is_nan(ds, "geo.place_id"), drop=True)["index"].shape[0]}'
)

In [None]:
ds = a2.twitter.locations.add_locations(
    ds,
    filename_location="locations.json",
    filename_location_not_found="locations_not_found.csv",
    download=True,
    key_place_id="geo.place_id",
    key_coordinates="geo.coordinates.coordinates",
)

In [None]:
ds.created_at.min()

In [None]:
a2.dataset.load_dataset.save_dataset(
    ds,
    f"{filename_base}.nc",
    add_attributes="added locations",
    no_conversion=False,
)

In [None]:
ds_loc = a2.dataset.load_dataset.load_tweets_dataset(
    f"{filename_base}.nc",
    raw=True,
)

In [None]:
ds_lj = pd.read_json("locations.json").to_xarray()

In [None]:
np.sum(a2.dataset.utils_dataset.is_nan(ds_lj, "centroid"))

In [None]:
ds_lj["centroid"] = (
    ["index"],
    np.array([str(x) for x in ds_lj.centroid.values]),
)

In [None]:
ds_lj_g_cen = a2.dataset.utils_dataset.dataset_groupby(ds_lj, "centroid")

In [None]:
ds_lj_g_cen

## Analysis tagged/tracked Tweets
GPS tracked Tweets have empty strings `''` in place of location tagged Tweets. 

In [None]:
ds_tweets_without_location = ds_loc.where(
    (a2.dataset.utils_dataset.is_na(ds_loc, "centroid"))
    & (a2.dataset.utils_dataset.is_na(ds_loc, "geo.coordinates.coordinates")),
    drop=True,
)
n_tweets_without_location = ds_tweets_without_location.index.shape[0]
print(
    f"Total number of Tweets: {ds_loc['index'].shape[0]}, Tweets with tagged location: {ds_loc.where(~a2.dataset.utils_dataset.is_na(ds_loc, 'centroid'), drop=True)['index'].shape[0]}, tweets with GPS tracked location {ds_loc.where(~a2.dataset.utils_dataset.is_na(ds_loc, 'geo.coordinates.coordinates'), drop=True).index.shape[0]} location missing for {n_tweets_without_location}"
)

In [None]:
ds_loc.where(ds_loc.coordinates_estimated == True).groupby("place_type").count().plot.scatter(y="id", x="place_type")
ax = plt.gca()
ax.tick_params(axis="x", labelrotation=0)
ax.set_yscale("log")
ax.set_xlabel("occurence of place_type when no coords given")
ax.set_ylabel("count");

## Computing area of central bounding box for tagged Tweets

In [None]:
# requires conversion of bounding box in str format to dict
ds_loc["bounding_box"] = (
    ["index"],
    a2.dataset.load_dataset.convert_str_to_dict(ds_loc["bounding_box"].values),
)

In [None]:
pd.isnull(ds_loc.bounding_box.values[0])

In [None]:
ds_loc["bounding_box_area"] = (
    ["index"],
    a2.twitter.locations.compute_area_bounding_box(ds_loc.bounding_box.values),
)

In [None]:
ds_loc["bounding_box_area"].plot.hist(bins=np.logspace(np.log10(1e-1), np.log10(1e6), 50))
ax = plt.gca()
ax.set_xscale("log")
ax.set_xlabel("size of bounding box [km$^2$]")
ax.set_ylabel("count of tweets");

In [None]:
print(
    f"tweets with coordinates or tagged with place, which is smaller than 100 km^2: {ds_loc.where((ds_loc.bounding_box_area < 100) & (ds_loc.bounding_box_area > 0) & (ds_loc['geo.coordinates.coordinates'] == 'nan'), drop=True)['index'].shape[0]} and larger for {ds_loc.where((ds_loc.bounding_box_area > 100) & (ds_loc['geo.coordinates.coordinates'] == 'nan'), drop=True)['index'].shape[0]}"
)

In [None]:
a2.dataset.load_dataset.save_dataset(
    ds_loc,
    f"{filename_base}_bba.nc",
    add_attributes=",computed bounding box area",
    no_conversion=False,
)

## Reload saved twitter dataset

In [None]:
ds_twit = a2.dataset.load_dataset.load_tweets_dataset(
    f"{filename_base}_bba.nc",
    raw=True,
)

In [None]:
print(
    f"found no location for {ds_twit.where((a2.dataset.utils_dataset.is_na(ds_twit, 'geo.coordinates.coordinates')) & (a2.dataset.utils_dataset.is_na(ds_twit, 'centroid')), drop=True).index.shape[0]} tweets"
)

In [None]:
# remove tweets without specified location
ds_twit_loc = ds_twit.where(
    (~a2.dataset.utils_dataset.is_na(ds_twit, "geo.coordinates.coordinates"))
    | (~a2.dataset.utils_dataset.is_na(ds_twit, "centroid")),
    drop=True,
)
ds_twit_loc = a2.dataset.load_dataset.reset_index_coordinate(ds_twit_loc)

In [None]:
print(f"removed {ds_twit.index.shape[0] - ds_twit_loc.index.shape[0]} Tweets")

## Convert centroid and coordinates to latitude and longitude

In [None]:
ds_twit_loc = a2.twitter.locations.convert_coordinates_to_lat_long(
    ds_twit_loc,
    key_coordinates="centroid",
    prefix_lat_long="centroid_",
    overwrite=True,
)

In [None]:
ds_twit_loc = a2.twitter.locations.convert_coordinates_to_lat_long(
    ds_twit_loc,
    key_coordinates="geo.coordinates.coordinates",
    prefix_lat_long="",
    overwrite=True,
)

In [None]:
mask = a2.dataset.utils_dataset.is_nan(ds_twit_loc, "longitude") | a2.dataset.utils_dataset.is_nan(
    ds_twit_loc, "latitude"
)
ds_twit_loc["latitude"].loc[mask] = ds_twit_loc["centroid_latitude"].loc[mask]
ds_twit_loc["longitude"].loc[mask] = ds_twit_loc["centroid_longitude"].loc[mask]

In [None]:
print(
    f"tweets not created in years 2014-2016: {ds_twit_loc.where((ds_twit_loc.created_at < pd.to_datetime(['2014-01-01']).values) | (ds_twit_loc.created_at > pd.to_datetime(['2017-01-01']).values), drop=True)['index'].shape[0]}"
)

In [None]:
print(
    f"tweet without specified latitude or longitude: {ds_twit_loc.where(a2.dataset.utils_dataset.is_na(ds_twit_loc, 'longitude') | a2.dataset.utils_dataset.is_na(ds_twit_loc, 'latitude'), drop=True).index.shape[0]}"
)

In [None]:
ds_twit_loc["latitude_rounded"] = (
    "index",
    ds_twit_loc["latitude"].astype(float).values.round(decimals=1),
)
ds_twit_loc["longitude_rounded"] = (
    "index",
    ds_twit_loc["longitude"].astype(float).values.round(decimals=1),
)

In [None]:
ds_twit_loc["created_at_h"] = (
    "index",
    pd.to_datetime(ds_twit_loc.created_at).round("1h").values,
)

In [None]:
ds_twit_loc

In [None]:
a2.dataset.load_dataset.save_dataset(
    ds_twit_loc,
    f"{filename_base}_locations_bba_prepTp.nc",
    add_attributes=", prepared for combining with tp",
)

## Convert weather data from cumulative to precipitation per hour

In [None]:
# ds_tp_raw = a2.dataset.load_dataset.load_tweets_dataset(
#     "../../data/precipitation/ds_prec_era5_uk_2017-2020_RAW.nc", raw=True
# )
ds_tp_raw = a2.dataset.load_dataset.load_tweets_dataset("reanalysis-era5-land-2014-2016_RAW.nc", raw=True)

In [None]:
ds_tp_raw

In [None]:
def create_uncumulative_dataset(
    ds,
    key="tp",
    time="time",
    key_new="tp_h",
    time_new="time_half",
    skip=1,
    skip_n=24,
    dim=None,
):
    if dim is None:
        dim = ["time", "latitude", "longitude"]
    if dim != ["time", "latitude", "longitude"]:
        raise NotImplementedError(f"{dim} != ['time', 'latitude', 'longitude']")
    # assumes time index comes first
    tp = ds[key].values[skip:, :, :]

    time = ds[time].values[skip:]
    time_converted = time - np.timedelta64(30, "m")

    mask_not_diff = slice(0, None, skip_n)

    shape = tp.shape
    tp_converted = np.concatenate((tp[0, :, :].reshape(1, shape[1], shape[2]), np.diff(tp, axis=0)))

    tp_converted[mask_not_diff, :, :] = tp[mask_not_diff, :, :]
    dim_new = [time_new] + dim[1:]

    coords = {time_new: time_converted}
    coords.update({d: ds[d].values for d in dim[1:]})
    data_vars = {key_new: (dim_new, tp_converted)}
    ds_new = xarray.Dataset(data_vars=data_vars, coords=coords)

    return ds_new


ds_clean = create_uncumulative_dataset(ds_tp_raw)

In [None]:
ds_clean

In [None]:
# reanalysis-era5-land-2014-2016_RAW.nc
a2.dataset.load_dataset.save_dataset(
    ds_clean,
    "ds_prec_era5_uk_2014-2016_decum.nc",
    no_conversion=True,
)

In [None]:
ds_clean = a2.dataset.load_dataset.load_tweets_dataset(
    "ds_prec_era5_uk_2014-2016_decum.nc",
)

In [None]:
import matplotlib

matplotlib.colors.Normalize

In [None]:
ds_neg = ds_clean.where(ds_clean.tp_h < 0, drop=True)

In [None]:
ds_neg_sel = ds_neg.sel(time_half=np.datetime64("2016-12-31T14:30:00.000000000"))

In [None]:
norm = matplotlib.colors.LogNorm(vmin=1e-7, vmax=1e-3)
cmap = "magma"

In [None]:
ds_clean.sel(time_half=np.datetime64("2016-12-31T14:30:00.000000000")).tp_h.plot(norm=norm, cmap=cmap)

In [None]:
ds_tp_raw.sel(time=np.datetime64("2016-12-31T14:00:00.000000000")).tp.plot(norm=norm, cmap=cmap)

In [None]:
a2.plotting.histograms.plot_histogram(
    ds_clean.tp_h.values * 1e3,
    log=["symlog", "log"],
    linear_thresh=1e-9,
    n_bins=100,
    xlim=[-1e3, 1e3],
    label_x="tp",
    filename=figure_path / "tp_histogram_ds_prec_era5_uk_2014-2016_decum.pdf",
)

## Check weather data

In [None]:
weather_files = ["ds_prec_era5_uk_2014-2016_decum.nc"]

In [None]:
for f in weather_files:
    filename_stem = pathlib.Path(f).stem
    print(f"... working on {filename_stem}")
    ds_tp = a2.dataset.load_dataset.load_tweets_dataset(f, raw=True)

    tp = ds_tp.tp_h.values[:]
    tp_non_nan = np.full_like(tp, -1)
    mask = ~np.isnan(tp)
    tp_non_nan[mask] = tp[mask]

    time_index_max = np.array([np.max(tp_non_nan[i, :, :]) for i in range(tp_non_nan.shape[0])])
    minimum_of_maximum_tp_per_time_index = np.min(time_index_max)

    if minimum_of_maximum_tp_per_time_index <= 0:
        warnings.warn(
            f"all values of tp equal/below 0 for a time index: {minimum_of_maximum_tp_per_time_index}! at time indices {ds_tp.time_half.values[time_index_max < 0]}"
        )
    a2.plotting.histograms.plot_histogram(
        tp_non_nan,
        log=["symlog", "log"],
        linear_thresh=1e-16,
        n_bins=100,
        xlim=[-1e3, 1e-3],
        label_x="tp",
        filename=figure_path / f"tp_histogram_{filename_stem}.pdf",
    )
    # del ds_tp, tp
    # gc.collect()

## Load twitter data

In [None]:
ds_twit = a2.dataset.load_dataset.load_tweets_dataset(
    f"{filename_base}_locations_bba_prepTp.nc",
    raw=True,
)

## Check for time/location outliers
- 'Louth, England' erroneously shifted 5 degrees west
- remove outlier(s) (Philippines, ...)
- remove tweets created at rounded time `2021-01-01T00:00:00.000000000`

In [None]:
# locations outside rough borders of UK
ds_twit.where(
    (
        (ds_twit.latitude_rounded < 49)
        | (ds_twit.latitude_rounded > 61)
        | (ds_twit.longitude_rounded < -9)
        | (ds_twit.longitude_rounded > 3)
    ),
    drop=True,
)

In [None]:
# # set city Louth to correct center by hand
# mask = ds_twit.full_name == "Louth, England"
# ds_twit["longitude"].loc[mask] = 0.0061
# ds_twit["longitude_rounded"].loc[mask] = 0

In [None]:
# # set 40FT Brewery to correct center by hand
# mask = ds_twit.full_name == "40FT Brewery"
# ds_twit["longitude"].loc[mask] = -0.073762
# ds_twit["longitude_rounded"].loc[mask] = -0.1

In [None]:
# locations outside rough borders of UK
ds_twit.where(
    (
        (ds_twit.latitude_rounded < 49)
        | (ds_twit.latitude_rounded > 62)
        | (ds_twit.longitude_rounded < -10)
        | (ds_twit.longitude_rounded > 5)
    ),
    drop=True,
)

In [None]:
# remove remaining outliers
ds_twit_uk = ds_twit.where(
    ~(
        (ds_twit.latitude_rounded < 49)
        | (ds_twit.latitude_rounded > 61)
        | (ds_twit.longitude_rounded < -9)
        | (ds_twit.longitude_rounded > 3)
    ),
    drop=True,
)

In [None]:
ds_tp

In [None]:
ds_tweets_tp = a2.dataset.utils_dataset.add_precipitation_to_tweets(
    ds_tweets=ds_twit_uk,
    ds_precipitation=ds_tp,
    key_precipitation_precipitation="tp_h",
    key_precipitation_tweets="tp_h_m",
)

In [None]:
a2.dataset.utils_dataset.is_nan(ds_tweets_tp, "tp_h_m").sum()

In [None]:
tp = ds_tweets_tp.tp_h_m.values * 1e3
a2.plotting.histograms.plot_histogram(
    tp,
    log=["symlog", False],
    linear_thresh=1e-9,
    n_bins=100,
    xlim=[-1e3, 1e4],
    label_x="tp_h_mm",
    filename=figure_path / f"tp_histogram_{filename_stem}.pdf",
)

In [None]:
a2.plotting.histograms.plot_histogram_2d(
    ds_twit_uk["longitude_rounded"].values,
    ds_twit_uk["latitude_rounded"].values,
    norm="log",
    n_bins=[111, 121],
    xlim=[-9, 2],
    ylim=[49, 61],
)

In [None]:
ds_tweets_tp_nan = ds_tweets_tp.where(a2.dataset.utils_dataset.is_nan(ds_tweets_tp, "tp_h_m"), drop=True)

In [None]:
a2.plotting.histograms.plot_histogram_2d(
    ds_tweets_tp_nan["longitude_rounded"].values,
    ds_tweets_tp_nan["latitude_rounded"].values,
    norm="log",
    n_bins=[111, 121],
    xlim=[-9, 2],
    ylim=[49, 61],
)

In [None]:
a2.plotting.histograms.plot_histogram(
    ds_tp.tp_h.values,
    log=["symlog", "log"],
    linear_thresh=1e-9,
    n_bins=100,
    xlim=[-1e3, 1e3],
    label_x="tp",
    filename=figure_path / f"tp_histogram_ds_prec_era5_uk_2014-2016_decum.pdf",
)

In [None]:
a2.plotting.histograms.plot_histogram(
    ds_tweets_tp.tp_h_m.values,
    log=["symlog", "log"],
    linear_thresh=1e-9,
    n_bins=100,
    xlim=[-1e3, 1e3],
    label_x="tp",
    filename=figure_path / f"tp_histogram_ds_prec_era5_uk_2014-2016_decum_tweets.pdf",
)

In [None]:
ds_tweets_tp_clean = ds_tweets_tp.where(~a2.dataset.utils_dataset.is_nan(ds_tweets_tp, "tp_h_m"), drop=True)

In [None]:
filename_base

In [None]:
a2.dataset.load_dataset.save_dataset(
    ds_tweets_tp_clean,
    f"{filename_base}_bba_Tp_era5.nc",
)

## Assign precipitation value to every Tweet

In [None]:
ds_twit = a2.dataset.load_dataset.load_tweets_dataset(
    "2017_2020_tweets_rain_sun_vocab_emojis_locations_bba_before_combine.nc",
    raw=True,
)

In [None]:
weather_files = ["../../data/precipitation/ds_prec_era5_uk_2017-2020_decum.nc"]

ds_tweets_precipitation = a2.dataset.utils_dataset.add_precipitation_memory_efficient(
    ds_tweets=ds_twit,
    ds_weather_filenames=weather_files,
    key_time_precipitation="time_half",
    key_precipitation_precipitation="tp_h",
    key_precipitation_tweets="tp_era5",
)

In [None]:
ds_tweets_precipitation

### Check for nan values in tp

In [None]:
ds_tweets_precipitation

In [None]:
ds_tweets_precipitation.where(a2.dataset.utils_dataset.is_nan(ds_twit, "tp_era5"), drop=True)

## Total precipitation below 1e-8 m appears to be 0
- seems to be raining roughly half of the time

In [None]:
filename = figure_path / f"tweets_2017-2020_tp_histogram.pdf"
a2.plotting.histograms.plot_histogram(
    ds_twit.tp_era5.values,
    log=["symlog", "log"],
    linear_thresh=1e-9,
    n_bins=100,
    label_x="tp_era5",
    filename=filename,
)

In [None]:
ds_twit["raining_old"] = ("index", np.abs(ds_twit["tp_cum"].values) > 1e-8)

In [None]:
ds_twit["raining"] = ("index", np.abs(ds_twit["tp_era5"].values) > 1e-8)

In [None]:
ds = ds_twit.where(~a2.dataset.utils_dataset.is_nan(ds_twit, "tp_cum"), drop=True)
field_x = "tp"
field_y = "tp_cum"
filename = figure_path / f"tweets_2017-2020_tp_cum_vs_tp_histogram.pdf"
a2.plotting.histograms.plot_2d_histogram(
    ds[field_x].values,
    ds[field_y].values,
    log=["symlog", "symlog"],
    linear_thresh=1e-9,
    n_bins=100,
    label_x=field_x,
    label_y=field_y,
    filename=filename,
    norm="log",
)

In [None]:
ds_twit["raining_old"].astype(int).plot.hist(bins=np.linspace(0, 1));

In [None]:
ds_twit["raining"].astype(int).plot.hist(bins=np.linspace(0, 1));

## Save resulting dataset

In [None]:
a2.dataset.load_dataset.save_dataset(
    ds_twit,
    "tweets_no_keywords_2020-02-13T00:00:00.000Z_2020-02-14T00:00:00_locations_bba_era5.nc",
    add_attributes="added era5 tp",
)