In [None]:
# Cell 1: Install dependencies
!apt-get update -qq
!apt-get install -y espeak-ng
!pip install -q TTS librosa soundfile viphoneme

import torch
import warnings
warnings.filterwarnings('ignore')

print(f"‚úÖ PyTorch: {torch.__version__}")
print(f"‚úÖ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")
    print(f"‚úÖ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
# Cell 2 (FIXED): VietSpeech Streaming Dataset - Tr√°nh OOM
import torch
from torch.utils.data import IterableDataset, DataLoader
import pyarrow.parquet as pq
from huggingface_hub import hf_hub_download
import soundfile as sf
from io import BytesIO
import os
import numpy as np

class VietSpeechStreamingDataset(IterableDataset):
    """
    Streaming dataset cho 100GB VietSpeech
    ‚úÖ ƒê·ªçc parquet theo BATCH nh·ªè (tr√°nh OOM)
    ‚úÖ Download file ‚Üí Process theo batch ‚Üí Yield samples ‚Üí Delete cache
    """
    
    def __init__(self, num_files=27, max_duration=10.0, target_sr=22050, 
                 parquet_batch_size=512):
        self.num_files = num_files
        self.max_duration = max_duration
        self.target_sr = target_sr
        self.parquet_batch_size = parquet_batch_size  # ƒê·ªçc 512 rows/l·∫ßn
        self.repo_id = "NhutP/VietSpeech"
    
    def __iter__(self):
        import librosa
        
        for file_idx in range(self.num_files):
            filename = f"data/train-{file_idx:05d}-of-00027.parquet"
            parquet_file = None
            
            try:
                print(f"\n{'='*60}")
                print(f"üì• File {file_idx+1}/{self.num_files}: {filename}")
                print(f"{'='*60}")
                
                # Download file (~4GB)
                parquet_file = hf_hub_download(
                    repo_id=self.repo_id,
                    filename=filename,
                    repo_type="dataset"
                )
                
                print(f"  ‚úÖ Downloaded: {os.path.basename(parquet_file)}")
                
                # ==========================================
                # üî• ƒê·ªåC THEO BATCH (TR√ÅNH OOM)
                # ==========================================
                # Thay v√¨: table = pq.read_table() (load h·∫øt 4.8GB v√†o RAM)
                # D√πng: iter_batches() ƒë·ªÉ ƒë·ªçc t·ª´ng batch nh·ªè
                
                parquet_file_obj = pq.ParquetFile(parquet_file)
                total_rows = parquet_file_obj.metadata.num_rows
                
                print(f"  üìä Total rows: {total_rows:,}")
                print(f"  üì¶ Reading in batches of {self.parquet_batch_size}")
                
                valid_count = 0
                processed_rows = 0
                
                # Iterate theo batch
                for batch in parquet_file_obj.iter_batches(
                    batch_size=self.parquet_batch_size
                ):
                    # Convert batch to dict
                    batch_dict = batch.to_pydict()
                    batch_size = len(batch_dict['audio'])
                    
                    # Process t·ª´ng row trong batch
                    for i in range(batch_size):
                        try:
                            # Get audio
                            audio_data = batch_dict['audio'][i]
                            audio_bytes = audio_data['bytes']
                            array, sr = sf.read(BytesIO(audio_bytes))
                            
                            # Filter duration
                            duration = len(array) / sr
                            if duration < 1.0 or duration > self.max_duration:
                                continue
                            
                            # Resample to 22050 Hz
                            if sr != self.target_sr:
                                array = librosa.resample(
                                    array,
                                    orig_sr=sr,
                                    target_sr=self.target_sr
                                )
                            
                            # Normalize
                            max_val = np.max(np.abs(array))
                            if max_val > 0:
                                array = array / max_val
                            
                            # Get text
                            text = batch_dict.get('transcription', [None])[i] or batch_dict.get('text', [None])[i]
                            if isinstance(text, list):
                                text = text[0] if text else ""
                            text = text.strip()
                            
                            # Filter text
                            if len(text) < 5 or len(text) > 200:
                                continue
                            
                            valid_count += 1
                            
                            yield {
                                'audio': array,
                                'text': text,
                                'sampling_rate': self.target_sr,
                                'duration': duration
                            }
                            
                        except Exception as e:
                            # Skip bad samples
                            continue
                    
                    processed_rows += batch_size
                    
                    # Progress update every few batches
                    if processed_rows % 5000 == 0:
                        print(f"  ‚è≥ Processed {processed_rows:,}/{total_rows:,} rows "
                              f"({processed_rows/total_rows*100:.1f}%) | "
                              f"Valid: {valid_count:,}")
                
                print(f"  ‚úÖ File complete: {valid_count:,} valid samples")
                
            except Exception as e:
                print(f"  ‚ùå Error: {e}")
                import traceback
                traceback.print_exc()
                
            finally:
                # X√ìA CACHE
                if parquet_file and os.path.exists(parquet_file):
                    try:
                        os.remove(parquet_file)
                        print(f"  üóëÔ∏è Deleted cache")
                    except Exception as e:
                        print(f"  ‚ö†Ô∏è Cannot delete: {e}")
                
                # Force garbage collection
                import gc
                gc.collect()

print("‚úÖ VietSpeechStreamingDataset created!")
print("\nüìä Memory-efficient features:")
print("  ‚úì Reads parquet in batches (512 rows/time)")
print("  ‚úì Max RAM usage: ~50-100 MB per batch")
print("  ‚úì Deletes cache after each file")
print("  ‚úì Suitable for Kaggle (30GB RAM limit)")

In [None]:
# Cell 3: Load pretrained VITS from Coqui
from TTS.api import TTS
import torch

print("üì• Loading pretrained VITS...\n")

tts = TTS(
    model_name="tts_models/en/ljspeech/vits",
    progress_bar=True,
    gpu=torch.cuda.is_available()
)

vits_model = tts.synthesizer.tts_model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"‚úÖ VITS loaded: {type(vits_model).__name__}")
print(f"‚úÖ Device: {device}")