In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import time
import numpy as np
import torchaudio
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import glob
import math
import urllib.request
import tarfile
import zipfile
from tqdm import tqdm
import subprocess
import sys
import shutil # Added for file copying

# --- Check and Install required packages ---
try:
    from pesq import pesq
    from pystoi import stoi
except ModuleNotFoundError:
    print("Installing required packages: pesq and pystoi...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "pesq", "pystoi"])
    from pesq import pesq
    from pystoi import stoi
    print("Installation complete. Continuing script execution.")

# --- Configuration for Causal VQ-Codec ---
SR = 16000
CHANNELS = 1
LATENT_DIM = 64
BLOCKS = 4
HEADS = 4
KERNEL_SIZE = 3
STRIDES = [2, 2, 4, 2] 
DOWN_FACTOR = np.prod(STRIDES)
HOP_SIZE_MS = 20
CHUNK_DURATION = 0.04 
WINDOW_SAMPLES = int(CHUNK_DURATION * SR) # 640 samples
HOP_SAMPLES = int(HOP_SIZE_MS * SR / 1000)
NUM_CODEBOOKS = 2
CODEBOOK_SIZE = 512
BITRATE_TARGET = (SR / DOWN_FACTOR * math.log2(CODEBOOK_SIZE) * NUM_CODEBOOKS) / 1000 

# Training Hyperparameters
SPECTRAL_LOSS_WEIGHT = 1.0
VQ_LOSS_WEIGHT = 0.1
COMMITMENT_COST = 1.0
TRANSFORMER_BLOCKS = 3
GRADIENT_ACCUMULATION_STEPS = 2
MONITOR_FREQUENCY = 10

# Paths
BASE_DIR = './TinyCodec_Training'
CHECKPOINT_PATH = f'{BASE_DIR}/checkpoint_full.pth'
BEST_MODEL_PATH = f'{BASE_DIR}/best_model.pth'
DATA_DIR = f'{BASE_DIR}/dataset'

# --- GOOGLE DRIVE PATH (New Addition) ---
# NOTE: This path assumes Google Drive is mounted at /content/drive/MyDrive/
GDRIVE_SAVE_DIR = '/content/drive/MyDrive/TinyCodec_Checkpoints'

LEARNING_RATE = 3e-4
EPOCHS = 50
BATCH_SIZE = 32
GRADIENT_CLIPPING_NORM = 1.0

# Create directories
os.makedirs(BASE_DIR, exist_ok=True)
os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(GDRIVE_SAVE_DIR, exist_ok=True) # Create GDrive save directory

print(f"Target Bitrate (Achieved): {BITRATE_TARGET:.2f} kbps")
print(f"Target Hop Latency: {HOP_SIZE_MS} ms")
print(f"Latent Frame Rate: {SR / DOWN_FACTOR} Hz")
print(f"Window Samples: {WINDOW_SAMPLES}")

# --- Dataset Download Function ---
def download_dataset():
    """Downloads LJSpeech dataset automatically."""
    dataset_url = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"
    dataset_path = os.path.join(DATA_DIR, "LJSpeech-1.1.tar.bz2")
    extract_path = os.path.join(DATA_DIR, "LJSpeech-1.1")
    
    if os.path.exists(os.path.join(extract_path, "wavs")):
        print("Dataset already exists. Skipping download.")
        return os.path.join(extract_path, "wavs")
    
    def download_with_progress(url, filepath):
        def download_hook(block_num, block_size, total_size):
            downloaded = block_num * block_size
            percent = min(downloaded * 100 / total_size, 100)
            mb_downloaded = downloaded / 1024 / 1024
            mb_total = total_size / 1024 / 1024
            print(f"Downloading: {percent:.1f}% ({mb_downloaded:.1f}/{mb_total:.1f} MB)", end='\r')
        
        print(f"Downloading LJSpeech dataset from {url}")
        urllib.request.urlretrieve(url, filepath, reporthook=download_hook)
        print("\nDownload complete!")
    
    if not os.path.exists(dataset_path):
        download_with_progress(dataset_url, dataset_path)
    
    print("Extracting dataset...")
    with tarfile.open(dataset_path, 'r:bz2') as tar:
        tar.extractall(DATA_DIR)
    print("Extraction complete!")
    
    os.remove(dataset_path)
    
    return os.path.join(extract_path, "wavs")

def download_mini_dataset():
    """Downloads a smaller dataset for quick testing."""
    print("Creating mini speech dataset for testing...")
    os.makedirs(os.path.join(DATA_DIR, "mini_wavs"), exist_ok=True)
    
    for i in range(100):
        duration = 2.0
        t = torch.linspace(0, duration, int(SR * duration))
        
        frequency = 200 + np.random.randint(-50, 50)
        signal = torch.sin(2 * np.pi * frequency * t)
        signal += 0.5 * torch.sin(2 * np.pi * frequency * 2 * t)
        signal += 0.3 * torch.sin(2 * np.pi * frequency * 3 * t)
        
        envelope = torch.exp(-t * 0.5) * (1 + 0.5 * torch.sin(2 * np.pi * 3 * t))
        signal = signal * envelope
        
        signal += 0.01 * torch.randn_like(signal)
        signal = signal / torch.max(torch.abs(signal))
        
        filepath = os.path.join(DATA_DIR, "mini_wavs", f"audio_{i:04d}.wav")
        torchaudio.save(filepath, signal.unsqueeze(0), SR)
    
    print(f"Created 100 sample audio files in {os.path.join(DATA_DIR, 'mini_wavs')}")
    return os.path.join(DATA_DIR, "mini_wavs")

# --- Model Components ---
class CausalConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super().__init__()
        self.padding_amount = kernel_size - 1
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding=0)
        self.norm = nn.GroupNorm(1, out_channels)
        self.relu = nn.ReLU()
        self.stride = stride

    def forward(self, x):
        x = F.pad(x, (self.padding_amount, 0), mode='constant', value=0)
        x = self.relu(self.norm(self.conv(x)))
        return x

