In [2]:
import pandas as pd

# Load the CSV file
file_path = './data/NEWS_YAHOO_stock_prediction.csv'
data = pd.read_csv(file_path)

# Display the first few rows of the dataframe
data.head()

Unnamed: 0.1,Unnamed: 0,ticker,Date,category,title,content,Open,High,Low,Close,Adj Close,Volume,label
0,0,AAPL,2020-01-27,opinion,Apple Set To Beat Q1 Earnings Estimates Tech ...,Technology giant Apple NASDAQ AAPL is set ...,77.514999,77.942497,76.220001,77.237503,75.793358,161940000,0
1,1,AAPL,2020-01-27,opinion,Tech Daily Intel Results Netflix Surge Appl...,The top stories in this digest are Intel s N...,77.514999,77.942497,76.220001,77.237503,75.793358,161940000,0
2,2,AAPL,2020-01-27,opinion,7 Monster Stock Market Predictions For The Wee...,S P 500 SPY \nThis week will be packed with e...,77.514999,77.942497,76.220001,77.237503,75.793358,161940000,0
3,3,AAPL,2020-01-27,opinion,Apple Earnings Preview 5G Launch Expanding S...,Reports Q1 2020 results on Tuesday Jan 28 ...,77.514999,77.942497,76.220001,77.237503,75.793358,161940000,0
4,4,AAPL,2020-01-27,opinion,Buy Surging Apple Microsoft Stock Before Qua...,On today s episode of Full Court Finance here ...,77.514999,77.942497,76.220001,77.237503,75.793358,161940000,0


In [3]:
# Step 1: Remove unnecessary column
data.drop(columns=['Unnamed: 0'], inplace=True)

# Step 2: Remove duplicate texts
data.drop_duplicates(subset=['title', 'content'], inplace=True)

# Step 3: Remove rows with large amount of spaces or empty texts in 'title' and 'content'
data = data[~data['title'].str.isspace()]
data = data[~data['content'].str.isspace()]
data.dropna(subset=['title', 'content'], inplace=True)

# Check the dataframe after these preprocessing steps
data.info()

# Step 5: Check for invalid numeric data
numeric_columns = ['Open', 'High', 'Low', 'Close', 'Adj Close', 'Volume']
data[numeric_columns].describe()


<class 'pandas.core.frame.DataFrame'>
Index: 15965 entries, 0 to 15974
Data columns (total 12 columns):
 #   Column     Non-Null Count  Dtype  
---  ------     --------------  -----  
 0   ticker     15965 non-null  object 
 1   Date       15965 non-null  object 
 2   category   15965 non-null  object 
 3   title      15965 non-null  object 
 4   content    15965 non-null  object 
 5   Open       15965 non-null  float64
 6   High       15965 non-null  float64
 7   Low        15965 non-null  float64
 8   Close      15965 non-null  float64
 9   Adj Close  15965 non-null  float64
 10  Volume     15965 non-null  int64  
 11  label      15965 non-null  int64  
dtypes: float64(5), int64(2), object(5)
memory usage: 1.6+ MB


Unnamed: 0,Open,High,Low,Close,Adj Close,Volume
count,15965.0,15965.0,15965.0,15965.0,15965.0,15965.0
mean,40.583061,40.952148,40.241173,40.605005,38.739098,153646300.0
std,11.884583,11.980327,11.799389,11.89182,12.15832,109603300.0
min,13.856071,14.271429,13.753571,13.9475,12.084597,45448000.0
25%,31.522499,31.772499,31.264999,31.475,28.576729,95174000.0
50%,40.9375,41.432499,40.602501,41.0,39.263371,121150800.0
75%,47.125,47.424999,46.695,47.037498,45.263882,169126400.0
max,80.0625,80.832497,79.379997,79.807503,78.315315,1460852000.0


In [4]:
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import pipeline

# Load the FinBERT model and tokenizer
checkpoint = 'yiyanghkust/finbert-tone'
tokenizer = BertTokenizer.from_pretrained(checkpoint)
model = BertForSequenceClassification.from_pretrained(checkpoint, num_labels=3)

