In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# RQ-VAE Model
class RQVAE(nn.Module):
    def __init__(self, input_dim=768, latent_dim=32, codebook_size=256, num_codebooks=3):
        super(RQVAE, self).__init__()
        
        # Encoder: Compress the input embedding (768 dimensions) to a latent representation
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, latent_dim)  # Output latent representation (32 dimensions)
        )
        
        # Residual Quantizer: Multiple codebooks for residual quantization
        self.codebooks = nn.ModuleList([
            nn.Embedding(codebook_size, latent_dim) for _ in range(num_codebooks)
        ])
        self.codebook_size = codebook_size
        self.num_codebooks = num_codebooks
        self.latent_dim = latent_dim
        
        # Decoder: Decode the quantized representation (32 dimensions) back to input embedding
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, input_dim)  # Reconstruct the original input embedding (768 dimensions)
        )
    
    def forward(self, x):
        # Encode input to latent representation
        latent = self.encoder(x)
        
        # Residual quantization
        quantized = torch.zeros_like(latent)  # Initialize quantized representation
        residual = latent
        codes = []  # Store semantic tokens (indices from codebooks)
        
        for i, codebook in enumerate(self.codebooks):
            # Compute distances between residual and codebook vectors
            distances = torch.cdist(residual.unsqueeze(1), codebook.weight.unsqueeze(0))  # (batch_size, 1, codebook_size)
            indices = torch.argmin(distances, dim=-1).squeeze(1)  # Get closest codebook vector indices
            codes.append(indices)  # Save the selected codebook indices
            
            # Add selected codebook vector to quantized representation
            quantized += codebook(indices)
            
            # Update residual
            residual = residual - codebook(indices)
        
        # Decode the quantized representation
        reconstructed = self.decoder(quantized)
        
        return reconstructed, codes

# Loss function for RQ-VAE
class RQVAE_Loss(nn.Module):
    def __init__(self, beta=0.25):
        super(RQVAE_Loss, self).__init__()
        self.beta = beta
    
    def forward(self, input, reconstructed, quantized, latent):
        # Reconstruction loss (MSE)
        recon_loss = F.mse_loss(reconstructed, input)
        
        # Quantization loss (distance between latent and quantized)
        quant_loss = F.mse_loss(latent, quantized)
        
        # Total loss
        total_loss = recon_loss + self.beta * quant_loss
        return total_loss

# Training example
if __name__ == "__main__":
    # Hyperparameters
    input_dim = 768
    latent_dim = 32
    codebook_size = 256
    num_codebooks = 3
    batch_size = 1024
    learning_rate = 0.4
    epochs = 20000
    
    # Create RQ-VAE model
    model = RQVAE(input_dim=input_dim, latent_dim=latent_dim, codebook_size=codebook_size, num_codebooks=num_codebooks)
    optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)
    criterion = RQVAE_Loss(beta=0.25)
    
    # Dummy dataset (replace with actual embeddings from Sentence-T5)
    input_embeddings = torch.randn(batch_size, input_dim)  # Random input embeddings
    
    # Training loop
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        
        # Forward pass
        reconstructed, codes = model(input_embeddings)
        
        # Compute loss
        latent = model.encoder(input_embeddings)
        quantized = torch.zeros_like(latent)
        residual = latent
        for i, codebook in enumerate(model.codebooks):
            indices = codes[i]
            quantized += codebook(indices)
            residual = residual - codebook(indices)
        loss = criterion(input_embeddings, reconstructed, quantized, latent)
        
        # Backpropagation
        loss.backward()
        optimizer.step()
        
        # Print loss every 1000 epochs
        if (epoch + 1) % 1000 == 0:
            print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.cluster import KMeans


