In [1]:
from pathlib import Path
import os, json, time, gc
import numpy as np
import pandas as pd
import soundfile as sf
import librosa

# ---- TensorFlow / TF-Hub setup (stable settings) ----
# os.environ["CUDA_VISIBLE_DEVICES"] = ""  # uncomment for CPU-only
os.environ["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=0"
os.environ["XLA_FLAGS"] = "--xla_cpu_use_thunk_runtime=false"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"

import tensorflow as tf
import tensorflow_hub as hub
tf.config.optimizer.set_jit(False)

print("Available GPUs:", tf.config.list_physical_devices('GPU'))
print("TF version:", tf.__version__)
print("Librosa:", librosa.__version__)

# === Absolute roots ===
HOME = Path(os.environ["HOME"])
REPO_ROOT = HOME / "Uni-stuff/semester-2/applied_Ml/reef_zmsc"

# === Paths ===
IN_MANIFEST = REPO_ROOT / "data/manifests/sample_50k_stratified.parquet"
OUT_ROOT    = REPO_ROOT / "data/features/embeds_yamnet_50k"   # output folder

# === Embedding parameters ===
CLIP_SECONDS = 10.0      # must match manifest
HOP_SECONDS  = 10.0
RESAMPLE_HZ  = 16000
BATCH_SIZE   = 64        # lower if you hit memory issues
DTYPE_OUT    = "float16" # store embeddings compactly

# === Pilot controls ===
LIMIT_WINDOWS = None      # set e.g. 20000 for a pilot
DRY_RUN = False           # True → skip saving; False → write output

print("\nConfig loaded ✅")
print(f"Manifest: {IN_MANIFEST}")
print(f"Output root: {OUT_ROOT}")
print(f"Limit windows: {LIMIT_WINDOWS}")
print(f"Dry run: {DRY_RUN}\n")



2025-10-27 06:39:59.578273: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from pkg_resources import parse_version


Available GPUs: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
TF version: 2.20.0
Librosa: 0.11.0

Config loaded ✅
Manifest: /home/sparch/Uni-stuff/semester-2/applied_Ml/reef_zmsc/data/manifests/sample_50k_stratified.parquet
Output root: /home/sparch/Uni-stuff/semester-2/applied_Ml/reef_zmsc/data/features/embeds_yamnet_50k
Limit windows: None
Dry run: False



## Load Manifest & Estimate Disk/Time

In [2]:
# === Load manifest ===
IN_MANIFEST = Path(IN_MANIFEST)
assert IN_MANIFEST.exists(), f"Manifest not found: {IN_MANIFEST}"

df = pd.read_parquet(IN_MANIFEST) if IN_MANIFEST.suffix == ".parquet" else pd.read_csv(IN_MANIFEST)

# Column that points to WAVs: your sampler wrote 'filepath'
PATH_COL = "filepath" if "filepath" in df.columns else ("wav_path" if "wav_path" in df.columns else None)
assert PATH_COL is not None, "Manifest needs a 'filepath' (or 'wav_path') column."

# ---- Optional subsample for pilot ----
if LIMIT_WINDOWS:
    df = df.sample(n=min(LIMIT_WINDOWS, len(df)), random_state=42).reset_index(drop=True)

num_rows = len(df)
est_total_mb = (num_rows * 1024 * 2 * 1.2) / (1024**2)  # rough estimate
print(f"⚡ Using {num_rows:,} clips")
print(f"Estimated output size: ~{est_total_mb:.1f} MB\n")

⚡ Using 50,000 clips
Estimated output size: ~117.2 MB



## Load YaMNet Model (TF Hub)

In [3]:
# ---- Load YAMNet model ----
YAMNET_HANDLE = "https://tfhub.dev/google/yamnet/1"
yamnet = hub.load(YAMNET_HANDLE)
print("YAMNet model loaded ✅")

I0000 00:00:1761509408.734622  228644 gpu_device.cc:2020] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 3497 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3060 Laptop GPU, pci bus id: 0000:01:00.0, compute capability: 8.6


YAMNet model loaded ✅


## Helper functions - audio loading , batching , shard writing

