"""
High-Performance DistilBERT Training for Stock Sentiment Analysis
Optimized for small datasets with extreme class imbalance and maximum hardware utilization
"""

In [None]:
import warnings
warnings.filterwarnings('ignore')

import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, f1_score
from transformers import (
    DistilBertTokenizer,
    DistilBertForSequenceClassification,
    Trainer,
    TrainingArguments,
    EarlyStoppingCallback,
    DataCollatorWithPadding
)
from torch.utils.data import Dataset
from pathlib import Path
import logging
from typing import Dict, List, Tuple
from dataclasses import dataclass
import json
import random

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

@dataclass
class FastTrainingConfig:
    """Optimized configuration for fast training on small datasets"""
    # Data parameters
    labeled_file: str = '../Data/NLP/news_dataset_id2_labeled.csv'
    max_sequence_length: int = 128   # You can reduce to 128 if most examples are <128 tokens
    test_size: float = 0.25
    val_split: float = 0.3

    # Model parameters
    model_name: str = 'distilbert-base-uncased'
    num_labels: int = 3

    # Aggressive training parameters for small datasets
    num_epochs: int = 8
    batch_size: int = 32                  # Try 32; adjust downward if VRAM OOM occurs
    gradient_accumulation_steps: int = 1  # No accumulation if batch_size fits in VRAM
    learning_rate: float = 3e-5
    warmup_ratio: float = 0.1
    weight_decay: float = 0.01
    early_stopping_patience: int = 5

    # Performance optimizations
    fp16: bool = True                    # Set to False if AMP stalls on your GPU
    dataloader_num_workers: int = 12      # Increase if you have ≥8 CPU cores
    dataloader_pin_memory: bool = True

    # Data augmentation
    augmentation_factor: int = 3  # Augment minority classes

    # Output parameters
    output_dir: str = './Models2/fast_sentiment_model'
    model_save_path: str = './Models2/fast_sentiment_distilbert'

In [None]:
class FocalLoss(torch.nn.Module):
    """Focal Loss for handling extreme class imbalance"""
    def __init__(self, alpha=1, gamma=2, num_classes=3):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.num_classes = num_classes

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        return focal_loss.mean()