class CausalTransformerBlock(nn.Module):
    def __init__(self, dim, heads):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )

    def forward(self, x):
        B, C, T = x.shape
        x_attn = x.transpose(1, 2)
        
        attn_mask = torch.triu(torch.ones(T, T, dtype=torch.bool, device=x.device), diagonal=1)
        
        attn_output, _ = self.attn(x_attn, x_attn, x_attn, attn_mask=attn_mask, is_causal=False)
        x_attn = self.norm1(x_attn + attn_output)
        
        ffn_output = self.ffn(x_attn)
        x_attn = self.norm2(x_attn + ffn_output)
        
        return x_attn.transpose(1, 2)

class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost=COMMITMENT_COST):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.commitment_cost = commitment_cost
        
        self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
        self.embedding.weight.data.uniform_(-1.0 / self.num_embeddings, 1.0 / self.num_embeddings)

    def forward(self, inputs):
        flat_input = inputs.transpose(1, 2).contiguous().view(-1, self.embedding_dim)
        
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                      + torch.sum(self.embedding.weight**2, dim=1)
                      - 2 * torch.matmul(flat_input, self.embedding.weight.t()))
            
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        
        encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)
        
        quantized = torch.matmul(encodings, self.embedding.weight).view(inputs.shape[0], inputs.shape[2], -1).transpose(1, 2)
        
        e_latent_loss = F.mse_loss(quantized.detach(), inputs) 
        q_latent_loss = F.mse_loss(quantized, inputs.detach()) 
        loss = q_latent_loss + self.commitment_cost * e_latent_loss

        quantized = inputs + (quantized - inputs).detach()
        
        return quantized, loss, encoding_indices

