# ML4F Semester Project

Current blocks/issues:
- write here!!

In [9]:
import os
import re  
import pandas as pd  
from time import time 
import sys
import numpy as np
import spacy  
import logging  
import nltk
from datetime import datetime
# nltk.download('wordnet')
# nltk.download('omw-1.4')
# nltk.download('stopwords')
# nltk.download('vader_lexicon')
import multiprocessing

RANDOM_SEED = 7
DATA_DIR = "../data/"
wsb_data_path = os.path.join(DATA_DIR, 'wsb_comments/wsb_comments_raw.csv') #note: wsb=wallstreetbets
stock_data_path = os.path.join(DATA_DIR, 'GME')

def load_wsb_data(data_path, nrows=None, cols=['author', 'body','created_utc', 'score', 'link_id', 'is_submitter']):
    "Load wsb data, nrows None indicates all rows, otherwise specified integer of rows"
    return pd.read_csv(wsb_data_path, nrows = nrows, delimiter=',', usecols=cols)

wsb_df = load_wsb_data(wsb_data_path, nrows=1000)
wsb_df['raw'] = wsb_df['body']
display(wsb_df.head())

Unnamed: 0,author,body,created_utc,is_submitter,link_id,score,raw
0,LazyMeal,We’re retarded and claim to be often. If you l...,1585123910,False,t3_fom9g6,1,We’re retarded and claim to be often. If you l...
1,math_salts,Yes,1585123909,False,t3_fod66b,1,Yes
2,Legendary_Squirrel,markets been open for 13 min...,1585123905,False,t3_fod66b,1,markets been open for 13 min...
3,WSBMORONICTRADER,Spy can fuck around all it wants just as long ...,1585123901,False,t3_fod66b,1,Spy can fuck around all it wants just as long ...
4,[deleted],[removed],1585123897,False,t3_fom0hg,1,[removed]


