In [78]:
# import sys
# sys.path.append('/media/linux-stuff/gpt2-diff/scripts')

import torch
import torch.nn as nn 
import torch.nn.functional as F 
from torch import device
from torch.optim.lr_scheduler import LambdaLR

import math
import os
import pandas as pd

from scripts.config import gpt2config
from scripts.model import DiffusionLM, LMEmbedding, Denoiser, Decoding
from scripts.utils import (
    MyTokenizer, 
    get_next_log_filename, 
    save_checkpoint, 
    load_checkpoint,
    posterior_mean,
    rounding_weight,
    get_batch,
    finalize_tokens,
    reverse_diffusion_with_clamping,
    visualize_embeddings_2d,
    fwd_diffusion
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)
torch.set_float32_matmul_precision('high')

Using device: cuda


In [79]:
tokenizer = MyTokenizer(max_len=13)
tokenizer.decode(tokenizer.encode("Hello, tiktoken is fast!"))

'<bos>Hello, tiktoken is fast!<eos><pad><pad><pad>'

In [80]:
config = gpt2config(n_vocab=tokenizer.n_vocab,n_embed=128,n_head= 12, mlp_expansion=4,n_latent=768)
model = DiffusionLM(config).to(device)
print(f"Total Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
print(config)

Total Model parameters: 100.46M
gpt2config(n_vocab=50260, n_layer=12, n_embed=128, n_context=1024, n_head=12, n_timesteps=1000, mlp_expansion=4, n_latent=768)


In [81]:
import pandas as pd

# Load E2E dataset - extract text from 'ref' column
df = pd.read_csv('datasets/ROCStories/rocstories_train.csv')
text = ' '.join(df['ref'].tolist())

print(f"Dataset length: {len(text)} characters")
print(f"Number of samples: {len(df)}")
print(f"First sample: {df['ref'][0]}")

Dataset length: 18007897 characters
Number of samples: 78528
First sample: The boy went to a video arcade. He played his favorite machine. His games didn't go very well. He told the owner about his experience. The owner explained that he had made the game settings harder.


In [82]:
# Split into train and test
m,n = (32,32)
train_size = int(0.9 * len(df))
train_df = df[:train_size].reset_index(drop=True)
test_df = df[train_size:].reset_index(drop=True)

print(f"Train samples: {len(train_df)}, Test samples: {len(test_df)}")

# Pre-encode all sequences for training efficiency
print("\nEncoding training data...")
train_encoded = []
for idx, row in train_df.iterrows():
    encoded = tokenizer.encode(row['ref'], max_len=m+n)  # Use fixed sequence length
    train_encoded.append(encoded)
    if (idx + 1) % 5000 == 0:
        print(f"Encoded {idx + 1}/{len(train_df)} train samples")

print("\nEncoding test data...")
test_encoded = []
for idx, row in test_df.iterrows():
    encoded = tokenizer.encode(row['ref'], max_len=64)
    test_encoded.append(encoded)

# Convert to tensors
train_encoded = torch.tensor(train_encoded, dtype=torch.long)
test_encoded = torch.tensor(test_encoded, dtype=torch.long)

print(f"\nTrain encoded shape: {train_encoded.shape}")
print(f"Test encoded shape: {test_encoded.shape}")

Train samples: 70675, Test samples: 7853

Encoding training data...
Encoded 5000/70675 train samples
Encoded 10000/70675 train samples
Encoded 15000/70675 train samples
Encoded 20000/70675 train samples
Encoded 25000/70675 train samples
Encoded 30000/70675 train samples
Encoded 35000/70675 train samples
Encoded 40000/70675 train samples
Encoded 45000/70675 train samples
Encoded 50000/70675 train samples
Encoded 55000/70675 train samples
Encoded 60000/70675 train samples
Encoded 65000/70675 train samples
Encoded 70000/70675 train samples

Encoding test data...

Train encoded shape: torch.Size([70675, 64])
Test encoded shape: torch.Size([7853, 64])


In [83]:
# Training configuration
max_iters = 100000  
learning_rate = 3e-3
eval_iters = 100  # Much fewer eval iterations (was 200!)
batch_size = 16  # Larger batch for better GPU utilization
T = 1000
num_timestep_samples = 4  # Sample 8 timesteps per iteration for better gradient estimate
m,n = (32,32)
sequence_length = m + n



In [84]:
# Fixed alpha schedule - simple sqrt schedule
t = torch.arange(0, T+1, device=device, dtype=torch.float32)
alpha_bars = 1 - torch.sqrt(t / T)  # Goes from ~0 to 1-sqrt(1)=0
alpha_bars = torch.clamp(alpha_bars, min=0.001, max=0.999)
alphas = torch.zeros(T+1, device=device) #alpha_0 to alpha_T
alphas[0] = alpha_bars[0]
alphas[1:] = alpha_bars[1:] / alpha_bars[:-1]
alphas = torch.clamp(alphas, min=0.001, max=0.999)

# Precompute sqrt terms for efficiency
sqrt_ab = torch.sqrt(alpha_bars)
sqrt_1mab = torch.sqrt(1 - alpha_bars)

print(f"Alpha bars range: [{alpha_bars.min():.4f}, {alpha_bars.max():.4f}]")
print(f"Alphas range: [{alphas.min():.4f}, {alphas.max():.4f}]")


Alpha bars range: [0.0010, 0.9990]
Alphas range: [0.6665, 0.9990]


In [85]:
optimizer_model = torch.optim.AdamW(model.parameters(), lr=learning_rate,weight_decay=0.0)
lr_lambda = lambda step: 1.0 - (step / float(max_iters))
scheduler_model = LambdaLR(optimizer_model, lr_lambda=lr_lambda)

In [86]:
log_file = get_next_log_filename('logs')
print(f"Logging to: {log_file}")

with open(log_file, 'w') as f:
    f.write("Training the SEQ2SEQ Diffusion Language Model\n")
    f.write("Iteration,Total_Loss,Denoising_Loss,Posterior_Loss,Anchor_Loss,Rounding_Loss\n")

checkpoint_counter = 0

# Importance sampling setup
loss_buffer_size = 10
# Loss buffer for each timestep: shape (T, buffer_size), initialized with ones for uniform start
loss_buffer = torch.ones((T + 1, loss_buffer_size), device=device)  # timesteps 0 to T
buffer_counts = torch.zeros(T + 1, device=device, dtype=torch.long)  # track how many samples per timestep

# Warmup iterations before fully using importance sampling
importance_warmup_iters = 5000

def get_importance_probs(loss_buffer, buffer_counts, iteration, warmup_iters, T):
    """Compute sampling probabilities based on RMS of loss buffer."""
    # Compute RMS for each timestep (only for t >= 1)
    rms_losses = torch.zeros(T + 1, device=loss_buffer.device)
    for t in range(1, T + 1):
        count = min(int(buffer_counts[t].item()), loss_buffer_size)
        if count > 0:
            rms_losses[t] = torch.sqrt(torch.mean(loss_buffer[t, :count] ** 2))
        else:
            rms_losses[t] = 1.0  # Default for unvisited timesteps
    
    # Normalize to get probabilities (only for t in [1, T])
    importance_probs = rms_losses[1:T+1]  # shape (T,)
    importance_probs = importance_probs / (importance_probs.sum() + 1e-8)
    
    # Uniform distribution
    uniform_probs = torch.ones(T, device=loss_buffer.device) / T
    
    # Blend: gradually shift from uniform to importance-based
    blend_factor = min(1.0, iteration / warmup_iters)
    final_probs = (1 - blend_factor) * uniform_probs + blend_factor * importance_probs
    
    # Ensure valid probability distribution
    final_probs = final_probs / (final_probs.sum() + 1e-8)
    
    return final_probs

for it in range(0, max_iters):

    w = get_batch('train', batch_size, sequence_length, train_encoded=train_encoded, test_encoded=test_encoded, device=device)
    w_emb = model.embedding(w)

    # Split embeddings: first m tokens (prefix) and last n tokens (suffix to be noised)
    w_emb_prefix = w_emb[:, :m, :]  # (batch, m, embed_dim) - stays unchanged
    w_emb_suffix = w_emb[:, m:, :]  # (batch, n, embed_dim) - will be noised

    # Only add initial perturbation to suffix
    x0_suffix = w_emb_suffix + 0.1 * torch.randn_like(w_emb_suffix)
    total_loss = 0.0
    
    # Noise only for suffix part
    eps = torch.randn_like(x0_suffix)
    denoising_loss = 0.0
    
    # Get importance-based sampling probabilities
    sampling_probs = get_importance_probs(loss_buffer, buffer_counts, it, importance_warmup_iters, T)
    
    # Sample timesteps according to importance distribution
    # Each sample in batch gets the same set of timesteps for simplicity
    sampled_timesteps = torch.multinomial(sampling_probs, num_timestep_samples, replacement=True) + 1  # +1 because probs are for t in [1,T]
    
    for t_sample in sampled_timesteps:
        t_random = t_sample.expand(batch_size)  # Same timestep for all batch items
        t_idx = t_random
        sqrt_ab_t = sqrt_ab[t_idx].view(batch_size, 1, 1)
        sqrt_1mab_t = sqrt_1mab[t_idx].view(batch_size, 1, 1)
        
        # Forward diffusion only on suffix
        xt_suffix = sqrt_ab_t * x0_suffix + sqrt_1mab_t * eps
        # Concatenate unchanged prefix with noised suffix
        xt = torch.cat([w_emb_prefix, xt_suffix], dim=1)
        x0_hat = model.denoiser(xt, t_random)
        # x0_hat = torch.clamp(x0_hat, min=-10.0, max=10.0)
        # Only compute denoising loss on suffix part
        x0_hat_suffix = x0_hat[:, m:, :]
        
        # Compute loss for this timestep
        timestep_loss = F.mse_loss(x0_hat_suffix, x0_suffix)
        denoising_loss += timestep_loss
        
        # Update loss buffer for this timestep
        t_val = t_sample.item()
        buffer_idx = int(buffer_counts[t_val].item()) % loss_buffer_size
        loss_buffer[t_val, buffer_idx] = timestep_loss.detach()
        buffer_counts[t_val] += 1
    
    x0_target = torch.cat([w_emb_prefix, x0_suffix], dim=1)
    denoising_loss = denoising_loss / num_timestep_samples 
    total_loss += denoising_loss
    
    # t_T = torch.full((batch_size,), T, device=device)
    # xT_suffix = sqrt_ab[-1] * x0_suffix + sqrt_1mab[-1] * eps
    # xT = torch.cat([w_emb_prefix, xT_suffix], dim=1)
    # x0_hat_T = model.denoiser(xT, t_T)
    # x0_hat_T = torch.clamp(x0_hat_T, min=-10.0, max=10.0)
    # mu_hat_T = posterior_mean(xT_suffix, x0_suffix, T, alpha_bars, alphas)
    # posterior_loss = torch.tensor(0.0, device=device)
    # posterior_loss = F.mse_loss(mu_hat_T, torch.zeros_like(mu_hat_T)) 
    # total_loss += posterior_loss
    
    # Anchor loss: only noise suffix at t=1
    xt_1_suffix = sqrt_ab[1] * x0_suffix + sqrt_1mab[1] * torch.rand_like(x0_suffix)
    xt_1 = torch.cat([w_emb_prefix, xt_1_suffix], dim=1)
    x0_hat_1 = model.denoiser(xt_1, torch.ones(batch_size, device=device))
    # x0_hat_1 = torch.clamp(x0_hat_1, min=-10.0, max=10.0)
    # Anchor loss on full sequence (prefix should reconstruct prefix, suffix should reconstruct suffix)
    anchor_loss = F.mse_loss(x0_hat_1, w_emb) 
    total_loss += anchor_loss

    reg_loss = torch.mean(x0_target**2)
    total_loss += reg_loss
    
    if torch.isnan(total_loss) or torch.isinf(total_loss):
        print(f"\n{'='*70}")
        print(f"TRAINING STOPPED: NaN/Inf detected at iteration {it}")
        print(f"{'='*70}")
        print(f"Loss Diagnostics:")
        print(f"  Total Loss:     {total_loss.item() if not torch.isnan(total_loss) else 'NaN'}")
        print(f"  Denoising:      {denoising_loss.item()}")
        # print(f"  Posterior:      {posterior_loss.item()}")
        print(f"  Anchor:         {anchor_loss.item()}")
        print(f"  Regularizing_loss:       {reg_loss.item()}")
        print(f"\nModel Output Statistics:")
        print(f"  x0_hat range:   [{x0_hat.min().item():.2f}, {x0_hat.max().item():.2f}]")
        print(f"\nGradient Statistics:")
        total_norm = 0.0
        for p in model.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5
        print(f"  Total grad norm: {total_norm:.4f}")
        print(f"{'='*70}\n")
        break
    
    optimizer_model.zero_grad(set_to_none=True)
    total_loss.backward()
    # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
    optimizer_model.step()
    scheduler_model.step()

    with open(log_file, 'a') as f:
        f.write(f"{it},{total_loss.item():.6f},{denoising_loss.item():.6f},{anchor_loss.item():.6f}\n")

    if it % eval_iters == 0:
        # Show sampling distribution stats
        probs = get_importance_probs(loss_buffer, buffer_counts, it, importance_warmup_iters, T)
        top_probs, top_t = torch.topk(probs, 5)
        print(f"Iter {it}: loss = {total_loss.item():.4f}, denoising = {denoising_loss.item():.4f}, anchor = {anchor_loss.item():.4f}")
        print(f"  Top-5 sampled timesteps: {(top_t + 1).tolist()} with probs: {[f'{p:.4f}' for p in top_probs.tolist()]}")

    # if it % 5000 == 0 and it > 0:
    #     checkpoint_name = f"training_ckpt_{checkpoint_counter % 2}"
    #     save_checkpoint(model, config, alpha_bars, T, checkpoint_name, save_individual=False)
    #     checkpoint_counter += 1

print(f"\nTraining complete! Logs saved to: {log_file}")

Logging to: logs/log_22.txt
Iter 0: loss = 3.7886, denoising = 1.4091, anchor = 1.3669
  Top-5 sampled timesteps: [2, 1, 3, 5, 4] with probs: ['0.0010', '0.0010', '0.0010', '0.0010', '0.0010']
Iter 100: loss = 1.8429, denoising = 0.4812, anchor = 0.5981
  Top-5 sampled timesteps: [958, 880, 851, 146, 396] with probs: ['0.0010', '0.0010', '0.0010', '0.0010', '0.0010']


KeyboardInterrupt: 

In [None]:
def rev_s2s_diffusion(model, config, tokenizer, input_text, alpha_bars, T, m=32, n=32,
                      batch_size=1, clamping_start=0.4, skip_step=1, display_at_steps=None, device='cuda'):
    """
    Reverse diffusion for seq2seq: condition on first m tokens, generate last n tokens.
    
    Args:
        m: number of prefix tokens (conditioning)
        n: number of suffix tokens (to generate)
    """
    model.eval()

    # Encode input and get prefix embeddings
    input_tokens = tokenizer.encode(input_text, raw=True)
    x_prefix = model.embedding(torch.tensor(input_tokens, device=device).unsqueeze(0))
    
    # Take first m tokens as conditioning prefix
    x_prefix = x_prefix[:, :m, :]  # (1, m, n_embed)
    x_prefix = x_prefix.repeat(batch_size, 1, 1)  # (batch_size, m, n_embed)
    
    # Initialize suffix with random noise (this is what we'll denoise)
    y_t = torch.randn(batch_size, n, config.n_embed, device=device)  # (batch_size, n, n_embed)
    
    # Concatenate: [prefix (clean) | suffix (noisy)]
    z_t = torch.cat([x_prefix, y_t], dim=1)  # (batch_size, m+n, n_embed)
    
    # Show initial random state
    initial_tokens = finalize_tokens(z_t, model.embedding.embed.weight)
    initial_text = tokenizer.decode(initial_tokens[0].tolist())
    initial_text_clean = tokenizer.clean_text(initial_text)
    print(f"Initial (noisy): {initial_text_clean}")
    print(f"{'-'*70}\n")

    sqrt_ab = torch.sqrt(alpha_bars)
    sqrt_1mab = torch.sqrt(1 - alpha_bars)

    with torch.no_grad():
        for t_step in range(T, 0, -skip_step):
            t_tensor = torch.tensor([t_step] * batch_size, device=device)
            
            # Predict x0 from current z_t
            z0_hat = model.denoiser(z_t, t_tensor)
            # x0_hat = torch.clamp(x0_hat, min=-10.0, max=10.0)
            
            if t_step < clamping_start * T:
                z0_clamped = finalize_tokens(z0_hat, model.embedding.embed.weight)
                z0_clamped = model.embedding(z0_clamped)
            else:
                z0_clamped = z0_hat
            
            epsilon = torch.randn_like(z_t)

            if t_step > 1:
                z_t = sqrt_ab[t_step - 1] * z0_clamped + \
                      sqrt_1mab[t_step - 1] * epsilon
            else:
                z_t = z0_clamped
            
            # Only update the suffix part (last n tokens)
            z_t = torch.cat([x_prefix, z_t[:, m:, :]], dim=1)

            # Display at specified steps
            if display_at_steps and t_step in display_at_steps:
                tokens = finalize_tokens(z_t, model.embedding.embed.weight)
                text = tokenizer.decode(tokens[0].tolist())
                text_clean = tokenizer.clean_text(text)
                print(f"Step {t_step}: {text_clean}")
    
    # Final output
    final_tokens = finalize_tokens(z_t, model.embedding.embed.weight)
    final_text = tokenizer.decode(final_tokens[0].tolist())
    final_text_clean = tokenizer.clean_text(final_text)
    
    print(f"\n{'-'*70}")
    print(f"Final output: {final_text_clean}")
    
    model.train()
    return z_t, final_tokens