In [3]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
import torch.nn.functional as F
import re
import string
import pickle
import os
from collections import Counter
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score

torch.manual_seed(42)
np.random.seed(42)

In [6]:
# ====================================
# Data Preprocessing
# ====================================

class Vocabulary:
    def __init__(self, freq_threshold=2):
        self.itos = {0: "<PAD>", 1: "<UNK>", 2: "<SOS>", 3: "<EOS>"}
        self.stoi = {"<PAD>": 0, "<UNK>": 1, "<SOS>": 2, "<EOS>": 3}
        self.freq_threshold = freq_threshold
        self.word_count = {}
        
    def __len__(self):
        return len(self.itos)
    
    def build_vocabulary(self, sentence_list):
        frequencies = Counter()
        idx = 4
        
        for sentence in sentence_list:
            for word in self.tokenize(sentence):
                frequencies[word] += 1
                
                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1
        
        self.word_count = dict(frequencies)
    
    def tokenize(self, text):
        text = text.lower()
    
        # Step 1: Protect numbers with decimals and percentages
        # Replace "26.3%" with "26.3_percent" (treat as a single token)
        text = re.sub(r'(\d+\.\d+)%', r'\1_percent', text)  # 26.3% → 26.3_percent
        text = re.sub(r'(\d+)%', r'\1_percent', text)       # 5% → 5_percent
        text = re.sub(r'(\d+)\.(\d+)', r'\1.\2', text)      # 1.5 → 1.5 (unchanged)
    
        # Step 2: Remove unwanted punctuation (except protected cases)
        # Keep: letters, numbers, underscores, and protected tokens (e.g., 26.3_percent)
        text = re.sub(r'[^\w\s.]', '', text)  # Allow dots (.) in numbers
    
        # Step 3: Split into tokens
        tokens = text.split()
    
        return tokens
    
    def numericalize(self, text):
        tokenized_text = self.tokenize(text)
        
        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]
    
    def save_vocab(self, path):
        with open(path, 'wb') as f:
            pickle.dump({
                'itos': self.itos,
                'stoi': self.stoi,
                'word_count': self.word_count
            }, f)
    
    @classmethod
    def load_vocab(cls, path):
        vocab = cls()
        with open(path, 'rb') as f:
            data = pickle.load(f)
            vocab.itos = data['itos']
            vocab.stoi = data['stoi']
            vocab.word_count = data['word_count']
        return vocab


class FinancialCausalDataset(Dataset):
    def __init__(self, texts, causes, effects, vocab, max_len=512):
        self.texts = texts
        self.causes = causes
        self.effects = effects
        self.vocab = vocab
        self.max_len = max_len
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, index):
        text = self.texts[index]
        cause = self.causes[index]
        effect = self.effects[index]
        
        # Numericalize text
        numeric_text = self.vocab.numericalize(text)
        
        # Truncating if necessary
        if len(numeric_text) > self.max_len:
            numeric_text = numeric_text[:self.max_len]
        
        text_length = len(numeric_text)
        
        # Creating a cause span representation (start, end indices in the text)
        cause_tokens = self.vocab.tokenize(cause)
        text_tokens = self.vocab.tokenize(text)
        
        cause_start, cause_end = -1, -1
        for i in range(len(text_tokens) - len(cause_tokens) + 1):
            if text_tokens[i:i+len(cause_tokens)] == cause_tokens:
                cause_start = i
                cause_end = i + len(cause_tokens) - 1
                break
        
        effect_tokens = self.vocab.tokenize(effect)
        effect_start, effect_end = -1, -1
        for i in range(len(text_tokens) - len(effect_tokens) + 1):
            if text_tokens[i:i+len(effect_tokens)] == effect_tokens:
                effect_start = i
                effect_end = i + len(effect_tokens) - 1
                break
        
        cause_mask = [0] * len(numeric_text)
        effect_mask = [0] * len(numeric_text)
        
        if cause_start >= 0 and cause_end < len(numeric_text):
            for i in range(cause_start, cause_end + 1):
                if i < len(cause_mask):
                    cause_mask[i] = 1
                    
        if effect_start >= 0 and effect_end < len(numeric_text):
            for i in range(effect_start, effect_end + 1):
                if i < len(effect_mask):
                    effect_mask[i] = 1
        
        return {
            "text": torch.tensor(numeric_text),
            "text_length": text_length,
            "cause_mask": torch.tensor(cause_mask),
            "effect_mask": torch.tensor(effect_mask),
            "original_text": text,
            "original_cause": cause,
            "original_effect": effect
        }


