diff --git a/.gitignore b/.gitignore index e73e203..6e54e0a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,18 @@ .env +__pycache__ # Development /sandbox /staging -/data \ No newline at end of file +/data +/original +og + +# Non-Docker Training Outputs +/src/trainer/outputs +/src/trainer/mlruns + +# Temp uv setup +.python-version +pyproject.toml +uv.lock \ No newline at end of file diff --git a/configs/trainer.yaml b/configs/trainer.yaml new file mode 100644 index 0000000..a3113f1 --- /dev/null +++ b/configs/trainer.yaml @@ -0,0 +1,122 @@ +data: + manifest_dir: "../../../ad_data/manifests" + dataset_root: "../../../ad_data/data/dataset" + extra_val_file: "rruff.jsonl" + auto_generate_manifests: true + train_ratio: 0.8 + val_ratio: 0.1 + test_ratio: 0.1 + seed: 42 + + loader: + # --- DataLoader --- + batch_size: 64 # match OG run (64 per process) + num_workers: 8 + pin_memory: true + persistent_workers: true + prefetch_factor: 2 + train_file: "train.jsonl" + val_file: "val.jsonl" + test_file: "test.jsonl" + + preprocessing: + validate_paths: false + extract_labels: true + allow_pickle: true + labels_key_map: + x: "dp" + cs: "cs" + sg: "sg" + lattice_params: null + lp_a: "_cell_length_a" + lp_b: "_cell_length_b" + lp_c: "_cell_length_c" + lp_alpha: "_cell_angle_alpha" + lp_beta: "_cell_angle_beta" + lp_gamma: "_cell_angle_gamma" + dtype: "float32" + mmap_mode: null + floor_at_zero: true + normalize_log1p: False + shift_labels: true + + augmentation: + noise_poisson_range: [1.0, 100.0] + noise_gaussian_range: [0.001, 0.1] + standardize_to: [0.0, 100.0] + +model: + type: "multiscale" + + backbone: + dim_in: 8192 + dims: [80, 80, 80] + kernel_sizes: [100, 50, 25] + strides: [5, 5, 5] + dropout_rate: 0.3 + layer_scale_init_value: 0.0 + drop_path_rate: 0.3 + ramped_dropout_rate: false + block_type: "convnext" + pooling_type: "average" + final_pool: true + use_batchnorm: false + activation: "leaky_relu" + output_type: "flatten" + + heads: + head_dropout: 0.5 + cs_hidden: [2300, 1150] + sg_hidden: [2300, 1150] + lp_hidden: [512, 256] + + tasks: + num_cs_classes: 7 + num_sg_classes: 230 + num_lp_outputs: 6 + + lp_bounds_min: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + lp_bounds_max: [300.0, 300.0, 300.0, 180.0, 180.0, 180.0] + bound_lp_with_sigmoid: true + + loss: + lambda_cs: 1.0 + lambda_sg: 1.0 + lambda_lp: 1.0 + + gemd_mu: 0.0 + gemd_distance_matrix_path: null + +optimizer: + lr: 0.0002 + weight_decay: 0.01 + use_adamw: true + gradient_clip_val: 1.0 + gradient_clip_algorithm: "norm" + +trainer: + default_root_dir: "outputs/convnext_paper" + max_epochs: 100 + accumulate_grad_batches: 1 + precision: "32" # match OG (AMP disabled) + accelerator: "gpu" + devices: 1 + log_every_n_steps: 200 + deterministic: false + benchmark: true + +logging: + logger: "mlflow" + csv_logger_name: "model_logs_convnext_paper" + mlflow_experiment_name: "AlphaDiffract_Paper_ConvNeXt" + mlflow_tracking_uri: null + mlflow_run_name: "ConvNeXt_Paper_Run" + +checkpointing: + monitor: "val/loss" + mode: "min" + save_top_k: 1 + every_n_epochs: 1 + + resume_from: null + test_after_train: true diff --git a/src/trainer/dataset/__init__.py b/src/trainer/dataset/__init__.py new file mode 100644 index 0000000..890c077 --- /dev/null +++ b/src/trainer/dataset/__init__.py @@ -0,0 +1,39 @@ +""" +Dataset package for training. + +Assumptions: +- Import paths use the local 'dataset' package. +- Manifests are JSON Lines and may include an optional first-line meta header: + {"__meta__": {"version": 1, "base_dir": ""}} + When present, non-absolute file paths in records are resolved relative to base_dir. + base_dir itself may be relative to the manifest file's directory. This makes manifests + independent of the current working directory. +- Legacy manifests without the meta header remain supported; their file paths are used as-is. + +Exports: +- NpyManifestDataset: Map-style dataset loading .npy files listed in JSONL manifests. +- NpyDataModule: LightningDataModule wiring datasets and DataLoaders. +- generate_manifests: Utility to create train/val/test manifests split by material ID. +- ManifestStats: Summary dataclass for manifest generation. +""" + +from .dataset import NpyManifestDataset, default_manifest_paths +from .datamodule import NpyDataModule +from .manifest_utils import ( + generate_manifests, + ManifestStats, + scan_dataset_root, + split_materials, + write_jsonl_manifest, +) + +__all__ = [ + "NpyManifestDataset", + "default_manifest_paths", + "NpyDataModule", + "generate_manifests", + "ManifestStats", + "scan_dataset_root", + "split_materials", + "write_jsonl_manifest", +] diff --git a/src/trainer/dataset/datamodule.py b/src/trainer/dataset/datamodule.py new file mode 100644 index 0000000..4c28106 --- /dev/null +++ b/src/trainer/dataset/datamodule.py @@ -0,0 +1,289 @@ +""" +LightningDataModule for .npy datasets using JSONL manifests. + +This module wires NpyManifestDataset to PyTorch Lightning and can optionally +auto-generate manifests from a dataset root if they are missing. + +Notes: +- Splitting is performed by material ID when generating manifests (never per-file). +- Manifests avoid scanning the entire dataset during training. +- Manifest format is JSON Lines with optional meta header: + {"__meta__": {"version": 1, "base_dir": "../dataset"}} + {"material_id": "...", "files": ["rel/path/to/file.npy", ...]} +""" + +from __future__ import annotations + +import os +from typing import Any, Callable, Dict, Optional, Tuple + +# Resolve base for relative paths as the current working directory (runtime CWD) +CWD_BASE = os.getcwd() + +import torch +from torch.utils.data import DataLoader +from torch.utils.data.dataloader import default_collate +try: + # Prefer importing Lightning; if not installed, this file will error at import time. + import pytorch_lightning as pl +except Exception as e: + raise RuntimeError( + "pytorch_lightning must be installed to use NpyDataModule. " + "Install with: pip install pytorch-lightning" + ) from e + +from .dataset import NpyManifestDataset, default_manifest_paths, make_poisson_gaussian_noise_transform +from .manifest_utils import generate_manifests + + +def make_standardize_transform(std_range): + """ + Build a transform that min-max scales a tensor to the given [min, max] range per-sample. + """ + min_t, max_t = float(std_range[0]), float(std_range[1]) + + def _transform(x: torch.Tensor) -> torch.Tensor: + orig_shape = x.shape + dp = x.reshape(-1) + dp_min = torch.min(dp) + dp_max = torch.max(dp) + denom = dp_max - dp_min + if denom > 0: + norm = (dp - dp_min) / denom + else: + norm = torch.zeros_like(dp) + scaled = norm * (max_t - min_t) + min_t + return scaled.reshape(orig_shape) + + return _transform + + +class NpyDataModule(pl.LightningDataModule): + """ + LightningDataModule that reads train/val/test JSONL manifests and constructs DataLoaders. + + Args: + manifest_dir: Directory containing train.jsonl, val.jsonl, test.jsonl. + batch_size: Batch size for all splits (you can override per-split if needed). + num_workers: Number of DataLoader workers. + pin_memory: Enable DataLoader pin_memory (recommended for CUDA). + persistent_workers: Keep workers alive between epochs (speed-up for long runs). + collate_fn: Optional collate function (defaults to PyTorch's default). + dataset_cls: Dataset class to use (default: NpyManifestDataset). + dataset_kwargs: Additional kwargs forwarded to dataset_cls for all splits. + dataset_root: If provided and auto_generate_manifests=True, used to scan for materials. + auto_generate_manifests: If True, generate manifests in prepare_data if missing. + train_ratio, val_ratio, test_ratio: Ratios for splitting materials (must sum to 1.0). + seed: Random seed for reproducible material ID splits. + train_file, val_file, test_file: Manifest file names in manifest_dir (defaults provided). + """ + + def __init__( + self, + manifest_dir: str, + batch_size: int, + num_workers: int, + pin_memory: bool, + persistent_workers: bool, + prefetch_factor: Optional[int], + train_file: str, + val_file: str, + test_file: str, + collate_fn: Optional[Callable] = None, + dataset_cls: type = NpyManifestDataset, + dataset_kwargs: Optional[Dict[str, Any]] = None, + # Optional manifest auto-generation + dataset_root: Optional[str] = None, + auto_generate_manifests: bool = False, + train_ratio: float = 0.8, + val_ratio: float = 0.1, + test_ratio: float = 0.1, + seed: int = 42, + # Optional: add a second validation manifest file (e.g., "rruff.jsonl") + extra_val_file: Optional[str] = None, + # Optional noise augmentation for training split only + noise_poisson_range: Optional[Tuple[float, float]] = None, + noise_gaussian_range: Optional[Tuple[float, float]] = None, + # Optional post-noise standardization to a fixed range (per-sample) + standardize_to: Optional[Tuple[float, float]] = None, + ) -> None: + super().__init__() + self.manifest_dir = manifest_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.pin_memory = pin_memory + self.persistent_workers = persistent_workers + self.prefetch_factor = prefetch_factor + self.collate_fn = collate_fn + + self.dataset_cls = dataset_cls + self.dataset_kwargs = dataset_kwargs or {} + + self.dataset_root = dataset_root + self.auto_generate_manifests = auto_generate_manifests + self.train_ratio = train_ratio + self.val_ratio = val_ratio + self.test_ratio = test_ratio + self.seed = seed + + self.train_file = train_file + self.val_file = val_file + self.test_file = test_file + self.extra_val_file = extra_val_file + + # Optional noise augmentation params + self.noise_poisson_range = noise_poisson_range + self.noise_gaussian_range = noise_gaussian_range + self.standardize_to = standardize_to + + # Internal datasets + self.train_ds = None + self.val_ds = None + self.test_ds = None + self.extra_val_ds = None + + def _manifest_paths(self) -> Dict[str, str]: + # Allow overriding filenames while still respecting manifest_dir + base = self.manifest_dir + if not os.path.isabs(base): + base = os.path.join(CWD_BASE, base) + paths = { + "train": os.path.join(base, self.train_file), + "val": os.path.join(base, self.val_file), + "test": os.path.join(base, self.test_file), + } + return paths + + def prepare_data(self) -> None: + """ + Single-process hook. Optionally generate manifests if they are missing. + Do NOT assign state like self.train_ds here. + """ + if not self.auto_generate_manifests: + return + + paths = self._manifest_paths() + all_exist = all(os.path.isfile(p) for p in paths.values()) + if all_exist: + return + + if self.dataset_root is None: + raise RuntimeError( + "auto_generate_manifests=True but dataset_root was not provided. " + "Provide dataset_root to scan and create manifests." + ) + + # Resolve absolute paths for dataset root and manifest dir + dataset_root_abs = self.dataset_root + if not os.path.isabs(dataset_root_abs): + dataset_root_abs = os.path.join(CWD_BASE, dataset_root_abs) + manifest_dir_abs = os.path.dirname(paths["train"]) + + # Generate manifests in a reproducible way by material ID. + _stats = generate_manifests( + dataset_root=dataset_root_abs, + manifest_dir=manifest_dir_abs, + train_ratio=self.train_ratio, + val_ratio=self.val_ratio, + test_ratio=self.test_ratio, + seed=self.seed, + ) + # No state assignment here; stats can be logged by the caller if desired. + + def setup(self, stage: Optional[str] = None) -> None: + """ + Per-process hook. Instantiate datasets from manifest files. + """ + paths = self._manifest_paths() + if self.train_ds is None: + # Compose transforms for training split: noise first, then optional standardization + train_kwargs = dict(self.dataset_kwargs) + transforms = [] + if self.noise_poisson_range is not None and self.noise_gaussian_range is not None: + transforms.append( + make_poisson_gaussian_noise_transform(self.noise_poisson_range, self.noise_gaussian_range) + ) + if self.standardize_to is not None: + transforms.append(make_standardize_transform(self.standardize_to)) + + if len(transforms) > 0: + train_kwargs = dict(train_kwargs) # shallow copy + def _compose(ts): + def _f(x: torch.Tensor) -> torch.Tensor: + for t in ts: + x = t(x) + return x + return _f + train_kwargs["transform"] = _compose(transforms) + + self.train_ds = self.dataset_cls(paths["train"], **train_kwargs) + if self.val_ds is None: + val_kwargs = dict(self.dataset_kwargs) + if self.standardize_to is not None: + val_kwargs = dict(val_kwargs) + val_kwargs["transform"] = make_standardize_transform(self.standardize_to) + self.val_ds = self.dataset_cls(paths["val"], **val_kwargs) + # Extra validation dataset (e.g., rruff) + if self.extra_val_ds is None and self.extra_val_file is not None: + manifest_dir_abs = os.path.dirname(paths["val"]) + extra_path = os.path.join(manifest_dir_abs, self.extra_val_file) + extra_kwargs = dict(self.dataset_kwargs) + if self.standardize_to is not None: + extra_kwargs = dict(extra_kwargs) + extra_kwargs["transform"] = make_standardize_transform(self.standardize_to) + self.extra_val_ds = self.dataset_cls(extra_path, **extra_kwargs) + if self.test_ds is None: + test_kwargs = dict(self.dataset_kwargs) + if self.standardize_to is not None: + test_kwargs = dict(test_kwargs) + test_kwargs["transform"] = make_standardize_transform(self.standardize_to) + self.test_ds = self.dataset_cls(paths["test"], **test_kwargs) + + def train_dataloader(self) -> DataLoader: + return DataLoader( + self.train_ds, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + prefetch_factor=(self.prefetch_factor if self.prefetch_factor is not None else 4), + persistent_workers=self.persistent_workers if self.num_workers > 0 else False, + shuffle=True, + collate_fn=self.collate_fn, + ) + + def val_dataloader(self) -> DataLoader: + val_loader = DataLoader( + self.val_ds, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + prefetch_factor=(self.prefetch_factor if self.prefetch_factor is not None else 4), + persistent_workers=self.persistent_workers if self.num_workers > 0 else False, + shuffle=False, + collate_fn=self.collate_fn, + ) + if getattr(self, "extra_val_ds", None) is not None: + extra_loader = DataLoader( + self.extra_val_ds, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + prefetch_factor=(self.prefetch_factor if self.prefetch_factor is not None else 4), + persistent_workers=self.persistent_workers if self.num_workers > 0 else False, + shuffle=False, + collate_fn=self.collate_fn, + ) + return [val_loader, extra_loader] + return val_loader + + def test_dataloader(self) -> DataLoader: + return DataLoader( + self.test_ds, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + prefetch_factor=(self.prefetch_factor if self.prefetch_factor is not None else 4), + persistent_workers=self.persistent_workers if self.num_workers > 0 else False, + shuffle=False, + collate_fn=self.collate_fn, + ) diff --git a/src/trainer/dataset/dataset.py b/src/trainer/dataset/dataset.py new file mode 100644 index 0000000..f830210 --- /dev/null +++ b/src/trainer/dataset/dataset.py @@ -0,0 +1,355 @@ +""" +Dataset definition for loading .npy samples using precomputed JSONL manifests. + +Manifest format (JSON Lines): +- Optional first line meta header: + {"__meta__": {"version": 1, "base_dir": "../dataset"}} +- Then one record per material ID: + {"material_id": "mp-4002", "files": ["mp-4002/mp-4002-7.npy", "mp-4002/mp-4002-8.npy", ...]} + +Path resolution: +- If meta header is present, non-absolute file paths are resolved relative to base_dir. + base_dir itself may be relative to the manifest file directory. +- If no meta header (legacy manifests), file paths are used as-is for backward compatibility. + +Notes: +- Splits are performed by material ID when manifests are generated (not at file level). +- This dataset reads the manifest and yields one sample per .npy file, with 'material_id' and 'path' metadata. +- Avoids scanning large trees by leveraging precomputed manifests under data/manifests. +""" + +from __future__ import annotations + +import json +import os +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple + +import numpy as np +import torch +from torch.utils.data import Dataset + +# Noise transform builder following the paper's two-step model +from typing import Tuple, Callable + +def make_poisson_gaussian_noise_transform(poisson_range: Tuple[float, float], gaussian_range: Tuple[float, float]) -> Callable[[torch.Tensor], torch.Tensor]: + """ + Build a transform that applies the paper's two-step noise model: + 1) Poisson noise with mean lambda sampled uniformly from poisson_range, applied to dp/max(dp), + then rescaled by max(dp)/lambda. + 2) Gaussian noise with std sigma sampled uniformly from gaussian_range, applied after normalizing to [0,1]. + Finally renormalize by the new min/max and rescale back to the Poisson-perturbed range, clamp to non-negative. + """ + lam_lo, lam_hi = float(poisson_range[0]), float(poisson_range[1]) + sig_lo, sig_hi = float(gaussian_range[0]), float(gaussian_range[1]) + + def _transform(x: torch.Tensor) -> torch.Tensor: + # Preserve shape + orig_shape = x.shape + dp = x.reshape(-1) + # If all nonpositive, return as-is (clamped) + dp_max = torch.max(dp) + if dp_max <= 0: + return torch.clamp(x, min=0) + + # Sample parameters + lam = torch.empty(1).uniform_(lam_lo, lam_hi).item() + sigma = torch.empty(1).uniform_(sig_lo, sig_hi).item() + + # Poisson step + rate = torch.clamp(dp, min=0) / dp_max * lam + dp_pois = torch.poisson(rate) + dp_pois = dp_pois * dp_max / lam + + # Gaussian step on normalized intensity + dp_min = torch.min(dp_pois) + dp_max2 = torch.max(dp_pois) + denom = dp_max2 - dp_min + if denom > 0: + norm = (dp_pois - dp_min) / denom + else: + norm = torch.zeros_like(dp_pois) + dp_gauss = norm + torch.randn_like(norm) * sigma + + # Final renormalize and rescale back to original Poisson range + g_min = torch.min(dp_gauss) + g_max = torch.max(dp_gauss) + g_denom = g_max - g_min + if g_denom > 0: + norm2 = (dp_gauss - g_min) / g_denom + else: + norm2 = torch.zeros_like(dp_gauss) + dp_noisy = norm2 * denom + dp_min + + dp_noisy = torch.clamp(dp_noisy, min=0.0) + return dp_noisy.reshape(orig_shape) + + return _transform + + +def _read_jsonl_manifest(manifest_path: str) -> Tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]: + """ + Read a JSONL manifest file. + + Returns: + (records, meta): + records: [{"material_id": str, "files": [str, ...]}, ...] + meta: Optional dict from a header line like {"__meta__": {"version": 1, "base_dir": "..."}} + """ + if not os.path.isfile(manifest_path): + raise FileNotFoundError(f"Manifest not found: {manifest_path}") + records: List[Dict[str, Any]] = [] + meta: Optional[Dict[str, Any]] = None + first_record_processed = False + with open(manifest_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + rec = json.loads(line) + # Detect optional meta header on the first non-empty line + if not first_record_processed and isinstance(rec, dict) and "__meta__" in rec: + meta = rec["__meta__"] + first_record_processed = True + continue + + first_record_processed = True + # Basic validation + if "material_id" not in rec or "files" not in rec: + raise ValueError(f"Invalid manifest record without material_id/files: {rec}") + if not isinstance(rec["files"], list): + raise ValueError(f"Manifest 'files' must be a list: {rec}") + records.append(rec) + if not records: + raise RuntimeError(f"No records in manifest: {manifest_path}") + return records, meta + + +class NpyManifestDataset(Dataset): + """ + Map-style dataset that loads individual .npy files listed in a JSONL manifest. + + One sample per file. Metadata includes 'material_id' and 'path' to the .npy file. + + Args: + manifest_path: Path to JSONL manifest (e.g., data/manifests/train.jsonl). + transform: Optional callable applied to the loaded torch.Tensor. + dtype: Torch dtype to convert the loaded array into (default: torch.float32). + mmap_mode: NumPy memmap mode for np.load (default: 'r'). Set to None to disable memmap. + return_meta: If True, returns a dict with keys {'x', 'material_id', 'path'}. + If False, returns only the tensor. + validate_paths: If True, validate that file paths exist when building the index. + + Returns: + If return_meta: + Dict[str, Any]: {'x': Tensor, 'material_id': str, 'path': str} + Else: + Tensor + """ + + def __init__( + self, + manifest_path: str, + dtype: torch.dtype, + mmap_mode: Optional[str], + return_meta: bool, + validate_paths: bool, + extract_labels: bool, + allow_pickle: bool, + floor_at_zero: bool, + normalize_log1p: bool, + shift_labels: bool, + transform: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, + labels_key_map: Optional[Dict[str, List[str]]] = None, + ) -> None: + super().__init__() + self.manifest_path = manifest_path + self.transform = transform + self.dtype = dtype + self.mmap_mode = mmap_mode + self.return_meta = return_meta + self.extract_labels = extract_labels + self.allow_pickle = allow_pickle + self.floor_at_zero = floor_at_zero + self.normalize_log1p = normalize_log1p + self.shift_labels = shift_labels + # Default key mapping for extracting fields from embedded containers + # Simplified: single string keys, no search lists + self.labels_key_map = labels_key_map or { + "x": "dp", + "cs": "cs", + "sg": "sg", + "lattice_params": None, + "lp_a": "_cell_length_a", + "lp_b": "_cell_length_b", + "lp_c": "_cell_length_c", + "lp_alpha": "_cell_angle_alpha", + "lp_beta": "_cell_angle_beta", + "lp_gamma": "_cell_angle_gamma", + } + # Cache exact keys + self.k_x = self.labels_key_map.get("x") + self.k_cs = self.labels_key_map.get("cs") + self.k_sg = self.labels_key_map.get("sg") + self.k_lp = self.labels_key_map.get("lattice_params") + self.k_lp_a = self.labels_key_map.get("lp_a") + self.k_lp_b = self.labels_key_map.get("lp_b") + self.k_lp_c = self.labels_key_map.get("lp_c") + self.k_lp_alpha = self.labels_key_map.get("lp_alpha") + self.k_lp_beta = self.labels_key_map.get("lp_beta") + self.k_lp_gamma = self.labels_key_map.get("lp_gamma") + + # Build index: flatten per-material records into (resolved_file_path, material_id) pairs + records, meta = _read_jsonl_manifest(manifest_path) + + # Resolve base_dir from meta header if present + base_dir_abs: Optional[str] = None + if meta is not None and isinstance(meta, dict): + base_dir_val = meta.get("base_dir") + if isinstance(base_dir_val, str): + manifest_dir_abs = os.path.abspath(os.path.dirname(manifest_path)) + base_dir_abs = ( + base_dir_val + if os.path.isabs(base_dir_val) + else os.path.normpath(os.path.join(manifest_dir_abs, base_dir_val)) + ) + + items: List[Tuple[str, str]] = [] + for rec in records: + mid = rec["material_id"] + for fpath in rec["files"]: + # Resolve relative paths against base_dir_abs if provided; otherwise use as-is + resolved = ( + fpath + if os.path.isabs(fpath) or base_dir_abs is None + else os.path.normpath(os.path.join(base_dir_abs, fpath)) + ) + if validate_paths and not os.path.isfile(resolved): + # You may choose to ignore missing files instead of raising. + raise FileNotFoundError(f"Missing file listed in manifest: {resolved}") + items.append((resolved, mid)) + + if not items: + raise RuntimeError(f"No file entries found in manifest: {manifest_path}") + + self._items = items + # Materials set for quick stats + self._materials = sorted({mid for _, mid in items}) + + def __len__(self) -> int: + return len(self._items) + + def __getitem__(self, idx: int): + fpath, mid = self._items[idx] + # Load numpy/npz; allow_pickle supports dict-like payloads saved in .npy + # Load with pickle-friendly mode; memmap disabled for object payloads + arr = np.load(fpath, mmap_mode=None, allow_pickle=self.allow_pickle) + + x_tensor: Optional[torch.Tensor] = None + y_cs_t: Optional[torch.Tensor] = None + y_sg_t: Optional[torch.Tensor] = None + y_lp_t: Optional[torch.Tensor] = None + + # Helper to resolve a key from a container using a list of candidate names + def _get_exact(container, key: str): + if key is None: + return None + if isinstance(container, dict): + return container.get(key) + return None + + # Simplified: only support .npy files with object dtype containing a dict payload + if isinstance(arr, np.ndarray) and arr.dtype == object: + payload = arr.item() + if not isinstance(payload, dict): + raise TypeError(f"Expected dict payload in object dtype .npy, got {type(payload)} for {fpath}") + x_np = _get_exact(payload, self.k_x) + if x_np is None: + raise KeyError(f"No data array found in npy {fpath} using key '{self.k_x}'") + x_np = np.asarray(x_np) + x_tensor = torch.from_numpy(x_np) + + if self.extract_labels: + y_cs = _get_exact(payload, self.k_cs) + y_sg = _get_exact(payload, self.k_sg) + y_lp = _get_exact(payload, self.k_lp) + if y_cs is not None: + y_cs_t = torch.as_tensor(np.asarray(y_cs)) + if y_sg is not None: + y_sg_t = torch.as_tensor(np.asarray(y_sg)) + if y_lp is not None: + y_lp_t = torch.as_tensor(np.asarray(y_lp)) + # Assemble lattice params from fixed keys if consolidated field not present + if y_lp_t is None: + a = _get_exact(payload, self.k_lp_a) + b = _get_exact(payload, self.k_lp_b) + c = _get_exact(payload, self.k_lp_c) + alpha = _get_exact(payload, self.k_lp_alpha) + beta = _get_exact(payload, self.k_lp_beta) + gamma = _get_exact(payload, self.k_lp_gamma) + if None not in (a, b, c, alpha, beta, gamma): + lp_np = np.asarray([a, b, c, alpha, beta, gamma], dtype=np.float64) + y_lp_t = torch.as_tensor(lp_np) + else: + raise TypeError(f"Unsupported file format: expected object dtype .npy, got dtype={getattr(arr, 'dtype', '?')} for {fpath}") + + # Cast dtype and apply preprocessing/transform to x only + if self.dtype is not None and x_tensor is not None: + x_tensor = x_tensor.to(self.dtype) + # Ensure non-negative counts (needed before Poisson rate computation) + if x_tensor is not None and self.floor_at_zero: + x_tensor = torch.clamp(x_tensor, min=0) + # Apply noise/augmentation transform following the paper (before any log transforms) + if self.transform is not None and x_tensor is not None: + x_tensor = self.transform(x_tensor) + # Optional log1p normalization to compress peak variance (applied after noise) + if x_tensor is not None and self.normalize_log1p: + x_tensor = torch.log1p(x_tensor) + + if not self.return_meta: + # Backward-compat: return only x + return x_tensor + + sample = {"x": x_tensor, "material_id": mid, "path": fpath} + # Attach labels if present/extracted + if self.extract_labels: + if y_cs_t is not None: + if self.shift_labels: + y_cs_t = y_cs_t - 1 + sample["cs"] = y_cs_t + if y_sg_t is not None: + if self.shift_labels: + y_sg_t = y_sg_t - 1 + sample["sg"] = y_sg_t + if y_lp_t is not None: + sample["lattice_params"] = y_lp_t + return sample + + @property + def materials(self) -> List[str]: + """Sorted list of unique material IDs present in this dataset.""" + return self._materials + + def count_by_material(self) -> Dict[str, int]: + """Return a mapping from material_id to number of files for that material in this dataset.""" + counts: Dict[str, int] = {} + for _, mid in self._items: + counts[mid] = counts.get(mid, 0) + 1 + return counts + + +# Optional convenience to locate standard split manifests under data/manifests +def default_manifest_paths(manifest_dir: str = "data/manifests") -> Dict[str, str]: + """ + Returns standard manifest paths for train/val/test within a manifest directory. + + Example: + paths = default_manifest_paths("data/manifests") + train_ds = NpyManifestDataset(paths["train"]) + val_ds = NpyManifestDataset(paths["val"]) + test_ds = NpyManifestDataset(paths["test"]) + """ + return { + "train": os.path.join(manifest_dir, "train.jsonl"), + "val": os.path.join(manifest_dir, "val.jsonl"), + "test": os.path.join(manifest_dir, "test.jsonl"), + } diff --git a/src/trainer/dataset/manifest_utils.py b/src/trainer/dataset/manifest_utils.py new file mode 100644 index 0000000..4a2339f --- /dev/null +++ b/src/trainer/dataset/manifest_utils.py @@ -0,0 +1,223 @@ +""" +Utilities for generating and loading manifest files for .npy datasets. + +Manifest format (JSON Lines): +- Optional first line with manifest metadata: + {"__meta__": {"version": 1, "base_dir": "../dataset"}} + +- Then one record per material ID: + {"material_id": "mp-4002", "files": ["mp-4002/mp-4002-7.npy", "mp-4002/mp-4002-8.npy", ...]} + +Notes: +- When the meta header with base_dir is present, file paths in subsequent records are interpreted + relative to base_dir. If a file path is absolute, it is used as-is. +- If no meta header is present (legacy manifests), file paths are interpreted as-is (backward compatible). +- Storing manifests under data/manifests allows fast dataset construction without scanning millions of files. +""" + +from __future__ import annotations + +import json +import os +import random +from dataclasses import dataclass +from typing import Dict, List, Tuple, Optional + + +@dataclass(frozen=True) +class ManifestStats: + train_materials: int + val_materials: int + test_materials: int + train_files: int + val_files: int + test_files: int + + +def _is_material_dir(path: str) -> bool: + """Heuristic: a material dir contains at least one .npy file.""" + if not os.path.isdir(path): + return False + try: + for entry in os.listdir(path): + if entry.endswith(".npy"): + return True + except Exception: + return False + return False + + +def scan_dataset_root(dataset_root: str) -> Dict[str, List[str]]: + """ + Scan dataset_root for material directories and collect their .npy file paths. + + Returns a mapping from material_id (dir name) to list of absolute file paths. + """ + if not os.path.isdir(dataset_root): + raise FileNotFoundError(f"Dataset root not found: {dataset_root}") + + materials: Dict[str, List[str]] = {} + for entry in sorted(os.listdir(dataset_root)): + material_dir = os.path.join(dataset_root, entry) + if not _is_material_dir(material_dir): + continue + material_id = os.path.basename(material_dir) + files: List[str] = [] + try: + for f in sorted(os.listdir(material_dir)): + if f.endswith(".npy"): + files.append(os.path.join(material_dir, f)) + except Exception: + # Skip unreadable directories + continue + if files: + materials[material_id] = files + + if not materials: + raise RuntimeError(f"No .npy files discovered under {dataset_root}") + return materials + + +def split_materials( + materials: Dict[str, List[str]], + train_ratio: float = 0.8, + val_ratio: float = 0.1, + test_ratio: float = 0.1, + seed: int = 42, +) -> Tuple[Dict[str, List[str]], Dict[str, List[str]], Dict[str, List[str]]]: + """ + Split by material ID into train/val/test sets. + + - Ratios must sum to 1.0 within a small tolerance. + - Splitting is performed at the material level (directory name), never per-file. + """ + total = train_ratio + val_ratio + test_ratio + if abs(total - 1.0) > 1e-6: + raise ValueError(f"Ratios must sum to 1.0, got {train_ratio}+{val_ratio}+{test_ratio}={total}") + + ids = list(materials.keys()) + random.Random(seed).shuffle(ids) + + n = len(ids) + n_train = int(round(train_ratio * n)) + n_val = int(round(val_ratio * n)) + # Ensure all materials are assigned, adjust test count + n_test = n - n_train - n_val + if n_test < 0: + n_test = 0 + + train_ids = set(ids[:n_train]) + val_ids = set(ids[n_train : n_train + n_val]) + test_ids = set(ids[n_train + n_val :]) + + train = {mid: materials[mid] for mid in train_ids} + val = {mid: materials[mid] for mid in val_ids} + test = {mid: materials[mid] for mid in test_ids} + return train, val, test + + +def ensure_dir(path: str) -> None: + os.makedirs(path, exist_ok=True) + + +def write_jsonl_manifest( + manifest_path: str, + materials: Dict[str, List[str]], + base_dir: Optional[str] = None, +) -> int: + """ + Write a JSONL manifest. + + If base_dir is provided, a meta header is written as the first line: + {"__meta__": {"version": 1, "base_dir": base_dir}} + and file paths are written relative to base_dir (when possible). + + Args: + manifest_path: Destination JSONL file. + materials: Mapping from material_id to list of file paths (absolute or relative). + base_dir: Base directory to which file paths will be made relative. This string is stored + in the meta header. If it is a relative path, it is interpreted relative to + the manifest file's directory when reading. + + Returns: + Number of material entries written (excluding meta header). + """ + ensure_dir(os.path.dirname(manifest_path)) + + manifest_dir_abs = os.path.abspath(os.path.dirname(manifest_path)) + base_dir_abs = None + if base_dir is not None: + base_dir_abs = ( + base_dir if os.path.isabs(base_dir) else os.path.normpath(os.path.join(manifest_dir_abs, base_dir)) + ) + + count = 0 + with open(manifest_path, "w", encoding="utf-8") as f: + # Optional meta header + if base_dir is not None: + meta = {"__meta__": {"version": 1, "base_dir": base_dir}} + f.write(json.dumps(meta) + "\n") + + for material_id, files in sorted(materials.items()): + files_out: List[str] = [] + for p in files: + if base_dir_abs is None: + files_out.append(p) + else: + try: + # Prefer paths relative to base_dir_abs. If outside, relpath may contain .. which is still valid. + files_out.append(os.path.relpath(p, start=base_dir_abs)) + except Exception: + files_out.append(p) + rec = {"material_id": material_id, "files": files_out} + f.write(json.dumps(rec) + "\n") + count += 1 + return count + + +def generate_manifests( + dataset_root: str, + manifest_dir: str = "data/manifests", + train_ratio: float = 0.8, + val_ratio: float = 0.1, + test_ratio: float = 0.1, + seed: int = 42, +) -> ManifestStats: + """ + Generate train/val/test manifests from dataset_root, splitting by material ID. + + Writes: + - {manifest_dir}/train.jsonl + - {manifest_dir}/val.jsonl + - {manifest_dir}/test.jsonl + + Returns aggregate stats. + """ + materials = scan_dataset_root(dataset_root) + train, val, test = split_materials(materials, train_ratio, val_ratio, test_ratio, seed=seed) + + train_path = os.path.join(manifest_dir, "train.jsonl") + val_path = os.path.join(manifest_dir, "val.jsonl") + test_path = os.path.join(manifest_dir, "test.jsonl") + + # Store base_dir in header relative to manifest_dir for portability + manifest_dir_abs = os.path.abspath(manifest_dir) + dataset_root_abs = os.path.abspath(dataset_root) + base_dir_header = os.path.relpath(dataset_root_abs, start=manifest_dir_abs) + + _ = write_jsonl_manifest(train_path, train, base_dir=base_dir_header) + _ = write_jsonl_manifest(val_path, val, base_dir=base_dir_header) + _ = write_jsonl_manifest(test_path, test, base_dir=base_dir_header) + + def _file_count(m: Dict[str, List[str]]) -> int: + return sum(len(v) for v in m.values()) + + stats = ManifestStats( + train_materials=len(train), + val_materials=len(val), + test_materials=len(test), + train_files=_file_count(train), + val_files=_file_count(val), + test_files=_file_count(test), + ) + return stats diff --git a/src/trainer/generate_manifests.py b/src/trainer/generate_manifests.py new file mode 100644 index 0000000..8f7c4b8 --- /dev/null +++ b/src/trainer/generate_manifests.py @@ -0,0 +1,146 @@ +""" +CLI to generate JSONL manifests for .npy datasets split by material ID. + +Manifest format (JSON Lines): +- Optional first line meta header: + {"__meta__": {"version": 1, "base_dir": "../dataset"}} +- Then one record per material ID: + {"material_id": "mp-4002", "files": ["mp-4002/mp-4002-7.npy", "mp-4002/mp-4002-8.npy", ...]} + +Defaults: +- dataset_root: data/dataset +- manifest_dir: data/manifests +- splits: train=0.8, val=0.1, test=0.1 +- seed: 42 + +Usage: + # Run from src/trainer as working directory + python generate_manifests.py \ + --dataset-root data/dataset \ + --manifest-dir data/manifests \ + --train-ratio 0.8 --val-ratio 0.1 --test-ratio 0.1 \ + --seed 42 + +This script avoids scanning during training by precomputing manifests that link +to the .npy files grouped by material ID (directory name). +""" + +from __future__ import annotations + +import argparse +import json +import os +import sys +from typing import Any, Dict + +# Assume cwd is src/trainer; import from local package +from dataset.manifest_utils import generate_manifests, ManifestStats + + +def _parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser( + description="Generate JSONL manifests for .npy datasets split by material ID." + ) + p.add_argument( + "--dataset-root", + type=str, + default="data/dataset", + help="Root directory containing material subfolders (e.g., data/dataset/mp-4002/...).", + ) + p.add_argument( + "--manifest-dir", + type=str, + default="data/manifests", + help="Directory where train/val/test JSONL manifests will be written.", + ) + p.add_argument( + "--train-ratio", + type=float, + default=0.8, + help="Fraction of materials assigned to train split.", + ) + p.add_argument( + "--val-ratio", + type=float, + default=0.1, + help="Fraction of materials assigned to validation split.", + ) + p.add_argument( + "--test-ratio", + type=float, + default=0.1, + help="Fraction of materials assigned to test split.", + ) + p.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for reproducible material ID splitting.", + ) + p.add_argument( + "--quiet", + action="store_true", + help="Suppress non-error output.", + ) + return p.parse_args() + + +def _validate_ratios(train: float, val: float, test: float) -> None: + total = train + val + test + if abs(total - 1.0) > 1e-6: + raise SystemExit( + f"Error: Ratios must sum to 1.0, got {train}+{val}+{test}={total}" + ) + if min(train, val, test) < 0.0: + raise SystemExit("Error: Ratios must be non-negative.") + + +def _print_summary(manifest_dir: str, stats: ManifestStats) -> None: + summary: Dict[str, Any] = { + "manifest_dir": manifest_dir, + "manifests": { + "train": os.path.join(manifest_dir, "train.jsonl"), + "val": os.path.join(manifest_dir, "val.jsonl"), + "test": os.path.join(manifest_dir, "test.jsonl"), + }, + "materials": { + "train": stats.train_materials, + "val": stats.val_materials, + "test": stats.test_materials, + }, + "files": { + "train": stats.train_files, + "val": stats.val_files, + "test": stats.test_files, + }, + } + print(json.dumps(summary, indent=2)) + + +def main() -> None: + args = _parse_args() + _validate_ratios(args.train_ratio, args.val_ratio, args.test_ratio) + + # Resolve paths relative to project root if provided as relative paths + dataset_root = args.dataset_root + manifest_dir = args.manifest_dir + + stats = generate_manifests( + dataset_root=dataset_root, + manifest_dir=manifest_dir, + train_ratio=args.train_ratio, + val_ratio=args.val_ratio, + test_ratio=args.test_ratio, + seed=args.seed, + ) + + if not args.quiet: + _print_summary(manifest_dir, stats) + + +if __name__ == "__main__": + try: + main() + except Exception as e: + print(f"[ERROR] {e}", file=sys.stderr) + sys.exit(1) diff --git a/src/trainer/infer_rruff.py b/src/trainer/infer_rruff.py new file mode 100644 index 0000000..0946b46 --- /dev/null +++ b/src/trainer/infer_rruff.py @@ -0,0 +1,350 @@ +import argparse +import os +from typing import Any, Dict, Optional, Sequence + +import torch +import yaml +import numpy as np +from tqdm import tqdm + +# Optional: use sklearn for confusion matrix if available; otherwise fall back to numpy implementation +try: + from sklearn.metrics import confusion_matrix as sk_confusion_matrix +except Exception: + sk_confusion_matrix = None + +import matplotlib +matplotlib.use("Agg") # non-interactive backend +import matplotlib.pyplot as plt + +# Expect PYTHONPATH=src; run with: python -m trainer.infer_rruff configs/trainer_convnext_paper.yaml --ckpt /path/to.ckpt +from dataset import NpyDataModule +from model.model import AlphaDiffractMultiscaleLightning + + +# ----------------------------- +# Config & datamodule builders (adapted from train_paper/debug_sanity) +# ----------------------------- + +def load_config(path: str) -> Dict[str, Any]: + if not os.path.isfile(path): + raise FileNotFoundError(f"Config file not found: {path}") + with open(path, "r", encoding="utf-8") as f: + cfg = yaml.safe_load(f) + if not isinstance(cfg, dict): + raise ValueError(f"Config must be a mapping (YAML dict), got: {type(cfg)}") + return cfg + + +def _to_dtype(name: str) -> torch.dtype: + table = { + "float32": torch.float32, + "float64": torch.float64, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + } + if name not in table: + raise ValueError(f"Unsupported dtype '{name}'. Allowed: {list(table.keys())}") + return table[name] + + +def build_datamodule_from_cfg(cfg: Dict[str, Any]) -> NpyDataModule: + dataset_kwargs = { + "dtype": _to_dtype(cfg["dtype"]), + "mmap_mode": cfg["mmap_mode"], + "return_meta": True, + "validate_paths": cfg["validate_paths"], + "extract_labels": cfg["extract_labels"], + "allow_pickle": cfg["allow_pickle"], + "floor_at_zero": cfg["floor_at_zero"], + "normalize_log1p": cfg["normalize_log1p"], + } + labels_key_map = cfg.get("labels_key_map") + if labels_key_map is not None: + dataset_kwargs["labels_key_map"] = labels_key_map + + dm = NpyDataModule( + manifest_dir=cfg["manifest_dir"], + batch_size=cfg["batch_size"], + num_workers=cfg["num_workers"], + pin_memory=cfg["pin_memory"], + persistent_workers=cfg["persistent_workers"] and cfg["num_workers"] > 0, + dataset_kwargs=dataset_kwargs, + dataset_root=cfg["dataset_root"], + auto_generate_manifests=cfg["auto_generate_manifests"], + train_ratio=cfg["train_ratio"], + val_ratio=cfg["val_ratio"], + test_ratio=cfg["test_ratio"], + seed=cfg["seed"], + extra_val_file=cfg.get("extra_val_file"), # expect "rruff.jsonl" + ) + return dm + + +def build_model_class_from_cfg(cfg: Dict[str, Any]): + model_type = cfg.get("model_type", "convnext").lower() + if model_type == "multiscale": + return AlphaDiffractMultiscaleLightning + else: + raise ValueError(f"Unsupported model_type '{model_type}'. Expected 'multiscale'.") + + +# ----------------------------- +# Checkpoint resolution helpers +# ----------------------------- + +def find_checkpoint(cfg: Dict[str, Any], candidate_override: Optional[str] = None) -> Optional[str]: + """ + Resolve a checkpoint path. Priority: + 1) Explicit override if provided and exists + 2) Last/best checkpoint under default_root_dir/checkpoints/** + 3) Heuristics: also check under src/trainer/outputs/... in case training ran from src/trainer + """ + if candidate_override and os.path.isfile(candidate_override): + return candidate_override + + # Helper to scan a base checkpoints dir + def _scan_dir(base: str) -> Optional[str]: + if not os.path.isdir(base): + return None + # Prefer epoch-best ckpt if present, else last.ckpt + best_ckpts: list[str] = [] + last_ckpts: list[str] = [] + for root, _dirs, files in os.walk(base): + for fn in files: + if fn.endswith(".ckpt"): + full = os.path.join(root, fn) + if "val_loss" in fn: + best_ckpts.append(full) + elif fn == "last.ckpt": + last_ckpts.append(full) + # Heuristic: choose first best, else first last + if best_ckpts: + # sort by epoch number if present + try: + best_ckpts.sort(key=lambda p: int(os.path.basename(p).split("epoch")[1].split("-")[0])) + except Exception: + pass + return best_ckpts[0] + if last_ckpts: + return last_ckpts[0] + return None + + # 2) default_root_dir from cfg (relative to CWD) + default_root = cfg.get("default_root_dir", "outputs/model") + ckpt_dir = os.path.join(default_root, "checkpoints") + found = _scan_dir(ckpt_dir) + if found: + return found + + # 3) try under src/trainer/outputs in case runs were launched from within src/trainer + alt_ckpt_dir = os.path.join("src", "trainer", "outputs", os.path.basename(default_root), "checkpoints") + found = _scan_dir(alt_ckpt_dir) + return found + + +# ----------------------------- +# Evaluation and confusion matrix +# ----------------------------- + +def compute_confusion(y_true: np.ndarray, y_pred: np.ndarray, labels: Sequence[int], normalize: bool) -> np.ndarray: + if sk_confusion_matrix is not None: + norm_opt = "true" if normalize else None + return sk_confusion_matrix(y_true, y_pred, labels=list(labels), normalize=norm_opt) + # Fallback: manual + L = len(labels) + label_to_idx = {lab: i for i, lab in enumerate(labels)} + cm = np.zeros((L, L), dtype=np.float64) + for t, p in zip(y_true, y_pred): + cm[label_to_idx[t], label_to_idx[p]] += 1 + if normalize: + row_sums = cm.sum(axis=1, keepdims=True) + row_sums[row_sums == 0] = 1.0 + cm = cm / row_sums + return cm + + +def plot_confusion(cm: np.ndarray, labels: Sequence[int], title: str, save_path: str) -> None: + size = max(6, min(20, int(len(labels) * 0.4))) # adapt figure size + fig, ax = plt.subplots(figsize=(size, size)) + im = ax.imshow(cm, cmap="Blues", interpolation="nearest", aspect="auto") + ax.set_title(title) + ax.set_xlabel("Predicted") + ax.set_ylabel("True") + + # Tick labels only if small enough + if len(labels) <= 30: + ax.set_xticks(range(len(labels))) + ax.set_yticks(range(len(labels))) + ax.set_xticklabels(labels, rotation=90) + ax.set_yticklabels(labels) + else: + ax.set_xticks([]) + ax.set_yticks([]) + + fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) + plt.tight_layout() + os.makedirs(os.path.dirname(save_path), exist_ok=True) + fig.savefig(save_path, dpi=200) + plt.close(fig) + + +# ----------------------------- +# Main +# ----------------------------- + +def _resolve_loader_for_split(dm: NpyDataModule, split: str): + split = split.lower() + if split == "rruff": + # Get RRUFF loader: val_dataloader returns list [val_loader, extra_loader] when extra_val_file is set + dm.setup("validate") + val_dl = dm.val_dataloader() + if isinstance(val_dl, list): + if len(val_dl) < 2: + raise RuntimeError("extra_val_file set, but val_dataloader did not return the extra loader") + return val_dl[1] + else: + if getattr(dm, "extra_val_ds", None) is None: + raise RuntimeError("RRUFF dataset not available. Ensure 'extra_val_file' is set in the config (e.g., 'rruff.jsonl').") + from torch.utils.data import DataLoader + return DataLoader( + dm.extra_val_ds, + batch_size=dm.batch_size, + num_workers=dm.num_workers, + pin_memory=dm.pin_memory, + persistent_workers=dm.persistent_workers if dm.num_workers > 0 else False, + shuffle=False, + ) + elif split == "val": + dm.setup("validate") + val_dl = dm.val_dataloader() + if isinstance(val_dl, list): + return val_dl[0] + return val_dl + elif split == "test": + dm.setup("test") + return dm.test_dataloader() + elif split == "train": + dm.setup("fit") + return dm.train_dataloader() + else: + raise ValueError("split must be one of: rruff, val, test, train") + + +def run_inference_and_confusion(cfg_path: str, ckpt_path: Optional[str], task: str = "cs", normalize: bool = True, device: Optional[str] = None, out_dir: Optional[str] = None, split: str = "rruff") -> str: + cfg = load_config(cfg_path) + dm = build_datamodule_from_cfg(cfg) + dm.prepare_data() + + # Resolve which loader/split to evaluate + loader = _resolve_loader_for_split(dm, split) + + # Resolve checkpoint + ckpt_resolved = find_checkpoint(cfg, candidate_override=ckpt_path) + if not ckpt_resolved: + raise FileNotFoundError("Could not resolve a checkpoint. Pass --ckpt explicitly or ensure outputs/checkpoints exists.") + + # Build model class and load from checkpoint + ModelCls = build_model_class_from_cfg(cfg) + dev = device or ("cuda" if torch.cuda.is_available() else "cpu") + model = ModelCls.load_from_checkpoint(ckpt_resolved, map_location=dev) + model.eval() + model.to(dev) + + # Determine task specifics + if task not in ("cs", "sg"): + raise ValueError("task must be 'cs' or 'sg'") + num_classes = model.num_cs_classes if task == "cs" else model.num_sg_classes + logits_key = "cs_logits" if task == "cs" else "sg_logits" + label_key = "cs" if task == "cs" else "sg" + + # Collect predictions and targets + y_true_all: list[int] = [] + y_pred_all: list[int] = [] + + with torch.no_grad(): + for batch in tqdm(loader): + if not isinstance(batch, dict): + raise RuntimeError("Expected dict batches from NpyDataModule") + x = batch.get("x") + y = batch.get(label_key) + if x is None or y is None: + # Skip if label missing + continue + x = x.to(dev) + preds = model(x) + logits = preds[logits_key] + y_idx = model._to_index(y.to(dev), num_classes) + y_pred = logits.argmax(dim=1) + y_true_all.extend(y_idx.cpu().numpy().tolist()) + y_pred_all.extend(y_pred.cpu().numpy().tolist()) + + if len(y_true_all) == 0: + raise RuntimeError("No labeled samples found for the selected split/task") + + y_true_arr = np.asarray(y_true_all, dtype=np.int32) + y_pred_arr = np.asarray(y_pred_all, dtype=np.int32) + + # Use only classes present to keep plot size reasonable + present_labels = sorted(set(y_true_arr.tolist())) + cm = compute_confusion(y_true_arr, y_pred_arr, labels=present_labels, normalize=normalize) + + # Accuracy summary + acc = float((y_true_arr == y_pred_arr).mean()) + + # Output paths + out_base = out_dir or os.path.join("outputs", "eval") + os.makedirs(out_base, exist_ok=True) + split_tag = split.lower() + img_path = os.path.join(out_base, f"{split_tag}_confusion_{task}.png") + npy_path = os.path.join(out_base, f"{split_tag}_confusion_{task}.npy") + + # Save plot and raw matrix + plot_title = f"{split_tag.upper()} {task.upper()} confusion (normalize={'row' if normalize else 'none'}) | acc={acc:.3f} | classes={len(present_labels)}" + plot_confusion(cm, labels=present_labels, title=plot_title, save_path=img_path) + np.save(npy_path, cm) + + # Also save a text summary + summary_path = os.path.join(out_base, f"{split_tag}_confusion_{task}_summary.txt") + with open(summary_path, "w", encoding="utf-8") as f: + f.write(f"Checkpoint: {ckpt_resolved}\n") + f.write(f"Split: {split_tag}\n") + f.write(f"Task: {task}\n") + f.write(f"Accuracy: {acc:.4f}\n") + f.write(f"Classes present: {len(present_labels)}\n") + f.write(f"Labels: {present_labels}\n") + f.write(f"Image: {img_path}\n") + f.write(f"Matrix: {npy_path}\n") + + return img_path + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description="Run inference on a chosen split (RRUFF/val/test/train) and plot confusion matrix") + p.add_argument("config", type=str, help="Path to trainer config YAML (e.g., configs/trainer_convnext_paper.yaml)") + p.add_argument("--ckpt", type=str, default=None, help="Path to Lightning checkpoint (.ckpt). If omitted, attempts auto-discovery.") + p.add_argument("--task", type=str, default="cs", choices=["cs", "sg"], help="Which head to evaluate: cs (7 classes) or sg (230 classes)") + p.add_argument("--split", type=str, default="rruff", choices=["rruff", "val", "test", "train"], help="Dataset split to evaluate") + p.add_argument("--no-normalize", action="store_true", help="Disable row-normalization in confusion matrix") + p.add_argument("--device", type=str, default=None, help="Override device (cuda/cpu)") + p.add_argument("--out-dir", type=str, default=None, help="Directory to save outputs (default: outputs/eval)") + return p.parse_args() + + +def main() -> None: + args = parse_args() + normalize = not args.no_normalize + img_path = run_inference_and_confusion( + cfg_path=args.config, + ckpt_path=args.ckpt, + task=args.task, + normalize=normalize, + device=args.device, + out_dir=args.out_dir, + split=args.split, + ) + print(f"Saved confusion matrix to: {img_path}") + + +if __name__ == "__main__": + main() diff --git a/src/trainer/model/model.py b/src/trainer/model/model.py new file mode 100644 index 0000000..e4b80df --- /dev/null +++ b/src/trainer/model/model.py @@ -0,0 +1,530 @@ +from typing import Dict, Tuple, Optional, List, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import pytorch_lightning as pl + + +# ----------------------------- +# Utility: DropPath (Stochastic Depth) +# ----------------------------- +def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + # work with broadcastable noise shape + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor = random_tensor.floor() + return x.div(keep_prob) * random_tensor + + +class DropPath(nn.Module): + def __init__(self, drop_prob: float = 0.0): + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return drop_path(x, self.drop_prob, self.training) + + +# ----------------------------- +# ConvNeXt 1D Block +# Follows ConvNeXt design adapted to 1D: +# depthwise conv -> LN (channels-last) -> PW-MLP (expand 4x) -> GELU -> PW-MLP (project) -> layer-scale gamma -> residual (+ DropPath) +# ----------------------------- +class ConvNeXtBlock1D(nn.Module): + def __init__( + self, + dim: int, + kernel_size: int, + drop_path: float, + layer_scale_init_value: float, + activation: nn.Module, + ): + super().__init__() + # depthwise 1D conv + self.dwconv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding='same', groups=dim) + # pointwise MLP implemented by Linear on channels-last + self.pwconv1 = nn.Linear(dim, 4 * dim) + self.act = activation() if isinstance(activation, type) else activation + self.pwconv2 = nn.Linear(4 * dim, dim) + # layer-scale + self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(dim)) if layer_scale_init_value > 0 else None + # stochastic depth + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: (N, C, L) + shortcut = x + x = self.dwconv(x) # (N, C, L) + x = x.permute(0, 2, 1) # (N, L, C) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = x * self.gamma + x = x.permute(0, 2, 1) # (N, C, L) + x = shortcut + self.drop_path(x) + return x + + +class ConvNextBlock1DAdaptor(nn.Module): + """ + Adaptor that wraps ConvNeXtBlock1D to match OG MultiscaleCNNBackbone behavior. + Handles channel adjustment, stride-based downsampling, and block application. + """ + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + dropout: float, + use_batchnorm: bool, + activation: nn.Module, + layer_scale_init_value: float, + drop_path_rate: float, + block_type: str, + ): + super().__init__() + + # Optional pointwise conv for channel adjustment (if in_channels != out_channels) + if in_channels != out_channels: + act = activation() if isinstance(activation, type) else activation + self.pwconv = nn.Sequential(nn.Linear(in_channels, out_channels), act) + else: + self.pwconv = None + + # ConvNeXt block (only if block_type == "convnext") + if block_type == "convnext": + self.block = ConvNeXtBlock1D( + dim=out_channels, + kernel_size=kernel_size, + drop_path=drop_path_rate, + layer_scale_init_value=layer_scale_init_value, + activation=activation, + ) + else: + self.block = None + + # Optional stride-based pooling for downsampling + if stride > 1: + self.reduction_pool = nn.AvgPool1d(kernel_size=stride, stride=stride) + else: + self.reduction_pool = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: (N, C_in, L) + + # Channel adjustment via pointwise conv (operates on channels-last) + if self.pwconv is not None: + x = x.permute(0, 2, 1) # (N, L, C_in) + x = self.pwconv(x) # (N, L, C_out) + x = x.permute(0, 2, 1) # (N, C_out, L) + + # Apply ConvNeXt block + if self.block is not None: + x = self.block(x) + + # Apply stride-based downsampling + if self.reduction_pool is not None: + x = self.reduction_pool(x) + + return x + + +# ----------------------------- +# Downsample: LN (channels-last) -> Conv1d (stride, channel increase) +# ----------------------------- + + +# ----------------------------- +# Backbone: ConvNeXt1D (generalized to N stages) +# ----------------------------- + + +# ----------------------------- +# Heads: simple MLP builders +# ----------------------------- +def make_mlp( + input_dim: int, + hidden_dims: Optional[Tuple[int, ...]], + output_dim: int, + dropout: float = 0.2, + output_activation: Optional[nn.Module] = None, +) -> nn.Module: + layers: List[nn.Module] = [] + last = input_dim + if hidden_dims is not None and len(hidden_dims) > 0: + for hd in hidden_dims: + layers.extend([nn.Linear(last, hd), nn.ReLU()]) + if dropout and dropout > 0: + layers.append(nn.Dropout(dropout)) + last = hd + layers.append(nn.Linear(last, output_dim)) + if output_activation is not None: + layers.append(output_activation) + return nn.Sequential(*layers) + + +# ----------------------------- +# Lightning Module: AlphaDiffractLightning +# ----------------------------- +BatchType = Union[ + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + Dict[str, torch.Tensor], +] + + + + +# ----------------------------- +# Mirrors alphadiffract.model.MultiscaleCNNBackbone behavior: +# - sequential conv stages with specified kernel_sizes and strides +# - optional average/max pooling between stages and at the end +# - output_type: 'gap' or 'flatten' +# ----------------------------- +class MultiscaleCNNBackbone1D(nn.Module): + def __init__( + self, + dim_in: int, + channels: Tuple[int, ...], + kernel_sizes: Tuple[int, ...], + strides: Tuple[int, ...], + dropout_rate: float, + ramped_dropout_rate: bool, + block_type: str, + pooling_type: str, + final_pool: bool, + use_batchnorm: bool, + activation: nn.Module, + output_type: str, + layer_scale_init_value: float, + drop_path_rate: float, + ): + super().__init__() + assert len(channels) == len(kernel_sizes) == len(strides), "channels, kernel_sizes, strides must match lengths" + self.dim_in = dim_in + self.output_type = output_type + + if ramped_dropout_rate: + dropout_per_stage = torch.linspace(0.0, dropout_rate, steps=len(channels)).tolist() + else: + dropout_per_stage = [dropout_rate] * len(channels) + + if pooling_type == "average": + pool_cls = nn.AvgPool1d + pool_kwargs = {"kernel_size": 3, "stride": 2} + elif pooling_type == "max": + pool_cls = nn.MaxPool1d + pool_kwargs = {"kernel_size": 2, "stride": 2} + else: + raise ValueError(f"Invalid pooling_type '{pooling_type}'") + + layers: List[nn.Module] = [] + in_ch = 1 + for i, (out_ch, k, s) in enumerate(zip(channels, kernel_sizes, strides)): + stage_block = ConvNextBlock1DAdaptor( + in_channels=in_ch, + out_channels=out_ch, + kernel_size=k, + stride=s, + dropout=dropout_per_stage[i], + use_batchnorm=use_batchnorm, + activation=activation, + layer_scale_init_value=layer_scale_init_value, + drop_path_rate=drop_path_rate, + block_type=block_type, + ) + layers.append(stage_block) + + # Inter-stage pooling + if i < len(channels) - 1 or final_pool: + layers.append(pool_cls(**pool_kwargs)) + + in_ch = out_ch + + self.net = nn.Sequential(*layers) + + # Determine output feature dimension + if self.output_type == "gap": + self.dim_output = channels[-1] + elif self.output_type == "flatten": + with torch.no_grad(): + dummy = torch.zeros(1, 1, self.dim_in) + out = self.net(dummy) + self.dim_output = int(out.shape[1] * out.shape[2]) + else: + raise ValueError(f"Invalid output_type '{self.output_type}'") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Accept (N, L) or (N, 1, L) + if x.ndim == 2: + x = x[:, None, :] + x = self.net(x) + if self.output_type == "gap": + x = x.mean(dim=-1) + else: # flatten + x = x.reshape(x.shape[0], -1) + return x + + +# ----------------------------- +# Lightning Module using MultiscaleCNNBackbone1D +# Same heads, losses, and metrics behavior as AlphaDiffractLightning +# ----------------------------- +class AlphaDiffractMultiscaleLightning(pl.LightningModule): + def __init__( + self, + dim_in: int, + channels: Tuple[int, ...], + kernel_sizes: Tuple[int, ...], + strides: Tuple[int, ...], + dropout_rate: float, + ramped_dropout_rate: bool, + block_type: str, + pooling_type: str, + final_pool: bool, + use_batchnorm: bool, + activation: nn.Module, + output_type: str, + layer_scale_init_value: float, + drop_path_rate: float, + + head_dropout: float, + cs_hidden: Optional[Tuple[int, ...]], + sg_hidden: Optional[Tuple[int, ...]], + lp_hidden: Optional[Tuple[int, ...]], + + num_cs_classes: int, + num_sg_classes: int, + num_lp_outputs: int, + + lp_bounds_min: Tuple[float, float, float, float, float, float], + lp_bounds_max: Tuple[float, float, float, float, float, float], + bound_lp_with_sigmoid: bool, + + lambda_cs: float, + lambda_sg: float, + lambda_lp: float, + + lr: float, + weight_decay: float, + use_adamw: bool, + steps_per_epoch: int, + ): + super().__init__() + self.save_hyperparameters() + + self.backbone = MultiscaleCNNBackbone1D( + dim_in=dim_in, + channels=channels, + kernel_sizes=kernel_sizes, + strides=strides, + dropout_rate=dropout_rate, + ramped_dropout_rate=ramped_dropout_rate, + block_type=block_type, + pooling_type=pooling_type, + final_pool=final_pool, + use_batchnorm=use_batchnorm, + activation=activation, + output_type=output_type, + layer_scale_init_value=layer_scale_init_value, + drop_path_rate=drop_path_rate, + ) + feat_dim = self.backbone.dim_output + + self.cs_head = make_mlp( + input_dim=feat_dim, + hidden_dims=cs_hidden, + output_dim=num_cs_classes, + dropout=head_dropout, + output_activation=None, + ) + self.sg_head = make_mlp( + input_dim=feat_dim, + hidden_dims=sg_hidden, + output_dim=num_sg_classes, + dropout=head_dropout, + output_activation=None, + ) + self.lp_head = make_mlp( + input_dim=feat_dim, + hidden_dims=lp_hidden, + output_dim=num_lp_outputs, + dropout=head_dropout, + output_activation=None, + ) + + self.ce = nn.CrossEntropyLoss() + self.mse = nn.MSELoss() + self.register_buffer("lp_min", torch.tensor(lp_bounds_min, dtype=torch.float32)) + self.register_buffer("lp_max", torch.tensor(lp_bounds_max, dtype=torch.float32)) + self.bound_lp_with_sigmoid = bound_lp_with_sigmoid + + self.lambda_cs = lambda_cs + self.lambda_sg = lambda_sg + self.lambda_lp = lambda_lp + self.lr = lr + self.weight_decay = weight_decay + self.use_adamw = use_adamw + self.steps_per_epoch = steps_per_epoch + + self.num_cs_classes = num_cs_classes + self.num_sg_classes = num_sg_classes + self.num_lp_outputs = num_lp_outputs + + def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: + feats = self.backbone(x) + cs_logits = self.cs_head(feats) + sg_logits = self.sg_head(feats) + lp = self.lp_head(feats) + if self.bound_lp_with_sigmoid: + lp = torch.sigmoid(lp) * (self.lp_max - self.lp_min) + self.lp_min + return {"features": feats, "cs_logits": cs_logits, "sg_logits": sg_logits, "lp": lp} + + @staticmethod + def _to_index(y: torch.Tensor, num_classes: int) -> torch.Tensor: + if y.dim() > 1 and y.size(-1) > 1: + idx = y.argmax(dim=-1) + else: + idx = y.long() + with torch.no_grad(): + if idx.numel() > 0: + idx = idx.clamp(min=0, max=num_classes - 1) + return idx + + def _extract_batch(self, batch: BatchType) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + if isinstance(batch, (list, tuple)): + assert len(batch) >= 3, "Expected at least (x, cs, sg) in the batch tuple." + if len(batch) >= 4: + x, y_cs, y_sg, y_lp = batch[:4] + else: + x, y_cs, y_sg = batch[:3] + y_lp = None + elif isinstance(batch, dict): + x = batch.get("x", batch.get("xrd", batch.get("signal"))) + if x is None: + raise KeyError("Batch dict must contain 'x' or 'xrd' or 'signal'.") + y_cs = batch.get("cs") + y_sg = batch.get("sg") + y_lp = batch.get("lattice_params", batch.get("lp")) + if y_cs is None or y_sg is None: + raise KeyError("Batch dict must contain 'cs' and 'sg'. 'lattice_params' (or 'lp') is optional.") + else: + raise TypeError("Unsupported batch type. Use Tuple or Dict.") + return x, y_cs, y_sg, y_lp + + def _compute_losses_and_metrics( + self, preds: Dict[str, torch.Tensor], y_cs: torch.Tensor, y_sg: torch.Tensor, y_lp: Optional[torch.Tensor] + ) -> Dict[str, torch.Tensor]: + cs_logits = preds["cs_logits"] + sg_logits = preds["sg_logits"] + lp_pred = preds["lp"] + y_cs_idx = self._to_index(y_cs, self.num_cs_classes) + y_sg_idx = self._to_index(y_sg, self.num_sg_classes) + if y_lp is not None: + y_lp = y_lp.float() + + loss_cs = self.ce(cs_logits, y_cs_idx) + loss_sg = self.ce(sg_logits, y_sg_idx) + loss_lp = self.mse(lp_pred, y_lp) if y_lp is not None else torch.tensor(0.0, device=cs_logits.device) + total_loss = self.lambda_cs * loss_cs + self.lambda_sg * loss_sg + self.lambda_lp * loss_lp + + with torch.no_grad(): + cs_acc = (cs_logits.argmax(dim=1) == y_cs_idx).float().mean() + sg_acc = (sg_logits.argmax(dim=1) == y_sg_idx).float().mean() + if y_lp is not None: + lp_mae = (lp_pred - y_lp).abs().mean() + lp_mse = F.mse_loss(lp_pred, y_lp) + else: + lp_mae = None + lp_mse = None + + return { + "loss_total": total_loss, + "loss_cs": loss_cs, + "loss_sg": loss_sg, + "loss_lp": loss_lp, + "cs_acc": cs_acc, + "sg_acc": sg_acc, + "lp_mae": lp_mae, + "lp_mse": lp_mse, + } + + def training_step(self, batch: BatchType, batch_idx: int) -> torch.Tensor: + x, y_cs, y_sg, y_lp = self._extract_batch(batch) + preds = self(x) + out = self._compute_losses_and_metrics(preds, y_cs, y_sg, y_lp) + self.log("train/loss", out["loss_total"], prog_bar=True, on_step=True, on_epoch=True) + self.log("train/loss_cs", out["loss_cs"], on_step=True, on_epoch=True) + self.log("train/loss_sg", out["loss_sg"], on_step=True, on_epoch=True) + self.log("train/loss_lp", out["loss_lp"], on_step=True, on_epoch=True) + self.log("train/cs_acc", out["cs_acc"], prog_bar=True, on_step=True, on_epoch=True) + self.log("train/sg_acc", out["sg_acc"], on_step=True, on_epoch=True) + if out.get("lp_mae") is not None: + self.log("train/lp_mae", out["lp_mae"], on_step=True, on_epoch=True) + if out.get("lp_mse") is not None: + self.log("train/lp_mse", out["lp_mse"], on_step=True, on_epoch=True) + return out["loss_total"] + + def validation_step(self, batch: BatchType, batch_idx: int, dataloader_idx: Optional[int] = None) -> None: + x, y_cs, y_sg, y_lp = self._extract_batch(batch) + preds = self(x) + out = self._compute_losses_and_metrics(preds, y_cs, y_sg, y_lp) + prefix = "val" if (dataloader_idx is None or dataloader_idx == 0) else "val_rruff" + self.log(f"{prefix}/loss", out["loss_total"], prog_bar=True, on_epoch=True, add_dataloader_idx=False) + self.log(f"{prefix}/loss_cs", out["loss_cs"], on_epoch=True, add_dataloader_idx=False) + self.log(f"{prefix}/loss_sg", out["loss_sg"], on_epoch=True, add_dataloader_idx=False) + self.log(f"{prefix}/loss_lp", out["loss_lp"], on_epoch=True, add_dataloader_idx=False) + self.log(f"{prefix}/cs_acc", out["cs_acc"], prog_bar=True, on_epoch=True, add_dataloader_idx=False) + self.log(f"{prefix}/sg_acc", out["sg_acc"], on_epoch=True, add_dataloader_idx=False) + if out.get("lp_mae") is not None: + self.log(f"{prefix}/lp_mae", out["lp_mae"], on_epoch=True, add_dataloader_idx=False) + if out.get("lp_mse") is not None: + self.log(f"{prefix}/lp_mse", out["lp_mse"], on_epoch=True, add_dataloader_idx=False) + + def test_step(self, batch: BatchType, batch_idx: int) -> None: + x, y_cs, y_sg, y_lp = self._extract_batch(batch) + preds = self(x) + out = self._compute_losses_and_metrics(preds, y_cs, y_sg, y_lp) + self.log("test/loss", out["loss_total"], prog_bar=True, on_epoch=True) + self.log("test/loss_cs", out["loss_cs"], on_epoch=True) + self.log("test/loss_sg", out["loss_sg"], on_epoch=True) + self.log("test/loss_lp", out["loss_lp"], on_epoch=True) + self.log("test/cs_acc", out["cs_acc"], prog_bar=True, on_epoch=True) + self.log("test/sg_acc", out["sg_acc"], on_epoch=True) + if out.get("lp_mae") is not None: + self.log("test/lp_mae", out["lp_mae"], on_epoch=True) + if out.get("lp_mse") is not None: + self.log("test/lp_mse", out["lp_mse"], on_epoch=True) + + def configure_optimizers(self): + params = self.parameters() + if self.use_adamw: + optimizer = torch.optim.AdamW(params, lr=self.lr, weight_decay=self.weight_decay) + else: + optimizer = torch.optim.Adam(params, lr=self.lr, weight_decay=self.weight_decay) + + step_size_up = int(6 * self.steps_per_epoch) + + scheduler = torch.optim.lr_scheduler.CyclicLR( + optimizer, + base_lr=self.lr * 0.1, + max_lr=self.lr, + step_size_up=step_size_up, + cycle_momentum=False, + mode="triangular2", + ) + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "step", # per-batch stepping, identical to OG + "name": "cyclic_lr", + }, + } diff --git a/src/trainer/train.py b/src/trainer/train.py new file mode 100644 index 0000000..f4b51e1 --- /dev/null +++ b/src/trainer/train.py @@ -0,0 +1,312 @@ +import argparse +import os +from typing import Dict, Any, Optional + +import torch +import yaml +from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, Callback +from pytorch_lightning.loggers import CSVLogger +try: + from pytorch_lightning.loggers import MLFlowLogger +except Exception: + MLFlowLogger = None + +from dataset import NpyDataModule +from model.model import AlphaDiffractMultiscaleLightning + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description="Train AlphaDiffract paper model (config-required)") + p.add_argument("config", type=str, help="Path to trainer config YAML (e.g., configs/trainer.yaml)") + return p.parse_args() + + +def load_config(path: str) -> Dict[str, Any]: + if not os.path.isfile(path): + raise FileNotFoundError(f"Config file not found: {path}") + with open(path, "r", encoding="utf-8") as f: + cfg = yaml.safe_load(f) + if not isinstance(cfg, dict): + raise ValueError(f"Config must be a mapping (YAML dict), got: {type(cfg)}") + return cfg + + +def _to_dtype(name: str) -> torch.dtype: + table = { + "float32": torch.float32, + "float64": torch.float64, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + } + if name not in table: + raise ValueError(f"Unsupported dtype '{name}'. Allowed: {list(table.keys())}") + return table[name] + + +def build_datamodule_from_cfg(cfg: Dict[str, Any]) -> NpyDataModule: + data_cfg = cfg["data"] + loader_cfg = data_cfg["loader"] + prep_cfg = data_cfg["preprocessing"] + aug_cfg = data_cfg["augmentation"] + + # Dataset kwargs from config + dataset_kwargs = { + "dtype": _to_dtype(prep_cfg["dtype"]), + "mmap_mode": prep_cfg["mmap_mode"], + "return_meta": True, + "validate_paths": prep_cfg["validate_paths"], + "extract_labels": prep_cfg["extract_labels"], + "allow_pickle": prep_cfg["allow_pickle"], + "floor_at_zero": prep_cfg["floor_at_zero"], + "normalize_log1p": prep_cfg["normalize_log1p"], + "shift_labels": prep_cfg["shift_labels"], + } + labels_key_map = prep_cfg["labels_key_map"] + if labels_key_map is not None: + dataset_kwargs["labels_key_map"] = labels_key_map + + dm = NpyDataModule( + manifest_dir=data_cfg["manifest_dir"], + batch_size=loader_cfg["batch_size"], + num_workers=loader_cfg["num_workers"], + pin_memory=loader_cfg["pin_memory"], + persistent_workers=loader_cfg["persistent_workers"] and loader_cfg["num_workers"] > 0, + prefetch_factor=loader_cfg["prefetch_factor"], + train_file=loader_cfg["train_file"], + val_file=loader_cfg["val_file"], + test_file=loader_cfg["test_file"], + dataset_kwargs=dataset_kwargs, + dataset_root=data_cfg["dataset_root"], + auto_generate_manifests=data_cfg["auto_generate_manifests"], + train_ratio=data_cfg["train_ratio"], + val_ratio=data_cfg["val_ratio"], + test_ratio=data_cfg["test_ratio"], + seed=data_cfg["seed"], + extra_val_file=data_cfg["extra_val_file"], + # Optional noise augmentation: apply to training split only + noise_poisson_range=tuple(aug_cfg["noise_poisson_range"]) if aug_cfg["noise_poisson_range"] is not None else None, + noise_gaussian_range=tuple(aug_cfg["noise_gaussian_range"]) if aug_cfg["noise_gaussian_range"] is not None else None, + # Optional standardization after noise to match OG runs + standardize_to=tuple(aug_cfg["standardize_to"]) if aug_cfg["standardize_to"] is not None else None, + ) + return dm + + +def build_model_from_cfg(cfg: Dict[str, Any], steps_per_epoch: int): + model_cfg = cfg["model"] + backbone_cfg = model_cfg["backbone"] + heads_cfg = model_cfg["heads"] + tasks_cfg = model_cfg["tasks"] + loss_cfg = model_cfg["loss"] + optim_cfg = cfg["optimizer"] + + model_type = model_cfg["type"].lower() + + act_name = backbone_cfg["activation"].lower() + if act_name == "leaky_relu": + activation = torch.nn.LeakyReLU + elif act_name == "relu": + activation = torch.nn.ReLU + elif act_name == "gelu": + activation = torch.nn.GELU + else: + raise ValueError(f"Unsupported activation '{act_name}'") + + if model_type == "multiscale": + return AlphaDiffractMultiscaleLightning( + # Map OG-style multiscale CNN params from cfg (use dims as channels) + dim_in=backbone_cfg["dim_in"], + channels=tuple(backbone_cfg["dims"]), + kernel_sizes=tuple(backbone_cfg["kernel_sizes"]), + strides=tuple(backbone_cfg["strides"]), + dropout_rate=backbone_cfg["dropout_rate"], + ramped_dropout_rate=backbone_cfg["ramped_dropout_rate"], + block_type=backbone_cfg["block_type"], + pooling_type=backbone_cfg["pooling_type"], + final_pool=backbone_cfg["final_pool"], + use_batchnorm=backbone_cfg["use_batchnorm"], + activation=activation, + output_type=backbone_cfg["output_type"], + layer_scale_init_value=backbone_cfg["layer_scale_init_value"], + drop_path_rate=backbone_cfg["drop_path_rate"], + + head_dropout=heads_cfg["head_dropout"], + cs_hidden=tuple(heads_cfg["cs_hidden"]), + sg_hidden=tuple(heads_cfg["sg_hidden"]), + lp_hidden=tuple(heads_cfg["lp_hidden"]), + + num_cs_classes=tasks_cfg["num_cs_classes"], + num_sg_classes=tasks_cfg["num_sg_classes"], + num_lp_outputs=tasks_cfg["num_lp_outputs"], + + lp_bounds_min=tuple(tasks_cfg["lp_bounds_min"]), + lp_bounds_max=tuple(tasks_cfg["lp_bounds_max"]), + bound_lp_with_sigmoid=tasks_cfg["bound_lp_with_sigmoid"], + + lambda_cs=loss_cfg["lambda_cs"], + lambda_sg=loss_cfg["lambda_sg"], + lambda_lp=loss_cfg["lambda_lp"], + + lr=optim_cfg["lr"], + weight_decay=optim_cfg["weight_decay"], + use_adamw=optim_cfg["use_adamw"], + steps_per_epoch=steps_per_epoch, + ) + else: + raise ValueError(f"Unsupported model_type '{model_type}'. Expected 'multiscale'.") + + +class ConfigArtifactLogger(Callback): + """ + Minimal callback: if MLFlowLogger is the active logger, log the raw YAML config + as an MLflow artifact under 'configs/'. + """ + def __init__(self, raw_config_path: str): + self.raw_config_path = raw_config_path + + def on_fit_start(self, trainer, pl_module) -> None: + logger = getattr(trainer, "logger", None) + if MLFlowLogger is not None and isinstance(logger, MLFlowLogger): + logger.experiment.log_artifact(logger.run_id, self.raw_config_path) + + +class RunCheckpointDirCallback(Callback): + """ + On fit start, create a unique sub-folder for checkpoints and, if MLflow is the logger, + log the folder path to MLflow so each run's checkpoints are easily discoverable. + """ + def __init__(self, base_dir: Optional[str] = None): + self.base_dir = base_dir + + def on_fit_start(self, trainer, pl_module) -> None: + # Determine base checkpoints directory + base_dir = self.base_dir or os.path.join(trainer.default_root_dir, "checkpoints") + + # Derive a unique run-specific subfolder + logger = getattr(trainer, "logger", None) + run_suffix: Optional[str] = None + if MLFlowLogger is not None and isinstance(logger, MLFlowLogger): + # Prefer the MLflow run_id when available + run_suffix = getattr(logger, "run_id", None) + + if not run_suffix: + # Fallback to timestamp-based folder if no MLflow run_id + from datetime import datetime + run_suffix = datetime.now().strftime("%Y%m%d_%H%M%S") + + run_ckpt_dir = os.path.join(base_dir, run_suffix) + os.makedirs(run_ckpt_dir, exist_ok=True) + + # Update any ModelCheckpoint callbacks to write into the run-specific directory + if hasattr(trainer, "checkpoint_callback") and trainer.checkpoint_callback is not None: + trainer.checkpoint_callback.dirpath = run_ckpt_dir + for cb in trainer.callbacks: + if isinstance(cb, ModelCheckpoint): + cb.dirpath = run_ckpt_dir + + # If MLflow is active, log the checkpoint directory path for this run + if MLFlowLogger is not None and isinstance(logger, MLFlowLogger): + try: + # Log as a MLflow param and tag for easy discovery + logger.experiment.log_param(logger.run_id, "checkpoint_dir", run_ckpt_dir) + logger.experiment.set_tag(logger.run_id, "checkpoint_dir", run_ckpt_dir) + except Exception: + pass + + +class MlflowShutdownCallback(Callback): + """ + Ensures MLflow runs are properly finalized on Ctrl-C (SIGINT) interruptions. + """ + def on_exception(self, trainer, pl_module, exception): + """Called when an exception occurs during training.""" + if isinstance(exception, KeyboardInterrupt): + logger = getattr(trainer, "logger", None) + if MLFlowLogger is not None and isinstance(logger, MLFlowLogger): + try: + logger.experiment.set_terminated(logger.run_id, status="KILLED") + print("\nMLflow run terminated with status KILLED") + except Exception as e: + print(f"\nWarning: Could not terminate MLflow run: {e}") + + +def build_trainer_from_cfg(cfg: Dict[str, Any], raw_config_path: Optional[str] = None) -> Trainer: + trainer_cfg = cfg["trainer"] + log_cfg = cfg["logging"] + ckpt_cfg = cfg["checkpointing"] + optim_cfg = cfg["optimizer"] + + ckpt_cb = ModelCheckpoint( + monitor=ckpt_cfg["monitor"], + mode=ckpt_cfg["mode"], + save_top_k=ckpt_cfg["save_top_k"], + dirpath=os.path.join(trainer_cfg["default_root_dir"], "checkpoints"), + filename="epoch{epoch:03d}-val_loss{val/loss:.4f}", + save_last=True, + every_n_epochs=ckpt_cfg["every_n_epochs"], + auto_insert_metric_name=False, + ) + lr_cb = LearningRateMonitor(logging_interval="epoch") + + # Configure logger from config + logger = None + if log_cfg["logger"] == "csv": + logger = CSVLogger(save_dir=trainer_cfg["default_root_dir"], name=log_cfg["csv_logger_name"]) + elif log_cfg["logger"] == "mlflow": + if MLFlowLogger is None: + raise ImportError("MLFlowLogger requested but 'mlflow' is not installed. Install with `pip install mlflow`.") + logger = MLFlowLogger( + experiment_name=log_cfg["mlflow_experiment_name"], + tracking_uri=log_cfg["mlflow_tracking_uri"], + run_name=log_cfg["mlflow_run_name"], + ) + + trainer = Trainer( + default_root_dir=trainer_cfg["default_root_dir"], + max_epochs=trainer_cfg["max_epochs"], + accelerator=trainer_cfg["accelerator"], + devices=trainer_cfg["devices"], + precision=trainer_cfg["precision"], + accumulate_grad_batches=trainer_cfg["accumulate_grad_batches"], + callbacks=[ckpt_cb, lr_cb, ConfigArtifactLogger(raw_config_path), RunCheckpointDirCallback(), MlflowShutdownCallback()], + logger=logger, + log_every_n_steps=trainer_cfg["log_every_n_steps"], + deterministic=trainer_cfg["deterministic"], + benchmark=trainer_cfg["benchmark"], + gradient_clip_val=optim_cfg["gradient_clip_val"], + gradient_clip_algorithm=optim_cfg["gradient_clip_algorithm"], + val_check_interval=0.5 + ) + return trainer + + +def main(): + + args = parse_args() + cfg = load_config(args.config) + + seed_everything(cfg["data"]["seed"], workers=True) + + torch.set_float32_matmul_precision('high') + + dm = build_datamodule_from_cfg(cfg) + # Explicitly setup datamodule to calculate steps_per_epoch + dm.setup("fit") + steps_per_epoch = len(dm.train_dataloader()) + + model = build_model_from_cfg(cfg, steps_per_epoch=steps_per_epoch) + trainer = build_trainer_from_cfg(cfg, raw_config_path=args.config) + + # Train + resume_from: Optional[str] = cfg["checkpointing"]["resume_from"] + trainer.fit(model, datamodule=dm, ckpt_path=resume_from) + + # Test best model if requested + if cfg["checkpointing"]["test_after_train"]: + ckpt_path = trainer.checkpoint_callback.best_model_path if trainer.checkpoint_callback else None + trainer.test(model=model, datamodule=dm, ckpt_path=ckpt_path or "best") + + +if __name__ == "__main__": + main()