# Configuration

In [None]:
import os
from pathlib import Path
os.chdir(Path.cwd().parent)   # go one level up
print(os.getcwd())            # check

# pip install xflow-py
from xflow import ConfigManager, SqlProvider, PyTorchPipeline, show_model_info
from xflow.data import build_transforms_from_config
from xflow.utils import load_validated_config, save_image
import xflow.extensions.physics

import torch
import os
import tarfile
from datetime import datetime  
from config_utils import load_config
from utils import *


# Create experiment output directory  (timestamped)
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")  

experiment_name = "CAE_validate_clear"  # TM, SHL_DNN, U_Net, Pix2pix, ERN, CAE, SwinT, CAE_syth
model_name = "CAE"
folder_name = f"{experiment_name}-{timestamp}"  
config_manager = ConfigManager(load_config(f"{experiment_name}.yaml", 
                                           experiment_name=folder_name))
config = config_manager.get()
config_manager.add_files(config["extra_files"])

experiment_output_dir = config["paths"]["output"]
os.makedirs(experiment_output_dir, exist_ok=True)


# New structure, read the database table first, get files from it.
# Extract tar file if needed
if config['file_extract']:
    dataset_tar_file = os.path.join(config["paths"]["dataset"], config["data"]["dataset"])
    dataset_base_dir = os.path.dirname(dataset_tar_file)
    dataset_name = os.path.splitext(config["data"]["dataset"])[0]  # Remove .tar extension
    dataset_extracted_dir = os.path.join(dataset_base_dir, dataset_name)

    # Unzip tar file if not already extracted
    if not os.path.exists(dataset_extracted_dir):
        print(f"Extracting {dataset_tar_file}...")
        with tarfile.open(dataset_tar_file, 'r') as tar:
            tar.extractall(path=dataset_base_dir)
        print(f"Extracted to {dataset_extracted_dir}")
    else:
        print(f"Dataset already extracted at {dataset_extracted_dir}")
        

def make_dataset(provider, transforms):
    pipeline = PyTorchPipeline(provider, transforms)
    dataset = pipeline.to_memory_dataset(config["data"]["dataset_ops"])
    return dataset, pipeline.in_memory_sample_count

/Users/andrewxu/Documents/GitHub/fiber-image-reconstruction
[config_utils] Using machine profile: mac-andrewxu
[config_utils] Using machine profile: mac-andrewxu


# Data preparation

In [None]:
# ==================== 
# Prepare Dataset (Laser scan + YAG screen )
# ====================
train_dir = config["paths"]["training_set"]
# Create SqlProvider to query the database
db_path = f"{train_dir}/db/dataset_meta.db"
query = """ 
SELECT 
    image_path
FROM mmf_dataset_metadata 
WHERE batch IN (15)
--LIMIT 300
"""
train_provider = SqlProvider(
    sources={"connection": db_path, "sql": query}, output_config={'list': "image_path"}
)

test_dir = config["paths"]["test_set"]
# Create SqlProvider to query the database
db_path = f"{test_dir}/db/dataset_meta.db"
query = """
SELECT 
    image_path
FROM mmf_dataset_metadata 
WHERE batch IN (1, 7)
--LIMIT 20
"""
evaluation_provider = SqlProvider(
    sources={"connection": db_path, "sql": query}, output_config={'list': "image_path"}
)
val_provider, test_provider = evaluation_provider.split(ratio=config["data"]["val_test_split"], seed=config["seed"])

# Swap traing dataset and evaluation dataset
train_dir, test_dir = test_dir, train_dir
train_provider, evaluation_provider = evaluation_provider, train_provider
val_provider, test_provider = evaluation_provider.split(ratio=config["data"]["val_test_split"], seed=config["seed"])

# For train dataset
config["data"]["transforms"]["torch"].insert(0, {
    "name": "add_parent_dir",
    "params": {
        "parent_dir": train_dir
    }
})
transforms = build_transforms_from_config(config["data"]["transforms"]["torch"])
train_datase, n1 = make_dataset(train_provider, transforms)

# For test datasets
config["data"]["transforms"]["torch"][0]["params"]["parent_dir"] = test_dir
transforms = build_transforms_from_config(config["data"]["transforms"]["torch"])
val_dataset, n2 = make_dataset(val_provider, transforms)
test_dataset, n3 = make_dataset(test_provider, transforms)

In [None]:
# ==================== 
# Prepare Dataset (YAG screen)
# ====================