# Create a pipeline for sentiment analysis
nlp = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, max_length=512, truncation=True, device=0)

2023-11-18 19:42:32.481687: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-11-18 19:42:32.842423: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
# Function to apply sentiment analysis to a dataframe
def apply_sentiment_analysis(df, nlp, text_column='content'):
    """
    Apply sentiment analysis to a column in a dataframe.
    
    Args:
    df (pd.DataFrame): Dataframe containing the text data.
    nlp (pipeline): HuggingFace pipeline for sentiment analysis.
    text_column (str): Name of the column containing text data.

    Returns:
    pd.DataFrame: Dataframe with a new column 'sentiment' containing the analysis results.
    """
    # Apply sentiment analysis to each row in the text column
    sentiments = []
    for text in df[text_column]:
        try:
            result = nlp(text)
            sentiments.append(result[0]['label'])
        except Exception as e:
            print(f"Error in processing text: {e}")
            sentiments.append('Error')

    # Add the sentiments as a new column in the dataframe
    df['sentiment'] = sentiments
    return df


In [7]:
from tqdm.auto import tqdm
import numpy as np
from concurrent.futures import ThreadPoolExecutor

def apply_sentiment_analysis_parallel(df, nlp, text_column='content', batch_size=10):
    """
    Apply sentiment analysis in parallel to a column in a dataframe.

    Args:
    df (pd.DataFrame): Dataframe containing the text data.
    nlp (pipeline): HuggingFace pipeline for sentiment analysis.
    text_column (str): Name of the column containing text data.
    batch_size (int): Number of texts to process in parallel.

    Returns:
    pd.DataFrame: Dataframe with a new column 'sentiment' containing the analysis results.
    """
    # Define a function to process a batch of texts
    def process_batch(texts):
        return [nlp(text)[0]['label'] for text in texts]

    # Break the texts into batches
    batches = [df[text_column][i:i + batch_size] for i in range(0, len(df), batch_size)]

    # Process batches in parallel
    sentiments = []
    with ThreadPoolExecutor() as executor:
        for batch_result in tqdm(executor.map(process_batch, batches), total=len(batches)):
            sentiments.extend(batch_result)

    # Add the sentiments as a new column in the dataframe
    df['sentiment'] = sentiments
    return df


In [8]:
# Example usage of the function
# Note: You will run this on your local machine as it requires GPU support
sample_texts = ["I've been waiting for a HuggingFace course my whole life.", "So have I!"]
sample_df = pd.DataFrame(sample_texts, columns=['content'])
apply_sentiment_analysis(sample_df, nlp)



Unnamed: 0,content,sentiment
0,I've been waiting for a HuggingFace course my ...,Neutral
1,So have I!,Neutral


In [9]:
# Assuming the apply_sentiment_analysis function is defined as shown previously

# Step 1: Apply sentiment analysis to the dataset
# This step should be done on your local machine due to the requirement of GPU support
data = apply_sentiment_analysis_parallel(data, nlp)

# Step 2: Prepare data for the prediction model
# Here we'll assume the sentiment analysis has been applied and 'sentiment' column is added to the data

# We might want to convert sentiments to numerical values for model training
sentiment_mapping = {'Positive': 1, 'Neutral': 0, 'Negative': -1}
data['sentiment_numeric'] = data['sentiment'].map(sentiment_mapping)



  0%|          | 0/1597 [00:00<?, ?it/s]

In [10]:
# Example code to save the processed DataFrame to a CSV file
data.to_csv('./data/dataset_with_sentiment.csv', index=False)


                                                 content sentiment
4418   Wall Street declined sharply on Friday as init...  Negative
13909  The market looked dead a few days back  fallin...   Neutral
15520  The week ahead brings a steady stream of earni...  Positive
12430  By Nate Raymond NEW YORK  Reuters    A U S  ap...   Neutral
950     Reuters    Apple Inc  NASDAQ AAPL  on Monday ...  Positive
5428   I hate to say it     and I m probably the last...  Positive
6557   By Paul Sandle LONDON  Reuters    WhatsApp  th...   Neutral
4692   For Immediate ReleaseChicago  IL   November 20...  Positive
2517   Recently Smith   Nephew  LON SN  plc   NYSE SN...  Positive
15014  Investing com   U S  stock futures pointed to ...   Neutral


