In [3]:
import re
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from tqdm import tqdm
from sklearn.metrics import classification_report, f1_score

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define Gist Token extraction function
def extract_gist(text, date):
    # Define regex patterns for dates, company names, and events
    date_pattern = r"\b\d{4}-\d{2}-\d{2}\b"  # Matches YYYY-MM-DD
    company_pattern = r"\b[A-Z][a-zA-Z]+(?:\s[A-Z][a-zA-Z]+)*\b"  # Matches proper nouns (e.g., "Apple Inc.")
    event_keywords = ["earnings", "merger", "report", "acquisition", "profit", "loss", "growth", "decline"]

    # Extract entities
    companies = re.findall(company_pattern, text)
    events = [word for word in event_keywords if word in text.lower()]

    # Create gist token
    gist = f"[GIST] {' '.join(companies)} {' '.join(events)} [DATE] {date}"
    return gist

# Define custom dataset with Gist Tokens
class GistTokenDataset(Dataset):
    def __init__(self, texts, dates, labels, tokenizer, max_len=512):
        self.texts = texts
        self.dates = dates
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        date = self.dates[idx]
        label = self.labels[idx]

        gist = extract_gist(text, date)
        encoding = self.tokenizer.encode_plus(
            gist + " " + text,
            max_length=self.max_len,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "labels": torch.tensor(label, dtype=torch.long)
        }

In [4]:
# Define model with altered attention mechanism and frozen layers
class GistEnhancedTransformer(nn.Module):
    def __init__(self, base_model_name="bert-base-uncased"):
        super(GistEnhancedTransformer, self).__init__()
        self.bert = BertModel.from_pretrained(base_model_name)
        self.hidden_size = self.bert.config.hidden_size

        # Freeze embedding and encoder layers
        for param in self.bert.embeddings.parameters():
            param.requires_grad = False
        for param in self.bert.encoder.parameters():
            param.requires_grad = False

        # Custom attention mechanism
        self.attention_layer = nn.Linear(self.hidden_size, 1)
        self.output_layer = nn.Linear(self.hidden_size, 2)  # Binary classification (increase or decrease)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_attentions=True,
            return_dict=True,
        )
        hidden_states = outputs.last_hidden_state  # Shape: (batch_size, seq_len, hidden_dim)
        attention_scores = self.attention_layer(hidden_states).squeeze(-1)  # Shape: (batch_size, seq_len)

        # Apply softmax to create attention scores
        attention_weights = torch.nn.functional.softmax(attention_scores, dim=-1)

        # Compute weighted sum of hidden states
        context_vector = torch.matmul(attention_weights.unsqueeze(1), hidden_states).squeeze(1)

        # Pass the context vector through output layer for classification
        logits = self.output_layer(context_vector)
        return logits, attention_weights


In [6]:
import pandas as pd
# Load dataset
train_data = pd.read_csv("train_stock_news.csv")
test_data = pd.read_csv("test_stock_news.csv")

# Prepare labels and text
train_labels = (train_data["Close"].diff().fillna(0) > 0).astype(int).tolist()
test_labels = (test_data["Close"].diff().fillna(0) > 0).astype(int).tolist()
train_texts = train_data["Text"].tolist()
train_dates = train_data["Date"].tolist()
test_texts = test_data["Text"].tolist()
test_dates = test_data["Date"].tolist()

# Tokenizer and dataset preparation
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
train_dataset = GistTokenDataset(train_texts, train_dates, train_labels, tokenizer)
test_dataset = GistTokenDataset(test_texts, test_dates, test_labels, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16)


In [8]:
# Initialize model
model = GistEnhancedTransformer().to(device)

# Optimizer and loss function
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-5)
criterion = nn.CrossEntropyLoss()

# Training loop
epochs = 3
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        optimizer.zero_grad()

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        logits, _ = model(input_ids, attention_mask)
        loss = criterion(logits, labels)

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader):.4f}")


Epoch 1/3: 100%|██████████| 161/161 [40:30<00:00, 15.10s/it]


Epoch 1, Loss: 0.5870


Epoch 2/3: 100%|██████████| 161/161 [40:23<00:00, 15.05s/it]


Epoch 2, Loss: 0.5094


Epoch 3/3: 100%|██████████| 161/161 [40:27<00:00, 15.08s/it]

Epoch 3, Loss: 0.4966





In [10]:
# Save the trained model
model_save_path = "saved_modified_attention_model.pth"
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

Model saved to modified_attention_model.pth


In [9]:
# Evaluate the model
model.eval()
all_preds, all_labels = [], []

with torch.no_grad():
    for batch in test_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        logits, _ = model(input_ids, attention_mask)
        preds = torch.argmax(logits, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Calculate F1 Score and Classification Report
f1 = f1_score(all_labels, all_preds, average="weighted")
print(f"F1 Score: {f1:.4f}")
print(classification_report(all_labels, all_preds, target_names=["Down/No Change", "Up"]))

F1 Score: 0.7259
                precision    recall  f1-score   support

Down/No Change       0.81      1.00      0.90       895
            Up       0.00      0.00      0.00       209

      accuracy                           0.81      1104
     macro avg       0.41      0.50      0.45      1104
  weighted avg       0.66      0.81      0.73      1104



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