class TinyTransformerCodec(nn.Module):
    def __init__(self, latent_dim=LATENT_DIM, blocks=BLOCKS, heads=HEADS, sr=SR):
        super().__init__()
        self.latent_dim = latent_dim
        self.sr = sr
        self.downsampling_factor = DOWN_FACTOR
        self.num_codebooks = NUM_CODEBOOKS

        # Encoder
        self.encoder_convs = nn.ModuleList()
        in_c = CHANNELS
        encoder_channels = []
        
        # Define encoder channel progression
        for i in range(blocks):
            out_c = min(latent_dim, 8 * (2**i)) # 8, 16, 32, 64
            encoder_channels.append(out_c)
            stride = STRIDES[i]
            self.encoder_convs.append(
                CausalConvBlock(in_c, out_c, KERNEL_SIZE, stride)
            )
            in_c = out_c
        
        self.pre_quant = CausalConvBlock(in_c, LATENT_DIM * NUM_CODEBOOKS, KERNEL_SIZE, 1)

        # Vector Quantization
        self.quantizers = nn.ModuleList([
            VectorQuantizer(CODEBOOK_SIZE, LATENT_DIM, commitment_cost=COMMITMENT_COST)
            for _ in range(NUM_CODEBOOKS)
        ])

        # Transformer
        self.transformer = nn.Sequential(*[
            CausalTransformerBlock(latent_dim * NUM_CODEBOOKS, heads)
            for _ in range(TRANSFORMER_BLOCKS)
        ])
        self.post_transformer = nn.Conv1d(latent_dim * NUM_CODEBOOKS, latent_dim * NUM_CODEBOOKS, 1)

        # Decoder - process in reverse order
        self.decoder_tconvs = nn.ModuleList()
        self.skip_convs = nn.ModuleList()
        
        in_c = latent_dim * NUM_CODEBOOKS
        decoder_channels = []
        
        # Build decoder layers in reverse order of encoder
        for i in range(blocks):
            idx = blocks - 1 - i # Reverse index
            stride = STRIDES[idx]
            
            # Determine output channels
            if idx > 0:
                out_c = encoder_channels[idx - 1]
            else:
                out_c = 16 # Base channel count before final layer
            
            decoder_channels.append(out_c)
            
            # Transposed convolution for upsampling
            self.decoder_tconvs.append(
                nn.ConvTranspose1d(in_c, out_c, KERNEL_SIZE, stride, padding=KERNEL_SIZE//2)
            )
            
            # Skip connections (except for last decoder layer)
            if idx > 0:
                skip_in_channels = encoder_channels[idx - 1]
                # Skip conv: concatenated channels -> output channels
                self.skip_convs.append(
                    nn.Conv1d(out_c + skip_in_channels, out_c, kernel_size=1)
                )
            
            in_c = out_c
        
        # Final output layer
        self.post_decoder_final = nn.Conv1d(in_c, CHANNELS, 1)

    def encode(self, x):
        x = x.view(x.size(0), CHANNELS, -1)
        input_length = x.shape[-1]
        
        encoder_outputs = []
        
        # Encoder
        for layer in self.encoder_convs:
            x = layer(x)
            encoder_outputs.append(x)
        
        # Pre-quantization
        z_e = self.pre_quant(x)
        
        # Vector Quantization
        z_q_list = []
        vq_loss_total = 0.0
        indices_list = []
        
        z_e_split = z_e.chunk(self.num_codebooks, dim=1)
        
        for i in range(self.num_codebooks):
            z_q, vq_loss, indices = self.quantizers[i](z_e_split[i])
            z_q_list.append(z_q)
            vq_loss_total += vq_loss
            indices_list.append(indices)
        
        z_q_concat = torch.cat(z_q_list, dim=1)
        
        # Transformer
        codes = self.transformer(z_q_concat)
        codes = self.post_transformer(codes)
        
        return codes, vq_loss_total, input_length, indices_list, encoder_outputs

    def decode(self, codes, input_length=None, encoder_outputs=None):
        x = codes
        
        # Decoder with skip connections
        for i, tconv in enumerate(self.decoder_tconvs):
            x = F.relu(tconv(x))
            
            # Apply skip connection if available
            if encoder_outputs and i < len(self.skip_convs):
                # Map decoder layer to encoder layer for skip connection
                # Decoder layer i corresponds to encoder layer (blocks - 2 - i)
                encoder_idx = len(self.encoder_convs) - 2 - i
                
                if 0 <= encoder_idx < len(encoder_outputs):
                    skip_features = encoder_outputs[encoder_idx]
                    
                    # Match temporal dimensions
                    min_len = min(skip_features.shape[-1], x.shape[-1])
                    skip_features = skip_features[..., :min_len]
                    x_trim = x[..., :min_len]
                    
                    # Concatenate and apply skip conv
                    x_cat = torch.cat([x_trim, skip_features], dim=1)
                    x_processed = self.skip_convs[i](x_cat)
                    
                    # Restore original length if needed
                    if x.shape[-1] > min_len:
                        x = torch.cat([x_processed, x[..., min_len:]], dim=-1)
                    else:
                        x = x_processed
        
        # Final output
        x = torch.tanh(self.post_decoder_final(x))
        
        # Match input length
        if input_length is not None:
            if x.shape[-1] > input_length:
                x = x[..., :input_length]
            elif x.shape[-1] < input_length:
                x = F.pad(x, (0, input_length - x.shape[-1]))
        
        return x.view(x.size(0), CHANNELS, -1)

# --- Loss Functions ---
class SpectralLoss(nn.Module):
    def __init__(self):
        super().__init__()
        # Adjust FFT sizes to be appropriate for 640-sample windows
        # Use smaller FFT sizes that fit within the window
        self.stft_losses = nn.ModuleList()
        
        # FFT sizes must be smaller than window size (640 samples)
        fft_sizes = [128, 256, 512]
        hop_sizes = [32, 64, 128]
        
        for fft, hop in zip(fft_sizes, hop_sizes):
            self.stft_losses.append(
                torchaudio.transforms.Spectrogram(
                    n_fft=fft, 
                    hop_length=hop, 
                    power=1,
                    normalized=False,
                    center=False # Don't pad the signal
                )
            )

    def forward(self, x, y):
        loss = 0.0
        for stft_fn in self.stft_losses:
            X = stft_fn(x)
            Y = stft_fn(y)
            loss += F.l1_loss(X, Y)
        return loss / len(self.stft_losses) # Average across different resolutions

# --- Dataset Class ---
class AudioFileDataset(Dataset):
    def __init__(self, data_dir, sr=SR, duration=CHUNK_DURATION, max_files=None):
        self.sr = sr
        self.chunk_len = int(duration * sr)
        
        self.file_list = glob.glob(os.path.join(data_dir, '*.wav'))
        if max_files:
            self.file_list = self.file_list[:max_files]
            
        if not self.file_list:
            raise ValueError(f"No .wav files found in {data_dir}")
            
        print(f"Found {len(self.file_list)} audio files")
        
        self.audio_chunks = []
        
        for fpath in tqdm(self.file_list, desc="Loading audio files"):
            try:
                wav, sample_rate = torchaudio.load(fpath)
                
                if sample_rate != self.sr:
                    resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sr)
                    wav = resampler(wav)
                
                wav = wav[0:1, :]
                
                for i in range(0, wav.shape[-1] - self.chunk_len + 1, self.chunk_len):
                    chunk = wav[:, i:i+self.chunk_len]
                    if chunk.shape[-1] == self.chunk_len:
                        self.audio_chunks.append(chunk)
                        
            except Exception as e:
                print(f"Skipping file {fpath}: {e}")
        
        print(f"Total audio chunks: {len(self.audio_chunks)}")

    def __len__(self):
        return len(self.audio_chunks)

    def __getitem__(self, idx):
        return self.audio_chunks[idx].clamp(-1.0, 1.0)

