In [30]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [38]:
import sys
import os
current_dir = os.path.dirname(os.path.abspath('test_data_helper.py'))
sys.path.append(os.path.dirname(current_dir))
import data_helper as dh
import pandas as pd
import pdb
import model_helper as mh
import random
import numpy as np
import torch
import matplotlib.pyplot as plt

def set_seed(seed, device):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if device != 'cpu':        
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
set_seed(42, 'cpu')
DATA_SOURCES = [{
            "name": "EDT",
            "text_path": "../../data/EDT/evaluate_news.json",
            "ts_path": "../../data/stock_emotions/price/",
            "ts_date_col": 'Date',
            'text_date_col': 'date',
            'text_col': 'text',
            'train_dates': '01/01/2020 - 03/09/2020',
            'test_dates': '04/09/2020 - 31/12/2020'
        },{
            "name": "stock_emotion",
            "text_path": "../../data/stock_emotions/tweet/processed_stockemo.csv",
            "ts_path": "../../data/stock_emotions/price/",
            "ts_date_col": 'Date',
            'text_date_col': 'date',
            'text_col': 'text',
            'train_dates': '01/01/2020 - 03/09/2020',
            'test_dates': '04/09/2020 - 31/12/2020'
        },  {
            "name": "stock_net",
            "text_path": "../../data/stocknet/tweet/organised_tweet.csv",
            "ts_path": "../../data/stocknet/price/raw/",
            "ts_date_col": 'Date',
            'text_date_col': 'created_at',
            'text_col': 'text',
            'train_dates': '01/01/2014 - 01/08/2015',
            'test_dates': '01/08/2015 - 01/01/2016'
        }
]


In [53]:
def plot_ticker_per_day(df):
        ticker_date_count = df.groupby(['ticker', 'date']).size().reset_index(name='count')
        # Create bar plots for each ticker
        unique_tickers = ticker_date_count['ticker'].unique()
        pdb.set_trace()

        for ticker in unique_tickers:
                ticker_data = ticker_date_count[ticker_date_count['ticker'] == ticker]
                plt.figure(figsize=(10, 6))
                plt.bar(ticker_data['date'], ticker_data['count'], color='blue')
                plt.title(f'Number of rows for {ticker} per day')
                plt.xlabel('Date')
                plt.ylabel('Number of rows')
                plt.xticks(rotation=45)
                plt.tight_layout()
                plt.show()

def test_wrangle_data():
    for data_source in DATA_SOURCES:
        text_df, ts_df = dh.wrangle_data(data_source)
        expected_columns_text = {'ticker', 'date', 'text'}
        assert set(expected_columns_text).issubset(set(text_df.columns)), f"The expected columns {expected_columns_text} should exist in the text DataFrame but got {text_df.columns} instead"
        
        expected_columns_ts = {'ticker', 'date', 'Close'}
        assert set(expected_columns_ts).issubset(set(ts_df.columns)), f"The expected columns {expected_columns_ts} should exist in the text DataFrame but got {ts_df.columns} instead"
        
        #text_df = text_df[text_df['ticker'].isin(top_tickers)].reset_index(drop=True)
        ##check the dates are within a reasonable period of each other
        #earliest_text_date = text_df['date'].min()
        #latest_text_date = text_df['date'].max()
        #earliest_ts_date = ts_df['date'].min()
        #latest_ts_date = ts_df['date'].max()
        #reasonable_period = pd.Timedelta(days=30)
#
        #assert (earliest_ts_date >= earliest_text_date - reasonable_period) or 
#
        #assert (earliest_ts_date >= earliest_text_date - reasonable_period) and (latest_ts_date <= latest_text_date + reasonable_period), "The earliest and latest dates in the time series data should be within 30 days of the text data date range"
def is_all_empty_strings(lst):
    return all(s == '' for s in lst)
     