class DataAugmenter:
    """Simple data augmentation for text"""

    @staticmethod
    def synonym_replacement(text: str, n: int = 2) -> str:
        """Simple synonym replacement (placeholder)"""
        words = text.split()
        if len(words) < 3:
            return text

        indices = random.sample(range(len(words)), min(n, len(words) // 3))
        for i in indices:
            if i < len(words) - 1:
                words[i], words[i + 1] = words[i + 1], words[i]
        return ' '.join(words)

    @staticmethod
    def random_insertion(text: str) -> str:
        """Insert random financial terms"""
        financial_terms = ['stock', 'market', 'price', 'trading', 'investor', 'earnings']
        words = text.split()
        if len(words) > 3:
            insert_pos = random.randint(0, len(words))
            term = random.choice(financial_terms)
            words.insert(insert_pos, term)
        return ' '.join(words)

    @staticmethod
    def augment_text(text: str, method: str = 'synonym') -> str:
        if method == 'synonym':
            return DataAugmenter.synonym_replacement(text)
        elif method == 'insertion':
            return DataAugmenter.random_insertion(text)
        else:
            return text


class FastDataset(Dataset):
    """
    Expects pre-tokenized input_ids & attention_mask tensors.
    """
    def __init__(
        self,
        input_ids: torch.Tensor,      # shape: (N, seq_len)
        attention_mask: torch.Tensor, # shape: (N, seq_len)
        labels: List[int]
    ):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self) -> int:
        return self.input_ids.size(0)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        return {
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_mask[idx],
            'labels': self.labels[idx]
        }

In [None]:
class FocalLossTrainer(Trainer):
    """Custom Trainer that uses Focal Loss instead of cross-entropy"""
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.focal_loss = FocalLoss(alpha=1, gamma=2, num_classes=3)

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        loss = self.focal_loss(logits, labels)
        return (loss, outputs) if return_outputs else loss


class FastStockSentimentTrainer:
    """Encapsulates data prep, model setup, and training for maximum performance"""

    def __init__(self, config: FastTrainingConfig):
        self.config = config
        self.label_mapping = {'Negative': 0, 'Neutral': 1, 'Positive': 2}
        self.reverse_label_mapping = {v: k for k, v in self.label_mapping.items()}

        # Create output directories if they don’t exist
        Path(config.output_dir).mkdir(parents=True, exist_ok=True)
        Path(config.model_save_path).mkdir(parents=True, exist_ok=True)

        # Enable cuDNN benchmark for fixed-size inputs
        torch.backends.cudnn.benchmark = True
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    def load_and_preprocess_data(self) -> pd.DataFrame:
        """Load CSV, build training_text, drop very short samples, label, and augment."""
        logger.info("Loading and preprocessing data...")
        try:
            df = pd.read_csv(self.config.labeled_file)
            logger.info(f"Dataset loaded: {df.shape}")
        except FileNotFoundError:
            raise FileNotFoundError(f"File not found: {self.config.labeled_file}")

        # Combine fields into 'training_text'
        df['training_text'] = df.apply(self._prepare_text_fast, axis=1)
        df = df[df['training_text'].str.len() > 10].reset_index(drop=True)

        # Map sentiment strings to ints
        df['label'] = df['stock_sentiment'].map(self.label_mapping)
        df = df.dropna(subset=['label']).reset_index(drop=True)
        df['label'] = df['label'].astype(int)

        self._log_distribution(df)
        df = self._augment_minority_classes(df)
        return df

    def _prepare_text_fast(self, row: pd.Series) -> str:
        parts = []
        for field in ['title', 'description', 'content']:
            if field in row and pd.notna(row[field]):
                text = str(row[field]).strip()
                if field == 'content':
                    text = ' '.join(text.split()[:150])  # truncate content
                if text:
                    parts.append(text)
        return ' '.join(parts)

    def _log_distribution(self, df: pd.DataFrame):
        dist = df['stock_sentiment'].value_counts()
        logger.info("Class distribution:")
        for label, count in dist.items():
            logger.info(f"  {label}: {count} ({count/len(df)*100:.1f}%)")

        # Check imbalance ratio
        class_counts = df['label'].value_counts()
        ratio = class_counts.max() / class_counts.min()
        if ratio > 5:
            logger.warning(f"High class imbalance (ratio: {ratio:.1f}) - using augmentation")

    def _augment_minority_classes(self, df: pd.DataFrame) -> pd.DataFrame:
        """Augment minority classes to help balance the dataset."""
        logger.info("Applying data augmentation...")
        class_counts = df['label'].value_counts()
        max_count = class_counts.max()
        augmented_data = []

        for label in [0, 2]:  # Negative and Positive
            class_df = df[df['label'] == label]
            current_count = len(class_df)
            target_count = min(max_count // 2, current_count * self.config.augmentation_factor)

            if target_count > current_count:
                augment_needed = target_count - current_count
                logger.info(f"Augmenting class {self.reverse_label_mapping[label]}: {current_count} -> {target_count}")

                for _ in range(augment_needed):
                    sample_row = class_df.sample(1).iloc[0].copy()
                    original_text = sample_row['training_text']
                    augmented_text = DataAugmenter.augment_text(original_text, 'synonym')
                    sample_row['training_text'] = augmented_text
                    augmented_data.append(sample_row)

        if augmented_data:
            augmented_df = pd.DataFrame(augmented_data)
            df = pd.concat([df, augmented_df], ignore_index=True)
            logger.info(f"Dataset after augmentation: {len(df)} samples")
            self._log_distribution(df)

        return df

    def setup_model_and_tokenizer(self):
        """Load DistilBERT + tokenizer, move model to GPU, enable gradient checkpointing."""
        logger.info("Setting up model and tokenizer...")
        self.tokenizer = DistilBertTokenizer.from_pretrained(self.config.model_name)
        self.model = DistilBertForSequenceClassification.from_pretrained(
            self.config.model_name,
            num_labels=self.config.num_labels,
            problem_type="single_label_classification"
        )
        if torch.cuda.is_available():
            self.model = self.model.cuda()
            logger.info(f"Model on GPU: {torch.cuda.get_device_name()}")
            logger.info(f"GPU Memory (GB): {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}")
        self.model.gradient_checkpointing_enable()

    def tokenize_splits(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Batch-tokenize a list of strings. Returns (input_ids, attention_mask) tensors.
        """
        encoding = self.tokenizer(
            texts,
            truncation=True,
            padding=True,  # dynamic padding to each batch’s max length
            max_length=self.config.max_sequence_length,
            return_tensors='pt',
            return_attention_mask=True
        )
        return encoding['input_ids'], encoding['attention_mask']

    def prepare_datasets(self, df: pd.DataFrame):
        """
        1) Split df into train/val/test (stratified),
        2) Batch-tokenize each split,
        3) Create FastDataset on the resulting tensors.
        """
        texts = df['training_text'].tolist()
        labels = df['label'].tolist()

        # 1. First split: train vs. temp (val+test)
        train_texts, temp_texts, train_labels, temp_labels = train_test_split(
            texts,
            labels,
            test_size=self.config.test_size,
            random_state=42,
            stratify=labels
        )

        # 2. Second split: val vs. test
        val_texts, test_texts, val_labels, test_labels = train_test_split(
            temp_texts,
            temp_labels,
            test_size=self.config.val_split,
            random_state=42,
            stratify=temp_labels
        )

        logger.info(f"📊 Final split → Train: {len(train_texts)}, Val: {len(val_texts)}, Test: {len(test_texts)}")

        # 3. Batch-tokenize each split
        logger.info("Batch-tokenizing train split…")
        train_input_ids, train_attention_mask = self.tokenize_splits(train_texts)
        logger.info("Batch-tokenizing val split…")
        val_input_ids, val_attention_mask = self.tokenize_splits(val_texts)
        logger.info("Batch-tokenizing test split…")
        test_input_ids, test_attention_mask = self.tokenize_splits(test_texts)

        # 4. Create FastDataset instances
        self.train_dataset = FastDataset(train_input_ids, train_attention_mask, train_labels)
        self.val_dataset = FastDataset(val_input_ids, val_attention_mask, val_labels)
        self.test_dataset = FastDataset(test_input_ids, test_attention_mask, test_labels)

        logger.info(
            f"Datasets built → Train: {len(self.train_dataset)}, "
            f"Val: {len(self.val_dataset)}, Test: {len(self.test_dataset)}"
        )

    def setup_training_args(self) -> TrainingArguments:
        """Optimized TrainingArguments for performance on GTX 1650."""
        return TrainingArguments(
            output_dir=self.config.output_dir,
            num_train_epochs=self.config.num_epochs,
            per_device_train_batch_size=self.config.batch_size,
            per_device_eval_batch_size=self.config.batch_size * 2,
            gradient_accumulation_steps=self.config.gradient_accumulation_steps,
            learning_rate=self.config.learning_rate,
            warmup_ratio=self.config.warmup_ratio,
            weight_decay=self.config.weight_decay,

            # Performance optimizations
            fp16=self.config.fp16,
            dataloader_num_workers=self.config.dataloader_num_workers,
            dataloader_pin_memory=self.config.dataloader_pin_memory,
            group_by_length=True,

            # Evaluation & saving once per epoch
            eval_strategy="epoch",
            save_strategy="epoch",
            logging_steps=10,

            load_best_model_at_end=False,  # Turn off while testing throughput
            metric_for_best_model="eval_f1",
            greater_is_better=True,

            save_total_limit=2,
            report_to=None,
            remove_unused_columns=False,

            # Disable torch.compile on Turing GPU
            torch_compile=False,
            optim="adamw_torch_fused" if torch.cuda.is_available() else "adamw_torch",
        )

    def compute_metrics(self, eval_pred) -> Dict[str, float]:
        """Compute accuracy, macro-F1, weighted-F1."""
        predictions, labels = eval_pred
        predictions = np.argmax(predictions, axis=1)
        accuracy = accuracy_score(labels, predictions)
        f1_macro = f1_score(labels, predictions, average='macro')
        f1_weighted = f1_score(labels, predictions, average='weighted')
        return {
            'accuracy': accuracy,
            'f1': f1_macro,
            'f1_weighted': f1_weighted
        }

    def train(self):
        """High-performance training loop"""
        logger.info("🚀 Starting high-performance training...")
        training_args = self.setup_training_args()
        data_collator = DataCollatorWithPadding(tokenizer=self.tokenizer, return_tensors="pt")

        trainer = FocalLossTrainer(
            model=self.model,
            args=training_args,
            train_dataset=self.train_dataset,
            eval_dataset=self.val_dataset,
            data_collator=data_collator,
            compute_metrics=self.compute_metrics,
            callbacks=[EarlyStoppingCallback(early_stopping_patience=self.config.early_stopping_patience)]
        )

        import time
        start_time = time.time()

        try:
            trainer.train()
            training_time = (time.time() - start_time) / 60
            logger.info(f"✅ Training completed in {training_time:.1f} minutes!")

            # Save model & tokenizer
            trainer.save_model(self.config.model_save_path)
            self.tokenizer.save_pretrained(self.config.model_save_path)
            logger.info(f"Model saved to: {self.config.model_save_path}")

            # Run final evaluation on test set
            self.evaluate_model(trainer)

        except Exception as e:
            logger.error(f"Training failed: {e}")
            raise

    def evaluate_model(self, trainer):
        """Fast evaluation on the test set with key metrics"""
        logger.info("📊 Evaluating model...")
        predictions = trainer.predict(self.test_dataset)
        predicted_labels = np.argmax(predictions.predictions, axis=1)
        true_labels = [item['labels'].item() for item in self.test_dataset]

        accuracy = accuracy_score(true_labels, predicted_labels)
        f1_macro = f1_score(true_labels, predicted_labels, average='macro')

        logger.info(f"🎯 Test Accuracy: {accuracy:.4f}")
        logger.info(f"🎯 F1-Score (Macro): {f1_macro:.4f}")

        label_names = list(self.label_mapping.keys())
        report = classification_report(true_labels, predicted_labels,
                                       target_names=label_names, output_dict=True)

        logger.info("📈 Per-class F1-scores:")
        for label in label_names:
            f1 = report[label]['f1-score']
            logger.info(f"  {label}: {f1:.3f}")

        results = {
            'test_accuracy': accuracy,
            'f1_macro': f1_macro,
            'classification_report': report
        }

        with open(f"{self.config.model_save_path}/results.json", 'w') as f:
            json.dump(results, f, indent=2, default=str)

        return results

In [None]:
# (Optional) set a random seed for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# 1. Instantiate configuration and trainer
config = FastTrainingConfig()
trainer = FastStockSentimentTrainer(config)

# 2. Load + preprocess the CSV
df = trainer.load_and_preprocess_data()

# 3. Setup model + tokenizer (needed before tokenizing)
trainer.setup_model_and_tokenizer()

# 4. Split & batch-tokenize + build PyTorch datasets
trainer.prepare_datasets(df)

# 5. Train & evaluate
trainer.train()