In [1]:
from src.docstore import Docstore

from pathlib import Path
import random

import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# docstore = Docstore.build(Path("data", "mock"))
# docstore.save(Path("tests", "test_artifacts", "test_docstore"))
docstore = Docstore.load(Path("..", "tests", "test_artifacts", "test_docstore"))


In [3]:
query_idx = random.choices(
    list(range(len(docstore._huggingface_dataset))),
    k=5,
)

In [4]:
queries = docstore._huggingface_dataset[query_idx]["Embeddings"]

In [5]:
queries = np.array(queries).astype(np.float32)

In [6]:
_, examples = docstore.get_nearest_examples(queries, 5)

In [7]:
examples[0]

{'Description': ['Netherlands - Foreign Direct Investment - Net Inflows',
  'Ukraine - GDP Growth Rate',
  'Serbia - Exports',
  'Costa Rica - Government Debt to Gdp',
  'Costa Rica - GDP per capita PPP'],
 'Units': ['EUR Million',
  'percent SA',
  'USD Million',
  'percent of GDP',
  'USD Constant 2017 International Dollars'],
 'Source': ['De Nederlandsche Bank',
  'State Statistics Service Of Ukraine',
  'Statistical Office of the Republic of Serbia',
  'Consejo Monetario Centroamericano',
  'World Bank'],
 'Start_Date': [datetime.date(2003, 6, 30),
  datetime.date(2010, 6, 30),
  datetime.date(2001, 8, 31),
  datetime.date(1991, 12, 31),
  datetime.date(1990, 12, 31)],
 'End_Date': [datetime.date(2021, 12, 31),
  datetime.date(2021, 12, 31),
  datetime.date(2022, 2, 28),
  datetime.date(2020, 12, 31),
  datetime.date(2020, 12, 31)],
 'Publisher': ['SGE', 'SGE', 'SGE', 'SGE', 'SGE'],
 '1960-12-31': [None, None, None, None, None],
 '1961-12-31': [None, None, None, None, None],
 '1962

In [8]:
n_neighbors = 7
search_date = "1/1/2022"
num_samples = 100
resample_freq = "M"
aggregation_fn = np.mean
dropna_threshold = 10
metadata_field = "Embeddings"

_dict = docstore.date_aligned_knn(queries, n_neighbors, search_date, num_samples, resample_freq, aggregation_fn, dropna_threshold, metadata_field)


  1960-12-31 1961-12-31 1962-12-31 1963-12-31 1964-12-31 1965-12-31  \
0        NaN        NaN        NaN        NaN        NaN        NaN   
1        NaN        NaN        NaN        NaN        NaN        NaN   
2        NaN        NaN        NaN        NaN        NaN        NaN   
3        NaN        NaN        NaN        NaN        NaN        NaN   
4        NaN        NaN        NaN        NaN        NaN        NaN   
5        NaN        NaN        NaN        NaN        NaN        NaN   
6  2856.5269  2729.1516  2847.1514  2879.0193  2896.0085   3075.448   
0       None       None       None       None       None       None   
1       None       None       None       None       None       None   
2       None       None       None       None       None       None   
3       None       None       None       None       None       None   
4       None       None       None       None       None       None   
5       None       None       None       None       None       None   
6     

In [11]:
_dict["metadata"].shape

torch.Size([5, 7, 768])

In [27]:
y = torch.from_numpy(x)
y.shape

torch.Size([7, 100])

In [36]:
z = F.pad(torch.from_numpy(x), (0, 0, 0, 2))

In [40]:
len(_dict["data"][1].columns) - _dict["data"][4].loc[2].isnull().sum() < dropna_threshold

False

In [32]:
data_tensors = []
for df in _dict["data"]:
    data_tensors.append(torch.stack([torch.FloatTensor(x) for x in _dict["data"][2].values]))

torch.stack(data_tensors).shape

torch.Size([5, 7, 100])

In [13]:
len(_dict["data"][0].columns)

79

In [9]:
dfs = [pd.DataFrame(ex)[docstore.time_index] for ex in examples]

df = pd.concat(dfs)

In [10]:
df.columns = pd.to_datetime(df.columns, format="%Y-%m-%d")
df

Unnamed: 0,1960-12-31,1961-12-31,1962-12-31,1963-12-31,1964-12-31,1965-12-31,1966-12-31,1967-12-31,1968-12-31,1969-12-31,...,2011-12-31,2012-12-31,2013-12-31,2014-12-31,2015-12-31,2016-12-31,2017-12-31,2018-12-31,2019-12-31,2020-12-31
0,,,,,,,,,,,...,102.59,106.35,105.86,106.03,106.36,104.3,102.8,107.0,110.8,108.7
1,,,,,,,,,,,...,101.9,104.1,104.8,105.5,105.8,101.1,101.5,102.6,102.4,105.5
2,,,,,,,,,,,...,1.8,-1.8,4.0,0.9,0.7,0.7,6.5,-2.0,-5.4,0.9
3,,,,,,,,,,,...,109.2,113.9,111.1,103.6,99.5,108.8,113.3,118.3,118.2,115.7
4,,,,,,,,,,,...,134.9,140.9,141.4,143.6,146.7,155.3,158.4,160.0,163.0,165.2
0,,,,,,,,,,,...,5.5,5.1,9.01,10.4,9.272,15.036,16.351,13.619,14.526,15.2
1,,,,,,,,,,,...,0.6,-0.8,2.4,-4.1,1.4,1.9,0.5,1.1,0.0,0.8
2,,,,,,,,,,,...,3.5,3.75,,4.75,4.25,,5.0,,,6.25
3,,,,,,,,,,,...,10.0,11.0,9.8,,,,,,,
4,,,,,,,,,,,...,,7612.39,9552.16,9805.55,10743.01,,,,,


In [55]:
df_resample = df.resample("M", axis=1).agg("mean").ffill(axis=1)
df_resample

Unnamed: 0,1960-12-31,1961-01-31,1961-02-28,1961-03-31,1961-04-30,1961-05-31,1961-06-30,1961-07-31,1961-08-31,1961-09-30,...,2020-03-31,2020-04-30,2020-05-31,2020-06-30,2020-07-31,2020-08-31,2020-09-30,2020-10-31,2020-11-30,2020-12-31
0,,,,,,,,,,,...,69.9,69.9,69.9,69.9,69.9,69.9,69.9,69.9,69.9,69.6
1,,,,,,,,,,,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.8
2,696.3,696.3,696.3,696.3,696.3,696.3,696.3,696.3,696.3,696.3,...,660.07,660.07,660.07,660.07,660.07,660.07,660.07,660.07,660.07,635.12
3,,,,,,,,,,,...,9.0,9.0,9.0,9.0,9.0,9.0,9.0,9.0,9.0,9.0
4,,,,,,,,,,,...,-48079.2,-48079.2,-48079.2,-48079.2,-48079.2,-48079.2,-48079.2,-48079.2,-48079.2,-105906.6
0,,,,,,,,,,,...,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,-137.9
1,696.3,696.3,696.3,696.3,696.3,696.3,696.3,696.3,696.3,696.3,...,660.07,660.07,660.07,660.07,660.07,660.07,660.07,660.07,660.07,635.12
2,,,,,,,,,,,...,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,6.25
3,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,...,15.94,15.94,15.94,15.94,15.94,15.94,15.94,15.94,15.94,16.42
4,,,,,,,,,,,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.8


In [None]:
df = df.fillna(value="NaN")

In [58]:
slice_ = df_resample.loc[:, :"1962-03-15"]
slice_

Unnamed: 0,1960-12-31,1961-01-31,1961-02-28,1961-03-31,1961-04-30,1961-05-31,1961-06-30,1961-07-31,1961-08-31,1961-09-30,1961-10-31,1961-11-30,1961-12-31,1962-01-31,1962-02-28
0,,,,,,,,,,,,,,,
1,,,,,,,,,,,,,,,
2,696.3,696.3,696.3,696.3,696.3,696.3,696.3,696.3,696.3,696.3,696.3,696.3,692.5,692.5,692.5
3,,,,,,,,,,,,,,,
4,,,,,,,,,,,,,,,
0,,,,,,,,,,,,,,,
1,696.3,696.3,696.3,696.3,696.3,696.3,696.3,696.3,696.3,696.3,696.3,696.3,692.5,692.5,692.5
2,,,,,,,,,,,,,,,
3,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.06,3.06,3.06
4,,,,,,,,,,,,,,,


In [82]:
slice_.iloc[:, len(slice_.columns) - 5:].head(5)

Unnamed: 0,1961-10-31,1961-11-30,1961-12-31,1962-01-31,1962-02-28
0,,,,,
1,,,,,
2,696.3,696.3,692.5,692.5,692.5
3,,,,,
4,,,,,
