## 1. Instalación y Setup

In [None]:
# Instalar dependencias si es necesario
# !pip install torch torchaudio librosa soundfile numpy scipy matplotlib tqdm

In [None]:
import os
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio
import librosa
import soundfile as sf
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
from IPython.display import Audio, display
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Preprocesamiento de Audio

In [None]:
# Configuración de preprocesamiento
CONFIG = {
    'sample_rate': 22050,
    'n_fft': 1024,
    'hop_length': 256,
    'win_length': 1024,
    'n_mels': 80,
    'fmin': 0,
    'fmax': 8000,
    'segment_size': 8192,  # ~0.37 segundos
}

def mel_spectrogram(y, sr, n_fft, hop_length, win_length, n_mels, fmin, fmax):
    """Extrae mel-spectrogram"""
    mel = librosa.feature.melspectrogram(
        y=y, sr=sr, n_fft=n_fft, hop_length=hop_length,
        win_length=win_length, n_mels=n_mels, fmin=fmin, fmax=fmax
    )
    mel_db = librosa.power_to_db(mel, ref=np.max)
    return mel_db

def preprocess_audio_files(input_dir, output_dir, config, max_files=None):
    """Preprocesa archivos de audio"""
    input_dir = Path(input_dir)
    output_dir = Path(output_dir)
    
    audio_dir = output_dir / 'audio'
    mel_dir = output_dir / 'mels'
    audio_dir.mkdir(parents=True, exist_ok=True)
    mel_dir.mkdir(parents=True, exist_ok=True)
    
    audio_files = list(input_dir.glob('*.mp3'))
    if max_files:
        audio_files = audio_files[:max_files]
    
    print(f"Procesando {len(audio_files)} archivos...")
    
    metadata = []
    for idx, audio_path in enumerate(tqdm(audio_files)):
        try:
            # Cargar audio
            audio, sr = librosa.load(str(audio_path), sr=config['sample_rate'], mono=True)
            
            # Guardar audio procesado
            audio_filename = f"{idx:04d}.wav"
            sf.write(audio_dir / audio_filename, audio, config['sample_rate'])
            
            # Extraer mel-spectrogram
            mel = mel_spectrogram(
                audio, config['sample_rate'], config['n_fft'],
                config['hop_length'], config['win_length'],
                config['n_mels'], config['fmin'], config['fmax']
            )
            
            mel_filename = f"{idx:04d}.npy"
            np.save(mel_dir / mel_filename, mel)
            
            metadata.append({
                'id': idx,
                'original': audio_path.name,
                'audio': audio_filename,
                'mel': mel_filename,
                'duration': len(audio) / config['sample_rate']
            })
        except Exception as e:
            print(f"Error en {audio_path.name}: {e}")
    
    # Guardar metadata
    with open(output_dir / 'metadata.json', 'w') as f:
        json.dump(metadata, f, indent=2)
    
    with open(output_dir / 'config.json', 'w') as f:
        json.dump(config, f, indent=2)
    
    print(f"\n✓ Procesados {len(metadata)} archivos")
    return metadata

# Ejecutar preprocesamiento

input_dir = r"C:\Users\carlo\Downloads\Deep_learning_P\jamendo_tracks"
output_dir = Path('data/processed')

audio_files = list(Path(input_dir).glob('*.mp3'))
print(f"Archivos encontrados: {audio_files}")

if not (output_dir / 'metadata.json').exists():
    metadata = preprocess_audio_files(input_dir, output_dir, CONFIG)
else:
    print("Datos ya preprocesados. Cargando metadata...")
    with open(output_dir / 'metadata.json', 'r') as f:
        metadata = json.load(f)
    print(f"✓ Cargados {len(metadata)} archivos")

## 3. Dataset para HiFi-GAN