# RQ-VAE Model
class RQVAE(nn.Module):
    def __init__(self, input_dim=768, latent_dim=32, codebook_size=256, num_codebooks=3):
        super(RQVAE, self).__init__()
        
        # Encoder: Compress the input embedding to a latent representation
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, latent_dim)  # Output latent representation (32 dimensions)
        )
        
        # Residual Quantizer: Placeholder for codebooks (will be initialized with k-means)
        self.codebooks = nn.ModuleList([
            nn.Embedding(codebook_size, latent_dim) for _ in range(num_codebooks)
        ])
        self.codebook_size = codebook_size
        self.num_codebooks = num_codebooks
        self.latent_dim = latent_dim
        
        # Decoder: Decode the quantized representation back to input embedding
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, input_dim)  # Reconstruct the original input embedding
        )

    def initialize_codebooks(self, data):
        """
        Use k-means clustering to initialize the codebook vectors.
        Args:
            data (torch.Tensor): Input data used for k-means initialization (latent representations).
        """
        latent = self.encoder(data).detach().cpu().numpy()  # Encode data to latent space
        
        residual = latent  # Initialize residual as the latent representation
        for i, codebook in enumerate(self.codebooks):
            # Perform k-means clustering on the residual
            kmeans = KMeans(n_clusters=self.codebook_size, random_state=42)
            kmeans.fit(residual)
            # Set the codebook weights to the cluster centers
            codebook.weight.data.copy_(torch.tensor(kmeans.cluster_centers_, dtype=torch.float32))
            # Update residual by subtracting the nearest cluster center
            residual = residual - kmeans.cluster_centers_[kmeans.labels_]

    def forward(self, x):
        # Encode input to latent representation
        latent = self.encoder(x)
        
        # Residual quantization
        quantized = torch.zeros_like(latent)  # Initialize quantized representation
        residual = latent
        codes = []  # Store semantic tokens (indices from codebooks)
        
        for i, codebook in enumerate(self.codebooks):
            # Compute distances between residual and codebook vectors
            distances = torch.cdist(residual.unsqueeze(1), codebook.weight.unsqueeze(0))  # (batch_size, 1, codebook_size)
            indices = torch.argmin(distances, dim=-1).squeeze(1)  # Get closest codebook vector indices
            codes.append(indices)  # Save the selected codebook indices
            
            # Add selected codebook vector to quantized representation
            quantized += codebook(indices)
            
            # Update residual
            residual = residual - codebook(indices)
        
        # Decode the quantized representation
        reconstructed = self.decoder(quantized)
        
        return reconstructed, codes

# Loss function for RQ-VAE
class RQVAE_Loss(nn.Module):
    def __init__(self, beta=0.25):
        super(RQVAE_Loss, self).__init__()
        self.beta = beta
    
    def forward(self, input, reconstructed, quantized, latent):
        # Reconstruction loss (MSE)
        recon_loss = F.mse_loss(reconstructed, input)
        
        # Quantization loss (distance between latent and quantized)
        quant_loss = F.mse_loss(latent, quantized)
        
        # Total loss
        total_loss = recon_loss + self.beta * quant_loss
        return total_loss

# Training example
if __name__ == "__main__":
    # Hyperparameters
    input_dim = 768
    latent_dim = 32
    codebook_size = 256
    num_codebooks = 3
    batch_size = 1024
    learning_rate = 0.4
    epochs = 20000
    
    # Create RQ-VAE model
    model = RQVAE(input_dim=input_dim, latent_dim=latent_dim, codebook_size=codebook_size, num_codebooks=num_codebooks)
    optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)
    criterion = RQVAE_Loss(beta=0.25)
    
    # Dummy dataset (replace with actual embeddings from Sentence-T5)
    input_embeddings = torch.randn(batch_size, input_dim)  # Random input embeddings

    # Initialize codebooks with k-means
    model.initialize_codebooks(input_embeddings)
    
    # Training loop
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        
        # Forward pass
        reconstructed, codes = model(input_embeddings)
        
        # Compute loss
        latent = model.encoder(input_embeddings)
        quantized = torch.zeros_like(latent)
        residual = latent
        for i, codebook in enumerate(model.codebooks):
            indices = codes[i]
            quantized += codebook(indices)
            residual = residual - codebook(indices)
        loss = criterion(input_embeddings, reconstructed, quantized, latent)
        
        # Backpropagation
        loss.backward()
        optimizer.step()
        
        # Print loss every 1000 epochs
        if (epoch + 1) % 1000 == 0:
            print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")

In [None]:
class Codebook(nn.Module):
    def __init__(self, num_embeddings, embedding_dim):
        super(Codebook, self).__init__()
        self.weight = nn.Parameter(torch.randn(num_embeddings, embedding_dim))  # 初始化 codebook 权重

    def forward(self, residual):
        # 计算 residual 和 codebook 的距离
        distances = torch.cdist(residual.unsqueeze(1), self.weight.unsqueeze(0))  # (batch_size, 1, num_embeddings)
        indices = torch.argmin(distances, dim=-1).squeeze(1)  # 最近邻索引 (batch_size,)
        selected_vectors = self.weight[indices]  # 最近邻向量 (batch_size, embedding_dim)
        return selected_vectors, indices
    


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizer, BertModel