# --- Monitoring Functions ---
def run_quality_metrics(model, val_loader, device, sr=SR, max_batches=5):
    model.eval()
    
    all_original = []
    all_reconstructed = []
    
    val_iter = iter(val_loader)
    for _ in range(min(max_batches, len(val_loader))):
        try:
            batch = next(val_iter)
        except StopIteration:
            break
        
        all_original.append(batch.cpu().numpy())
        
        with torch.no_grad():
            audio = batch.to(device)
            with torch.amp.autocast(device_type='cuda' if device.type == 'cuda' else 'cpu', dtype=torch.float16):
                codes, _, input_length, _, encoder_outputs = model.encode(audio)
                reconstructed_audio = model.decode(codes, input_length, encoder_outputs)
        
        all_reconstructed.append(reconstructed_audio.cpu().numpy())
    
    if not all_original:
        model.train()
        return 0.0, 0.0
        
    original_wavs = np.concatenate(all_original, axis=0)
    reconstructed_wavs = np.concatenate(all_reconstructed, axis=0)
    
    pesq_scores = []
    stoi_scores = []
    
    for i in range(min(5, original_wavs.shape[0])):
        original = original_wavs[i, 0]
        reconstructed = reconstructed_wavs[i, 0]
        
        min_len = min(len(original), len(reconstructed))
        original, reconstructed = original[:min_len], reconstructed[:min_len]

        try:
            p = pesq(sr, original, reconstructed, 'wb')
            pesq_scores.append(p)
        except:
            pass

        s = stoi(original, reconstructed, sr, extended=False)
        stoi_scores.append(s)
        
    avg_pesq = np.mean(pesq_scores) if pesq_scores else 0.0
    avg_stoi = np.mean(stoi_scores) if stoi_scores else 0.0
    
    model.train()
    return avg_pesq, avg_stoi