def collate_batch(batch):
    batch.sort(key=lambda x: x["text_length"], reverse=True)
    
    text = [item["text"] for item in batch]
    text_lengths = [item["text_length"] for item in batch]
    cause_masks = [item["cause_mask"] for item in batch]
    effect_masks = [item["effect_mask"] for item in batch]
    original_texts = [item["original_text"] for item in batch]
    original_causes = [item["original_cause"] for item in batch]
    original_effects = [item["original_effect"] for item in batch]
    
    # Pad sequences
    padded_text = pad_sequence(text, batch_first=True, padding_value=0)
    padded_cause_masks = pad_sequence(cause_masks, batch_first=True, padding_value=0)
    padded_effect_masks = pad_sequence(effect_masks, batch_first=True, padding_value=0)
    
    return {
        "text": padded_text,
        "text_lengths": torch.tensor(text_lengths),
        "cause_masks": padded_cause_masks,
        "effect_masks": padded_effect_masks,
        "original_texts": original_texts,
        "original_causes": original_causes,
        "original_effects": original_effects
    }


def prepare_data(csv_path, test_size=0.2, vocab_save_path="financial_vocab.pkl"):
    df = pd.read_csv(csv_path)
    df = df.head(20000)
    required_columns = ["Text", "Cause", "Effect"]
    for col in required_columns:
        if col not in df.columns:
            raise ValueError(f"Required column '{col}' not found in the CSV file")

    train_df, test_df = train_test_split(df, test_size=test_size, random_state=42)
    
    # Building the vocabulary on training data only
    vocab = Vocabulary(freq_threshold=2)
    vocab.build_vocabulary(train_df["Text"].tolist())
    
    # Saving vocabulary
    vocab.save_vocab(vocab_save_path)
    
    # Creating train and test datasets
    train_dataset = FinancialCausalDataset(
        texts=train_df["Text"].tolist(),
        causes=train_df["Cause"].tolist(),
        effects=train_df["Effect"].tolist(),
        vocab=vocab
    )
    
    test_dataset = FinancialCausalDataset(
        texts=test_df["Text"].tolist(),
        causes=test_df["Cause"].tolist(),
        effects=test_df["Effect"].tolist(),
        vocab=vocab
    )
    
    return train_dataset, test_dataset, vocab


# ====================================
# Model Architecture
# ====================================

class CausalAttention(nn.Module):
    def __init__(self, hidden_dim):
        super(CausalAttention, self).__init__()
        self.hidden_dim = hidden_dim
        self.query = nn.Linear(hidden_dim, hidden_dim)
        self.key = nn.Linear(hidden_dim, hidden_dim)
        self.value = nn.Linear(hidden_dim, hidden_dim)
        self.scale = torch.sqrt(torch.FloatTensor([hidden_dim])).to(next(self.parameters()).device)
    
    def forward(self, query, key, value, mask=None):
        # query, key, value: [batch_size, seq_len, hidden_dim]
        batch_size = query.shape[0]
        seq_len = query.shape[1]
        
        Q = self.query(query)  # [batch_size, seq_len, hidden_dim]
        K = self.key(key)      # [batch_size, seq_len, hidden_dim]
        V = self.value(value)  # [batch_size, seq_len, hidden_dim]
        
        # Q @ K^T / sqrt(hidden_dim)
        energy = torch.matmul(Q, K.permute(0, 2, 1)) / self.scale  # [batch_size, seq_len, seq_len]
        
        # Apply mask if provided
        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)
        
        # Softmax
        attention = torch.softmax(energy, dim=-1)  # [batch_size, seq_len, seq_len]
        
        # Weighted sum of values
        x = torch.matmul(attention, V)  # [batch_size, seq_len, hidden_dim]
        
        return x, attention


