In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np

# ------------------------
# 1. Fake dataframe
# ------------------------
df = pd.DataFrame({
    "uid": [f"id{i}" for i in range(5)],
    "path": [f"fake_path_{i}.wav" for i in range(5)],
    "cog": np.random.randint(3000, 7000, size=5),
    "fri_dur": np.random.randint(50, 200, size=5),
    "word": [f"word{i}" for i in range(5)],
    "consonant": ["s", "t", "k", "p", "m"],
    "vowel": ["a", "i", "u", "e", "o"],
    "train": [True, True, False, False, True],
    "label": ["A", "B", "C", "D", "E"],
    "label_idx": [0, 1, 2, 3, 4],
})

# ------------------------
# 2. Dataset
# ------------------------
class FakeDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        # pretend "data_tensor" is a 10-dim feature vector
        data_tensor = torch.randn(10)

        info = {
            "uid": row["uid"],
            "path": row["path"],
            "cog": row["cog"],
            "fri_dur": row["fri_dur"],
            "word": row["word"],
            "consonant": row["consonant"],
            "vowel": row["vowel"],
            "train": row["train"],
            "label": row["label"],
            "label_idx": row["label_idx"],
        }
        return data_tensor, info

# ------------------------
# 3. Collate function
# ------------------------
def collate_with_info(batch):
    xs, infos = zip(*batch)
    x = torch.stack(xs, dim=0)  # [B, feat_dim]
    out = {"input": x}
    keys = infos[0].keys()
    for k in keys:
        out[k] = [info[k] for info in infos]
    return out

# ------------------------
# 4. Dummy model
# ------------------------
class DummyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(10, 4)  # map 10-dim -> 4-dim

    def forward(self, x):
        return self.fc(x)

# ------------------------
# 5. Loader + run collection
# ------------------------
dataset = FakeDataset(df)
loader = DataLoader(dataset, batch_size=2, collate_fn=collate_with_info)

# model = DummyModel()
# device = "cpu"

# from pathlib import Path
# save_dir = Path("./test_outputs")
# save_dir.mkdir(exist_ok=True)

# # use the collection function we wrote earlier
# from pathlib import Path
# collect_to_npy_and_csv(
#     model,
#     loader,
#     device,
#     npy_path=save_dir / "vecs.npy",
#     csv_path=save_dir / "meta.csv"
# )