test_dir = config["paths"]["test_set"]
# Create SqlProvider to query the database
db_path = f"{test_dir}/db/dataset_meta.db"
query = """
SELECT 
    image_path
FROM mmf_dataset_metadata 
WHERE batch IN (1, 7)
--LIMIT 20
"""
realbeam_provider = SqlProvider(
    sources={"connection": db_path, "sql": query}, output_config={'list': "image_path"}
)
train_provider, evaluation_provider = realbeam_provider.split(ratio=config["data"]["train_val_split"], seed=config["seed"])
val_provider, test_provider = evaluation_provider.split(ratio=config["data"]["val_test_split"], seed=config["seed"])

# For train dataset
config["data"]["transforms"]["torch"].insert(0, {
    "name": "add_parent_dir",
    "params": {
        "parent_dir": test_dir
    }
})
transforms = build_transforms_from_config(config["data"]["transforms"]["torch"])
train_dataset, n1 = make_dataset(train_provider, transforms)
val_dataset, n2 = make_dataset(val_provider, transforms)
test_dataset, n3 = make_dataset(test_provider, transforms)

In [2]:
# ==================== 
# Prepare Dataset (Wednesday Chromox)
# ====================

test_dir = config["paths"]["chromox_01"]
# Create SqlProvider to query the database
db_path = f"{test_dir}/db/dataset_meta.db"
query = """
SELECT 
    image_path
FROM mmf_dataset_metadata 
WHERE batch IN (10, 11, 12)
--LIMIT 20
"""
realbeam_provider = SqlProvider(
    sources={"connection": db_path, "sql": query}, output_config={'list': "image_path"}
)
train_provider, evaluation_provider = realbeam_provider.split(ratio=config["data"]["train_val_split"], seed=config["seed"])
val_provider, test_provider = evaluation_provider.split(ratio=config["data"]["val_test_split"], seed=config["seed"])

# For train dataset
config["data"]["transforms"]["torch"].insert(0, {
    "name": "add_parent_dir",
    "params": {
        "parent_dir": test_dir
    }
})
transforms = build_transforms_from_config(config["data"]["transforms"]["torch"])
train_dataset, n1 = make_dataset(train_provider, transforms)
val_dataset, n2 = make_dataset(val_provider, transforms)
test_dataset, n3 = make_dataset(test_provider, transforms)

Loading data into memory:   0%|          | 0/263 [00:00<?, ?it/s]

Failed to preprocess item: Centroid (143.6953582763672, 181.52749633789062) outside rectangle bounds
Failed to preprocess item: Centroid (140.02857971191406, 92.65306091308594) outside rectangle bounds
Failed to preprocess item: Centroid (115.63353729248047, 185.60708618164062) outside rectangle bounds
Failed to preprocess item: Centroid (129.43499755859375, 89.33609008789062) outside rectangle bounds
Failed to preprocess item: Centroid (140.02857971191406, 92.65306091308594) outside rectangle bounds
Failed to preprocess item: Centroid (115.63353729248047, 185.60708618164062) outside rectangle bounds
Failed to preprocess item: Centroid (129.43499755859375, 89.33609008789062) outside rectangle bounds
Failed to preprocess item: Centroid (142.36422729492188, 87.2079849243164) outside rectangle bounds
Failed to preprocess item: Centroid (142.36422729492188, 87.2079849243164) outside rectangle bounds
Failed to preprocess item: Centroid (143.44810485839844, 85.41702270507812) outside rectang

In [None]:
# ==================== 
# Prepare Dataset (Friday + Saturday Chromox)
# ====================

test_dir = config["paths"][""]
# Create SqlProvider to query the database
db_path = f"{test_dir}/db/dataset_meta.db"
query = """
SELECT 
    image_path
FROM mmf_dataset_metadata 
WHERE batch IN (10, 11, 12)
--LIMIT 20
"""
realbeam_provider = SqlProvider(
    sources={"connection": db_path, "sql": query}, output_config={'list': "image_path"}
)
train_provider, evaluation_provider = realbeam_provider.split(ratio=config["data"]["train_val_split"], seed=config["seed"])
val_provider, test_provider = evaluation_provider.split(ratio=config["data"]["val_test_split"], seed=config["seed"])