class CausalContextLayer(nn.Module):
    def __init__(self, hidden_dim, num_heads=8, dropout=0.1):
        super(CausalContextLayer, self).__init__()
        assert hidden_dim % num_heads == 0
        
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        
        self.fc_q = nn.Linear(hidden_dim, hidden_dim)
        self.fc_k = nn.Linear(hidden_dim, hidden_dim)
        self.fc_v = nn.Linear(hidden_dim, hidden_dim)
        
        self.fc_o = nn.Linear(hidden_dim, hidden_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim]))
    
    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]
        
        # Linear projections
        Q = self.fc_q(query)  # [batch_size, seq_len, hidden_dim]
        K = self.fc_k(key)    # [batch_size, seq_len, hidden_dim]
        V = self.fc_v(value)  # [batch_size, seq_len, hidden_dim]
        
        # Reshape for multi-head attention
        Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  # [batch_size, num_heads, seq_len, head_dim]
        K = K.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  # [batch_size, num_heads, seq_len, head_dim]
        V = V.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  # [batch_size, num_heads, seq_len, head_dim]
        
        # Self-attention
        self.scale = self.scale.to(query.device)
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale  # [batch_size, num_heads, seq_len, seq_len]
        
        # Apply mask if provided
        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(2)  # [batch_size, 1, 1, seq_len]
            energy = energy.masked_fill(mask == 0, -1e10)
        
        # Softmax
        attention = torch.softmax(energy, dim=-1)    # [batch_size, num_heads, seq_len, seq_len]
        attention = self.dropout(attention)
        
        # Weighted sum of values
        x = torch.matmul(attention, V)               # [batch_size, num_heads, seq_len, head_dim]
        
        # Reshape back
        x = x.permute(0, 2, 1, 3).contiguous()       # [batch_size, seq_len, num_heads, head_dim]
        x = x.view(batch_size, -1, self.hidden_dim)  # [batch_size, seq_len, hidden_dim]
        
        # Final linear layer
        x = self.fc_o(x)  # [batch_size, seq_len, hidden_dim]
        
        return x


