In [None]:
import pandas as pd
import numpy as np
from transformers import BertModel, BertTokenizer
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

In [None]:
torch.cuda.is_available()
torch.cuda.empty_cache()

In [None]:
# CONSTANTS
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BERT_MODEL = 'google-bert/bert-base-uncased'
BATCH_SIZE = 16

In [None]:
class BERTSentimentClassifier(nn.Module):
    def __init__(self, num_classes=3):
        super().__init__()
        # Load pre-trained BERT model
        self.bert = BertModel.from_pretrained(BERT_MODEL).train().to(DEVICE)
        # Sequential block for dense layers
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.3),
            nn.Linear(self.bert.config.hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(256, num_classes)
        ).train().to(DEVICE)
        # Initialize weights
        BERTSentimentClassifier.initialize_weights(self.classifier)


    def forward(self, inputs):
        embeddings = self.bert(**inputs).pooler_output
        logits = self.classifier(embeddings)
        return logits


    @staticmethod
    def initialize_weights(model):
        with torch.no_grad():
            for m in model.modules():
                if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d, nn.Linear)):
                    nn.init.normal_(m.weight, 0.0, 0.02)

In [None]:
class NewsDataset(Dataset):
    def __init__(self, dataframe, max_length=512):
        self.tokenizer = BertTokenizer.from_pretrained(BERT_MODEL)
        self.max_length = max_length
        self.data = dataframe

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        summary_inputs = self.tokenizer(
            str(self.data['summary'][index]),
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt',
            truncation=True,
        )
        description_inputs = self.tokenizer(
            str(self.data['description'][index]),
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt',
            truncation=True,
        )

        summary_dict = {
            'input_ids': summary_inputs['input_ids'].squeeze(),
            'token_type_ids': summary_inputs['token_type_ids'].squeeze(),
            'attention_mask': summary_inputs['attention_mask'].squeeze(),
        }

        description_dict = {
            'input_ids': description_inputs['input_ids'].squeeze(),
            'token_type_ids': description_inputs['token_type_ids'].squeeze(),
            'attention_mask': description_inputs['attention_mask'].squeeze(),
        }

        return index, summary_dict, description_dict

    @staticmethod
    def collate_fn(batch):
        indices = []

        summary_input_ids = []
        summary_token_type_ids = []
        summary_attention_mask = []

        description_input_ids = []
        description_token_type_ids = []
        description_attention_mask = []

        for item in batch:
            indices.append(item[0])
            summary_input_ids.append(item[1]['input_ids'])
            summary_token_type_ids.append(item[1]['token_type_ids'])
            summary_attention_mask.append(item[1]['attention_mask'])

            description_input_ids.append(item[2]['input_ids'])
            description_token_type_ids.append(item[2]['token_type_ids'])
            description_attention_mask.append(item[2]['attention_mask'])

        summary_dict = {
            'input_ids': torch.stack(summary_input_ids),
            'token_type_ids': torch.stack(summary_token_type_ids),
            'attention_mask': torch.stack(summary_attention_mask),
        }

        description_dict = {
            'input_ids': torch.stack(description_input_ids),
            'token_type_ids': torch.stack(description_token_type_ids),
            'attention_mask': torch.stack(description_attention_mask),
        }

        return indices, summary_dict, description_dict


In [None]:
model = BERTSentimentClassifier()

state_dict = torch.load('../Models/sentiment_classification/bert_classifier.model', map_location=DEVICE)
# Remove the 'module.' prefix from the keys if present.
new_state_dict = {}
for k, v in state_dict.items():
    if k.startswith('module.'):
        new_state_dict[k[len('module.'):]] = v
    else:
        new_state_dict[k] = v
# Now load the modified state dictionary into your model.
model.load_state_dict(new_state_dict, strict=False)

model.eval()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
  state_dict = torch.load('bert_classifier.pth', map_location=DEVICE)


