Model Architecture

In [None]:
#Encoder Module


In [None]:
import torch
from torch import nn

# VAE class shared by all layers
class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(VAE, self).__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 2 * latent_dim)  # Mean and log-variance
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, input_dim),
            nn.Sigmoid()  # To reconstruct the input features
        )
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        # Encoder
        mu_logvar = self.encoder(x).view(-1, 2, self.latent_dim)
        mu = mu_logvar[:, 0, :]  # Mean
        logvar = mu_logvar[:, 1, :]  # Log-variance
        
        # Reparameterization trick
        z = self.reparameterize(mu, logvar)
        
        # Decoder
        reconstructed_x = self.decoder(z)
        return reconstructed_x, mu, logvar

# Layer 1: Precursor Speaker Representation
class PrecursorSpeakerLayer(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(PrecursorSpeakerLayer, self).__init__()
        self.vae = VAE(input_dim, latent_dim)
    
    def forward(self, x):
        # Use VAE to extract precursor speaker embedding
        reconstructed_x, mu, logvar = self.vae(x)
        return reconstructed_x, mu, logvar

# Layer 2: Disentangled Content Representation
class DisentangledContentLayer(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(DisentangledContentLayer, self).__init__()
        self.vae = VAE(input_dim, latent_dim)
    
    def forward(self, x, speaker_mu):
        # Subtract speaker_mu from x to disentangle content
        content_input = x - speaker_mu
        
        reconstructed_x, content_mu, content_logvar = self.vae(content_input)
        return reconstructed_x, content_mu, content_logvar

# Layer 3: Final Disentangled Speaker Representation
class FinalSpeakerLayer(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(FinalSpeakerLayer, self).__init__()
        self.vae = VAE(input_dim, latent_dim)
    
    def forward(self, x, content_mu):
        # Subtract content_mu from x to get final speaker embedding
        speaker_input = x - content_mu
        
        reconstructed_x, final_speaker_mu, final_speaker_logvar = self.vae(speaker_input)
        return reconstructed_x, final_speaker_mu, final_speaker_logvar

# Full Temporal Aggregation Module (Layer 1 -> Layer 2 -> Layer 3)
class TemporalAggregation(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(TemporalAggregation, self).__init__()
        self.layer1 = PrecursorSpeakerLayer(input_dim, latent_dim)
        self.layer2 = DisentangledContentLayer(input_dim, latent_dim)
        self.layer3 = FinalSpeakerLayer(input_dim, latent_dim)
    
    def forward(self, x):
        # Layer 1: Precursor Speaker Representation
        precursor_output, speaker_mu, speaker_logvar = self.layer1(x)
        
        # Layer 2: Disentangled Content Representation
        content_output, content_mu, content_logvar = self.layer2(x, speaker_mu)
        
        # Layer 3: Final Disentangled Speaker Representation
        final_speaker_output, final_speaker_mu, final_speaker_logvar = self.layer3(x, content_mu)
        
        return {
            'precursor_output': precursor_output,
            'speaker_mu': speaker_mu,
            'speaker_logvar': speaker_logvar,
            'content_output': content_output,
            'content_mu': content_mu,
            'content_logvar': content_logvar,
            'final_speaker_output': final_speaker_output,
            'final_speaker_mu': final_speaker_mu,
            'final_speaker_logvar': final_speaker_logvar
        }

# Example of running the complete Temporal Aggregation module
if __name__ == "__main__":
    # Hyperparameters
    input_dim = 40  # Example: MFCC features with 40 dimensions
    latent_dim = 16  # Latent space dimension
    
    # Simulated encoder output (batch of MFCCs)
    batch_size = 32
    sample_input = torch.randn(batch_size, input_dim)  # Batch of audio feature vectors (MFCCs)
    
    # Instantiate the Temporal Aggregation module
    temporal_aggregation = TemporalAggregation(input_dim, latent_dim)
    
    # Pass input through all three layers (Layer 1 -> Layer 2 -> Layer 3)
    output = temporal_aggregation(sample_input)
    
    # Print output shapes for each layer
    print("Precursor Output Shape:", output['precursor_output'].shape)
    print("Speaker Mu Shape:", output['speaker_mu'].shape)
    print("Content Output Shape:", output['content_output'].shape)
    print("Content Mu Shape:", output['content_mu'].shape)
    print("Final Speaker Output Shape:", output['final_speaker_output'].shape)
    print("Final Speaker Mu Shape:", output['final_speaker_mu'].shape)