def test_get_data():
    text_tokenizer = mh.get_text_encoder()
    columns_to_check = ['ticker', 'target_date_text_df', 'end_date', 'text_series', 'text_dates', 'text_id', 'target_date_ts_df', 'time_series', 'ts_past_features', 'ts_id', 'label', 'original_ts_past_features']
    for data_source in DATA_SOURCES:
        df_set1 = dh.get_data(text_tokenizer=text_tokenizer, data_source=data_source, subset_data=True, loaders=False)
        df_set2 = dh.get_data(text_tokenizer=text_tokenizer, data_source=data_source, subset_data=True, loaders=False)
        df_set3 = dh.get_data(text_tokenizer=text_tokenizer, data_source=data_source, subset_data=True, loaders=False, random_state=1)
        
        
        true_cols = []
        for df1, df2, df3 in zip(df_set1, df_set2, df_set3):
            #check there are 3 tickers present
            assert df1['ticker'].nunique() == 3
            assert df2['ticker'].nunique() == 3
            assert df3['ticker'].nunique() == 3
        
            #check if there are any examples of text_series with no text at all    
            assert not df1['text_series'].apply(is_all_empty_strings).any(), "There exists a text series with entirely empty strings"
            assert not df2['text_series'].apply(is_all_empty_strings).any(), "There exists a text series with entirely empty strings"
            assert not df3['text_series'].apply(is_all_empty_strings).any(), "There exists a text series with entirely empty strings"

            #test reproducibility     
            for column in columns_to_check:
                assert df1[column].equals(df2[column]), 'df reproducibility is hindered, dfs are not created the same with the same seed'
                #we can't just assert that df3 and df1 cols are the same col by col; this is because sometimes columns will be the same such as ticker or label
                
                #but we care that the whole df is diff such as text selected, negative pairs etc
                true_cols.append(df3[column].equals(df1[column]))
        assert not all(true_cols), 'df reproducibility is hindered, all dfs are the same'

def test_create_time_series_df():
     for data_source in DATA_SOURCES:
        text_df, ts_df = dh.wrangle_data(data_source)

        text_df, ts_df = dh.subset_data_helper(data_source=data_source, text_df=text_df, ts_df=ts_df)
        text_date_col = data_source['text_date_col']
        ts_date_col = data_source["ts_date_col"]
        text_col = data_source["text_col"]

        text_df[text_date_col] = pd.to_datetime(text_df[text_date_col], utc=True).dt.date
        ts_df[ts_date_col] = pd.to_datetime(ts_df[ts_date_col], utc=True).dt.date

        k = 6
        mode='start'
        created_df = dh.create_time_series_df(df=ts_df, k=k, mode=mode)
        pdb.set_trace()

def test_create_text_series_df():
     pass

def test_create_pairs():
     ts_window = 6
     ts_mode='start'
     text_window=3
     text_selection_method=('TFIDF', 5)
     for data_source in DATA_SOURCES:
        text_df, ts_df = dh.wrangle_data(data_source)

        text_df, ts_df = dh.subset_data_helper(data_source=data_source, text_df=text_df, ts_df=ts_df)
        text_date_col = data_source['text_date_col']
        ts_date_col = data_source["ts_date_col"]
        text_col = data_source["text_col"]

        text_df[text_date_col] = pd.to_datetime(text_df[text_date_col], utc=True).dt.date
        ts_df[ts_date_col] = pd.to_datetime(ts_df[ts_date_col], utc=True).dt.date
        #convert df to id, tickers:[list], start_date, texts:list, time_series:list, past_time_features:[list]
        text_df, ts_df = dh.process_windows(text_df=text_df, 
                                         ts_df=ts_df, 
                                         ts_window=ts_window,
                                         ts_mode=ts_mode, 
                                         text_window=text_window, 
                                         text_selection_method=text_selection_method, 
                                         text_col=text_col, 
                                         text_time_col=text_date_col, 
                                         ts_time_col=ts_date_col)
        df = dh.create_pairs(text_df=text_df, ts_df=ts_df, negatives_creation=('naive', 60))
        #need to assert for each row:
                #no text date in the text_dates is outside of the range expected
                #based on the range from ts_past_features
        pdb.set_trace()

#test_wrangle_data()
#test_get_data()
test_create_pairs()

> [1;32mc:\users\eoinp\appdata\local\temp\ipykernel_96572\4080804378.py[0m(96)[0;36mtest_create_pairs[1;34m()[0m

                                          ids ticker target_date_text_df  \
0                    [152463, 154113, 170317]    BAC          2020-03-17   
1                            [127255, 194636]    BAC          2020-03-24   
2                            [194636, 189151]    BAC          2020-03-25   
3                    [173896, 173064, 138840]    BAC          2020-03-30   
4            [173064, 178277, 173896, 121665]    BAC          2020-03-31   
..                                        ...    ...                 ...   
625      [16637, 1365, 61918, 209744, 252852]    BAC          2021-03-16   
626      [270801, 36200, 43963, 55540, 29453]    JNJ          2021-03-23   
627  [122279, 174834, 198701, 156672, 136235]    JNJ          2020-08-03   
628  [164458, 127782, 126815, 144666, 139022]    JNJ          2020-05-20   
629                  [191762, 150819, 145656]