In [21]:

class RevSTFT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.filter_length = config['n_fft']
        self.hop_length = config['hop_length']
        self.win_length = config['win_len']

        self.register_buffer(
            'window',
            torch.from_numpy(get_window('hann', self.win_length, fftbins=True).astype(np.float32))
        )

    def transform(self, input_data):
        if input_data.dim() == 3 and input_data.size(1) == 1:
            input_data = input_data.squeeze(1)
        
        self.original_length = input_data.size(-1)
        
        forward_transform = torch.stft(
            input_data,
            n_fft=self.filter_length, 
            hop_length=self.hop_length, 
            win_length=self.win_length, 
            window=self.window,
            center=True,
            normalized=False,
            return_complex=True
        )

        # Get magnitude and phase
        magnitudes = torch.abs(forward_transform)
        phases = torch.angle(forward_transform)
        
        # Stack magnitude and phase along new channel dimension
        # Shape will be [batch, 2, freq_bins, time_steps]
        merged_spec = torch.stack([magnitudes, phases], dim=1)
        
        return merged_spec

    def inverse(self, merged_spec):
        # Split back into magnitude and phase
        magnitudes, phases = merged_spec.chunk(2, dim=1)
        # Remove the channel dimension
        magnitudes = magnitudes.squeeze(1)
        phases = phases.squeeze(1)
        
        # Reconstruct complex spectrogram
        complex_spec = magnitudes * torch.exp(1j * phases)
        
        inverse_transform = torch.istft(
            complex_spec,
            n_fft=self.filter_length, 
            hop_length=self.hop_length, 
            win_length=self.win_length, 
            window=self.window,
            center=True,
            normalized=False,
            length=self.original_length
        )

        return inverse_transform
    
    def forward(self, input_data):
        merged_spec = self.transform(input_data)
        reconstruction = self.inverse(merged_spec)
        return reconstruction

# Test function
def test_merged_spectrogram():
    config = {
        'n_fft': 2048,
        'hop_length': 512,
        'win_len': 2048
    }
    
    model = RevSTFT(config)
    
    # Create test signal: 1 second of audio at 44.1kHz
    x = torch.randn(2, 44100)  # [batch_size=2, samples]
    
    # Get spectrogram
    merged_spec = model.transform(x)
    print("\nSpectrogram shape:", merged_spec.shape)
    print("Channel 0 (magnitude) range:", merged_spec[:,0].min().item(), "to", merged_spec[:,0].max().item())
    print("Channel 1 (phase) range:", merged_spec[:,1].min().item(), "to", merged_spec[:,1].max().item())
    
    # Test reconstruction
    y = model.inverse(merged_spec)
    print("\nInput shape:", x.shape)
    print("Output shape:", y.shape)
    print("Reconstruction error:", torch.mean((x - y) ** 2).item())
    
    # Test gradient flow
    merged_spec.requires_grad = True
    y = model.inverse(merged_spec)
    loss = y.mean()
    loss.backward()
    print("\nGradient exists:", merged_spec.grad is not None)
    
    return {
        'spectrogram_shape': merged_spec.shape,
        'reconstruction_shape': y.shape,
        'has_gradient': merged_spec.grad is not None
    }

if __name__ == "__main__":
    test_merged_spectrogram()

class YourNetwork(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.stft = RevSTFT(config)
        
        # Example processing layers
        self.process = nn.Sequential(
            nn.Conv2d(2, 32, 3, padding=1),  # Input: 2 channels (mag+phase)
            nn.ReLU(),
            nn.Conv2d(32, 2, 3, padding=1)   # Output: 2 channels (mag+phase)
        )
        
    def forward(self, x):
        # Convert to spectrogram
        spec = self.stft.transform(x)
        
        # Process
        processed_spec = self.process(spec)
        
        # Convert back to audio
        return self.stft.inverse(processed_spec)


Spectrogram shape: torch.Size([2, 2, 1025, 87])
Channel 0 (magnitude) range: 0.00040642524254508317 to 111.00408172607422
Channel 1 (phase) range: -3.141592502593994 to 3.141592502593994

Input shape: torch.Size([2, 44100])
Output shape: torch.Size([2, 44100])
Reconstruction error: 1.4275606995887806e-14

Gradient exists: True


In [22]:
import torch
import torch.optim as optim

# Define the model
model = YourNetwork(config={
    'n_fft': 512,
    'win_len': 512,
    'hop_length': 256
})

# Sample input data (batch size 1, length 2^12)
d = torch.ones((1, 2**12))

# Create an optimizer (e.g., Adam optimizer)
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Learning rate of 0.001

# Set number of epochs for the loop
epochs = 1000  # For example, 10 epochs

# Training loop with loss printing and optimization step
for epoch in range(epochs):
    # Zero the gradients before the backward pass
    optimizer.zero_grad()
    
    # Forward pass: model output
    f = model(d)
    
    # Calculate the loss (e.g., mean squared error between output and input)
    bv = (f - d) ** 2
    loss = bv.sum()  # You can also use .mean() if you want average loss

    # Backward pass: compute gradients
    loss.backward()

    # Perform an optimization step to update the model parameters
    optimizer.step()

    # Print the loss for the current epoch
    print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item()}')


AssertionError: was expecting embedding dimension of 64, but got 514