In [5]:
# ---------- AUDIO LOADING ----------
def read_window(wav_path: str, start_s: float, end_s: float, target_sr: int = 16000) -> np.ndarray:
    """Reads a specific time window from a WAV file and resamples to target_sr."""
    try:
        with sf.SoundFile(wav_path, "r") as rf:
            sr = rf.samplerate
            start_frame = max(0, int(round(start_s * sr)))
            n_frames = max(0, int(round((end_s - start_s) * sr)))
            rf.seek(start_frame)
            y = rf.read(frames=n_frames, dtype="float32", always_2d=False)
    except Exception as e:
        print(f"⚠️ Error reading {wav_path}: {e}")
        return np.zeros(int(target_sr * 0.5), dtype=np.float32)

    if y.size == 0:
        y = np.zeros(int(target_sr * 0.5), dtype=np.float32)

    # convert to mono if multi-channel
    if y.ndim > 1:
        y = np.mean(y, axis=1).astype(np.float32)

    # resample to 16 kHz
    if sr != target_sr:
        try:
            y = librosa.resample(y, orig_sr=sr, target_sr=target_sr, res_type="kaiser_fast")
        except Exception as e:
            print(f"⚠️ Resample failed for {wav_path}: {e}")
            y = np.zeros(int(target_sr * 0.5), dtype=np.float32)

    return np.clip(y, -1.0, 1.0)

# ---------- PATH HELPERS ----------
def shard_path(out_root: Path, logger: str, date: str) -> Path:
    """Determines which Parquet shard to write this clip's embedding to."""
    return out_root / "PAPCA" / (logger or "unknown") / (date or "unknown") / "features.parquet"

# ---------- IMPROVED PATH HELPERS ----------
def detect_logger_date(wav_path: str):
    """
    Extract logger ID and date from path.
    Expected formats: 
      - .../PAPCA/<logger>/<YYYYMMDD>/...
      - .../PAPCA_test/<logger>/<YYYYMMDD>/...
    """
    try:
        # Convert to Path object and get parts
        parts = Path(wav_path).parts
        
        # Find PAPCA or PAPCA_test in the path
        papca_idx = -1
        for idx, part in enumerate(parts):
            if part == "PAPCA" or part == "PAPCA_test":
                papca_idx = idx
                break
        
        if papca_idx >= 0:
            # Extract logger and date (next two parts after PAPCA/PAPCA_test)
            if papca_idx + 2 < len(parts):
                logger = parts[papca_idx + 1]
                date = parts[papca_idx + 2]
                
                # Validate they're not empty
                if logger and date and logger != "" and date != "":
                    return logger, date
        
        # Fallback: try to extract from string path with regex-like pattern
        path_str = str(wav_path)
        
        # Try both PAPCA and PAPCA_test
        for papca_variant in ["/PAPCA_test/", "/PAPCA/"]:
            if papca_variant in path_str:
                after_papca = path_str.split(papca_variant)[1]
                parts_after = after_papca.split("/")
                if len(parts_after) >= 2:
                    logger = parts_after[0]
                    date = parts_after[1]
                    if logger and date:
                        return logger, date
        
    except Exception as e:
        print(f"⚠️ Path parsing error for {wav_path}: {e}")
    
    return "unknown", "unknown"

# ---------- EMBEDDING UTILS ----------
MIN_SEC = 0.5  # pad short/empty windows

def _pad_min_length(y, target_sr, min_sec=MIN_SEC):
    min_len = int(target_sr * min_sec)
    if y.size < min_len:
        y = np.pad(y.astype(np.float32), (0, min_len - y.size))
    return y

def yamnet_embeddings(batch_waveforms, target_sr=16000):
    """
    Runs YAMNet on a batch of waveforms (each 1D np.float32).
    Returns: np.ndarray [B, 1024]
    """
    outs = []
    for w in batch_waveforms:
        try:
            w = _pad_min_length(np.asarray(w, dtype=np.float32), target_sr)
            wf = tf.convert_to_tensor(w, dtype=tf.float32)
            with tf.device('/GPU:0'):
                scores, embeddings, _ = yamnet(wf)
                emb = tf.reduce_mean(embeddings, axis=0)  # [1024]
            outs.append(emb.numpy())
        except Exception as e:
            # CPU fallback
            wf = tf.convert_to_tensor(w, dtype=tf.float32)
            with tf.device('/CPU:0'):
                scores, embeddings, _ = yamnet(wf)
                outs.append(tf.reduce_mean(embeddings, axis=0).numpy())
    return np.stack(outs, axis=0)

