In [1]:
from typing import List
import os, json
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
import re

In [2]:
!git clone https://github.com/koa-fin/sn2.git

Cloning into 'sn2'...
remote: Enumerating objects: 81318, done.[K
remote: Counting objects: 100% (3/3), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 81318 (delta 0), reused 2 (delta 0), pack-reused 81315[K
Receiving objects: 100% (81318/81318), 147.74 MiB | 14.26 MiB/s, done.
Resolving deltas: 100% (12217/12217), done.
Updating files: 100% (81242/81242), done.


In [4]:

class Summarizer:


    def get_summary(self, ticker, tweets):
        summary = None
        if tweets != []:
            summary = "\n".join(tweets)

        return summary


    def is_informative(self, summary):
         neg = r'.*[nN]o.*information.*|.*[nN]o.*facts.*|.*[nN]o.*mention.*|.*[nN]o.*tweets.*|.*do not contain.*'
         return not re.match(neg, summary)


In [5]:
import os
import json
import numpy as np
import pandas as pd
from datetime import datetime, timedelta

class DataLoader:
    def __init__(self):
        self.price_dir = "/content/sn2/price/preprocessed"
        self.tweet_dir = "/content/sn2/tweet/raw"
        self.seq_len = 5
        self.summarizer = Summarizer()

    def daterange(self, start_date, end_date):
        for n in range(int((end_date - start_date).days)):
            yield start_date + timedelta(n)

    def get_sentiment(self, date_str, price_path):
        price_data = np.genfromtxt(price_path, dtype=str, skip_header=False)
        price_chg = price_data[price_data[:, 0] == date_str][0, 1].astype(float)
        if price_chg > 0.0:
            sentiment = "Positive"
        else:
            sentiment = "Negative"
        return sentiment

    def get_tweets(self, ticker, date_str):
        tweets = []
        tweet_path = os.path.join(self.tweet_dir, ticker, date_str)
        if os.path.exists(tweet_path):
            with open(tweet_path) as f:
                lines = f.readlines()
                for line in lines:
                    tweet_obj = json.loads(line)
                    tweets.append(tweet_obj['text'])
        return tweets

    def load(self, flag):
        for file in os.listdir(self.price_dir):
            price_path = os.path.join(self.price_dir, file)
            ordered_price_data = np.flip(np.genfromtxt(price_path, dtype=str, skip_header=False), 0)
            ticker = file[:-4]
            tes_idx = round(len(ordered_price_data) * 0.8)
            end_idx = len(ordered_price_data)

            if flag == "train":
                data_range = range(tes_idx)
            else:
                data_range = range(tes_idx, end_idx)
                print("data_range", data_range)

            ticker_data = pd.DataFrame()

            for idx in data_range:
                summary_all = ""
                end_date_str = ordered_price_data[idx, 0]
                end_date = datetime.strptime(end_date_str, "%Y-%m-%d")
                start_date = end_date - timedelta(days=self.seq_len)
                target = self.get_sentiment(end_date_str, price_path)

                for seq_date in self.daterange(start_date, end_date):
                    seq_date_str = seq_date.strftime("%Y-%m-%d")
                    tweet_data = self.get_tweets(ticker, seq_date_str)
                    summary = self.summarizer.get_summary(ticker, tweet_data)

                    if summary and summary is not None and summary != "" and self.summarizer.is_informative(summary):
                        summary_all = summary_all + seq_date_str + "\n" + summary + "\n\n"

                if summary_all != "":
                    row = pd.DataFrame([{'ticker': ticker, 'summary': summary_all.rstrip(), 'target': target}])
                    ticker_data = pd.concat([ticker_data, row], ignore_index=True)

            if not ticker_data.empty:
                ticker_data.to_csv(f"{ticker}.csv", index=False)

        return

# Example usage
data_loader = DataLoader()
data_loader.load("train")  # or data_loader.load("test") depending on the flag