def monitor_codebook_usage(model, val_loader, device, max_batches=5):
    model.eval()
    
    usage_counts = [torch.zeros(CODEBOOK_SIZE, device=device) for _ in range(NUM_CODEBOOKS)]
    
    val_iter = iter(val_loader)
    
    with torch.no_grad():
        for _ in range(min(max_batches, len(val_loader))):
            try:
                batch = next(val_iter)
            except StopIteration:
                break
            
            audio = batch.to(device)
            _, _, _, indices_list, _ = model.encode(audio)
            
            for cb_idx, indices in enumerate(indices_list):
                unique_indices = torch.unique(indices)
                usage_counts[cb_idx][unique_indices] += 1
                
    print("\n--- Codebook Utilization Check ---")
    for cb_idx in range(NUM_CODEBOOKS):
        used_codes = (usage_counts[cb_idx] > 0).sum().item()
        utilization = (used_codes / CODEBOOK_SIZE) * 100
        print(f"Codebook {cb_idx}: {utilization:.1f}% ({used_codes}/{CODEBOOK_SIZE} codes)")
    
    model.train()

def validate(model, val_loader, criterion, spectral_criterion, device):
    model.eval()
    total_val_loss = 0.0
    
    with torch.no_grad():
        for batch in val_loader:
            audio = batch.to(device)

            with torch.amp.autocast(device_type='cuda' if device.type == 'cuda' else 'cpu', dtype=torch.float16):
                codes, vq_loss, input_length, _, encoder_outputs = model.encode(audio)
                reconstructed_audio = model.decode(codes, input_length, encoder_outputs)
                
                reconstruction_loss = criterion(reconstructed_audio, audio)
                spectral_loss = spectral_criterion(reconstructed_audio, audio)
                
                loss = reconstruction_loss + SPECTRAL_LOSS_WEIGHT * spectral_loss + VQ_LOSS_WEIGHT * vq_loss
                
            total_val_loss += loss.item()

    model.train()
    return total_val_loss / len(val_loader) if len(val_loader) > 0 else 0.0

