# Diffusion Language Model Benchmark
Fine-tuning small diffusion models on E2E and ROCStories datasets

In [1]:
import torch
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from torch.utils.data import Dataset, DataLoader
import os
from pathlib import Path

## 1. Configuration and Setup

In [2]:
# Model and dataset configuration
MODEL_NAME = 'distilgpt2'  # ~82M parameters
MAX_PARAMS = 200_000_000
DATASETS_DIR = './data'
OUTPUT_DIR = './dlm_checkpoint'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Training hyperparameters
BATCH_SIZE = 8
LEARNING_RATE = 5e-5
EPOCHS = 3
MAX_LENGTH = 256

print(f'Device: {DEVICE}')
print(f'Model: {MODEL_NAME}')
print(f'Max Parameters: {MAX_PARAMS:,}')

Device: cuda
Model: distilgpt2
Max Parameters: 200,000,000


## 2. Load Tokenizer and Model

In [3]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

# Load model
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
model.to(DEVICE)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f'Total Parameters: {total_params:,}')
assert total_params <= MAX_PARAMS, f'Model exceeds {MAX_PARAMS:,} parameters'

Total Parameters: 81,912,576


## 3. Custom Dataset Class

In [4]:
class TextDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=256):
        self.tokenizer = tokenizer
        self.texts = texts
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': encoding['input_ids'].squeeze()
        }

## 4. Data Loading Functions

In [5]:
def load_e2e_dataset(data_dir):
    """Load E2E dataset from local directory"""
    e2e_path = Path(data_dir) / 'e2e'
    texts = []
    
    if e2e_path.exists():
        for file in e2e_path.glob('*.txt'):
            with open(file, 'r', encoding='utf-8') as f:
                texts.extend([line.strip() for line in f if line.strip()])
    
    return texts

def load_rocstories_dataset(data_dir):
    """Load ROCStories dataset from local directory"""
    rocstories_path = Path(data_dir) / 'rocstories'
    texts = []
    
    if rocstories_path.exists():
        for file in rocstories_path.glob('*.txt'):
            with open(file, 'r', encoding='utf-8') as f:
                texts.extend([line.strip() for line in f if line.strip()])
    
    return texts

def combine_datasets(e2e_texts, rocstories_texts):
    """Combine datasets and shuffle"""
    all_texts = e2e_texts + rocstories_texts
    np.random.shuffle(all_texts)
    return all_texts

## 5. Load and Prepare Datasets

In [6]:
# Create data directory if needed
os.makedirs(DATASETS_DIR, exist_ok=True)

# Load datasets
print('Loading E2E dataset...')
e2e_texts = load_e2e_dataset(DATASETS_DIR)
print(f'E2E samples: {len(e2e_texts)}')

print('Loading ROCStories dataset...')
rocstories_texts = load_rocstories_dataset(DATASETS_DIR)
print(f'ROCStories samples: {len(rocstories_texts)}')

# Combine datasets
all_texts = combine_datasets(e2e_texts, rocstories_texts)
print(f'Total samples: {len(all_texts)}')

# Split into train and validation
split_idx = int(0.9 * len(all_texts))
train_texts = all_texts[:split_idx]
val_texts = all_texts[split_idx:]

print(f'Train samples: {len(train_texts)}')
print(f'Validation samples: {len(val_texts)}')

Loading E2E dataset...
E2E samples: 0
Loading ROCStories dataset...
ROCStories samples: 0
Total samples: 0
Train samples: 0
Validation samples: 0


## 6. Create Data Loaders

In [7]:
# Create datasets
train_dataset = TextDataset(train_texts, tokenizer, max_length=MAX_LENGTH)
val_dataset = TextDataset(val_texts, tokenizer, max_length=MAX_LENGTH)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f'Train batches: {len(train_loader)}')
print(f'Validation batches: {len(val_loader)}')

ValueError: num_samples should be a positive integer value, but got num_samples=0

## 7. Training Setup

In [None]:
# Define training arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    overwrite_output_dir=True,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    weight_decay=0.01,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    logging_steps=100,
)

# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

print('Trainer initialized')

## 8. Fine-tune Model

In [None]:
# Train the model
print('Starting fine-tuning...')
train_result = trainer.train()
print(f'Training completed. Final loss: {train_result.training_loss:.4f}')

## 9. Inference Functions

In [None]:
def generate_text(prompt, max_length=100, num_samples=1, temperature=0.7):
    """Generate text using fine-tuned model"""
    model.eval()
    
    inputs = tokenizer.encode(prompt, return_tensors='pt').to(DEVICE)
    
    with torch.no_grad():
        outputs = model.generate(
            inputs,
            max_length=max_length,
            num_return_sequences=num_samples,
            temperature=temperature,
            top_p=0.95,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )
    
    texts = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
    return texts

def calculate_perplexity(model, data_loader):
    """Calculate perplexity on validation set"""
    model.eval()
    total_loss = 0
    total_tokens = 0
    
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)
            
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            
            total_loss += loss.item() * input_ids.size(0)
            total_tokens += input_ids.size(0)
    
    avg_loss = total_loss / total_tokens
    perplexity = torch.exp(torch.tensor(avg_loss))
    return perplexity.item()

## 10. Evaluate Model

In [None]:
# Calculate validation perplexity
print('Calculating validation perplexity...')
val_perplexity = calculate_perplexity(model, val_loader)
print(f'Validation Perplexity: {val_perplexity:.4f}')

## 11. Generate Samples

In [None]:
# Generate text samples
prompts = [
    'The story begins',
    'Once upon a time',
    'A person walks into'
]

for prompt in prompts:
    print(f'\nPrompt: {prompt}')
    samples = generate_text(prompt, max_length=100, num_samples=2)
    for i, sample in enumerate(samples, 1):
        print(f'Sample {i}: {sample}\n')

## 12. Save Model

In [None]:
# Save the fine-tuned model
model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f'Model saved to {OUTPUT_DIR}')