In [12]:
# Adjust display settings for better visualization of samples
pd.set_option('display.max_colwidth', 200)  # Adjust the width to fit longer texts

# Display some random samples with formatted output
sample_data = data.sample(n=10)[['content', 'sentiment']]

# Print each sample in a more readable format
for index, row in sample_data.iterrows():
    print(f"Sample {index}:")
    print(f"Content: {row['content']}")
    print(f"Sentiment: {row['sentiment']}\n")

# Note: Since the actual data is not available here, this code is meant to be run on your local machine.


Sample 15309:
Content: Did China just put the last nail in the interest rate coffin  Did its move last week erase the idea that interest rates will ever return to  normal  in our lifetimes 
I have no doubt Janet Yellen s head slammed the table in frustration last week  As China devalued the value of its yuan  Yellen s job got much  much harder 
In fact  we d argue the global economic picture has never been quite so muddled and quite so dangerous  What has us worried is the growing gap between monetary policies across the world 
We re convinced the traditional idea of interest rates  at least in our investing lifetimes  is dead  As much as Yellen and her troops want to nudge rates higher and reload their ammo belt  divergent forces from across the globe are forcing them to move cautiously 
The news out of China last week is a perfect example  As Beijing works to decouple its yuan from the rising dollar  it puts an even brighter spotlight on the rising dollar  The greenback is an oasis i

In [13]:
# Assuming 'data' is your DataFrame with 'sentiment' and 'label' columns
# Calculate the proportion of each sentiment category
sentiment_counts = data['sentiment'].value_counts(normalize=True) * 100

# Calculate the proportion of each label
label_counts = data['label'].value_counts(normalize=True) * 100

# Print the results
print("Sentiment Distribution (%):")
print(sentiment_counts)
print("\nLabel Distribution (%):")
print(label_counts)

# For additional insights, we can also look at the cross-tabulation of sentiment and label
crosstab = pd.crosstab(data['sentiment'], data['label'], normalize='index') * 100
print("\nCross-Tabulation of Sentiment and Label (%):")
print(crosstab)

# Note: Replace 'sentiment' and 'label' with the actual column names if they are different in your dataset.
# The actual execution of these lines should be done on your local machine.


Sentiment Distribution (%):
sentiment
Positive    46.664579
Neutral     33.529596
Negative    19.805825
Name: proportion, dtype: float64

Label Distribution (%):
label
1    55.34607
0    44.65393
Name: proportion, dtype: float64

Cross-Tabulation of Sentiment and Label (%):
label              0          1
sentiment                      
Negative   50.031626  49.968374
Neutral    45.731366  54.268634
Positive   41.597315  58.402685


- 情感与股价变动的关联：存在一定的关联性，尤其是在Positive情感类别中，股价上升的比例高于下降的比例。

- 中性情感的影响：即便是被分类为Neutral的文本，也呈现出股价上升的倾向，这可能表明即使是中立的新闻或评论，也可能含有一些对市场有影响的信息。

- 情感分析的限制：虽然可以观察到一定的关联趋势，但这种关联并不是绝对的。例如，即使是Negative情感的文本，也有接近一半的情况下股价是上升的。

---
- The correlation between emotions and stock price changes: There is a certain degree of correlation, especially in the Positive emotion category, where the proportion of stock price increases is higher than decreases.

- The impact of neutral emotions: Even texts classified as Neutral show a tendency for stock prices to rise. This may indicate that even neutral news or comments may contain information that affects the market.

- Limitations of sentiment analysis: Although some correlation trends can be observed, this correlation is not absolute. For example, even in cases where the text has Negative emotions, there are close to half instances where stock prices still rise.