# ---------- SAVE TO PARQUET ----------
def append_parquet_rows(path: Path, rows):
    """
    Append rows to a Parquet shard as a new row-group (no full-file read).
    """
    import pyarrow as pa
    import pyarrow.parquet as pq

    if not rows:
        return

    # Convert to pylist with float16 list for 'yamnet_1024'
    recs = []
    for r in rows:
        rec = dict(r)
        v = rec.get("yamnet_1024", None)
        if v is not None:
            rec["yamnet_1024"] = np.asarray(v, dtype=np.float16).tolist()
        recs.append(rec)

    table = pa.Table.from_pylist(recs)
    path.parent.mkdir(parents=True, exist_ok=True)

    if path.exists():
        # Append a new row group
        with pq.ParquetWriter(path, table.schema, compression="snappy", use_dictionary=True, write_statistics=True, append=True) as writer:
            writer.write_table(table)
    else:
        pq.write_table(table, path, compression="snappy", use_dictionary=True, write_statistics=True)

## Tiny sample (dry run) to verify shapes & speed

In [6]:
if len(df) == 0:
    print("❌ Manifest is empty — nothing to process.")
else:
    sample_df = df.sample(n=min(8, len(df)), random_state=42).reset_index(drop=True)
    print(f"🎧 Testing YAMNet on {len(sample_df)} random clips...")

    batch = []
    for _, r in sample_df.iterrows():
        try:
            start_s = float(r["start_s"])
            end_val = r.get("end_s", np.nan)  # r is a Series; .get is safe
            if pd.notna(end_val):
                end_s = float(end_val)
            else:
                dur = float(r.get("duration_s", CLIP_SECONDS))
                end_s = start_s + dur

            w = read_window(r[PATH_COL], start_s, end_s, target_sr=RESAMPLE_HZ)
            batch.append(w)
        except Exception as e:
            print(f"⚠️ Error reading {r[PATH_COL]}: {e}")


🎧 Testing YAMNet on 8 random clips...


## Full pass , we write per day parquet shards (float16) to save memory

In [7]:
# ---------- IMPROVED SAVE TO PARQUET ----------
def append_parquet_rows(path: Path, rows):
    """
    Efficiently append rows to a Parquet file using ParquetWriter properly.
    """
    import pyarrow as pa
    import pyarrow.parquet as pq

    if not rows:
        return

    # Convert to pylist with float16 list for 'yamnet_1024'
    recs = []
    for r in rows:
        rec = dict(r)
        v = rec.get("yamnet_1024", None)
        if v is not None:
            rec["yamnet_1024"] = np.asarray(v, dtype=np.float16).tolist()
        recs.append(rec)

    new_table = pa.Table.from_pylist(recs)
    path.parent.mkdir(parents=True, exist_ok=True)

    if path.exists():
        # Read and combine approach (simpler, but reads entire file)
        existing_table = pq.read_table(path)
        combined_table = pa.concat_tables([existing_table, new_table])
        pq.write_table(combined_table, path, compression="snappy", 
                      use_dictionary=True, write_statistics=True)
    else:
        # Create new file
        pq.write_table(new_table, path, compression="snappy", 
                      use_dictionary=True, write_statistics=True)