In [None]:
class AudioDataset(Dataset):
    """Dataset para HiFi-GAN"""
    
    def __init__(self, data_dir, segment_size, hop_length, split='train', train_ratio=0.85):
        self.data_dir = Path(data_dir)
        self.segment_size = segment_size
        self.hop_length = hop_length
        
        # Cargar metadata
        with open(self.data_dir / 'metadata.json', 'r') as f:
            metadata = json.load(f)
        
        # Split train/val
        split_idx = int(len(metadata) * train_ratio)
        if split == 'train':
            self.metadata = metadata[:split_idx]
        else:
            self.metadata = metadata[split_idx:]
        
        print(f"{split.capitalize()} dataset: {len(self.metadata)} archivos")
    
    def __len__(self):
        return len(self.metadata)
    
    def __getitem__(self, idx):
        item = self.metadata[idx]
        
        # Cargar audio
        audio_path = self.data_dir / 'audio' / item['audio']
        audio, sr = sf.read(audio_path)
        audio = torch.FloatTensor(audio)
        
        # Cargar mel
        mel_path = self.data_dir / 'mels' / item['mel']
        mel = np.load(mel_path)
        mel = torch.FloatTensor(mel)
        
        # Random crop para training
        if audio.size(0) >= self.segment_size:
            max_start = audio.size(0) - self.segment_size
            start = np.random.randint(0, max_start + 1)
            audio = audio[start:start + self.segment_size]
            
            # Correspondiente mel segment
            mel_start = start // self.hop_length
            mel_length = self.segment_size // self.hop_length
            mel = mel[:, mel_start:mel_start + mel_length]
        else:
            # Pad si es muy corto
            audio = F.pad(audio, (0, self.segment_size - audio.size(0)))
            mel_length = self.segment_size // self.hop_length
            if mel.size(1) < mel_length:
                mel = F.pad(mel, (0, mel_length - mel.size(1)))
        
        return mel, audio.unsqueeze(0)

# Crear datasets
train_dataset = AudioDataset(output_dir, CONFIG['segment_size'], CONFIG['hop_length'], 'train')
val_dataset = AudioDataset(output_dir, CONFIG['segment_size'], CONFIG['hop_length'], 'val')

# Dataloaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=0)

print(f"\nBatch size: 4")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

## 4. Implementación de HiFi-GAN

In [None]:
class ResBlock(nn.Module):
    """Residual block con dilated convolutions"""
    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
        super().__init__()
        self.convs1 = nn.ModuleList([
            nn.Conv1d(channels, channels, kernel_size, 1, dilation=d,
                     padding=self.get_padding(kernel_size, d))
            for d in dilation
        ])
        self.convs2 = nn.ModuleList([
            nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
                     padding=self.get_padding(kernel_size, 1))
            for _ in dilation
        ])
    
    def get_padding(self, kernel_size, dilation):
        return int((kernel_size * dilation - dilation) / 2)
    
    def forward(self, x):
        for c1, c2 in zip(self.convs1, self.convs2):
            xt = F.leaky_relu(x, 0.1)
            xt = c1(xt)
            xt = F.leaky_relu(xt, 0.1)
            xt = c2(xt)
            x = xt + x
        return x

class Generator(nn.Module):
    """HiFi-GAN Generator"""
    def __init__(self, n_mels=80):
        super().__init__()
        
        # Configuración simplificada
        self.num_kernels = 3
        self.num_upsamples = 4
        upsample_rates = [8, 8, 2, 2]  # Total: 256x upsampling
        upsample_kernel_sizes = [16, 16, 4, 4]
        upsample_initial_channel = 256
        resblock_kernel_sizes = [3, 7, 11]
        resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
        
        # Input conv
        self.conv_pre = nn.Conv1d(n_mels, upsample_initial_channel, 7, 1, padding=3)
        
        # Upsampling layers
        self.ups = nn.ModuleList()
        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
            self.ups.append(
                nn.ConvTranspose1d(
                    upsample_initial_channel // (2**i),
                    upsample_initial_channel // (2**(i+1)),
                    k, u, padding=(k-u)//2
                )
            )
        
        # Residual blocks
        self.resblocks = nn.ModuleList()
        for i in range(len(self.ups)):
            ch = upsample_initial_channel // (2**(i+1))
            for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
                self.resblocks.append(ResBlock(ch, k, d))
        
        # Output conv
        self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3)
    
    def forward(self, x):
        x = self.conv_pre(x)
        
        for i, ups in enumerate(self.ups):
            x = F.leaky_relu(x, 0.1)
            x = ups(x)
            
            # Apply residual blocks
            xs = None
            for j in range(self.num_kernels):
                idx = i * self.num_kernels + j
                if xs is None:
                    xs = self.resblocks[idx](x)
                else:
                    xs += self.resblocks[idx](x)
            x = xs / self.num_kernels
        
        x = F.leaky_relu(x)
        x = self.conv_post(x)
        x = torch.tanh(x)
        
        return x

