In [26]:
import torch
import torch.nn as nn 
import torch.nn.functional as F 
from torch import device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)
from dataclasses import dataclass
import math


Using device: cuda


In [None]:
@dataclass
class gpt2config:
    n_vocab: int = 50257
    n_layer: int = 12
    n_embed: int = 64
    n_context: int = 1024
    n_head: int = 8
    n_timesteps: int = 1000


In [28]:
class GPT2Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_attn = nn.Linear(config.n_embed, 3 * config.n_embed)
        self.c_proj = nn.Linear(config.n_embed, config.n_embed)
        self.n_head = config.n_head
        self.n_embed = config.n_embed
        
        # Create a causal mask (lower triangular matrix) and register it as a buffer
        # A buffer is not a parameter, but is saved with the model state_dict
        self.register_buffer("bias", torch.tril(torch.ones(config.n_context, config.n_context))
                                     .view(1, 1, config.n_context, config.n_context))

    def forward(self, x):
        B, T, C = x.size()
        
        # Calculate query, key, values for all heads in batch
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embed, dim=2)
        
        # Reshape for multi-head attention: (B, nh, T, hs)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        # Scaled dot-product attention
        att = (q @ k.transpose(-2, -1)) * (1.0 / (k.size(-1) ** 0.5))
        
        # --- MASKING STARTS HERE ---
        # Apply the causal mask: fill "future" positions with -infinity
        # This makes their softmax probability zero.
        # att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
        # --- MASKING ENDS HERE ---

        att = F.softmax(att, dim=-1)
        y = att @ v # (B, nh, T, hs)
        
        # Re-assemble all head outputs side-by-side
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        
        # Output projection
        y = self.c_proj(y)
        return y
    
class GPT2MLP(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embed, 4*config.n_embed)
        self.act = nn.GELU(approximate="tanh")
        self.c_proj = nn.Linear(4*config.n_embed, config.n_embed)

    def forward(self,x):
        x = self.c_fc(x)
        x = self.act(x)
        x = self.c_proj(x)
        return x
    

class Block(nn.Module):
    def __init__(self,config):
        super().__init__()
        
        self.ln1 = nn.LayerNorm(config.n_embed,eps=1e-5,elementwise_affine=True)
        self.attn = GPT2Attention(config)
        self.ln2 = nn.LayerNorm(config.n_embed,eps=1e-5,elementwise_affine=True)
        self.mlp = GPT2MLP(config)

    def forward(self,x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


In [29]:
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) #
        # TODO: Double check the ordering here
        return embeddings

In [30]:
sine_embeds = SinusoidalPositionEmbeddings(100)
time = 10
time = torch.tensor([time], device=device)
out = sine_embeds(time)
out.shape

torch.Size([1, 100])

In [31]:
class LMEmbedding(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config = config

        self.embed = nn.Embedding(config.n_vocab,config.n_embed)
    
    def forward(self,input_ids):
        x = self.embed(input_ids)
        
        return x
        


In [32]:
class Denoiser(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            # wte = nn.Embedding(config.n_vocab,config.n_embed),
            wpe = nn.Embedding(config.n_context,config.n_embed),
            drop = nn.Dropout(0.1,inplace=False),
            h = nn.ModuleList(Block(config) for _ in range(config.n_layer)),
            ln_f = nn.LayerNorm(config.n_embed,eps=1e-5,elementwise_affine=True)
        ))
        
        # self.lm_head = nn.Linear(config.n_embed, config.n_vocab, bias=False)

        self.small_mlp = nn.Linear(config.n_embed, config.n_embed)

        self.time_embed = nn.Sequential(
            SinusoidalPositionEmbeddings(config.n_embed),
            nn.Linear(config.n_embed, config.n_embed),
            nn.GELU()
            )

    def forward(self,input_embeddings,time_step, targets=None):
        B,T,C = input_embeddings.size()
        device = input_embeddings.device

        pos = torch.arange(0,T,dtype=torch.long,device=device).unsqueeze(0)  # (1,T)
        x = input_embeddings +  self.transformer.wpe(pos)  # (B,T,C) pytorch does braodcasting for the position embeddingss and adds them to the token embeddings 
        
        time_emb = self.time_embed(time_step) # (B, C)
        x= x + time_emb.unsqueeze(1)  # (B, T, C)
        
        x = self.transformer.drop(x)


        for block in self.transformer.h:
            x = block(x)

        x = self.transformer.ln_f(x)  # (B,T,C)
        # logits = self.lm_head(x)  # (B,T,vocab_size) 
        # we don't need the head since we are not doing autoregressive language modeling
        
        # we want to predict the starting sequence before the noising part.
        x = self.small_mlp(x)  # (B,T,C)
        
        return x

