-
Notifications
You must be signed in to change notification settings - Fork 0
/
search_engine.py
188 lines (156 loc) · 5.8 KB
/
search_engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import time
import pandas as pd
from pytrends.request import TrendReq
from tqdm import tqdm
from datetime import date
from testchecks import is_valid_country_query
from typerhints import (
Countries,
CountryList,
HistoryFrame,
HistoryFrameBlocks,
SearchTerm,
)
# Local imports
from utils import get_sampled_countries, is_leap_year
class SearchEngine:
def __init__(
self,
pytrends: TrendReq,
supported_countries: CountryList,
fetch_interval: int = 1,
):
self.pytrends = pytrends
self.supported_countries = supported_countries
self.fetch_interval = fetch_interval
def get_daily_trends_by_year(
self, search_terms: SearchTerm, year: int, countries: Countries
) -> HistoryFrameBlocks:
# check if illegal query
if not is_valid_country_query(
countries=countries, country_names=self.supported_countries
):
raise ValueError("Illegal query")
if len(search_terms) > 5:
raise ValueError("Too many search terms: MAX 5")
# Find countries either by specific countries or random sampling by indicies
search_countries = get_sampled_countries(
countries=countries, country_names=self.supported_countries
)
trends = []
for geocode, country in tqdm(
search_countries,
desc=f"[{year}] Fetching {search_terms[0]}",
colour="green",
):
try:
history = self._get_yearly_interest_by_country(
search_terms=search_terms, year=year, geocode=geocode
)
trends.append((country, history))
except ValueError as ve:
print(ve)
return trends
def _get_interest_over_time(
self, search_terms: SearchTerm, tf: str, geocode: str
) -> HistoryFrame:
# don't fetch too often I guess?
time.sleep(self.fetch_interval)
# create the payload for related queries
self.pytrends.build_payload(search_terms, timeframe=tf, geo=geocode)
# request data from dataframe
payload = self.pytrends.interest_over_time()
if payload.empty:
payloadd = self.create_empty_payload(tf, search_terms)
print("Empty payload {payload}, but created: ", payloadd.shape)
raise ValueError(
f"[ERROR] Geocode: {geocode}: \
empty payload for terms {search_terms}"
)
return payload.drop(columns="isPartial")
def create_empty_payload(self, timeframe: str, search_terms: SearchTerm):
# timeframe e.g.: 2010-7-1 2010-12-31
first_half = timeframe.split(" ")[0]
second_half = timeframe.split(" ")[1]
current_year = int(first_half.split("-")[0])
first_day = int(first_half.split("-")[2])
first_month = int(first_half.split("-")[1])
last_day = int(second_half.split("-")[2])
last_month = int(second_half.split("-")[1])
first_date = date(current_year, first_month, first_day)
last_date = date(current_year, last_month, last_day)
date_changes = first_date - last_date
n_day = date_changes.days + 1 # add 1 to include last day
empty_payload = pd.DataFrame(
data=None,
index=pd.date_range(start=first_date, periods=n_day, freq="D"),
columns=search_terms,
)
return empty_payload
def _get_yearly_interest_by_country(
self,
search_terms: SearchTerm,
year: int,
geocode: str,
) -> HistoryFrame:
tmp_history = []
expected_days = 365
if is_leap_year(year):
expected_days = 366
# Build timeline
first_six_months = f"{year}-1-1 {year}-6-30"
last_six_months = f"{year}-7-1 {year}-12-31"
time_periods = [first_six_months, last_six_months]
for i, time_period in enumerate(time_periods):
history_frame = self._get_interest_over_time(
search_terms, time_period, geocode
)
tmp_history.append(history_frame)
# without rescaling
combined_history = pd.concat(tmp_history)
assert (
combined_history.shape[0] == expected_days
), f"Expected {expected_days}, not {combined_history.shape[0]} days in a year!"
return combined_history
@staticmethod
def _get_head_and_tail_values(hist):
head_hist = hist.head(1)
tail_hist = hist.tail(1)
heads = []
tails = []
for name in hist.columns:
hval = head_hist[name].values[0]
tval = tail_hist[name].values[0]
heads.append(hval)
tails.append(tval)
assert len(hist.columns) == len(
heads
), "length of heads doesnt match number of column names"
assert len(hist.columns) == len(
tails
), "length of tails doesnt match number of column names"
return heads, tails
if __name__ == "__main__":
from utils import histories_to_pandas, load_countries, plot_history
COUNTRY_DIR = "docs/countries.txt"
COUNTRY_IGNORE_DIR = "docs/ignore.txt"
LANGUAGE = "en-US"
TIME_ZONE = 360
search_engine = SearchEngine(
pytrends=TrendReq(hl=LANGUAGE, tz=TIME_ZONE),
supported_countries=load_countries(
filename=COUNTRY_DIR, ignore=COUNTRY_IGNORE_DIR
),
fetch_interval=1,
)
YEAR = 2010
SEARCH_TERMS = ["MeToo", "ski"]
SEARCH_COUNTRIES = ["Norway"]
history_trends = search_engine.get_daily_trends_by_year(
search_terms=SEARCH_TERMS, year=YEAR, countries=SEARCH_COUNTRIES
)
jj = histories_to_pandas(history_trends)
jj.to_csv("baselayer.csv", index=False)
print(jj)
print(jj.info())
print(jj["date"].dtype)