## Preparing equal data split for 4 experiments

In [4]:
import sys
import os
import warnings
warnings.filterwarnings("ignore", message=".*weights_only=False.*") # ignore warning from torch for loading models
os.environ["RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE"] = "1" 
# 1. Get the path to the directory above the current one (i.e., the project root)
# os.getcwd() gives '.../project_root/notebook'
# '..' steps up to '.../project_root'
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))

if project_root not in sys.path:
    sys.path.insert(0, project_root)
    
from config import *
print(project_root)

import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import pickle as pkl
from omegaconf import OmegaConf
from einops import rearrange
from tqdm import tqdm
import importlib
import random

from src.dataset.datasets.mm_base import build_mm_datasets

/data/code/jon/project-2-gradient_tri_scent


In [5]:
SEED = 42
TEST_FRAC = 0.2  # 80/20
OUT_NAME = f"two_folds_seed{SEED}_tissue_split.pkl"


def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)

In [6]:
def make_two_folds(tissue_ids, seed=42, test_frac=0.2):
    """
    Make 2 deterministic folds (tissue-level), both using ~ (1-test_frac) train.
    Fold2 test set is a rotated chunk so it is different from fold1.
    """
    tissue_ids = list(tissue_ids)
    rnd = random.Random(seed)
    rnd.shuffle(tissue_ids)

    n = len(tissue_ids)
    n_test = max(1, int(round(test_frac * n)))
    n_train = n - n_test

    # Fold 1: last chunk is test
    fold1_train = tissue_ids[:n_train]
    fold1_test = tissue_ids[n_train:]

    # Fold 2: rotate by test chunk size so test chunk changes
    rotated = tissue_ids[n_test:] + tissue_ids[:n_test]
    fold2_train = rotated[:n_train]
    fold2_test = rotated[n_train:]

    folds = {
        "seed": seed,
        "test_frac": test_frac,
        "n_total": n,
        "fold1": {"train": fold1_train, "test": fold1_test},
        "fold2": {"train": fold2_train, "test": fold2_test},
    }
    return folds

In [None]:
set_seed(SEED)

# ----------------------------
# Build dataset configs exactly like you do
# ----------------------------
base_cfg = OmegaConf.load("../src/dataset/configs/base_config.yaml")
base_cfg.marker_embedding_dir = "../src/dataset/esm2_t30_150M_UR50D"
marker_embeddings = load_marker_embeddings(base_cfg.marker_embedding_dir)
orion_subset_cfg = OmegaConf.load("../src/dataset/configs/orion_subset.yaml")
ds_cfg = OmegaConf.merge(base_cfg, orion_subset_cfg)

ds = build_mm_datasets(ds_cfg)

tissue_ids = ds[0].unimodal_datasets["cycif"].get_tissue_ids()
print(f"Total tissues: {len(tissue_ids)}")

folds = make_two_folds(tissue_ids, seed=SEED, test_frac=TEST_FRAC)

# ----------------------------
# Sanity checks: no overlap
# ----------------------------
f1_tr = set(folds["fold1"]["train"])
f1_te = set(folds["fold1"]["test"])
f2_tr = set(folds["fold2"]["train"])
f2_te = set(folds["fold2"]["test"])

print("\nSanity checks:")
print("Fold1 sizes:", len(f1_tr), len(f1_te), "overlap:", len(f1_tr & f1_te))
print("Fold2 sizes:", len(f2_tr), len(f2_te), "overlap:", len(f2_tr & f2_te))

# optional: how different are the two test sets?
print("Fold1 test ∩ Fold2 test:", len(f1_te & f2_te))

# ----------------------------
# Save into checkpoints
# ----------------------------
checkpoint_dir = os.path.join(project_root, "notebooks", "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)

out_path = os.path.join(checkpoint_dir, OUT_NAME)
with open(out_path, "wb") as f:
    pkl.dump(folds, f)

print(f"\nSaved folds to: {out_path}")

[32m2025-12-17 21:16:47.997[0m | [34m[1mDEBUG   [0m | [36msrc.dataset.datasets.mm_base[0m:[36m__init__[0m:[36m44[0m - [34m[1mLoading dataset orion from /data/virtues_orion_dataset/virtues_example/orion_subset[0m
[32m2025-12-17 21:16:48.001[0m | [1mINFO    [0m | [36msrc.dataset.datasets.multiplex_base[0m:[36m__init__[0m:[36m30[0m - [1mMultiplex Normalization metadata: QuantileMultiplexNormalizeMetadata(normalizer_name='q_99', rnd_crop_folder_name='random_crops_256_no_log', channel_file_name='channels', mean_name='mean', std_name='std', quantile_path='quantiles/q99.csv')[0m
[32m2025-12-17 21:16:48.036[0m | [1mINFO    [0m | [36msrc.dataset.datasets.multiplex_base[0m:[36m__init__[0m:[36m40[0m - [1mCrop folder /data/virtues_orion_dataset/virtues_example/orion_subset/cycif/random_crops_256_no_log exists[0m
  self.channels_per_image.fillna(1, inplace=True)
  self.channels_per_image = self.channels_per_image.replace(1, True)
[32m2025-12-17 21:16:48.053[0

Total tissues: 35

Sanity checks:
Fold1 sizes: 28 7 overlap: 0
Fold2 sizes: 28 7 overlap: 0
Fold1 test ∩ Fold2 test: 0

Saved folds to: /data/code/jon/project-2-gradient_tri_scent/notebooks/checkpoints/two_folds_seed42_tissue_split.pkl