In [10]:
class PreProcessing: 
    
    def __init__(self, wsb_data, lemmatize=True, lower_case=True, rem_stopwords=True, rem_punctuation=True, tokenize=True):
        """
        Initialise all class parameters
        
        :param data: nonempty pandas dataframe, wsb dataframe 
        :param lemmatize: bool, whether to perform lemmatization
        :param lower_case: bool, whether to lowercase
        :param rem_stopwords: bool, whether to remove stopwords
        :param tokenize: bool, whether to tokenize
        """
        
        self.wsb_data = wsb_data
        self.lemmatize = lemmatize
        self.lower_case = lower_case
        self.rem_stopwords = rem_stopwords
        self.rem_punctuation = rem_punctuation
        self.tokenize = tokenize
        
    ### Ensure Parameter types 
    #K: need to add to this one gradually as we add columns that we use etc.
    @property
    def wsb_data(self):
        return self._wsb_data
    @wsb_data.setter
    def wsb_data(self, wsb_data):
        
        req_columns = ['author','body','created_utc']
        str_columns = ['body']
        date_columns = ['created_utc']

        # Ensure the provided object is a dataframe
        if not isinstance(wsb_data, pd.DataFrame):
            raise Exception("The provided data must be a pandas Dataframe")
        
        # Ensure wsb dataframe is non empty
        if wsb_data.shape[0] == 0: 
            raise Exception("Provided Dataframe is empty")
        
        # Ensure all required columns are provided
        missing_columns = set(req_columns).difference(set(wsb_data.columns.tolist()))
        if len(missing_columns) > 0:
            raise Exception(f"The columns {missing_columns} are missing from the provided dataframe!")
            
        # Ensure all column names don't have unexpected periods
        if '.' in list(''.join(wsb_data.columns.tolist())):
            raise Exception("All Column names must not include periods :'.'")
            
        # Ensure all string columns are strings
        non_str_columns = set(str_columns).difference(set(wsb_data.select_dtypes(include='object')))
        if len(non_str_columns) > 0:
            raise Exception(f'The columns {non_str_columns} are expected as string (pandas object) columns.')
        
        # Ensure dates are interpretable
        for date_col in date_columns: 
            if wsb_data[date_col].apply(lambda x: datetime.utcfromtimestamp(x).strftime('%Y-%m-%d %H:%M:%S')).notnull().all():
                try:
                    # Otherwise Convert date using unixtimestamp to datetime object
                    wsb_data[date_col] = wsb_data[date_col].apply(lambda x: datetime.utcfromtimestamp(x).strftime('%Y-%m-%d %H:%M:%S'))
                except: 
                    raise Exception(f"{date_col} must be a valid unixtimestamp format")
                    
        self._wsb_data = wsb_data
        
    @property
    def lemmatize(self):
        return self._lemmatize
    @lemmatize.setter
    def lemmatize(self, lemmatize):
        if not isinstance(lemmatize, bool): 
            raise Exception('lemmatize must be provided as a boolean parameter (True/False) to the class')
        self._lemmatize = lemmatize
    
    @property
    def lower_case(self):
        return self._lower_case
    @lower_case.setter
    def lower_case(self, lower_case):
        if not isinstance(lower_case, bool): 
            raise Exception('lower_case must be provided as a boolean parameter (True/False) to the class')
        self._lower_case = lower_case
      
    @property
    def rem_stopwords(self):
        return self._rem_stopwords
    @rem_stopwords.setter
    def rem_stopwords(self, rem_stopwords):
        if not isinstance(rem_stopwords, bool): 
            raise Exception('rem_stopwords must be provided as a boolean parameter (True/False) to the class')
        self._rem_stopwords = rem_stopwords
        
    @property
    def rem_punctuation(self):
        return self._rem_punctuation
    @rem_punctuation.setter
    def rem_punctuation(self, rem_punctuation):
        if not isinstance(rem_punctuation, bool): 
            raise Exception('rem_punctuation must be provided as a boolean parameter (True/False) to the class')
        self._rem_punctuation = rem_punctuation
        
    @property
    def tokenize(self):
        return self._tokenize
    @tokenize.setter
    def tokenize(self, tokenize):
        if not isinstance(tokenize, bool): 
            raise Exception('tokenize must be provided as a boolean parameter (True/False) to the class')
        self._tokenize = tokenize
        
    def clean_textual_data(self, textual_columns):
        
        ### Ensure the provided textual columns exist, and if single string column name convert it into a list
        if len(textual_columns)<1:
            raise Exception('The number of textual columns to clean must be greater than 0')
        if isinstance(textual_columns, str):
            textual_columns = [textual_columns]
        missing_columns = set(textual_columns).difference(set(self.wsb_data.columns.tolist()))
        if len(missing_columns) > 0:
            raise Exception(f"The columns {missing_columns} to clean are missing from the wsb dataframe!")

        def lower_case_fn(self, col_name): 
            self.wsb_data[col_name] = self.wsb_data[col_name].str.lower()
            return self.wsb_data

        def lemmatize_fn(self, col_name):
            w_tokenizer = nltk.tokenize.WhitespaceTokenizer()
            lemmatizer = nltk.stem.WordNetLemmatizer()
            self.wsb_data[col_name] = self.wsb_data[col_name].apply(lambda x: [lemmatizer.lemmatize(w) for w in w_tokenizer.tokenize(x)])
            return self.wsb_data

        def stemming_fn(self, col_name):
            w_tokenizer = nltk.tokenize.WhitespaceTokenizer()
            stemmer = nltk.stem.porter.PorterStemmer()
            self.wsb_data[col_name] = self.wsb_data[col_name].apply(lambda x: [stemmer.stem(w) for w in w_tokenizer.tokenize(x)])
            return self.wsb_data

        def tokenize_fn(self, col_name):
            w_tokenizer = nltk.tokenize.WhitespaceTokenizer()
            self.wsb_data[col_name] = self.wsb_data[col_name].apply(lambda x: [w for w in w_tokenizer.tokenize(x)])
            return self.wsb_data

        def rem_punctuation_fn(self, col_name):
            self.wsb_data[col_name] = self.wsb_data[col_name].apply(lambda x: [w for w in x if w.isalnum()])
            return self.wsb_data

        def rem_stopwords_fn(self, col_name):
            "stopwords dictionary considered English, wsb is an english forum"
            remove_elements = set(nltk.corpus.stopwords.words('english'))
            self.wsb_data[col_name] = self.wsb_data[col_name].apply(lambda x: [w for w in x if not w in remove_elements])
            return self.wsb_data

        def remove_tokenization(self, col_name):
            "Necessary as final step to untokenize in case desired, tokenization required for other functions to not break"
            self.wsb_data[col_name] = self.wsb_data[col_name].apply(lambda x: ' '.join(x))
            return self.wsb_data

        for textual_col in textual_columns:

            if self.lower_case:
                lower_case_fn(self, textual_col)

            # lemmatize tokens if true, if false, stem tokens, if None then just tokenize
            if self.lemmatize:
                lemmatize_fn(self, textual_col)
            elif self.lemmatize:
                stemming_fn(self, textual_col)
            else: 
                tokenize_fn(self, textual_col)

            if self.rem_punctuation:
                rem_punctuation_fn(self, textual_col)
            if self.rem_stopwords:
                rem_stopwords_fn(self, textual_col)
            if not self.tokenize:
                remove_tokenization(self, textual_col)

        return self.wsb_data
        
        
    # to later remove: for development
    def output_data(self):
        return self.wsb_data 

