# DiffusionBERT Training

This notebook implements the training pipeline for DiffusionBERT.

In [None]:
# Install required packages
!pip install transformers datasets torch tqdm accelerate

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import sys
import torch
import logging
import random
import numpy as np
from transformers import BertTokenizer, BertConfig
from torch.optim import AdamW
from tqdm.notebook import tqdm
import json
from datetime import datetime

# Setup logging
logging.basicConfig(
    format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
    datefmt='%m/%d/%Y %H:%M:%S',
    level=logging.INFO
)
logger = logging.getLogger(__name__)

In [None]:
class TrainingConfig:
    def __init__(self):
        # Model settings
        self.model_name = "bert-base-uncased"
        self.max_seq_length = 128
        self.vocab_size = 30522
        
        # Training settings
        self.batch_size = 32
        self.learning_rate = 1e-4
        self.num_train_epochs = 3
        self.warmup_steps = 10000
        self.gradient_accumulation_steps = 2
        self.fp16 = True
        self.seed = 42
        
        # Diffusion settings
        self.diffusion_steps = 2000
        self.noise_schedule = "cosine"
        self.word_freq_lambda = 0.3
        
        # Paths
        self.output_dir = "/content/drive/MyDrive/DiffusionBERT/training_output"
        self.word_freq_path = "/content/drive/MyDrive/DiffusionBERT/word_freqs/word_freq.pt"
        self.train_data_dir = "/content/drive/MyDrive/DiffusionBERT/data"
        
        # Device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Create output directory
        os.makedirs(self.output_dir, exist_ok=True)

config = TrainingConfig()

In [None]:
def set_seed(seed):
    """Set random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

In [None]:
def setup_model_and_tokenizer(config):
    """Initialize model and tokenizer"""
    try:
        logger.info(f"Loading tokenizer from {config.model_name}")
        tokenizer = BertTokenizer.from_pretrained(config.model_name)
        
        logger.info("Initializing model configuration")
        model_config = BertConfig.from_pretrained(config.model_name)
        model_config.vocab_size = config.vocab_size
        
        logger.info("Initializing model")
        from models.modeling_bert import BertForMaskedLM
        model = BertForMaskedLM(model_config)
        model = model.to(config.device)
        
        return model, tokenizer
    
    except Exception as e:
        logger.error(f"Error setting up model and tokenizer: {str(e)}")
        raise

In [None]:
def setup_data_loader(config, tokenizer):
    """Setup data loading pipeline"""
    try:
        from dataloader import DiffusionLoader
        loader = DiffusionLoader(tokenizer)
        
        logger.info("Loading datasets")
        train_data, dev_data = loader.my_load("lm1b", splits=["train", "validation"])
        
        # Create data loaders
        train_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )
        
        dev_loader = torch.utils.data.DataLoader(
            dev_data,
            batch_size=config.batch_size * 2,
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )
        
        return train_loader, dev_loader
    
    except Exception as e:
        logger.error(f"Error setting up data loaders: {str(e)}")
        raise

In [None]:
def train_epoch(model, train_loader, optimizer, scheduler, config, epoch):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    step = 0
    
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
    
    for batch in progress_bar:
        try:
            # Move batch to device
            input_ids = batch['input_ids'].to(config.device)
            attention_mask = batch['attention_mask'].to(config.device)
            
            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=input_ids
            )
            
            loss = outputs.loss / config.gradient_accumulation_steps
            
            # Backward pass
            if config.fp16:
                with torch.cuda.amp.autocast():
                    loss.backward()
            else:
                loss.backward()
            
            total_loss += loss.item()
            
            # Update weights
            if (step + 1) % config.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
            
            # Update progress bar
            progress_bar.set_postfix({
                'loss': total_loss / (step + 1),
                'lr': scheduler.get_last_lr()[0]
            })
            
            step += 1
            
        except Exception as e:
            logger.error(f"Error in training step: {str(e)}")
            continue
    
    return total_loss / step

In [None]:
def evaluate(model, dev_loader, config):
    """Evaluate the model"""
    model.eval()
    total_loss = 0
    total_steps = 0
    
    with torch.no_grad():
        for batch in tqdm(dev_loader, desc="Evaluating"):
            try:
                # Move batch to device
                input_ids = batch['input_ids'].to(config.device)
                attention_mask = batch['attention_mask'].to(config.device)
                
                # Forward pass
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=input_ids
                )
                
                total_loss += outputs.loss.item()
                total_steps += 1
                
            except Exception as e:
                logger.error(f"Error in evaluation step: {str(e)}")
                continue
    
    return total_loss / total_steps

In [None]:
def save_checkpoint(model, optimizer, scheduler, epoch, loss, config):
    """Save model checkpoint"""
    try:
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': loss,
            'config': vars(config)
        }
        
        path = os.path.join(config.output_dir, f'checkpoint-epoch-{epoch+1}.pt')
        torch.save(checkpoint, path)
        logger.info(f"Checkpoint saved to {path}")
        
    except Exception as e:
        logger.error(f"Error saving checkpoint: {str(e)}")
        raise

In [None]:
def main():
    try:
        # Set random seed
        set_seed(config.seed)
        
        # Setup model and tokenizer
        model, tokenizer = setup_model_and_tokenizer(config)
        
        # Setup data loaders
        train_loader, dev_loader = setup_data_loader(config, tokenizer)
        
        # Setup optimizer and scheduler
        optimizer = AdamW(model.parameters(), lr=config.learning_rate)
        scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer,
            total_iters=config.warmup_steps
        )
        
        # Setup mixed precision training if enabled
        if config.fp16:
            scaler = torch.cuda.amp.GradScaler()
        
        # Training loop
        best_loss = float('inf')
        
        for epoch in range(config.num_train_epochs):
            # Train
            train_loss = train_epoch(model, train_loader, optimizer, scheduler, config, epoch)
            logger.info(f"Epoch {epoch+1} - Training Loss: {train_loss:.4f}")
            
            # Evaluate
            eval_loss = evaluate(model, dev_loader, config)
            logger.info(f"Epoch {epoch+1} - Evaluation Loss: {eval_loss:.4f}")
            
            # Save checkpoint if best model
            if eval_loss < best_loss:
                best_loss = eval_loss
                save_checkpoint(model, optimizer, scheduler, epoch, eval_loss, config)
                logger.info(f"New best model saved with loss: {best_loss:.4f}")
            
            # Save regular checkpoint
            if (epoch + 1) % 10 == 0:
                save_checkpoint(model, optimizer, scheduler, epoch, eval_loss, config)
        
        logger.info("Training completed!")
        
    except Exception as e:
        logger.error(f"Error in main execution: {str(e)}")
        raise

In [None]:
if __name__ == "__main__":
    main()