In [10]:
import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, LSTM, Embedding, Dropout, concatenate
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences


In [29]:
import pandas as pd
data = pd.read_csv("Stock Market Prediction Analysis/Combined_News_DJIA(train).csv")

print(data['Label'].value_counts())
data.head()


Label
1    1065
0     924
Name: count, dtype: int64


Unnamed: 0,Date,Label,Top1,Top2,Top3,Top4,Top5,Top6,Top7,Top8,...,Top16,Top17,Top18,Top19,Top20,Top21,Top22,Top23,Top24,Top25
0,2008-08-08,0,"b""Georgia 'downs two Russian warplanes' as cou...",b'BREAKING: Musharraf to be impeached.',b'Russia Today: Columns of troops roll into So...,b'Russian tanks are moving towards the capital...,"b""Afghan children raped with 'impunity,' U.N. ...",b'150 Russian tanks have entered South Ossetia...,"b""Breaking: Georgia invades South Ossetia, Rus...","b""The 'enemy combatent' trials are nothing but...",...,b'Georgia Invades South Ossetia - if Russia ge...,b'Al-Qaeda Faces Islamist Backlash',"b'Condoleezza Rice: ""The US would not act to p...",b'This is a busy day: The European Union has ...,"b""Georgia will withdraw 1,000 soldiers from Ir...",b'Why the Pentagon Thinks Attacking Iran is a ...,b'Caucasus in crisis: Georgia invades South Os...,b'Indian shoe manufactory - And again in a se...,b'Visitors Suffering from Mental Illnesses Ban...,"b""No Help for Mexico's Kidnapping Surge"""
1,2008-08-11,1,b'Why wont America and Nato help us? If they w...,b'Bush puts foot down on Georgian conflict',"b""Jewish Georgian minister: Thanks to Israeli ...",b'Georgian army flees in disarray as Russians ...,"b""Olympic opening ceremony fireworks 'faked'""",b'What were the Mossad with fraudulent New Zea...,b'Russia angered by Israeli military sale to G...,b'An American citizen living in S.Ossetia blam...,...,b'Israel and the US behind the Georgian aggres...,"b'""Do not believe TV, neither Russian nor Geor...",b'Riots are still going on in Montreal (Canada...,b'China to overtake US as largest manufacturer',b'War in South Ossetia [PICS]',b'Israeli Physicians Group Condemns State Tort...,b' Russia has just beaten the United States ov...,b'Perhaps *the* question about the Georgia - R...,b'Russia is so much better at war',"b""So this is what it's come to: trading sex fo..."
2,2008-08-12,0,b'Remember that adorable 9-year-old who sang a...,"b""Russia 'ends Georgia operation'""","b'""If we had no sexual harassment we would hav...","b""Al-Qa'eda is losing support in Iraq because ...",b'Ceasefire in Georgia: Putin Outmaneuvers the...,b'Why Microsoft and Intel tried to kill the XO...,b'Stratfor: The Russo-Georgian War and the Bal...,"b""I'm Trying to Get a Sense of This Whole Geor...",...,b'U.S. troops still in Georgia (did you know t...,b'Why Russias response to Georgia was right',"b'Gorbachev accuses U.S. of making a ""serious ...","b'Russia, Georgia, and NATO: Cold War Two'",b'Remember that adorable 62-year-old who led y...,b'War in Georgia: The Israeli connection',b'All signs point to the US encouraging Georgi...,b'Christopher King argues that the US and NATO...,b'America: The New Mexico?',"b""BBC NEWS | Asia-Pacific | Extinction 'by man..."
3,2008-08-13,0,b' U.S. refuses Israel weapons to attack Iran:...,"b""When the president ordered to attack Tskhinv...",b' Israel clears troops who killed Reuters cam...,b'Britain\'s policy of being tough on drugs is...,b'Body of 14 year old found in trunk; Latest (...,b'China has moved 10 *million* quake survivors...,"b""Bush announces Operation Get All Up In Russi...",b'Russian forces sink Georgian ships ',...,b'Elephants extinct by 2020?',b'US humanitarian missions soon in Georgia - i...,"b""Georgia's DDOS came from US sources""","b'Russian convoy heads into Georgia, violating...",b'Israeli defence minister: US against strike ...,b'Gorbachev: We Had No Choice',b'Witness: Russian forces head towards Tbilisi...,b' Quarter of Russians blame U.S. for conflict...,b'Georgian president says US military will ta...,b'2006: Nobel laureate Aleksander Solzhenitsyn...
4,2008-08-14,1,b'All the experts admit that we should legalis...,b'War in South Osetia - 89 pictures made by a ...,b'Swedish wrestler Ara Abrahamian throws away ...,b'Russia exaggerated the death toll in South O...,b'Missile That Killed 9 Inside Pakistan May Ha...,"b""Rushdie Condemns Random House's Refusal to P...",b'Poland and US agree to missle defense deal. ...,"b'Will the Russians conquer Tblisi? Bet on it,...",...,b'Bank analyst forecast Georgian crisis 2 days...,"b""Georgia confict could set back Russia's US r...",b'War in the Caucasus is as much the product o...,"b'""Non-media"" photos of South Ossetia/Georgia ...",b'Georgian TV reporter shot by Russian sniper ...,b'Saudi Arabia: Mother moves to block child ma...,b'Taliban wages war on humanitarian aid workers',"b'Russia: World ""can forget about"" Georgia\'s...",b'Darfur rebels accuse Sudan of mounting major...,b'Philippines : Peace Advocate say Muslims nee...


In [12]:
# 将所有新闻合并为一列
data['combined_news'] = data.iloc[:, 2:27].apply(lambda row: ''.join(str(row.values)), axis=1).replace(['\n', '\r','b\'', 'b"', '\"'], '', regex=True)

train_df, test_df = train_test_split(data, test_size=0.2, random_state=42)


In [30]:
# 建構分類器
# 加載 BERT 模型與分詞器
from transformers import BertTokenizer
from transformers import BertForSequenceClassification
from torch.utils.data import Dataset
import torch
# 初始化分詞器和模型
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

class NewsDataset(Dataset):
    def __init__(self, df, tokenizer, max_len=512):
        self.df = df
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx): 
        text = self.df.iloc[idx]['Top1']
        label = self.df.iloc[idx]['Label']

        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
        ) 

        return {
            'text': text,
            'input_ids': encoding['input_ids'],
            'attention_mask': encoding['attention_mask'],
            'label': torch.tensor(label, dtype=torch.long)
        }
    

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [31]:
# 訓練模型
from transformers import Trainer, TrainingArguments