In [11]:
WSBPreProcessing = PreProcessing(wsb_df, lemmatize=True, lower_case=True, rem_stopwords=True, rem_punctuation=True, tokenize=True)
WSBPreProcessing.clean_textual_data('body')
useful_columns = ['author','raw', 'body','created_utc', 'score', 'link_id', 'is_submitter']
WSB_preprocessed_data = WSBPreProcessing.output_data()[useful_columns]

## Data Cleaning
- Get rid of deleted comments and row with empty or nan values
- Add POS, Tag, Dep, Shape
- Remove URL and @

In [15]:
import spacy
nlp = spacy.load("en_core_web_sm") # efficiency and accuracy --> "en_core_web_trf"

WSB_preprocessed_data = WSB_preprocessed_data[WSB_preprocessed_data['author'] != '[deleted]']

raw = []
pos = []
tag = []
dep = []
shape = []

for sentence in WSB_preprocessed_data['raw']:
    sentence = re.sub('http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+#]|[!*\(\),]|'\
                       '(?:%[0-9a-fA-F][0-9a-fA-F]))+','', sentence)
    sentence = re.sub("(@[A-Za-z0-9_]+)","", sentence)
    raw.append(sentence)
    
    pos_tmp = []
    tag_tmp = []
    dep_tmp = []
    shape_tmp = []
    for token in nlp(sentence):
        pos_tmp.append(token.pos_)
        tag_tmp.append(token.tag_)
        dep_tmp.append(token.dep_)
        shape_tmp.append(token.shape_)
    pos.append(pos_tmp)
    tag.append(tag_tmp)
    dep.append(dep_tmp)
    shape.append(shape_tmp) 
     
WSB_preprocessed_data['Pos'] = pos
WSB_preprocessed_data['Tag'] = tag
WSB_preprocessed_data['Dep'] = dep
WSB_preprocessed_data['Shape'] = shape    

WSB_preprocessed_data['raw'] = raw

WSB_preprocessed_data.replace('', np.nan, inplace=True)
WSB_preprocessed_data = WSB_preprocessed_data.dropna(how='any',axis=0)
WSB_preprocessed_data = WSB_preprocessed_data[WSB_preprocessed_data['body'].map(lambda d: len(d)) > 0]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user

## Sentiment Analyser
- Flair
- Vader
- Blob

In [16]:
from flair.models import TextClassifier
from flair.data import Sentence
classifier = TextClassifier.load('en-sentiment')
from nltk.sentiment import SentimentIntensityAnalyzer
from textblob import TextBlob

2022-01-02 10:32:54,452 loading file /home/kamran/.flair/models/sentiment-en-mix-distillbert_4.pt


In [37]:
fler = []
vader = []
blob = []

