# üéì WeDLM Training Tutorial

Fine-tune a pretrained AR model into WeDLM using **Causal Masked Language Modeling (CMLM)**:

$$\mathcal{L}_{\text{CMLM}} = -\mathbb{E}_{x} \left[\sum_{j \in \mathcal{M}} \log P_\theta(x_j \mid x_{<j} \cap \mathcal{M}^c)\right]$$

## üìã Contents
1. Environment Setup
2. Configuration  
3. Data Preparation with Span Masking
4. Model Loading
5. Training Loop
6. Test Generation
7. Save Model


In [None]:
# 1Ô∏è‚É£ Install Dependencies & Clone Repo
!pip install -q torch transformers datasets tqdm accelerate

# Clone the repository (for wedlm package)
import os
if not os.path.exists('05_WeDLM_Reconciling_Diffusion_with_Causal_Attention'):
    !git clone https://github.com/Gaurav14cs17/05_WeDLM_Reconciling_Diffusion_with_Causal_Attention.git
    
# Add to Python path
import sys
sys.path.insert(0, '05_WeDLM_Reconciling_Diffusion_with_Causal_Attention')

# Verify wedlm import
try:
    from wedlm import LLM, SamplingParams
    print("‚úÖ wedlm package imported successfully!")
except ImportError as e:
    print(f"‚ö†Ô∏è wedlm import failed: {e}")
    print("Continuing with standalone implementation...")

!nvidia-smi


In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import numpy as np
from tqdm.auto import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# Configuration
MODEL = "Qwen/Qwen2.5-0.5B"
MAX_LEN = 256
BATCH = 4
EPOCHS = 2
LR = 2e-5
MASK_TOKEN = "<|mask|>"


In [None]:
# Span Masking Function
def span_mask(ids, mask_id, ratio=0.3):
    n = len(ids)
    masked = ids.clone()
    flags = torch.zeros(n, dtype=torch.bool)
    target = int(n * ratio)
    count = 0
    while count < target:
        span = min(np.random.geometric(0.3), n-1)
        start = np.random.randint(0, n - span + 1)
        for i in range(start, start + span):
            if not flags[i]:
                masked[i] = mask_id
                flags[i] = True
                count += 1
                if count >= target: break
    return masked, ids, flags

# Dataset
class MLMDataset(Dataset):
    def __init__(self, tok, texts, max_len, mask_id):
        self.tok, self.texts, self.max_len, self.mask_id = tok, texts, max_len, mask_id
    def __len__(self): return len(self.texts)
    def __getitem__(self, i):
        enc = self.tok(self.texts[i], truncation=True, max_length=self.max_len, 
                       padding="max_length", return_tensors="pt")
        ids = enc["input_ids"].squeeze(0)
        masked, labels, flags = span_mask(ids, self.mask_id, np.random.uniform(0.1, 0.5))
        return {"input_ids": masked, "attention_mask": enc["attention_mask"].squeeze(0),
                "labels": labels, "mask_flags": flags}


In [None]:
# Load Model & Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
if MASK_TOKEN not in tokenizer.get_vocab():
    tokenizer.add_special_tokens({"additional_special_tokens": [MASK_TOKEN]})
