# Energy-based Diffusion Model for Text Generation

### Step 1: Imports

The working environment is in `environment.yml` but it may contain a lot of unrelated dependencies.

In [None]:
import os
import torch
import torch.nn.functional as F
from transformers import get_cosine_schedule_with_warmup
from EDLM import EnergyDiffusionModel
from dataset import get_dataloader, decode_tokens
from config import EDLMConfig
from tqdm import tqdm

### Step 2: Initialize the model

In [None]:
config = EDLMConfig()
device = "cuda:0" if torch.cuda.is_available() else "cpu"

_, tokenizer, mask_token_id = get_dataloader(
    batch_size=1,
    sequence_length=config.sequence_length
)

model = EnergyDiffusionModel(
    vocab_size=len(tokenizer),
    max_seq_length=config.sequence_length,
    hidden_size=config.hidden_size,
    num_layers=config.num_layers,
    num_heads=config.num_heads,
    mask_token_id=mask_token_id,
    dropout=config.dropout,
    num_timesteps=config.num_timesteps,
    importance_sampling_size=config.importance_sampling_size,
    importance_sampling_window=config.importance_sampling_window,
    temperature=config.temperature
).to(device)

Load from checkpoint:

In [None]:
checkpoint_path = "checkpoints/last_model.pt" # Replace with your checkpoint path
model.load_state_dict(torch.load(checkpoint_path, map_location=device, weights_only=False))

### Step 3: Predicting masked tokens

In [None]:
sequence_length = 128

passage = "Today I went to the store and bought some apples. Then I went home and watched TV."

_, tokenizer, mask_token_id = get_dataloader(batch_size=1, sequence_length=sequence_length)

sequence = tokenizer.encode(passage)
padding = [tokenizer.pad_token_id] * (sequence_length - len(sequence))
sequence = sequence + padding
sequence = torch.tensor(sequence, dtype=torch.long)
print(sequence)

print(decode_tokens(sequence, tokenizer))

In [None]:
torch.manual_seed(1)

t = torch.tensor([100]).to(device)
sequence = sequence.to(device)
masked_sequence, mask = model.forward(sequence.unsqueeze(0), t)
print(decode_tokens(masked_sequence[0], tokenizer))

Highest predictions for "," / "the" / "\n" — some bug during training

In [None]:
logits = model.backward(masked_sequence, t)[mask]
denoised_sequence = masked_sequence.clone()
print(decode_tokens(denoised_sequence[0], tokenizer))
denoised_sequence[mask] = torch.argmax(logits, dim=-1)
print(decode_tokens(denoised_sequence[0], tokenizer))

In [None]:
logits

### Step 4: Generating text

In [None]:
model.eval()

generated_tokens = model.generate(
    batch_size=1,
    temperature=2.0,
    device=device
)

print(decode_tokens(generated_tokens[0], tokenizer))