for i in range(len(WSB_preprocessed_data)):
    raw = WSB_preprocessed_data['raw'].iloc[i]
    sentence = WSB_preprocessed_data['body'].iloc[i]
    
    # FLAIR
    s = Sentence(sentence)
    classifier.predict(s)
    total_sentiment = str(s.labels[0])
    num = float(re.findall(r'\d+\.\d+', total_sentiment)[0])
    if total_sentiment.find('POSITIVE') == -1:
        num = num * -1
    fler.append(num)

    # VADER
    analyzer = SentimentIntensityAnalyzer()
    vs = analyzer.polarity_scores(raw)['compound']
    vader.append(float(vs))

    # BLOB
    _tmp = TextBlob(raw).sentiment.polarity
    blob.append(float(_tmp))
    
WSB_preprocessed_data['FLAIR'] = fler
WSB_preprocessed_data['VADER'] = vader
WSB_preprocessed_data['BLOB'] = blob


In [38]:
WSB_preprocessed_data

Unnamed: 0,author,raw,body,created_utc,score,link_id,is_submitter,Pos,Tag,Dep,Shape,FLAIR,VADER,BLOB
0,LazyMeal,We’re retarded and claim to be often. If you l...,"[retarded, claim, listen, doe, make]",2020-03-25 08:11:50,1,t3_fom9g6,False,"[PRON, VERB, ADJ, CCONJ, VERB, PART, AUX, ADV,...","[PRP, VBZ, JJ, CC, VBP, TO, VB, RB, ., IN, PRP...","[nsubjpass, auxpass, ROOT, cc, conj, aux, xcom...","[Xx, ’xx, xxxx, xxx, xxxx, xx, xx, xxxx, ., Xx...",-0.9713,-0.5719,-0.8000
1,math_salts,Yes,[yes],2020-03-25 08:11:49,1,t3_fod66b,False,[INTJ],[UH],[ROOT],[Xxx],0.9918,0.4019,0.0000
2,Legendary_Squirrel,markets been open for 13 min...,"[market, open, 13]",2020-03-25 08:11:45,1,t3_fod66b,False,"[NOUN, AUX, ADJ, ADP, NUM, NOUN, PUNCT]","[NNS, VBN, JJ, IN, CD, NN, .]","[nsubj, ROOT, acomp, prep, nummod, pobj, punct]","[xxxx, xxxx, xxxx, xxx, dd, xxx, ...]",0.8276,0.0000,0.0000
3,WSBMORONICTRADER,Spy can fuck around all it wants just as long ...,"[spy, fuck, around, want, long, 220, put, prin...",2020-03-25 08:11:41,1,t3_fod66b,False,"[NOUN, VERB, VERB, ADV, DET, PRON, VERB, ADV, ...","[NN, MD, VB, RB, DT, PRP, VBZ, RB, RB, RB, IN,...","[nsubj, aux, ROOT, prep, pobj, nsubj, relcl, a...","[Xxx, xxx, xxxx, xxxx, xxx, xx, xxxx, xxxx, xx...",-0.9948,-0.5423,-0.0625
5,madamlazonga,"you lost me at ""bulls fucked""",[lost],2020-03-25 08:11:36,1,t3_fod66b,False,"[PRON, VERB, PRON, ADP, PUNCT, NOUN, VERB, PUNCT]","[PRP, VBD, PRP, IN, ``, NNS, VBN, '']","[nsubj, ROOT, dobj, prep, punct, nsubj, pcomp,...","[xxx, xxxx, xx, xx, "", xxxx, xxxx, ""]",-0.9932,-0.7717,-0.6000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
993,plimsickins16978,ANAL 12/11 $50,[anal],2020-12-05 20:08:38,4,t3_k6tl0d,False,"[ADJ, NUM, SYM, NUM]","[JJ, CD, $, CD]","[compound, ROOT, nmod, npadvmod]","[XXXX, dd/dd, $, dd]",0.8674,0.0000,0.0000
994,PlentyC,Mmm fuck yeah,"[mmm, fuck, yeah]",2020-12-05 20:08:37,7,t3_k6tl0d,False,"[PROPN, ADJ, INTJ]","[NNP, JJ, UH]","[compound, ROOT, ROOT]","[Xxx, xxxx, xxxx]",0.9925,-0.3182,-0.4000
995,Jujubewise,Haha 😆,[haha],2020-12-05 20:08:36,7,t3_k6tl0d,False,"[PROPN, PROPN]","[NNP, NNP]","[dep, ROOT]","[Xxxx, 😆]",-0.9691,0.4588,0.2000
998,steve_pops_01,Real,[real],2020-12-05 20:08:33,5,t3_k6tl0d,False,[ADJ],[JJ],[ROOT],[Xxxx],0.9993,0.0000,0.2000


