# Self-supervised gap-filling data loader

This notebook loads the precomputed similarity tensors saved as `.pt` files,
reads the manifest produced by the previous pipeline, and builds train/eval
DataLoaders.


In [1]:
from __future__ import annotations

from pathlib import Path
import csv
import random

try:
    import torch
    from torch.utils.data import Dataset, DataLoader, Subset
except ImportError as exc:
    raise ImportError(
        "This notebook requires torch. Install it before running."
    ) from exc

DATA_ROOT = Path("data")
PRECOMPUTED_DIR = None  # set Path("data/precomputed_L100") if needed
RANDOM_SEED = 42
TRAIN_FRACTION = 0.8
BATCH_SIZE = 64
NUM_WORKERS = 0
RETURN_MASK = False
ADD_CHANNEL = False


In [2]:
def _extract_len(path: Path) -> int:
    prefix = "precomputed_L"
    if not path.name.startswith(prefix):
        return -1
    try:
        return int(path.name[len(prefix) :])
    except ValueError:
        return -1


def _find_precomputed(root: Path) -> Path:
    candidates = sorted(root.glob("precomputed_L*"), key=_extract_len)
    if not candidates:
        raise FileNotFoundError("No precomputed_L* directory found in data/.")
    return candidates[-1]


if PRECOMPUTED_DIR is None:
    PRECOMPUTED_DIR = _find_precomputed(DATA_ROOT)

MANIFEST_PATH = PRECOMPUTED_DIR / "manifest.csv"
if not MANIFEST_PATH.exists():
    raise FileNotFoundError(f"Missing manifest: {MANIFEST_PATH}")

print(f"Using precomputed dir: {PRECOMPUTED_DIR}")
print(f"Manifest: {MANIFEST_PATH}")

with MANIFEST_PATH.open("r", newline="") as handle:
    reader = csv.DictReader(handle)
    manifest_records = list(reader)

ok_records = [r for r in manifest_records if r["status"] == "ok"]
print(f"Total records: {len(manifest_records)}")
print(f"OK records: {len(ok_records)}")


Using precomputed dir: data/precomputed_L100
Manifest: data/precomputed_L100/manifest.csv
Total records: 6631
OK records: 6631


In [3]:
class SimilarityDataset(Dataset):
    def __init__(self, records, return_mask: bool = False, add_channel: bool = False):
        self.records = list(records)
        self.return_mask = return_mask
        self.add_channel = add_channel

        if not self.records:
            raise ValueError("No records provided to the dataset.")

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx: int):
        rec = self.records[idx]
        sim = torch.load(rec["sim_path"])
        if self.add_channel:
            sim = sim.unsqueeze(0)

        if self.return_mask:
            n_used = int(rec["n_used"]) if rec["n_used"] else sim.shape[-1]
            mask = torch.zeros_like(sim, dtype=torch.bool)
            mask[..., :n_used, :n_used] = True
            return sim, mask
        return sim


dataset = SimilarityDataset(ok_records, return_mask=RETURN_MASK, add_channel=ADD_CHANNEL)
indices = list(range(len(dataset)))
random.Random(RANDOM_SEED).shuffle(indices)

split = int(TRAIN_FRACTION * len(indices))
train_idx, eval_idx = indices[:split], indices[split:]

train_ds = Subset(dataset, train_idx)
eval_ds = Subset(dataset, eval_idx)

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
)
eval_loader = DataLoader(
    eval_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
)

print(f"Train samples: {len(train_ds)}")
print(f"Eval samples: {len(eval_ds)}")

batch = next(iter(train_loader))
if RETURN_MASK:
    sim_batch, mask_batch = batch
    print("Batch shapes:", sim_batch.shape, mask_batch.shape)
else:
    print("Batch shape:", batch.shape)


Train samples: 5304
Eval samples: 1327
Batch shape: torch.Size([64, 100, 100])