class FinancialCausalDetector(nn.Module):
    def __init__(self, vocab_size, embedding_dim=300, hidden_dim=768, num_layers=3, 
                 num_heads=8, dropout=0.3, use_glove=False, glove_path=None):
        super(FinancialCausalDetector, self).__init__()
        
        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        
        # Bidirectional LSTM for encoding
        self.lstm = nn.LSTM(
            embedding_dim,
            hidden_dim // 2,  # Bidirectional, so each direction gets half
            num_layers=num_layers,
            bidirectional=True,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )
        
        # Context encoding layers
        self.context_layers = nn.ModuleList([
            CausalContextLayer(hidden_dim, num_heads, dropout)
            for _ in range(num_layers)
        ])
        
        # Normalization layers
        self.context_norms = nn.ModuleList([
            nn.LayerNorm(hidden_dim)
            for _ in range(num_layers)
        ])
        
        # Feed-forward layers after each context layer
        self.ff_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim * 4),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim * 4, hidden_dim),
                nn.Dropout(dropout)
            )
            for _ in range(num_layers)
        ])
        
        # Feed-forward norms
        self.ff_norms = nn.ModuleList([
            nn.LayerNorm(hidden_dim)
            for _ in range(num_layers)
        ])
        
        # Cause-Effect span prediction
        self.cause_start_classifier = nn.Linear(hidden_dim, 1)
        self.cause_end_classifier = nn.Linear(hidden_dim, 1)
        self.effect_start_classifier = nn.Linear(hidden_dim, 1)
        self.effect_end_classifier = nn.Linear(hidden_dim, 1)
        
        # Causal relation classification
        self.relation_classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 3)  # No relation, Cause->Effect, Effect->Cause
        )
        
        # Initialize weights
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
            
    
    def create_padding_mask(self, x, pad_idx=0):
        # x: [batch_size, seq_len]
        # Returns a mask where pad tokens are 0, others are 1
        return (x != pad_idx).float()
    
    def forward(self, text, text_lengths):
        # text: [batch_size, seq_len]
        batch_size = text.shape[0]
        seq_len = text.shape[1]
        
        # Create padding mask
        padding_mask = self.create_padding_mask(text).unsqueeze(-1)  # [batch_size, seq_len, 1]
        
        # Get embeddings
        embedded = self.embedding(text)                              # [batch_size, seq_len, embedding_dim]
        
        # Pack padded sequence for LSTM
        packed_embedded = pack_padded_sequence(
            embedded, text_lengths.cpu(), batch_first=True, enforce_sorted=True
        )
        
        # Run through LSTM
        packed_outputs, (hidden, cell) = self.lstm(packed_embedded)
        
        # Unpack the sequence
        outputs, _ = pad_packed_sequence(packed_outputs, batch_first=True, total_length=seq_len)
        # outputs: [batch_size, seq_len, hidden_dim]
        
        # Apply context layers
        x = outputs
        for i in range(len(self.context_layers)):
            # Apply context layer
            context_input = self.context_norms[i](x)
            context_output = self.context_layers[i](context_input, context_input, context_input)
            x = x + context_output
            
            # Apply feed-forward layer
            ff_input = self.ff_norms[i](x)
            ff_output = self.ff_layers[i](ff_input)
            x = x + ff_output
        
        # Get mask predictions
        cause_start_logits = self.cause_start_classifier(x).squeeze(-1)  # [batch_size, seq_len]
        cause_end_logits = self.cause_end_classifier(x).squeeze(-1)      # [batch_size, seq_len]
        effect_start_logits = self.effect_start_classifier(x).squeeze(-1)  # [batch_size, seq_len]
        effect_end_logits = self.effect_end_classifier(x).squeeze(-1)      # [batch_size, seq_len]
        
        # Apply padding mask
        padding_mask = padding_mask.squeeze(-1)  # [batch_size, seq_len]
        cause_start_logits = cause_start_logits * padding_mask - 1e10 * (1 - padding_mask)
        cause_end_logits = cause_end_logits * padding_mask - 1e10 * (1 - padding_mask)
        effect_start_logits = effect_start_logits * padding_mask - 1e10 * (1 - padding_mask)
        effect_end_logits = effect_end_logits * padding_mask - 1e10 * (1 - padding_mask)
        
        # For relation classification, we use the representation of the [CLS] token
        # (which is the first token in the sequence)
        sentence_rep = torch.mean(x, dim=1)  # [batch_size, hidden_dim]
        
        # Get cause and effect representations
        # We'll use the mean of token representations in the cause/effect spans
        cause_probs = torch.softmax(cause_start_logits, dim=-1).unsqueeze(-1)    # [batch_size, seq_len, 1]
        effect_probs = torch.softmax(effect_start_logits, dim=-1).unsqueeze(-1)  # [batch_size, seq_len, 1]
        
        # Weighted sum of token representations
        cause_rep = torch.sum(x * cause_probs, dim=1)    # [batch_size, hidden_dim]
        effect_rep = torch.sum(x * effect_probs, dim=1)  # [batch_size, hidden_dim]
        
        # Concatenate for relation classification
        relation_input = torch.cat([cause_rep, effect_rep], dim=-1)  # [batch_size, hidden_dim*2]
        relation_logits = self.relation_classifier(relation_input)   # [batch_size, 3]
        
        return {
            "cause_start_logits": cause_start_logits,
            "cause_end_logits": cause_end_logits,
            "effect_start_logits": effect_start_logits,
            "effect_end_logits": effect_end_logits,
            "relation_logits": relation_logits
        }


# ====================================
# Training Functions
# ====================================