mask_id = tokenizer.convert_tokens_to_ids(MASK_TOKEN)
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(MODEL, trust_remote_code=True,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
model.resize_token_embeddings(len(tokenizer))
model = model.to(device)
print("‚úÖ Model loaded")


In [None]:
# Load Data
data = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
texts = [t for t in data["text"] if len(t.strip()) > 50][:500]
dataset = MLMDataset(tokenizer, texts, MAX_LEN, mask_id)
loader = DataLoader(dataset, batch_size=BATCH, shuffle=True)
print(f"‚úÖ {len(texts)} training samples")


In [None]:
# Training Loop
opt = AdamW(model.parameters(), lr=LR)
use_amp = torch.cuda.is_available()
scaler = torch.amp.GradScaler('cuda') if use_amp else None

for epoch in range(EPOCHS):
    model.train()
    total = 0
    pbar = tqdm(loader, desc=f"Epoch {epoch+1}")
    for batch in pbar:
        ids = batch["input_ids"].to(device)
        mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        flags = batch["mask_flags"].to(device)
        
        if use_amp:
            with torch.amp.autocast('cuda'):
                logits = model(ids, attention_mask=mask).logits
                shift_logits = logits[:, :-1, :].reshape(-1, logits.size(-1))
                shift_labels = labels[:, 1:].reshape(-1)
                shift_flags = flags[:, 1:].reshape(-1)
                if shift_flags.sum() > 0:
                    loss = F.cross_entropy(shift_logits[shift_flags], shift_labels[shift_flags])
                else:
                    loss = torch.tensor(0.0, device=device, requires_grad=True)
        else:
            logits = model(ids, attention_mask=mask).logits
            shift_logits = logits[:, :-1, :].reshape(-1, logits.size(-1))
            shift_labels = labels[:, 1:].reshape(-1)
            shift_flags = flags[:, 1:].reshape(-1)
            if shift_flags.sum() > 0:
                loss = F.cross_entropy(shift_logits[shift_flags], shift_labels[shift_flags])
            else:
                loss = torch.tensor(0.0, device=device, requires_grad=True)
        
        opt.zero_grad()
        if scaler:
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
        else:
            loss.backward()
            opt.step()
        
        total += loss.item()
        pbar.set_postfix({"loss": f"{total/(pbar.n+1):.4f}"})
    
    print(f"Epoch {epoch+1}: Loss = {total/len(loader):.4f}")

print("\n‚úÖ Training Complete!")


In [None]:
# Save Model
model.save_pretrained("./wedlm_model")
tokenizer.save_pretrained("./wedlm_model")
print("‚úÖ Model saved to ./wedlm_model")


---

## üìö Detailed Version (Alternative Implementation)

The cells above provide a compact implementation. Below is a more detailed version with extensive comments.


In [None]:
# 1Ô∏è‚É£ Environment Setup (Detailed Version)
!pip install -q torch transformers datasets accelerate tqdm

# Clone repo and setup wedlm import (if not already done)
import os, sys
if not os.path.exists('05_WeDLM_Reconciling_Diffusion_with_Causal_Attention'):
    !git clone https://github.com/Gaurav14cs17/05_WeDLM_Reconciling_Diffusion_with_Causal_Attention.git
sys.path.insert(0, '05_WeDLM_Reconciling_Diffusion_with_Causal_Attention')

# Import wedlm
try:
    from wedlm import LLM, SamplingParams
    print("‚úÖ wedlm package imported!")
except ImportError:
    print("‚ö†Ô∏è Using standalone implementation")

!nvidia-smi


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import numpy as np
from tqdm.auto import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 2Ô∏è‚É£ Configuration
class Config:
    model_name = "Qwen/Qwen2.5-0.5B"  # Small model for demo (use larger for production)
    max_seq_len = 256
    batch_size = 4
    gradient_accumulation_steps = 4
    learning_rate = 2e-5
    num_epochs = 2
    mask_ratio_min = 0.1
    mask_ratio_max = 0.5
    mask_token = "<|mask|>"

config = Config()


In [None]:
# 3Ô∏è‚É£ Random Span Masking Function
def random_span_masking(input_ids, mask_token_id, mask_ratio=0.3, span_mean=3):
    """
    Apply random span masking to input sequence.
    
    Original:  [The] [quick] [brown] [fox] [jumps]
    Masked:    [The] [MASK] [MASK] [fox] [jumps]
    """
    seq_len = len(input_ids)
    num_to_mask = int(seq_len * mask_ratio)
    
    masked_ids = input_ids.clone()
    mask_flags = torch.zeros(seq_len, dtype=torch.bool)
    
    positions_masked = 0
    attempts = 0
    
    while positions_masked < num_to_mask and attempts < seq_len * 10:
        attempts += 1
        span_len = min(np.random.geometric(p=1/span_mean), seq_len - 1)
        start = np.random.randint(0, seq_len - span_len + 1)
        
        for i in range(start, start + span_len):
            if not mask_flags[i]:
                masked_ids[i] = mask_token_id
                mask_flags[i] = True
                positions_masked += 1
                if positions_masked >= num_to_mask:
                    break
    
    return masked_ids, input_ids, mask_flags

# 4Ô∏è‚É£ Dataset Class
class CausalMLMDataset(Dataset):
    def __init__(self, tokenizer, texts, max_length, mask_token_id):
        self.tokenizer = tokenizer
        self.texts = texts
        self.max_length = max_length
        self.mask_token_id = mask_token_id
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        encoding = self.tokenizer(
            self.texts[idx], truncation=True, max_length=self.max_length,
            padding="max_length", return_tensors="pt"
        )
        input_ids = encoding["input_ids"].squeeze(0)
        attention_mask = encoding["attention_mask"].squeeze(0)
        
        mask_ratio = np.random.uniform(config.mask_ratio_min, config.mask_ratio_max)
        masked_ids, labels, mask_flags = random_span_masking(input_ids, self.mask_token_id, mask_ratio)
        
        return {"input_ids": masked_ids, "attention_mask": attention_mask, 
                "labels": labels, "mask_flags": mask_flags}


In [None]:
# 5Ô∏è‚É£ Load Tokenizer and Data
print(f"Loading tokenizer: {config.model_name}")
tokenizer = AutoTokenizer.from_pretrained(config.model_name, trust_remote_code=True)

# Add mask token
if config.mask_token not in tokenizer.get_vocab():
    tokenizer.add_special_tokens({"additional_special_tokens": [config.mask_token]})
mask_token_id = tokenizer.convert_tokens_to_ids(config.mask_token)
print(f"Mask token ID: {mask_token_id}")

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load dataset
print("Loading dataset...")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
texts = [t for t in dataset["text"] if len(t.strip()) > 50][:500]  # Subset for demo
print(f"Using {len(texts)} training examples")

train_dataset = CausalMLMDataset(tokenizer, texts, config.max_seq_len, mask_token_id)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)


