Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
.env
__pycache__

# Development
/sandbox
/staging
/data
/data
/original
og

# Non-Docker Training Outputs
/src/trainer/outputs
/src/trainer/mlruns

# Temp uv setup
.python-version
pyproject.toml
uv.lock
122 changes: 122 additions & 0 deletions configs/trainer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
data:
manifest_dir: "../../../ad_data/manifests"
dataset_root: "../../../ad_data/data/dataset"
extra_val_file: "rruff.jsonl"
auto_generate_manifests: true
train_ratio: 0.8
val_ratio: 0.1
test_ratio: 0.1
seed: 42

loader:
# --- DataLoader ---
batch_size: 64 # match OG run (64 per process)
num_workers: 8
pin_memory: true
persistent_workers: true
prefetch_factor: 2
train_file: "train.jsonl"
val_file: "val.jsonl"
test_file: "test.jsonl"

preprocessing:
validate_paths: false
extract_labels: true
allow_pickle: true
labels_key_map:
x: "dp"
cs: "cs"
sg: "sg"
lattice_params: null
lp_a: "_cell_length_a"
lp_b: "_cell_length_b"
lp_c: "_cell_length_c"
lp_alpha: "_cell_angle_alpha"
lp_beta: "_cell_angle_beta"
lp_gamma: "_cell_angle_gamma"
dtype: "float32"
mmap_mode: null
floor_at_zero: true
normalize_log1p: False
shift_labels: true

augmentation:
noise_poisson_range: [1.0, 100.0]
noise_gaussian_range: [0.001, 0.1]
standardize_to: [0.0, 100.0]

model:
type: "multiscale"

backbone:
dim_in: 8192
dims: [80, 80, 80]
kernel_sizes: [100, 50, 25]
strides: [5, 5, 5]
dropout_rate: 0.3
layer_scale_init_value: 0.0
drop_path_rate: 0.3
ramped_dropout_rate: false
block_type: "convnext"
pooling_type: "average"
final_pool: true
use_batchnorm: false
activation: "leaky_relu"
output_type: "flatten"

heads:
head_dropout: 0.5
cs_hidden: [2300, 1150]
sg_hidden: [2300, 1150]
lp_hidden: [512, 256]

tasks:
num_cs_classes: 7
num_sg_classes: 230
num_lp_outputs: 6

lp_bounds_min: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
lp_bounds_max: [300.0, 300.0, 300.0, 180.0, 180.0, 180.0]
bound_lp_with_sigmoid: true

loss:
lambda_cs: 1.0
lambda_sg: 1.0
lambda_lp: 1.0

gemd_mu: 0.0
gemd_distance_matrix_path: null

optimizer:
lr: 0.0002
weight_decay: 0.01
use_adamw: true
gradient_clip_val: 1.0
gradient_clip_algorithm: "norm"

trainer:
default_root_dir: "outputs/convnext_paper"
max_epochs: 100
accumulate_grad_batches: 1
precision: "32" # match OG (AMP disabled)
accelerator: "gpu"
devices: 1
log_every_n_steps: 200
deterministic: false
benchmark: true

logging:
logger: "mlflow"
csv_logger_name: "model_logs_convnext_paper"
mlflow_experiment_name: "AlphaDiffract_Paper_ConvNeXt"
mlflow_tracking_uri: null
mlflow_run_name: "ConvNeXt_Paper_Run"

checkpointing:
monitor: "val/loss"
mode: "min"
save_top_k: 1
every_n_epochs: 1

resume_from: null
test_after_train: true
39 changes: 39 additions & 0 deletions src/trainer/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
Dataset package for training.

Assumptions:
- Import paths use the local 'dataset' package.
- Manifests are JSON Lines and may include an optional first-line meta header:
{"__meta__": {"version": 1, "base_dir": "<path to dataset root>"}}
When present, non-absolute file paths in records are resolved relative to base_dir.
base_dir itself may be relative to the manifest file's directory. This makes manifests
independent of the current working directory.
- Legacy manifests without the meta header remain supported; their file paths are used as-is.

Exports:
- NpyManifestDataset: Map-style dataset loading .npy files listed in JSONL manifests.
- NpyDataModule: LightningDataModule wiring datasets and DataLoaders.
- generate_manifests: Utility to create train/val/test manifests split by material ID.
- ManifestStats: Summary dataclass for manifest generation.
"""

from .dataset import NpyManifestDataset, default_manifest_paths
from .datamodule import NpyDataModule
from .manifest_utils import (
generate_manifests,
ManifestStats,
scan_dataset_root,
split_materials,
write_jsonl_manifest,
)

__all__ = [
"NpyManifestDataset",
"default_manifest_paths",
"NpyDataModule",
"generate_manifests",
"ManifestStats",
"scan_dataset_root",
"split_materials",
"write_jsonl_manifest",
]
Loading