In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torchaudio
import torchaudio.transforms as T
from synth import Synth, Wave
from synth_generator import WaveIterableDataset

torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

dataset = WaveIterableDataset(batch_size=100, duration=2.0, sample_rate=48000)
dataloader = DataLoader(dataset, batch_size=None)  

class AudioFeatureExtractor(nn.Module):
    def __init__(self, sample_rate=48000, n_fft=2048, n_mels=128):
        super().__init__()
        self.mel_spectrogram = T.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft,
            hop_length=512,
            n_mels=n_mels,
            normalized=True
        )
        
    def forward(self, x):
        mel_spec = self.mel_spectrogram(x)
        return mel_spec


class SynthParameterPredictor(nn.Module):
    def __init__(self, input_dim=128, hidden_dim=256, output_dim=3):
        super().__init__()
                
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
                
        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            #nn.Linear(64 * (input_dim // 8) * 93, hidden_dim),  
            nn.LazyLinear(hidden_dim),
            #nn.Linear(64 * 16 * 23, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, output_dim)
        )
        
    def forward(self, x):
        x = self.conv_layers(x)
        #print(f"Shape before flattening: {x.shape}")
        x = self.fc_layers(x)
        return x


feature_extractor = AudioFeatureExtractor().to(device)
model = SynthParameterPredictor().to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_batches = 5  
total_loss = 0
batch_losses = []

param_ranges = {
    'frequency': (110, 880),    #A2 to A5
    'phase': (0, 1),
    'volume': (0.2, 1.0)
}

def normalize_params(params):
    """Normalize parameters to [0, 1] range"""
    normalized = torch.zeros_like(params)
    normalized[:, 0] = (params[:, 0] - 110) / (880 - 110)  
    normalized[:, 1] = params[:, 1]  
    normalized[:, 2] = (params[:, 2] - 0.2) / (1.0 - 0.2)  
    return normalized

def denormalize_params(norm_params):
    """Convert normalized parameters back to original range"""
    denorm = torch.zeros_like(norm_params)
    denorm[:, 0] = norm_params[:, 0] * (880 - 110) + 110  
    denorm[:, 1] = norm_params[:, 1]  
    denorm[:, 2] = norm_params[:, 2] * (1.0 - 0.2) + 0.2  
    return denorm


model.train()
for batch_idx, (audio_batch, params_batch) in enumerate(dataloader):
    if batch_idx >= num_batches:
        break
            
    audio_batch = audio_batch.to(device)
    params_batch = params_batch.to(device)
    
    normalized_params = normalize_params(params_batch)
    
    #with torch.no_grad():
    features = feature_extractor(audio_batch).unsqueeze(1)  #Add channel dimension
    
    optimizer.zero_grad()
    predictions = model(features)
    loss = criterion(predictions, normalized_params)
    
    loss.backward()
    optimizer.step()
    
    batch_loss = loss.item()
    batch_losses.append(batch_loss)
    
    print(f"Batch {batch_idx+1}/{num_batches}, Loss: {batch_loss:.6f}")
    
    if batch_idx == num_batches - 1:  #Last batch
        model.eval()

        #with torch.no_grad():
        sample_indices = torch.randint(0, len(audio_batch), (5,))
        sample_audio = audio_batch[sample_indices]
        sample_params = params_batch[sample_indices]
        
        sample_features = feature_extractor(sample_audio).unsqueeze(1)
        sample_predictions = model(sample_features)
        
        denorm_predictions = denormalize_params(sample_predictions)
        
        print("\nSample Predictions:")
        print("Index | Parameter | True Value | Predicted Value")
        print("-" * 50)
        
        param_names = ['Frequency', 'Phase', 'Volume']
        for i, (true, pred) in enumerate(zip(sample_params, denorm_predictions)):
            print(f"Sample {i+1}:")
            for j, name in enumerate(param_names):
                print(f"  {name}: {true[j]:.4f} | {pred[j]:.4f}")


plt.figure(figsize=(10, 5))
plt.plot(range(1, num_batches+1), batch_losses, marker='o')
plt.title('Training Loss')
plt.xlabel('Batch')
plt.ylabel('Loss')
plt.grid(True)
plt.show()

In [None]:
# Save the model
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, 'synth_parameter_predictor.pth')

print("Model saved to 'synth_parameter_predictor.pth'")

In [None]:
# Function to visualize spectrograms and predictions
def visualize_prediction(audio, true_params, pred_params):
    """Visualize the audio spectrogram and the parameter predictions"""
    plt.figure(figsize=(12, 8))
    
    # Plot spectrogram
    plt.subplot(2, 1, 1)
    spec = feature_extractor(audio.unsqueeze(0)).squeeze().cpu().numpy()
    plt.imshow(spec, aspect='auto', origin='lower', cmap='viridis')
    plt.colorbar(format='%+2.0f dB')
    plt.title('Mel Spectrogram')
    plt.xlabel('Time Frames')
    plt.ylabel('Mel Frequency Bins')
    
    # Plot parameter comparison
    plt.subplot(2, 1, 2)
    param_names = ['Frequency (Hz)', 'Phase (0-1)', 'Volume (0-1)']
    x = np.arange(len(param_names))
    width = 0.35
    
    true_values = true_params.cpu().numpy()
    pred_values = pred_params.cpu().numpy()
    
    plt.bar(x - width/2, true_values, width, label='True')
    plt.bar(x + width/2, pred_values, width, label='Predicted')
    
    plt.xticks(x, param_names)
    plt.ylabel('Parameter Value')
    plt.title('Parameter Comparison')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

# Visualize a few examples
model.eval()
with torch.no_grad():
    # Get a new batch of data
    audio_batch, params_batch = next(iter(dataloader))
    audio_batch = audio_batch.to(device)
    params_batch = params_batch.to(device)
    
    # Extract features and predict
    features = feature_extractor(audio_batch).unsqueeze(1)
    normalized_predictions = model(features)
    predictions = denormalize_params(normalized_predictions)
    
    # Visualize 2 random examples
    sample_indices = torch.randint(0, len(audio_batch), (2,))
    for idx in sample_indices:
        visualize_prediction(
            audio_batch[idx], 
            params_batch[idx], 
            predictions[idx]
        )