# --- Main Training Function ---
def train_model(use_mini_dataset=False):
    """Main training function with automatic dataset download."""
    print("\n" + "="*60)
    print("TINY TRANSFORMER CODEC TRAINING")
    print("="*60)
    
    # Step 1: Dataset Preparation
    print("\n--- Step 1: Dataset Preparation ---")
    if use_mini_dataset:
        audio_dir = download_mini_dataset()
        max_files = None
    else:
        audio_dir = download_dataset()
        max_files = 1000
    
    # Step 2: Setup Device
    print("\n--- Step 2: Device Setup ---")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    if device.type == 'cuda':
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    
    scaler = GradScaler()

    # Step 3: Initialize Model
    print("\n--- Step 3: Model Initialization ---")
    model = TinyTransformerCodec().to(device)
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total trainable parameters: {total_params:,}")
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.L1Loss()
    spectral_criterion = SpectralLoss().to(device)
    best_val_loss = float('inf')

    # Step 4: Load Dataset
    print("\n--- Step 4: Loading Dataset ---")
    full_dataset = AudioFileDataset(data_dir=audio_dir, max_files=max_files)
    train_size = int(0.9 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42)
    )
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    print(f"Batch size: {BATCH_SIZE}")
    print(f"Steps per epoch: {len(train_loader)}")
    
    # Step 5: Setup Scheduler
    steps_per_epoch = len(train_loader)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, 
        max_lr=LEARNING_RATE,
        epochs=EPOCHS,
        steps_per_epoch=steps_per_epoch,
        pct_start=0.1,
        anneal_strategy='cos'
    )
    
    start_epoch = 0

    # Step 6: Load Checkpoint if exists
    if os.path.exists(CHECKPOINT_PATH):
        try:
            print("\n--- Loading Checkpoint ---")
            checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            scaler.load_state_dict(checkpoint['scaler_state_dict'])
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            start_epoch = checkpoint['epoch'] + 1
            best_val_loss = checkpoint.get('best_val_loss', float('inf'))
            print(f"Resumed from epoch {start_epoch}, Best val loss: {best_val_loss:.6f}")
        except Exception as e:
            print(f"Could not load checkpoint: {e}")
            print("Starting fresh training...")

    # Step 7: Training Loop
    print("\n--- Step 5: Starting Training ---")
    print(f"Training for {EPOCHS} epochs")
    print("="*60)
    
    model.train()
    
    for epoch in range(start_epoch, EPOCHS):
        epoch_start_time = time.time()
        total_loss = 0
        reconstruction_loss_accum = 0
        spectral_loss_accum = 0
        vq_loss_accum = 0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
        
        for i, batch in enumerate(progress_bar):
            audio = batch.to(device)

            with torch.amp.autocast(device_type='cuda' if device.type == 'cuda' else 'cpu', dtype=torch.float16):
                codes, vq_loss, input_length, _, encoder_outputs = model.encode(audio)
                reconstructed_audio = model.decode(codes, input_length, encoder_outputs)

                reconstruction_loss = criterion(reconstructed_audio, audio)
                spectral_loss = spectral_criterion(reconstructed_audio, audio)
                
                loss = (reconstruction_loss + SPECTRAL_LOSS_WEIGHT * spectral_loss + VQ_LOSS_WEIGHT * vq_loss) / GRADIENT_ACCUMULATION_STEPS
                
            scaler.scale(loss).backward()

            total_loss += loss.item() * GRADIENT_ACCUMULATION_STEPS
            reconstruction_loss_accum += reconstruction_loss.item()
            spectral_loss_accum += spectral_loss.item()
            vq_loss_accum += vq_loss.item()
            
            if (i + 1) % GRADIENT_ACCUMULATION_STEPS == 0 or (i + 1) == len(train_loader):
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIPPING_NORM)
                
                scaler.step(optimizer)
                scheduler.step()
                scaler.update()
                optimizer.zero_grad()
            
            progress_bar.set_postfix({
                'Loss': f'{total_loss/(i+1):.4f}',
                'Recon': f'{reconstruction_loss_accum/(i+1):.4f}',
                'Spec': f'{spectral_loss_accum/(i+1):.4f}',
                'VQ': f'{vq_loss_accum/(i+1):.4f}'
            })

        avg_train_loss = total_loss / len(train_loader)
        
        # Validation
        print("\nRunning validation...")
        avg_val_loss = validate(model, val_loader, criterion, spectral_criterion, device)
        
        # Quality Metrics
        avg_pesq, avg_stoi = 0.0, 0.0
        if (epoch + 1) % 5 == 0:
            avg_pesq, avg_stoi = run_quality_metrics(model, val_loader, device)
        
        # Codebook Monitoring
        if (epoch + 1) % MONITOR_FREQUENCY == 0:
            monitor_codebook_usage(model, val_loader, device)
        
        epoch_time = time.time() - epoch_start_time
        
        # Print epoch summary
        print(f"\n{'='*60}")
        print(f"Epoch {epoch+1}/{EPOCHS} Summary:")
        print(f"  Train Loss: {avg_train_loss:.6f}")
        print(f"  Val Loss: {avg_val_loss:.6f}")
        if (epoch + 1) % 5 == 0:
            print(f"  PESQ: {avg_pesq:.4f}, STOI: {avg_stoi:.4f}")
        print(f"  Time: {epoch_time:.1f}s")
        print(f"  LR: {scheduler.get_last_lr()[0]:.6f}")
        
        # Save checkpoint (Local)
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scaler_state_dict': scaler.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': avg_train_loss,
            'best_val_loss': best_val_loss,
            'pesq': avg_pesq,
            'stoi': avg_stoi
        }, CHECKPOINT_PATH)
        
        # Save checkpoint (Google Drive)
        gdrive_checkpoint_path = os.path.join(GDRIVE_SAVE_DIR, os.path.basename(CHECKPOINT_PATH))
        try:
            shutil.copyfile(CHECKPOINT_PATH, gdrive_checkpoint_path)
            print(f"  Full checkpoint saved to GDrive: {gdrive_checkpoint_path}")
        except Exception as e:
            print(f"  WARNING: Could not save full checkpoint to GDrive: {e}")
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            
            # Save best model (Local)
            torch.save({
                'model_state_dict': model.state_dict(),
                'val_loss': best_val_loss,
            }, BEST_MODEL_PATH)
            print(f"  *** New best model saved locally! ***")
            
            # Save best model (Google Drive)
            gdrive_best_model_path = os.path.join(GDRIVE_SAVE_DIR, os.path.basename(BEST_MODEL_PATH))
            try:
                shutil.copyfile(BEST_MODEL_PATH, gdrive_best_model_path)
                print(f"  *** New best model saved to GDrive: {gdrive_best_model_path} ***")
            except Exception as e:
                print(f"  WARNING: Could not save best model to GDrive: {e}")

        print(f"{'='*60}\n")
    
    print("\n" + "="*60)
    print("TRAINING COMPLETE!")
    print(f"Best validation loss: {best_val_loss:.6f}")
    print(f"Model saved locally to: {BEST_MODEL_PATH}")
    print(f"Model saved to GDrive folder: {GDRIVE_SAVE_DIR}")
    print("="*60)

