# Load Data


In [63]:
import numpy as np
import torch
from tqdm import tqdm


from torch.utils.data import DataLoader
from pips.grid_dataset import GridDataset, DatasetType

pad_token_id = 10
batch_size = 32
H, W = 32, 32

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [52]:
train_dataset = GridDataset(dataset_type=DatasetType.TRAIN)
val_dataset = GridDataset(dataset_type=DatasetType.VAL)

def collate_fn(batch, permute=False):
    result = [] 

    for g in batch:
        if permute: 
            g = g.permute()
        result.append(g.project(H, W, pad_token_id).flatten())

    batch = np.stack(result)

    return torch.from_numpy(batch).to(device)


train_loader = DataLoader(train_dataset,
                          batch_size=batch_size, 
                          collate_fn=lambda x: collate_fn(x, permute=True))

val_loader = DataLoader(val_dataset,
                        batch_size=batch_size, 
                        collate_fn=lambda x: collate_fn(x, permute=False))


In [53]:
batch = next(iter(train_loader))

In [67]:
from typing import Optional, Tuple
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F

from pips.dvae import AttnCodebook, LatentTransformer, RoPE2D, RotaryPositionalEmbeddings, Transformer


class GridDVAEConfig:
    def __init__(self):
        self.n_vocab = 11
        self.n_dim = 64
        self.n_head = 4
        self.n_grid_layer = 1
        self.n_latent_layer = 1
        self.n_codes = 64
        self.codebook_size = 512
        self.height = H
        self.width = W
        self.n_pos = H*W
        self.rope_base_height = 10000
        self.rope_base_width = 10000