# For train dataset
config["data"]["transforms"]["torch"].insert(0, {
    "name": "add_parent_dir",
    "params": {
        "parent_dir": test_dir
    }
})
transforms = build_transforms_from_config(config["data"]["transforms"]["torch"])
train_dataset, n1 = make_dataset(train_provider, transforms)
val_dataset, n2 = make_dataset(val_provider, transforms)
test_dataset, n3 = make_dataset(test_provider, transforms)

In [3]:
print("Total samples in providers: ",len(train_provider),len(val_provider),len(test_provider))
print("Total samples in datasets:", n1, n2, n3)
print("Batch: ",len(train_dataset),len(val_dataset),len(test_dataset))

# save a sample from dataset for debugging
if model_name in REGRESSION:
    for left_parts, params, right_parts in test_dataset:
        print(f"Batch shapes: {left_parts.shape}, {right_parts.shape}")
        save_image(left_parts[0], config["paths"]["output"] + "/input.png")
        save_image(right_parts[0], config["paths"]["output"] + "/output.png")
        break
else:
    for index, sample in enumerate(test_dataset):  # test_dataset
        left_parts, right_parts = sample
        # batch will be a tuple: (right_halves, left_halves) due to split_width
        print(f"Batch shapes: {left_parts.shape}, {right_parts.shape}")
        if model_name in SAMPLE_FLATTENED:
            save_image(left_parts[0].reshape(config['data']['input_shape']), config["paths"]["output"] + f"/input_{index}.png")
            save_image(right_parts[0].reshape(config['data']['output_shape']), config["paths"]["output"] + f"/output_{index}.png")
        else:
            save_image(left_parts[0], config["paths"]["output"] + f"/input_{index}.png")
            save_image(right_parts[0], config["paths"]["output"] + f"/output_{index}.png")
        break

Total samples in providers:  2101 263 263
Total samples in datasets: 1939 245 248
Batch:  61 8 8
Batch shapes: torch.Size([32, 1, 256, 256]), torch.Size([32, 1, 256, 256])


# Construct Model

In [4]:
# ==================== 
# Construct Model
# ====================
if model_name == "CAE":
    from models.CAE import Autoencoder2D
    model = Autoencoder2D(
        in_channels=int(config['model']["in_channels"]),
        encoder=config['model']["encoder"],
        decoder=config['model']["decoder"],
        kernel_size=int(config['model']["kernel_size"]),
        apply_batchnorm=config['model']["apply_batchnorm"],
        apply_dropout=config['model']["apply_dropout"],
        final_activation=str(config['model']["final_activation"]),
    )
elif model_name == "TM":
    from models.TM import TransmissionMatrix
    model = TransmissionMatrix(
        input_height = config["data"]["input_shape"][0],
        input_width = config["data"]["input_shape"][1],
        output_height = config["data"]["output_shape"][0],
        output_width = config["data"]["output_shape"][1],
        initialization = "xavier",
    )
elif model_name == "SHL_DNN":
    from models.SHL_DNN import SHLNeuralNetwork
    model = SHLNeuralNetwork(
        input_size=config['data']['input_shape'][0] * config['data']['input_shape'][1],
        hidden_size=config['model']['hidden_size'], 
        output_size=config['data']['output_shape'][0] * config['data']['output_shape'][1],
        dropout_rate=config['model']['dropout_rate'],
    )
elif model_name == "U_Net":
    from models.U_Net import UNet
    model = UNet(
        in_channels=config["model"]["in_channels"],
        encoder=config["model"]["encoder"],
        decoder=config["model"]["decoder"],
        kernel_size=config["model"]["kernel_size"],
        apply_batchnorm=config["model"]["apply_batchnorm"],
        apply_dropout=config["model"]["apply_dropout"],
        out_channels=config["model"]["out_channels"],
        final_activation=config["model"]["final_activation"],
    )
elif model_name == "SwinT":
    from models.SwinT import SwinUNet, ReconLoss
    model = SwinUNet(
        img_size=config['model']['img_size'],
        in_chans=config['model']['in_chans'],
        out_chans=config['model']['out_chans'],
        embed_dim=config['model']['embed_dim'],
        depths=config['model']['depths'],
        num_heads=config['model']['num_heads'],
        window_size=config['model']['window_size'],
        patch_size=config['model']['patch_size'],
    )