## Posts Cluster
- By Cluster
- By Time 

In [39]:
searchfor = ['aapl', 'apple', 'Apple', 'AAPL'] #add ISIN, Permco, etc..
ticker_posts = WSB_preprocessed_data[WSB_preprocessed_data['raw'].str.contains('|'.join(searchfor))]
ticker_posts.sort_values(by=['created_utc'], inplace=True, ascending=False)

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  This is separate from the ipykernel package so we can avoid doing imports until


## Individual Posts

In [40]:
ticker_posts

Unnamed: 0,author,raw,body,created_utc,score,link_id,is_submitter,Pos,Tag,Dep,Shape,FLAIR,VADER,BLOB
880,LegNest,oof aapl **109.26**,"[oof, aapl]",2020-10-30 08:06:39,1,t3_jkhkw8,False,"[PROPN, PROPN, SPACE, PUNCT, PUNCT, NUM, PUNCT...","[NNP, NNP, _SP, NFP, NFP, CD, NFP, NFP]","[amod, ROOT, , dep, dep, ROOT, dep, ROOT]","[xxx, xxxx, , *, *, ddd.dd, *, *]",-0.9887,0.0,0.0
532,Agent248,Holding apple calls is making me gay from last...,"[holding, apple, call, making, gay, last, 2, day]",2020-09-23 20:11:29,2,t3_iyhxsc,False,"[VERB, NOUN, NOUN, AUX, VERB, PRON, ADJ, ADP, ...","[VBG, NN, NNS, VBZ, VBG, PRP, JJ, IN, JJ, CD, ...","[csubj, compound, dobj, aux, ROOT, nsubj, ccom...","[Xxxxx, xxxx, xxxx, xx, xxxx, xx, xxx, xxxx, x...",-0.9977,0.0,0.208333
568,sandawg_,Hey guys! 😃 AAPL gang checking in 🙂. Today was...,"[hey, aapl, gang, checking, today, wa, tough, ...",2020-09-23 20:11:05,-1,t3_iyhxsc,False,"[INTJ, NOUN, PUNCT, PROPN, PROPN, NOUN, VERB, ...","[UH, NNS, ., NNP, NNP, NN, VBG, IN, NNP, ., NN...","[intj, ROOT, punct, compound, compound, nsubj,...","[Xxx, xxxx, !, 😃, XXXX, xxxx, xxxx, xx, 🙂, ., ...",-0.9709,0.643,0.136508
457,Graxiano,"AAPL is so last week, hype from split has died...","[aapl, last, hype, split, ha, died, tsla, new]",2020-08-18 08:06:48,3,t3_iblrnn,False,"[PROPN, AUX, ADV, ADJ, NOUN, PUNCT, NOUN, ADP,...","[NNP, VBZ, RB, JJ, NN, ,, NN, IN, NN, VBZ, VBN...","[nsubj, ccomp, advmod, amod, npadvmod, punct, ...","[XXXX, xx, xx, xxxx, xxxx, ,, xxxx, xxxx, xxxx...",-0.9999,-0.5574,-0.006397
498,Throwawayeconboi,AAPL is as FLAT as a white woman’s ass. \n\n\n...,"[aapl, flat, white, fuuuuuuuuuuuuuuuuuu]",2020-08-18 08:03:50,6,t3_iblrnn,False,"[PROPN, AUX, ADV, ADJ, SCONJ, DET, ADJ, NOUN, ...","[NNP, VBZ, RB, JJ, IN, DT, JJ, NN, POS, NN, .,...","[nsubj, ROOT, advmod, acomp, prep, det, amod, ...","[XXXX, xx, xx, XXXX, xx, x, xxxx, xxxx, ’x, xx...",-0.9963,-0.5423,-0.0125
387,strawberry-jam-boy,Spce tsla se lulu plug wmt upwk amd aapl msft ...,"[spce, tsla, se, lulu, plug, wmt, upwk, amd, a...",2020-07-12 20:09:20,6,t3_hq11ao,False,"[NOUN, NOUN, VERB, NOUN, PROPN, PROPN, PROPN, ...","[NN, NNS, VBP, NN, NNP, NNP, NNP, NNP, NNP, NN...","[compound, compound, compound, compound, compo...","[Xxxx, xxxx, xx, xxxx, xxxx, xxx, xxxx, xxx, x...",0.9968,0.3612,0.0
107,philmacrack123,"Y'all see amazon, if apple is bad, we will sel...","[see, apple, gonna, need, appl, easily, carry,...",2020-04-30 20:11:50,6,t3_gaszeo,False,"[NOUN, VERB, PROPN, PUNCT, SCONJ, NOUN, AUX, A...","[NN, VB, NNP, ,, IN, NN, VBZ, JJ, ,, PRP, MD, ...","[nsubj, ccomp, dobj, punct, mark, nsubj, advcl...","[X'xxx, xxx, xxxx, ,, xx, xxxx, xx, xxx, ,, xx...",-0.9823,-0.0534,0.038294
112,originalmuggins,Lol Apple with the slow roll. You gotta respec...,"[lol, apple, slow, gotta, respect]",2020-04-30 20:11:46,2,t3_gaszeo,False,"[PROPN, PROPN, ADP, DET, ADJ, NOUN, PUNCT, PRO...","[NNP, NNP, IN, DT, JJ, NN, ., PRP, VBD, TO, VB...","[compound, ROOT, prep, det, amod, pobj, punct,...","[Xxx, Xxxxx, xxxx, xxx, xxxx, xxxx, ., Xxx, xx...",0.9679,0.7096,0.25
121,ch1p_skylark,"AAPL at 4:30, UAL 4:15 EST","[aapl, ual, est]",2020-04-30 20:11:42,1,t3_gaszeo,False,"[VERB, ADP, NUM, PUNCT, PROPN, NUM, NOUN]","[VBD, IN, CD, ,, NNP, CD, NN]","[ROOT, prep, pobj, punct, npadvmod, nummod, ap...","[XXXX, xx, d:dd, ,, XXX, d:dd, XXX]",0.7741,0.0,0.0
196,facehuggerpoop,bought one apple call before close.\n\nI think...,"[bought, one, apple, call, think, going, lose,...",2020-04-30 20:11:08,1,t3_gaszeo,False,"[VERB, NUM, NOUN, NOUN, ADP, NOUN, PUNCT, SPAC...","[VBD, CD, NN, NN, IN, NN, ., _SP, PRP, VBP, PR...","[ROOT, nummod, compound, dobj, advmod, advmod,...","[xxxx, xxx, xxxx, xxxx, xxxx, xxxx, ., \n\n, X...",-0.9975,-0.4019,0.285714