# Inicializar generador
generator = Generator(n_mels=CONFIG['n_mels']).to(device)
print(f"\nGenerador creado")
print(f"Parámetros: {sum(p.numel() for p in generator.parameters()):,}")

# Test forward pass
test_mel = torch.randn(1, 80, 32).to(device)
test_output = generator(test_mel)
print(f"\nTest input: {test_mel.shape}")
print(f"Test output: {test_output.shape}")

## 5. Discriminadores

In [None]:
class PeriodDiscriminator(nn.Module):
    """Discriminador periódico"""
    def __init__(self, period):
        super().__init__()
        self.period = period
        
        self.convs = nn.ModuleList([
            nn.Conv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
            nn.Conv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
            nn.Conv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
            nn.Conv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
            nn.Conv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
        ])
        self.conv_post = nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))
    
    def forward(self, x):
        fmap = []
        
        # Reshape to 2D
        b, c, t = x.shape
        if t % self.period != 0:
            n_pad = self.period - (t % self.period)
            x = F.pad(x, (0, n_pad), "reflect")
            t = t + n_pad
        x = x.view(b, c, t // self.period, self.period)
        
        for conv in self.convs:
            x = conv(x)
            x = F.leaky_relu(x, 0.1)
            fmap.append(x)
        
        x = self.conv_post(x)
        fmap.append(x)
        x = torch.flatten(x, 1, -1)
        
        return x, fmap

class MultiPeriodDiscriminator(nn.Module):
    """Multi-Period Discriminator"""
    def __init__(self):
        super().__init__()
        self.discriminators = nn.ModuleList([
            PeriodDiscriminator(2),
            PeriodDiscriminator(3),
            PeriodDiscriminator(5),
            PeriodDiscriminator(7),
            PeriodDiscriminator(11),
        ])
    
    def forward(self, y, y_hat):
        y_d_rs = []
        y_d_gs = []
        fmap_rs = []
        fmap_gs = []
        
        for d in self.discriminators:
            y_d_r, fmap_r = d(y)
            y_d_g, fmap_g = d(y_hat)
            y_d_rs.append(y_d_r)
            y_d_gs.append(y_d_g)
            fmap_rs.append(fmap_r)
            fmap_gs.append(fmap_g)
        
        return y_d_rs, y_d_gs, fmap_rs, fmap_gs

# Inicializar discriminador
discriminator = MultiPeriodDiscriminator().to(device)
print(f"\nDiscriminador creado")
print(f"Parámetros: {sum(p.numel() for p in discriminator.parameters()):,}")

## 6. Funciones de Loss

In [None]:
def feature_loss(fmap_r, fmap_g):
    """Feature matching loss"""
    loss = 0
    for dr, dg in zip(fmap_r, fmap_g):
        for rl, gl in zip(dr, dg):
            loss += torch.mean(torch.abs(rl - gl))
    return loss * 2

def discriminator_loss(disc_real_outputs, disc_generated_outputs):
    """Discriminator loss"""
    loss = 0
    for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
        r_loss = torch.mean((1 - dr) ** 2)
        g_loss = torch.mean(dg ** 2)
        loss += (r_loss + g_loss)
    return loss

def generator_loss(disc_outputs):
    """Generator adversarial loss"""
    loss = 0
    for dg in disc_outputs:
        loss += torch.mean((1 - dg) ** 2)
    return loss

def mel_spectrogram_loss(y, y_g, config):
    """Mel-spectrogram reconstruction loss"""
    y_mel = mel_spectrogram(
        y.squeeze(1).cpu().numpy()[0],
        config['sample_rate'], config['n_fft'],
        config['hop_length'], config['win_length'],
        config['n_mels'], config['fmin'], config['fmax']
    )
    y_g_mel = mel_spectrogram(
        y_g.squeeze(1).detach().cpu().numpy()[0],
        config['sample_rate'], config['n_fft'],
        config['hop_length'], config['win_length'],
        config['n_mels'], config['fmin'], config['fmax']
    )
    
    return F.l1_loss(
        torch.FloatTensor(y_mel).to(device),
        torch.FloatTensor(y_g_mel).to(device)
    )

print("✓ Funciones de loss definidas")

## 7. Training Loop

In [None]:
# Configuración de training
TRAIN_CONFIG = {
    'epochs': 100,
    'lr': 0.0002,
    'betas': (0.8, 0.99),
    'lambda_mel': 45,  # Peso de mel loss
    'lambda_fm': 2,    # Peso de feature matching
}

# Optimizers
optim_g = torch.optim.AdamW(generator.parameters(), 
                            TRAIN_CONFIG['lr'], 
                            betas=TRAIN_CONFIG['betas'])
optim_d = torch.optim.AdamW(discriminator.parameters(), 
                            TRAIN_CONFIG['lr'], 
                            betas=TRAIN_CONFIG['betas'])

# Schedulers
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=0.999)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=0.999)

