In [None]:
import torch
import torch.nn as nn
import math

class BalancedBERT(nn.Module):
    """
    Optimized for balanced dataset of ~12K samples
    Larger capacity than previous models but carefully regularized
    """
    def __init__(self, vocab_size, num_classes=5, max_len=128,
                 hidden_size=256, num_layers=4, num_heads=8,
                 intermediate_size=512, dropout=0.2):
        super().__init__()
        
        # Embeddings (slightly larger for balanced data)
        self.embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
        self.position_embeddings = nn.Embedding(max_len, hidden_size)
        self.token_type_embeddings = nn.Embedding(2, hidden_size)
        
        self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(dropout)
        
        # Transformer layers with residual connections
        self.encoder_layers = nn.ModuleList([
            BalancedTransformerLayer(hidden_size, num_heads, intermediate_size, dropout)
            for _ in range(num_layers)
        ])
        
        # Multi-head attention pooling (better than just CLS)
        self.attention_pool = MultiHeadAttentionPooling(hidden_size, num_heads=4)
        
        # Enhanced classifier with multiple residual blocks
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.GELU(),
            nn.Dropout(dropout * 0.8),  # Slightly less dropout in later layers
            nn.Linear(hidden_size // 2, num_classes)
        )
        
        self._init_weights()
        print(f"BalancedBERT parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad):,}")
    
    def _init_weights(self):
        """Better initialization for balanced data"""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight, gain=0.7)  # Lower gain
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def forward(self, input_ids, token_type_ids=None, attention_mask=None):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, 
                                   device=input_ids.device).unsqueeze(0)
        
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        
        # Embeddings
        words_embeddings = self.embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        
        embeddings = words_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        
        # Prepare attention mask
        if attention_mask is None:
            attention_mask = (input_ids != 0)
        
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        extended_attention_mask = (1.0 - extended_attention_mask.float()) * -10000.0
        
        # Encoder layers with residual connections
        hidden_states = embeddings
        for layer in self.encoder_layers:
            hidden_states = layer(hidden_states, extended_attention_mask)
        
        # Attention pooling over all tokens (better representation)
        pooled_output = self.attention_pool(hidden_states, attention_mask)
        
        # Classification
        logits = self.classifier(pooled_output)
        
        return logits

class BalancedTransformerLayer(nn.Module):
    """Enhanced transformer layer with pre-norm and better initialization"""
    def __init__(self, hidden_size, num_heads, intermediate_size, dropout=0.2):
        super().__init__()
        
        # Multi-head self-attention
        self.self_attn = nn.MultiheadAttention(
            hidden_size, num_heads, dropout=dropout, batch_first=True
        )
        
        # Feed-forward network with gated linear unit
        self.ffn = nn.Sequential(
            nn.Linear(hidden_size, intermediate_size * 2),
            nn.GLU(dim=-1),
            nn.Dropout(dropout),
            nn.Linear(intermediate_size, hidden_size)
        )
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(hidden_size)
        self.norm2 = nn.LayerNorm(hidden_size)
        
        # Dropout
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        
        # Learnable scaling factors (for better gradient flow)
        self.gamma1 = nn.Parameter(torch.ones(hidden_size))
        self.gamma2 = nn.Parameter(torch.ones(hidden_size))
    
    def forward(self, x, attention_mask):
        # Pre-norm self-attention
        x_norm = self.norm1(x)
        attn_output, _ = self.self_attn(
            x_norm, x_norm, x_norm,
            key_padding_mask=(attention_mask.squeeze(1).squeeze(1) == 0)
        )
        x = x + self.dropout1(attn_output) * self.gamma1
        
        # Pre-norm feed-forward
        x_norm = self.norm2(x)
        ffn_output = self.ffn(x_norm)
        x = x + self.dropout2(ffn_output) * self.gamma2
        
        return x

class MultiHeadAttentionPooling(nn.Module):
    """Context-aware attention pooling"""
    def __init__(self, hidden_size, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        
        # Learnable query
        self.query = nn.Parameter(torch.randn(1, 1, hidden_size))
        
        # Linear projections
        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, hidden_size)
        self.v_linear = nn.Linear(hidden_size, hidden_size)
        self.out_proj = nn.Linear(hidden_size, hidden_size)
    
    def forward(self, hidden_states, attention_mask):
        batch_size = hidden_states.size(0)
        
        # Expand learnable query
        query = self.query.expand(batch_size, -1, -1)
        
        # Project
        Q = self.q_linear(query)
        K = self.k_linear(hidden_states)
        V = self.v_linear(hidden_states)
        
        # Reshape for multi-head
        Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        # Apply attention mask
        if attention_mask is not None:
            scores = scores.masked_fill(
                attention_mask.squeeze(1).squeeze(1).unsqueeze(1).unsqueeze(2) == 0,
                float('-inf')
            )
        
        attn_weights = torch.softmax(scores, dim=-1)
        
        # Apply attention
        context = torch.matmul(attn_weights, V)
        
        # Reshape back
        context = context.transpose(1, 2).contiguous().view(
            batch_size, -1, self.num_heads * self.head_dim
        )
        
        # Output projection
        pooled = self.out_proj(context).squeeze(1)
        
        return pooled

In [None]:
# Example usage of augmentation helpers. Set run_examples=True to execute
# (downloads models: nlpaug BERT, MarianMT for back-translation).
run_examples = True

if run_examples:
    sample_texts = df['text'].dropna().sample(2, random_state=0).tolist()

    print("BERT insert augmentation:")
    bert_aug = helper.augment_with_bert_insert(
        sample_texts,
        model_path='bert-base-uncased',
        n=1,
        aug_p=0.2,
        action='insert',
    )
    for original, augmented in zip(sample_texts, bert_aug):
        print("- Original:", original)
        print("  Augmented:", augmented)