In [18]:
# Convert the 'Date' column to datetime format and sort the dataframe by 'Date'
data['Date'] = pd.to_datetime(data['Date'])
data_sorted = data.sort_values(by='Date')

# Group by 'Date' and take the last entry of the day as it contains the closing information
# This will be used to represent the trading information of the day
daily_data = data_sorted.groupby('Date').last().reset_index()

# Shift the 'Open' column to get the next day's opening price
daily_data['Next_Open'] = daily_data['Open'].shift(-1)

# Drop the last row as it will not have a 'Next_Open' value
daily_data = daily_data[:-1]

# Display the first few rows of the adjusted dataframe
daily_data.head(20)

Unnamed: 0,Date,ticker,category,title,content,Open,High,Low,Close,Adj Close,Volume,label,sentiment,sentiment_numeric,Next_Open
0,2012-07-23,AAPL,opinion,Summer Heat Scorches Europe And U S,Europe flares as summer heat continues Summer heat covers the nation as Europe s debt crisis flares again On My Wall Street Radar In the chart of the S P 500 above NYSEARCA SPY we can see how ...,21.228571,21.639286,20.989643,21.565357,18.413187,487975600,1,Negative,-1,21.692142
1,2012-07-24,AAPL,opinion,Market Bait And Switch,That is the sound we are going to hear soon from the BTFD crowd So far in 2012 we have been conditioned to BTFD Our brain has been wired to the safety net named Bernanke That if the stock market...,21.692142,21.774286,21.375357,21.46143,18.324455,565132400,0,Neutral,0,20.536072
2,2012-07-27,AAPL,opinion,Will AAPL Fall From The Tree,Apple s AAPL sales for the third quarter missed estimates due to weakness in European economy and a pause in iPhone sales ahead of the release of a new version Following the dissapointing third...,20.536072,20.922501,20.413929,20.898571,17.84387,403936400,1,Negative,-1,21.104286
3,2012-07-30,AAPL,opinion,Bulls Snatch Victory From Jaws of Defeat,Last week the bulls pulled another save out of their hat turning what initially looked like a losing week into a respectable advance to a new 12 week high The broad market rebound came on the...,21.104286,21.408571,20.99357,21.251072,18.144846,379142400,1,Negative,-1,21.543928
4,2012-07-31,AAPL,opinion,What s Driving China s Real Estate Rally Part 3,In the preceding posts I examined the first two out of five basic theories that might explain the latest bump in China s property sales numbers and whether they portend a genuine turn around in ...,21.543928,21.84643,21.525715,21.812857,18.624512,462327600,1,Neutral,0,23.102858
5,2012-08-10,AAPL,opinion,Good Knight Public Markets,With all of the recent financial scandals and trading losses the last thing Wall Street needed was another humiliation But those black eyes just keep coming for public trading markets Last week ...,23.102858,23.127144,22.718214,22.791786,19.5439,637994000,0,Negative,-1,22.566786
6,2012-08-14,AAPL,opinion,VIX Is Under 14 Now What,OK trader buddies I bet this is something you didn t think would happen so soon after the euro Italian Spanish etc crises VIX slid below 14 briefly today and what s coming next could be a big ...,22.566786,22.807501,22.5075,22.560356,19.345457,340169200,0,Negative,-1,22.546429
7,2012-08-15,AAPL,opinion,Why Bearish Short Term Still Buying Stocks Long Term,Some of you do not understand how I can be bearish on the markets for the rest of this year yet continue to recommend going long equities The answer is that there are three time horizons for inve...,22.546429,22.642857,22.419643,22.529642,19.319118,257342400,0,Neutral,0,22.543215
8,2012-08-16,AAPL,opinion,AAPL Trade Considerations For Option Expiration Week,Option trading is not just the same as stocks It turns on the three primal forces ruling an options trader s world time to expiration price of the underlying and implied volatility As expira...,22.543215,22.741428,22.517857,22.726429,19.487867,254534000,1,Neutral,0,22.857143
9,2012-08-17,AAPL,opinion,Pressure Released Into A Breakout Apple On The Edge,T2108 Status 73 6 second day of fifth overbought period since June 29 VIX Status 14 3General Short term Trading Call Hold Commentary finally got released today As I mentioned earlier desp...,22.857143,23.149643,22.814644,23.146786,19.848316,442761200,1,Neutral,0,23.957857


