In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
import pandas as pd
import numpy as np
import torch
from transformers import BertTokenizer, BertForSequenceClassification, BertModel
from transformers import T5Tokenizer, T5ForConditionalGeneration

# **Part 1: news sentiment classification with Bert model**

Load news data from apple as training set

In [4]:
df=pd.read_csv('stock_news/aapl.csv')
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

Load pre-trained Bert Model

In [13]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
model = BertModel.from_pretrained('bert-base-cased')

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## 2-classification task
Use Bert to classify news to 2 categories: Positive, Negative

In [14]:
from bert_classification import Bert_Classifier
bert_clf = Bert_Classifier(model, tokenizer, device, df, 2)
bert_clf.train()
bert_clf.test()

The accuracy of BERT mode on news 2-classification task is: 0.6427104722792608


## 3-classification task
Use Bert to classify news to 3 categories: Positive, Negative, Neutral

In [15]:
from bert_classification import Bert_Classifier
bert_clf3 = Bert_Classifier(model, tokenizer, device, df, 3)
bert_clf3.train()
bert_clf3.test()

The accuracy of BERT mode on news 3-classification task is: 0.4574948665297741


# **Part 2: news sentiment classification with T5 model**

load pretrained t5 model

In [6]:
tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5ForConditionalGeneration.from_pretrained('t5-small',use_cache='False')

## 2-classification task
Use T5 to classify news to 2 categories: Positive, Negative

In [7]:
from T5_classification import T5_Classifier
t5_clf = T5_Classifier(model, tokenizer, device, df, 2)
t5_clf.train()




Training...




  Batch    40  of    609.
  Batch    80  of    609.
  Batch   120  of    609.
  Batch   160  of    609.
  Batch   200  of    609.
  Batch   240  of    609.
  Batch   280  of    609.
  Batch   320  of    609.
  Batch   360  of    609.
  Batch   400  of    609.
  Batch   440  of    609.
  Batch   480  of    609.
  Batch   520  of    609.
  Batch   560  of    609.
  Batch   600  of    609.

summary results
epoch | trn loss | trn time 
    1 | 0.76187 | 0:05:41

Training...
  Batch    40  of    609.
  Batch    80  of    609.
  Batch   120  of    609.
  Batch   160  of    609.
  Batch   200  of    609.
  Batch   240  of    609.
  Batch   280  of    609.
  Batch   320  of    609.
  Batch   360  of    609.
  Batch   400  of    609.
  Batch   440  of    609.
  Batch   480  of    609.
  Batch   520  of    609.
  Batch   560  of    609.
  Batch   600  of    609.

summary results
epoch | trn loss | trn time 
    2 | 0.22212 | 0:05:41

Training...
  Batch    40  of    609.
  Batch    80  of    609

In [8]:
test_stat, test_result = t5_clf.test()


Running Testing...
  Batch    40  of    203.    Elapsed: 0:00:16.
  Batch    80  of    203.    Elapsed: 0:00:31.
  Batch   120  of    203.    Elapsed: 0:00:47.
  Batch   160  of    203.    Elapsed: 0:01:02.
  Batch   200  of    203.    Elapsed: 0:01:18.


In [9]:
print(test_stat)
test_result

[{'Test Loss': 0.2059948617234606, 'Test PPL.': 1.2287468903377718, 'Test Acc.': 0.6645768025078367, 'Test F1': 0.6920763777542512}]


Unnamed: 0,predicted,actual
0,positive,negative
1,positive,negative
2,positive,negative
3,negative,positive
4,negative,positive
...,...,...
2430,negative,negative
2431,positive,positive
2432,positive,positive
2433,positive,positive


## 3-classification task
Use T5 to classify news to 3 categories: Positive, Negative, Neutral

In [10]:
t5_clf3 = T5_Classifier(model, tokenizer, device, df, 3)
t5_clf3.train()




Training...




  Batch    40  of    609.
  Batch    80  of    609.
  Batch   120  of    609.
  Batch   160  of    609.
  Batch   200  of    609.
  Batch   240  of    609.
  Batch   280  of    609.
  Batch   320  of    609.
  Batch   360  of    609.
  Batch   400  of    609.
  Batch   440  of    609.
  Batch   480  of    609.
  Batch   520  of    609.
  Batch   560  of    609.
  Batch   600  of    609.

summary results
epoch | trn loss | trn time 
    1 | 0.40769 | 0:05:41

Training...
  Batch    40  of    609.
  Batch    80  of    609.
  Batch   120  of    609.
  Batch   160  of    609.
  Batch   200  of    609.
  Batch   240  of    609.
  Batch   280  of    609.
  Batch   320  of    609.
  Batch   360  of    609.
  Batch   400  of    609.
  Batch   440  of    609.
  Batch   480  of    609.
  Batch   520  of    609.
  Batch   560  of    609.
  Batch   600  of    609.

summary results
epoch | trn loss | trn time 
    2 | 0.35968 | 0:05:41

Training...
  Batch    40  of    609.
  Batch    80  of    609

In [11]:
test_stat, test_result = t5_clf.test()


Running Testing...
  Batch    40  of    203.    Elapsed: 0:00:16.
  Batch    80  of    203.    Elapsed: 0:00:31.
  Batch   120  of    203.    Elapsed: 0:00:47.
  Batch   160  of    203.    Elapsed: 0:01:02.
  Batch   200  of    203.    Elapsed: 0:01:18.


In [12]:
print(test_stat)
test_result

[{'Test Loss': 0.3839951871063909, 'Test PPL.': 1.4681383756711799, 'Test Acc.': 0.4345051500223913, 'Test F1': 0.3620878628358344}]


Unnamed: 0,predicted,actual
0,negative,positive
1,positive,positive
2,positive,negative
3,negative,negative
4,neutral,positive
...,...,...
2430,neutral,negative
2431,neutral,positive
2432,negative,negative
2433,positive,positive