## By frequency appareance

In [43]:
# K: Not sure what the issue is: TypeError: can only concatenate str (not "Timedelta") to str

ticker_cluster =[]

for date in ticker_posts['created_utc']:
    for step in range(7200): #5days
        next_date = str(date) + pd.Timedelta(step, unit='m')
    
        a = WSB_preprocessed_data[WSB_preprocessed_data['created_utc'].between(date, next_date, inclusive=True)]
        b = ticker_posts[ticker_posts['created_utc'].between(date, next_date, inclusive=True)]  
    
        if a * 0.9 > b:
            ticker_cluster.append([date, next_date])
            ticker_posts = ticker_posts[ticker_posts['created_utc']>next_date]
            break

TypeError: can only concatenate str (not "Timedelta") to str

## By Post

In [44]:
ticker_cluster =[]

for post in ticker_posts['link_id'].unique():
    if len(ticker_posts[ticker_posts['link_id']==post]) >= len(WSB_preprocessed_data[WSB_preprocessed_data['link_id']==post]) * 0.7:
        ticker_cluster.append(post)

## Fetch Ticker Data 
- Ticker
- Frequency

In [None]:
# # Not needed? I already added the data at the top, scraped from InteractiveBrokers
# import yfinance as yf
# data = yf.download(ticker, interval=frequency, start="2012-01-31", end="2021-02-16")