train_dataset = NewsDataset(train_df, tokenizer, max_len=512)
test_dataset = NewsDataset(test_df, tokenizer, max_len=512)

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    warmup_steps=500,
    weight_decay=0.01,
    eval_strategy='steps',  
    logging_dir='./logs',
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset
)

trainer.train()

 76%|███████▌  | 451/597 [1:08:10<1:31:41, 37.68s/it]

In [25]:
# 模型評估
r = trainer.evaluate()
print(r)

26it [01:57,  4.52s/it]                        

{'eval_loss': 0.6880764365196228, 'eval_runtime': 108.973, 'eval_samples_per_second': 3.652, 'eval_steps_per_second': 0.229, 'epoch': 3.0}





In [18]:
# 保存模型
# model.save_pretrained('./news_model')
# tokenizer.save_pretrained('./news_model')
model.save_pretrained('./news_top1_model')
tokenizer.save_pretrained('./news_top1_model')


100%|██████████| 25/25 [01:34<00:00,  3.80s/it]


('./news_model/tokenizer_config.json',
 './news_model/special_tokens_map.json',
 './news_model/vocab.txt',
 './news_model/added_tokens.json')

In [19]:
from transformers import pipeline
from transformers import BertTokenizer, BertForSequenceClassification

# 呼叫已保存的模型
model = BertForSequenceClassification.from_pretrained('./news_model')
tokenizer = BertTokenizer.from_pretrained('./news_model')
classifier = pipeline('text-classification', model=model, tokenizer=tokenizer, device=0)

