In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torchaudio
import torchaudio.transforms as T
from synth import Synth, Wave
from synth_generator import WaveIterableDataset

torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [2]:
dataset = WaveIterableDataset(duration=2.0, sample_rate=48000)
import multiprocessing
dataloader = DataLoader(
	dataset,
	batch_size=100, 
	num_workers=4,
	multiprocessing_context=multiprocessing.get_context('spawn')
)

class AudioFeatureExtractor(nn.Module):
	def __init__(self, sample_rate=48000, n_fft=2048, n_mels=128):
		super().__init__()
		self.mel_spectrogram = T.MelSpectrogram(
			sample_rate=sample_rate,
			n_fft=n_fft,
			hop_length=512,
			n_mels=n_mels,
			normalized=True
		)
		
	def forward(self, x):
		mel_spec = self.mel_spectrogram(x)
		return mel_spec


class SynthParameterPredictor(nn.Module):
	def __init__(self, input_dim=128, hidden_dim=256, output_dim=3):
		super().__init__()
				
		self.conv_layers = nn.Sequential(
			#Conv layers
			nn.Conv2d(1, 16, kernel_size=3, padding=1),
			nn.ReLU(),
			nn.MaxPool2d(2),
			nn.Conv2d(16, 32, kernel_size=3, padding=1),
			nn.ReLU(),
			nn.MaxPool2d(2),
			nn.Conv2d(32, 64, kernel_size=3, padding=1),
			nn.ReLU(),
			nn.MaxPool2d(2),

			#Linear layers
			nn.Flatten(),
			nn.LazyLinear(hidden_dim),
			#nn.Linear(64 * 16 * 23, hidden_dim),
			nn.ReLU(),
			nn.Dropout(0.3),
			nn.Linear(hidden_dim, hidden_dim // 2),
			nn.ReLU(),
			nn.Dropout(0.3),
			nn.Linear(hidden_dim // 2, output_dim)
		)
		
	def forward(self, x):
		return self.conv_layers(x)


feature_extractor = AudioFeatureExtractor().to(device)
model = SynthParameterPredictor().to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

checkpoint_path = 'synth_parameter_predictor.pth'
import os
if os.path.exists(checkpoint_path):
	checkpoint = torch.load(checkpoint_path)
	model.load_state_dict(checkpoint['model_state_dict'])
	optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
	model.eval()
	print("Checkpoint loaded. Skipping training.")
	loadedModel = True
else:
	print("Checkpoint not found. Proceeding with training.")
	loadedModel = False

Checkpoint loaded. Skipping training.


In [None]:
num_batches = 200
total_loss = 0
batch_losses = []

param_ranges = {
	'frequency': (110, 880),    #A2 to A5
	'phase': (0, 1),
	'volume': (0.2, 1.0)
}

def normalize_params(params):
	"""Normalize parameters to [0, 1] range"""
	normalized = torch.zeros_like(params)
	normalized[0] = (params[0] - 110) / (880 - 110)  
	normalized[1] = params[1]  
	normalized[2] = (params[2] - 0.2) / (1.0 - 0.2)  
	return normalized

def normalize_batch(batch):
	return torch.stack([normalize_params(params) for params in batch])

def denormalize_params(norm_params):
	"""Convert normalized parameters back to original range"""
	denorm = torch.zeros_like(norm_params)
	denorm[0] = norm_params[0] * (880 - 110) + 110  
	denorm[1] = norm_params[1]  
	denorm[2] = norm_params[2] * (1.0 - 0.2) + 0.2  
	return denorm

def denormalize_batch(batch):
	return torch.stack([denormalize_params(params) for params in batch])

if not loadedModel:
	model.train()
	for batch_idx in range(0, num_batches):
		audio_batch, params_batch = next(iter(dataloader))
				
		audio_batch = audio_batch.to(device)
		params_batch = params_batch.to(device)
		
		normalized_params = normalize_batch(params_batch).to(device)
		
		with torch.no_grad():
			features = feature_extractor(audio_batch)
		
		optimizer.zero_grad()
		predictions = model(features.unsqueeze(1))
		loss = criterion(predictions, normalized_params)
		
		loss.backward()
		optimizer.step()
		
		batch_loss = loss.item()
		batch_losses.append(batch_loss)
		
		print(f"Batch {batch_idx+1}/{num_batches}, Loss: {batch_loss:.6f}")
		
		if batch_idx == num_batches - 1:  #Last batch
			model.eval()

			with torch.no_grad():
				sample_indices = torch.randint(0, len(audio_batch), (5,))
				sample_audio = audio_batch[sample_indices]
				sample_params = params_batch[sample_indices]
				
				sample_features = feature_extractor(sample_audio).unsqueeze(1)
				sample_predictions = model(sample_features)

				print("Shape of audio:", sample_audio.shape)
				print("Shape of features:", sample_features.shape)
				
				denorm_predictions = denormalize_batch(sample_predictions).detach().cpu().numpy()
				
				print("\nSample Predictions:")
				print("Index | Parameter | True Value | Predicted Value")
				print("-" * 50)
				
				param_names = ['Frequency', 'Phase', 'Volume']
				for i, (true, pred) in enumerate(zip(sample_params, denorm_predictions)):
					print(f"Sample {i+1}:")
					for j, name in enumerate(param_names):
						print(f"  {name}: {true[j]:.4f} | {pred[j]:.4f}")


	plt.figure(figsize=(10, 5))
	plt.plot(range(1, num_batches+1), batch_losses, marker='o')
	plt.title('Training Loss')
	plt.xlabel('Batch')
	plt.ylabel('Loss')
	plt.grid(True)
	plt.show()

	torch.save({
		'model_state_dict': model.state_dict(),
		'optimizer_state_dict': optimizer.state_dict(),
	}, 'synth_parameter_predictor.pth')
	print("Model saved to 'synth_parameter_predictor.pth'")

In [None]:
from reinforcement_learner import SynthRL

trash_model = model

dataset = WaveIterableDataset()
model = SynthRL(device, dataset)

model.train(num_epochs=100, samples_per_epoch=5, steps_per_sample=10)

test_audio, true_params = next(iter(dataset))
test_audio = test_audio.to(device)
true_params = true_params.to(device)

predicted_params = denormalize_params(
	trash_model(
		feature_extractor(test_audio).unsqueeze(0).unsqueeze(0)
	).squeeze(0)
)

predicted_params = model.predict(
	test_audio,
	params=predicted_params
)

np.set_printoptions(suppress=True)
print("True parameters:", true_params.cpu().numpy())
print("Predicted parameters:", predicted_params.cpu().numpy())

Epoch 0: Avg Loss = 0.316668, Param MSE = 1426835.500000
Epoch 10: Avg Loss = 0.322952, Param MSE = 1024361.500000
Epoch 20: Avg Loss = 0.431762, Param MSE = 2630281.000000
Epoch 30: Avg Loss = 0.421652, Param MSE = 230952.921875
Epoch 40: Avg Loss = 0.370940, Param MSE = 1667228.500000
Epoch 50: Avg Loss = 0.346912, Param MSE = 1175408.500000
Epoch 60: Avg Loss = 0.277563, Param MSE = 417643.718750
Epoch 70: Avg Loss = 0.537395, Param MSE = 1429459.250000
Epoch 80: Avg Loss = 0.434282, Param MSE = 2373.382080
Epoch 90: Avg Loss = 0.477019, Param MSE = 36297.312500
Training completed!
Before parameters: tensor([203.4914,   0.4103,   0.6460], device='cuda:0')
True parameters: [258.7986       0.8389284    0.80834633]
Predicted parameters: [203.49144      0.41026366   0.6459962 ]