# df_train, df_test = train_test_split(
#   df,
#   test_size=0.1,
#   random_state=RANDOM_SEED)

In [None]:
from sklearn.model_selection import TimeSeriesSplit
import seaborn as sns
import matplotlib.pyplot as plt

# Need to time series split to preserve time series
tscv = TimeSeriesSplit(gap=0, max_train_size=0.8, n_splits=3, test_size=0.1)

## Model Training

In [None]:
from models import build_tcn, build_lstm

class PerformTraining(PreProcessing):
    
    """
    This class performs the training of the desired model

    ...

    Attributes
    ----------
    seed : int
        the integer of the seed utilised for reproducibility 
    DATA_DIR : str
        a string indicating the directory containing the raw stock and WSB data
    INTERM_DATA_DIR : str
        a string indicating the directory containing intermediate computed data
    MODEL_DIR : str
        a string indicating the directory containing created models
    preprocessing_params : dict
        a dictionary of preprocessing parameters
    saved_dataset_pkl : str
        a string indicating the directory containing the previously preprocess-compiled dataset
    saved_dataset_params_pkl : str
        a string indicating the directory containing the previously saved variance thresholded parameters
    data : (ndarray, ndarray, ndarray, ndarray, ndarray, ndarray)    
        a tuple of either 6, if tuning is True, or 2 ndarrays,
        (X_train, y_train, X_val, y_val, X_test, y_test) or (X_train, y_train)
        
    Methods
    -------
    reproducible_results()
        Sets seed and ensures all deterministic operations are reproducible
    retrieve_data()
        Retrieves the data given the data directory and folders
    prepare_data(preprocessing_params, tuning=True):
        Combines the preprocessing methods and splits the data for training 
    """

    def __init__(self, DATA_DIR, INTERM_DATA_DIR, MODEL_DIR, DATA_SRCS, preprocessing_params, model_type, tuning=True):
        
        """
        Parameters
        ----------
        DATA_DIR : str
            a string indicating the directory containing the raw stock and WSB data
        INTERM_DATA_DIR : str
            a string indicating the directory containing intermediate computed data
        MODEL_DIR : str
            a string indicating the directory containing created models
        preprocessing_params : dict
            a dictionary containing the following keys, 
                ... to fill
        data_aug_params : dict
            a dictionary containing the following keys, 
                ... to fill
        tuning : bool
            a boolean indicating whether this is for tuning or the final model
            for tuning, will prepare a threeway train, validation test split
            if false, will not split the data
        """

        self.seed = 7

        self.DATA_DIR = DATA_DIR
        self.INTERM_DATA_DIR = INTERM_DATA_DIR
        self.MODEL_DIR = MODEL_DIR
        self.DATA_SRCS = DATA_SRCS
        self.preprocessing_params = preprocessing_params
        self.tuning = tuning
        self.max_epochs = 100
        
        # Define pickle file where compiled dataset is saved
        # This is used to reduce unnecessary recomputations of the same preprocessing 
        self.saved_dataset_pkl = os.path.join(self.INTERM_DATA_DIR, 'dataset_var_thresh_preprocessed.pkl')
        self.saved_dataset_params_pkl = os.path.join(self.INTERM_DATA_DIR, 'dataset_params.pkl')
        
        self.data = self.prepare_data()
        
        if model_type == 'tcn':
            perform_cv_gs_training(build_tcn, ...)
            
        if model_type == 'lstm':
            perform_cv_gs_training(build_lstm, ...)

        def reproducible_results(self):
        
            """Obtain reproducible results with keras, source: https://stackoverflow.com/a/52897216"""

            # 1. Set the `PYTHONHASHSEED` environment variable at a fixed value
            os.environ['PYTHONHASHSEED'] = str(self.seed)

            # 2. Set the `python` built-in pseudo-random generator at a fixed value
            random.seed(self.seed)

            # 3. Set the `numpy` pseudo-random generator at a fixed value
            np.random.seed(self.seed)

            # 4. Set the `tensorflow` pseudo-random generator at a fixed value
            tf.compat.v1.set_random_seed(self.seed)

            # 5. Configure a new global `tensorflow` session
            session_conf = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
            sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf)
            K.set_session(sess)

        def prepare_data(self):
            
            """Readies the data for training by splitting into training, validation and testing sets and launching the preprocessing
            
            """
            pass
        
        def perform_cv_gs_training(self, model_fn, checkpoint_filepath, flush, save_plot):

            """Performs cross-validated grid search training for selected model function

            Parameters
            ----------
            model_fn : function
                a function which creates a keras compiled model
            checkpoint_filepath : str
                a string indicating where to save the plots and grid search results 
                of the cross-validated grid search
            flush : bool
                a boolean indicating whether to flush the indicated checkpoint_filepath
            save_plot : bool
                a boolean indicating whether to save the plots depicting performance over the folds

            Returns
            -------
            pandas.DataFrame
                a pandas dataframe containing the results of the grid search, 
                specifically average performance for given hyperparameters
            """  

            pkl_name = os.path.join(checkpoint_filepath, 'gs_res.pkl')
            if os.path.isfile(pkl_name) and flush==False:
                with open(pkl_name, 'rb') as f:
                    gs_res = pkl.load(f)
            else: 
                gs_res = []

            X_train, y_train, X_val, y_val, X_test, y_test = self.data

            # Merge inputs and targets for K-fold cross validation
            inputs = np.concatenate((X_train, X_val), axis=0)
            targets = np.concatenate((y_train, y_val), axis=0)

            for idx, params in enumerate(self.model_gs_params): 

                print("=================================================")
                print("Presenting Results for: %s/%s Hyperparameter Combination" % (idx+1, len(self.model_gs_params)))

                model_params = self.model_param_setup(params)
                print(model_params)

                batch_size = model_params['batch_size']

                # Create backlog for accuracy in each fold
                val_fold_accuracy = []
                test_fold_accuracy = []

                # Define the K-fold Cross Validator
                kfold = model_selection.KFold(n_splits=self.k_folds, shuffle=True)

                # K-fold Cross Validation model evaluation
                fold_no = 1
                for train, val in kfold.split(inputs, targets):

                    try:     
                        # Increase fold number
                        print("Performing fold: %s/%s" % (fold_no, self.k_folds))
                        fold_no = fold_no + 1

                        X_iter_train, X_iter_val = inputs[train], inputs[val]
                        y_iter_train, y_iter_val = targets[train], targets[val]

                        # Prepare the training dataset
                        train_dataset = tf.data.Dataset.from_tensor_slices((X_iter_train, y_iter_train))
                        train_dataset = train_dataset.shuffle(buffer_size = 1024).batch(batch_size)

                        # Prepare the validation dataset
                        val_dataset = tf.data.Dataset.from_tensor_slices((X_iter_val, y_iter_val))
                        val_dataset = val_dataset.shuffle(buffer_size = 1024).batch(batch_size)

                        model = model_fn(X_train, **model_params)

                        # Create Tensorboard
                        logdir = "logs/scalars/" + datetime.now().strftime("%Y%m%d-%H%M%S")
                        tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir, update_freq='epoch', profile_batch=0)
                        # Model Checkpoint Callback
                        checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(checkpoint_filepath,'checkpoint'), save_weights_only=True, monitor='val_loss', mode='min', save_best_only=True)
                        # Early Stopping Callback
                        early_stopping_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience = 2)

                        # Train the model
                        training_history = model.fit(X_iter_train, y_iter_train, batch_size=batch_size, validation_data=(X_iter_val, y_iter_val), 
                                                     callbacks = [tensorboard_callback,
                                                                  early_stopping_callback,
                                                                  checkpoint_callback],
                                                     epochs=self.max_epochs, verbose=1)

            return 