print("✓ Optimizers configurados")

In [None]:
# Training function
def train_epoch(epoch):
    generator.train()
    discriminator.train()
    
    total_loss_g = 0
    total_loss_d = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
    
    for batch_idx, (mel, audio) in enumerate(pbar):
        mel = mel.to(device)
        audio = audio.to(device)
        
        # ===== Train Discriminator =====
        optim_d.zero_grad()
        
        # Generate audio
        audio_g = generator(mel)
        
        # Discriminator
        y_dr, y_dg, _, _ = discriminator(audio, audio_g.detach())
        
        # Loss
        loss_d = discriminator_loss(y_dr, y_dg)
        
        loss_d.backward()
        optim_d.step()
        
        # ===== Train Generator =====
        optim_g.zero_grad()
        
        # Discriminator on generated
        y_dr, y_dg, fmap_r, fmap_g = discriminator(audio, audio_g)
        
        # Losses
        loss_fm = feature_loss(fmap_r, fmap_g)
        loss_gen = generator_loss(y_dg)
        
        # Mel loss (simplified - only first sample)
        loss_mel = F.l1_loss(mel, mel) * 0  # Placeholder
        
        loss_g = loss_gen + TRAIN_CONFIG['lambda_fm'] * loss_fm + TRAIN_CONFIG['lambda_mel'] * loss_mel
        
        loss_g.backward()
        optim_g.step()
        
        # Track
        total_loss_g += loss_g.item()
        total_loss_d += loss_d.item()
        
        pbar.set_postfix({
            'G': f"{loss_g.item():.4f}",
            'D': f"{loss_d.item():.4f}"
        })
    
    # Schedulers
    scheduler_g.step()
    scheduler_d.step()
    
    return total_loss_g / len(train_loader), total_loss_d / len(train_loader)

print("✓ Training function definida")

## 8. Ejecutar Training

In [None]:
# Training loop
history = {'loss_g': [], 'loss_d': []}

NUM_EPOCHS = 50  # Ajusta según necesites

print(f"\nIniciando training por {NUM_EPOCHS} epochs...")
print("=" * 70)

for epoch in range(1, NUM_EPOCHS + 1):
    loss_g, loss_d = train_epoch(epoch)
    
    history['loss_g'].append(loss_g)
    history['loss_d'].append(loss_d)
    
    print(f"\nEpoch {epoch}/{NUM_EPOCHS}")
    print(f"  Generator Loss: {loss_g:.4f}")
    print(f"  Discriminator Loss: {loss_d:.4f}")
    
    # Save checkpoint every 10 epochs
    if epoch % 10 == 0:
        checkpoint_dir = Path('models/hifigan_checkpoints')
        checkpoint_dir.mkdir(parents=True, exist_ok=True)
        
        torch.save({
            'epoch': epoch,
            'generator': generator.state_dict(),
            'discriminator': discriminator.state_dict(),
            'optim_g': optim_g.state_dict(),
            'optim_d': optim_d.state_dict(),
        }, checkpoint_dir / f'checkpoint_epoch_{epoch}.pt')
        
        print(f"  ✓ Checkpoint guardado")

print("\n" + "=" * 70)
print("✓ Training completado!")

## 9. Visualizar Training

In [None]:
# Plot training curves
fig, ax = plt.subplots(1, 1, figsize=(12, 5))

