In [1]:
import torch
import torch.nn as nn 
import torch.nn.functional as F 
from torch import device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)
from dataclasses import dataclass
import math


Using device: cuda


In [2]:
@dataclass
class gpt2config:
    n_vocab: int = 50257
    n_layer: int = 8
    n_embed: int = 64
    n_context: int = 1024
    n_head: int = 8
    n_timesteps: int = 1000


In [3]:
class GPT2Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_attn = nn.Linear(config.n_embed, 3 * config.n_embed)
        self.c_proj = nn.Linear(config.n_embed, config.n_embed)
        self.n_head = config.n_head
        self.n_embed = config.n_embed
        
        # Create a causal mask (lower triangular matrix) and register it as a buffer
        # A buffer is not a parameter, but is saved with the model state_dict
        self.register_buffer("bias", torch.tril(torch.ones(config.n_context, config.n_context))
                                     .view(1, 1, config.n_context, config.n_context))

    def forward(self, x):
        B, T, C = x.size()
        
        # Calculate query, key, values for all heads in batch
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embed, dim=2)
        
        # Reshape for multi-head attention: (B, nh, T, hs)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        # Scaled dot-product attention
        att = (q @ k.transpose(-2, -1)) * (1.0 / (k.size(-1) ** 0.5))
        
        # --- MASKING STARTS HERE ---
        # Apply the causal mask: fill "future" positions with -infinity
        # This makes their softmax probability zero.
        # att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
        # --- MASKING ENDS HERE ---

        att = F.softmax(att, dim=-1)
        y = att @ v # (B, nh, T, hs)
        
        # Re-assemble all head outputs side-by-side
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        
        # Output projection
        y = self.c_proj(y)
        return y
    
class GPT2MLP(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embed, 4*config.n_embed)
        self.act = nn.GELU(approximate="tanh")
        self.c_proj = nn.Linear(4*config.n_embed, config.n_embed)

    def forward(self,x):
        x = self.c_fc(x)
        x = self.act(x)
        x = self.c_proj(x)
        return x
    

class Block(nn.Module):
    def __init__(self,config):
        super().__init__()
        
        self.ln1 = nn.LayerNorm(config.n_embed,eps=1e-5,elementwise_affine=True)
        self.attn = GPT2Attention(config)
        self.ln2 = nn.LayerNorm(config.n_embed,eps=1e-5,elementwise_affine=True)
        self.mlp = GPT2MLP(config)

    def forward(self,x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


In [4]:
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) #
        # TODO: Double check the ordering here
        return embeddings

In [5]:
sine_embeds = SinusoidalPositionEmbeddings(100)
time = 10
time = torch.tensor([time], device=device)
out = sine_embeds(time)
out.shape

torch.Size([1, 100])

In [6]:
class LMEmbedding(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config = config

        self.embed = nn.Embedding(config.n_vocab,config.n_embed)
    
    def forward(self,input_ids):
        x = self.embed(input_ids)
        
        return x
        


In [7]:
class Denoiser(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            # wte = nn.Embedding(config.n_vocab,config.n_embed),
            wpe = nn.Embedding(config.n_context,config.n_embed),
            drop = nn.Dropout(0.1,inplace=False),
            h = nn.ModuleList(Block(config) for _ in range(config.n_layer)),
            ln_f = nn.LayerNorm(config.n_embed,eps=1e-5,elementwise_affine=True)
        ))
        
        # self.lm_head = nn.Linear(config.n_embed, config.n_vocab, bias=False)

        self.small_mlp = nn.Linear(config.n_embed, config.n_embed)

        self.time_embed = nn.Sequential(
            SinusoidalPositionEmbeddings(config.n_embed),
            nn.Linear(config.n_embed, config.n_embed),
            nn.GELU()
            )

    def forward(self,input_embeddings,time_step, targets=None):
        B,T,C = input_embeddings.size()
        device = input_embeddings.device

        pos = torch.arange(0,T,dtype=torch.long,device=device).unsqueeze(0)  # (1,T)
        x = input_embeddings +  self.transformer.wpe(pos)  # (B,T,C) pytorch does braodcasting for the position embeddingss and adds them to the token embeddings 
        
        time_emb = self.time_embed(time_step) # (B, C)
        x= x + time_emb.unsqueeze(1)  # (B, T, C)
        
        x = self.transformer.drop(x)


        for block in self.transformer.h:
            x = block(x)

        x = self.transformer.ln_f(x)  # (B,T,C)
        # logits = self.lm_head(x)  # (B,T,vocab_size) 
        # we don't need the head since we are not doing autoregressive language modeling
        
        # we want to predict the starting sequence before the noising part.
        x = self.small_mlp(x)  # (B,T,C)
        
        return x

In [8]:
class Decoding(nn.Module):
    def __init__(self,config):
        super().__init__()
    # takes x0 (B,T,C) and give a softmax over vocab size           
        self.l1 = nn.Linear(config.n_embed, config.n_vocab, bias=False)
        
        
    def forward(self,x):
        x = self.l1(x)
        # x = F.softmax(x,dim=-1)

        return x

## Tokenizer

In [9]:
import tiktoken