# 模擬的新聞標題
input_texts = [
    "France Cracks Down on Factory Farms - A viral video campaign has moved the govt to act. In footage shared widely online, animals writhe in pain as they bleed to death or are dismembered, in violation of rules requiring they be rendered unconscious before slaughter.",
    "The company's revenue has seen a significant increase of 20% this quarter, reflecting strong market demand and effective management strategies.",
    'Chancellor Angela Merkel would rather see the UK exit from the European Union than compromise over the principle of free movement of workers.',
    "Europe has 421 million fewer birds than 30 years ago - Study finds about 90% of a decline in the most common bird species, including grey partridges, skylarks, sparrows & starlings.",
    "A town in Portugal makes Mandarin Chinese compulsory for 8- and 9-year old school children for future competition in the thriving Chinese market.",
    "Portugal Sees Chinese Do 90% of Bids at Property Auction.",
    "Monument To Apple's Jobs Removed In Russia After CEO Comes Out.",
    "Dr. Godfrey George, medical superintendent of Kambia Government Hospital in northern Sierra Leone, has died from Ebola.",
    "Muslims in Canada mark holy day with anti-terror march.",
    "Canadian warplanes drop first bombs against Islamic State in Iraq.",
    "Suicide blast kills 52 near India - Pakistan Border (Wagah Border).",
    "Singapore joins coalition against ISIS.",
    "South Australia completes largest wind farm to date: generates 1,350 GWh and offsets nearly 1 million tonnes of carbon annually.",
    "ISIS destroying Iraq's cultural heritage: UNESCO chief.",
    "Argentina bans Procter & Gamble.",
    "106 retired Israeli generals, spy chiefs urge Netanyahu to push for peace - Diplomacy and Defense Israel News.",
    "Police are using loopholes in UK surveillance laws to gain access to peoples voicemails, texts and emails, according to an investigation by The Times.",
    "Haaretz Refuses to Back Down in Storm Over Cartoon Depicting Netanyahu as 9/11 Hijacker.",
    "Virgin Galactic will continue work on 2nd rocket plane despite crash.",
    "India to develop an API for entire government.",
    "Eleven arrested in eastern China for allegedly stealing women's corpses for use in 'ghost marriages'.",
    "France finally upgrades animals from 'furniture' status.",
    "ISIS leader Abu-Bakr al-Baghdadi orders bodies of all Kurdish fighters to be burned.",
    "Russian supply underpins global oil glut.",
    "Politicians and industry rise up against 'unjustified' metadata bill.",
    "After rapidly intensifying, Super Typhoon Nuri may be planet's strongest storm of 2014.",
    "Public opposition has cost tar sands industry $17bn, says report."
]

predictions = classifier(input_texts)

print(predictions)



[{'label': 'LABEL_1', 'score': 0.5249744057655334}, {'label': 'LABEL_1', 'score': 0.522462785243988}, {'label': 'LABEL_1', 'score': 0.5193727612495422}, {'label': 'LABEL_1', 'score': 0.5258693099021912}, {'label': 'LABEL_1', 'score': 0.526710033416748}, {'label': 'LABEL_1', 'score': 0.5315632224082947}, {'label': 'LABEL_1', 'score': 0.5147374868392944}, {'label': 'LABEL_1', 'score': 0.5319657921791077}, {'label': 'LABEL_1', 'score': 0.5278193354606628}, {'label': 'LABEL_1', 'score': 0.5274370908737183}, {'label': 'LABEL_1', 'score': 0.5340280532836914}, {'label': 'LABEL_1', 'score': 0.5257196426391602}, {'label': 'LABEL_1', 'score': 0.5319598317146301}, {'label': 'LABEL_1', 'score': 0.5247516632080078}, {'label': 'LABEL_1', 'score': 0.5193204283714294}, {'label': 'LABEL_1', 'score': 0.524553656578064}, {'label': 'LABEL_1', 'score': 0.5216874480247498}, {'label': 'LABEL_1', 'score': 0.5271680951118469}, {'label': 'LABEL_1', 'score': 0.5253353118896484}, {'label': 'LABEL_1', 'score': 0.5

In [22]:
results = pd.DataFrame(predictions)
results['text'] = input_texts
results['label'] = results['label'].map({'LABEL_0': 'down_or_no_change', 'LABEL_1': 'up'})
results = results[['text', 'label']]
print(results)
print(results['text'][0])


                                                 text label
0   France Cracks Down on Factory Farms - A viral ...    up
1   The company's revenue has seen a significant i...    up
2   Chancellor Angela Merkel would rather see the ...    up
3   Europe has 421 million fewer birds than 30 yea...    up
4   A town in Portugal makes Mandarin Chinese comp...    up
5   Portugal Sees Chinese Do 90% of Bids at Proper...    up
6   Monument To Apple's Jobs Removed In Russia Aft...    up
7   Dr. Godfrey George, medical superintendent of ...    up
8   Muslims in Canada mark holy day with anti-terr...    up
9   Canadian warplanes drop first bombs against Is...    up
10  Suicide blast kills 52 near India - Pakistan B...    up
11            Singapore joins coalition against ISIS.    up
12  South Australia completes largest wind farm to...    up
13  ISIS destroying Iraq's cultural heritage: UNES...    up
14                   Argentina bans Procter & Gamble.    up
15  106 retired Israeli generals, spy ch