In [33]:
class Decoding(nn.Module):
    def __init__(self,config):
        super().__init__()
    # takes x0 (B,T,C) and give a softmax over vocab size           
        self.l1 = nn.Linear(config.n_embed, config.n_vocab, bias=False)
        
        
    def forward(self,x):
        x = self.l1(x)
        # x = F.softmax(x,dim=-1)

        return x

## Tokenizer

In [34]:
import tiktoken

# 1. Load the tokenizer for GPT-4o
tokenizer = tiktoken.get_encoding("r50k_base")
print("vocab:",tokenizer.n_vocab)
# 2. Convert text to tokens
text = "Hello, tiktoken is fast!"
tokens = tokenizer.encode(text)
print(f"Token IDs: {tokens}")
print(f"Token Count: {len(tokens)}")

# 3. Convert back to original text
decoded_text = tokenizer.decode(tokens)
print(f"Decoded: {decoded_text}")


config = gpt2config(n_vocab=tokenizer.n_vocab)
print(config)

vocab: 50257
Token IDs: [15496, 11, 256, 1134, 30001, 318, 3049, 0]
Token Count: 8
Decoded: Hello, tiktoken is fast!
gpt2config(n_vocab=50257, n_layer=8, n_embed=64, n_context=1024, n_head=8, n_timesteps=1000)


In [35]:
emb_func = LMEmbedding(config).to(device)
model = Denoiser(config).to(device)
decoder = Decoding(config).to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
print(f"Embedding parameters: {sum(p.numel() for p in emb_func.parameters())/1e6:.2f}M")
print(f"Decoder parameters: {sum(p.numel() for p in decoder.parameters())/1e6:.2f}M")

Model parameters: 0.47M
Embedding parameters: 3.22M
Decoder parameters: 3.22M


In [36]:
sample_input = "Once upon a time in a land far away, there lived a"
sample_tokens = tokenizer.encode(sample_input)
sample_input_ids = torch.tensor([sample_tokens], device=device)  # (1, sequence_length)
sample_time_step = torch.tensor([10], device=device)  # (1,)

In [37]:
sample_input_ids.shape

torch.Size([1, 13])

In [38]:
sample_output = model(emb_func(sample_input_ids), sample_time_step)  # (1, sequence_length, n_embed)

def finalize_tokens(x0_final, embedding_weights):
    """
    Converts the final denoised latent into discrete token IDs.
    Args:
        x0_final: Tensor of shape (B, T, C)
        embedding_weights: Tensor of shape (Vocab, C)
    """
    # Fix: x2 must be 3D to match x1 (B, T, C)
    # Unsqueeze(0) makes it (1, Vocab, C), and PyTorch broadcasts it to (B, Vocab, C)
    distances = torch.cdist(x0_final, embedding_weights.unsqueeze(0), p=2) #(B,T,Vocab)
    # print("dist shape:", distances.shape) 
    
    # Argmin-rounding: Find the index with the minimum distance among all tokens in vocab
    # Result shape: (B, T)
    
    token_ids = torch.argmin(distances, dim=-1)
    
    return token_ids

token_ids = finalize_tokens(sample_output, emb_func.embed.weight)
decoded_output = tokenizer.decode(token_ids.squeeze(0).tolist())
print("Decoded Text:",decoded_output)



Decoded Text:  dives nauseaiotics diveseker Sci Pediatrics249 militantsGray treaties Lebaneseagan


## Forward Diffusion

In [39]:
def get_alphas(T=2000, s=1e-4):
    """
    Computes the bar_alpha (signal) schedule for Diffusion-LM[cite: 232, 483].
    s: constant determining initial noise level (standard dev = 0.1)[cite: 515].
    """
    t = torch.linspace(0, T, T + 1)
    # Sqrt schedule: alpha_bar = 1 - sqrt(t/T + s) 
    alphas = 1 - torch.sqrt(t / T )
    
    return alphas

