In [12]:
#%% Imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from datasets import load_dataset
import torchvision.transforms as transforms
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

In [13]:
#%% Hyperparameters (Paper Table 3)
class Config:
    # Dataset parameters
    dataset = 'cifar10'  # mnist, cifar10, cifar100
    batch_size = 128
    image_size = 32
    
    # Model architecture
    embed_dim = 20
    hidden_dim = 256
    num_timesteps = 10
    num_classes = 10  # Will be set automatically
    
    # Training parameters
    epochs = 150
    learning_rate = 1e-3
    weight_decay = 1e-3
    eta = 0.1
    
    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = Config()

In [14]:
#%% Helper Functions
def cosine_noise_schedule(t, T=config.num_timesteps):
    """Cosine noise schedule from paper"""
    s = 0.008
    f_t = torch.cos((t/T + s)/(1 + s) * torch.pi/2).clamp(min=1e-5)
    alpha_bar = f_t / f_t[0]
    return alpha_bar

class EMA:
    """Exponential Moving Average for model stability"""
    def __init__(self, beta=0.995):
        self.beta = beta
        self.step = 0

    def update(self, model, ema_model):
        with torch.no_grad():
            for param, ema_param in zip(model.parameters(), ema_model.parameters()):
                ema_param.data = self.beta * ema_param.data + (1 - self.beta) * param.data
        self.step += 1