class GridDVAE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.n_pos = config.n_pos
        self.embd = nn.Embedding(config.n_vocab, config.n_dim)
        nn.init.normal_(self.embd.weight, mean=0.0, std=0.02)
        

        ## The choice of out_norm is inspired by Llama. We can think of both Encoder and Decoder as Llama models.
        ## With the difference that some of the layers are replaced by TransformerProjection blocks.
        ## Like in Llama, the token embeddings flow through unnormalized until the head is applied.
        ## In my case, we normalise the final output of the encoder as well as that of decoder before applyin their
        ## respective heads. Nothing gets normalised from base to bottleneck and vice versa.

        rope_2d = RoPE2D(
                dim=config.n_dim // config.n_head,  # per-head dimension (e.g., 256//8 = 32)
                max_height=config.height,
                max_width=config.width,
                base_height=config.rope_base_height,
                base_width=config.rope_base_width)
        
        rope_1d = RotaryPositionalEmbeddings(
            dim=config.n_dim // config.n_head,  # 256//8 = 32, per-head dimension
            max_seq_len=config.n_pos,
            base=config.rope_base_height
        )

        rope_codebook = RotaryPositionalEmbeddings(
            dim=config.n_dim // 1,
            max_seq_len=config.n_pos,
            base=config.rope_base_height
        )

        
        self.grid_encoder = Transformer(
            d_model=config.n_dim,
            n_head=config.n_head,
            n_layer=config.n_grid_layer,
            out_norm=False,
            rope=rope_2d
        )

        self.latent_encoder = LatentTransformer(
            n_latent=config.n_codes,
            d_model=config.n_dim,
            n_head=config.n_head,
            n_layer=config.n_latent_layer,
            out_norm=False,
            rope=rope_1d
        )

        self.codebook = AttnCodebook(d_model=config.n_dim, 
                                    codebook_size=config.codebook_size,
                                    use_exp_relaxed=False,
                                    rope=rope_codebook,
                                    sampling=False,
                                    normalise_kq=False)

        self.latent_decoder = LatentTransformer(
            n_latent=config.n_pos,
            d_model=config.n_dim,
            n_head=config.n_head,
            n_layer=config.n_latent_layer,
            out_norm=False,
            rope=rope_1d
        )

        self.grid_decoder = Transformer(
            d_model=config.n_dim,
            n_head=config.n_head,
            n_layer=config.n_grid_layer,
            out_norm=True,
            rope=rope_2d
        )

        self.decoder_head = nn.Linear(config.n_dim, config.n_vocab, bias=False)

        rows = torch.arange(config.height, dtype=torch.long)
        cols = torch.arange(config.width, dtype=torch.long)
        grid_y, grid_x = torch.meshgrid(rows, cols, indexing='ij')
        grid_pos_indices = torch.stack([grid_y.flatten(), grid_x.flatten()], dim=1).unsqueeze(0)
        latent_pos_indices = torch.arange(config.n_pos).unsqueeze(0)

        self.register_buffer("latent_pos_indices", latent_pos_indices, persistent=False)
        self.register_buffer('grid_pos_indices', grid_pos_indices, persistent=False)

        # Apply weight initialization on registered modules.
        self.apply(self._init_weights)
        # Additionally, initialize any raw nn.Parameters.
        self.initialize_all_parameters()

    def _init_weights(self, module):
        # Get initialization mode from config (if present). Default is "normal".
        # Set self.config.init_mode = 'xavier' if you prefer Xavier initialization.
        init_mode = getattr(self.config, "init_mode", "normal")
        
        if isinstance(module, nn.Linear):
            if init_mode == "xavier":
                torch.nn.init.xavier_normal_(module.weight)
            else:
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            module.weight._initialized = True
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
                module.bias._initialized = True
        elif isinstance(module, nn.Embedding):
            if init_mode == "xavier":
                torch.nn.init.xavier_normal_(module.weight)
            else:
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            module.weight._initialized = True
        elif isinstance(module, nn.LayerNorm):
            # Often LayerNorm weights are initialized to ones and biases to zeros.
            torch.nn.init.ones_(module.weight)
            module.weight._initialized = True
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
                module.bias._initialized = True

    def initialize_all_parameters(self):
        """
        Initialize all parameters in the model recursively.
        This method is meant to also initialize raw nn.Parameter attributes that are not part
        of a submodule (and hence not handled by self.apply).
        """
        init_mode = getattr(self.config, "init_mode", "normal")
        for name, param in self.named_parameters():
            # If the parameter has already been initialized (flagged via _initialized), skip.
            if hasattr(param, "_initialized"):
                continue
            if param.ndim >= 2:
                if init_mode == "xavier":
                    torch.nn.init.xavier_normal_(param)
                else:
                    torch.nn.init.normal_(param, mean=0.0, std=0.02)
            param._initialized = True

    def encode(self, x: Tensor, grid_pos_indices: Tensor, latent_pos_indices: Tensor) -> Tensor:
        x_embd = self.embd(x)
        grid_encoded, _ = self.grid_encoder(x_embd, positions=grid_pos_indices)
        latent_encoded, _ = self.latent_encoder(grid_encoded, positions=latent_pos_indices)
        return latent_encoded


    def decode(self, x: Tensor, grid_pos_indices: Tensor, latent_pos_indices: Tensor) -> Tensor:    
        latent_decoded, _ = self.latent_decoder(x, positions=latent_pos_indices)        
        grid_decoded, _ = self.grid_decoder(latent_decoded, positions=grid_pos_indices)        
        grid_decoded_logits = self.decoder_head(grid_decoded)
        return grid_decoded_logits

    def forward(self, x: Tensor, tau: Tensor = torch.tensor(1.0), residual_scaling: Tensor = torch.tensor(0.0)) -> Tuple[Tensor, dict, Tensor]:
        B, S = x.size()
        grid_pos_indices = self.grid_pos_indices.expand(B, -1, -1)
        latent_pos_indices = self.latent_pos_indices.expand(B, -1)
    
        encoded_logits = self.encode(x, grid_pos_indices, latent_pos_indices)
    
        quantized, log_alpha, z = self.codebook(encoded_logits, tau=tau, residual_scaling=residual_scaling, positions=latent_pos_indices)
    
        decoded_logits = self.decode(quantized, grid_pos_indices, latent_pos_indices)
    
        return decoded_logits, log_alpha
    

model = GridDVAE(GridDVAEConfig()).to(device)
    

In [76]:
def compute_perplexity(log_alpha, tau):
    """
    Compute perplexity of probability distributions, accounting for temperature.
    Perplexity = 2^(entropy), measures how uniform the distribution is.
    
    Args:
        log_alpha: Logits for the latent distribution [B, N, codebook_size]
        tau: Temperature parameter for scaling logits
        
    Returns:
        Perplexity per sample [B, N]
    """
    # Apply temperature to logits and compute probabilities
    probs = F.softmax(log_alpha / tau, dim=-1)
    
    # Compute entropy: -sum(p * log(p))
    log_probs = torch.log2(probs + 1e-10)  # Add small epsilon to avoid log(0)
    entropy = -torch.sum(probs * log_probs, dim=-1)  # [B, N]
    
    # Perplexity = 2^(entropy)
    perplexity = 2.0 ** entropy
    
    return perplexity\
    