# --- Inference Function ---
def test_inference(model_path=BEST_MODEL_PATH):
    """Test the trained model on a sample."""
    print("\n--- Testing Inference ---")
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = TinyTransformerCodec().to(device)
    
    if os.path.exists(model_path):
        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Model loaded from {model_path}")
    else:
        print("No trained model found. Train first!")
        return
    
    model.eval()
    
    test_duration = 2.0
    t = torch.linspace(0, test_duration, int(SR * test_duration))
    test_audio = torch.sin(2 * np.pi * 440 * t) * 0.5 
    test_audio = test_audio.unsqueeze(0).unsqueeze(0).to(device)
    
    with torch.no_grad():
        codes, _, input_length, indices_list, encoder_outputs = model.encode(test_audio)
        reconstructed = model.decode(codes, input_length, encoder_outputs)
    
    print(f"Input shape: {test_audio.shape}")
    print(f"Codes shape: {codes.shape}")
    print(f"Reconstructed shape: {reconstructed.shape}")
    print(f"Compression ratio: {test_audio.shape[-1] / codes.shape[-1]:.1f}x")
    
    mse = F.mse_loss(reconstructed, test_audio).item()
    snr = 10 * np.log10(1.0 / mse) if mse > 0 else float('inf')
    print(f"Reconstruction SNR: {snr:.2f} dB")

# --- MAIN EXECUTION ---
if __name__ == '__main__':
    print("""
    ╔══════════════════════════════════════════════════════════╗
    ║     TINY TRANSFORMER VQ-CODEC - AUTOMATIC TRAINING       ║
    ╠══════════════════════════════════════════════════════════╣
    ║  This script will:                                       ║
    ║  1. Automatically download the LJSpeech dataset.         ║
    ║  2. Train the Tiny Transformer Codec model.              ║
    ║  3. Save checkpoints to Google Drive.                    ║
    ║                                                          ║
    ║  To use a small test dataset, change 'use_mini = False'  ║
    ║  to 'use_mini = True' below.                             ║
    ╚══════════════════════════════════════════════════════════╝
    """)
    
    use_mini = False
    
    train_model(use_mini_dataset=use_mini)
    a
    # test_inference()
