In [None]:
# Import required libraries
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from src.data_loader import get_dataloader
from src.model import SimpleTTS
from src.utils import text_to_sequence
from src import config

print("PyTorch version:", torch.__version__)
print("Device:", config.DEVICE)


In [None]:
# Load and explore the dataset
dataloader = get_dataloader(config.DATA_PATH, batch_size=4, shuffle=False)

# Get a sample batch
sample_batch = next(iter(dataloader))
input_padded, target_padded = sample_batch

print("Dataset loaded successfully!")
print(f"Input shape (text sequences): {input_padded.shape}")
print(f"Target shape (mel-spectrograms): {target_padded.shape}")
print(f"Total batches in dataset: {len(dataloader)}")

# Visualize a sample mel-spectrogram
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(target_padded[0].numpy(), aspect='auto', origin='lower')
plt.title('Sample Mel-Spectrogram')
plt.xlabel('Time Steps')
plt.ylabel('Mel Channels')
plt.colorbar()

plt.subplot(1, 2, 2)
plt.plot(target_padded[0].numpy().mean(axis=0))
plt.title('Average Mel Energy Over Time')
plt.xlabel('Time Steps')
plt.ylabel('Average Energy')
plt.tight_layout()
plt.show()


In [None]:
# Initialize model and training components
device = torch.device(config.DEVICE)
model = SimpleTTS().to(device)
dataloader = get_dataloader(config.DATA_PATH, config.BATCH_SIZE)

# Loss and optimizer
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.LEARNING_RATE)

# Training loop for 200+ iterations
losses = []
iterations = 0
target_iterations = 200

print("Starting training for 200+ iterations...")
model.train()

for epoch in range(100):  # Will break when we reach target iterations
    for batch in dataloader:
        # Move data to device
        text_padded, mel_target = [d.to(device) for d in batch]
        
        # Forward pass
        mel_pred = model(text_padded)
        loss = criterion(mel_pred, mel_target)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.CLIP_THRESH)
        optimizer.step()
        
        # Record loss
        losses.append(loss.item())
        iterations += 1
        
        # Print progress every 50 iterations
        if iterations % 50 == 0:
            print(f"Iteration {iterations}, Loss: {loss.item():.4f}")
        
        # Stop when we reach target iterations
        if iterations >= target_iterations:
            break
    
    if iterations >= target_iterations:
        break

print(f"Training completed after {iterations} iterations")


In [None]:
# Plot loss curves
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(losses)
plt.title('Training Loss Over Iterations')
plt.xlabel('Iteration')
plt.ylabel('L1 Loss')
plt.grid(True)

plt.subplot(1, 2, 2)
# Plot moving average for smoother curve
window_size = 10
if len(losses) >= window_size:
    moving_avg = np.convolve(losses, np.ones(window_size)/window_size, mode='valid')
    plt.plot(range(window_size-1, len(losses)), moving_avg, 'r-', linewidth=2)
    plt.title(f'Moving Average Loss (window={window_size})')
    plt.xlabel('Iteration')
    plt.ylabel('L1 Loss')
    plt.grid(True)

plt.tight_layout()
plt.show()

print(f"Initial loss: {losses[0]:.4f}")
print(f"Final loss: {losses[-1]:.4f}")
print(f"Loss reduction: {((losses[0] - losses[-1])/losses[0]*100):.1f}%")


In [None]:
# Generate mel-spectrograms from sample texts
model.eval()
sample_texts = [
    "hello world",
    "this is a test",
    "artificial intelligence",
    "text to speech synthesis"
]

generated_mels = []
with torch.no_grad():
    for text in sample_texts:
        # Convert text to sequence
        sequence = torch.LongTensor(text_to_sequence(text)).unsqueeze(0).to(device)
        
        # Generate mel-spectrogram
        mel_output = model(sequence)
        generated_mels.append(mel_output.squeeze(0).cpu().numpy())

# Visualize generated mel-spectrograms
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
axes = axes.flatten()

for i, (text, mel) in enumerate(zip(sample_texts, generated_mels)):
    im = axes[i].imshow(mel, aspect='auto', origin='lower', cmap='viridis')
    axes[i].set_title(f'Generated Mel: "{text}"')
    axes[i].set_xlabel('Time Steps')
    axes[i].set_ylabel('Mel Channels')
    plt.colorbar(im, ax=axes[i])

plt.tight_layout()
plt.show()


In [None]:
# Compare generated vs ground truth
# Get a sample from the dataset
sample_batch = next(iter(dataloader))
input_padded, target_padded = sample_batch

# Generate prediction for the first sample
with torch.no_grad():
    predicted_mel = model(input_padded[:1].to(device))
    predicted_mel = predicted_mel.squeeze(0).cpu().numpy()

ground_truth_mel = target_padded[0].numpy()

# Plot comparison
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.imshow(ground_truth_mel, aspect='auto', origin='lower', cmap='viridis')
plt.title('Ground Truth Mel-Spectrogram')
plt.xlabel('Time Steps')
plt.ylabel('Mel Channels')
plt.colorbar()

plt.subplot(1, 3, 2)
plt.imshow(predicted_mel, aspect='auto', origin='lower', cmap='viridis')
plt.title('Generated Mel-Spectrogram')
plt.xlabel('Time Steps')
plt.ylabel('Mel Channels')
plt.colorbar()

plt.subplot(1, 3, 3)
difference = np.abs(ground_truth_mel - predicted_mel)
plt.imshow(difference, aspect='auto', origin='lower', cmap='hot')
plt.title('Absolute Difference')
plt.xlabel('Time Steps')
plt.ylabel('Mel Channels')
plt.colorbar()

plt.tight_layout()
plt.show()

# Calculate metrics
mae = np.mean(np.abs(ground_truth_mel - predicted_mel))
mse = np.mean((ground_truth_mel - predicted_mel) ** 2)
print(f"Mean Absolute Error: {mae:.4f}")
print(f"Mean Squared Error: {mse:.4f}")


In [None]:
# Import audio processing libraries
import librosa
import soundfile as sf
from IPython.display import Audio, display

def save_wav(mel_spectrogram, path):
    """Converts a mel-spectrogram to a WAV file using Griffin-Lim."""
    stft_matrix = librosa.feature.inverse.mel_to_stft(
        mel_spectrogram,
        sr=config.SAMPLING_RATE,
        n_fft=config.N_FFT
    )
    audio = librosa.griffinlim(stft_matrix, hop_length=config.HOP_LENGTH)
    sf.write(path, audio, config.SAMPLING_RATE)
    return audio

# Generate audio from the first generated mel-spectrogram
test_text = "hello world"
test_mel = generated_mels[0]  # From previous cell

print(f"Converting mel-spectrogram to audio for: '{test_text}'")
audio_path = "demo_output.wav"
audio_waveform = save_wav(test_mel, audio_path)

# Display audio player
print(f"Audio saved to: {audio_path}")
display(Audio(audio_waveform, rate=config.SAMPLING_RATE))

# Plot the waveform
plt.figure(figsize=(12, 4))
time_axis = np.linspace(0, len(audio_waveform) / config.SAMPLING_RATE, len(audio_waveform))
plt.plot(time_axis, audio_waveform)
plt.title(f'Generated Audio Waveform: "{test_text}"')
plt.xlabel('Time (seconds)')
plt.ylabel('Amplitude')
plt.grid(True)
plt.show()

print(f"Audio duration: {len(audio_waveform) / config.SAMPLING_RATE:.2f} seconds")
print(f"Sample rate: {config.SAMPLING_RATE} Hz")