In [40]:
def fwd_sample(x0, t, alphas):
    """
    Directly samples x_t from x_0 at a specific timestep[cite: 109, 170].
    
    Args:
        x0: Clean embeddings (B, SeqLen, EmbedDim) [cite: 126]
        t: Timesteps for the batch (B,) 
        alphas: Precomputed signal schedule from get_alphas()
    """
    # Select alpha_bar for each batch item and reshape for broadcasting
    a = alphas[t].view(-1, 1, 1).to(x0.device)
    
    # Sample Gaussian noise with same shape as x0
    noise = torch.randn_like(x0)
    
    # Formula: x_t = sqrt(alpha_bar) * x0 + sqrt(1 - alpha_bar) * noise [cite: 169]
    xt = torch.sqrt(a) * x0 + torch.sqrt(1 - a) * noise
    
    return xt

In [41]:
alphas = get_alphas().to(device)

noisy_input = fwd_sample(emb_func.embed(sample_input_ids), torch.tensor([1000], device=device), alphas)

In [42]:

token_ids = finalize_tokens(noisy_input, emb_func.embed.weight)
decoded_output = tokenizer.decode(token_ids.squeeze(0).tolist())
print("Decoded Text:",decoded_output)

Decoded Text: Once ecstasy a time wrongly a land Eden away, there lived Bucc


## Loading Datasets

In [43]:
# Load tiny shakespeare dataset
with open('datasets/ROCStories_train.txt', 'r', encoding='utf-8') as f:
    text = f.read()

print(f"Dataset length: {len(text)} characters")
print(f"First 100 characters:\n{text[0:100]}")

Dataset length: 18007898 characters
First 100 characters:
The boy went to a video arcade. He played his favorite machine. His games didn't go very well. He to


In [44]:
# Encode the entire dataset
data = tokenizer.encode(text)
print(f"Encoded length: {len(data)} tokens")

# Split into train and validation
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

print(f"Train tokens: {len(train_data)}, Val tokens: {len(val_data)}")

Encoded length: 4111142 tokens
Train tokens: 3700027, Val tokens: 411115


In [45]:
# Data loader function
def get_batch(split, batch_size=8, block_size=256):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    w_stack = torch.stack([torch.tensor(data[i:i+block_size]) for i in ix])
    # y = torch.stack([torch.tensor(data[i+1:i+block_size+1]) for i in ix])
    w_stack = w_stack.to(device)
    return w_stack

# Test batch
w_stack = get_batch('train')
print(f"Batch shape: {w_stack.shape}")
print(w_stack)

Batch shape: torch.Size([8, 256])
tensor([[ 5322,   663,  8971,  ..., 39517,  8848,    13],
        [  679,  1043,  5762,  ...,  3072,    13,   198],
        [  616,  3435,   717,  ...,   428,   257,  4171],
        ...,
        [ 4203, 30285,   355,  ...,  6949,   284,   787],
        [ 1097,   373,  5017,  ...,   502, 20202,  6642],
        [ 5500,   282,   318,  ...,  2968,   477,   780]], device='cuda:0')


## Training Loop


In [46]:
# Training configuration
max_iters = 10000
eval_interval = 10  # Evaluate less frequently
learning_rate = 3e-4
eval_iters = 20  # Much fewer eval iterations (was 200!)
batch_size = 16  # Larger batch for better GPU utilization
T = 100
alphas = get_alphas(T=T,s=1e-4).to(device)

In [47]:
@torch.no_grad()
def val_loss():
    pass

In [48]:
import time 

In [49]:
optimizer_model = torch.optim.AdamW(model.parameters(), lr=learning_rate)
optimizer_model_decoder = torch.optim.AdamW(decoder.parameters(), lr=learning_rate)
optimizer_emb = torch.optim.AdamW(emb_func.parameters(), lr=learning_rate)