# === Full YAMNet embedding pass ===
if not DRY_RUN and len(df) > 0:
    import time, gc

    total_rows = len(df)
    df_iter = df.itertuples(index=False)
    rows_buf = {}  # {Path -> [records]}
    processed, errors = 0, 0
    t0 = time.time()

    print(f"\n▶ Starting YAMNet embedding pass")
    print(f"   Total windows: {total_rows:,}")
    print(f"   Output directory: {OUT_ROOT}\n")

    OUT_ROOT.mkdir(parents=True, exist_ok=True)

    try:
        while processed < total_rows:
            batch_waves, batch_meta = [], []

            # === Load one batch of windows ===
            for _ in range(min(BATCH_SIZE, total_rows - processed)):
                try:
                    r = next(df_iter)
                except StopIteration:
                    break
                try:
                    wav_path = getattr(r, PATH_COL)
                    start_s = float(r.start_s)
                    # compute end_s if missing in sampled parquet
                    end_s_attr = getattr(r, 'end_s', None)
                    if end_s_attr is None:
                        dur = float(getattr(r, 'duration_s', CLIP_SECONDS))
                        end_s = start_s + dur
                    else:
                        end_s = float(end_s_attr)

                    y = read_window(wav_path, start_s, end_s, target_sr=RESAMPLE_HZ)
                    batch_waves.append(y)
                    batch_meta.append((wav_path, start_s, end_s))
                except Exception as e:
                    errors += 1
                    if errors <= 5:
                        print(f"⚠️ Read error: {e}")
                    continue

            if not batch_waves:
                break

            # === Compute embeddings ===
            embs = yamnet_embeddings(batch_waves)  # [B, 1024]

            # === Accumulate rows ===
            for (wav_path, start_s, end_s), emb in zip(batch_meta, embs):
                logger, date = detect_logger_date(wav_path)
                rec = {
                    "filepath": wav_path,
                    "start_s": start_s,
                    "end_s": end_s,
                    "logger": logger,
                    "date": date,
                    "yamnet_1024": emb.astype(DTYPE_OUT),  # float16
                }
                shard = shard_path(OUT_ROOT, logger, date)
                rows_buf.setdefault(shard, []).append(rec)

            processed += len(batch_waves)

            # === Flush periodically (reduced frequency to minimize file operations) ===
            if sum(len(v) for v in rows_buf.values()) >= 5000:  # Increased from 2000
                print(f"💾 Flushing buffer at {processed:,} clips...")
                for spath, rows in list(rows_buf.items()):
                    if rows:
                        try:
                            append_parquet_rows(spath, rows)
                            rows_buf[spath] = []
                        except Exception as e:
                            print(f"⚠️ Error writing to {spath}: {e}")
                gc.collect()

            # === Progress log ===
            if processed % (BATCH_SIZE * 10) == 0 or processed >= total_rows:
                dt = time.time() - t0
                rps = processed / max(1e-6, dt)
                print(f"✅ {processed:,}/{total_rows:,} done ({processed/total_rows*100:.1f}%) | "
                      f"{rps:.1f} clips/s | Errors: {errors}")

        # === Final flush ===
        print("\n💾 Final flush...")
        for spath, rows in list(rows_buf.items()):
            if rows:
                try:
                    append_parquet_rows(spath, rows)
                except Exception as e:
                    print(f"⚠️ Error in final flush for {spath}: {e}")
        print("✅ All embeddings written successfully!")

    except KeyboardInterrupt:
        print("\n🟡 Interrupted — flushing buffers before exit...")
        for spath, rows in list(rows_buf.items()):
            if rows:
                try:
                    append_parquet_rows(spath, rows)
                except Exception as e:
                    print(f"⚠️ Error flushing {spath}: {e}")
        print("   Buffers flushed. Partial progress saved.")

    # === Summary ===
    dt = time.time() - t0
    print(f"\n🏁 Done: {processed:,} processed | {errors:,} errors | "
          f"{processed / max(1, dt):.1f} clips/s | Total time: {dt / 60:.1f} min")
else:
    print("DRY_RUN=True → Skipping embedding write.")


▶ Starting YAMNet embedding pass
   Total windows: 50,000
   Output directory: /home/sparch/Uni-stuff/semester-2/applied_Ml/reef_zmsc/data/features/embeds_yamnet_50k



2025-10-27 06:41:20.131795: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:473] Loaded cuDNN version 91002


✅ 640/50,000 done (1.3%) | 33.7 clips/s | Errors: 0
✅ 1,280/50,000 done (2.6%) | 34.4 clips/s | Errors: 0
✅ 1,920/50,000 done (3.8%) | 34.6 clips/s | Errors: 0
✅ 2,560/50,000 done (5.1%) | 35.0 clips/s | Errors: 0
✅ 3,200/50,000 done (6.4%) | 35.2 clips/s | Errors: 0
✅ 3,840/50,000 done (7.7%) | 35.6 clips/s | Errors: 0
✅ 4,480/50,000 done (9.0%) | 35.6 clips/s | Errors: 0
💾 Flushing buffer at 5,056 clips...
✅ 5,120/50,000 done (10.2%) | 35.2 clips/s | Errors: 0
✅ 5,760/50,000 done (11.5%) | 35.2 clips/s | Errors: 0
✅ 6,400/50,000 done (12.8%) | 35.2 clips/s | Errors: 0
✅ 7,040/50,000 done (14.1%) | 35.2 clips/s | Errors: 0
✅ 7,680/50,000 done (15.4%) | 35.3 clips/s | Errors: 0
✅ 8,320/50,000 done (16.6%) | 35.3 clips/s | Errors: 0
✅ 8,960/50,000 done (17.9%) | 35.3 clips/s | Errors: 0
✅ 9,600/50,000 done (19.2%) | 35.3 clips/s | Errors: 0
💾 Flushing buffer at 10,112 clips...
✅ 10,240/50,000 done (20.5%) | 35.3 clips/s | Errors: 0
✅ 10,880/50,000 done (21.8%) | 35.3 clips/s | Errors: 0