# 1. Load the tokenizer for GPT-4o
tokenizer = tiktoken.get_encoding("r50k_base")
print("vocab:",tokenizer.n_vocab)
# 2. Convert text to tokens
text = "Hello, tiktoken is fast!"
tokens = tokenizer.encode(text)
print(f"Token IDs: {tokens}")
print(f"Token Count: {len(tokens)}")

# 3. Convert back to original text
decoded_text = tokenizer.decode(tokens)
print(f"Decoded: {decoded_text}")


config = gpt2config(n_vocab=tokenizer.n_vocab)
print(config)

vocab: 50257
Token IDs: [15496, 11, 256, 1134, 30001, 318, 3049, 0]
Token Count: 8
Decoded: Hello, tiktoken is fast!
gpt2config(n_vocab=50257, n_layer=8, n_embed=64, n_context=1024, n_head=8, n_timesteps=1000)


In [10]:
emb_func = LMEmbedding(config).to(device)
model = Denoiser(config).to(device)
decoder = Decoding(config).to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
print(f"Embedding parameters: {sum(p.numel() for p in emb_func.parameters())/1e6:.2f}M")
print(f"Decoder parameters: {sum(p.numel() for p in decoder.parameters())/1e6:.2f}M")

Model parameters: 0.47M
Embedding parameters: 3.22M
Decoder parameters: 3.22M


In [11]:
sample_input = "Once upon a time in a land far away, there lived a"
sample_tokens = tokenizer.encode(sample_input)
sample_input_ids = torch.tensor([sample_tokens], device=device)  # (1, sequence_length)
sample_time_step = torch.tensor([10], device=device)  # (1,)

In [12]:
sample_input_ids.shape

torch.Size([1, 13])

In [13]:
sample_output = model(emb_func(sample_input_ids), sample_time_step)  # (1, sequence_length, n_embed)

def finalize_tokens(x0_final, embedding_weights):
    """
    Converts the final denoised latent into discrete token IDs.
    Args:
        x0_final: Tensor of shape (B, T, C)
        embedding_weights: Tensor of shape (Vocab, C)
    """
    # Fix: x2 must be 3D to match x1 (B, T, C)
    # Unsqueeze(0) makes it (1, Vocab, C), and PyTorch broadcasts it to (B, Vocab, C)
    distances = torch.cdist(x0_final, embedding_weights.unsqueeze(0), p=2) #(B,T,Vocab)
    # print("dist shape:", distances.shape) 
    
    # Argmin-rounding: Find the index with the minimum distance among all tokens in vocab
    # Result shape: (B, T)
    
    token_ids = torch.argmin(distances, dim=-1)
    
    return token_ids

token_ids = finalize_tokens(sample_output, emb_func.embed.weight)
decoded_output = tokenizer.decode(token_ids.squeeze(0).tolist())
print("Decoded Text:",decoded_output)



Decoded Text:  covetedalloStack Fiscal gifts Roku fading Walletforcer3 coveted juggling fullest


## Forward Diffusion

In [14]:
def get_alphas(T=2000, s=1e-4):
    """
    Computes the bar_alpha (signal) schedule for Diffusion-LM[cite: 232, 483].
    s: constant determining initial noise level (standard dev = 0.1)[cite: 515].
    """
    t = torch.linspace(0, T, T + 1)
    # Sqrt schedule: alpha_bar = 1 - sqrt(t/T + s) 
    alphas = 1 - torch.sqrt(t / T )
    
    return alphas

In [15]:
def fwd_sample(x0, t, alphas):
    """
    Directly samples x_t from x_0 at a specific timestep[cite: 109, 170].
    
    Args:
        x0: Clean embeddings (B, SeqLen, EmbedDim) [cite: 126]
        t: Timesteps for the batch (B,) 
        alphas: Precomputed signal schedule from get_alphas()
    """
    # Select alpha_bar for each batch item and reshape for broadcasting
    a = alphas[t].view(-1, 1, 1).to(x0.device)
    
    # Sample Gaussian noise with same shape as x0
    noise = torch.randn_like(x0)
    
    # Formula: x_t = sqrt(alpha_bar) * x0 + sqrt(1 - alpha_bar) * noise [cite: 169]
    xt = torch.sqrt(a) * x0 + torch.sqrt(1 - a) * noise
    
    return xt

In [16]:
alphas = get_alphas().to(device)

noisy_input = fwd_sample(emb_func.embed(sample_input_ids), torch.tensor([1000], device=device), alphas)

In [17]:

token_ids = finalize_tokens(noisy_input, emb_func.embed.weight)
decoded_output = tokenizer.decode(token_ids.squeeze(0).tolist())
print("Decoded Text:",decoded_output)

Decoded Text: Once upon a justifies Tory a land far away, there commissioner a


## Loading Datasets

In [18]:
# Load tiny shakespeare dataset
with open('ROCStories_train.txt', 'r', encoding='utf-8') as f:
    text = f.read()

print(f"Dataset length: {len(text)} characters")
print(f"First 100 characters:\n{text[0:100]}")