In [50]:
for iter in range(max_iters):
    # model.eval()
    # decoder.eval()
    # emb_func.eval()
    w = get_batch('train', batch_size)  # already on device
    w_emb = emb_func(w)  # (B, T, C)
    x0 = w_emb + 0.2 * torch.randn_like(w_emb)  # use randn_like for Gaussian noise
    
    total_loss = 0
    mu_T = fwd_sample(x0,torch.tensor([T]*batch_size,device=device),alphas)
    total_loss += torch.mean(mu_T**2)
    # for t_step in range(T + 1):
        # t_tensor = torch.tensor([t_step] * batch_size, device=device)
        
    t_tensor = torch.randint(1, T + 1, (batch_size,), device=device)
    xt = fwd_sample(x0, t_tensor, alphas)
    x0_cap = model(xt, t_tensor)
    total_loss += F.mse_loss(x0_cap, x0)
    
    # Final step loss
    t_one = torch.tensor([1] * batch_size, device=device)
    total_loss += F.mse_loss(model(fwd_sample(x0, t_one, alphas), t_one), w_emb)

    # Decoder cross-entropy loss
    logits = decoder(x0)  # (B, T, V)
    V = config.n_vocab
    logits_flat = logits.view(-1, V)  # (B*T, V)
    targets_flat = w.view(-1)  # (B*T,)
    total_loss += F.cross_entropy(logits_flat, targets_flat)
    # print(log_loss)

    optimizer_model.zero_grad(set_to_none=True)
    optimizer_model_decoder.zero_grad(set_to_none=True)
    optimizer_emb.zero_grad(set_to_none=True)
    total_loss.backward()
    optimizer_model.step()
    optimizer_model_decoder.step()
    optimizer_emb.step()
    
    if iter%100 == 0:
        print(f"Iter {iter}: Loss {total_loss.item()}")

Iter 0: Loss 14.679441452026367
Iter 100: Loss 12.138429641723633
Iter 100: Loss 12.138429641723633
Iter 200: Loss 10.456873893737793
Iter 200: Loss 10.456873893737793
Iter 300: Loss 8.942022323608398
Iter 300: Loss 8.942022323608398
Iter 400: Loss 7.252367973327637
Iter 400: Loss 7.252367973327637
Iter 500: Loss 5.751605987548828
Iter 500: Loss 5.751605987548828
Iter 600: Loss 4.887882232666016
Iter 600: Loss 4.887882232666016
Iter 700: Loss 4.176258563995361
Iter 700: Loss 4.176258563995361
Iter 800: Loss 3.919649124145508
Iter 800: Loss 3.919649124145508
Iter 900: Loss 3.6510186195373535
Iter 900: Loss 3.6510186195373535
Iter 1000: Loss 3.4616472721099854
Iter 1000: Loss 3.4616472721099854
Iter 1100: Loss 3.1809492111206055
Iter 1100: Loss 3.1809492111206055
Iter 1200: Loss 3.0598251819610596
Iter 1200: Loss 3.0598251819610596
Iter 1300: Loss 3.0905849933624268
Iter 1300: Loss 3.0905849933624268
Iter 1400: Loss 2.691953659057617
Iter 1400: Loss 2.691953659057617
Iter 1500: Loss 2.70

## Inference