def entropy_loss(log_alpha, tau, reduction="mean"):
    """
    Compute the entropy of the latent distribution, accounting for temperature.
    
    Args:
        log_alpha: Logits for the latent distribution, shape [B, N, codebook_size]
        tau: Temperature parameter for scaling logits
        reduction: Reduction method ('sum', 'mean', or 'batchmean')
                
    Returns:
        Entropy reduced according to the specified method
    """
    # Apply temperature to logits
    scaled_log_alpha = log_alpha / tau
    
    # Compute log probabilities using log_softmax
    log_probs = F.log_softmax(scaled_log_alpha, dim=-1)
    
    # Get probabilities by exponentiating log probabilities
    probs = torch.exp(log_probs)
    
    # Compute entropy: -sum(p * log(p))
    entropy_per_sample = -torch.sum(probs * log_probs, dim=-1)  # [B, N]
    
    # Apply reduction
    if reduction == "sum":
        return entropy_per_sample.sum()
    elif reduction == "mean":
        return entropy_per_sample.mean()
    elif reduction == "batchmean":
        return entropy_per_sample.mean()
    elif reduction == "none":
        return entropy_per_sample
    else:
        raise ValueError(f"Invalid reduction: {reduction}")

def codebook_diversity_loss(log_alpha, tau, reduction="mean"):
    """
    Compute a diversity loss that encourages different samples to use different codebook entries.
    
    Args:
        log_alpha: Logits for the latent distribution, shape [B, N, codebook_size]
        tau: Temperature parameter for scaling logits
        reduction: Reduction method ('sum', 'mean', or 'batchmean')
                
    Returns:
        Diversity loss reduced according to the specified method
    """
    # Apply temperature to logits
    scaled_log_alpha = log_alpha / tau
    
    # Compute probabilities
    probs = F.softmax(scaled_log_alpha, dim=-1)  # [B, N, codebook_size]
    
    # Average usage of each codebook entry across the batch
    # This gives us the average probability of each codebook entry for each code position
    batch_avg_probs = probs.mean(dim=0)  # [N, codebook_size]
    
    # Compute entropy of the batch-averaged distribution
    # High entropy means different samples use different codebook entries
    # Low entropy means all samples use the same codebook entries
    log_batch_avg_probs = torch.log2(batch_avg_probs + 1e-10)
    batch_entropy = -torch.sum(batch_avg_probs * log_batch_avg_probs, dim=-1)  # [N]
    
    # We want to maximize this entropy, so we negate it for minimization
    diversity_loss = -batch_entropy  # [N]
    
    # Apply reduction
    if reduction == "sum":
        return diversity_loss.sum()
    elif reduction == "mean":
        return diversity_loss.mean()
    elif reduction == "batchmean":
        return diversity_loss.mean()
    elif reduction == "none":
        return diversity_loss
    else:
        raise ValueError(f"Invalid reduction: {reduction}")

In [79]:
import torch.optim as optim

learning_rate = 0.001
optimizer = optim.Adam(model.parameters(), lr=learning_rate)



num_steps = 1000
diversity_weight = 0.1
max_entropy_weight = 0.1
min_entropy_weight = 0.0
entropy_anneal_steps = 1000

model.train()

tau = torch.tensor(1.0).to(device)
residual_scaling = torch.tensor(0.0).to(device)
criterion = nn.CrossEntropyLoss(reduction='mean')

entropy_schedule = lambda step: min(max_entropy_weight,
                                    min_entropy_weight + (max_entropy_weight - min_entropy_weight) * 
                                    min(1.0, step / entropy_anneal_steps))

# Add a progress bar
progress_bar = tqdm(range(num_steps))
step = 0
while True:
    for batch in train_loader:


        optimizer.zero_grad()

        logits, log_alpha = model(batch, tau=tau, residual_scaling=residual_scaling)
        recon_loss = criterion(logits.view(-1, logits.size(-1)), batch.view(-1))
        ent_loss = entropy_loss(log_alpha, tau=tau, reduction="mean")
        div_loss = codebook_diversity_loss(log_alpha, tau=tau, reduction="mean")
        entropy_weight = entropy_schedule(step)

        loss = recon_loss + entropy_weight * ent_loss + diversity_weight * div_loss

        loss.backward()
        optimizer.step()

        progress_bar.update(1)
        step += 1

        progress_bar.set_postfix({
            'loss': loss.item(),
            'recon': recon_loss.item(),
            'entropy': ent_loss.item(),
            'diversity': div_loss.item() if diversity_weight > 0 else 0.0,
            'temp': tau.item(),
            'entropy_weight': entropy_weight
        })

        if iter >= num_steps:
            break
      

model(batch)[0].size()

  1%|          | 12/1000 [00:11<15:36,  1.06it/s, loss=0.00558, recon=0.899, entropy=6.24, diversity=-9, temp=1, entropy_weight=0.0011]
 18%|█▊        | 181/1000 [01:20<06:32,  2.09it/s, loss=-0.231, recon=0.556, entropy=6.24, diversity=-9, temp=1, entropy_weight=0.018]   

KeyboardInterrupt: 