### Prediction with NVIDIA Parakeet TDT 0.6B model

With VAD preprocessing to avoid CUDA out of memory

In [None]:
import json
import logging
import os
import tempfile
from google.cloud import storage

storage_client = storage.Client()

def download_gcs_file(gcs_uri: str, dest_path=None) -> str:    
    try:
        if not gcs_uri.startswith("gs://"):
            raise ValueError("URI non valido, deve iniziare con 'gs://'")
        bucket_name, blob_name = gcs_uri.replace("gs://", "").split("/", 1)
        bucket = storage_client.bucket(bucket_name)
        blob = bucket.blob(blob_name)
        suffix = os.path.splitext(blob_name)[1]
        if dest_path:
            blob.download_to_filename(dest_path)
            return dest_path
        else:
            with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
                blob.download_to_filename(tmp_file.name)
                print(f"File scaricato da {gcs_uri} a {tmp_file.name}")
                return tmp_file.name
    except Exception as e:
        logging.error(f"Errore durante il download da GCS {gcs_uri}: {e}")
        raise

In [7]:
import torch
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

In [None]:
import os
import json
import time
import math
import shutil
import tempfile
from typing import List, Tuple

import torch
import torchaudio
from pyannote.audio import Pipeline
import nemo.collections.asr as nemo_asr

# Pipeline per fare predizione con il modello Nvidia Parakeet 0.6 B con VAD preprocessing e gestione resiliente

def _deterministic_cache_dir(audio_path: str, cache_root: str = None) -> str:
    """Crea una cartella cache deterministica basata sul nome file audio.
    Esempio: /tmp/vad_cache/<basename_senza_ext>"""
    base = os.path.splitext(os.path.basename(audio_path))[0]
    cache_root = cache_root or os.path.join(tempfile.gettempdir(), "vad_cache")
    os.makedirs(cache_root, exist_ok=True)
    cache_dir = os.path.join(cache_root, base)
    os.makedirs(cache_dir, exist_ok=True)
    return cache_dir