BERTSentimentClassifier(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, e

In [None]:
news_df = pd.read_excel('market_article_df.xlsx')
news_df.head()

Unnamed: 0,authors,datetime,description,source_url,title,url,summary,summary_vader,description_vader
0,"['Sunil Shankar Matkar', 'Sunil Matkar']",2024-05-02 01:01:05,"A short build-up was seen in 52 stocks, includ...",https://www.moneycontrol.com,Trade setup for Thursday: 15 things to know be...,https://www.moneycontrol.com/news/business/mar...,The market took a U-turn from its lifetime hi...,0.5267,0.0
1,"['Sai Aravindh', 'Live Tv', 'Stock Market', 'L...",2024-05-02 02:07:54,Market analysis suggests a further weakness fr...,https://www.ndtvprofit.com,Trade Setup For May 2: Nifty To See Further We...,https://www.ndtvprofit.com/markets/trade-setup...,The GIFT Nifty was trading 43 points or 0.19%...,0.9068,-0.128
2,['Hormaz Fatakia'],2024-05-02 07:24:08,The Nifty corrected nearly 200 points from rec...,https://www.cnbctv18.com,Trade Setup for May 2: Nifty faces pressure at...,https://www.cnbctv18.com/market/trade-setup-ma...,Nifty faces pressure at higher levels ahead o...,0.7096,0.6597
3,['Asit Manohar'],2024-05-02 07:54:13,Trade setup for Thursday: In the US Fed meetin...,https://www.livemint.com,Stock market today: Trade setup for Nifty 50 a...,https://www.livemint.com/market/stock-market-n...,Trade setup for Nifty 50 after US Fed meeting...,0.872,0.4019
4,['Hormaz Fatakia'],2024-05-02 21:45:22,"While Kotak Bank, ICICI Bank and Axis Bank kep...",https://www.cnbctv18.com,Trade Setup for May 3: Will heavyweight banks ...,https://www.cnbctv18.com/market/trade-setup-ma...,The Reserve Bank of India has lifted the rest...,-0.4019,0.4404


In [None]:
dataset = NewsDataset(news_df)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=NewsDataset.collate_fn)

In [None]:
for batch in (pbar := tqdm(loader)):
    torch.cuda.empty_cache()
    indices, summary, description = batch
    summary = {k: v.to(DEVICE) for k, v in summary.items()}
    description = {k: v.to(DEVICE) for k, v in description.items()}
    with torch.amp.autocast('cuda'):
        summary_logits = model(summary).cpu()
        description_logits = model(description).cpu()

    for idx, index in enumerate(indices):
        news_df.loc[index, 'summary_sentiment'] = summary_logits[idx].argmax().item()
        news_df.loc[index, 'description_sentiment'] = description_logits[idx].argmax().item()

  0%|          | 0/102 [00:00<?, ?it/s]

In [None]:
news_df.head()

Unnamed: 0,authors,datetime,description,source_url,title,url,summary,summary_vader,description_vader,summary_sentiment,description_sentiment
0,"['Sunil Shankar Matkar', 'Sunil Matkar']",2024-05-02 01:01:05,"A short build-up was seen in 52 stocks, includ...",https://www.moneycontrol.com,Trade setup for Thursday: 15 things to know be...,https://www.moneycontrol.com/news/business/mar...,The market took a U-turn from its lifetime hi...,0.5267,0.0,1.0,1.0
1,"['Sai Aravindh', 'Live Tv', 'Stock Market', 'L...",2024-05-02 02:07:54,Market analysis suggests a further weakness fr...,https://www.ndtvprofit.com,Trade Setup For May 2: Nifty To See Further We...,https://www.ndtvprofit.com/markets/trade-setup...,The GIFT Nifty was trading 43 points or 0.19%...,0.9068,-0.128,1.0,0.0
2,['Hormaz Fatakia'],2024-05-02 07:24:08,The Nifty corrected nearly 200 points from rec...,https://www.cnbctv18.com,Trade Setup for May 2: Nifty faces pressure at...,https://www.cnbctv18.com/market/trade-setup-ma...,Nifty faces pressure at higher levels ahead o...,0.7096,0.6597,1.0,0.0
3,['Asit Manohar'],2024-05-02 07:54:13,Trade setup for Thursday: In the US Fed meetin...,https://www.livemint.com,Stock market today: Trade setup for Nifty 50 a...,https://www.livemint.com/market/stock-market-n...,Trade setup for Nifty 50 after US Fed meeting...,0.872,0.4019,2.0,1.0
4,['Hormaz Fatakia'],2024-05-02 21:45:22,"While Kotak Bank, ICICI Bank and Axis Bank kep...",https://www.cnbctv18.com,Trade Setup for May 3: Will heavyweight banks ...,https://www.cnbctv18.com/market/trade-setup-ma...,The Reserve Bank of India has lifted the rest...,-0.4019,0.4404,1.0,1.0


In [None]:
news_df.to_excel('../Dataset/news_ratings/market_article_df_with_ratings.xlsx', index=False)