elif model_name == "Pix2pix":
    from models.Pix2pix import Generator, Discriminator, Pix2PixLosses
    G = Generator(channels=config["model"]["channels"])
    D = Discriminator(channels=config["model"]["channels"])
    losses = Pix2PixLosses(lambda_l1=config["model"]["lambda_l1"])
    opt_g = torch.optim.Adam(G.parameters(), lr=config["training"]["learning_rate"], betas=config["training"]["betas"])
    opt_d = torch.optim.Adam(D.parameters(), lr=config["training"]["learning_rate"], betas=config["training"]["betas"])
elif model_name == "ERN":
    from models.ERN import EncoderRegressor
    model = EncoderRegressor(
            in_channels=config['model']['in_channels'],
            kernel_size=config['model']['kernel_size'],
            encoder=config['model']['encoder'],
            decoder=config['model']['decoder'],
            final_activation=config['model']['final_activation'],  
        )
elif model_name == "CAE_syth":
    from models.CAE import Autoencoder2D
    model = Autoencoder2D(
        in_channels=int(config['model']["in_channels"]),
        encoder=config['model']["encoder"],
        decoder=config['model']["decoder"],
        kernel_size=int(config['model']["kernel_size"]),
        apply_batchnorm=config['model']["apply_batchnorm"],
        apply_dropout=config['model']["apply_dropout"],
        final_activation=str(config['model']["final_activation"]),
    )

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


if model_name == "Pix2pix":
    G = G.to(device)
    D = D.to(device)
    show_model_info(G)
    show_model_info(D)
elif model_name == "SwinT":
    from torch.optim.lr_scheduler import LambdaLR
    total_steps = config['training']['epochs'] * len(train_dataset)
    warmup_steps = int(config['training']['warmup_ratio'] * total_steps)
    def lr_lambda(step):
        if step < warmup_steps:
            return (step + 1) / max(1, warmup_steps)
        t = (step - warmup_steps) / max(1, total_steps - warmup_steps)
        return 0.5 * (1.0 + math.cos(math.pi * t))

    model = model.to(device)
    criterion = ReconLoss(w_l1=config['training']['w_l1'], w_ssim=config['training']['w_ssim']) # Loss: L1 + 0.3*SSIM

    # Optimizer: AdamW with recommended params
    base_lr = 4e-4 if config['training']['batch_size'] >= 64 else 2e-4
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=base_lr, betas=config['training']['betas'],
        eps=config['training']['eps'], weight_decay=config['training']['weight_decay']
    )
    scheduler = LambdaLR(optimizer, lr_lambda)
    show_model_info(model)
else:
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config['training']['learning_rate'])
    show_model_info(model)

Detected framework: PyTorch
Framework:           PyTorch
Model:               Autoencoder2D
Device / dtype:      unavailable / N/A
Parameters:          18,620,544 total
                     18,620,544 trainable
                     0 non-trainable
Size:                71.06 MB
Sub-modules:         67


In [5]:
# ==================== 
# Training
# ====================
from functools import partial
import torch.nn as nn

from xflow import TorchTrainer, TorchGANTrainer
from xflow.trainers import build_callbacks_from_config
from xflow.extensions.physics.beam import extract_beam_parameters

# 1) loss/optimizer
# criterion = torch.nn.MSELoss()

# loss function with weighting on high-intensity regions
class WeightedMSELoss(nn.Module):
    def __init__(self, alpha=5.0, beta=0.5):
        super().__init__()
        self.alpha = alpha
        self.beta = beta

    def forward(self, pred, gt):
        weight = 1.0 + self.alpha * (gt.clamp(min=0.) ** self.beta)
        return (weight * (pred - gt) ** 2).mean()

criterion = WeightedMSELoss(alpha=5.0, beta=0.5)

# 2) callbacks (unchanged) + any custom wiring
callbacks = build_callbacks_from_config(
    config=config["callbacks"],
    framework=config["framework"],  
) # keep dataset closure for last callback, sequence hardcoded
callbacks[-1].set_dataset(test_dataset)

# Extract beam parameters closure (return as dict)
if model_name in SAMPLE_FLATTENED:
    extract_beam_parameters_dict = partial(extract_beam_parameters_flat, as_array=False)
    beam_param_metric = make_beam_param_metric(extract_beam_parameters_dict)
elif model_name in REGRESSION:   # e.g., "ERN"
    beam_param_metric = make_param_metric()
else:
    extract_beam_parameters_dict = partial(extract_beam_parameters, as_array=False)
    beam_param_metric = make_beam_param_metric(extract_beam_parameters_dict)