else:
    print("Set run_examples=True to run augmentation demos (will download models).")

In [None]:
# # Check if dataset is large enough for training embeddings from scratch
# # Requirement: >50k samples for training embeddings from scratch

# THRESHOLD = 50000
# dataset_size = len(df)

# print("=" * 60)
# print("DATASET SIZE ASSESSMENT FOR EMBEDDING TRAINING")
# print("=" * 60)
# print(f"\nCurrent dataset size: {dataset_size:,} samples")
# print(f"Required threshold: {THRESHOLD:,} samples")
# print(f"\nDifference: {dataset_size - THRESHOLD:,} samples")
# print(f"Percentage of threshold: {(dataset_size / THRESHOLD) * 100:.2f}%")

# print("\n" + "=" * 60)
# if dataset_size >= THRESHOLD:
#     print("‚úÖ SUFFICIENT: Dataset meets the requirement for training embeddings from scratch")
# else:
#     print("‚ùå INSUFFICIENT: Dataset is below the recommended threshold")
#     print(f"   You need {THRESHOLD - dataset_size:,} more samples to meet the requirement")
#     print("\n   Recommendations:")
#     print("   - Consider data augmentation techniques")
#     print("   - Look for additional data sources")
#     print("   - Use smaller embedding dimensions if training anyway")
#     print("   - Consider using pre-trained embeddings (if allowed by project constraints)")
# print("=" * 60)


In [None]:
# print("Dataset Size Information:")
# print(f"Shape (rows, columns): {df.shape}")
# print(f"Number of rows: {df.shape[0]:,}")
# print(f"Number of columns: {df.shape[1]}")
# print(f"\nColumn names: {list(df.columns)}")
# print(f"\nMemory usage:")
# print(df.memory_usage(deep=True))
# print(f"\nTotal memory usage: {df.memory_usage(deep=True).sum() / 1024**2:.2f} MB")

In [None]:
# # Calculate dictionary/vocabulary size (number of unique words)
# # This is crucial for embedding training: vocab_size √ó embedding_dim = embedding matrix size

# from collections import Counter
# import re

# # Use cleaned text if available, otherwise use original text
# text_column = 'text_no_urls' if 'text_no_urls' in df.columns else 'text'

# # Combine all text and convert to lowercase
# all_text = ' '.join(df[text_column].astype(str))

# # Tokenize: split by whitespace and remove punctuation
# words = re.findall(r'\b\w+\b', all_text.lower())

# # Count unique words (dictionary/vocabulary size)
# unique_words = set(words)
# vocab_size = len(unique_words)
# word_counts = Counter(words)

# print("=" * 60)
# print("DICTIONARY/VOCABULARY SIZE ANALYSIS")
# print("=" * 60)
# print(f"\nüìö Dictionary Size (Vocabulary Size): {vocab_size:,} unique words")
# print(f"üìù Total word tokens: {len(words):,}")
# print(f"üìä Average words per review: {len(words) / len(df):.2f}")
# print(f"üìà Vocabulary coverage: {(vocab_size / len(words)) * 100:.4f}% (unique/total)")

# print(f"\nüîù Most frequent words (top 20):")
# for i, (word, count) in enumerate(word_counts.most_common(20), 1):
#     percentage = (count / len(words)) * 100
#     print(f"  {i:2d}. {word:15s} : {count:8,} occurrences ({percentage:5.2f}%)")

# print(f"\nüí° Embedding Matrix Size Estimation:")
# print(f"   For embedding_dim = 100: {vocab_size:,} √ó 100 = {vocab_size * 100:,} parameters")
# print(f"   For embedding_dim = 200: {vocab_size:,} √ó 200 = {vocab_size * 200:,} parameters")
# print(f"   For embedding_dim = 300: {vocab_size:,} √ó 300 = {vocab_size * 300:,} parameters")

# print("=" * 60)

In [3]:
import helper

text = "The quick brown fox jumps over the lazy dog."

# Use BERT word-level insert (fast and avoids XLNet tokenization bug)
augmented_texts = helper.augment_with_bert_insert(
    [text],
    model_path="bert-base-uncased",
    n=3,
    aug_p=0.2,
    action="insert",
)

print("Original:")
print(text)
print("Augmented Texts:")
for i, aug in enumerate(augmented_texts, 1):
    print(f"{i}: {aug}")

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

The following layers were not sharded: bert.encoder.layer.*.attention.self.value.weight, bert.encoder.layer.*.attention.output.LayerNorm.bias, cls.predictions.transform.dense.weight, bert.encoder.layer.*.intermediate.dense.weight, bert.encoder.layer.*.output.dense.bias, cls.predictions.transform.LayerNorm.bias, cls.predictions.transform.LayerNorm.weight, bert.encoder.layer.*.attention.self.key.bias, bert.embeddings.position_embeddings.weight, bert.encoder.layer.*.attention.self.value.bias, bert.encoder.layer.*.attention.self.key.weight, bert.embeddings.LayerNorm.bias, bert.encoder.layer.*.attention.output.LayerNorm.weight, cls.predictions.transform.dense.bias, bert.encoder.layer.*.attention.self.query.weight, cls.predictions.decoder.bias, bert.embeddings.LayerNorm.weight, bert.embeddings.token_type_embeddings.weight, bert.encoder.layer.*.attention.self.query.bias, cls.predictions.bias, bert.encoder.layer.*.attention.output.dense.bias, bert.encoder.layer.*.output.dense.weight, bert.enco

Original:
The quick brown fox jumps over the lazy dog.
Augmented Texts:
1: the big quick talking brown fox jumps over the lazy dog.
2: the quick and brown white fox jumps over the lazy dog.
3: lucky the cute quick brown fox jumps over the lazy dog.