def train_model(model, train_loader, val_loader, optimizer, criterion, device, num_epochs=10):
    best_val_loss = float('inf')
    best_model_state = None
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0.0
        
        for batch in train_loader:
            texts = batch["text"].to(device)
            text_lengths = batch["text_lengths"]
            cause_masks = batch["cause_masks"].float().to(device)
            effect_masks = batch["effect_masks"].float().to(device)
            
            # Forward pass
            outputs = model(texts, text_lengths)
            
            # Calculating loss
            cause_start_loss = criterion(outputs["cause_start_logits"], cause_masks)
            cause_end_loss = criterion(outputs["cause_end_logits"], cause_masks)
            effect_start_loss = criterion(outputs["effect_start_logits"], effect_masks)
            effect_end_loss = criterion(outputs["effect_end_logits"], effect_masks)
            
            # Total loss (we're not using relation loss here since we focus on span detection)
            loss = cause_start_loss + cause_end_loss + effect_start_loss + effect_end_loss
            
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        
        # Validation
        model.eval()
        val_loss = 0.0
        
        with torch.no_grad():
            for batch in val_loader:
                texts = batch["text"].to(device)
                text_lengths = batch["text_lengths"]
                cause_masks = batch["cause_masks"].float().to(device)
                effect_masks = batch["effect_masks"].float().to(device)
                
                # Forward pass
                outputs = model(texts, text_lengths)
                
                # Calculate loss
                cause_start_loss = criterion(outputs["cause_start_logits"], cause_masks)
                cause_end_loss = criterion(outputs["cause_end_logits"], cause_masks)
                effect_start_loss = criterion(outputs["effect_start_logits"], effect_masks)
                effect_end_loss = criterion(outputs["effect_end_logits"], effect_masks)
                
                # Total loss
                loss = cause_start_loss + cause_end_loss + effect_start_loss + effect_end_loss
                
                val_loss += loss.item()
            
            val_loss /= len(val_loader)
            
            print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
            
            # Save the best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model_state = model.state_dict().copy()
    
    # Load the best model
    model.load_state_dict(best_model_state)
    return model


def extract_spans(logits, text_tokens, threshold=0.5):
    """Extract spans from token-level logits."""
    probs = torch.sigmoid(logits).cpu().numpy()
    spans = []
    
    current_span = []
    for i, prob in enumerate(probs):
        if prob > threshold:
            current_span.append(i)
        elif current_span:
            start = current_span[0]
            end = current_span[-1]
            spans.append((start, end))
            current_span = []
    
    # Adding the last span if it exists
    if current_span:
        start = current_span[0]
        end = current_span[-1]
        spans.append((start, end))
    
    return spans


def evaluate_model(model, test_loader, device, threshold=0.5):
    model.eval()
    
    cause_predictions = []
    effect_predictions = []
    true_causes = []
    true_effects = []
    
    with torch.no_grad():
        for batch in test_loader:
            texts = batch["text"].to(device)
            text_lengths = batch["text_lengths"]
            cause_masks = batch["cause_masks"]
            effect_masks = batch["effect_masks"]
            original_texts = batch["original_texts"]
            original_causes = batch["original_causes"]
            original_effects = batch["original_effects"]
            
            # Forward pass
            outputs = model(texts, text_lengths)
            
            # Get predictions
            for i in range(len(texts)):
                # Get the text length for this sample
                text_len = text_lengths[i].item()
                
                # Get logits for this sample
                cause_start_logits = outputs["cause_start_logits"][i][:text_len]
                effect_start_logits = outputs["effect_start_logits"][i][:text_len]
                
                # Extract spans
                cause_spans = extract_spans(cause_start_logits, original_texts[i], threshold)
                effect_spans = extract_spans(effect_start_logits, original_texts[i], threshold)
                
                # Add predictions and ground truth
                cause_predictions.append(cause_spans)
                effect_predictions.append(effect_spans)
                true_causes.append((original_causes[i], cause_masks[i]))
                true_effects.append((original_effects[i], effect_masks[i]))
    
    return {
        "cause_predictions": cause_predictions,
        "effect_predictions": effect_predictions,
        "true_causes": true_causes,
        "true_effects": true_effects
    }