# 1. Transformer Encoder-Decoder Model
class TransformerSeq2Seq(nn.Module):
    def __init__(self, vocab_size, user_vocab_size, embed_dim=128, num_heads=6, num_layers=4, mlp_dim=1024, dropout=0.1):
        super(TransformerSeq2Seq, self).__init__()
        
        # Embedding layers for Semantic ID tokens and User ID tokens
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)  # For semantic tokens
        self.user_embedding = nn.Embedding(user_vocab_size, embed_dim)  # For user ID tokens
        
        # Positional encoding
        self.positional_encoding = nn.Parameter(torch.zeros(500, embed_dim))  # Max length = 500
        
        # Encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,          # 128
            nhead=num_heads,            # 6
            dim_feedforward=mlp_dim,    # 1024
            dropout=dropout,            # 0.1
            activation="relu"
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) # 4
        
        # Decoder
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_dim,          # 128
            nhead=num_heads,            # 6
            dim_feedforward=mlp_dim,    # 1024
            dropout=dropout,            # 0.1
            activation="relu"
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers) # 4
        
        # Output projection
        self.output_projection = nn.Linear(embed_dim, vocab_size)
    
    def forward(self, user_ids, input_ids, target_ids):
        """
        Args:
            user_ids (torch.Tensor): User ID tokens, shape (batch_size, 1)
            input_ids (torch.Tensor): Input token IDs, shape (batch_size, seq_len)
            target_ids (torch.Tensor): Target token IDs, shape (batch_size, target_len)
        Returns:
            logits (torch.Tensor): Predicted logits, shape (batch_size, target_len, vocab_size)
        """
        batch_size, seq_len = input_ids.shape
        target_len = target_ids.shape[1]
        
        # User ID embedding
        user_embedded = self.user_embedding(user_ids)  # Shape: (batch_size, 1, embed_dim)
        
        # Semantic ID embedding with positional encoding
        input_embedded = self.token_embedding(input_ids) + self.positional_encoding[:seq_len, :]  # Shape: (batch_size, seq_len, embed_dim)
        
        # Concatenate user embedding and input embedding
        input_embedded = torch.cat([user_embedded, input_embedded], dim=1)  # Shape: (batch_size, seq_len+1, embed_dim)
        
        # Encode input
        memory = self.encoder(input_embedded.permute(1, 0, 2))  # Shape: (seq_len+1, batch_size, embed_dim)
        
        # Target embedding with positional encoding
        target_embedded = self.token_embedding(target_ids) + self.positional_encoding[:target_len, :]  # Shape: (batch_size, target_len, embed_dim)
        
        # Decode
        output = self.decoder(
            tgt=target_embedded.permute(1, 0, 2),  # Shape: (target_len, batch_size, embed_dim)
            memory=memory  # Shape: (seq_len+1, batch_size, embed_dim)
        )  # Shape: (target_len, batch_size, embed_dim)
        
        # Project to vocab size
        logits = self.output_projection(output.permute(1, 0, 2))  # Shape: (batch_size, target_len, vocab_size)
        
        return logits

# 2. Training Configuration
def train_model():
    # Hyperparameters
    vocab_size = 1024  # Total tokens (Semantic ID tokens)
    user_vocab_size = 2000  # User ID tokens
    embed_dim = 128
    num_heads = 6
    num_layers = 4
    mlp_dim = 1024
    dropout = 0.1
    batch_size = 256
    seq_len = 50  # Max sequence length
    target_len = 3  # Semantic ID length (m=3)
    learning_rate = 0.01
    num_steps = 200000
    
    # Model
    model = TransformerSeq2Seq(vocab_size, user_vocab_size, embed_dim, num_heads, num_layers, mlp_dim, dropout)
    model.train()
    
    # Optimizer and Loss
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda=lambda step: min((step + 1) ** -0.5, (step + 1) * (10000 ** -1.5))
    )
    criterion = nn.CrossEntropyLoss()
    
    # Dummy data (replace with real data)
    user_ids = torch.randint(0, user_vocab_size, (batch_size, 1))  # Random user IDs
    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))  # Random input sequences
    target_ids = torch.randint(0, vocab_size, (batch_size, target_len))  # Random target sequences
    
    # Training loop
    for step in range(num_steps):
        optimizer.zero_grad()
        
        # Forward pass
        logits = model(user_ids, input_ids, target_ids)
        
        # Compute loss
        loss = criterion(logits.view(-1, vocab_size), target_ids.view(-1))
        
        # Backward pass
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        # Log progress
        if (step + 1) % 1000 == 0:
            print(f"Step [{step + 1}/{num_steps}], Loss: {loss.item():.4f}")

# Run training
if __name__ == "__main__":
    train_model()