In [23]:
from pathlib import Path
import torch
import numpy as np
import sys
from sklearn.model_selection import StratifiedGroupKFold
from typing import List, Tuple, Optional

PROJECT_ROOT = Path("..").resolve()

print(f"Project root: {PROJECT_ROOT}")
print(f"CWD: {Path.cwd()}")
sys.path.append(str(PROJECT_ROOT))

from data.dataset import image_loader

Project root: /home/hshi/Documents/researchproject/aihab/repo/aihab-clip
CWD: /home/hshi/Documents/researchproject/aihab/repo/aihab-clip/notebooks


In [30]:
cfg = {
    "dataset_paths": [PROJECT_ROOT / "data" / "CS_Xplots_2019_2023_train"], 
    "index_file_names": [PROJECT_ROOT / "data" / "CS_Xplots_2019_2023_train" / "CS_Xplots_2019_23_NEW02OCT24.csv"],    # dataset id used in feature folder name
    "preprocessing": {"resize": 256},
    "seed": 1,
    "val_split": 0.1
}
print(cfg["dataset_paths"], "\n", cfg["index_file_names"])


[PosixPath('/home/hshi/Documents/researchproject/aihab/repo/aihab-clip/data/CS_Xplots_2019_2023_train')] 
 [PosixPath('/home/hshi/Documents/researchproject/aihab/repo/aihab-clip/data/CS_Xplots_2019_2023_train/CS_Xplots_2019_23_NEW02OCT24.csv')]


In [15]:
images_tr, labels_tr, plot_word_labels_tr, poly_labels_tr, poly_word_labels_tr, file_names_tr, plot_idx_tr, src_tr = \
        image_loader(cfg['dataset_paths'], cfg['index_file_names'], cfg["preprocessing"]["resize"], verbose=False)

Loading images from /home/hshi/Documents/researchproject/aihab/repo/aihab-clip/data/CS_Xplots_2019_2023_train:   0%|          | 0/4233 [00:00<?, ?file/s]

Loading images from /home/hshi/Documents/researchproject/aihab/repo/aihab-clip/data/CS_Xplots_2019_2023_train: 100%|██████████| 4233/4233 [01:24<00:00, 50.01file/s]


In [24]:
def _stratified_group_split_indices(labels: np.ndarray,
                                   groups: np.ndarray,
                                   val_ratio: float,
                                   seed: int) -> Tuple[np.ndarray, np.ndarray]:
    """
    StratifiedGroup split: preserve class balance while keeping grouped samples together
    (here, group = plot_idx). Uses StratifiedGroupKFold to approximate the requested split.
    """
    labels = np.asarray(labels)
    groups = np.asarray(groups)
    if val_ratio <= 0:
        return np.arange(len(labels), dtype=np.int64), np.array([], dtype=np.int64)

    n_splits = max(2, int(round(1.0 / val_ratio)))
    sgkf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=seed)
    train_idx, val_idx = next(sgkf.split(labels, labels, groups=groups))
    return train_idx.astype(np.int64), val_idx.astype(np.int64)

In [31]:
seed = int(cfg.get('seed', 1))
rng = np.random.RandomState(seed)
val_ratio = float(cfg.get('val_split', 0.1))

train_pool_idx, val_idx = _stratified_group_split_indices(labels_tr, plot_idx_tr, val_ratio, seed)



In [36]:
print(np.max(val_idx))

4187
