In [1]:
import torch

print("="*80)
print("CUDA DIAGNOSTICS")
print("="*80)

# Check CUDA availability
print(f"\nCUDA Available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"PyTorch Version: {torch.__version__}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    
    for i in range(torch.cuda.device_count()):
        print(f"\n--- GPU {i} ---")
        print(f"Name: {torch.cuda.get_device_name(i)}")
        print(f"Capability: {torch.cuda.get_device_capability(i)}")
        
        # Memory info
        props = torch.cuda.get_device_properties(i)
        print(f"Total Memory: {props.total_memory / 1024**3:.2f} GB")
        print(f"Available Memory: {torch.cuda.mem_get_info(i)[0] / 1024**3:.2f} GB")
        print(f"Allocated Memory: {torch.cuda.memory_allocated(i) / 1024**3:.2f} GB")
        
        # Test tensor creation
        try:
            test_tensor = torch.randn(100, 100).cuda(i)
            print(f"✓ Can create tensors on GPU {i}")
        except Exception as e:
            print(f"✗ Error creating tensor: {e}")
else:
    print("\n✗ CUDA is NOT available!")


CUDA DIAGNOSTICS

CUDA Available: True
CUDA Version: 12.6
PyTorch Version: 2.9.1+cu126
Number of GPUs: 1

--- GPU 0 ---
Name: NVIDIA RTX A1000
Capability: (8, 6)
Total Memory: 8.00 GB
Available Memory: 7.03 GB
Allocated Memory: 0.00 GB
✓ Can create tensors on GPU 0


In [41]:
import torch
import gc

# Kill all CUDA processes
torch.cuda.empty_cache()
gc.collect()

# Reset peak memory stats
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()

print(f"Available Memory: {torch.cuda.mem_get_info(0)[0] / 1024**3:.2f} GB")

Available Memory: 5.88 GB


In [2]:
# Core dependencies
#!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

!pip install torchaudio transformers einops tqdm descript-audio-codec
#audiotools

Defaulting to user installation because normal site-packages is not writeable
Collecting torchaudio
  Downloading torchaudio-2.9.1-cp312-cp312-win_amd64.whl.metadata (6.9 kB)
Collecting transformers
  Using cached transformers-4.57.3-py3-none-any.whl.metadata (43 kB)
Collecting einops
  Downloading einops-0.8.1-py3-none-any.whl.metadata (13 kB)
Collecting tqdm
  Using cached tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Collecting descript-audio-codec
  Downloading descript_audio_codec-1.0.0-py3-none-any.whl.metadata (7.8 kB)
Collecting huggingface-hub<1.0,>=0.34.0 (from transformers)
  Using cached huggingface_hub-0.36.0-py3-none-any.whl.metadata (14 kB)
Collecting regex!=2019.12.17 (from transformers)
  Downloading regex-2025.11.3-cp312-cp312-win_amd64.whl.metadata (41 kB)
Collecting tokenizers<=0.23.0,>=0.22.0 (from transformers)
  Using cached tokenizers-0.22.1-cp39-abi3-win_amd64.whl.metadata (6.9 kB)
Collecting safetensors>=0.4.3 (from transformers)
  Using cached safetensors-0.7


[notice] A new release of pip is available: 25.0.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [4]:
!pip install --retries 10 --timeout 30 descript-audio-codec
!pip install --retries 10 --timeout 30 git+https://github.com/descriptinc/audiotools

Collecting descript-audio-codec
  Using cached descript_audio_codec-1.0.0-py3-none-any.whl.metadata (7.8 kB)
Collecting argbind>=0.3.7 (from descript-audio-codec)
  Downloading argbind-0.3.9.tar.gz (17 kB)
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'done'
Collecting descript-audiotools>=0.7.2 (from descript-audio-codec)
  Downloading descript_audiotools-0.7.2-py2.py3-none-any.whl.metadata (3.4 kB)
Collecting docstring-parser (from argbind>=0.3.7->descript-audio-codec)
  Downloading docstring_parser-0.17.0-py3-none-any.whl.metadata (3.5 kB)
Collecting pyloudnorm (from descript-audiotools>=0.7.2->descript-audio-codec)
  Downloading pyloudnorm-0.1.1-py3-none-any.whl.metadata (5.6 kB)
Collecting importlib-resources

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
databricks-sdk 0.70.0 requires protobuf<7.0,>=4.21.0, but you have protobuf 3.19.6 which is incompatible.
opentelemetry-proto 1.38.0 requires protobuf<7.0,>=5.0, but you have protobuf 3.19.6 which is incompatible.
ray 2.52.1 requires protobuf>=3.20.3, but you have protobuf 3.19.6 which is incompatible.
tensorflow 2.20.0 requires protobuf>=5.28.0, but you have protobuf 3.19.6 which is incompatible.


Collecting git+https://github.com/descriptinc/audiotools
  Cloning https://github.com/descriptinc/audiotools to c:\users\user\appdata\local\temp\pip-req-build-5_d7i_qg


  Running command git clone --filter=blob:none --quiet https://github.com/descriptinc/audiotools 'C:\Users\user\AppData\Local\Temp\pip-req-build-5_d7i_qg'
  fatal: unable to access 'https://github.com/descriptinc/audiotools/': Could not resolve host: github.com
  error: subprocess-exited-with-error
  
  git clone --filter=blob:none --quiet https://github.com/descriptinc/audiotools 'C:\Users\user\AppData\Local\Temp\pip-req-build-5_d7i_qg' did not run successfully.
  exit code: 128
  
  No available output.
  
  note: This error originates from a subprocess, and is likely not a problem with pip.
ERROR: Failed to build 'git+https://github.com/descriptinc/audiotools' when git clone --filter=blob:none --quiet https://github.com/descriptinc/audiotools 'c:\users\user\appdata\local\temp\pip-req-build-5_d7i_qg'


In [1]:
!pip install torchcodec

Defaulting to user installation because normal site-packages is not writeable
Collecting torchcodec
  Downloading torchcodec-0.9.1-cp312-cp312-win_amd64.whl.metadata (11 kB)
Downloading torchcodec-0.9.1-cp312-cp312-win_amd64.whl (2.2 MB)
   ---------------------------------------- 0.0/2.2 MB ? eta -:--:--
   ---------------------------------------- 2.2/2.2 MB 31.0 MB/s eta 0:00:00
Installing collected packages: torchcodec
Successfully installed torchcodec-0.9.1



[notice] A new release of pip is available: 25.0.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [None]:
# In notebook:
!pip install "C:/Users/user/Downloads/protobuf-5.28.3-cp310-abi3-win_amd64.whl" --force-reinstal
# manual download of protobuf because of DNS issues

Defaulting to user installation because normal site-packages is not writeable
Processing c:\users\user\downloads\protobuf-5.28.3-cp310-abi3-win_amd64.whl
Installing collected packages: protobuf
  Attempting uninstall: protobuf
    Found existing installation: protobuf 3.19.6
    Uninstalling protobuf-3.19.6:
      Successfully uninstalled protobuf-3.19.6
Successfully installed protobuf-5.28.3


ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
descript-audiotools 0.7.2 requires protobuf<3.20,>=3.9.2, but you have protobuf 5.28.3 which is incompatible.

[notice] A new release of pip is available: 25.0.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [None]:
import os

dac_model_path = "C:/Users/user/Downloads/weights_44khz_16kbps.pth"
# manual download of DAC because of DNS issues from official github repo releases 

if os.path.exists(dac_model_path):
    print("✅ DAC model found! You can train!")
    print(f"   Location: {dac_model_path}")
    file_size = os.path.getsize(dac_model_path) / (1024 * 1024)
    print(f"   Size: {file_size:.2f} MB")
else:
    print("❌ DAC model NOT found!")
    print(f"   Expected location: {dac_model_path}")
    print("\n📥 You need to download it first!")

✅ DAC model found! You can train!
   Location: C:/Users/user/Downloads/weights_44khz_16kbps.pth
   Size: 245.08 MB


In [None]:
!pip install soundfile
# use soundfile instead to avoid audiotools issues

Defaulting to user installation because normal site-packages is not writeable



[notice] A new release of pip is available: 25.0.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [None]:
"""
Production-Ready Audio Effect Generator using DAC-VAE (No AudioTools)
NO AUDIOTOOLS DEPENDENCY - Uses DAC encoder/decoder directly!
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import pandas as pd
import os
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from sklearn.model_selection import train_test_split
from einops import rearrange
import matplotlib.pyplot as plt
import json
import numpy as np
from tqdm import tqdm
import torchaudio
import soundfile  # This import confirms it's installed
print("✓ soundfile available")

# DAC import (NO audiotools needed!)
try:
    import dac
    print("✓ DAC library imported successfully")
except ImportError:
    print("❌ DAC not installed. Run: pip install descript-audio-codec")
    exit(1)

#############################################
#                 CONFIG
#############################################

class CFG:
    # Paths - KEEP YOUR PATHS
    csv_path = "C:/zahra/EchoMind/data/5k_datapoints_1_prompt.csv"
    base_path = "C:/zahra/EchoMind/data"
    checkpoint_path = f"{base_path}/result/model.pt"
    best_model_path = f"{base_path}/result/model_best.pt"
    plot_path = f"{base_path}/result/training_curves.png"
    
    # Create result directory if needed
    os.makedirs(f"{base_path}/result", exist_ok=True)
    
    # Columns - SAME AS YOURS
    audio_col_in = "input_audio_path"
    audio_col_out = "output_audio_path"
    text_col = "prompt"
    
    # Audio settings
    sample_rate = 44100  # DAC uses 44.1kHz (better quality than 24kHz)
    max_audio_length = 5 * 44100  # 5 seconds at 44.1kHz
    
    # DAC Model settings
    dac_model_path = "44khz"  # Options: "16khz", "24khz", "44khz"
    
    # Training - OPTIMIZED FOR DAC
    batch_size = 4  # Start small due to 44kHz
    accumulation_steps = 4  # Effective batch = 16
    epochs = 100  # More epochs for production quality
    
    # Learning rates - TUNED FOR DAC
    lr_unet = 2e-5 #1e-5  # UNet learning rate
    lr_text = 5e-7  # Text encoder learning rate (frozen mostly)
    weight_decay = 0.01
    grad_clip = 1.0
    
    # Loss weights
    audio_loss_weight = 1.5 #1.0  # Waveform reconstruction
    latent_loss_weight = 0.12 #0.1  # Latent space matching
    
    # Mixed precision
    use_amp = True
    
    # Logging
    log_interval = 100  # Print every 100 steps
    
    # UNet architecture - OPTIMIZED FOR DAC LATENTS
    unet_channels = [64, 128, 256, 512]  # Deeper for better quality
    text_dim = 768  # BERT hidden size
    
    # Data splits
    train_ratio = 0.7
    val_ratio = 0.15
    test_ratio = 0.15
    
    # Freezing options
    freeze_text_encoder = True  # Set False to fine-tune BERT
    freeze_dac = True  # ALWAYS keep True (don't touch DAC)
    
    # Device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Num workers
    num_workers = 0  # Windows compatibility

cfg = CFG()

print("="*60)
print("DAC-VAE AUDIO EFFECT GENERATOR (No AudioTools)")
print("="*60)
print(f"Device: {cfg.device}")
print(f"Sample Rate: {cfg.sample_rate} Hz")
print(f"Max Length: {cfg.max_audio_length / cfg.sample_rate:.1f} seconds")
print(f"Batch Size: {cfg.batch_size} x {cfg.accumulation_steps} = {cfg.batch_size * cfg.accumulation_steps}")
print("="*60 + "\n")

#############################################
#          LOAD DAC MODEL
#############################################

print("Loading DAC model...")

# Load from manually downloaded file (for offline use)
dac_model_path = "C:/Users/user/Downloads/weights_44khz_16kbps.pth"

if not os.path.exists(dac_model_path):
    print(f"\n❌ DAC model not found at: {dac_model_path}")
    print("\n" + "="*60)
    print("MANUAL DOWNLOAD REQUIRED")
    print("="*60)
    print("\n📥 Download Instructions:")
    print("\n1. Go to: https://github.com/descriptinc/descript-audio-codec/releases/tag/1.0.0")
    print("2. Download: weights_44khz_16kbps.pth (245 MB)")
    print(f"3. Save to: {dac_model_path}")
    print("\n💡 TIP: Use mobile hotspot if you have network/DNS issues!")
    print("="*60 + "\n")
    exit(1)

print(f"✓ Loading from: {dac_model_path}")
dac_model = dac.DAC.load(dac_model_path)
dac_model = dac_model.to(cfg.device)
dac_model.eval()

# Freeze DAC encoder and decoder (we only train UNet)
for param in dac_model.parameters():
    param.requires_grad = False

print("✓ DAC model loaded and frozen")

# Get DAC latent dimensions
with torch.no_grad():
    dummy_audio = torch.randn(1, 1, cfg.sample_rate).to(cfg.device)
    # Use encoder directly (no audiotools needed!)
    z = dac_model.encoder(dummy_audio)
    latent_channels = z.shape[1]
    latent_time_reduction = dummy_audio.shape[-1] // z.shape[-1]
    
print(f"✓ DAC Latent Channels: {latent_channels}")
print(f"✓ Time Reduction Factor: {latent_time_reduction}x")
print()

#############################################
#      DATASET LOADING & PREPARATION
#############################################

print("="*60)
print("LOADING DATASET")
print("="*60)

df = pd.read_csv(cfg.csv_path)
print(f"Original dataset: {len(df)} samples")

# Fix paths - SAME AS YOUR CODE
for col in [cfg.audio_col_in, cfg.audio_col_out]:
    df[col] = df[col].apply(lambda p: os.path.join(cfg.base_path, str(p).replace('\\', '/')))

# Validate files exist
print("\nValidating files...")
valid_indices = []
for idx, row in tqdm(df.iterrows(), total=len(df), desc="Validating"):
    if os.path.exists(row[cfg.audio_col_in]) and os.path.exists(row[cfg.audio_col_out]):
        valid_indices.append(idx)

df = df.iloc[valid_indices].reset_index(drop=True)
print(f"✓ Valid samples: {len(df)}")

# Split dataset - SAME AS YOURS
train_df, temp_df = train_test_split(df, test_size=(cfg.val_ratio + cfg.test_ratio), random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=cfg.test_ratio/(cfg.val_ratio + cfg.test_ratio), random_state=42)

print(f"\nDataset splits:")
print(f"  Train:      {len(train_df)} samples ({len(train_df)/len(df)*100:.1f}%)")
print(f"  Validation: {len(val_df)} samples ({len(val_df)/len(df)*100:.1f}%)")
print(f"  Test:       {len(test_df)} samples ({len(test_df)/len(df)*100:.1f}%)")
print()

#############################################
#      TOKENIZER
#############################################

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
print("✓ Tokenizer loaded\n")

#############################################
#      DATASET CLASS
#############################################

class AudioEffectDataset(Dataset):
    """Dataset with soundfile (no torchcodec issues!)"""
    
    def __init__(self, df):
        self.df = df.reset_index(drop=True)
    
    def __len__(self):
        return len(self.df)
    
    def _load_and_process(self, path):
        """Load audio using soundfile"""
        import soundfile as sf
        
        # Load with soundfile (NOT torchaudio!)
        wav, sr = sf.read(path)
        wav = torch.from_numpy(wav).float()
        
        # Ensure correct shape: (channels, samples)
        if wav.dim() == 1:
            wav = wav.unsqueeze(0)
        elif wav.dim() == 2 and wav.size(0) > wav.size(1):
            wav = wav.t()
        
        # Resample if needed
        if sr != cfg.sample_rate:
            wav = torchaudio.functional.resample(wav, sr, cfg.sample_rate)
        
        # Convert to mono
        if wav.size(0) > 1:
            wav = wav.mean(dim=0, keepdim=True)
        
        # Pad/trim
        if wav.size(1) > cfg.max_audio_length:
            wav = wav[:, :cfg.max_audio_length]
        elif wav.size(1) < cfg.max_audio_length:
            wav = F.pad(wav, (0, cfg.max_audio_length - wav.size(1)))
        
        return wav
    
    def __getitem__(self, idx):
        try:
            row = self.df.iloc[idx]
            wav_in = self._load_and_process(row[cfg.audio_col_in])
            wav_out = self._load_and_process(row[cfg.audio_col_out])
            text = row[cfg.text_col]
            
            return wav_in, wav_out, text
        
        except Exception as e:
            print(f"Error loading sample {idx}: {e}")
            return (
                torch.zeros(1, cfg.max_audio_length),
                torch.zeros(1, cfg.max_audio_length),
                "error loading audio"
            )

def collate_fn(batch):
    """Collate function for batching"""
    wav_in, wav_out, texts = zip(*batch)
    
    # Stack waveforms (already same length from dataset)
    wav_in = torch.stack(wav_in)
    wav_out = torch.stack(wav_out)
    
    # Tokenize texts
    tokens = tokenizer(
        list(texts),
        padding=True,
        truncation=True,
        max_length=128,
        return_tensors="pt"
    )
    
    return wav_in, wav_out, tokens.input_ids, tokens.attention_mask

#############################################
#      CREATE DATALOADERS
#############################################

print("="*60)
print("CREATING DATALOADERS")
print("="*60)

train_ds = AudioEffectDataset(train_df)
val_ds = AudioEffectDataset(val_df)
test_ds = AudioEffectDataset(test_df)

train_dl = DataLoader(
    train_ds,
    batch_size=cfg.batch_size,
    shuffle=True,
    num_workers=cfg.num_workers,
    collate_fn=collate_fn,
    pin_memory=True
)

val_dl = DataLoader(
    val_ds,
    batch_size=cfg.batch_size,
    shuffle=False,
    num_workers=cfg.num_workers,
    collate_fn=collate_fn,
    pin_memory=True
)

test_dl = DataLoader(
    test_ds,
    batch_size=cfg.batch_size,
    shuffle=False,
    num_workers=cfg.num_workers,
    collate_fn=collate_fn,
    pin_memory=True
)

print(f"Batches per epoch:")
print(f"  Train: {len(train_dl)} batches")
print(f"  Val:   {len(val_dl)} batches")
print(f"  Test:  {len(test_dl)} batches")
print()

#############################################
#      MODEL ARCHITECTURE
#############################################

class CrossAttention(nn.Module):
    """Cross-attention between audio latents and text embeddings"""
    
    def __init__(self, audio_dim, text_dim, n_heads=8):
        super().__init__()
        self.n_heads = n_heads
        self.scale = (audio_dim // n_heads) ** -0.5
        
        self.to_q = nn.Linear(audio_dim, audio_dim)
        self.to_k = nn.Linear(text_dim, audio_dim)
        self.to_v = nn.Linear(text_dim, audio_dim)
        self.to_out = nn.Linear(audio_dim, audio_dim)
        
    def forward(self, x, context):
        """
        x: (B, C, T) - audio features
        context: (B, S, D) - text embeddings
        """
        B, C, T = x.shape
        x_flat = rearrange(x, 'b c t -> b t c')
        
        q = self.to_q(x_flat)
        k = self.to_k(context)
        v = self.to_v(context)
        
        q = rearrange(q, 'b t (h d) -> b h t d', h=self.n_heads)
        k = rearrange(k, 'b s (h d) -> b h s d', h=self.n_heads)
        v = rearrange(v, 'b s (h d) -> b h s d', h=self.n_heads)
        
        attn = torch.einsum('bhqd,bhkd->bhqk', q, k) * self.scale
        attn = F.softmax(attn, dim=-1)
        
        out = torch.einsum('bhqk,bhvd->bhqd', attn, v)
        out = rearrange(out, 'b h t d -> b t (h d)')
        out = self.to_out(out)
        
        return rearrange(out, 'b t c -> b c t')

class ResidualBlock(nn.Module):
    """Residual block with group normalization"""
    
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv1d(channels, channels, 3, padding=1)
        self.conv2 = nn.Conv1d(channels, channels, 3, padding=1)
        self.norm1 = nn.GroupNorm(8, channels)
        self.norm2 = nn.GroupNorm(8, channels)
        self.act = nn.SiLU()
        
    def forward(self, x):
        residual = x
        x = self.act(self.norm1(self.conv1(x)))
        x = self.act(self.norm2(self.conv2(x)))
        return x + residual

class DownBlock(nn.Module):
    """Downsampling block with optional cross-attention"""
    
    def __init__(self, in_c, out_c, text_dim=768, use_attn=False):
        super().__init__()
        self.use_attn = use_attn
        
        self.conv = nn.Conv1d(in_c, out_c, 3, padding=1)
        self.res1 = ResidualBlock(out_c)
        self.res2 = ResidualBlock(out_c)
        
        if use_attn:
            self.attn = CrossAttention(out_c, text_dim)
        
        self.downsample = nn.Conv1d(out_c, out_c, 4, stride=2, padding=1)
        
    def forward(self, x, text_emb=None):
        x = self.conv(x)
        x = self.res1(x)
        x = self.res2(x)
        
        if self.use_attn and text_emb is not None:
            x = x + self.attn(x, text_emb)
        
        skip = x
        x = self.downsample(x)
        return x, skip

class UpBlock(nn.Module):
    """Upsampling block with skip connections and optional cross-attention"""
    
    def __init__(self, in_c, out_c, skip_c, text_dim=768, use_attn=False):
        super().__init__()
        self.use_attn = use_attn
        
        self.upsample = nn.ConvTranspose1d(in_c, out_c, 4, stride=2, padding=1)
        self.conv = nn.Conv1d(out_c + skip_c, out_c, 3, padding=1)
        self.res1 = ResidualBlock(out_c)
        self.res2 = ResidualBlock(out_c)
        
        if use_attn:
            self.attn = CrossAttention(out_c, text_dim)
        
    def forward(self, x, skip, text_emb=None):
        x = self.upsample(x)
        
        # Match temporal dimensions
        if x.size(-1) != skip.size(-1):
            x = F.interpolate(x, size=skip.size(-1), mode='linear', align_corners=False)
        
        x = torch.cat([x, skip], dim=1)
        x = self.conv(x)
        x = self.res1(x)
        x = self.res2(x)
        
        if self.use_attn and text_emb is not None:
            x = x + self.attn(x, text_emb)
        
        return x

class LatentUNet(nn.Module):
    """UNet for manipulating DAC latent space"""
    
    def __init__(self, latent_channels, channels, text_dim=768):
        super().__init__()
        
        # Input projection
        self.input_conv = nn.Conv1d(latent_channels, channels[0], 7, padding=3)
        
        # Encoder
        self.down_blocks = nn.ModuleList()
        for i in range(len(channels) - 1):
            use_attn = i >= 2  # Add attention in deeper layers
            self.down_blocks.append(
                DownBlock(channels[i], channels[i+1], text_dim, use_attn)
            )
        
        # Bottleneck
        self.mid_block1 = ResidualBlock(channels[-1])
        self.mid_attn = CrossAttention(channels[-1], text_dim)
        self.mid_block2 = ResidualBlock(channels[-1])
        
        # Decoder
        self.up_blocks = nn.ModuleList()
        for i in range(len(channels) - 1, 0, -1):
            use_attn = i >= 2
            self.up_blocks.append(
                UpBlock(
                    in_c=channels[i],
                    out_c=channels[i-1],
                    skip_c=channels[i],
                    text_dim=text_dim,
                    use_attn=use_attn
                )
            )
        
        # Output projection
        self.output_conv = nn.Conv1d(channels[0], latent_channels, 7, padding=3)
        
    def forward(self, z, text_emb):
        """
        z: (B, latent_channels, T) - DAC latents
        text_emb: (B, S, text_dim) - text embeddings
        """
        original_length = z.size(-1)
        
        x = self.input_conv(z)
        
        # Encoder path
        skips = []
        for down in self.down_blocks:
            x, skip = down(x, text_emb)
            skips.append(skip)
        
        # Bottleneck
        x = self.mid_block1(x)
        x = x + self.mid_attn(x, text_emb)
        x = self.mid_block2(x)
        
        # Decoder path
        for up in self.up_blocks:
            skip = skips.pop()
            x = up(x, skip, text_emb)
        
        # Output
        x = self.output_conv(x)
        
        # Ensure output matches input length
        if x.size(-1) != original_length:
            x = F.interpolate(x, size=original_length, mode='linear', align_corners=False)
        
        return x

class AudioEffectModel(nn.Module):
    """Complete model: Text Encoder + UNet + DAC (no audiotools!)"""
    
    def __init__(self, dac_model, latent_channels, unet_channels, text_dim):
        super().__init__()
        
        # Text encoder (BERT)
        self.text_encoder = AutoModel.from_pretrained("bert-base-uncased")
        
        # Freeze/unfreeze text encoder based on config
        if cfg.freeze_text_encoder:
            for param in self.text_encoder.parameters():
                param.requires_grad = False
            print("Text encoder: FROZEN ❄️")
        else:
            for param in self.text_encoder.parameters():
                param.requires_grad = True
            print("Text encoder: TRAINABLE 🔥 (fine-tuning enabled)")
        
        # DAC model (frozen)
        self.dac = dac_model
        
        # UNet (trainable)
        self.unet = LatentUNet(latent_channels, unet_channels, text_dim)
        
    def forward(self, wav_in, wav_out, input_ids, attention_mask):
        """
        Forward pass with input validation (NO AUDIOTOOLS!)
        
        Returns:
            wav_pred: Predicted output waveform
            z_pred: Predicted latent
            z_target: Target latent
        """
        # Check for NaN in inputs
        if torch.isnan(wav_in).any() or torch.isnan(wav_out).any():
            return None, None, None
        
        # Encode text
        text_output = self.text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        text_emb = text_output.last_hidden_state  # (B, S, 768)
        
        # Encode audio to latents using DAC (NO AUDIOTOOLS - direct encoder call!)
        with torch.no_grad():
            z_in = self.dac.encoder(wav_in)
            z_target = self.dac.encoder(wav_out)
        
        # Check for NaN in latents
        if torch.isnan(z_in).any() or torch.isnan(z_target).any():
            print("⚠️ NaN detected in DAC encoding")
            return None, None, None
        
        # Process with UNet
        z_pred = self.unet(z_in, text_emb)
        
        # Check for NaN in prediction
        if torch.isnan(z_pred).any():
            print("⚠️ NaN detected in UNet output")
            return None, None, None
        
        # Decode latents to waveform (NO AUDIOTOOLS - direct decoder call!)
        with torch.no_grad():
            # Decode directly - decoder only needs the latents!
            wav_pred = self.dac.decoder(z_pred)
        
        # Check for NaN in decoded audio
        if torch.isnan(wav_pred).any():
            print("⚠️ NaN detected in DAC decoding")
            return None, None, None
        
        return wav_pred, z_pred, z_target

def init_weights(m):
    """Initialize weights with small values for stability"""
    if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        m.weight.data *= 0.1  # Scale down for stability
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight, gain=0.02)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.GroupNorm):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)

#############################################
#     MODEL INITIALIZATION
#############################################

print("="*60)
print("INITIALIZING MODEL")
print("="*60)

model = AudioEffectModel(
    dac_model=dac_model,
    latent_channels=latent_channels,
    unet_channels=cfg.unet_channels,
    text_dim=cfg.text_dim
).to(cfg.device)

# Initialize UNet weights
model.unet.apply(init_weights)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Frozen parameters: {total_params - trainable_params:,}")
print(f"UNet channels: {cfg.unet_channels}")
print(f"Latent channels: {latent_channels}")
print()

#############################################
#     OPTIMIZER & LOSS
#############################################

# Optimizer based on freeze_text_encoder setting
if cfg.freeze_text_encoder:
    # Only optimize UNet
    optimizer = torch.optim.AdamW(
        model.unet.parameters(),
        lr=cfg.lr_unet,
        weight_decay=cfg.weight_decay
    )
else:
    # Optimize UNet + Text Encoder
    optimizer = torch.optim.AdamW([
        {"params": model.unet.parameters(), "lr": cfg.lr_unet},
        {"params": model.text_encoder.parameters(), "lr": cfg.lr_text},
    ], weight_decay=cfg.weight_decay)

# Cosine annealing scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=cfg.epochs * len(train_dl),
    eta_min=cfg.lr_unet * 0.1
)

# Loss functions
criterion_audio = nn.L1Loss()
criterion_latent = nn.MSELoss()

# Mixed precision scaler
scaler = torch.amp.GradScaler('cuda', enabled=cfg.use_amp)

print("="*60)
print("TRAINING SETUP")
print("="*60)
print(f"Optimizer: AdamW")
print(f"Learning rate: {cfg.lr_unet}")
print(f"Scheduler: CosineAnnealingLR")
print(f"Loss: L1 (audio) + MSE (latent)")
print(f"Mixed precision: {cfg.use_amp}")
print()

#############################################
#     TRAINING & VALIDATION FUNCTIONS
#############################################

def train_epoch(model, dataloader, optimizer, scheduler, scaler, epoch):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    total_audio_loss = 0
    total_latent_loss = 0
    nan_count = 0
    
    optimizer.zero_grad()
    
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{cfg.epochs}")
    
    for step, (wav_in, wav_out, ids, mask) in enumerate(pbar):
        wav_in = wav_in.to(cfg.device)
        wav_out = wav_out.to(cfg.device)
        ids = ids.to(cfg.device)
        mask = mask.to(cfg.device)
        
        # Check input
        if torch.isnan(wav_in).any() or torch.isnan(wav_out).any():
            print(f"⚠️ NaN in input at step {step}, skipping...")
            nan_count += 1
            continue
        
        with torch.amp.autocast('cuda', enabled=cfg.use_amp):
            # Forward pass
            wav_pred, z_pred, z_target = model(wav_in, wav_out, ids, mask)
            
            # Check for None (indicates NaN in forward pass)
            if wav_pred is None:
                nan_count += 1
                continue
            
            # Match lengths
            if wav_pred.size(-1) != wav_out.size(-1):
                min_len = min(wav_pred.size(-1), wav_out.size(-1))
                wav_pred = wav_pred[..., :min_len]
                wav_out = wav_out[..., :min_len]
            
            if z_pred.size(-1) != z_target.size(-1):
                min_len = min(z_pred.size(-1), z_target.size(-1))
                z_pred = z_pred[..., :min_len]
                z_target = z_target[..., :min_len]
            
            # Compute losses
            audio_loss = criterion_audio(wav_pred, wav_out)
            latent_loss = criterion_latent(z_pred, z_target)
            
            loss = (cfg.audio_loss_weight * audio_loss + 
                   cfg.latent_loss_weight * latent_loss)
            
            # Scale for gradient accumulation
            loss = loss / cfg.accumulation_steps
        
        # Check loss
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"⚠️ NaN/Inf loss at step {step}, skipping...")
            nan_count += 1
            continue
        
        # Backward pass
        scaler.scale(loss).backward()
        
        # Optimizer step (with gradient accumulation)
        if (step + 1) % cfg.accumulation_steps == 0:
            # Unscale gradients
            scaler.unscale_(optimizer)
            
            # Clip gradients
            grad_norm = torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                cfg.grad_clip
            )
            
            # Check gradient norm
            if torch.isnan(grad_norm) or torch.isinf(grad_norm) or grad_norm > 100:
                print(f"⚠️ Bad gradient (norm={grad_norm:.2f}) at step {step}, skipping...")
                optimizer.zero_grad()
                scaler.update()
                nan_count += 1
                continue
            
            # Update weights
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            optimizer.zero_grad()
        
        # Accumulate losses
        total_loss += loss.item() * cfg.accumulation_steps
        total_audio_loss += audio_loss.item()
        total_latent_loss += latent_loss.item()
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f'{loss.item() * cfg.accumulation_steps:.4f}',
            'audio': f'{audio_loss.item():.4f}',
            'latent': f'{latent_loss.item():.4f}',
            'nans': nan_count
        })
        
        # Log every N steps
        if (step + 1) % cfg.log_interval == 0:
            avg_loss = total_loss / (step + 1)
            print(f"\n  Step {step+1}/{len(dataloader)} | "
                  f"Loss: {avg_loss:.6f} | "
                  f"Audio: {total_audio_loss/(step+1):.6f} | "
                  f"Latent: {total_latent_loss/(step+1):.6f} | "
                  f"NaNs: {nan_count}")
    
    if nan_count > 0:
        print(f"\n⚠️ Epoch had {nan_count} NaN occurrences")
    
    avg_loss = total_loss / len(dataloader)
    avg_audio_loss = total_audio_loss / len(dataloader)
    avg_latent_loss = total_latent_loss / len(dataloader)
    
    return avg_loss, avg_audio_loss, avg_latent_loss

@torch.no_grad()
def validate_epoch(model, dataloader, epoch):
    """Validate for one epoch"""
    model.eval()
    total_loss = 0
    total_audio_loss = 0
    total_latent_loss = 0
    
    pbar = tqdm(dataloader, desc=f"Validation {epoch+1}/{cfg.epochs}")
    
    for wav_in, wav_out, ids, mask in pbar:
        wav_in = wav_in.to(cfg.device)
        wav_out = wav_out.to(cfg.device)
        ids = ids.to(cfg.device)
        mask = mask.to(cfg.device)
        
        with torch.amp.autocast('cuda', enabled=cfg.use_amp):
            wav_pred, z_pred, z_target = model(wav_in, wav_out, ids, mask)
            
            if wav_pred is None:
                continue
            
            # Match lengths
            if wav_pred.size(-1) != wav_out.size(-1):
                min_len = min(wav_pred.size(-1), wav_out.size(-1))
                wav_pred = wav_pred[..., :min_len]
                wav_out = wav_out[..., :min_len]
            
            if z_pred.size(-1) != z_target.size(-1):
                min_len = min(z_pred.size(-1), z_target.size(-1))
                z_pred = z_pred[..., :min_len]
                z_target = z_target[..., :min_len]
            
            audio_loss = criterion_audio(wav_pred, wav_out)
            latent_loss = criterion_latent(z_pred, z_target)
            
            loss = (cfg.audio_loss_weight * audio_loss + 
                   cfg.latent_loss_weight * latent_loss)
        
        total_loss += loss.item()
        total_audio_loss += audio_loss.item()
        total_latent_loss += latent_loss.item()
        
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'audio': f'{audio_loss.item():.4f}',
            'latent': f'{latent_loss.item():.4f}'
        })
    
    avg_loss = total_loss / len(dataloader)
    avg_audio_loss = total_audio_loss / len(dataloader)
    avg_latent_loss = total_latent_loss / len(dataloader)
    
    return avg_loss, avg_audio_loss, avg_latent_loss

#############################################
#     TRAINING LOOP
#############################################

print("="*60)
print("STARTING TRAINING")
print("="*60)
print(f"Total epochs: {cfg.epochs}")
print(f"Steps per epoch: {len(train_dl)}")
print(f"Validation every epoch")
print("="*60 + "\n")

# Training history
train_losses = []
val_losses = []
train_audio_losses = []
train_latent_losses = []
val_audio_losses = []
val_latent_losses = []

best_val_loss = float('inf')
start_epoch = 0

# Resume from checkpoint if exists
if os.path.exists(cfg.checkpoint_path):
    print("Loading checkpoint...")
    ckpt = torch.load(cfg.checkpoint_path, map_location=cfg.device)
    model.load_state_dict(ckpt['model'])
    optimizer.load_state_dict(ckpt['optimizer'])
    scheduler.load_state_dict(ckpt['scheduler'])
    scaler.load_state_dict(ckpt['scaler'])
    start_epoch = ckpt['epoch'] + 1
    train_losses = ckpt.get('train_losses', [])
    val_losses = ckpt.get('val_losses', [])
    train_audio_losses = ckpt.get('train_audio_losses', [])
    train_latent_losses = ckpt.get('train_latent_losses', [])
    val_audio_losses = ckpt.get('val_audio_losses', [])
    val_latent_losses = ckpt.get('val_latent_losses', [])
    best_val_loss = ckpt.get('best_val_loss', float('inf'))
    print(f"✓ Resumed from epoch {start_epoch}")
    print()

# Training loop
for epoch in range(start_epoch, cfg.epochs):
    print(f"\n{'='*60}")
    print(f"EPOCH {epoch+1}/{cfg.epochs}")
    print(f"{'='*60}\n")
    
    # Train
    train_loss, train_audio, train_latent = train_epoch(
        model, train_dl, optimizer, scheduler, scaler, epoch
    )
    
    # Validate
    val_loss, val_audio, val_latent = validate_epoch(
        model, val_dl, epoch
    )
    
    # Store losses
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_audio_losses.append(train_audio)
    train_latent_losses.append(train_latent)
    val_audio_losses.append(val_audio)
    val_latent_losses.append(val_latent)
    
    # Print summary
    print(f"\n{'='*60}")
    print(f"EPOCH {epoch+1}/{cfg.epochs} SUMMARY")
    print(f"{'='*60}")
    print(f"Train Loss:  {train_loss:.6f} (Audio: {train_audio:.6f}, Latent: {train_latent:.6f})")
    print(f"Val Loss:    {val_loss:.6f} (Audio: {val_audio:.6f}, Latent: {val_latent:.6f})")
    print(f"Learning Rate: {scheduler.get_last_lr()[0]:.2e}")
    print(f"{'='*60}\n")
    
    # Save checkpoint
    checkpoint = {
        'epoch': epoch,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict(),
        'scaler': scaler.state_dict(),
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_audio_losses': train_audio_losses,
        'train_latent_losses': train_latent_losses,
        'val_audio_losses': val_audio_losses,
        'val_latent_losses': val_latent_losses,
        'best_val_loss': best_val_loss,
        'config': {
            'latent_channels': latent_channels,
            'unet_channels': cfg.unet_channels,
            'text_dim': cfg.text_dim,
            'sample_rate': cfg.sample_rate
        }
    }
    torch.save(checkpoint, cfg.checkpoint_path)
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model': model.state_dict(),
            'val_loss': best_val_loss,
            'config': checkpoint['config']
        }, cfg.best_model_path)
        print(f"✅ NEW BEST MODEL! Val Loss: {best_val_loss:.6f}\n")

print("\n" + "="*60)
print("TRAINING COMPLETE!")
print("="*60 + "\n")

#############################################
#     TEST SET EVALUATION
#############################################

print("="*60)
print("TESTING BEST MODEL")
print("="*60 + "\n")

# Load best model
best_ckpt = torch.load(cfg.best_model_path, map_location=cfg.device)
model.load_state_dict(best_ckpt['model'])
print(f"Loaded best model from epoch {best_ckpt['epoch']}")

# Test
test_loss, test_audio, test_latent = validate_epoch(model, test_dl, cfg.epochs)

print(f"\n{'='*60}")
print("FINAL TEST RESULTS")
print(f"{'='*60}")
print(f"Test Loss:  {test_loss:.6f}")
print(f"  Audio Loss:  {test_audio:.6f}")
print(f"  Latent Loss: {test_latent:.6f}")
print(f"{'='*60}\n")

#############################################
#     PLOT TRAINING CURVES
#############################################

print("Generating training curves...")

fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Plot 1: Total loss
ax = axes[0, 0]
epochs_range = range(len(train_losses))
ax.plot(epochs_range, train_losses, 'b-', label='Train', linewidth=2, marker='o', markersize=4)
ax.plot(epochs_range, val_losses, 'r-', label='Val', linewidth=2, marker='s', markersize=4)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Total Loss', fontsize=12)
ax.set_title('Total Loss (Audio + Latent)', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

# Plot 2: Audio loss
ax = axes[0, 1]
ax.plot(epochs_range, train_audio_losses, 'b-', label='Train', linewidth=2, marker='o', markersize=4)
ax.plot(epochs_range, val_audio_losses, 'r-', label='Val', linewidth=2, marker='s', markersize=4)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Audio Loss (L1)', fontsize=12)
ax.set_title('Audio Reconstruction Loss', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

# Plot 3: Latent loss
ax = axes[1, 0]
ax.plot(epochs_range, train_latent_losses, 'b-', label='Train', linewidth=2, marker='o', markersize=4)
ax.plot(epochs_range, val_latent_losses, 'r-', label='Val', linewidth=2, marker='s', markersize=4)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Latent Loss (MSE)', fontsize=12)
ax.set_title('Latent Space Loss', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

# Plot 4: Generalization gap
ax = axes[1, 1]
gap = [v - t for t, v in zip(train_losses, val_losses)]
ax.plot(epochs_range, gap, 'g-', label='Val - Train', linewidth=2, marker='d', markersize=4)
ax.axhline(y=0, color='k', linestyle='--', alpha=0.3)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss Gap', fontsize=12)
ax.set_title('Generalization Gap (Val - Train)', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(cfg.plot_path, dpi=300, bbox_inches='tight')
print(f"✓ Plot saved to: {cfg.plot_path}")
plt.close()

#############################################
#     SAVE SUMMARY
#############################################

summary = {
    'dataset': {
        'total_samples': len(df),
        'train_samples': len(train_df),
        'val_samples': len(val_df),
        'test_samples': len(test_df)
    },
    'training': {
        'epochs': cfg.epochs,
        'batch_size': cfg.batch_size,
        'accumulation_steps': cfg.accumulation_steps,
        'effective_batch_size': cfg.batch_size * cfg.accumulation_steps
    },
    'model': {
        'total_parameters': total_params,
        'trainable_parameters': trainable_params,
        'latent_channels': latent_channels,
        'unet_channels': cfg.unet_channels
    },
    'results': {
        'best_train_loss': float(min(train_losses)),
        'best_val_loss': float(best_val_loss),
        'test_loss': float(test_loss),
        'test_audio_loss': float(test_audio),
        'test_latent_loss': float(test_latent)
    },
    'config': {
        'sample_rate': cfg.sample_rate,
        'max_audio_length': cfg.max_audio_length,
        'lr_unet': cfg.lr_unet,
        'audio_loss_weight': cfg.audio_loss_weight,
        'latent_loss_weight': cfg.latent_loss_weight
    }
}

summary_path = f"{cfg.base_path}/result_DAC/training_summary.json"
with open(summary_path, 'w') as f:
    json.dump(summary, f, indent=2)

print(f"✓ Summary saved to: {summary_path}")

print("\n" + "="*60)
print("ALL FILES SAVED")
print("="*60)
print(f"✓ Best model: {cfg.best_model_path}")
print(f"✓ Checkpoint: {cfg.checkpoint_path}")
print(f"✓ Training curves: {cfg.plot_path}")
print(f"✓ Summary: {summary_path}")
print("="*60)

print("\n🎉 TRAINING PIPELINE COMPLETE! 🎉\n")
print("Next steps:")
print("1. Check training curves for convergence")
print("2. Use inference script to test on new audio")
print("3. Fine-tune hyperparameters if needed")
print("\nGood luck with your production model! 🚀")

  from .autonotebook import tqdm as notebook_tqdm


✓ soundfile available
✓ DAC library imported successfully
DAC-VAE AUDIO EFFECT GENERATOR (No AudioTools)
Device: cuda
Sample Rate: 44100 Hz
Max Length: 5.0 seconds
Batch Size: 4 x 4 = 16

Loading DAC model...
✓ Loading from: C:/Users/user/Downloads/weights_44khz_16kbps.pth


  WeightNorm.apply(module, name, dim)


✓ DAC model loaded and frozen
✓ DAC Latent Channels: 128
✓ Time Reduction Factor: 512x

LOADING DATASET
Original dataset: 5000 samples

Validating files...


Validating: 100%|██████| 5000/5000 [00:01<00:00, 3309.64it/s]


✓ Valid samples: 5000

Dataset splits:
  Train:      3500 samples (70.0%)
  Validation: 750 samples (15.0%)
  Test:       750 samples (15.0%)

Loading tokenizer...
✓ Tokenizer loaded

CREATING DATALOADERS
Batches per epoch:
  Train: 875 batches
  Val:   188 batches
  Test:  188 batches

INITIALIZING MODEL
Text encoder: FROZEN ❄️
Total parameters: 188,877,378
Trainable parameters: 15,181,888
Frozen parameters: 173,695,490
UNet channels: [64, 128, 256, 512]
Latent channels: 128

TRAINING SETUP
Optimizer: AdamW
Learning rate: 2e-05
Scheduler: CosineAnnealingLR
Loss: L1 (audio) + MSE (latent)
Mixed precision: True

STARTING TRAINING
Total epochs: 100
Steps per epoch: 875
Validation every epoch

Loading checkpoint...
✓ Resumed from epoch 100


TRAINING COMPLETE!

TESTING BEST MODEL

Loaded best model from epoch 99


Validation 101/100: 100%|█| 188/188 [03:13<00:00,  1.03s/it, 



FINAL TEST RESULTS
Test Loss:  1.264591
  Audio Loss:  0.110320
  Latent Loss: 9.159253

Generating training curves...
✓ Plot saved to: C:/zahra/EchoMind/data/result/training_curves.png


FileNotFoundError: [Errno 2] No such file or directory: 'C:/zahra/EchoMind/data/result_DAC/training_summary.json'

# Inference

In [2]:
"""
Inference Script for DAC-VAE Audio Effect Generator (NO AUDIOTOOLS)
Test your trained model on new audio files

FIXED VERSION:
- No audiotools dependency
- Correct decoder API (simple call)
- Offline DAC model loading
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import os
from transformers import AutoTokenizer, AutoModel
from einops import rearrange
import argparse

# DAC import (NO audiotools!)
try:
    import dac
    print("✓ DAC library imported successfully")
except ImportError:
    print("❌ DAC not installed. Run: pip install descript-audio-codec")
    exit(1)

#############################################
#     MODEL ARCHITECTURE (SAME AS TRAINING)
#############################################

class CrossAttention(nn.Module):
    """Cross-attention between audio latents and text embeddings"""
    
    def __init__(self, audio_dim, text_dim, n_heads=8):
        super().__init__()
        self.n_heads = n_heads
        self.scale = (audio_dim // n_heads) ** -0.5
        self.to_q = nn.Linear(audio_dim, audio_dim)
        self.to_k = nn.Linear(text_dim, audio_dim)
        self.to_v = nn.Linear(text_dim, audio_dim)
        self.to_out = nn.Linear(audio_dim, audio_dim)
        
    def forward(self, x, context):
        B, C, T = x.shape
        x_flat = rearrange(x, 'b c t -> b t c')
        q = self.to_q(x_flat)
        k = self.to_k(context)
        v = self.to_v(context)
        q = rearrange(q, 'b t (h d) -> b h t d', h=self.n_heads)
        k = rearrange(k, 'b s (h d) -> b h s d', h=self.n_heads)
        v = rearrange(v, 'b s (h d) -> b h s d', h=self.n_heads)
        attn = torch.einsum('bhqd,bhkd->bhqk', q, k) * self.scale
        attn = F.softmax(attn, dim=-1)
        out = torch.einsum('bhqk,bhvd->bhqd', attn, v)
        out = rearrange(out, 'b h t d -> b t (h d)')
        out = self.to_out(out)
        return rearrange(out, 'b t c -> b c t')

class ResidualBlock(nn.Module):
    """Residual block with group normalization"""
    
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv1d(channels, channels, 3, padding=1)
        self.conv2 = nn.Conv1d(channels, channels, 3, padding=1)
        self.norm1 = nn.GroupNorm(8, channels)
        self.norm2 = nn.GroupNorm(8, channels)
        self.act = nn.SiLU()
        
    def forward(self, x):
        residual = x
        x = self.act(self.norm1(self.conv1(x)))
        x = self.act(self.norm2(self.conv2(x)))
        return x + residual

class DownBlock(nn.Module):
    """Downsampling block with optional cross-attention"""
    
    def __init__(self, in_c, out_c, text_dim=768, use_attn=False):
        super().__init__()
        self.use_attn = use_attn
        self.conv = nn.Conv1d(in_c, out_c, 3, padding=1)
        self.res1 = ResidualBlock(out_c)
        self.res2 = ResidualBlock(out_c)
        if use_attn:
            self.attn = CrossAttention(out_c, text_dim)
        self.downsample = nn.Conv1d(out_c, out_c, 4, stride=2, padding=1)
        
    def forward(self, x, text_emb=None):
        x = self.conv(x)
        x = self.res1(x)
        x = self.res2(x)
        if self.use_attn and text_emb is not None:
            x = x + self.attn(x, text_emb)
        skip = x
        x = self.downsample(x)
        return x, skip

class UpBlock(nn.Module):
    """Upsampling block with skip connections and optional cross-attention"""
    
    def __init__(self, in_c, out_c, skip_c, text_dim=768, use_attn=False):
        super().__init__()
        self.use_attn = use_attn
        self.upsample = nn.ConvTranspose1d(in_c, out_c, 4, stride=2, padding=1)
        self.conv = nn.Conv1d(out_c + skip_c, out_c, 3, padding=1)
        self.res1 = ResidualBlock(out_c)
        self.res2 = ResidualBlock(out_c)
        if use_attn:
            self.attn = CrossAttention(out_c, text_dim)
        
    def forward(self, x, skip, text_emb=None):
        x = self.upsample(x)
        if x.size(-1) != skip.size(-1):
            x = F.interpolate(x, size=skip.size(-1), mode='linear', align_corners=False)
        x = torch.cat([x, skip], dim=1)
        x = self.conv(x)
        x = self.res1(x)
        x = self.res2(x)
        if self.use_attn and text_emb is not None:
            x = x + self.attn(x, text_emb)
        return x

class LatentUNet(nn.Module):
    """UNet for manipulating DAC latent space"""
    
    def __init__(self, latent_channels, channels, text_dim=768):
        super().__init__()
        self.input_conv = nn.Conv1d(latent_channels, channels[0], 7, padding=3)
        
        self.down_blocks = nn.ModuleList()
        for i in range(len(channels) - 1):
            use_attn = i >= 2
            self.down_blocks.append(DownBlock(channels[i], channels[i+1], text_dim, use_attn))
        
        self.mid_block1 = ResidualBlock(channels[-1])
        self.mid_attn = CrossAttention(channels[-1], text_dim)
        self.mid_block2 = ResidualBlock(channels[-1])
        
        self.up_blocks = nn.ModuleList()
        for i in range(len(channels) - 1, 0, -1):
            use_attn = i >= 2
            self.up_blocks.append(
                UpBlock(channels[i], channels[i-1], channels[i], text_dim, use_attn)
            )
        
        self.output_conv = nn.Conv1d(channels[0], latent_channels, 7, padding=3)
        
    def forward(self, z, text_emb):
        original_length = z.size(-1)
        x = self.input_conv(z)
        
        skips = []
        for down in self.down_blocks:
            x, skip = down(x, text_emb)
            skips.append(skip)
        
        x = self.mid_block1(x)
        x = x + self.mid_attn(x, text_emb)
        x = self.mid_block2(x)
        
        for up in self.up_blocks:
            skip = skips.pop()
            x = up(x, skip, text_emb)
        
        x = self.output_conv(x)
        
        if x.size(-1) != original_length:
            x = F.interpolate(x, size=original_length, mode='linear', align_corners=False)
        
        return x

class AudioEffectModel(nn.Module):
    """Complete model: Text Encoder + UNet + DAC (NO AUDIOTOOLS!)"""
    
    def __init__(self, dac_model, latent_channels, unet_channels, text_dim):
        super().__init__()
        self.text_encoder = AutoModel.from_pretrained("bert-base-uncased")
        self.dac = dac_model
        self.unet = LatentUNet(latent_channels, unet_channels, text_dim)
        
    @torch.no_grad()
    def generate(self, wav_in, prompt, sample_rate):
        """
        Generate audio with effect applied (NO AUDIOTOOLS!)
        
        Args:
            wav_in: Input waveform (1, 1, T) or (1, T)
            prompt: Text description of effect (string)
            sample_rate: Sample rate of input audio
            
        Returns:
            wav_out: Output waveform with effect applied
        """
        self.eval()
        
        # Ensure correct shape
        if wav_in.dim() == 2:
            wav_in = wav_in.unsqueeze(1)
        
        # Tokenize prompt
        tokens = tokenizer(
            [prompt],
            padding=True,
            truncation=True,
            max_length=128,
            return_tensors="pt"
        ).to(wav_in.device)
        
        # Encode text
        text_output = self.text_encoder(
            input_ids=tokens.input_ids,
            attention_mask=tokens.attention_mask
        )
        text_emb = text_output.last_hidden_state
        
        # Encode audio to latents (NO AUDIOTOOLS - direct encoder!)
        z_in = self.dac.encoder(wav_in)
        
        # Process with UNet
        z_out = self.unet(z_in, text_emb)
        
        # Decode to waveform (FIXED - simple decoder call!)
        wav_out = self.dac.decoder(z_out)
        
        return wav_out

#############################################
#     INFERENCE CLASS
#############################################

#############################################
#     INFERENCE CLASS (FIXED FOR WINDOWS)
#############################################

class AudioEffectInference:
    def __init__(self, model_path, dac_model_path=None, device='cuda'):
        """
        Initialize inference pipeline
        
        Args:
            model_path: Path to trained model checkpoint (.pt file)
            dac_model_path: Path to DAC weights (optional, defaults to cache location)
            device: 'cuda' or 'cpu'
        """
        self.device = device if torch.cuda.is_available() else 'cpu'
        
        print("="*60)
        print("LOADING MODEL FOR INFERENCE (NO AUDIOTOOLS)")
        print("="*60)
        
        # Load checkpoint
        print(f"Loading checkpoint from: {model_path}")
        ckpt = torch.load(model_path, map_location=self.device)
        
        # Get config
        config = ckpt['config']
        self.sample_rate = config['sample_rate']
        latent_channels = config['latent_channels']
        unet_channels = config['unet_channels']
        text_dim = config['text_dim']
        
        print(f"✓ Sample rate: {self.sample_rate} Hz")
        print(f"✓ Latent channels: {latent_channels}")
        print(f"✓ UNet channels: {unet_channels}")
        
        # Load DAC model from local file
        if dac_model_path is None:
            dac_model_path = "C:/Users/user/Downloads/weights_44khz_16kbps.pth"
        
        print(f"Loading DAC model from: {dac_model_path}")
        
        if not os.path.exists(dac_model_path):
            print(f"\n❌ DAC model not found at: {dac_model_path}")
            print("Please download it first!")
            exit(1)
        
        self.dac_model = dac.DAC.load(dac_model_path)
        self.dac_model = self.dac_model.to(self.device)
        self.dac_model.eval()
        print("✓ DAC model loaded")
        
        # Create model
        self.model = AudioEffectModel(
            dac_model=self.dac_model,
            latent_channels=latent_channels,
            unet_channels=unet_channels,
            text_dim=text_dim
        ).to(self.device)
        
        # Load weights
        self.model.load_state_dict(ckpt['model'])
        self.model.eval()
        print("✓ Model weights loaded")
        
        # Load tokenizer
        global tokenizer
        tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        print("✓ Tokenizer loaded")
        
        print(f"✓ Device: {self.device}")
        print("="*60 + "\n")
    
    def process_audio(self, input_path, output_path, prompt):
        """
        Process audio file with effect (FIXED for Windows - uses soundfile!)
        
        Args:
            input_path: Path to input audio file
            output_path: Path to save output audio
            prompt: Text description of effect to apply
        """
        import soundfile as sf
        
        print(f"Processing: {input_path}")
        print(f"Effect: '{prompt}'")
        
        # Load audio using soundfile (NOT torchaudio!)
        wav, sr = sf.read(input_path)
        wav = torch.from_numpy(wav).float()
        
        # Ensure correct shape: (channels, samples)
        if wav.dim() == 1:
            wav = wav.unsqueeze(0)  # (samples,) -> (1, samples)
        elif wav.dim() == 2 and wav.size(0) > wav.size(1):
            wav = wav.t()  # (samples, channels) -> (channels, samples)
        
        original_length = wav.size(-1)
        
        # Resample if needed
        if sr != self.sample_rate:
            print(f"Resampling from {sr} Hz to {self.sample_rate} Hz")
            wav = torchaudio.functional.resample(wav, sr, self.sample_rate)
        
        # Convert to mono
        if wav.size(0) > 1:
            print("Converting to mono")
            wav = wav.mean(dim=0, keepdim=True)
        
        # Add batch dimension and move to device
        wav = wav.unsqueeze(0).to(self.device)  # (1, 1, samples)
        
        print(f"Input shape: {wav.shape}")
        print("Generating...")
        
        # Generate
        with torch.no_grad():
            wav_out = self.model.generate(wav, prompt, self.sample_rate)
        
        # Move to CPU and remove batch dimension
        wav_out = wav_out.squeeze(0).cpu()  # (1, samples)
        
        # Match original length (approximately)
        current_length = wav_out.size(-1)
        target_length = wav.squeeze(0).size(-1)
        
        if current_length != target_length:
            print(f"Adjusting length: {current_length} -> {target_length}")
            if current_length > target_length:
                wav_out = wav_out[..., :target_length]
            else:
                wav_out = F.pad(wav_out, (0, target_length - current_length))
        
        print(f"Output shape: {wav_out.shape}")
        
        # Save using soundfile (NOT torchaudio.save!)
        wav_out_np = wav_out.squeeze(0).numpy()  # (samples,)
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        sf.write(output_path, wav_out_np, self.sample_rate)
        
        print(f"✓ Saved to: {output_path}\n")
    
    def batch_process(self, input_dir, output_dir, prompt):
        """
        Process all audio files in a directory
        
        Args:
            input_dir: Directory containing input audio files
            output_dir: Directory to save output audio files
            prompt: Text description of effect to apply
        """
        os.makedirs(output_dir, exist_ok=True)
        
        # Get all audio files
        audio_extensions = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
        audio_files = [
            f for f in os.listdir(input_dir)
            if os.path.splitext(f)[1].lower() in audio_extensions
        ]
        
        if not audio_files:
            print(f"❌ No audio files found in {input_dir}")
            return
        
        print(f"Found {len(audio_files)} audio files")
        print(f"Effect: '{prompt}'\n")
        
        for i, filename in enumerate(audio_files, 1):
            print(f"[{i}/{len(audio_files)}] Processing: {filename}")
            
            input_path = os.path.join(input_dir, filename)
            # Keep original extension
            name, ext = os.path.splitext(filename)
            output_filename = f"{name}_processed{ext}"
            output_path = os.path.join(output_dir, output_filename)
            
            try:
                self.process_audio(input_path, output_path, prompt)
            except Exception as e:
                print(f"❌ Error processing {filename}: {e}\n")
        
        print(f"\n✅ Batch processing complete!")
        print(f"   Processed: {len(audio_files)} files")
        print(f"   Output directory: {output_dir}")
       

#############################################
#     MAIN FUNCTION
#############################################

def main():
    parser = argparse.ArgumentParser(
        description='Audio Effect Generator Inference (DAC-VAE)',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Process single file
  python dac_vae_inference_FIXED.py --model model_best.pt --input audio.wav --output result.wav --prompt "add rain sounds"
  
  # Batch process directory
  python dac_vae_inference_FIXED.py --model model_best.pt --input input_folder/ --output output_folder/ --prompt "add birds chirping"
        """
    )
    
    parser.add_argument('--model', type=str, required=True,
                       help='Path to trained model checkpoint (.pt file)')
    parser.add_argument('--input', type=str, required=True,
                       help='Input audio file or directory')
    parser.add_argument('--output', type=str, required=True,
                       help='Output audio file or directory')
    parser.add_argument('--prompt', type=str, required=True,
                       help='Effect description (e.g., "add rain sounds")')
    parser.add_argument('--dac-model', type=str, default=None,
                       help='Path to DAC model weights (default: C:/Users/user/Downloads/weights_44khz_16kbps.pth)')
    parser.add_argument('--device', type=str, default='cuda',
                       help='Device to use (cuda or cpu)')
    
    args = parser.parse_args()
    
    # Initialize inference
    try:
        inference = AudioEffectInference(args.model, args.dac_model, args.device)
    except Exception as e:
        print(f"❌ Failed to load model: {e}")
        return
    
    # Check if input is file or directory
    if os.path.isfile(args.input):
        # Single file
        inference.process_audio(args.input, args.output, args.prompt)
    elif os.path.isdir(args.input):
        # Batch processing
        inference.batch_process(args.input, args.output, args.prompt)
    else:
        print(f"❌ Error: {args.input} is not a valid file or directory")


  from .autonotebook import tqdm as notebook_tqdm


✓ DAC library imported successfully


In [4]:
if __name__ == '__main__':
    # Example usage (can be run directly or via command line)
    #C:/Users/user/Downloads/weights_44khz_16kbps.pth
    # For direct usage in script:
    
    model_path = "C:/zahra/EchoMind/data/result/model_best.pt"
    input_audio = "C:/zahra/EchoMind/data/audios/input_audios/22.wav"
    output_audio = "C:/zahra/EchoMind/data/result/inference100/22_lightning.wav"
    effect_prompt = "add lightning sounds"
    
    inference = AudioEffectInference(model_path, dac_model_path="C:/Users/user/Downloads/weights_44khz_16kbps.pth", device='cuda')
    inference.process_audio(input_audio, output_audio, effect_prompt)
    
    # For command line usage:
    main()

LOADING MODEL FOR INFERENCE (NO AUDIOTOOLS)
Loading checkpoint from: C:/zahra/EchoMind/data/result/model_best.pt
✓ Sample rate: 44100 Hz
✓ Latent channels: 128
✓ UNet channels: [64, 128, 256, 512]
Loading DAC model from: C:/Users/user/Downloads/weights_44khz_16kbps.pth
✓ DAC model loaded
✓ Model weights loaded
✓ Tokenizer loaded
✓ Device: cuda

Processing: C:/zahra/EchoMind/data/audios/input_audios/22.wav
Effect: 'add lightning sounds'
Resampling from 22050 Hz to 44100 Hz
Input shape: torch.Size([1, 1, 220500])
Generating...
Adjusting length: 220160 -> 220500
Output shape: torch.Size([1, 220500])
✓ Saved to: C:/zahra/EchoMind/data/result/inference100/22_lightning.wav



usage: ipykernel_launcher.py [-h] --model MODEL --input
                             INPUT --output OUTPUT --prompt
                             PROMPT [--dac-model DAC_MODEL]
                             [--device DEVICE]
ipykernel_launcher.py: error: the following arguments are required: --model, --input, --output, --prompt


SystemExit: 2

In [9]:
if __name__ == '__main__':
    # Example usage (can be run directly or via command line)
    #C:/Users/user/Downloads/weights_44khz_16kbps.pth
    # For direct usage in script:
    
    model_path = "C:/zahra/EchoMind/data/result/model_best.pt"
    input_audio = "C:/zahra/EchoMind/data/result/inference/english_OJCIvTNk.wav"
    output_audio = "C:/zahra/EchoMind/data/result/inference100/english_OJCIvTNk_lightning.wav"
    effect_prompt = "add lightning sounds"
    
    inference = AudioEffectInference(model_path, dac_model_path="C:/Users/user/Downloads/weights_44khz_16kbps.pth", device='cuda')
    inference.process_audio(input_audio, output_audio, effect_prompt)
    
    # For command line usage:
    main()

LOADING MODEL FOR INFERENCE (NO AUDIOTOOLS)
Loading checkpoint from: C:/zahra/EchoMind/data/result/model_best.pt
✓ Sample rate: 44100 Hz
✓ Latent channels: 128
✓ UNet channels: [64, 128, 256, 512]
Loading DAC model from: C:/Users/user/Downloads/weights_44khz_16kbps.pth
✓ DAC model loaded
✓ Model weights loaded
✓ Tokenizer loaded
✓ Device: cuda

Processing: C:/zahra/EchoMind/data/result/inference/english_OJCIvTNk.wav
Effect: 'add lightning sounds'
Resampling from 48000 Hz to 44100 Hz
Converting to mono
Input shape: torch.Size([1, 1, 213003])
Generating...
Adjusting length: 212992 -> 213003
Output shape: torch.Size([1, 213003])
✓ Saved to: C:/zahra/EchoMind/data/result/inference100/english_OJCIvTNk_lightning.wav



usage: ipykernel_launcher.py [-h] --model MODEL --input
                             INPUT --output OUTPUT --prompt
                             PROMPT [--dac-model DAC_MODEL]
                             [--device DEVICE]
ipykernel_launcher.py: error: the following arguments are required: --model, --input, --output, --prompt


SystemExit: 2

In [14]:
if __name__ == '__main__':
    # Example usage (can be run directly or via command line)
    #C:/Users/user/Downloads/weights_44khz_16kbps.pth
    # For direct usage in script:
    
    model_path = "C:/zahra/EchoMind/data/result/model_best.pt"
    input_audio = "C:/zahra/EchoMind/data/result/inference/arabic_XBmfzfHL.wav"
    output_audio = "C:/zahra/EchoMind/data/result/inference100/arabic_XBmfzfHL_cats.wav"
    effect_prompt = "add cats sounds"
    
    inference = AudioEffectInference(model_path, dac_model_path="C:/Users/user/Downloads/weights_44khz_16kbps.pth", device='cuda')
    inference.process_audio(input_audio, output_audio, effect_prompt)
    
    # For command line usage:
    main()

LOADING MODEL FOR INFERENCE (NO AUDIOTOOLS)
Loading checkpoint from: C:/zahra/EchoMind/data/result/model_best.pt
✓ Sample rate: 44100 Hz
✓ Latent channels: 128
✓ UNet channels: [64, 128, 256, 512]
Loading DAC model from: C:/Users/user/Downloads/weights_44khz_16kbps.pth
✓ DAC model loaded
✓ Model weights loaded
✓ Tokenizer loaded
✓ Device: cuda

Processing: C:/zahra/EchoMind/data/result/inference/arabic_XBmfzfHL.wav
Effect: 'add cats sounds'
Resampling from 48000 Hz to 44100 Hz
Converting to mono
Input shape: torch.Size([1, 1, 219618])
Generating...
Adjusting length: 219136 -> 219618
Output shape: torch.Size([1, 219618])
✓ Saved to: C:/zahra/EchoMind/data/result/inference100/arabic_XBmfzfHL_cats.wav



usage: ipykernel_launcher.py [-h] --model MODEL --input
                             INPUT --output OUTPUT --prompt
                             PROMPT [--dac-model DAC_MODEL]
                             [--device DEVICE]
ipykernel_launcher.py: error: the following arguments are required: --model, --input, --output, --prompt


SystemExit: 2

In [19]:
if __name__ == '__main__':
    # Example usage (can be run directly or via command line)
    #C:/Users/user/Downloads/weights_44khz_16kbps.pth
    # For direct usage in script:
    
    model_path = "C:/zahra/EchoMind/data/result/model_best.pt"
    input_audio = "C:/zahra/EchoMind/data/result/inference/bouchra.ogg"
    output_audio = "C:/zahra/EchoMind/data/result/inference100/bouchra_lightning.wav"
    effect_prompt = "add lightning sounds"
    
    inference = AudioEffectInference(model_path, dac_model_path="C:/Users/user/Downloads/weights_44khz_16kbps.pth", device='cuda')
    inference.process_audio(input_audio, output_audio, effect_prompt)
    
    # For command line usage:
    main()

LOADING MODEL FOR INFERENCE (NO AUDIOTOOLS)
Loading checkpoint from: C:/zahra/EchoMind/data/result/model_best.pt
✓ Sample rate: 44100 Hz
✓ Latent channels: 128
✓ UNet channels: [64, 128, 256, 512]
Loading DAC model from: C:/Users/user/Downloads/weights_44khz_16kbps.pth
✓ DAC model loaded
✓ Model weights loaded
✓ Tokenizer loaded
✓ Device: cuda

Processing: C:/zahra/EchoMind/data/result/inference/bouchra.ogg
Effect: 'add lightning sounds'
Resampling from 16000 Hz to 44100 Hz
Input shape: torch.Size([1, 1, 209820])
Generating...
Adjusting length: 209408 -> 209820
Output shape: torch.Size([1, 209820])
✓ Saved to: C:/zahra/EchoMind/data/result/inference100/bouchra_lightning.wav



usage: ipykernel_launcher.py [-h] --model MODEL --input
                             INPUT --output OUTPUT --prompt
                             PROMPT [--dac-model DAC_MODEL]
                             [--device DEVICE]
ipykernel_launcher.py: error: the following arguments are required: --model, --input, --output, --prompt


SystemExit: 2