In [20]:
# Convert the 'Date' column to datetime format and sort the dataframe by 'Date'
data['Date'] = pd.to_datetime(data['Date'])
data_sorted = data.sort_values(by='Date')

# Group by 'Date' and 'category', then calculate the sum of sentiment scores for each category
sentiment_sum_per_day_category = data_sorted.groupby(['Date', 'category'])['sentiment_numeric'].sum().unstack().reset_index()

# Merge this sentiment data with the daily_data
daily_data_merged = pd.merge(daily_data, sentiment_sum_per_day_category, on='Date', how='left')

# Display the first few rows of the adjusted dataframe
daily_data_merged.tail()


Unnamed: 0,Date,ticker,category,title,content,Open,High,Low,Close,Adj Close,Volume,label,sentiment,sentiment_numeric,Next_Open,news,opinion
0,2012-07-23,AAPL,opinion,Summer Heat Scorches Europe And U S,Europe flares as summer heat continues Summer heat covers the nation as Europe s debt crisis flares again On My Wall Street Radar In the chart of the S P 500 above NYSEARCA SPY we can see how ...,21.228571,21.639286,20.989643,21.565357,18.413187,487975600,1,Negative,-1,21.692142,,-1.0
1,2012-07-24,AAPL,opinion,Market Bait And Switch,That is the sound we are going to hear soon from the BTFD crowd So far in 2012 we have been conditioned to BTFD Our brain has been wired to the safety net named Bernanke That if the stock market...,21.692142,21.774286,21.375357,21.46143,18.324455,565132400,0,Neutral,0,20.536072,,0.0
2,2012-07-27,AAPL,opinion,Will AAPL Fall From The Tree,Apple s AAPL sales for the third quarter missed estimates due to weakness in European economy and a pause in iPhone sales ahead of the release of a new version Following the dissapointing third...,20.536072,20.922501,20.413929,20.898571,17.84387,403936400,1,Negative,-1,21.104286,,-1.0
3,2012-07-30,AAPL,opinion,Bulls Snatch Victory From Jaws of Defeat,Last week the bulls pulled another save out of their hat turning what initially looked like a losing week into a respectable advance to a new 12 week high The broad market rebound came on the...,21.104286,21.408571,20.99357,21.251072,18.144846,379142400,1,Negative,-1,21.543928,,-1.0
4,2012-07-31,AAPL,opinion,What s Driving China s Real Estate Rally Part 3,In the preceding posts I examined the first two out of five basic theories that might explain the latest bump in China s property sales numbers and whether they portend a genuine turn around in ...,21.543928,21.84643,21.525715,21.812857,18.624512,462327600,1,Neutral,0,23.102858,,1.0


In [28]:
daily_data_merged.head(10)