def predict_causal_relation(model, text, vocab, device):
    """Predict cause and effect in a given text."""
    model.eval()
    
    # Tokenize and numericalize the text
    numeric_text = vocab.numericalize(text)
    text_tensor = torch.tensor(numeric_text).unsqueeze(0).to(device)  # [1, seq_len]
    text_length = torch.tensor([len(numeric_text)])
    
    # Forward pass
    with torch.no_grad():
        outputs = model(text_tensor, text_length)
    
    # Get predictions
    cause_start_logits = outputs["cause_start_logits"][0]
    cause_end_logits = outputs["cause_end_logits"][0]
    effect_start_logits = outputs["effect_start_logits"][0]
    effect_end_logits = outputs["effect_end_logits"][0]
    relation_logits = outputs["relation_logits"][0]
    
    # Get cause and effect spans
    cause_start_probs = torch.sigmoid(cause_start_logits).cpu().numpy()
    cause_end_probs = torch.sigmoid(cause_end_logits).cpu().numpy()
    effect_start_probs = torch.sigmoid(effect_start_logits).cpu().numpy()
    effect_end_probs = torch.sigmoid(effect_end_logits).cpu().numpy()
    
    # Extract spans
    cause_spans = extract_spans(cause_start_logits, text)
    effect_spans = extract_spans(effect_start_logits, text)
    
    # Get relation type
    relation_type = torch.argmax(relation_logits).item()
    relation_types = ["No relation", "Cause->Effect", "Effect->Cause"]
    
    # Converting token indices back to words
    text_tokens = vocab.tokenize(text)
    
    causes = []
    for start, end in cause_spans:
        if start < len(text_tokens) and end < len(text_tokens):
            cause_text = " ".join(text_tokens[start:end+1])
            causes.append(cause_text)
    
    effects = []
    for start, end in effect_spans:
        if start < len(text_tokens) and end < len(text_tokens):
            effect_text = " ".join(text_tokens[start:end+1])
            effects.append(effect_text)
    
    return {
        "text": text,
        "causes": causes,
        "effects": effects,
        "relation_type": relation_types[relation_type],
        "cause_confidence": cause_start_probs.max() if len(cause_start_probs) > 0 else 0,
        "effect_confidence": effect_start_probs.max() if len(effect_start_probs) > 0 else 0
    }


# ====================================
# Main Function
# ====================================

def main(csv_path, save_dir="financial_model", test_sentences=None):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    # Preparing data
    vocab_path = os.path.join(save_dir, "financial_vocab.pkl")
    train_dataset, test_dataset, vocab = prepare_data(csv_path, test_size=0.2, vocab_save_path=vocab_path)
    
    # Creating data loaders
    train_size = int(0.9 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=16,
        shuffle=True,
        collate_fn=collate_batch
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=16,
        shuffle=False,
        collate_fn=collate_batch
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=16,
        shuffle=False,
        collate_fn=collate_batch
    )

    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Initializing model
    model = FinancialCausalDetector(
        vocab_size=len(vocab),
        embedding_dim=300,
        hidden_dim=768,
        num_layers=3,
        num_heads=8,
        dropout=0.3
    ).to(device)
    
    # Setting up optimizer and loss function
    optimizer = optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
    criterion = nn.BCEWithLogitsLoss()
    
    # Training the model
    model = train_model(
        model,
        train_loader,
        val_loader,
        optimizer,
        criterion,
        device,
        num_epochs=20
    )
    
    # Saving the trained model
    model_path = os.path.join(save_dir, "financial_causal_model.pt")
    torch.save({
        'model_state_dict': model.state_dict(),
        'vocab_size': len(vocab),
        'embedding_dim': 300,
        'hidden_dim': 768,
        'num_layers': 3,
        'num_heads': 8,
        'dropout': 0.3
    }, model_path)
    
    print(f"Model saved to {model_path}")
    
    # Evaluating model
    eval_results = evaluate_model(model, test_loader, device)
    
    if test_sentences:
        print("\nMaking predictions on test sentences:")
        for sentence in test_sentences:
            prediction = predict_causal_relation(model, sentence, vocab, device)
            print(f"\nText: {prediction['text']}")
            print(f"Detected causes: {prediction['causes']}")
            print(f"Detected effects: {prediction['effects']}")
            print(f"Relation type: {prediction['relation_type']}")
            print(f"Cause confidence: {prediction['cause_confidence']:.4f}")
            print(f"Effect confidence: {prediction['effect_confidence']:.4f}")
    
    return model, vocab


# ====================================
# Model Loading and Inference
# ====================================

def load_model(model_path, vocab_path, device=None):
    """Load a trained model and vocabulary."""
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Loading vocabulary
    vocab = Vocabulary.load_vocab(vocab_path)
    
    # Loading model configuration and state
    checkpoint = torch.load(model_path, map_location=device)
    
    # Initializing model with the same configuration
    model = FinancialCausalDetector(
        vocab_size=checkpoint['vocab_size'],
        embedding_dim=checkpoint['embedding_dim'],
        hidden_dim=checkpoint['hidden_dim'],
        num_layers=checkpoint['num_layers'],
        num_heads=checkpoint['num_heads'],
        dropout=checkpoint['dropout']
    ).to(device)
    
    # Loading state dict
    model.load_state_dict(checkpoint['model_state_dict'])
    
    return model, vocab