# 3) run training
if model_name in GAN:
    trainer = TorchGANTrainer(
        generator=G,
        discriminator=D,
        optimizer_g=opt_g,
        optimizer_d=opt_d,
        losses=losses,
        device=device,
        callbacks=callbacks,
        output_dir=config["paths"]["output"],
        data_pipeline=train_dataset,
        val_metrics=[beam_param_metric],
    )
else:
    trainer = TorchTrainer(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        device=device,
        callbacks=callbacks,
        output_dir=config["paths"]["output"],
        data_pipeline=train_dataset,
        val_metrics=[beam_param_metric],
        scheduler= scheduler if model_name == "SwinT" else None, 
        scheduler_step_per_batch=True,
    )

history = trainer.fit(
    train_loader=train_dataset, 
    val_loader=val_dataset,
    epochs=config['training']['epochs'],
)

# 4) save results
trainer.save_history(f"{config['paths']['output']}/history.json")
trainer.save_model(config["paths"]["output"])  # uses model.save_model(...) if available
config_manager.save(output_dir=config["paths"]["output"], config_filename=config["name"])

print("Training ALL complete.")  

Starting Training
Total epochs: 100

Epoch 1/100 - 61 batches
input image max pixel: 0.1098, ground truth image max pixel: 0.2063, reconstructed image max pixel: 0.5000
input image max pixel: 0.1098, ground truth image max pixel: 0.2063, reconstructed image max pixel: 0.5000
Epoch 1 completed in 561.89s - train_loss: 0.0868 - val_loss: 0.0198

Epoch 2/100 - 61 batches

Epoch 1 completed in 561.89s - train_loss: 0.0868 - val_loss: 0.0198

Epoch 2/100 - 61 batches
input image max pixel: 0.4579, ground truth image max pixel: 0.1765, reconstructed image max pixel: 0.4237
input image max pixel: 0.4579, ground truth image max pixel: 0.1765, reconstructed image max pixel: 0.4237
Epoch 2 completed in 557.56s - train_loss: 0.0103 - val_loss: 0.0068

Epoch 3/100 - 61 batches

Epoch 2 completed in 557.56s - train_loss: 0.0103 - val_loss: 0.0068

Epoch 3/100 - 61 batches
input image max pixel: 0.1373, ground truth image max pixel: 0.8568, reconstructed image max pixel: 0.3343
input image max pixel

# Data Analysis

In [7]:
# ...existing code...

from pathlib import Path

save_dir = Path("/Users/andrewxu/Desktop/untitled folder 2")
save_dir.mkdir(parents=True, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device).eval()

def _save_triplet(idx, inp_tensor, gt_tensor, pred_tensor):
    torch.save(
        {"input": inp_tensor, "ground_truth": gt_tensor, "prediction": pred_tensor},
        save_dir / f"sample_{idx:05d}.pt",
    )
    save_image(inp_tensor[0], str(save_dir / f"input_{idx:05d}.png"))
    save_image(gt_tensor[0], str(save_dir / f"ground_truth_{idx:05d}.png"))
    save_image(pred_tensor[0], str(save_dir / f"prediction_{idx:05d}.png"))

with torch.no_grad():
    if model_name in REGRESSION:
        for idx, (left_parts, params, right_parts) in enumerate(test_dataset):
            inputs = left_parts.to(device)
            targets = right_parts.to(device)
            params = params.to(device) if torch.is_tensor(params) else params
            preds = model(inputs, params).cpu()

            _save_triplet(idx, inputs.cpu(), targets.cpu(), preds)
    else:
        for idx, (left_parts, right_parts) in enumerate(test_dataset):
            inputs = left_parts.to(device)
            targets = right_parts.to(device)
            preds = model(inputs).cpu()

            if model_name in SAMPLE_FLATTENED:
                inp = inputs.cpu()
                tgt = targets.cpu()
                pred = preds
                inp[0] = inp[0].reshape(*config["data"]["input_shape"])
                tgt[0] = tgt[0].reshape(*config["data"]["output_shape"])
                pred[0] = pred[0].reshape(*config["data"]["output_shape"])
                _save_triplet(idx, inp, tgt, pred)
            else:
                _save_triplet(idx, inputs.cpu(), targets.cpu(), preds)

print(f"Saved predictions to {save_dir}")
# ...existing code...

Saved predictions to /Users/andrewxu/Desktop/untitled folder 2


In [8]:
len(test_dataset)

8