In [15]:
#%% Fixed Model Architecture for Batched Timesteps
class NoPropBlock(nn.Module):
    """Diffusion dynamics block with batched timestep handling"""
    def __init__(self, embed_dim, hidden_dim, num_timesteps=10, image_channels=3):
        super().__init__()
        # Timestep embedding for multiple steps
        self.t_embed = nn.Embedding(num_timesteps, embed_dim)
        
        # Image processing branch
        self.image_embed = nn.Sequential(
            nn.Conv2d(image_channels, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(32*(config.image_size//4)**2, hidden_dim)
        )
        
        # Noise processing branch
        self.noise_embed = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Combined processing
        self.combine = nn.Sequential(
            nn.Linear(2*hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embed_dim)
        )

    def forward(self, z, x, t):
        # t is a tensor of shape [batch_size]
        t_emb = self.t_embed(t)  # Shape: [batch_size, embed_dim]
        
        # Add timestep embedding to noise
        z = z + t_emb
        
        # Process inputs
        x_embed = self.image_embed(x)  # Shape: [batch_size, hidden_dim]
        z_embed = self.noise_embed(z)  # Shape: [batch_size, hidden_dim]
        
        # Combine features
        combined = torch.cat([x_embed, z_embed], dim=1)
        return self.combine(combined)

class NoPropModel(nn.Module):
    def __init__(self, num_timesteps=10):
        super().__init__()
        self.embed = nn.Embedding(config.num_classes, config.embed_dim)
        self.block = NoPropBlock(
            config.embed_dim,
            config.hidden_dim,
            num_timesteps=num_timesteps
        )
        self.final_layer = nn.Linear(config.embed_dim, config.num_classes)
        
    def forward(self, z, x, t):
        # t should be a tensor of timestep indices
        return self.block(z, x, t)

In [16]:
#%% Fixed Dataset Loading with Hugging Face
# Define collate function at top level
def no_prop_collate_fn(batch):
    images = torch.stack([item['image'] for item in batch])
    labels = torch.stack([item['label'] for item in batch])
    return images, labels

def get_dataset(config):
    # Map dataset names to Hugging Face paths
    dataset_map = {
        'mnist': 'mnist',
        'cifar10': 'cifar10',
        'cifar100': 'cifar100'
    }
    
    # Load dataset from Hugging Face
    dataset = load_dataset(dataset_map[config.dataset.lower()])
    
    # Set number of classes
    if config.dataset.lower() == 'cifar100':
        config.num_classes = 100
    else:
        config.num_classes = 10

    # Define transforms
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    # Function to apply transforms and convert to tensors
    def apply_transforms(examples):
        # Convert images to RGB (CIFAR) or grayscale (MNIST)
        if config.dataset.lower() == 'mnist':
            examples['image'] = [transform(image.convert('L')) for image in examples['img']]
        else:
            examples['image'] = [transform(image.convert('RGB')) for image in examples['img']]
            
        # Convert labels to tensors
        examples['label'] = [torch.tensor(label) for label in examples['label']]
        return examples

    # Process datasets with proper tensor conversion
    dataset = dataset.map(
        apply_transforms,
        batched=True,
        batch_size=config.batch_size,
        remove_columns=['img']  # Remove original image column
    )

    # Convert to PyTorch format
    train_dataset = dataset['train'].with_format('torch', 
        columns=['image', 'label'],
        output_all_columns=False
    )
    test_dataset = dataset['test'].with_format('torch', 
        columns=['image', 'label'],
        output_all_columns=False
    )

    # Create DataLoaders with proper collation
    def collate_fn(batch):
        images = torch.stack([item['image'] for item in batch])
        labels = torch.stack([item['label'] for item in batch])
        return images, labels

    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        collate_fn=no_prop_collate_fn,
        pin_memory=False,  # Disable pin_memory for stability
        num_workers=0,  # Disable multiprocessing
        persistent_workers=False
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        collate_fn=no_prop_collate_fn,
        pin_memory=False,
        num_workers=0,
        persistent_workers=False
    )
    
    return train_loader, test_loader

# Initialize datasets
trainloader, testloader = get_dataset(config)

In [17]:
#%% Training Loop (Algorithm 1)
model = NoPropModel().to(config.device)
ema_model = NoPropModel().to(config.device)
ema_model.load_state_dict(model.state_dict())
ema = EMA()

optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate, 
                       weight_decay=config.weight_decay)
criterion = nn.CrossEntropyLoss()

# Initialize noise schedule
timesteps = torch.arange(config.num_timesteps)
alpha_bar = cosine_noise_schedule(timesteps)

for epoch in range(config.epochs):
    model.train()
    progress_bar = tqdm(trainloader)
    
    for images, labels in progress_bar:
        images = images.to(config.device)
        labels = labels.to(config.device)
        batch_size = images.size(0)
        
        # Convert labels to embeddings
        u_y = model.embed(labels)
        
        # Sample random timestep
        t = torch.randint(1, config.num_timesteps, (batch_size,)).to(config.device)
        alpha_bar_t = alpha_bar[t].view(-1, 1).to(config.device)
        
        # Forward process (add noise)
        epsilon = torch.randn_like(u_y)
        z_t = torch.sqrt(alpha_bar_t) * u_y + torch.sqrt(1 - alpha_bar_t) * epsilon
        
        # Sample random timesteps as indices (shape [batch_size])
        t = torch.randint(1, config.num_timesteps, (batch_size,)).to(config.device)

        # Forward pass with batched timesteps
        u_pred = model(z_t, images, t-1)  # t-1 for 0-based indexing
        
        # Compute loss (Equation 8)
        snr_t = alpha_bar_t / (1 - alpha_bar_t)
        snr_t_prev = alpha_bar[t-1].view(-1, 1) / (1 - alpha_bar[t-1].view(-1, 1))
        loss_denoise = torch.mean((snr_t - snr_t_prev) * torch.norm(u_pred - u_y, dim=1)**2)
        
        # Final classification loss
        logits = model.final_layer(z_t)
        loss_cls = criterion(logits, labels)
        
        # Total loss
        total_loss = loss_cls + config.eta * loss_denoise
        
        # Backprop and optimize
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        ema.update(model, ema_model)
        
        progress_bar.set_description(f"Epoch {epoch+1} Loss: {total_loss.item():.4f}")

Epoch 1 Loss: nan: 100%|██████████| 391/391 [02:00<00:00,  3.24it/s]
Epoch 2 Loss: nan: 100%|██████████| 391/391 [01:53<00:00,  3.44it/s]
Epoch 3 Loss: nan: 100%|██████████| 391/391 [01:49<00:00,  3.56it/s]
Epoch 4 Loss: nan: 100%|██████████| 391/391 [01:50<00:00,  3.53it/s]
Epoch 5 Loss: nan: 100%|██████████| 391/391 [01:49<00:00,  3.57it/s]
Epoch 6 Loss: nan: 100%|██████████| 391/391 [01:48<00:00,  3.61it/s]
Epoch 7 Loss: nan: 100%|██████████| 391/391 [01:51<00:00,  3.52it/s]
Epoch 8 Loss: nan: 100%|██████████| 391/391 [01:50<00:00,  3.53it/s]
Epoch 9 Loss: nan: 100%|██████████| 391/391 [01:53<00:00,  3.44it/s]
Epoch 10 Loss: nan: 100%|██████████| 391/391 [01:51<00:00,  3.51it/s]
Epoch 11 Loss: nan: 100%|██████████| 391/391 [01:54<00:00,  3.43it/s]
Epoch 12 Loss: nan: 100%|██████████| 391/391 [01:52<00:00,  3.49it/s]
Epoch 13 Loss: nan: 100%|██████████| 391/391 [01:51<00:00,  3.52it/s]
Epoch 14 Loss: nan: 100%|██████████| 391/391 [01:54<00:00,  3.41it/s]
Epoch 15 Loss: nan: 100%|████

In [18]:
#%% Evaluation (Fixed Scope Issue)
@torch.no_grad()
def evaluate(model, testloader):
    model.eval()
    correct = 0
    total = 0
    
    # Initialize noise schedule locally with proper device
    timesteps = torch.arange(config.num_timesteps, device=config.device)
    alpha_bar = cosine_noise_schedule(timesteps)
    
    for images, labels in tqdm(testloader):
        images = images.to(config.device)
        labels = labels.to(config.device)
        
        # Inference process
        z = torch.randn(len(images), config.embed_dim).to(config.device)
        
        # Iterative denoising
        for t_step in reversed(range(config.num_timesteps)):
            t_batch = torch.full((len(z),), t_step, 
                               dtype=torch.long, 
                               device=config.device)
            
            u_pred = model(z, images, t_batch)
            alpha_bar_t = alpha_bar[t_step]
            
            # Update rule (keep as tensor operations)
            z = (z - (1 - alpha_bar_t) * u_pred) / torch.sqrt(alpha_bar_t)
            z += torch.sqrt(1 - alpha_bar_t) * torch.randn_like(z)
        
        # Final classification
        logits = model.final_layer(z)
        _, predicted = torch.max(logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    return 100 * correct / total

accuracy = evaluate(ema_model, testloader)
print(f"Test Accuracy: {accuracy:.2f}%")

100%|██████████| 79/79 [00:32<00:00,  2.41it/s]

Test Accuracy: 10.00%





In [19]:
#%% Continuous-Time Variant (Algorithm 2)
class NoPropCT(nn.Module):
    """Continuous-time NoProp with neural ODE"""
    def __init__(self):
        super().__init__()
        self.embed = nn.Embedding(config.num_classes, config.embed_dim)
        self.block = NoPropBlock(config.embed_dim, config.hidden_dim)
        self.time_embed = nn.Embedding(100, config.embed_dim)
        self.final_layer = nn.Linear(config.embed_dim, config.num_classes)
        
    def forward(self, z, x, t):
        t_embed = self.time_embed(t)
        return self.block(z + t_embed, x)

# Training loop for continuous-time would involve:
# - Random time sampling
# - Neural ODE solver integration
# - Different noise schedule handling
# (Implementation similar to discrete-time but with continuous components)

In [20]:
#%% Flow Matching Variant (Algorithm 3)
class NoPropFM(nn.Module):
    """Flow matching variant"""
    def __init__(self):
        super().__init__()
        self.embed = nn.Embedding(config.num_classes, config.embed_dim)
        self.block = NoPropBlock(config.embed_dim, config.hidden_dim)
        self.time_embed = nn.Embedding(100, config.embed_dim)
        self.final_layer = nn.Linear(config.embed_dim, config.num_classes)
        
    def forward(self, z, x, t):
        t_embed = self.time_embed(t)
        return self.block(z + t_embed, x)

# Training loop for flow matching would involve:
# - Linear interpolation between noise and target
# - Vector field prediction
# - Anchor loss implementation