def batch_predict(model, texts, vocab, device, batch_size=16):
    """Make predictions on a batch of texts."""
    model.eval()
    results = []
    
    # Process in batches
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i+batch_size]
        batch_numeric = []
        batch_lengths = []
        
        # Numericalize texts
        for text in batch_texts:
            numeric_text = vocab.numericalize(text)
            batch_numeric.append(torch.tensor(numeric_text))
            batch_lengths.append(len(numeric_text))
        
        # Pad sequences
        padded_texts = pad_sequence(batch_numeric, batch_first=True, padding_value=0)
        text_lengths = torch.tensor(batch_lengths)
        
        # Forward pass
        with torch.no_grad():
            outputs = model(padded_texts.to(device), text_lengths)
        
        # Process each prediction
        for j, text in enumerate(batch_texts):
            text_len = text_lengths[j].item()
            
            # Get logits for this sample
            cause_start_logits = outputs["cause_start_logits"][j][:text_len]
            effect_start_logits = outputs["effect_start_logits"][j][:text_len]
            relation_logits = outputs["relation_logits"][j]
            
            # Get cause and effect spans
            cause_spans = extract_spans(cause_start_logits, text)
            effect_spans = extract_spans(effect_start_logits, text)
            
            # Get relation type
            relation_type = torch.argmax(relation_logits).item()
            relation_types = ["No relation", "Cause->Effect", "Effect->Cause"]
            
            # Convert token indices back to words
            text_tokens = vocab.tokenize(text)
            
            causes = []
            for start, end in cause_spans:
                if start < len(text_tokens) and end < len(text_tokens):
                    cause_text = " ".join(text_tokens[start:end+1])
                    causes.append(cause_text)
            
            effects = []
            for start, end in effect_spans:
                if start < len(text_tokens) and end < len(text_tokens):
                    effect_text = " ".join(text_tokens[start:end+1])
                    effects.append(effect_text)
            
            results.append({
                "text": text,
                "causes": causes,
                "effects": effects,
                "relation_type": relation_types[relation_type]
            })
    
    return results


# ====================================
# Context Processing
# ====================================

class FinancialContextProcessor:
    """Process financial text to enhance causal detection with domain knowledge."""
    
    def __init__(self, model, vocab, device):
        self.model = model
        self.vocab = vocab
        self.device = device
        
        # Financial domain-specific patterns
        self.financial_keywords = [
            'revenue', 'profit', 'loss', 'growth', 'decline', 'increase', 'decrease',
            'market', 'stock', 'share', 'price', 'investment', 'investor', 'dividend',
            'earning', 'quarter', 'fiscal', 'report', 'forecast', 'outlook', 'guidance',
            'volatility', 'inflation', 'interest rate', 'debt', 'asset', 'liability',
            'acquisition', 'merger', 'bankruptcy', 'default', 'credit', 'loan'
        ]
        
        self.causal_connectors = [
            'because', 'due to', 'as a result of', 'therefore', 'consequently',
            'hence', 'thus', 'leads to', 'causes', 'results in', 'affects',
            'influences', 'impacts', 'drives', 'triggered by', 'stemming from'
        ]
    
    def preprocess_text(self, text):
        """Apply domain-specific preprocessing."""
        # Normalize text
        text = text.lower()
        
        # Special handling for financial abbreviations and terms
        # Replace common abbreviations with full forms
        text = text.replace('q1', 'quarter one')
        text = text.replace('q2', 'quarter two')
        text = text.replace('q3', 'quarter three')
        text = text.replace('q4', 'quarter four')
        text = text.replace('fy', 'fiscal year')
        text = text.replace('yoy', 'year over year')
        text = text.replace('mom', 'month over month')
        text = text.replace('qoq', 'quarter over quarter')
        
        return text
    
    def enhance_context(self, text):
        """Enhance the text with financial context markers."""
        # Find financial terms
        enhanced_text = text
        for keyword in self.financial_keywords:
            if keyword in text.lower():
                # Highlight financial terms (in a real implementation, 
                # this could involve more sophisticated techniques)
                enhanced_text = enhanced_text.replace(keyword, f"<FIN>{keyword}</FIN>")
        
        # Highlight causal connectors
        for connector in self.causal_connectors:
            if connector in text.lower():
                enhanced_text = enhanced_text.replace(connector, f"<CAUSE>{connector}</CAUSE>")
        
        return enhanced_text
    
    def analyze_text(self, text):
        """Analyze text for financial causal relations with enhanced context."""
        # Preprocess
        preprocessed_text = self.preprocess_text(text)
        
        # Enhance context (this step would integrate with the model in a real implementation)
        enhanced_text = self.enhance_context(preprocessed_text)
        
        # Standard prediction
        prediction = predict_causal_relation(self.model, preprocessed_text, self.vocab, self.device)
        
        # Add domain-specific analysis
        financial_terms = []
        for keyword in self.financial_keywords:
            if keyword in text.lower():
                financial_terms.append(keyword)
        
        causal_markers = []
        for connector in self.causal_connectors:
            if connector in text.lower():
                causal_markers.append(connector)
        
        prediction.update({
            "financial_terms": financial_terms,
            "causal_markers": causal_markers,
            "enhanced_text": enhanced_text
        })
        
        return prediction