Unnamed: 0,Date,ticker,category,title,content,Open,High,Low,Close,Adj Close,Volume,label,Next_Open,news,opinion
0,2012-07-23,AAPL,opinion,Summer Heat Scorches Europe And U S,Europe flares as summer heat continues Summer heat covers the nation as Europe s debt crisis flares again On My Wall Street Radar In the chart of the S P 500 above NYSEARCA SPY we can see how ...,21.228571,21.639286,20.989643,21.565357,18.413187,487975600,1,21.692142,0.0,-1.0
1,2012-07-24,AAPL,opinion,Market Bait And Switch,That is the sound we are going to hear soon from the BTFD crowd So far in 2012 we have been conditioned to BTFD Our brain has been wired to the safety net named Bernanke That if the stock market...,21.692142,21.774286,21.375357,21.46143,18.324455,565132400,0,20.536072,0.0,0.0
2,2012-07-27,AAPL,opinion,Will AAPL Fall From The Tree,Apple s AAPL sales for the third quarter missed estimates due to weakness in European economy and a pause in iPhone sales ahead of the release of a new version Following the dissapointing third...,20.536072,20.922501,20.413929,20.898571,17.84387,403936400,1,21.104286,0.0,-1.0
3,2012-07-30,AAPL,opinion,Bulls Snatch Victory From Jaws of Defeat,Last week the bulls pulled another save out of their hat turning what initially looked like a losing week into a respectable advance to a new 12 week high The broad market rebound came on the...,21.104286,21.408571,20.99357,21.251072,18.144846,379142400,1,21.543928,0.0,-1.0
4,2012-07-31,AAPL,opinion,What s Driving China s Real Estate Rally Part 3,In the preceding posts I examined the first two out of five basic theories that might explain the latest bump in China s property sales numbers and whether they portend a genuine turn around in ...,21.543928,21.84643,21.525715,21.812857,18.624512,462327600,1,23.102858,0.0,1.0
5,2012-08-10,AAPL,opinion,Good Knight Public Markets,With all of the recent financial scandals and trading losses the last thing Wall Street needed was another humiliation But those black eyes just keep coming for public trading markets Last week ...,23.102858,23.127144,22.718214,22.791786,19.5439,637994000,0,22.566786,0.0,-1.0
6,2012-08-14,AAPL,opinion,VIX Is Under 14 Now What,OK trader buddies I bet this is something you didn t think would happen so soon after the euro Italian Spanish etc crises VIX slid below 14 briefly today and what s coming next could be a big ...,22.566786,22.807501,22.5075,22.560356,19.345457,340169200,0,22.546429,0.0,-2.0
7,2012-08-15,AAPL,opinion,Why Bearish Short Term Still Buying Stocks Long Term,Some of you do not understand how I can be bearish on the markets for the rest of this year yet continue to recommend going long equities The answer is that there are three time horizons for inve...,22.546429,22.642857,22.419643,22.529642,19.319118,257342400,0,22.543215,0.0,0.0
8,2012-08-16,AAPL,opinion,AAPL Trade Considerations For Option Expiration Week,Option trading is not just the same as stocks It turns on the three primal forces ruling an options trader s world time to expiration price of the underlying and implied volatility As expira...,22.543215,22.741428,22.517857,22.726429,19.487867,254534000,1,22.857143,0.0,0.0
9,2012-08-17,AAPL,opinion,Pressure Released Into A Breakout Apple On The Edge,T2108 Status 73 6 second day of fifth overbought period since June 29 VIX Status 14 3General Short term Trading Call Hold Commentary finally got released today As I mentioned earlier desp...,22.857143,23.149643,22.814644,23.146786,19.848316,442761200,1,23.957857,0.0,0.0


In [None]:
# Fill NaN values in 'news' and 'opinions' columns with 0
daily_data_merged['news'].fillna(0, inplace=True)
daily_data_merged['opinion'].fillna(0, inplace=True)

# Optionally, if you want to use average values instead of 0, you can use the following:
# mean_news = daily_data_merged['news'].mean()
# mean_opinions = daily_data_merged['opinions'].mean()
# daily_data_merged['news'].fillna(mean_news, inplace=True)
# daily_data_merged['opinions'].fillna(mean_opinions, inplace=True)

# Remove some columns that are no longer needed
if 'sentiment' in daily_data_merged.columns and 'sentiment_numeric' in daily_data_merged.columns:
    daily_data_merged.drop(columns=['sentiment', 'sentiment_numeric'], inplace=True)

if 'category' in daily_data_merged.columns and 'title' in daily_data_merged.columns and 'content' in daily_data_merged.columns:
    daily_data_merged.drop(columns=['category', 'title', 'content'], inplace=True)

In [35]:
# Display the first few rows of the adjusted dataframe
daily_data_merged.head(10)

