## Patient split

In [None]:
import os
import pickle
from collections import defaultdict

In [None]:
DATA_ROOT = "/data/scratch/qc25022/pancreas/tokenised_data_word_level/cprd_upgi"
SPLITS = ["train", "tuning", "held_out"]  # add any extra splits you used

def collect_subjects(split):
    split_dir = os.path.join(DATA_ROOT, split)
    subjects = set()
    for fname in os.listdir(split_dir):
        if not fname.endswith(".pkl"):
            continue
        with open(os.path.join(split_dir, fname), "rb") as f:
            for record in pickle.load(f):
                subjects.add(record["subject_id"])
    return subjects

split_subjects = {split: collect_subjects(split) for split in SPLITS}

# pairwise intersection report
for a in SPLITS:
    for b in SPLITS:
        if a >= b:
            continue
        overlap = split_subjects[a] & split_subjects[b]
        print(f"{a} âˆ© {b}: {len(overlap)}")
        if overlap:
            print(sorted(list(overlap))[:20], "...")  # sample IDs if debugging

## Trajectory Length

import numpy as np
from src.data.unified_dataset import UnifiedEHRDataset

In [None]:
def describe_lengths(split, cutoff):
    dataset = UnifiedEHRDataset(
        data_dir=DATA_ROOT,
        vocab_file=VOCAB,
        labels_file=LABELS,
        medical_lookup_file=MEDICAL,
        lab_lookup_file=LAB,
        region_lookup_file=REGION,
        time_lookup_file=TIME,
        cutoff_months=cutoff,
        format="text",
        split=split,
        max_sequence_length=None,
    )
    lengths_chars = []
    lengths_tokens = []
    for item in dataset:
        if item is None:
            continue
        text = item["text"]
        lengths_chars.append(len(text))
        lengths_tokens.append(len(text.split()))  # crude word count; swap with tokenizer.encode if desired
    summary = lambda arr: dict(count=len(arr), mean=np.mean(arr), p95=np.percentile(arr, 95), max=max(arr))
    return summary(lengths_chars), summary(lengths_tokens)

for split in ["train", "tuning", "held_out"]:
    char_stats, token_stats = describe_lengths(split, cutoff=12)
    print(f"{split} char stats: {char_stats}")
    print(f"{split} token stats: {token_stats}")