# ====================================
# Usage
# ====================================

if __name__ == "__main__":
    csv_path = "updated_final_dataset.csv"  
    save_dir = "financial_model"
    
    # Test sentences
    test_sentences = [
        "The company's stock price fell sharply because of the disappointing earnings report.",
        "Due to the increase in interest rates, mortgage applications decreased by 10% last month.",
        "Revenue growth in the technology sector accelerated as a result of increased consumer spending on electronics.",
        "The market volatility was triggered by concerns about inflation.",
        "The corporate tax cut led to higher profit margins for many companies.",
        "Due to the increase in tax rates, the company's profit margin decreased by 5 percent.",
        "The stock price dropped 10 percent when the CEO resigned unexpectedly.",
        "As interest rates fell, mortgage applications increased by 20 percent.",
        "The merger with ABC Corp resulted in a 15 percent increase in market share.",
        "The company's revenue grew by 8 percent because they expanded into international markets.",
        "Because Oshrad is a privately held business, it was liable for a 26.3% company tax on the sale.",
        "The Israel Tax Authority claims that when a company controlled by the Brosh brothers sold a 14% stake in Oshrad Natural Gas for 8.6 million shekels in 2017, they should have paid 4.3 million shekels in tax.",
        
        "The collapse of Silicon Bank led to widespread investor panic and a temporary dip in the tech sector.",
        "Following the announcement of record-high inflation, consumer confidence fell to its lowest level in a decade.",
        "Strong quarterly earnings prompted a surge in the company's stock price.",
        "The central bank’s decision to cut interest rates boosted lending activity across the housing market.",
        "The introduction of new tariffs on Chinese goods caused import costs to rise significantly for U.S. retailers.",
        "A downgrade in the country’s credit rating triggered a sell-off in government bonds.",
        "Due to a cybersecurity breach, the fintech company lost $10 million in market value overnight.",
        "The announcement of a new product line caused investor optimism and lifted share prices by 12%.",
        "Because of reduced oil supply from OPEC countries, global fuel prices surged to a five-year high.",
        "Rising raw material costs led manufacturers to increase prices across multiple sectors."
    ]

    # Training and save model
    model, vocab = main(csv_path, save_dir, test_sentences)
    
    # Create financial insights analyzer
    insights_analyzer = FinancialCausalInsights(model, vocab, torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    
    # Analyze a financial text
    financial_text = "The Federal Reserve's decision to raise interest rates caused significant declines in growth stocks, leading many investors to shift their portfolios to value stocks."
    insights = insights_analyzer.analyze_financial_impact(financial_text)
    
    print("\nFinancial Insights:")
    print(f"Text: {insights['text']}")
    print(f"Overall Sentiment: {insights['overall_sentiment']}")
    print("\nCausal Chains:")
    for chain in insights['causal_chains']:
        print(f"  {chain['cause']} ({chain['cause_category']}, {chain['cause_impact']}) → {chain['effect']} ({chain['effect_category']}, {chain['effect_impact']})")
    
    print("\nEntities:")
    for entity in insights['entities']:
        print(f"  {entity['name']} - {entity['category']} - {entity['impact']}")