In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchaudio
import torchaudio.transforms as T
import numpy as np

# --- 1. Model Definition (Placeholder) ---
class AudioMAE(nn.Module):
    def __init__(self, input_dim, encoder_dim, decoder_dim, num_patches, mask_ratio=0.75):
        super().__init__()
        self.input_dim = input_dim
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim
        self.num_patches = num_patches
        self.mask_ratio = mask_ratio

        # Patch embedding (e.g., a Conv1D or Linear layer)
        # For simplicity, let's assume patches are already formed and flattened
        self.patch_embed = nn.Linear(input_dim // num_patches, encoder_dim) # Simplified

        # Encoder (e.g., a Transformer Encoder)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=encoder_dim, nhead=4, dim_feedforward=encoder_dim*2, batch_first=True),
            num_layers=6
        )

        # Decoder (e.g., a Transformer Decoder or MLP)
        self.decoder_embed = nn.Linear(encoder_dim, decoder_dim)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))
        self.decoder = nn.TransformerEncoder( # Using Encoder for simplicity, a Decoder would be more standard
            nn.TransformerEncoderLayer(d_model=decoder_dim, nhead=4, dim_feedforward=decoder_dim*2, batch_first=True),
            num_layers=2
        )
        self.decoder_pred = nn.Linear(decoder_dim, input_dim // num_patches) # Predicts patch features

    def forward_encoder(self, x, unmasked_indices):
        # x: (batch_size, num_patches, patch_dim)
        x = self.patch_embed(x) # (batch_size, num_patches, encoder_dim)

        # Apply positional encoding (not shown for brevity, but important)
        # x = x + self.pos_embed

        # Select only unmasked patches for encoder
        x_unmasked = torch.gather(x, dim=1, index=unmasked_indices.unsqueeze(-1).expand(-1, -1, x.shape[-1]))
        encoded_patches = self.encoder(x_unmasked)
        return encoded_patches

    def forward_decoder(self, encoded_patches, unmasked_indices, masked_indices):
        # encoded_patches: (batch_size, num_unmasked_patches, encoder_dim)
        batch_size = encoded_patches.shape[0]
        num_unmasked = unmasked_indices.shape[1]
        num_masked = masked_indices.shape[1]

        encoded_patches = self.decoder_embed(encoded_patches)

        # Create full sequence with mask tokens
        full_sequence = torch.cat([encoded_patches, self.mask_token.expand(batch_size, num_masked, -1)], dim=1)
        # Reconstruct original order (simplified, actual MAE shuffles and unshuffles)
        # For simplicity, we assume unmasked_indices and masked_indices together form the full sequence in order
        # A more robust implementation would involve an unshuffle operation based on original indices.
        # Here, we'll just decode the concatenated sequence.
        # A proper MAE would re-introduce positional embeddings for the decoder.

        decoded_patches = self.decoder(full_sequence)
        return self.decoder_pred(decoded_patches[:, num_unmasked:]) # Predict only masked patches

    def forward_loss(self, x_patches, pred_patches, masked_indices):
        # x_patches: (batch_size, num_patches, patch_dim) - original patches
        # pred_patches: (batch_size, num_masked_patches, patch_dim) - predicted masked patches
        # masked_indices: (batch_size, num_masked_patches)

        # Gather the ground truth masked patches
        target_patches = torch.gather(x_patches, dim=1, index=masked_indices.unsqueeze(-1).expand(-1, -1, x_patches.shape[-1]))

        loss = nn.functional.mse_loss(pred_patches, target_patches)
        return loss

    def forward(self, x_patches):
        # x_patches: (batch_size, num_patches, patch_dim)
        batch_size, num_patches, _ = x_patches.shape
        num_masked = int(self.mask_ratio * num_patches)

        # Generate random mask
        # This is a simplified masking. Real MAE shuffles patches.
        noise = torch.rand(batch_size, num_patches, device=x_patches.device)
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        unmasked_indices = ids_shuffle[:, :num_patches - num_masked]
        masked_indices = ids_shuffle[:, num_patches - num_masked:]


        encoded_patches = self.forward_encoder(x_patches, unmasked_indices)
        pred_masked_patches = self.forward_decoder(encoded_patches, unmasked_indices, masked_indices)
        loss = self.forward_loss(x_patches, pred_masked_patches, masked_indices)
        return loss, pred_masked_patches, masked_indices, unmasked_indices


# --- 2. Dataset and DataLoader (Placeholder) ---
class DummyAudioDataset(Dataset):
    def __init__(self, num_samples=1000, sample_rate=16000, duration=1, num_patches=64, feature_dim=256):
        self.num_samples = num_samples
        self.sample_rate = sample_rate
        self.duration = duration
        self.num_frames = self.sample_rate * self.duration
        self.num_patches = num_patches # Number of patches
        self.patch_len = self.num_frames // self.num_patches # Length of each patch in frames
        self.feature_dim = feature_dim # For mel spectrogram

        # Mel Spectrogram configuration
        self.mel_spectrogram = T.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=400,
            hop_length=160, # results in 100 frames per second for 16kHz
            n_mels=feature_dim
        )
        # Ensure total frames from mel spec are divisible by num_patches
        # For 1s audio at 16kHz, hop_length 160 -> 100 frames. If num_patches=64, this won't divide.
        # Let's adjust num_frames for mel spectrogram output
        # Effective frames for mel: (self.num_frames - n_fft) // hop_length + 1
        # For simplicity, we'll generate raw audio and then process it.
        # The actual patching strategy needs to be carefully designed.

        # For this dummy dataset, we'll generate random data and "pretend" they are patches
        # A real dataset would load audio, compute spectrograms, and then patch them.
        self.patch_feature_dim = feature_dim # Each patch is a segment of the mel spectrogram
                                            # If we flatten time within a patch, it's patch_len_mel * feature_dim
                                            # Here, we assume patches are (num_patches, feature_dim_per_patch)
                                            # Let's assume each "patch" is a time step of the mel spectrogram
                                            # So, input_dim for AudioMAE will be feature_dim (n_mels)
                                            # And num_patches will be the number of time steps in the mel spectrogram

        # Calculate number of time frames from mel spectrogram
        # For 1s audio, 16000 frames, n_fft=400, hop_length=160 -> (16000-400)//160 + 1 = 97.5 -> 98 frames
        # Let's use a fixed number of frames for simplicity in this dummy example
        self.num_mel_frames = 128 # This will be our num_patches
        self.input_dim_mae = feature_dim # n_mels

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Generate dummy audio waveform
        waveform = torch.randn(self.num_frames)

        # Compute Mel Spectrogram
        mel_spec = self.mel_spectrogram(waveform) # (n_mels, num_time_frames)
        mel_spec = mel_spec.squeeze(0) # Remove channel dim if present
        mel_spec = mel_spec[:, :self.num_mel_frames] # Trim or pad to ensure fixed num_mel_frames

        # Normalize (example)
        mel_spec = (mel_spec - mel_spec.mean()) / (mel_spec.std() + 1e-6)

        # Patches are time steps of the mel spectrogram
        # (n_mels, num_mel_frames) -> (num_mel_frames, n_mels) to match (num_patches, patch_dim)
        patches = mel_spec.transpose(0, 1) # (num_mel_frames, n_mels)
        return patches


