```
conda init
conda create -n get-train-test-splits python=3.10 -y
conda activate get-train-test-splits
pip install torch==2.6.0+cu118 torchaudio==2.6.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
pip install scikit-learn pandas ipykernel tqdm
```

In [1]:
import torch
import pandas as pd
from sklearn.model_selection import train_test_split
from pathlib import Path

In [2]:
DATA_DIR = Path("../VoxCeleb2")

# Inputs
JPG_EMB_PATH = DATA_DIR / "full/jpg_embeddings.pt"
JPG_CSV_PATH = DATA_DIR / "full/jpg_metadata.csv"
WAV_EMB_PATH = DATA_DIR / "full/wav_embeddings.pt"
WAV_GPT_PATH = DATA_DIR / "full/wav_gpt_latents.pt"
WAV_CSV_PATH = DATA_DIR / "full/wav_metadata.csv"

# Outputs
TRAIN_DIR = DATA_DIR / "train"
TEST_DIR = DATA_DIR / "test"
TRAIN_DIR.mkdir(exist_ok=True)
TEST_DIR.mkdir(exist_ok=True)

In [3]:
# Load metadata and embeddings
jpg_df = pd.read_csv(JPG_CSV_PATH)
wav_df = pd.read_csv(WAV_CSV_PATH)

jpg_emb = torch.load(JPG_EMB_PATH)
wav_emb = torch.load(WAV_EMB_PATH)
wav_gpt = torch.load(WAV_GPT_PATH)

In [4]:
# Speaker id split
unique_ids = jpg_df["id"].unique()

train_ids, test_ids = train_test_split(
    unique_ids, 
    test_size=0.2, 
    random_state=42,
    shuffle=True
)

In [5]:
def split_df(df, emb_tensor, train_ids, test_ids):
    train_mask = df["id"].isin(train_ids)
    test_mask = df["id"].isin(test_ids)

    train_df = df[train_mask].reset_index(drop=True)
    test_df  = df[test_mask].reset_index(drop=True)

    train_emb = emb_tensor[train_mask.values]
    test_emb  = emb_tensor[test_mask.values]

    return train_df, train_emb, test_df, test_emb

In [6]:
# JPG split
jpg_train_df, jpg_train_emb, jpg_test_df, jpg_test_emb = split_df(
    jpg_df, jpg_emb, train_ids, test_ids
)

In [7]:
# WAV split (speaker embedding)
wav_train_df, wav_train_emb, wav_test_df, wav_test_emb = split_df(
    wav_df, wav_emb, train_ids, test_ids
)

In [8]:
# WAV split (gpt latents)
wav_train_gpt, wav_test_gpt = (
    wav_gpt[wav_df["id"].isin(train_ids).values],
    wav_gpt[wav_df["id"].isin(test_ids).values]
)

In [9]:
def pct(a, b):
    return 100 * a / max(1, b)

In [11]:
# Unique speaker counts
num_speakers = len(unique_ids)
num_train_speakers = len(train_ids)
num_test_speakers = len(test_ids)

print(f"Speakers total: {num_speakers}")
print(f"Train speakers: {num_train_speakers} ({pct(num_train_speakers, num_speakers):.2f}%)")
print(f"Test speakers : {num_test_speakers}  ({pct(num_test_speakers, num_speakers):.2f}%)")

# Data point counts
jpg_total = len(jpg_df)
jpg_train_n = len(jpg_train_df)
jpg_test_n = len(jpg_test_df)

wav_total = len(wav_df)
wav_train_n = len(wav_train_df)
wav_test_n = len(wav_test_df)

print("\n--- JPG Data Points ---")
print(f"Total JPG: {jpg_total}")
print(f"Train JPG: {jpg_train_n} ({pct(jpg_train_n, jpg_total):.2f}%)")
print(f"Test JPG : {jpg_test_n}  ({pct(jpg_test_n, jpg_total):.2f}%)")

print("\n--- WAV Data Points ---")
print(f"Total WAV: {wav_total}")
print(f"Train WAV: {wav_train_n} ({pct(wav_train_n, wav_total):.2f}%)")
print(f"Test WAV : {wav_test_n}  ({pct(wav_test_n, wav_total):.2f}%)")

Speakers total: 118
Train speakers: 94 (79.66%)
Test speakers : 24  (20.34%)

--- JPG Data Points ---
Total JPG: 36237
Train JPG: 29498 (81.40%)
Test JPG : 6739  (18.60%)

--- WAV Data Points ---
Total WAV: 36237
Train WAV: 29498 (81.40%)
Test WAV : 6739  (18.60%)


In [12]:
# Save Train Splits
jpg_train_df.to_csv(TRAIN_DIR / "jpg_train_metadata.csv", index=False)
torch.save(jpg_train_emb, TRAIN_DIR / "jpg_train_embeddings.pt")

wav_train_df.to_csv(TRAIN_DIR / "wav_train_metadata.csv", index=False)
torch.save(wav_train_emb, TRAIN_DIR / "wav_train_embeddings.pt")
torch.save(wav_train_gpt, TRAIN_DIR / "wav_train_gpt_latents.pt")

In [13]:
# Save Test Splits
jpg_test_df.to_csv(TEST_DIR / "jpg_test_metadata.csv", index=False)
torch.save(jpg_test_emb, TEST_DIR / "jpg_test_embeddings.pt")

wav_test_df.to_csv(TEST_DIR / "wav_test_metadata.csv", index=False)
torch.save(wav_test_emb, TEST_DIR / "wav_test_embeddings.pt")
torch.save(wav_test_gpt, TEST_DIR / "wav_test_gpt_latents.pt")