In [None]:
!pip install -q transformers datasets scikit-learn

In [None]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, get_scheduler
from transformers import AdamW
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import numpy as np
from tqdm import tqdm

In [None]:
# Sample financial text dataset
data = {
    'text': [
        "The company posted a net profit of 10 million dollars.",
        "Revenue declined by 15% in the last quarter.",
        "There was a significant increase in operational costs.",
        "Positive growth forecast for next year.",
        "Stock prices are expected to rise after merger announcement."
    ],
    'label': ['positive', 'negative', 'negative', 'positive', 'positive']
}

df = pd.DataFrame(data)

# Encode labels
label_encoder = LabelEncoder()
df['label_enc'] = label_encoder.fit_transform(df['label'])

# Split data
train_texts, val_texts, train_labels, val_labels = train_test_split(
    df['text'].tolist(), df['label_enc'].tolist(), test_size=0.2, random_state=42
)

In [None]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

class FinancialDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        encoding = self.tokenizer(
            self.texts[idx],
            padding="max_length",
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt"
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }

train_dataset = FinancialDataset(train_texts, train_labels, tokenizer)
val_dataset = FinancialDataset(val_texts, val_labels, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
model.to(device)

In [None]:
optimizer = AdamW(model.parameters(), lr=5e-5)

num_training_steps = len(train_loader) * 3  # 3 epochs
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

In [None]:
model.train()

for epoch in range(3):  # 3 epochs
    total_loss = 0
    for batch in tqdm(train_loader):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        total_loss += loss.item()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

    print(f"Epoch {epoch + 1} Loss: {total_loss / len(train_loader)}")

In [None]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch in val_loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        predictions = torch.argmax(outputs.logits, dim=-1)
        correct += (predictions == batch['labels']).sum().item()
        total += batch['labels'].size(0)

print(f"Validation Accuracy: {correct / total * 100:.2f}%")

In [None]:
def classify_text(text):
    model.eval()
    encoding = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=128
    )
    encoding = {k: v.to(device) for k, v in encoding.items()}
    with torch.no_grad():
        outputs = model(**encoding)
        prediction = torch.argmax(outputs.logits, dim=-1).item()
    return label_encoder.inverse_transform([prediction])[0]

# Example usage:
print(classify_text("The company showed a strong increase in sales this quarter."))