**Continuous Pretraining / Domain Adaptation** - This adapts a general-purpose BERT model to understand your specific domain (medical, legal, finance, etc.) better.

### The Process:

1. **Starts with pretrained BERT** - Uses `bert-base-uncased` which already understands general English

2. **Masked Language Modeling (MLM)** - Randomly masks 15% of words in your domain text:
   - "Neural networks use **[MASK]** to update weights" 
   - Model learns to predict the masked word ("backpropagation")

3. **Learns Domain Vocabulary & Context** - By training on your domain-specific texts, BERT learns:
   - Domain-specific terminology
   - How words are used in your context
   - Relationships between domain concepts

### What You Get as Output:

**During Training:**
```
Epoch 1/5: 100%|██████| 3/3 [00:05<00:00]
loss: 3.2451, avg_loss: 3.1892

Epoch 2/5: 100%|██████| 3/3 [00:05<00:00]
loss: 2.8934, avg_loss: 2.7231
...
```

**After Training:**
1. **A fine-tuned BERT model** saved to `./domain_adapted_bert/` folder
   - This model now "speaks" your domain language better
   - Better understanding of domain-specific context

2. **Test Predictions** showing the model understands your domain:
```
Predictions for masked tokens:

Position 4:
  backpropagation: 8.2341
  gradients: 6.5432
  algorithms: 5.8901
  optimization: 5.2341
```

### What Can You Use It For?

After domain adaptation, use this model for:
- **Text Classification** (sentiment, category, intent)
- **Named Entity Recognition** (extract domain entities)
- **Question Answering** (domain-specific Q&A)
- **Semantic Search** (better embeddings for your domain)
- **Any downstream NLP task** in your domain

**The key benefit**: Instead of training BERT from scratch on your domain data (very expensive), you're teaching the pretrained BERT to "specialize" in your domain - much faster and cheaper!

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
    BertTokenizer, 
    BertForMaskedLM, 
    BertConfig,
    get_linear_schedule_with_warmup
)
from torch.optim import AdamW
from typing import List, Dict, Tuple
import numpy as np
from tqdm import tqdm
import random

class DomainTextDataset(Dataset):
    """Dataset for domain-specific text with MLM objective"""
    
    def __init__(self, texts: List[str], tokenizer: BertTokenizer, 
                 max_length: int = 512, mlm_prob: float = 0.15):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.mlm_prob = mlm_prob
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        
        # Tokenize
        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        input_ids = encoding['input_ids'].squeeze()
        attention_mask = encoding['attention_mask'].squeeze()
        
        # Create labels and apply masking
        labels = input_ids.clone()
        masked_input_ids = input_ids.clone()
        
        # Create probability matrix for masking
        prob_matrix = torch.full(labels.shape, self.mlm_prob)
        
        # Don't mask special tokens
        special_tokens_mask = torch.tensor(
            self.tokenizer.get_special_tokens_mask(
                input_ids.tolist(), 
                already_has_special_tokens=True
            )
        )
        prob_matrix.masked_fill_(special_tokens_mask.bool(), value=0.0)
        
        # Get masked indices
        masked_indices = torch.bernoulli(prob_matrix).bool()
        
        # Set labels to -100 for non-masked tokens (ignored in loss)
        labels[~masked_indices] = -100
        
        # Apply masking strategy (80% [MASK], 10% random, 10% original)
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        masked_input_ids[indices_replaced] = self.tokenizer.mask_token_id
        
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        masked_input_ids[indices_random] = random_words[indices_random]
        
        return {
            'input_ids': masked_input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }


class BERTDomainAdapter:
    """Continuous pretraining wrapper for BERT domain adaptation"""
    
    def __init__(self, model_name: str = 'bert-base-uncased', device: str = None):
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.model = BertForMaskedLM.from_pretrained(model_name).to(self.device)
        
    def prepare_domain_data(self, domain_texts: List[str], 
                           batch_size: int = 16, 
                           max_length: int = 512,
                           mlm_prob: float = 0.15) -> DataLoader:
        """Prepare domain-specific dataset"""
        dataset = DomainTextDataset(
            domain_texts, 
            self.tokenizer, 
            max_length=max_length,
            mlm_prob=mlm_prob
        )
        
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=0
        )
        
        return dataloader
    
    def train(self, 
              dataloader: DataLoader,
              epochs: int = 3,
              learning_rate: float = 5e-5,
              warmup_steps: int = 500,
              weight_decay: float = 0.01,
              gradient_accumulation_steps: int = 1,
              max_grad_norm: float = 1.0):
        """Continuous pretraining with MLM objective"""
        
        # Prepare optimizer
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {
                'params': [p for n, p in self.model.named_parameters() 
                          if not any(nd in n for nd in no_decay)],
                'weight_decay': weight_decay
            },
            {
                'params': [p for n, p in self.model.named_parameters() 
                          if any(nd in n for nd in no_decay)],
                'weight_decay': 0.0
            }
        ]
        
        optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
        
        # Calculate total steps
        total_steps = len(dataloader) * epochs // gradient_accumulation_steps
        
        # Scheduler
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps
        )
        
        # Training loop
        self.model.train()
        global_step = 0
        total_loss = 0
        
        print(f"Starting continuous pretraining for {epochs} epochs...")
        print(f"Device: {self.device}")
        
        for epoch in range(epochs):
            epoch_loss = 0
            progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
            
            for step, batch in enumerate(progress_bar):
                # Move batch to device
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['labels'].to(self.device)
                
                # Forward pass
                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                
                loss = outputs.loss / gradient_accumulation_steps
                loss.backward()
                
                total_loss += loss.item()
                epoch_loss += loss.item()
                
                # Gradient accumulation
                if (step + 1) % gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), 
                        max_grad_norm
                    )
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
                    global_step += 1
                
                # Update progress bar
                progress_bar.set_postfix({
                    'loss': f'{loss.item() * gradient_accumulation_steps:.4f}',
                    'avg_loss': f'{total_loss / (step + 1):.4f}'
                })
            
            avg_epoch_loss = epoch_loss / len(dataloader)
            print(f"Epoch {epoch+1} completed. Average loss: {avg_epoch_loss:.4f}")
        
        print("Training completed!")
        return total_loss / len(dataloader) / epochs
    
    def save_model(self, output_dir: str):
        """Save fine-tuned model"""
        self.model.save_pretrained(output_dir)
        self.tokenizer.save_pretrained(output_dir)
        print(f"Model saved to {output_dir}")
    
    def load_model(self, model_dir: str):
        """Load fine-tuned model"""
        self.model = BertForMaskedLM.from_pretrained(model_dir).to(self.device)
        self.tokenizer = BertTokenizer.from_pretrained(model_dir)
        print(f"Model loaded from {model_dir}")
    
    def predict_masked_tokens(self, text: str, top_k: int = 5) -> List[Dict]:
        """Predict masked tokens in text"""
        self.model.eval()
        
        # Tokenize with mask
        inputs = self.tokenizer(text, return_tensors='pt').to(self.device)
        
        with torch.no_grad():
            outputs = self.model(**inputs)
            predictions = outputs.logits
        
        # Get mask token positions
        mask_token_index = torch.where(inputs['input_ids'] == self.tokenizer.mask_token_id)[1]
        
        results = []
        for idx in mask_token_index:
            mask_token_logits = predictions[0, idx, :]
            top_tokens = torch.topk(mask_token_logits, top_k, dim=0)
            
            predictions_list = []
            for token_id, score in zip(top_tokens.indices, top_tokens.values):
                predictions_list.append({
                    'token': self.tokenizer.decode([token_id]),
                    'score': score.item()
                })
            
            results.append({
                'position': idx.item(),
                'predictions': predictions_list
            })
        
        return results


# Example usage
if __name__ == "__main__":
    # Sample domain-specific texts (replace with your domain data)
    domain_texts = [
        "Machine learning models require careful hyperparameter tuning for optimal performance.",
        "Neural networks use backpropagation to update weights during training.",
        "Deep learning architectures like transformers have revolutionized NLP tasks.",
        "The attention mechanism allows models to focus on relevant parts of the input.",
        "Transfer learning enables models to leverage knowledge from pretrained weights.",
        "Regularization techniques like dropout help prevent overfitting in neural networks.",
        "Gradient descent optimizes the loss function by iteratively updating parameters.",
        "Convolutional neural networks excel at processing grid-like data such as images.",
        "Recurrent neural networks are designed to handle sequential data effectively.",
        "Batch normalization stabilizes training by normalizing layer inputs."
    ]
    
    # Initialize adapter
    adapter = BERTDomainAdapter(model_name='bert-base-uncased')
    
    # Prepare data
    dataloader = adapter.prepare_domain_data(
        domain_texts,
        batch_size=4,
        max_length=128,
        mlm_prob=0.15
    )
    
    # Train (continuous pretraining)
    adapter.train(
        dataloader,
        epochs=5,
        learning_rate=5e-5,
        warmup_steps=100
    )
    
    # Save adapted model
    adapter.save_model('./domain_adapted_bert')
    
    # Test prediction
    test_text = "Neural networks use [MASK] to update weights during training."
    predictions = adapter.predict_masked_tokens(test_text)
    
    print("\nPredictions for masked tokens:")
    for pred in predictions:
        print(f"\nPosition {pred['position']}:")
        for p in pred['predictions']:
            print(f"  {p['token']}: {p['score']:.4f}")

2025-10-02 14:10:57.587222: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1759414257.914841      13 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1759414258.008082      13 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Starting continuous pretraining for 5 epochs...
Device: cpu


Epoch 1/5: 100%|██████████| 3/3 [00:12<00:00,  4.15s/it, loss=4.4839, avg_loss=3.7544]


Epoch 1 completed. Average loss: 3.7544


Epoch 2/5: 100%|██████████| 3/3 [00:10<00:00,  3.35s/it, loss=4.8484, avg_loss=7.6058]


Epoch 2 completed. Average loss: 3.8515


Epoch 3/5: 100%|██████████| 3/3 [00:09<00:00,  3.27s/it, loss=0.0173, avg_loss=9.5992]


Epoch 3 completed. Average loss: 1.9934


Epoch 4/5: 100%|██████████| 3/3 [00:10<00:00,  3.57s/it, loss=3.2637, avg_loss=12.0446]


Epoch 4 completed. Average loss: 2.4454


Epoch 5/5: 100%|██████████| 3/3 [00:09<00:00,  3.28s/it, loss=4.9802, avg_loss=15.4228]


Epoch 5 completed. Average loss: 3.3781
Training completed!
Model saved to ./domain_adapted_bert

Predictions for masked tokens:

Position 4:
  it: 8.6777
  weights: 8.4960
  this: 8.3737
  algorithms: 8.3467
  them: 8.3423