In [55]:
def reverse_diffusion_with_clamping(model, emb_func, alphas, T, context_length=50, batch_size=1):
    """
    Performs reverse diffusion with clamping trick from Diffusion-LM.
    At each step, clamps the predicted x0 to nearest word embedding.
    
    Formula: x_{t-1} = sqrt(alpha_{t-1}) * Clamp(f_theta(x_t, t)) + sqrt(1 - alpha_{t-1}) * epsilon
    
    Args:
        model: Trained Denoiser model
        emb_func: Trained embedding function (for clamping to nearest word)
        alphas: Alpha_bar schedule tensor on device
        T: Number of diffusion timesteps
        context_length: Length of sequence to generate
        batch_size: Number of sequences to generate
    
    Returns:
        generated_tokens: Token IDs of generated sequences (B, T)
        generated_text: Decoded text strings
    """
    model.eval()
    emb_func.eval()
    
    # Start from pure noise: x_T ~ N(0, I)
    x_t = torch.randn(batch_size, context_length, config.n_embed, device=device)
    
    with torch.no_grad():
        # Reverse diffusion: t = T, T-1, ..., 1, 0
        for t_step in reversed(range(T + 1)):
            # print(f"Denoising step {t_step}/{T}")
            
            if t_step == 0:
                # Final step: just clamp to get x_0
                x_0 = x_t
                break
            
            # Create timestep tensor for batch
            t_tensor = torch.tensor([t_step] * batch_size, device=device)
            
            # Predict x_0 from x_t using the denoiser
            x0_pred = model(x_t, t_tensor)
            
            # CLAMPING TRICK: Map predicted x_0 to nearest word embedding
            # This forces intermediate predictions to be valid words
            x0_clamped_tokens = finalize_tokens(x0_pred, emb_func.embed.weight)
            x0_clamped = emb_func(x0_clamped_tokens)  # (B, T, C)
            
            # Compute x_{t-1} using the formula:
            # x_{t-1} = sqrt(alpha_{t-1}) * x0_clamped + sqrt(1 - alpha_{t-1}) * epsilon
            
            alpha_t_prev = alphas[t_step - 1] if t_step > 0 else alphas[0]
            
            # Sample fresh noise
            epsilon = torch.randn_like(x_t)
            
            # Update: x_{t-1} = sqrt(alpha_{t-1}) * x0_clamped + sqrt(1 - alpha_{t-1}) * epsilon
            x_t = torch.sqrt(alpha_t_prev) * x0_clamped + torch.sqrt(1 - alpha_t_prev) * epsilon
    
    # Final denoised embeddings: x_0
    x0_final = x_t
    
    # Convert to tokens using argmin rounding with learned embeddings
    generated_tokens = finalize_tokens(x0_final, emb_func.embed.weight)
    
    # Decode to text
    generated_text = []
    for i in range(batch_size):
        text = tokenizer.decode(generated_tokens[i].tolist())
        generated_text.append(text)
    
    return generated_tokens, generated_text


# Run inference
print("Starting reverse diffusion inference with clamping...")
context_length = 256
generated_tokens, generated_text = reverse_diffusion_with_clamping(
    model=model,
    emb_func=emb_func,
    alphas=alphas,
    T=T,
    context_length=context_length,
    batch_size=1
)

print("\n" + "="*50)
print("GENERATED TEXT:")
print("="*50)
print(generated_text[0])
print("="*50)

Starting reverse diffusion inference with clamping...

GENERATED TEXT:
 home was asked a too his to At. the't day too go a, the home a a started so on decided a do to something the go to get been always store. She were all out to my of were butI the get told that friends the back took day house day for However the when. made like the by's so and.My. told. but were. the. The a day would week. was to, a people an found was money. an. do lot she She's one thought of I decided store the He So at
 got. the with He Now found for. told his. school school to in got. him more them she he themI I with. She to The's It could with She. happy play didn that water broke, decided. made When told
 the so into his way of a. everyone dog made were of When She a birthday she Finally was decided a many store for, of. house
 few asked The I she out gave them what home did She to gave
 Her around I would One needed.'s When of it like When. excited wanted house him in wanted. his took He. off did. no she wat

In [52]:
embs = emb_func(torch.tensor([tokenizer.encode("!")], device=device))

In [53]:
embs

tensor([[[-0.2571, -0.0819,  1.0992, -0.5341, -0.3827, -0.3746,  0.7433,
           0.8786,  0.9243,  0.8089, -0.4689, -0.2725, -0.3584,  0.5885,
          -0.3957, -0.1734, -0.2059, -0.2432,  0.2523,  1.1037, -0.3995,
           0.5379, -0.6829,  0.4967,  0.0608,  0.0281,  0.4738, -0.0913,
          -0.8046, -0.0025, -0.9382, -0.5513,  0.2202, -0.1311, -0.0457,
           0.6176, -0.7737,  0.1629,  0.4038,  0.0137,  0.6967, -0.6890,
          -0.2077, -0.2974,  0.0478,  0.2556,  0.6905, -0.0358,  0.4670,
           0.6208,  0.5352, -0.8628, -0.3029,  0.0031, -0.7207, -0.7787,
           0.6197, -0.0059,  0.5145, -0.4796,  0.1727,  0.0291, -0.6105,
          -0.5873]]], device='cuda:0', grad_fn=<EmbeddingBackward0>)