In [6]:
# pip install xflow-py
from xflow import ConfigManager, FileProvider, PyTorchPipeline, show_model_info
from xflow.data import build_transforms_from_config
from xflow.utils import load_validated_config, save_image
import xflow.extensions.physics

import torch
import os

os.chdir('..')
from datetime import datetime  
from config_utils import load_config
from utils import *


# ==================== 
# Configuration
# ==================== 

# Create experiment output directory  (timestamped)
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")  

experiment_name = "CAE"  # TM, SHL_DNN, U_Net, Pix2pix, ERN, CAE, SwinT, CAE_syth
config_manager = ConfigManager(load_config(f"{experiment_name}.yaml"))
config = config_manager.get()
config

[config_utils] Using machine profile: mac-andrewxu


{'paths': {'project_root': '.',
  'datasets': {'mmf': '/Users/andrewxu/Documents/DataHub/local_images/MMF',
   'syns': '/Users/andrewxu/Documents/DataHub/datasets/2024-08-15/dataset/1'},
  'output_root': './results/CAE-20251023120443',
  'output': './results/CAE-20251023120443',
  'dataset': '/Users/andrewxu/Documents/DataHub/local_images/MMF'},
 'seed': 500,
 'name': 'CAE.yaml',
 'framework': 'pytorch',
 'model': {'in_channels': 1,
  'out_channels': 1,
  'kernel_size': 4,
  'encoder': [64, 128, 128, 256, 512, 512],
  'decoder': [512, 512, 256, 128, 128, 64],
  'apply_batchnorm': [0, 1, 1, 1, 1, 1],
  'apply_dropout': [1, 1, 0, 0, 0, 0],
  'final_activation': 'sigmoid'},
 'training': {'epochs': 100, 'learning_rate': 0.0001, 'batch_size': 32},
 'callbacks': [{'name': 'torch_batch_progress_bar',
   'params': {'only_keys': ['train_loss', 'val_loss']}},
  {'name': 'torch_early_stopping',
   'params': {'monitor': 'val_loss', 'patience': 20}},
  {'name': 'torch_image_reconstruction_callback'

In [7]:
# ==================== 
# Prepare Dataset
# ====================
training_folder = os.path.join(config["paths"]["dataset"], config["data"]["training_set"])
evaluation_folder = os.path.join(config["paths"]["dataset"], config["data"]["evaluation_set"])
train_provider = FileProvider(training_folder).subsample(fraction=config["data"]["subsample_fraction"], seed=config["seed"]) 
evaluation_provider = FileProvider(evaluation_folder).subsample(fraction=config["data"]["subsample_fraction"], seed=config["seed"]) 
val_provider, test_provider = evaluation_provider.split(ratio=config["data"]["val_test_split"], seed=config["seed"])

transforms = build_transforms_from_config(config["data"]["transforms"]["torch"])
def make_dataset(provider):
    return PyTorchPipeline(provider, transforms).to_memory_dataset(config["data"]["dataset_ops"])

train_dataset = make_dataset(train_provider)
val_dataset = make_dataset(val_provider)
test_dataset = make_dataset(test_provider)

print("Samples: ",len(train_provider),len(val_provider),len(test_provider))
print("Batch: ",len(train_dataset),len(val_dataset),len(test_dataset))

Samples:  5200 549 549
Batch:  163 18 18
Batch shapes: torch.Size([32, 1, 256, 256]), torch.Size([32, 1, 256, 256])