Dataset length: 18007898 characters
First 100 characters:
The boy went to a video arcade. He played his favorite machine. His games didn't go very well. He to


In [19]:
# Encode the entire dataset
data = tokenizer.encode(text)
print(f"Encoded length: {len(data)} tokens")

# Split into train and validation
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

print(f"Train tokens: {len(train_data)}, Val tokens: {len(val_data)}")

Encoded length: 4111142 tokens
Train tokens: 3700027, Val tokens: 411115


In [20]:
# Data loader function
def get_batch(split, batch_size=8, block_size=256):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    w_stack = torch.stack([torch.tensor(data[i:i+block_size]) for i in ix])
    # y = torch.stack([torch.tensor(data[i+1:i+block_size+1]) for i in ix])
    w_stack = w_stack.to(device)
    return w_stack

# Test batch
w_stack = get_batch('train')
print(f"Batch shape: {w_stack.shape}")
print(w_stack)

Batch shape: torch.Size([8, 256])
tensor([[ 1332,    13,   679,  ...,   750,   407,   423],
        [   13,  1881,  1110,  ...,   339,   714,    13],
        [ 4094,  5935,   550,  ...,  1816,   736,   284],
        ...,
        [  284,   257,   649,  ...,  4876,    13,  8616],
        [  465, 23916,    13,  ...,    13,  1375,  1234],
        [ 1965,   683,   284,  ..., 35903,  1422,   470]], device='cuda:0')


## Training Loop


In [21]:
# Training configuration
max_iters = 1000
eval_interval = 10  # Evaluate less frequently
learning_rate = 3e-4
eval_iters = 20  # Much fewer eval iterations (was 200!)
batch_size = 8  # Larger batch for better GPU utilization
T = 20


In [22]:
@torch.no_grad()
def val_loss():
    pass

In [23]:
import time 

In [24]:
optimizer_model = torch.optim.AdamW(model.parameters(), lr=learning_rate)
optimizer_model_decoder = torch.optim.AdamW(decoder.parameters(), lr=learning_rate)
optimizer_emb = torch.optim.AdamW(emb_func.parameters(), lr=learning_rate)

In [None]:


for iter in range(max_iters):
    # model.eval()
    # decoder.eval()
    # emb_func.eval()
    w = get_batch('train', batch_size)  # already on device
    w_emb = emb_func(w)  # (B, T, C)
    x0 = w_emb + 1.2 * torch.randn_like(w_emb)  # use randn_like for Gaussian noise
    
    total_loss = 0
    for t_step in range(T + 1):
        t_tensor = torch.tensor([t_step] * batch_size, device=device)
        xt = fwd_sample(x0, t_tensor, alphas)
        x0_cap = model(xt, t_tensor)
        total_loss += F.mse_loss(x0_cap, x0)
    
    # Final step loss
    t_one = torch.tensor([1] * batch_size, device=device)
    total_loss += F.mse_loss(model(fwd_sample(x0, t_one, alphas), t_one), w_emb)

    # Decoder cross-entropy loss
    logits = decoder(x0)  # (B, T, V)
    V = config.n_vocab
    logits_flat = logits.view(-1, V)  # (B*T, V)
    targets_flat = w.view(-1)  # (B*T,)
    total_loss += F.cross_entropy(logits_flat, targets_flat)

    optimizer_model.zero_grad(set_to_none=True)
    optimizer_model_decoder.zero_grad(set_to_none=True)
    optimizer_emb.zero_grad(set_to_none=True)
    total_loss.backward()
    optimizer_model.step()
    optimizer_model_decoder.step()
    optimizer_emb.step()
    
    print(f"Iter {iter}: Loss {total_loss.item() / (T + 2)}")

Iter 0: Loss 3.2485892555930396
Iter 1: Loss 3.2029602744362573
Iter 1: Loss 3.2029602744362573
Iter 2: Loss 3.1829147338867188
Iter 2: Loss 3.1829147338867188
Iter 3: Loss 3.1562409834428267
Iter 3: Loss 3.1562409834428267
Iter 4: Loss 3.103327664462003
Iter 4: Loss 3.103327664462003
Iter 5: Loss 3.049136768687855
Iter 5: Loss 3.049136768687855
Iter 6: Loss 3.0433436307040127
Iter 6: Loss 3.0433436307040127
Iter 7: Loss 3.0179082697088067
Iter 7: Loss 3.0179082697088067
Iter 8: Loss 3.004383087158203
Iter 8: Loss 3.004383087158203
Iter 9: Loss 2.9724877097389917
Iter 9: Loss 2.9724877097389917
Iter 10: Loss 2.954865889115767
Iter 10: Loss 2.954865889115767
Iter 11: Loss 2.936160694469105
Iter 11: Loss 2.936160694469105
Iter 12: Loss 2.919105876575817
Iter 12: Loss 2.919105876575817
Iter 13: Loss 2.8960399627685547
Iter 13: Loss 2.8960399627685547
Iter 14: Loss 2.8674484599720347
Iter 14: Loss 2.8674484599720347
Iter 15: Loss 2.845590764825994
Iter 15: Loss 2.845590764825994
Iter 16: L

KeyboardInterrupt: 