# RNN/LSTM Model Architecture for Music Tagging

This notebook implements Recurrent Neural Network (LSTM/GRU) architectures for music classification and tagging.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio

## LSTM Model for Genre Classification

In [None]:
class MusicLSTM(nn.Module):
    """LSTM-based model for music genre classification."""
    
    def __init__(self, n_classes=10, sample_rate=22050, n_mels=128, 
                 hidden_size=256, num_layers=2, bidirectional=True):
        super(MusicLSTM, self).__init__()
        
        self.sample_rate = sample_rate
        self.n_mels = n_mels
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        
        # Mel-spectrogram transform
        self.mel_spec = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=2048,
            hop_length=512,
            n_mels=n_mels
        )
        self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
        
        # LSTM layers
        self.lstm = nn.LSTM(
            input_size=n_mels,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=bidirectional,
            dropout=0.3 if num_layers > 1 else 0
        )
        
        # Fully connected layers
        lstm_output_size = hidden_size * 2 if bidirectional else hidden_size
        self.fc1 = nn.Linear(lstm_output_size, 128)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, n_classes)
    
    def forward(self, x):
        # x: (batch, 1, time)
        
        # Convert to mel-spectrogram
        x = self.mel_spec(x)
        x = self.amplitude_to_db(x)
        # x: (batch, 1, n_mels, time)
        
        # Reshape for LSTM: (batch, time, n_mels)
        x = x.squeeze(1).transpose(1, 2)
        
        # LSTM forward pass
        lstm_out, (h_n, c_n) = self.lstm(x)
        
        # Use the last output
        if self.bidirectional:
            # Concatenate forward and backward hidden states
            last_output = torch.cat((h_n[-2], h_n[-1]), dim=1)
        else:
            last_output = h_n[-1]
        
        # Fully connected layers
        x = F.relu(self.fc1(last_output))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

## GRU Model for Music Tagging

In [None]:
class MusicGRU(nn.Module):
    """GRU-based model for music tagging (multi-label)."""
    
    def __init__(self, n_classes=50, sample_rate=22050, n_mels=128,
                 hidden_size=256, num_layers=2, bidirectional=True):
        super(MusicGRU, self).__init__()
        
        self.sample_rate = sample_rate
        self.n_mels = n_mels
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        
        # Mel-spectrogram transform
        self.mel_spec = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=2048,
            hop_length=512,
            n_mels=n_mels
        )
        self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
        
        # GRU layers
        self.gru = nn.GRU(
            input_size=n_mels,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=bidirectional,
            dropout=0.3 if num_layers > 1 else 0
        )
        
        # Fully connected layers
        gru_output_size = hidden_size * 2 if bidirectional else hidden_size
        self.fc1 = nn.Linear(gru_output_size, 256)
        self.dropout1 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, 128)
        self.dropout2 = nn.Dropout(0.5)
        self.fc3 = nn.Linear(128, n_classes)
    
    def forward(self, x):
        # x: (batch, 1, time)
        
        # Convert to mel-spectrogram
        x = self.mel_spec(x)
        x = self.amplitude_to_db(x)
        # x: (batch, 1, n_mels, time)
        
        # Reshape for GRU: (batch, time, n_mels)
        x = x.squeeze(1).transpose(1, 2)
        
        # GRU forward pass
        gru_out, h_n = self.gru(x)
        
        # Use the last output
        if self.bidirectional:
            # Concatenate forward and backward hidden states
            last_output = torch.cat((h_n[-2], h_n[-1]), dim=1)
        else:
            last_output = h_n[-1]
        
        # Fully connected layers
        x = F.relu(self.fc1(last_output))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        
        return x

## CNN-LSTM Hybrid Model

In [None]:
class CNNLSTM(nn.Module):
    """Hybrid CNN-LSTM model for music classification."""
    
    def __init__(self, n_classes=10, sample_rate=22050, n_mels=128,
                 hidden_size=256, num_layers=2):
        super(CNNLSTM, self).__init__()
        
        self.sample_rate = sample_rate
        self.n_mels = n_mels
        
        # Mel-spectrogram transform
        self.mel_spec = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=2048,
            hop_length=512,
            n_mels=n_mels
        )
        self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
        
        # CNN feature extraction
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d((2, 2))
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d((2, 2))
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool3 = nn.MaxPool2d((2, 2))
        
        # Calculate feature dimension after CNN
        # After 3 pooling layers: n_mels // 8
        cnn_output_dim = 128 * (n_mels // 8)
        
        # LSTM for temporal modeling
        self.lstm = nn.LSTM(
            input_size=cnn_output_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=0.3 if num_layers > 1 else 0
        )
        
        # Fully connected layers
        self.fc1 = nn.Linear(hidden_size * 2, 128)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, n_classes)
    
    def forward(self, x):
        # x: (batch, 1, time)
        
        # Convert to mel-spectrogram
        x = self.mel_spec(x)
        x = self.amplitude_to_db(x)
        # x: (batch, 1, n_mels, time)
        
        # CNN feature extraction
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))
        # x: (batch, 128, n_mels//8, time//8)
        
        # Reshape for LSTM: (batch, time, features)
        batch_size, channels, freq, time = x.size()
        x = x.permute(0, 3, 1, 2)  # (batch, time, channels, freq)
        x = x.reshape(batch_size, time, channels * freq)
        
        # LSTM
        lstm_out, (h_n, c_n) = self.lstm(x)
        
        # Use last hidden state (bidirectional)
        last_output = torch.cat((h_n[-2], h_n[-1]), dim=1)
        
        # Fully connected layers
        x = F.relu(self.fc1(last_output))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

## Test Model Creation

In [None]:
# Test MusicLSTM
model_lstm = MusicLSTM(n_classes=10)
print("MusicLSTM:")
print(model_lstm)
print(f"\nNumber of parameters: {sum(p.numel() for p in model_lstm.parameters())}")

# Test with random input
batch_size = 4
sample_rate = 22050
duration = 30
x = torch.randn(batch_size, 1, sample_rate * duration)
output = model_lstm(x)
print(f"\nInput shape: {x.shape}")
print(f"Output shape: {output.shape}")

In [None]:
# Test MusicGRU
model_gru = MusicGRU(n_classes=50)
print("MusicGRU:")
print(model_gru)
print(f"\nNumber of parameters: {sum(p.numel() for p in model_gru.parameters())}")

# Test with random input
output = model_gru(x)
print(f"\nInput shape: {x.shape}")
print(f"Output shape: {output.shape}")

In [None]:
# Test CNN-LSTM
model_hybrid = CNNLSTM(n_classes=10)
print("CNN-LSTM:")
print(model_hybrid)
print(f"\nNumber of parameters: {sum(p.numel() for p in model_hybrid.parameters())}")

# Test with random input
output = model_hybrid(x)
print(f"\nInput shape: {x.shape}")
print(f"Output shape: {output.shape}")

## Save Model

In [None]:
def save_model(model, path):
    """Save model state dict."""
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")

def load_model(model, path):
    """Load model state dict."""
    model.load_state_dict(torch.load(path))
    print(f"Model loaded from {path}")
    return model

# Example usage:
# save_model(model_lstm, '../models/music_lstm.pth')
# model_lstm = load_model(MusicLSTM(n_classes=10), '../models/music_lstm.pth')