# --- 3. Training Configuration ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Hyperparameters
batch_size = 16
learning_rate = 1e-4
num_epochs = 50
num_patches_dataset = 128 # This should match DummyAudioDataset.num_mel_frames
feature_dim_dataset = 64  # This should match DummyAudioDataset.feature_dim (n_mels)

# Instantiate Dataset and DataLoader
# Note: The input_dim for AudioMAE should be feature_dim_dataset
# The num_patches for AudioMAE should be num_patches_dataset
dummy_dataset = DummyAudioDataset(num_samples=200, num_patches=num_patches_dataset, feature_dim=feature_dim_dataset)
train_loader = DataLoader(dummy_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

# Instantiate Model
# The input_dim to the patch_embed in AudioMAE should be feature_dim_dataset
# The num_patches in AudioMAE should be num_patches_dataset
model = AudioMAE(
    input_dim=feature_dim_dataset * num_patches_dataset, # This is if we flatten all patches
                                                        # If patch_embed takes individual patch features:
                                                        # input_dim should be feature_dim_dataset
                                                        # And the model's patch_embed input should be feature_dim_dataset
    encoder_dim=256,
    decoder_dim=128,
    num_patches=num_patches_dataset, # This is the sequence length for the transformer
    mask_ratio=0.75
).to(device)

# Adjusting model init based on how DummyAudioDataset provides patches:
# DummyAudioDataset provides (num_mel_frames, n_mels) which is (num_patches, feature_dim_per_patch)
# So, AudioMAE's patch_embed should take feature_dim_per_patch as input.
# The input_dim in AudioMAE's __init__ was for the *total* flattened input, let's adjust.
# The patch_embed in AudioMAE is nn.Linear(input_dim // num_patches, encoder_dim)
# So, input_dim // num_patches should be feature_dim_dataset.
# Thus, the first argument to AudioMAE (input_dim) should be feature_dim_dataset * num_patches_dataset.
# This seems correct if the model internally flattens the (num_patches, patch_feature_dim) input before patch_embed.
# Let's refine the AudioMAE's patch_embed to directly take patch_feature_dim.

class AudioMAERefined(nn.Module):
    def __init__(self, patch_feature_dim, encoder_dim, decoder_dim, num_patches, mask_ratio=0.75):
        super().__init__()
        self.patch_feature_dim = patch_feature_dim # e.g., n_mels
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim
        self.num_patches = num_patches # e.g., number of time frames in mel spectrogram
        self.mask_ratio = mask_ratio

        self.patch_embed = nn.Linear(patch_feature_dim, encoder_dim)

        # Positional Encoding (learnable or fixed)
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, encoder_dim))

        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=encoder_dim, nhead=4, dim_feedforward=encoder_dim*2, batch_first=True, dropout=0.1),
            num_layers=6
        )

        self.decoder_embed = nn.Linear(encoder_dim, decoder_dim)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))
        # Positional Encoding for decoder (can be shared or separate)
        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches, decoder_dim))


        self.decoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=decoder_dim, nhead=4, dim_feedforward=decoder_dim*2, batch_first=True, dropout=0.1),
            num_layers=2
        )
        self.decoder_pred = nn.Linear(decoder_dim, patch_feature_dim) # Predicts original patch features

    def _generate_random_mask(self, x):
        batch_size, num_patches, _ = x.shape
        num_unmasked = int(num_patches * (1 - self.mask_ratio))

        noise = torch.rand(batch_size, num_patches, device=x.device)
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        unmasked_indices = ids_shuffle[:, :num_unmasked]
        masked_indices = ids_shuffle[:, num_unmasked:]
        return unmasked_indices, masked_indices, ids_restore

    def forward_encoder(self, x_patches, unmasked_indices):
        # x_patches: (batch_size, num_patches, patch_feature_dim)
        x = self.patch_embed(x_patches) # (batch_size, num_patches, encoder_dim)
        x = x + self.pos_embed # Add positional encoding

        # Gather unmasked patches
        x_unmasked = torch.gather(x, dim=1, index=unmasked_indices.unsqueeze(-1).expand(-1, -1, self.encoder_dim))
        encoded_patches = self.encoder(x_unmasked) # (batch_size, num_unmasked_patches, encoder_dim)
        return encoded_patches

    def forward_decoder(self, encoded_unmasked_patches, unmasked_indices, masked_indices, ids_restore):
        # encoded_unmasked_patches: (batch_size, num_unmasked_patches, encoder_dim)
        batch_size = encoded_unmasked_patches.shape[0]
        num_total_patches = self.num_patches
        num_masked_patches = masked_indices.shape[1]

        # Embed encoded patches to decoder dimension
        decoder_embedded_patches = self.decoder_embed(encoded_unmasked_patches) # (B, num_unmasked, decoder_dim)

        # Create mask tokens for masked patches
        mask_tokens = self.mask_token.expand(batch_size, num_masked_patches, -1) # (B, num_masked, decoder_dim)

        # Concatenate unmasked patch embeddings and mask tokens
        # We need to place them in the original shuffled order before unmasking
        # This is a simplification: MAE typically appends mask tokens and then unshuffles.
        # For a more accurate MAE:
        # 1. Create a full sequence placeholder for the decoder.
        # 2. Place decoder_embedded_patches into this placeholder at unmasked_indices (after shuffling).
        # 3. Place mask_tokens into this placeholder at masked_indices (after shuffling).
        # 4. Add decoder_pos_embed to this full sequence.
        # 5. Unshuffle the sequence using ids_restore.
        # 6. Pass to decoder.

        # Simplified approach (less faithful to original MAE but easier to start):
        # Assume unmasked_indices and masked_indices are sorted for concatenation
        # This is not what MAE does. MAE restores order.

        # Let's try to be more faithful:
        # Create full sequence for decoder input
        decoder_input_full = torch.zeros(batch_size, num_total_patches, self.decoder_dim, device=encoded_unmasked_patches.device)

        # Scatter unmasked patches
        # ids_shuffle gives the shuffled indices. unmasked_indices are the first N of these.
        # We need to place decoder_embedded_patches at their *original* positions in the shuffled sequence
        # This part is tricky without the exact MAE unshuffle logic.
        # For now, let's use a simplified placeholder for decoder input construction.

        # A common MAE approach:
        # x_full for decoder: (batch_size, num_patches, decoder_dim)
        # Initialize with mask tokens, then fill in the encoded unmasked patches
        x_full_decoder = self.mask_token.expand(batch_size, num_total_patches, -1).clone()
        # Scatter the embedded unmasked patches to their positions in the *shuffled* sequence
        # unmasked_indices are like [idx_val1, idx_val2, ...] where idx_val is the original patch index
        # We need to map these original patch indices to their positions in the *shuffled* sequence.
        # This is what ids_shuffle and ids_restore are for.

        # Let `h_encoder` be `decoder_embedded_patches`
        # Let `ids_unmasked` be `unmasked_indices` (these are indices of original patches that are kept)
        # Let `ids_masked` be `masked_indices` (indices of original patches that are masked)

        # Create the full sequence for the decoder
        # The length is num_total_patches.
        # The first `num_unmasked` positions (in the shuffled sense) get `decoder_embedded_patches`.
        # The remaining `num_masked` positions (in the shuffled sense) get `mask_tokens`.
        # Then, add decoder positional embeddings and unshuffle.

        # Step 1: Concatenate visible tokens and mask tokens
        # decoder_embedded_patches corresponds to the first `num_unmasked` elements of the shuffled sequence.
        # mask_tokens corresponds to the last `num_masked` elements of the shuffled sequence.
        x_shuffled_for_decoder = torch.cat([decoder_embedded_patches, mask_tokens], dim=1)

        # Step 2: Unshuffle the sequence
        # ids_restore will map from the shuffled order back to the original patch order.
        x_unshuffled_for_decoder = torch.gather(x_shuffled_for_decoder, dim=1,
                                             index=ids_restore.unsqueeze(-1).expand(-1, -1, self.decoder_dim))

        # Step 3: Add decoder positional embedding
        x_unshuffled_for_decoder = x_unshuffled_for_decoder + self.decoder_pos_embed

        # Step 4: Pass through decoder
        decoded_full_sequence = self.decoder(x_unshuffled_for_decoder) # (B, num_total_patches, decoder_dim)

        # Step 5: Predict only the masked patches
        # We need to gather the outputs corresponding to the original masked_indices
        # decoded_full_sequence is in original patch order.
        # masked_indices contains the *original* indices of the patches that were masked.
        pred_masked_patches = torch.gather(decoded_full_sequence, dim=1,
                                           index=masked_indices.unsqueeze(-1).expand(-1, -1, self.decoder_dim))
        pred_masked_patches = self.decoder_pred(pred_masked_patches) # (B, num_masked_patches, patch_feature_dim)
        return pred_masked_patches


    def forward_loss(self, original_patches, pred_masked_patches, masked_indices):
        # original_patches: (batch_size, num_total_patches, patch_feature_dim)
        # pred_masked_patches: (batch_size, num_masked_patches, patch_feature_dim)
        # masked_indices: (batch_size, num_masked_patches) - original indices of masked patches

        # Gather the ground truth for the masked patches from the original input
        target_masked_patches = torch.gather(original_patches, dim=1,
                                          index=masked_indices.unsqueeze(-1).expand(-1, -1, self.patch_feature_dim))

        loss = nn.functional.mse_loss(pred_masked_patches, target_masked_patches)
        return loss

    def forward(self, x_patches):
        # x_patches: (batch_size, num_patches, patch_feature_dim)
        unmasked_indices, masked_indices, ids_restore = self._generate_random_mask(x_patches)

        encoded_unmasked_patches = self.forward_encoder(x_patches, unmasked_indices)
        pred_masked_patches = self.forward_decoder(encoded_unmasked_patches, unmasked_indices, masked_indices, ids_restore)
        loss = self.forward_loss(x_patches, pred_masked_patches, masked_indices)

        return loss, pred_masked_patches, masked_indices # ids_restore could also be returned for viz