def _run_vad_and_save_manifest(
    audio_path: str,
    cache_dir: str,
    sample_rate: int = 16000,
    max_duration: float = 30.0,
    hf_token: str | None = None,
) -> List[Tuple[float, float]]:
    """Esegue la VAD con pyannote, unisce i segmenti fino a max_duration e salva manifest su JSON.
    Restituisce la lista dei segmenti uniti [(start, end), ...]."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print("[INFO] Caricamento modello VAD pyannote.audio...")
    pipeline = Pipeline.from_pretrained(
        "pyannote/voice-activity-detection",
        use_auth_token=hf_token,
    )
    pipeline.to(device)

    print(f"[INFO] Analisi VAD in corso su {audio_path} ...")
    vad_output = pipeline(audio_path)

    # segmenti grezzi
    raw_segments = [(speech.start, speech.end) for speech in vad_output.get_timeline()]

    # salva manifest per ripresa senza ripetere VAD
    manifest_path = os.path.join(cache_dir, "segments.json")
    meta = {
        "audio_path": audio_path,
        "sample_rate": sample_rate,
        "max_duration": max_duration,
        "segments": raw_segments,
        "created_at": time.time(),
    }
    with open(manifest_path, "w", encoding="utf-8") as f:
        json.dump(meta, f, ensure_ascii=False, indent=2)

    print(f"[SAVE] Manifest VAD: {manifest_path} ({len(raw_segments)} segmenti)")
    return raw_segments


def prepare_vad_cache(
    audio_path: str,
    cache_root: str | None = None,
    sample_rate: int = 16000,
    max_duration: float = 30.0,
    hf_token: str | None = None,
) -> Tuple[List[Tuple[float, float]], str]:
    """Prepara (o carica) i segmenti VAD per un file audio.
    - Se esiste segments.json, lo carica e *non* riesegue la VAD.
    - Altrimenti esegue la VAD, salva il manifest e restituisce i segmenti.

    Ritorna: (segments, cache_dir)
    """
    cache_dir = _deterministic_cache_dir(audio_path, cache_root)
    manifest_path = os.path.join(cache_dir, "segments.json")

    if os.path.exists(manifest_path):
        with open(manifest_path, "r", encoding="utf-8") as f:
            meta = json.load(f)
        segments = [tuple(seg) for seg in meta["segments"]]
        print(f"[INFO] Manifest VAD trovato: {manifest_path} ({len(segments)} segmenti)")
    else:
        # token da env (se non passato esplicitamente)
        token = "mytoken"
        segments = _run_vad_and_save_manifest(
            audio_path,
            cache_dir,
            sample_rate=sample_rate,
            max_duration=max_duration,
            hf_token=token,
        )
    return segments, cache_dir


def _export_chunks_for_indices(
    audio_path: str,
    segments: List[Tuple[float, float]],
    indices: List[int],
    cache_dir: str,
    sample_rate: int = 16000,
) -> List[str]:
    """Esporta su disco i chunk WAV solo per gli indici richiesti e restituisce i path.
    I file sono creati come {cache_dir}/chunk_XXX.wav
    """
    waveform, orig_sr = torchaudio.load(audio_path)
    # mono
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    # resampler (solo se necessario)
    resampler = None
    if orig_sr != sample_rate:
        resampler = torchaudio.transforms.Resample(orig_sr, sample_rate)

    out_paths = []
    for idx in indices:
        s, e = segments[idx]
        start_sample = int(s * orig_sr)
        end_sample = int(e * orig_sr)
        chunk = waveform[:, start_sample:end_sample]
        if resampler is not None:
            chunk = resampler(chunk)
        out_path = os.path.join(cache_dir, f"chunk_{idx+1:03d}.wav")
        torchaudio.save(out_path, chunk, sample_rate)
        out_paths.append(out_path)
    return out_paths


def _delete_files(paths: List[str]) -> None:
    for p in paths:
        try:
            if os.path.exists(p):
                os.remove(p)
        except Exception as exc:
            print(f"[WARN] Impossibile rimuovere {p}: {exc}")


def _cleanup_cache_dir(cache_dir: str) -> None:
    """Rimuove completamente la cartella cache (chunk + manifest)."""
    try:
        if os.path.exists(cache_dir):
            shutil.rmtree(cache_dir)
            print(f"[CLEAN] Cache rimossa: {cache_dir}")
    except Exception as exc:
        print(f"[WARN] Cleanup incompleto {cache_dir}: {exc}")

def transcribe_chunks_resumable(
    audio_path: str,
    segments: List[Tuple[float, float]],
    cache_dir: str,
    model_name: str,
    output_segments_json: str,
    output_words_json: str,
    batch_size: int = 2,
    sample_rate: int = 16000,
    cleanup_on_success: bool = True,
) -> Tuple[list, list, float]:
    """Trascrive i segmenti con ripresa, cleanup chunk per-batch e skip chunk falliti.

    - Usa progress.json per sapere quanti chunk sono già stati completati.
    - Esporta solo i chunk necessari per il batch corrente; dopo il salvataggio
      dei risultati, elimina immediatamente i file WAV del batch.
    - Se un chunk fallisce, viene saltato e registrato in skipped_chunks.json.
    - Se l'intero job termina con successo, opzionalmente rimuove tutta la cache.
    """
    import json
    import time
    import os

    # Carica (o crea) file di output incrementali
    if os.path.exists(output_segments_json):
        with open(output_segments_json, "r", encoding="utf-8") as f:
            segments_result = json.load(f)
    else:
        segments_result = []

    if os.path.exists(output_words_json):
        with open(output_words_json, "r", encoding="utf-8") as f:
            words_result = json.load(f)
    else:
        words_result = []

    # Stato di avanzamento
    progress_path = os.path.join(cache_dir, "progress.json")
    if os.path.exists(progress_path):
        with open(progress_path, "r", encoding="utf-8") as f:
            progress = json.load(f)
        processed_chunks = int(progress.get("processed_chunks", 0))
    else:
        processed_chunks = 0

    skipped_path = os.path.join(cache_dir, "skipped_chunks.json")
    if os.path.exists(skipped_path):
        with open(skipped_path, "r", encoding="utf-8") as f:
            skipped_chunks = json.load(f)
    else:
        skipped_chunks = []

    total_chunks = len(segments)
    print(f"[INFO] Caricamento modello {model_name}...")
    asr_model = nemo_asr.models.ASRModel.from_pretrained(f"nvidia/{model_name}")

    elapsed_time = 0.0
    success = False

    try:
        for start_idx in range(processed_chunks, total_chunks, batch_size):
            end_idx = min(start_idx + batch_size, total_chunks)
            batch_indices = list(range(start_idx, end_idx))

            # Esporta chunk solo per questo batch
            batch_paths = _export_chunks_for_indices(
                audio_path, segments, batch_indices, cache_dir, sample_rate=sample_rate
            )

            print(f"[ASR] Trascrizione batch {start_idx+1}-{end_idx}/{total_chunks}")

            for chunk_idx, chunk_path in zip(batch_indices, batch_paths):
                t0 = time.time()
                try:
                    output = asr_model.transcribe([chunk_path], timestamps=True)[0]
                    t1 = time.time()
                    elapsed_time += (t1 - t0)

                    offset_start, offset_end = segments[chunk_idx]
                    # segment-level
                    for seg in output.timestamp.get('segment', []):
                        global_start = seg['start'] + offset_start
                        global_end = seg['end'] + offset_start
                        segments_result.append((global_start, global_end, seg['segment']))
                    # word-level
                    for w in output.timestamp.get('word', []):
                        global_start = w['start'] + offset_start
                        global_end = w['end'] + offset_start
                        words_result.append((global_start, global_end, w['word']))

                except Exception as e:
                    print(f"[WARN] Trascrizione chunk {chunk_idx+1} fallita, skip: {e}")
                    skipped_chunks.append(chunk_idx)

            # Salvataggio incrementale
            with open(output_segments_json, "w", encoding="utf-8") as f:
                json.dump(segments_result, f, ensure_ascii=False, indent=2)
            with open(output_words_json, "w", encoding="utf-8") as f:
                json.dump(words_result, f, ensure_ascii=False, indent=2)
            with open(skipped_path, "w", encoding="utf-8") as f:
                json.dump(skipped_chunks, f, ensure_ascii=False, indent=2)

            # Aggiorna progresso *dopo* il salvataggio
            processed_chunks = end_idx
            with open(progress_path, "w", encoding="utf-8") as f:
                json.dump({"processed_chunks": processed_chunks}, f)

            print(f"[SAVE] Salvati progressi fino al chunk {end_idx}/{total_chunks}")

            # Cleanup immediato dei file WAV del batch
            _delete_files(batch_paths)

        success = True
        print("\n[INFO] Trascrizione completata.")
        print(f"  Segmenti totali: {len(segments_result)}")
        print(f"  Parole totali: {len(words_result)}")
        if skipped_chunks:
            print(f"  Chunk saltati: {skipped_chunks}")
        return segments_result, words_result, elapsed_time

    finally:
        # Pulizia se e solo se tutto è andato a buon fine
        if success and cleanup_on_success:
            _cleanup_cache_dir(cache_dir)
        else:
            print(f"[INFO] Interruzione o errore: cache preservata in {cache_dir}")

In [None]:
def process_files(files: List[str]):
    for file in files:
        print(f"\n========== PROCESSING FILE {file} ==========")
        uri = f"gs://bucket-stagisti/lucca_folder/audio/{file}"
        dest_path = f"{file}"
        if not os.path.exists(dest_path):
            local_uri = download_gcs_file(uri, dest_path=dest_path)
        else:
            local_uri = dest_path

        # 1) Prepara (o ricarica) i segmenti VAD senza ripetere la VAD se già fatta
        segments, cache_dir = prepare_vad_cache(
            local_uri,
            cache_root=None,            # default /tmp/vad_cache
            sample_rate=16000,
            max_duration=30.0,
            hf_token="my_token" 
        )

        # 2) Trascrizione resiliente + cleanup automatico dei chunk
        segments_result, words_result, elapsed_time = transcribe_chunks_resumable(
            audio_path=local_uri,
            segments=segments,
            cache_dir=cache_dir,
            model_name="parakeet-tdt-0.6b-v3",
            output_segments_json=f"{file}_segments.json",
            output_words_json=f"{file}_words.json",
            batch_size=8,
            sample_rate=16000,
            cleanup_on_success=True,   # elimina cache solo a job completato
        )

        with open("times.txt", "a", encoding="utf-8") as f:
            f.write(f"{file} : {elapsed_time}\n")


In [None]:
from utils.names import get_file_names

files = get_file_names()


In [None]:
process_files(files)


### Formatting in single JSON

In [None]:
def build_transcript_json(segments_result, words_result):
    """
    Costruisce un JSON strutturato con segmenti e parole per ciascun segmento.

    Args:
        segments_result (list): lista di tuple (start, end, text)
        words_result (list): lista di tuple (start, end, word, [score])

    Returns:
        list: transcript strutturato
    """
    transcript = []

    for seg_start, seg_end, seg_text in segments_result:
        # Trova tutte le parole che appartengono al segmento
        segment_words = [
            {
                "word": w[2],
                "start": w[0],
                "end": w[1]                
            }
            for w in words_result
            if w[0] >= seg_start and w[1] <= seg_end
        ]

        transcript.append({
            "start": seg_start,
            "end": seg_end,
            "text": seg_text,
            "words": segment_words
        })

    return transcript

In [None]:
for file in files:
    with open(f"{file}_test.json", "r", encoding="utf-8") as f:
        segments_result = json.load(f)
    
    with open(f"{file}_words.json", "r", encoding="utf-8") as f:
        words_result = json.load(f)
    
    final_json = build_transcript_json(segments_result, words_result)
    
    with open(f"{file}.json", "w", encoding="utf-8") as f:
        json.dump(final_json, f, ensure_ascii=False, indent=2)
                                       
                        