In [1]:
import os
from pathlib import Path
os.chdir(Path.cwd().parent)   # go one level up
print(os.getcwd())            # check

# pip install xflow-py
from xflow import ConfigManager, SqlProvider, PyTorchPipeline, show_model_info
from xflow.data import build_transforms_from_config
from xflow.utils import load_validated_config, save_image
import xflow.extensions.physics

import torch
import os
import tarfile
from datetime import datetime  
from config_utils import load_config
from utils import *

# Create experiment output directory  (timestamped)
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")  

experiment_name = "CLEAR_make_dataset"  # TM, SHL_DNN, U_Net, Pix2pix, ERN, CAE, SwinT, CAE_syth
folder_name = f"{experiment_name}-{timestamp}"  
config_manager = ConfigManager(load_config(f"{experiment_name}.yaml", 
                                           experiment_name=folder_name))
config = config_manager.get()
config_manager.add_files(config["extra_files"])
experiment_output_dir = config["paths"]["output"]
        
def make_dataset(provider, transforms):
    pipeline = PyTorchPipeline(provider, transforms)
    dataset = pipeline.to_memory_dataset(config["data"]["dataset_ops"])
    return dataset, pipeline.in_memory_sample_count

/Users/andrewxu/Documents/GitHub/fiber-image-reconstruction
[config_utils] Using machine profile: mac-andrewxu


In [None]:
# ==================== 
# Prepare Dataset (Wednesday Chromox)
# ====================

test_dir = config["paths"]["chromox_01"]
# Create SqlProvider to query the database
db_path = f"{test_dir}/db/dataset_meta.db"
query = """
SELECT 
    image_path
FROM mmf_dataset_metadata 
WHERE batch IN (10, 11, 12)
--LIMIT 20
"""
realbeam_provider = SqlProvider(
    sources={"connection": db_path, "sql": query}, output_config={'list': "image_path"}
)
train_provider, evaluation_provider = realbeam_provider.split(ratio=config["data"]["train_val_split"], seed=config["seed"])
val_provider, test_provider = evaluation_provider.split(ratio=config["data"]["val_test_split"], seed=config["seed"])

# For train dataset
config["data"]["transforms"]["torch"].insert(0, {
    "name": "add_parent_dir",
    "params": {
        "parent_dir": test_dir
    }
})
transforms = build_transforms_from_config(config["data"]["transforms"]["torch"])
train_dataset, n1 = make_dataset(train_provider, transforms)
val_dataset, n2 = make_dataset(val_provider, transforms)
test_dataset, n3 = make_dataset(test_provider, transforms)

In [None]:
# ==================== 
# Prepare Dataset (Friday + Saturday Chromox)
# ====================
test_dir = config["paths"]["chromox_02"]
db_path = f"{test_dir}/db/dataset_meta.db"
query = f"""
SELECT
    '{test_dir}' || '/' || image_path AS image_path
FROM mmf_dataset_metadata
ORDER BY RANDOM()
LIMIT 20
"""
realbeam_provider = SqlProvider(
    sources={"connection": db_path, "sql": query}, output_config={'list': "image_path"}
)

test_dir = config["paths"]["chromox_03"]
db_path = f"{test_dir}/db/dataset_meta.db"
query = f"""
SELECT
    '{test_dir}' || '/' || image_path AS image_path
FROM mmf_dataset_metadata
ORDER BY RANDOM()
LIMIT 20
"""
realbeam_provider = realbeam_provider.merge(SqlProvider(
    sources={"connection": db_path, "sql": query}, output_config={'list': "image_path"}
))

transforms = build_transforms_from_config(config["data"]["transforms"]["torch"])
realbeam_provider, n1 = make_dataset(realbeam_provider, transforms)

In [None]:
# ==================== 
# Prepare Dataset (WWednesday DMD)
# ====================
test_dir = config["paths"]["dmd_01"]
db_path = f"{test_dir}/db/dataset_meta.db"
query = f"""
SELECT
    '{test_dir}' || '/' || image_path AS image_path
FROM mmf_dataset_metadata
ORDER BY RANDOM()
LIMIT 20
"""
realbeam_provider = SqlProvider(
    sources={"connection": db_path, "sql": query}, output_config={'list': "image_path"}
)

transforms = build_transforms_from_config(config["data"]["transforms"]["torch"])
realbeam_provider, n1 = make_dataset(realbeam_provider, transforms)

In [None]:
# ==================== 
# Prepare Dataset (YAG screen)
# ====================

test_dir = config["paths"]["test_set"]
# Create SqlProvider to query the database
db_path = f"{test_dir}/db/dataset_meta.db"
query = """
SELECT 
    image_path
FROM mmf_dataset_metadata 
WHERE batch IN (1, 7)
--LIMIT 20
"""
realbeam_provider = SqlProvider(
    sources={"connection": db_path, "sql": query}, output_config={'list': "image_path"}
)
train_provider, evaluation_provider = realbeam_provider.split(ratio=config["data"]["train_val_split"], seed=config["seed"])
val_provider, test_provider = evaluation_provider.split(ratio=config["data"]["val_test_split"], seed=config["seed"])

# For train dataset
config["data"]["transforms"]["torch"].insert(0, {
    "name": "add_parent_dir",
    "params": {
        "parent_dir": test_dir
    }
})
transforms = build_transforms_from_config(config["data"]["transforms"]["torch"])
train_dataset, n1 = make_dataset(train_provider, transforms)
val_dataset, n2 = make_dataset(val_provider, transforms)
test_dataset, n3 = make_dataset(test_provider, transforms)

In [None]:
print("Total samples in providers: ",len(train_provider),len(val_provider),len(test_provider))
print("Total samples in datasets:", n1, n2, n3)
print("Batch: ",len(train_dataset),len(val_dataset),len(test_dataset))