Unnamed: 0,Date,ticker,Open,High,Low,Close,Adj Close,Volume,label,Next_Open,news,opinion
0,2012-07-23,AAPL,21.228571,21.639286,20.989643,21.565357,18.413187,487975600,1,21.692142,0.0,-1.0
1,2012-07-24,AAPL,21.692142,21.774286,21.375357,21.46143,18.324455,565132400,0,20.536072,0.0,0.0
2,2012-07-27,AAPL,20.536072,20.922501,20.413929,20.898571,17.84387,403936400,1,21.104286,0.0,-1.0
3,2012-07-30,AAPL,21.104286,21.408571,20.99357,21.251072,18.144846,379142400,1,21.543928,0.0,-1.0
4,2012-07-31,AAPL,21.543928,21.84643,21.525715,21.812857,18.624512,462327600,1,23.102858,0.0,1.0
5,2012-08-10,AAPL,23.102858,23.127144,22.718214,22.791786,19.5439,637994000,0,22.566786,0.0,-1.0
6,2012-08-14,AAPL,22.566786,22.807501,22.5075,22.560356,19.345457,340169200,0,22.546429,0.0,-2.0
7,2012-08-15,AAPL,22.546429,22.642857,22.419643,22.529642,19.319118,257342400,0,22.543215,0.0,0.0
8,2012-08-16,AAPL,22.543215,22.741428,22.517857,22.726429,19.487867,254534000,1,22.857143,0.0,0.0
9,2012-08-17,AAPL,22.857143,23.149643,22.814644,23.146786,19.848316,442761200,1,23.957857,0.0,0.0


In [34]:
daily_data_merged.tail(20)

Unnamed: 0,Date,ticker,Open,High,Low,Close,Adj Close,Volume,label,Next_Open,news,opinion
1633,2019-12-16,AAPL,69.25,70.197502,69.245003,69.964996,68.656837,128186000,1,69.892502,2.0,5.0
1634,2019-12-17,AAPL,69.892502,70.442497,69.699997,70.102501,68.791771,114158400,1,69.949997,1.0,6.0
1635,2019-12-18,AAPL,69.949997,70.474998,69.779999,69.934998,68.627396,116028400,0,69.875,2.0,3.0
1636,2019-12-19,AAPL,69.875,70.294998,69.737503,70.004997,68.696083,98369200,1,70.557503,0.0,33.0
1637,2019-12-20,AAPL,70.557503,70.662498,69.639999,69.860001,68.553818,275978000,0,70.1325,5.0,13.0
1638,2019-12-23,AAPL,70.1325,71.0625,70.092499,71.0,69.672501,98572000,1,71.172501,5.0,6.0
1639,2019-12-24,AAPL,71.172501,71.222504,70.730003,71.067497,69.738731,48478800,0,71.205002,2.0,1.0
1640,2019-12-26,AAPL,71.205002,72.495003,71.175003,72.477501,71.12236,93121200,1,72.779999,1.0,6.0
1641,2019-12-27,AAPL,72.779999,73.4925,72.029999,72.449997,71.095383,146266000,0,72.364998,2.0,4.0
1642,2019-12-30,AAPL,72.364998,73.172501,71.305,72.879997,71.517342,144114400,1,72.482498,6.0,3.0


In [32]:
# Select the features - note that we are not using the 'Open' column as a feature
features = daily_data_merged[['Open', 'High', 'Low', 'Close', 'Volume', 'news', 'opinion']]
target = daily_data_merged['Next_Open']

# normalization
from sklearn.preprocessing import MinMaxScaler

# Apply the MinMaxScaler to the features and target
scaler_features = MinMaxScaler()
scaler_target = MinMaxScaler()

scaled_features = scaler_features.fit_transform(features)
scaled_target = scaler_target.fit_transform(target.values.reshape(-1, 1))

# Create new DataFrames with the scaled features and target
scaled_features_df = pd.DataFrame(scaled_features, columns=features.columns)
scaled_target_df = pd.DataFrame(scaled_target, columns=['Next_Open'])

# Display the first few rows of the scaled dataframe
scaled_features_df.head(), scaled_target_df.head()