model = AudioMAERefined(
    patch_feature_dim=feature_dim_dataset, # n_mels
    encoder_dim=256, # Latent dimension of encoder
    decoder_dim=128, # Latent dimension of decoder
    num_patches=num_patches_dataset, # Number of time frames / patches
    mask_ratio=0.75
).to(device)


# Optimizer and Loss Function
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.05)
# Loss is calculated inside the model's forward pass for MAE

# --- 4. Training Loop ---
print("Starting training...")
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch_idx, patches_batch in enumerate(train_loader):
        patches_batch = patches_batch.to(device) # (batch_size, num_patches, patch_feature_dim)

        optimizer.zero_grad()

        loss, _, _ = model(patches_batch)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if batch_idx % 2 == 0: # Print every few batches
            print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}")

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}] completed. Average Loss: {avg_loss:.4f}")

print("Training finished.")

# --- 5. Evaluation / Inference (Placeholder) ---
# model.eval()
# with torch.no_grad():
#     # Perform inference, e.g., reconstruct masked audio or extract features
#     for patches_batch in train_loader: # Using train_loader for example
#         patches_batch = patches_batch.to(device)
#         loss, reconstructed_patches, masked_indices = model(patches_batch)
#         print(f"Evaluation Loss: {loss.item():.4f}")
#         # Here you could visualize the original vs reconstructed masked patches
#         break # Just one batch for example


Using device: cuda
Starting training...


RuntimeError: The size of tensor a (101) must match the size of tensor b (128) at non-singleton dimension 1