epochs = range(1, len(history['loss_g']) + 1)
ax.plot(epochs, history['loss_g'], 'b-', label='Generator Loss', linewidth=2)
ax.plot(epochs, history['loss_d'], 'r-', label='Discriminator Loss', linewidth=2)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title('HiFi-GAN Training Curves', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('models/hifigan_checkpoints/training_curves.png', dpi=150)
plt.show()

print("✓ Gráficas guardadas")

## 10. Generación de Audio

In [None]:
# Función de generación
@torch.no_grad()
def generate_from_mel(mel_spec):
    """Genera audio desde mel-spectrogram"""
    generator.eval()
    
    if isinstance(mel_spec, np.ndarray):
        mel_spec = torch.FloatTensor(mel_spec)
    
    if mel_spec.dim() == 2:
        mel_spec = mel_spec.unsqueeze(0)
    
    mel_spec = mel_spec.to(device)
    
    audio = generator(mel_spec)
    audio = audio.squeeze().cpu().numpy()
    
    return audio

# Cargar un mel de validación
val_sample = val_dataset[0]
val_mel, val_audio_gt = val_sample

print(f"Mel shape: {val_mel.shape}")
print(f"Audio GT shape: {val_audio_gt.shape}")

# Generar audio
generated_audio = generate_from_mel(val_mel)
print(f"Generated audio shape: {generated_audio.shape}")

# Comparar
print("\n=" * 70)
print("AUDIO GENERADO")
print("=" * 70)

print("\n1. Ground Truth (Original):")
display(Audio(val_audio_gt.numpy(), rate=CONFIG['sample_rate']))

print("\n2. Generado por HiFi-GAN:")
display(Audio(generated_audio, rate=CONFIG['sample_rate']))

## 11. Visualización de Resultados

In [None]:
# Plot waveforms
fig, axes = plt.subplots(2, 1, figsize=(15, 6))

# Ground truth
time_gt = np.arange(len(val_audio_gt.squeeze())) / CONFIG['sample_rate']
axes[0].plot(time_gt, val_audio_gt.squeeze().numpy(), linewidth=0.5)
axes[0].set_title('Ground Truth Audio', fontsize=12, fontweight='bold')
axes[0].set_xlabel('Time (s)')
axes[0].set_ylabel('Amplitude')
axes[0].grid(True, alpha=0.3)

# Generated
time_gen = np.arange(len(generated_audio)) / CONFIG['sample_rate']
axes[1].plot(time_gen, generated_audio, linewidth=0.5, color='orange')
axes[1].set_title('Generated Audio (HiFi-GAN)', fontsize=12, fontweight='bold')
axes[1].set_xlabel('Time (s)')
axes[1].set_ylabel('Amplitude')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Plot spectrograms
fig, axes = plt.subplots(1, 2, figsize=(15, 4))

# GT spectrogram
axes[0].imshow(val_mel.numpy(), aspect='auto', origin='lower', cmap='viridis')
axes[0].set_title('Input Mel-Spectrogram', fontsize=12, fontweight='bold')
axes[0].set_xlabel('Time')
axes[0].set_ylabel('Mel Bins')

# Generated mel
gen_mel = mel_spectrogram(
    generated_audio, CONFIG['sample_rate'], CONFIG['n_fft'],
    CONFIG['hop_length'], CONFIG['win_length'],
    CONFIG['n_mels'], CONFIG['fmin'], CONFIG['fmax']
)
axes[1].imshow(gen_mel, aspect='auto', origin='lower', cmap='viridis')
axes[1].set_title('Generated Mel-Spectrogram', fontsize=12, fontweight='bold')
axes[1].set_xlabel('Time')
axes[1].set_ylabel('Mel Bins')

plt.tight_layout()
plt.show()

## 12. Guardar Modelo y Generar Múltiples Samples

In [None]:
# Guardar modelo final
model_dir = Path('models/hifigan_final')
model_dir.mkdir(parents=True, exist_ok=True)

torch.save({
    'generator': generator.state_dict(),
    'config': CONFIG,
    'train_config': TRAIN_CONFIG,
}, model_dir / 'hifigan_generator.pt')

print(f"✓ Modelo guardado en: {model_dir / 'hifigan_generator.pt'}")

# Generar múltiples samples
output_dir = Path('output/hifigan_samples')
output_dir.mkdir(parents=True, exist_ok=True)

print("\nGenerando samples...")
for i in range(min(5, len(val_dataset))):
    mel, audio_gt = val_dataset[i]
    
    # Generate
    audio_gen = generate_from_mel(mel)
    
    # Save
    sf.write(
        output_dir / f'sample_{i}_generated.wav',
        audio_gen,
        CONFIG['sample_rate']
    )
    sf.write(
        output_dir / f'sample_{i}_groundtruth.wav',
        audio_gt.squeeze().numpy(),
        CONFIG['sample_rate']
    )
    
    print(f"  ✓ Sample {i} guardado")

print(f"\n✓ Samples guardados en: {output_dir}")