In [None]:
import torch
from torch.nn import Module
import torchaudio
import torchaudio.transforms as T
from torchaudio.functional import DB_to_amplitude
from einops import rearrange
from vocos import Vocos
import matplotlib.pyplot as plt
import numpy as np

# --- Helper Functions and Base Class ---

def exists(val):
    return exists is not None

class AudioEncoderDecoder(Module):
    pass

# --- Your Modified MelVoco Class ---

class SpecVoco(AudioEncoderDecoder):
    def __init__(
        self,
        *,
        log = True,
        sampling_rate = 24000,
        n_fft = 1024,
        win_length = 640,
        hop_length = 160,
        pretrained_vocos_path = 'charactr/vocos-mel-24khz'
    ):
        super().__init__()
        self.log = log
        self.n_fft = n_fft
        self.win_length = win_length
        self.hop_length = hop_length
        self.sampling_rate = sampling_rate

        self.vocos = Vocos.from_pretrained(pretrained_vocos_path)

    @property
    def downsample_factor(self):
        return self.hop_length

    @property
    def latent_dim(self):
        # Number of frequency bins: n_fft // 2 + 1
        return self.n_fft // 2 + 1

    def encode(self, audio):
        stft_transform = T.Spectrogram(
            n_fft = self.n_fft,
            win_length = self.win_length,
            hop_length = self.hop_length,
            window_fn = torch.hann_window,
            power=2.0 # Ensure the output is a power spectrogram
        )

        spectrogram = stft_transform(audio)

        # Explicitly remove the channel dimension
        spectrogram = spectrogram.reshape(spectrogram.shape[0], spectrogram.shape[2], spectrogram.shape[3])

        if self.log:
            # Apply log conversion (AmplitudeToDB) directly to the raw spectrogram
            # AmplitudeToDB expects a power spectrogram when converting to DB
            spectrogram = T.AmplitudeToDB()(spectrogram)

        # Reshape from (B, Freq, Time) to (B, Time, Freq/Dim)
        spectrogram = rearrange(spectrogram, 'b d n -> b n d')
        return spectrogram

    def decode(self, spectrogram_features):
        # Decode method is not needed for visualization, but kept for completeness
        spectrogram_features = rearrange(spectrogram_features, 'b n d -> b d n')
        # ... (rest of decode logic)
        return spectrogram_features # Returning features just for consistency, not actual audio

# --- Example Usage ---

# 1. Instantiate the model
model = SpecVoco(sampling_rate=16000) # Use 16kHz for common speech audio

# 2. Load a sample audio file
# IMPORTANT: Replace 'path/to/your/audio.wav' with a path to an audio file you upload or download.
# For example, you can upload a file and use its path.
try:
    # Attempt to load a sample audio file. Torchaudio loads (waveform, sample_rate).
    # We take the first channel if it's stereo, and ensure it's a 2D tensor (Batch, Time).
    waveform, sr = torchaudio.load('C:\Users\ABHILASH\Desktop\lucid_Rain\out.wav')
    if sr != model.sampling_rate:
         # Resample the audio if the sample rate doesn't match the model's expected rate (recommended)
         resampler = T.Resample(sr, model.sampling_rate)
         waveform = resampler(waveform)

    # Use the first channel and add a batch dimension (1, Time)
    audio_input = waveform[:1, :].unsqueeze(0)

except FileNotFoundError:
    print("WARNING: Audio file not found. Generating dummy noise for visualization.")
    # Create 3 seconds of dummy noise if file loading fails
    T = model.sampling_rate * 3
    audio_input = torch.randn(1, 1, T)

# 3. Encode to get the Spectrogram features
# Output shape: (Batch, Time, Freq)
spectrogram_features = model.encode(audio_input)
print(f"\nEncoded Spectrogram Shape (B, T, Freq): {spectrogram_features.shape}")

# 4. Prepare data for plotting
# Convert to NumPy and remove the batch dimension.
# We transpose to (Freq, Time) for the plot, which is the standard format for visualization.
spectrogram_to_plot = spectrogram_features[0].cpu().numpy().T



# Create the plot
plt.figure(figsize=(10, 4))
plt.imshow(spectrogram_to_plot, aspect='auto', origin='lower',
           interpolation='none', cmap='viridis')

# Set labels for verification
# X-axis: Time (in frames)
plt.xlabel("Time Frame Index")
# Y-axis: Frequency Bins
# Since Freq Bins = n_fft/2 + 1, this confirms the output is a raw spectrogram
plt.ylabel(f"Frequency Bins (0 to {model.latent_dim - 1})")
plt.title("Raw Spectrogram (Log-Magnitude)")
plt.colorbar(format="%+2.0f dB")
plt.show()