In [None]:
# 6Ô∏è‚É£ Load Model
print(f"Loading model: {config.model_name}")
model = AutoModelForCausalLM.from_pretrained(
    config.model_name, trust_remote_code=True,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
model.resize_token_embeddings(len(tokenizer))
model = model.to(device)
print(f"Model loaded on {device}")

# 7Ô∏è‚É£ Loss Function
def compute_cmlm_loss(model, batch, mask_token_id):
    """Compute loss only on masked positions with causal attention."""
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    labels = batch["labels"].to(device)
    mask_flags = batch["mask_flags"].to(device)
    
    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    logits = outputs.logits
    
    # Shift for next-token prediction
    shift_logits = logits[:, :-1, :].reshape(-1, logits.size(-1))
    shift_labels = labels[:, 1:].reshape(-1)
    shift_mask = mask_flags[:, 1:].reshape(-1)
    
    # Loss only on masked positions
    if shift_mask.sum() == 0:
        return torch.tensor(0.0, device=device)
    
    return F.cross_entropy(shift_logits[shift_mask], shift_labels[shift_mask])


In [None]:
# 8Ô∏è‚É£ Training Loop
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=config.learning_rate, weight_decay=0.01)
use_amp = torch.cuda.is_available()
scaler = torch.amp.GradScaler('cuda') if use_amp else None

print("=" * 60)
print("STARTING TRAINING")
print("=" * 60)

for epoch in range(config.num_epochs):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
    optimizer.zero_grad()
    
    for step, batch in enumerate(progress_bar):
        if use_amp:
            with torch.amp.autocast('cuda'):
                loss = compute_cmlm_loss(model, batch, mask_token_id)
                loss = loss / config.gradient_accumulation_steps
        else:
            loss = compute_cmlm_loss(model, batch, mask_token_id)
            loss = loss / config.gradient_accumulation_steps
        
        if scaler:
            scaler.scale(loss).backward()
        else:
            loss.backward()
        
        if (step + 1) % config.gradient_accumulation_steps == 0:
            if scaler:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()
            optimizer.zero_grad()
        
        total_loss += loss.item() * config.gradient_accumulation_steps
        progress_bar.set_postfix({"loss": f"{total_loss/(step+1):.4f}"})
    
    print(f"Epoch {epoch+1} - Avg Loss: {total_loss/len(train_loader):.4f}")

print("\n‚úÖ TRAINING COMPLETE!")


In [None]:
# 9Ô∏è‚É£ Test Generation (Simplified WeDLM Decoding)
@torch.no_grad()
def wedlm_generate(model, tokenizer, prompt, max_tokens=30, window_size=8, entropy_threshold=0.5):
    """Simplified WeDLM generation demo."""
    model.eval()
    
    prompt_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    generated = prompt_ids[0].tolist()
    window = [mask_token_id] * window_size
    
    for _ in range(max_tokens // window_size + 1):
        input_ids = torch.tensor([generated + window], device=device)
        logits = model(input_ids).logits[0]
        
        window_logits = logits[len(generated)-1:len(generated)+window_size-1]
        probs = F.softmax(window_logits, dim=-1)
        entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=-1)
        
        # Fill low-entropy positions
        predicted = probs.argmax(dim=-1)
        fill_mask = entropy < entropy_threshold
        if not fill_mask.any():
            fill_mask[0] = True
        
        for i in range(window_size):
            if fill_mask[i]:
                window[i] = predicted[i].item()
        
        # Commit prefix
        commit = 0
        for i, tok in enumerate(window):
            if tok != mask_token_id:
                commit += 1
            else:
                break
        if commit == 0:
            commit = 1
        
        generated.extend(window[:commit])
        if tokenizer.eos_token_id in window[:commit]:
            break
        window = window[commit:] + [mask_token_id] * commit
    
    return tokenizer.decode(generated, skip_special_tokens=True)

# Test
print("=" * 60)
print("TESTING GENERATION")
print("=" * 60)
for prompt in ["The quick brown", "Machine learning is"]:
    print(f"\nPrompt: {prompt}")
    print(f"Output: {wedlm_generate(model, tokenizer, prompt)}")


In [None]:
# üîü Save Model
output_dir = "./wedlm_finetuned"
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"\n‚úÖ Model saved to {output_dir}")
print("\nüìö Training complete! Next: Use the Inference notebook for optimized generation.")