(       Open      High       Low     Close    Volume      news   opinion
 0  0.111356  0.110693  0.110262  0.115667  0.312651  0.736842  0.114583
 1  0.118358  0.112721  0.116139  0.114089  0.367163  0.736842  0.125000
 2  0.100897  0.099924  0.101489  0.105543  0.253276  0.736842  0.114583
 3  0.109479  0.107227  0.110321  0.110895  0.235759  0.736842  0.114583
 4  0.116119  0.113805  0.118430  0.119425  0.294530  0.736842  0.135417,
    Next_Open
 0   0.118358
 1   0.100897
 2   0.109479
 3   0.116119
 4   0.139666)

In [33]:
# Example code to save the processed DataFrame to a CSV file
data.to_csv('./data/dataset_for_model.csv', index=False)

In [38]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from transformers import TimeSeriesTransformerForPrediction
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

# 将数据转换为Tensor
features_tensor = torch.tensor(scaled_features, dtype=torch.float32)
target_tensor = torch.tensor(scaled_target, dtype=torch.float32).unsqueeze(-1)

past_observed_mask = torch.ones_like(features_tensor, dtype=torch.bool)

# 划分训练集和测试集
train_features, test_features, train_target, test_target = train_test_split(features_tensor, target_tensor, test_size=0.2, random_state=42)

# 创建TensorDataset
train_dataset = TensorDataset(train_features, train_target)
test_dataset = TensorDataset(test_features, test_target)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [39]:
# 创建模型
model = TimeSeriesTransformerForPrediction.from_pretrained("huggingface/time-series-transformer-tourism-monthly", num_labels=1)

Downloading config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/151k [00:00<?, ?B/s]

In [41]:
# 训练循环
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.MSELoss()

# 设置超参数
num_epochs = 10  # 定义训练的迭代次数

for epoch in range(num_epochs):
    model.train()
    for batch in train_loader:
        outputs = model(
        past_values=batch["past_values"],
        past_time_features=batch["past_time_features"],
        past_observed_mask=batch["past_observed_mask"],
        static_categorical_features=batch["static_categorical_features"],
        static_real_features=batch["static_real_features"],
        future_values=batch["future_values"],
        future_time_features=batch["future_time_features"],
        )

        loss = outputs.loss
        loss.backward()

        # # 分离特征和标签
        # inputs, labels = batch

        # # 根据inputs准备past_observed_mask
        # # 如果没有缺失值，可以使用全为True的张量
        # past_observed_mask = torch.ones_like(inputs, dtype=torch.bool)

        # # 重置梯度
        # optimizer.zero_grad()
        
        # # 前向传播，确保使用正确的参数
        # outputs = model(inputs, past_observed_mask)

        # # 计算损失
        # loss = criterion(outputs, labels)

        # # 后向传播和优化
        # loss.backward()
        optimizer.step()

        # 累计损失
        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

# 保存模型
torch.save(model.state_dict(), "model.pth")


TypeError: TimeSeriesTransformerForPrediction.forward() missing 1 required positional argument: 'past_observed_mask'

In [None]:
# 预测和评估
model.eval()
predictions, actuals = [], []
with torch.no_grad():
    for batch in test_loader:
        inputs, labels = batch
        outputs = model(inputs)
        predictions.extend(outputs.numpy())
        actuals.extend(labels.numpy())

# 假设 predictions 和 actuals 是模型的预测结果和实际目标值
predictions_tensor = torch.tensor(predictions, dtype=torch.float32)
actuals_tensor = torch.tensor(actuals, dtype=torch.float32)

# 将预测结果和实际值转换回原始尺度
predicted_prices = scaler_target.inverse_transform(predictions_tensor.numpy())
actual_prices = scaler_target.inverse_transform(actuals_tensor.numpy())

# 可视化预测结果
plt.figure(figsize=(10,6))
plt.plot(actuals, label='Actual')
plt.plot(predictions, label='Predicted')
plt.title('Time Series Prediction')
plt.xlabel('Time')
plt.ylabel('Normalized Price')
plt.legend()
plt.show()