From 024ceb0d7f008880fce26849bde26e8c24e2bad6 Mon Sep 17 00:00:00 2001 From: linked-liszt Date: Mon, 3 Nov 2025 10:35:55 -0600 Subject: [PATCH 01/18] feat: Add initial dataset implementation --- src/trainer/dataset/__init__.py | 39 +++++ src/trainer/dataset/datamodule.py | 220 +++++++++++++++++++++++++ src/trainer/dataset/dataset.py | 195 ++++++++++++++++++++++ src/trainer/dataset/manifest_utils.py | 223 ++++++++++++++++++++++++++ src/trainer/generate_manifests.py | 146 +++++++++++++++++ 5 files changed, 823 insertions(+) create mode 100644 src/trainer/dataset/__init__.py create mode 100644 src/trainer/dataset/datamodule.py create mode 100644 src/trainer/dataset/dataset.py create mode 100644 src/trainer/dataset/manifest_utils.py create mode 100644 src/trainer/generate_manifests.py diff --git a/src/trainer/dataset/__init__.py b/src/trainer/dataset/__init__.py new file mode 100644 index 0000000..890c077 --- /dev/null +++ b/src/trainer/dataset/__init__.py @@ -0,0 +1,39 @@ +""" +Dataset package for training. + +Assumptions: +- Import paths use the local 'dataset' package. +- Manifests are JSON Lines and may include an optional first-line meta header: + {"__meta__": {"version": 1, "base_dir": ""}} + When present, non-absolute file paths in records are resolved relative to base_dir. + base_dir itself may be relative to the manifest file's directory. This makes manifests + independent of the current working directory. +- Legacy manifests without the meta header remain supported; their file paths are used as-is. + +Exports: +- NpyManifestDataset: Map-style dataset loading .npy files listed in JSONL manifests. +- NpyDataModule: LightningDataModule wiring datasets and DataLoaders. +- generate_manifests: Utility to create train/val/test manifests split by material ID. +- ManifestStats: Summary dataclass for manifest generation. +""" + +from .dataset import NpyManifestDataset, default_manifest_paths +from .datamodule import NpyDataModule +from .manifest_utils import ( + generate_manifests, + ManifestStats, + scan_dataset_root, + split_materials, + write_jsonl_manifest, +) + +__all__ = [ + "NpyManifestDataset", + "default_manifest_paths", + "NpyDataModule", + "generate_manifests", + "ManifestStats", + "scan_dataset_root", + "split_materials", + "write_jsonl_manifest", +] diff --git a/src/trainer/dataset/datamodule.py b/src/trainer/dataset/datamodule.py new file mode 100644 index 0000000..d25b591 --- /dev/null +++ b/src/trainer/dataset/datamodule.py @@ -0,0 +1,220 @@ +""" +LightningDataModule for .npy datasets using JSONL manifests. + +This module wires NpyManifestDataset to PyTorch Lightning and can optionally +auto-generate manifests from a dataset root if they are missing. + +Typical usage: + from pytorch_lightning import Trainer + from dataset.datamodule import NpyDataModule + + dm = NpyDataModule( + manifest_dir="data/manifests", + batch_size=64, + num_workers=8, + pin_memory=True, + persistent_workers=True, + # Optional: auto-generate manifests if missing + dataset_root="data/dataset", + auto_generate_manifests=True, + train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, seed=42, + # Dataset-specific kwargs + dataset_kwargs={"dtype": torch.float32, "mmap_mode": "r", "return_meta": True}, + ) + + trainer = Trainer(max_epochs=10, accelerator="auto", devices="auto") + trainer.fit(model, dm) + trainer.test(model, dm) + +Notes: +- Splitting is performed by material ID when generating manifests (never per-file). +- Manifests avoid scanning the entire dataset during training. +- Manifest format is JSON Lines with optional meta header: + {"__meta__": {"version": 1, "base_dir": "../dataset"}} + {"material_id": "...", "files": ["rel/path/to/file.npy", ...]} +""" + +from __future__ import annotations + +import os +from typing import Any, Callable, Dict, Optional + +# Resolve project root for robust default paths independent of cwd (src/trainer) +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) + +import torch +from torch.utils.data import DataLoader +try: + # Prefer importing Lightning; if not installed, this file will error at import time. + import pytorch_lightning as pl +except Exception as e: + raise RuntimeError( + "pytorch_lightning must be installed to use NpyDataModule. " + "Install with: pip install pytorch-lightning" + ) from e + +from .dataset import NpyManifestDataset, default_manifest_paths +from .manifest_utils import generate_manifests + + +class NpyDataModule(pl.LightningDataModule): + """ + LightningDataModule that reads train/val/test JSONL manifests and constructs DataLoaders. + + Args: + manifest_dir: Directory containing train.jsonl, val.jsonl, test.jsonl. + batch_size: Batch size for all splits (you can override per-split if needed). + num_workers: Number of DataLoader workers. + pin_memory: Enable DataLoader pin_memory (recommended for CUDA). + persistent_workers: Keep workers alive between epochs (speed-up for long runs). + collate_fn: Optional collate function (defaults to PyTorch's default). + dataset_cls: Dataset class to use (default: NpyManifestDataset). + dataset_kwargs: Additional kwargs forwarded to dataset_cls for all splits. + dataset_root: If provided and auto_generate_manifests=True, used to scan for materials. + auto_generate_manifests: If True, generate manifests in prepare_data if missing. + train_ratio, val_ratio, test_ratio: Ratios for splitting materials (must sum to 1.0). + seed: Random seed for reproducible material ID splits. + train_file, val_file, test_file: Manifest file names in manifest_dir (defaults provided). + """ + + def __init__( + self, + manifest_dir: str = "data/manifests", + batch_size: int = 32, + num_workers: int = 4, + pin_memory: bool = True, + persistent_workers: bool = True, + collate_fn: Optional[Callable] = None, + dataset_cls: type = NpyManifestDataset, + dataset_kwargs: Optional[Dict[str, Any]] = None, + # Optional manifest auto-generation + dataset_root: Optional[str] = None, + auto_generate_manifests: bool = False, + train_ratio: float = 0.8, + val_ratio: float = 0.1, + test_ratio: float = 0.1, + seed: int = 42, + # Custom manifest filenames (within manifest_dir) + train_file: str = "train.jsonl", + val_file: str = "val.jsonl", + test_file: str = "test.jsonl", + ) -> None: + super().__init__() + self.manifest_dir = manifest_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.pin_memory = pin_memory + self.persistent_workers = persistent_workers + self.collate_fn = collate_fn + + self.dataset_cls = dataset_cls + self.dataset_kwargs = dataset_kwargs or {} + + self.dataset_root = dataset_root + self.auto_generate_manifests = auto_generate_manifests + self.train_ratio = train_ratio + self.val_ratio = val_ratio + self.test_ratio = test_ratio + self.seed = seed + + self.train_file = train_file + self.val_file = val_file + self.test_file = test_file + + # Internal datasets + self.train_ds = None + self.val_ds = None + self.test_ds = None + + def _manifest_paths(self) -> Dict[str, str]: + # Allow overriding filenames while still respecting manifest_dir + base = self.manifest_dir + if not os.path.isabs(base): + base = os.path.join(PROJECT_ROOT, base) + paths = { + "train": os.path.join(base, self.train_file), + "val": os.path.join(base, self.val_file), + "test": os.path.join(base, self.test_file), + } + return paths + + def prepare_data(self) -> None: + """ + Single-process hook. Optionally generate manifests if they are missing. + Do NOT assign state like self.train_ds here. + """ + if not self.auto_generate_manifests: + return + + paths = self._manifest_paths() + all_exist = all(os.path.isfile(p) for p in paths.values()) + if all_exist: + return + + if self.dataset_root is None: + raise RuntimeError( + "auto_generate_manifests=True but dataset_root was not provided. " + "Provide dataset_root to scan and create manifests." + ) + + # Resolve absolute paths for dataset root and manifest dir + dataset_root_abs = self.dataset_root + if not os.path.isabs(dataset_root_abs): + dataset_root_abs = os.path.join(PROJECT_ROOT, dataset_root_abs) + manifest_dir_abs = os.path.dirname(paths["train"]) + + # Generate manifests in a reproducible way by material ID. + _stats = generate_manifests( + dataset_root=dataset_root_abs, + manifest_dir=manifest_dir_abs, + train_ratio=self.train_ratio, + val_ratio=self.val_ratio, + test_ratio=self.test_ratio, + seed=self.seed, + ) + # No state assignment here; stats can be logged by the caller if desired. + + def setup(self, stage: Optional[str] = None) -> None: + """ + Per-process hook. Instantiate datasets from manifest files. + """ + paths = self._manifest_paths() + if self.train_ds is None: + self.train_ds = self.dataset_cls(paths["train"], **self.dataset_kwargs) + if self.val_ds is None: + self.val_ds = self.dataset_cls(paths["val"], **self.dataset_kwargs) + if self.test_ds is None: + self.test_ds = self.dataset_cls(paths["test"], **self.dataset_kwargs) + + def train_dataloader(self) -> DataLoader: + return DataLoader( + self.train_ds, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers if self.num_workers > 0 else False, + shuffle=True, + collate_fn=self.collate_fn, + ) + + def val_dataloader(self) -> DataLoader: + return DataLoader( + self.val_ds, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers if self.num_workers > 0 else False, + shuffle=False, + collate_fn=self.collate_fn, + ) + + def test_dataloader(self) -> DataLoader: + return DataLoader( + self.test_ds, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers if self.num_workers > 0 else False, + shuffle=False, + collate_fn=self.collate_fn, + ) diff --git a/src/trainer/dataset/dataset.py b/src/trainer/dataset/dataset.py new file mode 100644 index 0000000..c8c42c3 --- /dev/null +++ b/src/trainer/dataset/dataset.py @@ -0,0 +1,195 @@ +""" +Dataset definition for loading .npy samples using precomputed JSONL manifests. + +Manifest format (JSON Lines): +- Optional first line meta header: + {"__meta__": {"version": 1, "base_dir": "../dataset"}} +- Then one record per material ID: + {"material_id": "mp-4002", "files": ["mp-4002/mp-4002-7.npy", "mp-4002/mp-4002-8.npy", ...]} + +Path resolution: +- If meta header is present, non-absolute file paths are resolved relative to base_dir. + base_dir itself may be relative to the manifest file directory. +- If no meta header (legacy manifests), file paths are used as-is for backward compatibility. + +Notes: +- Splits are performed by material ID when manifests are generated (not at file level). +- This dataset reads the manifest and yields one sample per .npy file, with 'material_id' and 'path' metadata. +- Avoids scanning large trees by leveraging precomputed manifests under data/manifests. +""" + +from __future__ import annotations + +import json +import os +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple + +import numpy as np +import torch +from torch.utils.data import Dataset + + +def _read_jsonl_manifest(manifest_path: str) -> Tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]: + """ + Read a JSONL manifest file. + + Returns: + (records, meta): + records: [{"material_id": str, "files": [str, ...]}, ...] + meta: Optional dict from a header line like {"__meta__": {"version": 1, "base_dir": "..."}} + """ + if not os.path.isfile(manifest_path): + raise FileNotFoundError(f"Manifest not found: {manifest_path}") + records: List[Dict[str, Any]] = [] + meta: Optional[Dict[str, Any]] = None + first_record_processed = False + with open(manifest_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + rec = json.loads(line) + # Detect optional meta header on the first non-empty line + if not first_record_processed and isinstance(rec, dict) and "__meta__" in rec: + meta = rec["__meta__"] + first_record_processed = True + continue + + first_record_processed = True + # Basic validation + if "material_id" not in rec or "files" not in rec: + raise ValueError(f"Invalid manifest record without material_id/files: {rec}") + if not isinstance(rec["files"], list): + raise ValueError(f"Manifest 'files' must be a list: {rec}") + records.append(rec) + if not records: + raise RuntimeError(f"No records in manifest: {manifest_path}") + return records, meta + + +class NpyManifestDataset(Dataset): + """ + Map-style dataset that loads individual .npy files listed in a JSONL manifest. + + One sample per file. Metadata includes 'material_id' and 'path' to the .npy file. + + Args: + manifest_path: Path to JSONL manifest (e.g., data/manifests/train.jsonl). + transform: Optional callable applied to the loaded torch.Tensor. + dtype: Torch dtype to convert the loaded array into (default: torch.float32). + mmap_mode: NumPy memmap mode for np.load (default: 'r'). Set to None to disable memmap. + return_meta: If True, returns a dict with keys {'x', 'material_id', 'path'}. + If False, returns only the tensor. + validate_paths: If True, validate that file paths exist when building the index. + + Returns: + If return_meta: + Dict[str, Any]: {'x': Tensor, 'material_id': str, 'path': str} + Else: + Tensor + """ + + def __init__( + self, + manifest_path: str, + transform: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, + dtype: torch.dtype = torch.float32, + mmap_mode: Optional[str] = "r", + return_meta: bool = True, + validate_paths: bool = False, + ) -> None: + super().__init__() + self.manifest_path = manifest_path + self.transform = transform + self.dtype = dtype + self.mmap_mode = mmap_mode + self.return_meta = return_meta + + # Build index: flatten per-material records into (resolved_file_path, material_id) pairs + records, meta = _read_jsonl_manifest(manifest_path) + + # Resolve base_dir from meta header if present + base_dir_abs: Optional[str] = None + if meta is not None and isinstance(meta, dict): + base_dir_val = meta.get("base_dir") + if isinstance(base_dir_val, str): + manifest_dir_abs = os.path.abspath(os.path.dirname(manifest_path)) + base_dir_abs = ( + base_dir_val + if os.path.isabs(base_dir_val) + else os.path.normpath(os.path.join(manifest_dir_abs, base_dir_val)) + ) + + items: List[Tuple[str, str]] = [] + for rec in records: + mid = rec["material_id"] + for fpath in rec["files"]: + # Resolve relative paths against base_dir_abs if provided; otherwise use as-is + resolved = ( + fpath + if os.path.isabs(fpath) or base_dir_abs is None + else os.path.normpath(os.path.join(base_dir_abs, fpath)) + ) + if validate_paths and not os.path.isfile(resolved): + # You may choose to ignore missing files instead of raising. + raise FileNotFoundError(f"Missing file listed in manifest: {resolved}") + items.append((resolved, mid)) + + if not items: + raise RuntimeError(f"No file entries found in manifest: {manifest_path}") + + self._items = items + # Materials set for quick stats + self._materials = sorted({mid for _, mid in items}) + + def __len__(self) -> int: + return len(self._items) + + def __getitem__(self, idx: int): + fpath, mid = self._items[idx] + # Load numpy array; use mmap for reduced memory pressure on large datasets + arr = np.load(fpath, mmap_mode=self.mmap_mode) + # Ensure array is a standard ndarray (materialize memmap view if needed) + arr = np.asarray(arr) + # Convert to torch tensor with requested dtype + x = torch.from_numpy(arr) + if self.dtype is not None: + x = x.to(self.dtype) + + if self.transform is not None: + x = self.transform(x) + + if self.return_meta: + return {"x": x, "material_id": mid, "path": fpath} + else: + return x + + @property + def materials(self) -> List[str]: + """Sorted list of unique material IDs present in this dataset.""" + return self._materials + + def count_by_material(self) -> Dict[str, int]: + """Return a mapping from material_id to number of files for that material in this dataset.""" + counts: Dict[str, int] = {} + for _, mid in self._items: + counts[mid] = counts.get(mid, 0) + 1 + return counts + + +# Optional convenience to locate standard split manifests under data/manifests +def default_manifest_paths(manifest_dir: str = "data/manifests") -> Dict[str, str]: + """ + Returns standard manifest paths for train/val/test within a manifest directory. + + Example: + paths = default_manifest_paths("data/manifests") + train_ds = NpyManifestDataset(paths["train"]) + val_ds = NpyManifestDataset(paths["val"]) + test_ds = NpyManifestDataset(paths["test"]) + """ + return { + "train": os.path.join(manifest_dir, "train.jsonl"), + "val": os.path.join(manifest_dir, "val.jsonl"), + "test": os.path.join(manifest_dir, "test.jsonl"), + } diff --git a/src/trainer/dataset/manifest_utils.py b/src/trainer/dataset/manifest_utils.py new file mode 100644 index 0000000..4a2339f --- /dev/null +++ b/src/trainer/dataset/manifest_utils.py @@ -0,0 +1,223 @@ +""" +Utilities for generating and loading manifest files for .npy datasets. + +Manifest format (JSON Lines): +- Optional first line with manifest metadata: + {"__meta__": {"version": 1, "base_dir": "../dataset"}} + +- Then one record per material ID: + {"material_id": "mp-4002", "files": ["mp-4002/mp-4002-7.npy", "mp-4002/mp-4002-8.npy", ...]} + +Notes: +- When the meta header with base_dir is present, file paths in subsequent records are interpreted + relative to base_dir. If a file path is absolute, it is used as-is. +- If no meta header is present (legacy manifests), file paths are interpreted as-is (backward compatible). +- Storing manifests under data/manifests allows fast dataset construction without scanning millions of files. +""" + +from __future__ import annotations + +import json +import os +import random +from dataclasses import dataclass +from typing import Dict, List, Tuple, Optional + + +@dataclass(frozen=True) +class ManifestStats: + train_materials: int + val_materials: int + test_materials: int + train_files: int + val_files: int + test_files: int + + +def _is_material_dir(path: str) -> bool: + """Heuristic: a material dir contains at least one .npy file.""" + if not os.path.isdir(path): + return False + try: + for entry in os.listdir(path): + if entry.endswith(".npy"): + return True + except Exception: + return False + return False + + +def scan_dataset_root(dataset_root: str) -> Dict[str, List[str]]: + """ + Scan dataset_root for material directories and collect their .npy file paths. + + Returns a mapping from material_id (dir name) to list of absolute file paths. + """ + if not os.path.isdir(dataset_root): + raise FileNotFoundError(f"Dataset root not found: {dataset_root}") + + materials: Dict[str, List[str]] = {} + for entry in sorted(os.listdir(dataset_root)): + material_dir = os.path.join(dataset_root, entry) + if not _is_material_dir(material_dir): + continue + material_id = os.path.basename(material_dir) + files: List[str] = [] + try: + for f in sorted(os.listdir(material_dir)): + if f.endswith(".npy"): + files.append(os.path.join(material_dir, f)) + except Exception: + # Skip unreadable directories + continue + if files: + materials[material_id] = files + + if not materials: + raise RuntimeError(f"No .npy files discovered under {dataset_root}") + return materials + + +def split_materials( + materials: Dict[str, List[str]], + train_ratio: float = 0.8, + val_ratio: float = 0.1, + test_ratio: float = 0.1, + seed: int = 42, +) -> Tuple[Dict[str, List[str]], Dict[str, List[str]], Dict[str, List[str]]]: + """ + Split by material ID into train/val/test sets. + + - Ratios must sum to 1.0 within a small tolerance. + - Splitting is performed at the material level (directory name), never per-file. + """ + total = train_ratio + val_ratio + test_ratio + if abs(total - 1.0) > 1e-6: + raise ValueError(f"Ratios must sum to 1.0, got {train_ratio}+{val_ratio}+{test_ratio}={total}") + + ids = list(materials.keys()) + random.Random(seed).shuffle(ids) + + n = len(ids) + n_train = int(round(train_ratio * n)) + n_val = int(round(val_ratio * n)) + # Ensure all materials are assigned, adjust test count + n_test = n - n_train - n_val + if n_test < 0: + n_test = 0 + + train_ids = set(ids[:n_train]) + val_ids = set(ids[n_train : n_train + n_val]) + test_ids = set(ids[n_train + n_val :]) + + train = {mid: materials[mid] for mid in train_ids} + val = {mid: materials[mid] for mid in val_ids} + test = {mid: materials[mid] for mid in test_ids} + return train, val, test + + +def ensure_dir(path: str) -> None: + os.makedirs(path, exist_ok=True) + + +def write_jsonl_manifest( + manifest_path: str, + materials: Dict[str, List[str]], + base_dir: Optional[str] = None, +) -> int: + """ + Write a JSONL manifest. + + If base_dir is provided, a meta header is written as the first line: + {"__meta__": {"version": 1, "base_dir": base_dir}} + and file paths are written relative to base_dir (when possible). + + Args: + manifest_path: Destination JSONL file. + materials: Mapping from material_id to list of file paths (absolute or relative). + base_dir: Base directory to which file paths will be made relative. This string is stored + in the meta header. If it is a relative path, it is interpreted relative to + the manifest file's directory when reading. + + Returns: + Number of material entries written (excluding meta header). + """ + ensure_dir(os.path.dirname(manifest_path)) + + manifest_dir_abs = os.path.abspath(os.path.dirname(manifest_path)) + base_dir_abs = None + if base_dir is not None: + base_dir_abs = ( + base_dir if os.path.isabs(base_dir) else os.path.normpath(os.path.join(manifest_dir_abs, base_dir)) + ) + + count = 0 + with open(manifest_path, "w", encoding="utf-8") as f: + # Optional meta header + if base_dir is not None: + meta = {"__meta__": {"version": 1, "base_dir": base_dir}} + f.write(json.dumps(meta) + "\n") + + for material_id, files in sorted(materials.items()): + files_out: List[str] = [] + for p in files: + if base_dir_abs is None: + files_out.append(p) + else: + try: + # Prefer paths relative to base_dir_abs. If outside, relpath may contain .. which is still valid. + files_out.append(os.path.relpath(p, start=base_dir_abs)) + except Exception: + files_out.append(p) + rec = {"material_id": material_id, "files": files_out} + f.write(json.dumps(rec) + "\n") + count += 1 + return count + + +def generate_manifests( + dataset_root: str, + manifest_dir: str = "data/manifests", + train_ratio: float = 0.8, + val_ratio: float = 0.1, + test_ratio: float = 0.1, + seed: int = 42, +) -> ManifestStats: + """ + Generate train/val/test manifests from dataset_root, splitting by material ID. + + Writes: + - {manifest_dir}/train.jsonl + - {manifest_dir}/val.jsonl + - {manifest_dir}/test.jsonl + + Returns aggregate stats. + """ + materials = scan_dataset_root(dataset_root) + train, val, test = split_materials(materials, train_ratio, val_ratio, test_ratio, seed=seed) + + train_path = os.path.join(manifest_dir, "train.jsonl") + val_path = os.path.join(manifest_dir, "val.jsonl") + test_path = os.path.join(manifest_dir, "test.jsonl") + + # Store base_dir in header relative to manifest_dir for portability + manifest_dir_abs = os.path.abspath(manifest_dir) + dataset_root_abs = os.path.abspath(dataset_root) + base_dir_header = os.path.relpath(dataset_root_abs, start=manifest_dir_abs) + + _ = write_jsonl_manifest(train_path, train, base_dir=base_dir_header) + _ = write_jsonl_manifest(val_path, val, base_dir=base_dir_header) + _ = write_jsonl_manifest(test_path, test, base_dir=base_dir_header) + + def _file_count(m: Dict[str, List[str]]) -> int: + return sum(len(v) for v in m.values()) + + stats = ManifestStats( + train_materials=len(train), + val_materials=len(val), + test_materials=len(test), + train_files=_file_count(train), + val_files=_file_count(val), + test_files=_file_count(test), + ) + return stats diff --git a/src/trainer/generate_manifests.py b/src/trainer/generate_manifests.py new file mode 100644 index 0000000..8f7c4b8 --- /dev/null +++ b/src/trainer/generate_manifests.py @@ -0,0 +1,146 @@ +""" +CLI to generate JSONL manifests for .npy datasets split by material ID. + +Manifest format (JSON Lines): +- Optional first line meta header: + {"__meta__": {"version": 1, "base_dir": "../dataset"}} +- Then one record per material ID: + {"material_id": "mp-4002", "files": ["mp-4002/mp-4002-7.npy", "mp-4002/mp-4002-8.npy", ...]} + +Defaults: +- dataset_root: data/dataset +- manifest_dir: data/manifests +- splits: train=0.8, val=0.1, test=0.1 +- seed: 42 + +Usage: + # Run from src/trainer as working directory + python generate_manifests.py \ + --dataset-root data/dataset \ + --manifest-dir data/manifests \ + --train-ratio 0.8 --val-ratio 0.1 --test-ratio 0.1 \ + --seed 42 + +This script avoids scanning during training by precomputing manifests that link +to the .npy files grouped by material ID (directory name). +""" + +from __future__ import annotations + +import argparse +import json +import os +import sys +from typing import Any, Dict + +# Assume cwd is src/trainer; import from local package +from dataset.manifest_utils import generate_manifests, ManifestStats + + +def _parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser( + description="Generate JSONL manifests for .npy datasets split by material ID." + ) + p.add_argument( + "--dataset-root", + type=str, + default="data/dataset", + help="Root directory containing material subfolders (e.g., data/dataset/mp-4002/...).", + ) + p.add_argument( + "--manifest-dir", + type=str, + default="data/manifests", + help="Directory where train/val/test JSONL manifests will be written.", + ) + p.add_argument( + "--train-ratio", + type=float, + default=0.8, + help="Fraction of materials assigned to train split.", + ) + p.add_argument( + "--val-ratio", + type=float, + default=0.1, + help="Fraction of materials assigned to validation split.", + ) + p.add_argument( + "--test-ratio", + type=float, + default=0.1, + help="Fraction of materials assigned to test split.", + ) + p.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for reproducible material ID splitting.", + ) + p.add_argument( + "--quiet", + action="store_true", + help="Suppress non-error output.", + ) + return p.parse_args() + + +def _validate_ratios(train: float, val: float, test: float) -> None: + total = train + val + test + if abs(total - 1.0) > 1e-6: + raise SystemExit( + f"Error: Ratios must sum to 1.0, got {train}+{val}+{test}={total}" + ) + if min(train, val, test) < 0.0: + raise SystemExit("Error: Ratios must be non-negative.") + + +def _print_summary(manifest_dir: str, stats: ManifestStats) -> None: + summary: Dict[str, Any] = { + "manifest_dir": manifest_dir, + "manifests": { + "train": os.path.join(manifest_dir, "train.jsonl"), + "val": os.path.join(manifest_dir, "val.jsonl"), + "test": os.path.join(manifest_dir, "test.jsonl"), + }, + "materials": { + "train": stats.train_materials, + "val": stats.val_materials, + "test": stats.test_materials, + }, + "files": { + "train": stats.train_files, + "val": stats.val_files, + "test": stats.test_files, + }, + } + print(json.dumps(summary, indent=2)) + + +def main() -> None: + args = _parse_args() + _validate_ratios(args.train_ratio, args.val_ratio, args.test_ratio) + + # Resolve paths relative to project root if provided as relative paths + dataset_root = args.dataset_root + manifest_dir = args.manifest_dir + + stats = generate_manifests( + dataset_root=dataset_root, + manifest_dir=manifest_dir, + train_ratio=args.train_ratio, + val_ratio=args.val_ratio, + test_ratio=args.test_ratio, + seed=args.seed, + ) + + if not args.quiet: + _print_summary(manifest_dir, stats) + + +if __name__ == "__main__": + try: + main() + except Exception as e: + print(f"[ERROR] {e}", file=sys.stderr) + sys.exit(1) From 896881deb4d705a64f1e9f0c6258248f63587354 Mon Sep 17 00:00:00 2001 From: linked-liszt Date: Mon, 3 Nov 2025 18:24:54 -0600 Subject: [PATCH 02/18] feat: Add initial running pytorch lightning trainer --- configs/trainer.yaml | 88 +++++ src/trainer/dataset/datamodule.py | 8 +- src/trainer/dataset/dataset.py | 116 ++++++- src/trainer/model/model.py | 539 ++++++++++++++++++++++++++++++ src/trainer/train_paper.py | 161 +++++++++ 5 files changed, 893 insertions(+), 19 deletions(-) create mode 100644 configs/trainer.yaml create mode 100644 src/trainer/model/model.py create mode 100644 src/trainer/train_paper.py diff --git a/configs/trainer.yaml b/configs/trainer.yaml new file mode 100644 index 0000000..74fb14a --- /dev/null +++ b/configs/trainer.yaml @@ -0,0 +1,88 @@ +# AlphaDiffract trainer configuration (paper-aligned defaults provided here) +# This file is required by src/trainer/train_paper.py. It contains all parameters with no script-side defaults. + +# --- Data / Manifests --- +manifest_dir: "../../data/manifests" +dataset_root: "../../data/dataset" # used when auto_generate_manifests is true +auto_generate_manifests: true +train_ratio: 0.8 +val_ratio: 0.1 +test_ratio: 0.1 +seed: 42 + +# --- DataLoader --- +batch_size: 64 # paper used 64 +num_workers: 8 +pin_memory: true +persistent_workers: true + +# --- Dataset label extraction (embedded in .npy/.npz) --- +validate_paths: false +extract_labels: true +allow_pickle: true +labels_key_map: + x: ["x", "signal", "xrd", "pattern"] + cs: ["cs", "crystal_system"] + sg: ["sg", "space_group"] + lattice_params: ["lattice_params", "lp"] +dtype: "float32" # one of: float32, float64, float16, bfloat16 +mmap_mode: "r" # NumPy memmap mode: 'r', 'r+', 'w+', or null to disable + +# --- Model architecture --- +depths: [2, 2, 4, 2] +dims: [128, 256, 384, 560] +kernel_sizes: [101, 65, 51, 25] +strides: [5, 5, 5, 5] +drop_path_rate: 0.2 +layer_scale_init_value: 1.0e-6 + +# Heads +head_dropout: 0.2 +cs_hidden: [1024, 512] +sg_hidden: [2048, 1024] +lp_hidden: [512, 256] + +# Task sizes +num_cs_classes: 7 +num_sg_classes: 230 +num_lp_outputs: 6 + +# LP output bounds +lp_bounds_min: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +lp_bounds_max: [500.0, 500.0, 500.0, 180.0, 180.0, 180.0] +bound_lp_with_sigmoid: true + +# Loss weights +lambda_cs: 1.0 +lambda_sg: 1.0 +lambda_lp: 1.0 + +# Optional GEMD term on SG +gemd_mu: 0.0 +gemd_distance_matrix: # e.g., "path/to/space_group_distance_matrix.npy" to enable GEMD + +# Optimizer +lr: 0.0002 # paper used 2e-4 +weight_decay: 0.01 # paper used 0.01 +use_adamw: true + +# --- Trainer settings --- +default_root_dir: "outputs/paper_model" +max_epochs: 100 +accumulate_grad_batches: 1 +precision: "32" # e.g., '32', '16-mixed', 'bf16-mixed' +accelerator: "auto" +devices: "auto" +log_every_n_steps: 50 +deterministic: false +benchmark: true + +# --- Checkpointing --- +monitor: "val/loss" +mode: "min" +save_top_k: 1 +every_n_epochs: 1 + +# --- Evaluation --- +resume_from: # e.g., "outputs/paper_model/checkpoints/epochXYZ.ckpt" +test_after_train: true diff --git a/src/trainer/dataset/datamodule.py b/src/trainer/dataset/datamodule.py index d25b591..61e0af9 100644 --- a/src/trainer/dataset/datamodule.py +++ b/src/trainer/dataset/datamodule.py @@ -39,8 +39,8 @@ import os from typing import Any, Callable, Dict, Optional -# Resolve project root for robust default paths independent of cwd (src/trainer) -PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +# Resolve base for relative paths as the current working directory (runtime CWD) +CWD_BASE = os.getcwd() import torch from torch.utils.data import DataLoader @@ -130,7 +130,7 @@ def _manifest_paths(self) -> Dict[str, str]: # Allow overriding filenames while still respecting manifest_dir base = self.manifest_dir if not os.path.isabs(base): - base = os.path.join(PROJECT_ROOT, base) + base = os.path.join(CWD_BASE, base) paths = { "train": os.path.join(base, self.train_file), "val": os.path.join(base, self.val_file), @@ -160,7 +160,7 @@ def prepare_data(self) -> None: # Resolve absolute paths for dataset root and manifest dir dataset_root_abs = self.dataset_root if not os.path.isabs(dataset_root_abs): - dataset_root_abs = os.path.join(PROJECT_ROOT, dataset_root_abs) + dataset_root_abs = os.path.join(CWD_BASE, dataset_root_abs) manifest_dir_abs = os.path.dirname(paths["train"]) # Generate manifests in a reproducible way by material ID. diff --git a/src/trainer/dataset/dataset.py b/src/trainer/dataset/dataset.py index c8c42c3..794c97e 100644 --- a/src/trainer/dataset/dataset.py +++ b/src/trainer/dataset/dataset.py @@ -97,6 +97,9 @@ def __init__( mmap_mode: Optional[str] = "r", return_meta: bool = True, validate_paths: bool = False, + extract_labels: bool = False, + labels_key_map: Optional[Dict[str, List[str]]] = None, + allow_pickle: bool = True, ) -> None: super().__init__() self.manifest_path = manifest_path @@ -104,6 +107,33 @@ def __init__( self.dtype = dtype self.mmap_mode = mmap_mode self.return_meta = return_meta + self.extract_labels = extract_labels + self.allow_pickle = allow_pickle + # Default key mapping for extracting fields from embedded containers + # Simplified: single string keys, no search lists + self.labels_key_map = labels_key_map or { + "x": "dp", + "cs": "cs", + "sg": "sg", + "lattice_params": None, + "lp_a": "_cell_length_a", + "lp_b": "_cell_length_b", + "lp_c": "_cell_length_c", + "lp_alpha": "_cell_angle_alpha", + "lp_beta": "_cell_angle_beta", + "lp_gamma": "_cell_angle_gamma", + } + # Cache exact keys + self.k_x = self.labels_key_map.get("x") + self.k_cs = self.labels_key_map.get("cs") + self.k_sg = self.labels_key_map.get("sg") + self.k_lp = self.labels_key_map.get("lattice_params") + self.k_lp_a = self.labels_key_map.get("lp_a") + self.k_lp_b = self.labels_key_map.get("lp_b") + self.k_lp_c = self.labels_key_map.get("lp_c") + self.k_lp_alpha = self.labels_key_map.get("lp_alpha") + self.k_lp_beta = self.labels_key_map.get("lp_beta") + self.k_lp_gamma = self.labels_key_map.get("lp_gamma") # Build index: flatten per-material records into (resolved_file_path, material_id) pairs records, meta = _read_jsonl_manifest(manifest_path) @@ -147,22 +177,78 @@ def __len__(self) -> int: def __getitem__(self, idx: int): fpath, mid = self._items[idx] - # Load numpy array; use mmap for reduced memory pressure on large datasets - arr = np.load(fpath, mmap_mode=self.mmap_mode) - # Ensure array is a standard ndarray (materialize memmap view if needed) - arr = np.asarray(arr) - # Convert to torch tensor with requested dtype - x = torch.from_numpy(arr) - if self.dtype is not None: - x = x.to(self.dtype) - - if self.transform is not None: - x = self.transform(x) - - if self.return_meta: - return {"x": x, "material_id": mid, "path": fpath} + # Load numpy/npz; allow_pickle supports dict-like payloads saved in .npy + # Load with pickle-friendly mode; memmap disabled for object payloads + arr = np.load(fpath, mmap_mode=None, allow_pickle=self.allow_pickle) + + x_tensor: Optional[torch.Tensor] = None + y_cs_t: Optional[torch.Tensor] = None + y_sg_t: Optional[torch.Tensor] = None + y_lp_t: Optional[torch.Tensor] = None + + # Helper to resolve a key from a container using a list of candidate names + def _get_exact(container, key: str): + if key is None: + return None + if isinstance(container, dict): + return container.get(key) + return None + + # Simplified: only support .npy files with object dtype containing a dict payload + if isinstance(arr, np.ndarray) and arr.dtype == object: + payload = arr.item() + if not isinstance(payload, dict): + raise TypeError(f"Expected dict payload in object dtype .npy, got {type(payload)} for {fpath}") + x_np = _get_exact(payload, self.k_x) + if x_np is None: + raise KeyError(f"No data array found in npy {fpath} using key '{self.k_x}'") + x_np = np.asarray(x_np) + x_tensor = torch.from_numpy(x_np) + + if self.extract_labels: + y_cs = _get_exact(payload, self.k_cs) + y_sg = _get_exact(payload, self.k_sg) + y_lp = _get_exact(payload, self.k_lp) + if y_cs is not None: + y_cs_t = torch.as_tensor(np.asarray(y_cs)) + if y_sg is not None: + y_sg_t = torch.as_tensor(np.asarray(y_sg)) + if y_lp is not None: + y_lp_t = torch.as_tensor(np.asarray(y_lp)) + # Assemble lattice params from fixed keys if consolidated field not present + if y_lp_t is None: + a = _get_exact(payload, self.k_lp_a) + b = _get_exact(payload, self.k_lp_b) + c = _get_exact(payload, self.k_lp_c) + alpha = _get_exact(payload, self.k_lp_alpha) + beta = _get_exact(payload, self.k_lp_beta) + gamma = _get_exact(payload, self.k_lp_gamma) + if None not in (a, b, c, alpha, beta, gamma): + lp_np = np.asarray([a, b, c, alpha, beta, gamma], dtype=np.float64) + y_lp_t = torch.as_tensor(lp_np) else: - return x + raise TypeError(f"Unsupported file format: expected object dtype .npy, got dtype={getattr(arr, 'dtype', '?')} for {fpath}") + + # Cast dtype and apply transform to x only + if self.dtype is not None and x_tensor is not None: + x_tensor = x_tensor.to(self.dtype) + if self.transform is not None and x_tensor is not None: + x_tensor = self.transform(x_tensor) + + if not self.return_meta: + # Backward-compat: return only x + return x_tensor + + sample = {"x": x_tensor, "material_id": mid, "path": fpath} + # Attach labels if present/extracted + if self.extract_labels: + if y_cs_t is not None: + sample["cs"] = y_cs_t + if y_sg_t is not None: + sample["sg"] = y_sg_t + if y_lp_t is not None: + sample["lattice_params"] = y_lp_t + return sample @property def materials(self) -> List[str]: diff --git a/src/trainer/model/model.py b/src/trainer/model/model.py new file mode 100644 index 0000000..e785d5a --- /dev/null +++ b/src/trainer/model/model.py @@ -0,0 +1,539 @@ +from typing import Dict, Tuple, Optional, List, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +try: + import pytorch_lightning as pl +except ImportError as e: + raise ImportError( + "pytorch_lightning is required for this module. Please install with `pip install pytorch-lightning`." + ) from e + + +# ----------------------------- +# Utility: DropPath (Stochastic Depth) +# ----------------------------- +def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + # work with broadcastable noise shape + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor = random_tensor.floor() + return x.div(keep_prob) * random_tensor + + +class DropPath(nn.Module): + def __init__(self, drop_prob: float = 0.0): + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return drop_path(x, self.drop_prob, self.training) + + +# ----------------------------- +# ConvNeXt 1D Block (self-contained) +# Follows ConvNeXt design adapted to 1D: depthwise conv -> LN -> PW-MLP (expand) -> GELU -> PW-MLP (project) -> gamma -> residual +# ----------------------------- +class ConvNeXtBlock1D(nn.Module): + def __init__( + self, + dim: int, + kernel_size: int = 7, + drop_path: float = 0.0, + layer_scale_init_value: float = 1e-6, + ): + super().__init__() + # depthwise conv + padding = kernel_size // 2 + self.dwconv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding, groups=dim) + # LayerNorm over channels - apply by permuting to (N, L, C) + self.norm = nn.LayerNorm(dim, eps=1e-6) + # pointwise MLP + self.pwconv1 = nn.Linear(dim, 4 * dim) + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + # layer scale (gamma) as a learnable per-channel vector + self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(dim)) if layer_scale_init_value > 0 else None + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: (N, C, L) + shortcut = x + x = self.dwconv(x) # (N, C, L) + x = x.permute(0, 2, 1) # (N, L, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = x * self.gamma + x = x.permute(0, 2, 1) # (N, C, L) + x = shortcut + self.drop_path(x) + return x + + +# ----------------------------- +# Backbone: ConvNeXt1D (self-contained) +# ----------------------------- +class ConvNeXt1DBackbonePaper(nn.Module): + """ + Self-contained 1D ConvNeXt backbone adapted for XRD: + - 4 stages with progressive downsampling via strided convs + - ConvNeXt blocks within each stage + - Global average pooling at end to produce a feature vector + Final dim_output is dims[-1]. Default dims end at 560 to match the paper. + """ + + def __init__( + self, + in_chans: int = 1, + depths: Tuple[int, int, int, int] = (2, 2, 4, 2), + dims: Tuple[int, int, int, int] = (128, 256, 384, 560), + kernel_sizes: Tuple[int, int, int, int] = (101, 65, 51, 25), + strides: Tuple[int, int, int, int] = (5, 5, 5, 5), + drop_path_rate: float = 0.2, + layer_scale_init_value: float = 1e-6, + ): + super().__init__() + assert len(depths) == 4 and len(dims) == 4 and len(kernel_sizes) == 4 and len(strides) == 4 + + self.dim_output = dims[-1] + + # Downsampling layers: stem + 3 transitions + self.downsample_layers = nn.ModuleList() + stem = nn.Sequential( + nn.Conv1d(in_chans, dims[0], kernel_size=kernel_sizes[0], stride=strides[0], padding=0), + # channel-first LayerNorm is approximated with GroupNorm(1,..) or use channel-wise LN by permuting when needed + # Here we avoid extra permutes by using simple affine scaling with GroupNorm(1, C) + nn.GroupNorm(1, dims[0], eps=1e-6), + ) + self.downsample_layers.append(stem) + + for i in range(3): + down = nn.Sequential( + nn.GroupNorm(1, dims[i], eps=1e-6), + nn.Conv1d(dims[i], dims[i + 1], kernel_size=kernel_sizes[i + 1], stride=strides[i + 1], padding=0), + ) + self.downsample_layers.append(down) + + # Stages with ConvNeXt blocks + self.stages = nn.ModuleList() + dp_rates = torch.linspace(0, drop_path_rate, sum(depths)).tolist() + cur = 0 + for i in range(4): + blocks = [ + ConvNeXtBlock1D( + dim=dims[i], + kernel_size=kernel_sizes[i], + drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_value, + ) + for j in range(depths[i]) + ] + cur += depths[i] + self.stages.append(nn.Sequential(*blocks)) + + self.final_norm = nn.LayerNorm(dims[-1], eps=1e-6) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Accept (N, 8192) or (N, 1, 8192) + if x.ndim == 2: + x = x[:, None, :] + # 4 stages + for i in range(4): + x = self.downsample_layers[i](x) # (N, C_i, L_i) + x = self.stages[i](x) # (N, C_i, L_i) + # Global average pooling over length + x = x.mean(dim=-1) # (N, C_last) + x = self.final_norm(x) # LayerNorm over channels + return x # (N, dim_output) == (N, dims[-1]) == 560 by default + + +# ----------------------------- +# Heads: simple MLP builders +# ----------------------------- +def make_mlp( + input_dim: int, + hidden_dims: Optional[Tuple[int, ...]], + output_dim: int, + dropout: float = 0.2, + output_activation: Optional[nn.Module] = None, +) -> nn.Module: + layers: List[nn.Module] = [] + last = input_dim + if hidden_dims is not None and len(hidden_dims) > 0: + for hd in hidden_dims: + layers.extend([nn.Linear(last, hd), nn.ReLU()]) + if dropout and dropout > 0: + layers.append(nn.Dropout(dropout)) + last = hd + layers.append(nn.Linear(last, output_dim)) + if output_activation is not None: + layers.append(output_activation) + return nn.Sequential(*layers) + + +# ----------------------------- +# Lightning Module: AlphaDiffractLightning +# ----------------------------- +BatchType = Union[ + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + Dict[str, torch.Tensor], +] + + +class AlphaDiffractLightning(pl.LightningModule): + """ + PyTorch Lightning module for the AlphaDiffract model: + - 1D ConvNeXt backbone producing a 560-dim feature vector + - CS classifier head (7 classes) + - SG classifier head (230 classes) + - LP regressor head (6 outputs, optional bounding via sigmoid to [min, max]) + + Expected batch formats: + - Tuple: (x, y_cs, y_sg, y_lp) + - Dict keys: x or xrd or signal; cs, sg, lattice_params (or lp) + + Losses: + - CS: CrossEntropy + - SG: CrossEntropy + - LP: L1 (MAE) by default; bounded via sigmoid to given ranges if enabled + + Metrics logged: + - train/val/test: cs_acc, sg_acc, lp_mae, total_loss + """ + + def __init__( + self, + # Backbone args + depths: Tuple[int, int, int, int] = (2, 2, 4, 2), + dims: Tuple[int, int, int, int] = (128, 256, 384, 560), + kernel_sizes: Tuple[int, int, int, int] = (101, 65, 51, 25), + strides: Tuple[int, int, int, int] = (5, 5, 5, 5), + drop_path_rate: float = 0.2, + layer_scale_init_value: float = 1e-6, + + # Head dims + head_dropout: float = 0.2, + cs_hidden: Optional[Tuple[int, ...]] = (1024, 512), + sg_hidden: Optional[Tuple[int, ...]] = (2048, 1024), + lp_hidden: Optional[Tuple[int, ...]] = (512, 256), + + # Task sizes + num_cs_classes: int = 7, + num_sg_classes: int = 230, + num_lp_outputs: int = 6, + + # LP bounding + lp_bounds_min: Tuple[float, float, float, float, float, float] = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0), + lp_bounds_max: Tuple[float, float, float, float, float, float] = (500.0, 500.0, 500.0, 180.0, 180.0, 180.0), + bound_lp_with_sigmoid: bool = True, + + # Loss weights + lambda_cs: float = 1.0, + lambda_sg: float = 1.0, + lambda_lp: float = 1.0, + # Optional GEMD for SG + gemd_mu: float = 0.0, + gemd_distance_matrix_path: Optional[str] = None, + + # Optimizer + lr: float = 2e-4, + weight_decay: float = 1e-2, + use_adamw: bool = True, + ): + super().__init__() + self.save_hyperparameters() + + # Backbone + self.backbone = ConvNeXt1DBackbonePaper( + in_chans=1, + depths=depths, + dims=dims, + kernel_sizes=kernel_sizes, + strides=strides, + drop_path_rate=drop_path_rate, + layer_scale_init_value=layer_scale_init_value, + ) + feat_dim = self.backbone.dim_output # should be 560 with default dims + + # Heads (produce logits for classification; no softmax here) + self.cs_head = make_mlp( + input_dim=feat_dim, + hidden_dims=cs_hidden, + output_dim=num_cs_classes, + dropout=head_dropout, + output_activation=None, + ) + self.sg_head = make_mlp( + input_dim=feat_dim, + hidden_dims=sg_hidden, + output_dim=num_sg_classes, + dropout=head_dropout, + output_activation=None, + ) + self.lp_head = make_mlp( + input_dim=feat_dim, + hidden_dims=lp_hidden, + output_dim=num_lp_outputs, + dropout=head_dropout, + output_activation=None, + ) + + # Losses + self.ce = nn.CrossEntropyLoss() + self.mse = nn.MSELoss() + + # LP bounds + self.register_buffer("lp_min", torch.tensor(lp_bounds_min, dtype=torch.float32)) + self.register_buffer("lp_max", torch.tensor(lp_bounds_max, dtype=torch.float32)) + + self.bound_lp_with_sigmoid = bound_lp_with_sigmoid + + # weights + self.lambda_cs = lambda_cs + self.lambda_sg = lambda_sg + self.lambda_lp = lambda_lp + + # Optimizer config + self.lr = lr + self.weight_decay = weight_decay + self.use_adamw = use_adamw + + # Task sizes + self.num_cs_classes = num_cs_classes + self.num_sg_classes = num_sg_classes + self.num_lp_outputs = num_lp_outputs + + # GEMD setup (optional) + self.gemd_mu = gemd_mu + self.register_buffer("gemd_D", torch.empty(0)) + if gemd_distance_matrix_path is not None: + D_np = np.load(gemd_distance_matrix_path) + D_t = torch.as_tensor(D_np, dtype=torch.float32) + if D_t.ndim != 2 or D_t.shape[0] != self.num_sg_classes or D_t.shape[1] != self.num_sg_classes: + raise ValueError("GEMD distance matrix must be of shape (num_sg_classes, num_sg_classes)") + self.register_buffer("gemd_D", D_t) + + # ----------------------------- + # Forward + # ----------------------------- + def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: + """ + Forward pass. + Returns a dict with: + - cs_logits: (N, 7) + - sg_logits: (N, 230) + - lp: (N, 6) bounded if enabled + - features: (N, 560) + """ + feats = self.backbone(x) # (N, 560) + cs_logits = self.cs_head(feats) + sg_logits = self.sg_head(feats) + lp = self.lp_head(feats) + + if self.bound_lp_with_sigmoid: + # Bound to [min, max] via sigmoid + lp = torch.sigmoid(lp) * (self.lp_max - self.lp_min) + self.lp_min + + return { + "features": feats, + "cs_logits": cs_logits, + "sg_logits": sg_logits, + "lp": lp, + } + + # ----------------------------- + # Data parsing helpers + # ----------------------------- + @staticmethod + def _to_index(y: torch.Tensor, num_classes: int) -> torch.Tensor: + # Convert targets to class indices. Supports one-hot and integer labels. + if y.dim() > 1 and y.size(-1) > 1: + idx = y.argmax(dim=-1) + else: + idx = y.long() + # Normalize 1-based labels to 0-based if detected (min>=1 and max==num_classes) + with torch.no_grad(): + if idx.numel() > 0: + minv = int(idx.min().item()) + maxv = int(idx.max().item()) + if minv >= 1 and maxv == num_classes: + idx = idx - 1 + return idx + + @staticmethod + def _extract_batch(batch: BatchType) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if isinstance(batch, (list, tuple)): + assert len(batch) >= 4, "Expected at least (x, cs, sg, lp) in the batch tuple." + x, y_cs, y_sg, y_lp = batch[:4] + elif isinstance(batch, dict): + # Try common keys + x = batch.get("x", batch.get("xrd", batch.get("signal"))) + if x is None: + raise KeyError("Batch dict must contain 'x' or 'xrd' or 'signal'.") + y_cs = batch.get("cs") + y_sg = batch.get("sg") + y_lp = batch.get("lattice_params", batch.get("lp")) + if y_cs is None or y_sg is None or y_lp is None: + raise KeyError("Batch dict must contain 'cs', 'sg', and 'lattice_params' (or 'lp').") + else: + raise TypeError("Unsupported batch type. Use Tuple or Dict.") + + return x, y_cs, y_sg, y_lp + + # ----------------------------- + # Loss and metrics + # ----------------------------- + def _compute_losses_and_metrics( + self, preds: Dict[str, torch.Tensor], y_cs: torch.Tensor, y_sg: torch.Tensor, y_lp: torch.Tensor + ) -> Dict[str, torch.Tensor]: + cs_logits = preds["cs_logits"] + sg_logits = preds["sg_logits"] + lp_pred = preds["lp"] + + # targets: convert one-hot to index for CE if necessary + y_cs_idx = self._to_index(y_cs, self.num_cs_classes) + y_sg_idx = self._to_index(y_sg, self.num_sg_classes) + # regression targets should be float + y_lp = y_lp.float() + + loss_cs = self.ce(cs_logits, y_cs_idx) + loss_sg = self.ce(sg_logits, y_sg_idx) + loss_lp = self.mse(lp_pred, y_lp) + + # Optional GEMD term + loss_gemd = torch.tensor(0.0, device=cs_logits.device) + sg_probs = torch.softmax(sg_logits, dim=1) + if self.gemd_mu > 0.0 and self.gemd_D.numel() > 0: + D_rows = self.gemd_D[y_sg_idx] + gemd_per_sample = (D_rows * sg_probs).sum(dim=1) + loss_gemd = gemd_per_sample.mean() + + total_loss = ( + self.lambda_cs * loss_cs + + self.lambda_sg * loss_sg + + self.lambda_lp * loss_lp + + self.gemd_mu * loss_gemd + ) + + # metrics + with torch.no_grad(): + cs_acc = (cs_logits.argmax(dim=1) == y_cs_idx).float().mean() + sg_acc = (sg_logits.argmax(dim=1) == y_sg_idx).float().mean() + lp_mae = (lp_pred - y_lp).abs().mean() + lp_mse = F.mse_loss(lp_pred, y_lp) + + return { + "loss_total": total_loss, + "loss_cs": loss_cs, + "loss_sg": loss_sg, + "loss_lp": loss_lp, + "loss_gemd": loss_gemd, + "cs_acc": cs_acc, + "sg_acc": sg_acc, + "lp_mae": lp_mae, + "lp_mse": lp_mse, + } + + # ----------------------------- + # Lightning hooks + # ----------------------------- + def training_step(self, batch: BatchType, batch_idx: int) -> torch.Tensor: + x, y_cs, y_sg, y_lp = self._extract_batch(batch) + preds = self(x) + out = self._compute_losses_and_metrics(preds, y_cs, y_sg, y_lp) + + self.log("train/loss", out["loss_total"], prog_bar=True, on_step=True, on_epoch=True) + self.log("train/loss_cs", out["loss_cs"], on_step=True, on_epoch=True) + self.log("train/loss_sg", out["loss_sg"], on_step=True, on_epoch=True) + self.log("train/loss_lp", out["loss_lp"], on_step=True, on_epoch=True) + self.log("train/loss_gemd", out["loss_gemd"], on_step=True, on_epoch=True) + self.log("train/cs_acc", out["cs_acc"], prog_bar=True, on_step=True, on_epoch=True) + self.log("train/sg_acc", out["sg_acc"], on_step=True, on_epoch=True) + self.log("train/lp_mae", out["lp_mae"], on_step=True, on_epoch=True) + self.log("train/lp_mse", out["lp_mse"], on_step=True, on_epoch=True) + + return out["loss_total"] + + def validation_step(self, batch: BatchType, batch_idx: int) -> None: + x, y_cs, y_sg, y_lp = self._extract_batch(batch) + preds = self(x) + out = self._compute_losses_and_metrics(preds, y_cs, y_sg, y_lp) + + self.log("val/loss", out["loss_total"], prog_bar=True, on_epoch=True) + self.log("val/loss_cs", out["loss_cs"], on_epoch=True) + self.log("val/loss_sg", out["loss_sg"], on_epoch=True) + self.log("val/loss_lp", out["loss_lp"], on_epoch=True) + self.log("val/loss_gemd", out["loss_gemd"], on_epoch=True) + self.log("val/cs_acc", out["cs_acc"], prog_bar=True, on_epoch=True) + self.log("val/sg_acc", out["sg_acc"], on_epoch=True) + self.log("val/lp_mae", out["lp_mae"], on_epoch=True) + self.log("val/lp_mse", out["lp_mse"], on_epoch=True) + + def test_step(self, batch: BatchType, batch_idx: int) -> None: + x, y_cs, y_sg, y_lp = self._extract_batch(batch) + preds = self(x) + out = self._compute_losses_and_metrics(preds, y_cs, y_sg, y_lp) + + self.log("test/loss", out["loss_total"], prog_bar=True, on_epoch=True) + self.log("test/loss_cs", out["loss_cs"], on_epoch=True) + self.log("test/loss_sg", out["loss_sg"], on_epoch=True) + self.log("test/loss_lp", out["loss_lp"], on_epoch=True) + self.log("test/loss_gemd", out["loss_gemd"], on_epoch=True) + self.log("test/cs_acc", out["cs_acc"], prog_bar=True, on_epoch=True) + self.log("test/sg_acc", out["sg_acc"], on_epoch=True) + self.log("test/lp_mae", out["lp_mae"], on_epoch=True) + self.log("test/lp_mse", out["lp_mse"], on_epoch=True) + + def configure_optimizers(self): + params = self.parameters() + if self.use_adamw: + optimizer = torch.optim.AdamW(params, lr=self.lr, weight_decay=self.weight_decay) + else: + optimizer = torch.optim.Adam(params, lr=self.lr, weight_decay=self.weight_decay) + return optimizer + + +# ----------------------------- +# Example factory +# ----------------------------- +def build_alphadiffract_model_for_8192() -> AlphaDiffractLightning: + """ + Build the default model for 1x8192 XRD input, 560-dim features, and three heads. + """ + return AlphaDiffractLightning( + depths=(2, 2, 4, 2), + dims=(128, 256, 384, 560), + kernel_sizes=(101, 65, 51, 25), + strides=(5, 5, 5, 5), + drop_path_rate=0.2, + layer_scale_init_value=1e-6, + + head_dropout=0.2, + cs_hidden=(1024, 512), + sg_hidden=(2048, 1024), + lp_hidden=(512, 256), + + num_cs_classes=7, + num_sg_classes=230, + num_lp_outputs=6, + + lp_bounds_min=(0.0, 0.0, 0.0, 0.0, 0.0, 0.0), + lp_bounds_max=(500.0, 500.0, 500.0, 180.0, 180.0, 180.0), + bound_lp_with_sigmoid=True, + + lambda_cs=1.0, + lambda_sg=1.0, + lambda_lp=1.0, + + lr=2e-4, + weight_decay=1e-2, + use_adamw=True, + ) diff --git a/src/trainer/train_paper.py b/src/trainer/train_paper.py new file mode 100644 index 0000000..7910964 --- /dev/null +++ b/src/trainer/train_paper.py @@ -0,0 +1,161 @@ +import argparse +import os +from typing import Dict, Any, Optional + +import torch +import yaml +from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor + +# Project imports (expect PYTHONPATH=src or run via `python -m trainer.train_paper`) +from dataset import NpyDataModule +from model.model import AlphaDiffractLightning + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description="Train AlphaDiffract paper model (config-required)") + # Require a config file path with no script-side defaults + p.add_argument("config", type=str, help="Path to trainer config YAML (e.g., configs/trainer.yaml)") + return p.parse_args() + + +def load_config(path: str) -> Dict[str, Any]: + if not os.path.isfile(path): + raise FileNotFoundError(f"Config file not found: {path}") + with open(path, "r", encoding="utf-8") as f: + cfg = yaml.safe_load(f) + if not isinstance(cfg, dict): + raise ValueError(f"Config must be a mapping (YAML dict), got: {type(cfg)}") + return cfg + + +def _to_dtype(name: str) -> torch.dtype: + table = { + "float32": torch.float32, + "float64": torch.float64, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + } + if name not in table: + raise ValueError(f"Unsupported dtype '{name}'. Allowed: {list(table.keys())}") + return table[name] + + +def build_datamodule_from_cfg(cfg: Dict[str, Any]) -> NpyDataModule: + # Dataset kwargs from config + dataset_kwargs = { + "dtype": _to_dtype(cfg["dtype"]), + "mmap_mode": cfg["mmap_mode"], + "return_meta": True, + "validate_paths": cfg["validate_paths"], + "extract_labels": cfg["extract_labels"], + "allow_pickle": cfg["allow_pickle"], + } + labels_key_map = cfg.get("labels_key_map") + if labels_key_map is not None: + dataset_kwargs["labels_key_map"] = labels_key_map + + dm = NpyDataModule( + manifest_dir=cfg["manifest_dir"], + batch_size=cfg["batch_size"], + num_workers=cfg["num_workers"], + pin_memory=cfg["pin_memory"], + persistent_workers=cfg["persistent_workers"] and cfg["num_workers"] > 0, + dataset_kwargs=dataset_kwargs, + dataset_root=cfg["dataset_root"], + auto_generate_manifests=cfg["auto_generate_manifests"], + train_ratio=cfg["train_ratio"], + val_ratio=cfg["val_ratio"], + test_ratio=cfg["test_ratio"], + seed=cfg["seed"], + ) + return dm + + +def build_model_from_cfg(cfg: Dict[str, Any]) -> AlphaDiffractLightning: + model = AlphaDiffractLightning( + # Backbone + depths=tuple(cfg["depths"]), + dims=tuple(cfg["dims"]), + kernel_sizes=tuple(cfg["kernel_sizes"]), + strides=tuple(cfg["strides"]), + drop_path_rate=cfg["drop_path_rate"], + layer_scale_init_value=cfg["layer_scale_init_value"], + # Heads + head_dropout=cfg["head_dropout"], + cs_hidden=tuple(cfg["cs_hidden"]), + sg_hidden=tuple(cfg["sg_hidden"]), + lp_hidden=tuple(cfg["lp_hidden"]), + # Task sizes + num_cs_classes=cfg["num_cs_classes"], + num_sg_classes=cfg["num_sg_classes"], + num_lp_outputs=cfg["num_lp_outputs"], + # LP bounds and output handling + lp_bounds_min=tuple(cfg["lp_bounds_min"]), + lp_bounds_max=tuple(cfg["lp_bounds_max"]), + bound_lp_with_sigmoid=cfg["bound_lp_with_sigmoid"], + # Loss weights + lambda_cs=cfg["lambda_cs"], + lambda_sg=cfg["lambda_sg"], + lambda_lp=cfg["lambda_lp"], + # Optional GEMD + gemd_mu=cfg["gemd_mu"], + gemd_distance_matrix_path=cfg.get("gemd_distance_matrix"), + # Optimizer + lr=cfg["lr"], + weight_decay=cfg["weight_decay"], + use_adamw=cfg["use_adamw"], + ) + return model + + +def build_trainer_from_cfg(cfg: Dict[str, Any]) -> Trainer: + ckpt_cb = ModelCheckpoint( + monitor=cfg["monitor"], + mode=cfg["mode"], + save_top_k=cfg["save_top_k"], + dirpath=os.path.join(cfg["default_root_dir"], "checkpoints"), + filename="epoch{epoch:03d}-val_loss{val/loss:.4f}", + save_last=True, + every_n_epochs=cfg["every_n_epochs"], + auto_insert_metric_name=False, + ) + lr_cb = LearningRateMonitor(logging_interval="epoch") + + trainer = Trainer( + default_root_dir=cfg["default_root_dir"], + max_epochs=cfg["max_epochs"], + accelerator=cfg["accelerator"], + devices=cfg["devices"], + precision=cfg["precision"], + accumulate_grad_batches=cfg["accumulate_grad_batches"], + callbacks=[ckpt_cb, lr_cb], + log_every_n_steps=cfg["log_every_n_steps"], + deterministic=cfg["deterministic"], + benchmark=cfg["benchmark"], + ) + return trainer + + +def main(): + args = parse_args() + cfg = load_config(args.config) + + seed_everything(cfg["seed"], workers=True) + + dm = build_datamodule_from_cfg(cfg) + model = build_model_from_cfg(cfg) + trainer = build_trainer_from_cfg(cfg) + + # Train + resume_from: Optional[str] = cfg.get("resume_from") + trainer.fit(model, datamodule=dm, ckpt_path=resume_from) + + # Test best model if requested + if cfg["test_after_train"]: + ckpt_path = trainer.checkpoint_callback.best_model_path if trainer.checkpoint_callback else None + trainer.test(model=model, datamodule=dm, ckpt_path=ckpt_path or "best") + + +if __name__ == "__main__": + main() From a4ff801b19c2e4a9746b66cb0253d6a7d43ced97 Mon Sep 17 00:00:00 2001 From: linked-liszt Date: Mon, 3 Nov 2025 20:48:56 -0600 Subject: [PATCH 03/18] feat: Align model with paper and add mlflow logging --- .gitignore | 13 +++- configs/trainer.yaml | 47 +++++++----- src/trainer/model/model.py | 150 ++++++++++++++++++++----------------- src/trainer/train_paper.py | 19 +++++ 4 files changed, 143 insertions(+), 86 deletions(-) diff --git a/.gitignore b/.gitignore index e73e203..9d905f8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,17 @@ .env +__pycache__ # Development /sandbox /staging -/data \ No newline at end of file +/data +/original + +# Non-Docker Training Outputs +/src/trainer/outputs +/src/trainer/mlruns + +# Temp uv setup +.python-version +pyproject.toml +uv.lock \ No newline at end of file diff --git a/configs/trainer.yaml b/configs/trainer.yaml index 74fb14a..a588b3a 100644 --- a/configs/trainer.yaml +++ b/configs/trainer.yaml @@ -21,25 +21,31 @@ validate_paths: false extract_labels: true allow_pickle: true labels_key_map: - x: ["x", "signal", "xrd", "pattern"] - cs: ["cs", "crystal_system"] - sg: ["sg", "space_group"] - lattice_params: ["lattice_params", "lp"] + x: "dp" + cs: "cs" + sg: "sg" + lattice_params: null + lp_a: "_cell_length_a" + lp_b: "_cell_length_b" + lp_c: "_cell_length_c" + lp_alpha: "_cell_angle_alpha" + lp_beta: "_cell_angle_beta" + lp_gamma: "_cell_angle_gamma" dtype: "float32" # one of: float32, float64, float16, bfloat16 -mmap_mode: "r" # NumPy memmap mode: 'r', 'r+', 'w+', or null to disable +mmap_mode: null # NumPy memmap mode: 'r', 'r+', 'w+', or null to disable # --- Model architecture --- -depths: [2, 2, 4, 2] -dims: [128, 256, 384, 560] -kernel_sizes: [101, 65, 51, 25] -strides: [5, 5, 5, 5] -drop_path_rate: 0.2 +depths: [1, 1, 1] +dims: [80, 80, 80] +kernel_sizes: [100, 50, 25] +strides: [5, 5, 5] +drop_path_rate: 0.3 layer_scale_init_value: 1.0e-6 # Heads head_dropout: 0.2 -cs_hidden: [1024, 512] -sg_hidden: [2048, 1024] +cs_hidden: [2300, 1150] +sg_hidden: [2300, 1150] lp_hidden: [512, 256] # Task sizes @@ -66,14 +72,21 @@ lr: 0.0002 # paper used 2e-4 weight_decay: 0.01 # paper used 0.01 use_adamw: true +# --- Logging --- +logger: "csv" # 'csv' or 'mlflow' +csv_logger_name: "model_logs" +mlflow_experiment_name: "OpenAlphaDiffract" +mlflow_tracking_uri: null # null uses MLflow default (file:./mlruns) +mlflow_run_name: "paper_model_run" + # --- Trainer settings --- -default_root_dir: "outputs/paper_model" +default_root_dir: "outputs/model" max_epochs: 100 accumulate_grad_batches: 1 -precision: "32" # e.g., '32', '16-mixed', 'bf16-mixed' -accelerator: "auto" -devices: "auto" -log_every_n_steps: 50 +precision: "bf16-mixed" # e.g., '32', '16-mixed', 'bf16-mixed' +accelerator: "gpu" +devices: 1 +log_every_n_steps: 500 deterministic: false benchmark: true diff --git a/src/trainer/model/model.py b/src/trainer/model/model.py index e785d5a..bdd0993 100644 --- a/src/trainer/model/model.py +++ b/src/trainer/model/model.py @@ -37,8 +37,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # ----------------------------- -# ConvNeXt 1D Block (self-contained) +# ConvNeXt 1D Block (paper-aligned) # Follows ConvNeXt design adapted to 1D: depthwise conv -> LN -> PW-MLP (expand) -> GELU -> PW-MLP (project) -> gamma -> residual +# Note: Paper specifies DropPath on the non-residual (stem) branch; we keep block DropPath disabled by default. # ----------------------------- class ConvNeXtBlock1D(nn.Module): def __init__( @@ -60,6 +61,7 @@ def __init__( self.pwconv2 = nn.Linear(4 * dim, dim) # layer scale (gamma) as a learnable per-channel vector self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(dim)) if layer_scale_init_value > 0 else None + # For paper alignment, set default drop_path=0.0 (no stochastic depth in residual branch) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -79,80 +81,91 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # ----------------------------- -# Backbone: ConvNeXt1D (self-contained) +# Backbone: ConvNeXt1D (paper-aligned: 3 blocks, avg pooling downsampling, final 560-dim projection) # ----------------------------- class ConvNeXt1DBackbonePaper(nn.Module): """ - Self-contained 1D ConvNeXt backbone adapted for XRD: - - 4 stages with progressive downsampling via strided convs - - ConvNeXt blocks within each stage - - Global average pooling at end to produce a feature vector - Final dim_output is dims[-1]. Default dims end at 560 to match the paper. + Paper-aligned 1D ConvNeXt backbone adapted for XRD: + - 3 ConvNeXt blocks (kernel sizes 100, 50, 25; channels 80 throughout) + - Average pooling downsamples after each block with stride 5 + - Final 1x1 pointwise conv projects channels to 560 + - Global average pooling at end to produce a 560-dim feature vector + + Default dims end at 560 to match the paper's feature dimension. """ def __init__( self, in_chans: int = 1, - depths: Tuple[int, int, int, int] = (2, 2, 4, 2), - dims: Tuple[int, int, int, int] = (128, 256, 384, 560), - kernel_sizes: Tuple[int, int, int, int] = (101, 65, 51, 25), - strides: Tuple[int, int, int, int] = (5, 5, 5, 5), - drop_path_rate: float = 0.2, + depths: Tuple[int, int, int] = (1, 1, 1), + dims: Tuple[int, int, int] = (80, 80, 80), + kernel_sizes: Tuple[int, int, int] = (100, 50, 25), + strides: Tuple[int, int, int] = (5, 5, 5), + drop_path_rate: float = 0.3, # applied to non-residual (stem) branch after pooling as per paper layer_scale_init_value: float = 1e-6, + final_dim: int = 560, ): super().__init__() - assert len(depths) == 4 and len(dims) == 4 and len(kernel_sizes) == 4 and len(strides) == 4 + assert len(depths) == 3 and len(dims) == 3 and len(kernel_sizes) == 3 and len(strides) == 3 - self.dim_output = dims[-1] + self.dim_output = final_dim - # Downsampling layers: stem + 3 transitions - self.downsample_layers = nn.ModuleList() - stem = nn.Sequential( - nn.Conv1d(in_chans, dims[0], kernel_size=kernel_sizes[0], stride=strides[0], padding=0), - # channel-first LayerNorm is approximated with GroupNorm(1,..) or use channel-wise LN by permuting when needed - # Here we avoid extra permutes by using simple affine scaling with GroupNorm(1, C) + # Stem: expand channels from 1 -> 80 using 1x1 conv (paper lists Block 1 input channels 1 and output 80; + # for depthwise conv consistency, we expand channels before the ConvNeXt block). + self.stem = nn.Sequential( + nn.Conv1d(in_chans, dims[0], kernel_size=1, stride=1, padding=0), nn.GroupNorm(1, dims[0], eps=1e-6), ) - self.downsample_layers.append(stem) + # 3 stages with ConvNeXt blocks (one block per stage by default) + self.blocks = nn.ModuleList() for i in range(3): - down = nn.Sequential( - nn.GroupNorm(1, dims[i], eps=1e-6), - nn.Conv1d(dims[i], dims[i + 1], kernel_size=kernel_sizes[i + 1], stride=strides[i + 1], padding=0), - ) - self.downsample_layers.append(down) - - # Stages with ConvNeXt blocks - self.stages = nn.ModuleList() - dp_rates = torch.linspace(0, drop_path_rate, sum(depths)).tolist() - cur = 0 - for i in range(4): - blocks = [ + stage_blocks = [ ConvNeXtBlock1D( dim=dims[i], kernel_size=kernel_sizes[i], - drop_path=dp_rates[cur + j], + drop_path=0.0, # paper applies DropPath to non-residual branch; disable in-block stochastic depth layer_scale_init_value=layer_scale_init_value, ) - for j in range(depths[i]) + for _ in range(depths[i]) ] - cur += depths[i] - self.stages.append(nn.Sequential(*blocks)) + self.blocks.append(nn.Sequential(*stage_blocks)) + + # Average pooling downsampling after each stage + self.pools = nn.ModuleList([ + nn.AvgPool1d(kernel_size=strides[i], stride=strides[i]) + for i in range(3) + ]) - self.final_norm = nn.LayerNorm(dims[-1], eps=1e-6) + # Apply DropPath to the non-residual (stem) branch post-pooling (as described in the paper) + self.stem_drops = nn.ModuleList([ + DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + for _ in range(3) + ]) + + # Final projection to 560 channels via 1x1 conv + self.head_proj = nn.Conv1d(dims[-1], final_dim, kernel_size=1, stride=1, padding=0) + self.final_norm = nn.LayerNorm(final_dim, eps=1e-6) def forward(self, x: torch.Tensor) -> torch.Tensor: # Accept (N, 8192) or (N, 1, 8192) if x.ndim == 2: x = x[:, None, :] - # 4 stages - for i in range(4): - x = self.downsample_layers[i](x) # (N, C_i, L_i) - x = self.stages[i](x) # (N, C_i, L_i) - # Global average pooling over length - x = x.mean(dim=-1) # (N, C_last) - x = self.final_norm(x) # LayerNorm over channels - return x # (N, dim_output) == (N, dims[-1]) == 560 by default + + # Stem expansion + x = self.stem(x) # (N, 80, L) + + # 3 blocks with avg pooling downsampling + for i in range(3): + x = self.blocks[i](x) # (N, 80, L) + x = self.pools[i](x) # (N, 80, L_i) + x = self.stem_drops[i](x) # DropPath on non-residual branch + + # Final projection and global average pooling + x = self.head_proj(x) # (N, 560, L_last) + x = x.mean(dim=-1) # (N, 560) + x = self.final_norm(x) # LayerNorm over features + return x # (N, 560) # ----------------------------- @@ -191,10 +204,10 @@ def make_mlp( class AlphaDiffractLightning(pl.LightningModule): """ PyTorch Lightning module for the AlphaDiffract model: - - 1D ConvNeXt backbone producing a 560-dim feature vector + - Paper-aligned 1D ConvNeXt backbone producing a 560-dim feature vector - CS classifier head (7 classes) - SG classifier head (230 classes) - - LP regressor head (6 outputs, optional bounding via sigmoid to [min, max]) + - LP regressor head (6 outputs, bounded via sigmoid to [min, max]) Expected batch formats: - Tuple: (x, y_cs, y_sg, y_lp) @@ -203,26 +216,26 @@ class AlphaDiffractLightning(pl.LightningModule): Losses: - CS: CrossEntropy - SG: CrossEntropy - - LP: L1 (MAE) by default; bounded via sigmoid to given ranges if enabled + - LP: MSE (mean squared error); bounded via sigmoid to given ranges if enabled Metrics logged: - - train/val/test: cs_acc, sg_acc, lp_mae, total_loss + - train/val/test: cs_acc, sg_acc, lp_mae, lp_mse, total_loss """ def __init__( self, - # Backbone args - depths: Tuple[int, int, int, int] = (2, 2, 4, 2), - dims: Tuple[int, int, int, int] = (128, 256, 384, 560), - kernel_sizes: Tuple[int, int, int, int] = (101, 65, 51, 25), - strides: Tuple[int, int, int, int] = (5, 5, 5, 5), - drop_path_rate: float = 0.2, + # Backbone args (paper-aligned) + depths: Tuple[int, int, int] = (1, 1, 1), + dims: Tuple[int, int, int] = (80, 80, 80), + kernel_sizes: Tuple[int, int, int] = (100, 50, 25), + strides: Tuple[int, int, int] = (5, 5, 5), + drop_path_rate: float = 0.3, layer_scale_init_value: float = 1e-6, - # Head dims + # Head dims (paper-aligned for CS/SG, LP unchanged) head_dropout: float = 0.2, - cs_hidden: Optional[Tuple[int, ...]] = (1024, 512), - sg_hidden: Optional[Tuple[int, ...]] = (2048, 1024), + cs_hidden: Optional[Tuple[int, ...]] = (2300, 1150), + sg_hidden: Optional[Tuple[int, ...]] = (2300, 1150), lp_hidden: Optional[Tuple[int, ...]] = (512, 256), # Task sizes @@ -260,8 +273,9 @@ def __init__( strides=strides, drop_path_rate=drop_path_rate, layer_scale_init_value=layer_scale_init_value, + final_dim=560, ) - feat_dim = self.backbone.dim_output # should be 560 with default dims + feat_dim = self.backbone.dim_output # 560 # Heads (produce logits for classification; no softmax here) self.cs_head = make_mlp( @@ -506,19 +520,19 @@ def configure_optimizers(self): # ----------------------------- def build_alphadiffract_model_for_8192() -> AlphaDiffractLightning: """ - Build the default model for 1x8192 XRD input, 560-dim features, and three heads. + Build the paper-aligned model for 1x8192 XRD input, 560-dim features, and three heads. """ return AlphaDiffractLightning( - depths=(2, 2, 4, 2), - dims=(128, 256, 384, 560), - kernel_sizes=(101, 65, 51, 25), - strides=(5, 5, 5, 5), - drop_path_rate=0.2, + depths=(1, 1, 1), + dims=(80, 80, 80), + kernel_sizes=(100, 50, 25), + strides=(5, 5, 5), + drop_path_rate=0.3, layer_scale_init_value=1e-6, head_dropout=0.2, - cs_hidden=(1024, 512), - sg_hidden=(2048, 1024), + cs_hidden=(2300, 1150), + sg_hidden=(2300, 1150), lp_hidden=(512, 256), num_cs_classes=7, diff --git a/src/trainer/train_paper.py b/src/trainer/train_paper.py index 7910964..e813bc2 100644 --- a/src/trainer/train_paper.py +++ b/src/trainer/train_paper.py @@ -6,6 +6,11 @@ import yaml from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor +from pytorch_lightning.loggers import CSVLogger +try: + from pytorch_lightning.loggers import MLFlowLogger +except Exception: + MLFlowLogger = None # Project imports (expect PYTHONPATH=src or run via `python -m trainer.train_paper`) from dataset import NpyDataModule @@ -122,6 +127,19 @@ def build_trainer_from_cfg(cfg: Dict[str, Any]) -> Trainer: ) lr_cb = LearningRateMonitor(logging_interval="epoch") + # Configure logger from config + logger = None + if cfg["logger"] == "csv": + logger = CSVLogger(save_dir=cfg["default_root_dir"], name=cfg["csv_logger_name"]) + elif cfg["logger"] == "mlflow": + if MLFlowLogger is None: + raise ImportError("MLFlowLogger requested but 'mlflow' is not installed. Install with `pip install mlflow`.") + logger = MLFlowLogger( + experiment_name=cfg["mlflow_experiment_name"], + tracking_uri=cfg["mlflow_tracking_uri"], + run_name=cfg["mlflow_run_name"], + ) + trainer = Trainer( default_root_dir=cfg["default_root_dir"], max_epochs=cfg["max_epochs"], @@ -130,6 +148,7 @@ def build_trainer_from_cfg(cfg: Dict[str, Any]) -> Trainer: precision=cfg["precision"], accumulate_grad_batches=cfg["accumulate_grad_batches"], callbacks=[ckpt_cb, lr_cb], + logger=logger, log_every_n_steps=cfg["log_every_n_steps"], deterministic=cfg["deterministic"], benchmark=cfg["benchmark"], From e9958b4e0787e9cd271e291ef908bce414522599 Mon Sep 17 00:00:00 2001 From: linked-liszt Date: Mon, 3 Nov 2025 22:04:33 -0600 Subject: [PATCH 04/18] fix: Matmul precision and padding fix --- src/trainer/model/model.py | 10 ++-------- src/trainer/train_paper.py | 3 +++ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/trainer/model/model.py b/src/trainer/model/model.py index bdd0993..a248b2e 100644 --- a/src/trainer/model/model.py +++ b/src/trainer/model/model.py @@ -5,12 +5,7 @@ import torch.nn.functional as F import numpy as np -try: - import pytorch_lightning as pl -except ImportError as e: - raise ImportError( - "pytorch_lightning is required for this module. Please install with `pip install pytorch-lightning`." - ) from e +import pytorch_lightning as pl # ----------------------------- @@ -51,8 +46,7 @@ def __init__( ): super().__init__() # depthwise conv - padding = kernel_size // 2 - self.dwconv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding, groups=dim) + self.dwconv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding='same', groups=dim) # LayerNorm over channels - apply by permuting to (N, L, C) self.norm = nn.LayerNorm(dim, eps=1e-6) # pointwise MLP diff --git a/src/trainer/train_paper.py b/src/trainer/train_paper.py index e813bc2..0a918af 100644 --- a/src/trainer/train_paper.py +++ b/src/trainer/train_paper.py @@ -157,11 +157,14 @@ def build_trainer_from_cfg(cfg: Dict[str, Any]) -> Trainer: def main(): + args = parse_args() cfg = load_config(args.config) seed_everything(cfg["seed"], workers=True) + torch.set_float32_matmul_precision('high') + dm = build_datamodule_from_cfg(cfg) model = build_model_from_cfg(cfg) trainer = build_trainer_from_cfg(cfg) From ccfb9ce2b76c0e0c8c06d18cebcf7ebff1a40a4e Mon Sep 17 00:00:00 2001 From: linked-liszt Date: Tue, 4 Nov 2025 13:36:59 -0600 Subject: [PATCH 05/18] fix: Change GELU to leakyRELU as per paper --- configs/trainer.yaml | 14 +++---- src/trainer/model/model.py | 75 ++++++++++++++------------------------ 2 files changed, 35 insertions(+), 54 deletions(-) diff --git a/configs/trainer.yaml b/configs/trainer.yaml index a588b3a..367009f 100644 --- a/configs/trainer.yaml +++ b/configs/trainer.yaml @@ -2,8 +2,8 @@ # This file is required by src/trainer/train_paper.py. It contains all parameters with no script-side defaults. # --- Data / Manifests --- -manifest_dir: "../../data/manifests" -dataset_root: "../../data/dataset" # used when auto_generate_manifests is true +manifest_dir: "../../../data/manifests" +dataset_root: "../../../data/dataset" # used when auto_generate_manifests is true auto_generate_manifests: true train_ratio: 0.8 val_ratio: 0.1 @@ -11,7 +11,7 @@ test_ratio: 0.1 seed: 42 # --- DataLoader --- -batch_size: 64 # paper used 64 +batch_size: 1024 # paper used 64 num_workers: 8 pin_memory: true persistent_workers: true @@ -68,12 +68,12 @@ gemd_mu: 0.0 gemd_distance_matrix: # e.g., "path/to/space_group_distance_matrix.npy" to enable GEMD # Optimizer -lr: 0.0002 # paper used 2e-4 +lr: 0.01 # paper used 2e-4 weight_decay: 0.01 # paper used 0.01 use_adamw: true # --- Logging --- -logger: "csv" # 'csv' or 'mlflow' +logger: "mlflow" # 'csv' or 'mlflow' csv_logger_name: "model_logs" mlflow_experiment_name: "OpenAlphaDiffract" mlflow_tracking_uri: null # null uses MLflow default (file:./mlruns) @@ -81,12 +81,12 @@ mlflow_run_name: "paper_model_run" # --- Trainer settings --- default_root_dir: "outputs/model" -max_epochs: 100 +max_epochs: 5 accumulate_grad_batches: 1 precision: "bf16-mixed" # e.g., '32', '16-mixed', 'bf16-mixed' accelerator: "gpu" devices: 1 -log_every_n_steps: 500 +log_every_n_steps: 100 deterministic: false benchmark: true diff --git a/src/trainer/model/model.py b/src/trainer/model/model.py index a248b2e..bf9abdd 100644 --- a/src/trainer/model/model.py +++ b/src/trainer/model/model.py @@ -32,9 +32,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # ----------------------------- -# ConvNeXt 1D Block (paper-aligned) +# ConvNeXt 1D Block # Follows ConvNeXt design adapted to 1D: depthwise conv -> LN -> PW-MLP (expand) -> GELU -> PW-MLP (project) -> gamma -> residual -# Note: Paper specifies DropPath on the non-residual (stem) branch; we keep block DropPath disabled by default. # ----------------------------- class ConvNeXtBlock1D(nn.Module): def __init__( @@ -45,17 +44,12 @@ def __init__( layer_scale_init_value: float = 1e-6, ): super().__init__() - # depthwise conv self.dwconv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding='same', groups=dim) - # LayerNorm over channels - apply by permuting to (N, L, C) self.norm = nn.LayerNorm(dim, eps=1e-6) - # pointwise MLP self.pwconv1 = nn.Linear(dim, 4 * dim) - self.act = nn.GELU() + self.act = nn.LeakyReLU() self.pwconv2 = nn.Linear(4 * dim, dim) - # layer scale (gamma) as a learnable per-channel vector self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(dim)) if layer_scale_init_value > 0 else None - # For paper alignment, set default drop_path=0.0 (no stochastic depth in residual branch) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -75,17 +69,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # ----------------------------- -# Backbone: ConvNeXt1D (paper-aligned: 3 blocks, avg pooling downsampling, final 560-dim projection) +# Backbone: ConvNeXt1D # ----------------------------- -class ConvNeXt1DBackbonePaper(nn.Module): +class ConvNeXt1DBackbone(nn.Module): """ - Paper-aligned 1D ConvNeXt backbone adapted for XRD: + ConvNeXt backbone adapted for XRD: - 3 ConvNeXt blocks (kernel sizes 100, 50, 25; channels 80 throughout) - Average pooling downsamples after each block with stride 5 - - Final 1x1 pointwise conv projects channels to 560 - - Global average pooling at end to produce a 560-dim feature vector + - Global average pooling at end to produce an 80-dim feature vector (final_pool=true) + - Output type: flatten - Default dims end at 560 to match the paper's feature dimension. + Default feature dim is 80 per config. """ def __init__( @@ -95,17 +89,14 @@ def __init__( dims: Tuple[int, int, int] = (80, 80, 80), kernel_sizes: Tuple[int, int, int] = (100, 50, 25), strides: Tuple[int, int, int] = (5, 5, 5), - drop_path_rate: float = 0.3, # applied to non-residual (stem) branch after pooling as per paper + dropout_rate: float = 0.3, layer_scale_init_value: float = 1e-6, - final_dim: int = 560, ): super().__init__() assert len(depths) == 3 and len(dims) == 3 and len(kernel_sizes) == 3 and len(strides) == 3 - self.dim_output = final_dim + self.dim_output = dims[-1] - # Stem: expand channels from 1 -> 80 using 1x1 conv (paper lists Block 1 input channels 1 and output 80; - # for depthwise conv consistency, we expand channels before the ConvNeXt block). self.stem = nn.Sequential( nn.Conv1d(in_chans, dims[0], kernel_size=1, stride=1, padding=0), nn.GroupNorm(1, dims[0], eps=1e-6), @@ -118,8 +109,8 @@ def __init__( ConvNeXtBlock1D( dim=dims[i], kernel_size=kernel_sizes[i], - drop_path=0.0, # paper applies DropPath to non-residual branch; disable in-block stochastic depth - layer_scale_init_value=layer_scale_init_value, + drop_path=0.0, + ) for _ in range(depths[i]) ] @@ -131,15 +122,11 @@ def __init__( for i in range(3) ]) - # Apply DropPath to the non-residual (stem) branch post-pooling (as described in the paper) self.stem_drops = nn.ModuleList([ - DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + nn.Dropout(p=dropout_rate) if dropout_rate > 0.0 else nn.Identity() for _ in range(3) ]) - # Final projection to 560 channels via 1x1 conv - self.head_proj = nn.Conv1d(dims[-1], final_dim, kernel_size=1, stride=1, padding=0) - self.final_norm = nn.LayerNorm(final_dim, eps=1e-6) def forward(self, x: torch.Tensor) -> torch.Tensor: # Accept (N, 8192) or (N, 1, 8192) @@ -153,13 +140,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: for i in range(3): x = self.blocks[i](x) # (N, 80, L) x = self.pools[i](x) # (N, 80, L_i) - x = self.stem_drops[i](x) # DropPath on non-residual branch + x = self.stem_drops[i](x) # Dropout on non-residual branch - # Final projection and global average pooling - x = self.head_proj(x) # (N, 560, L_last) - x = x.mean(dim=-1) # (N, 560) - x = self.final_norm(x) # LayerNorm over features - return x # (N, 560) + # Final global average pooling (final_pool=true) and flatten + x = x.mean(dim=-1) # (N, 80) + return x # (N, 80) # ----------------------------- @@ -176,7 +161,7 @@ def make_mlp( last = input_dim if hidden_dims is not None and len(hidden_dims) > 0: for hd in hidden_dims: - layers.extend([nn.Linear(last, hd), nn.ReLU()]) + layers.extend([nn.Linear(last, hd), nn.LeakyReLU()]) if dropout and dropout > 0: layers.append(nn.Dropout(dropout)) last = hd @@ -198,7 +183,6 @@ def make_mlp( class AlphaDiffractLightning(pl.LightningModule): """ PyTorch Lightning module for the AlphaDiffract model: - - Paper-aligned 1D ConvNeXt backbone producing a 560-dim feature vector - CS classifier head (7 classes) - SG classifier head (230 classes) - LP regressor head (6 outputs, bounded via sigmoid to [min, max]) @@ -218,16 +202,14 @@ class AlphaDiffractLightning(pl.LightningModule): def __init__( self, - # Backbone args (paper-aligned) depths: Tuple[int, int, int] = (1, 1, 1), dims: Tuple[int, int, int] = (80, 80, 80), kernel_sizes: Tuple[int, int, int] = (100, 50, 25), strides: Tuple[int, int, int] = (5, 5, 5), - drop_path_rate: float = 0.3, + dropout_rate: float = 0.3, layer_scale_init_value: float = 1e-6, - # Head dims (paper-aligned for CS/SG, LP unchanged) - head_dropout: float = 0.2, + head_dropout: float = 0.5, cs_hidden: Optional[Tuple[int, ...]] = (2300, 1150), sg_hidden: Optional[Tuple[int, ...]] = (2300, 1150), lp_hidden: Optional[Tuple[int, ...]] = (512, 256), @@ -259,17 +241,16 @@ def __init__( self.save_hyperparameters() # Backbone - self.backbone = ConvNeXt1DBackbonePaper( + self.backbone = ConvNeXt1DBackbone( in_chans=1, depths=depths, dims=dims, kernel_sizes=kernel_sizes, strides=strides, - drop_path_rate=drop_path_rate, + dropout_rate=dropout_rate, layer_scale_init_value=layer_scale_init_value, - final_dim=560, ) - feat_dim = self.backbone.dim_output # 560 + feat_dim = self.backbone.dim_output # 80 # Heads (produce logits for classification; no softmax here) self.cs_head = make_mlp( @@ -339,9 +320,9 @@ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: - cs_logits: (N, 7) - sg_logits: (N, 230) - lp: (N, 6) bounded if enabled - - features: (N, 560) + - features: (N, 80) """ - feats = self.backbone(x) # (N, 560) + feats = self.backbone(x) # (N, 80) cs_logits = self.cs_head(feats) sg_logits = self.sg_head(feats) lp = self.lp_head(feats) @@ -514,17 +495,17 @@ def configure_optimizers(self): # ----------------------------- def build_alphadiffract_model_for_8192() -> AlphaDiffractLightning: """ - Build the paper-aligned model for 1x8192 XRD input, 560-dim features, and three heads. + Build the model for 1x8192 XRD input, 80-dim features, and three heads. """ return AlphaDiffractLightning( depths=(1, 1, 1), dims=(80, 80, 80), kernel_sizes=(100, 50, 25), strides=(5, 5, 5), - drop_path_rate=0.3, + dropout_rate=0.3, layer_scale_init_value=1e-6, - head_dropout=0.2, + head_dropout=0.5, cs_hidden=(2300, 1150), sg_hidden=(2300, 1150), lp_hidden=(512, 256), From c0854e0bd07a45e0c8a636bafbd6c1980eb5651a Mon Sep 17 00:00:00 2001 From: linked-liszt Date: Tue, 4 Nov 2025 21:52:38 -0600 Subject: [PATCH 06/18] feat: Training system tweaks --- .gitignore | 1 + configs/trainer.yaml | 22 +++++++++++++--------- src/trainer/dataset/dataset.py | 12 +++++++++++- src/trainer/train_paper.py | 30 ++++++++++++++++++++++++------ 4 files changed, 49 insertions(+), 16 deletions(-) diff --git a/.gitignore b/.gitignore index 9d905f8..6e54e0a 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ __pycache__ /staging /data /original +og # Non-Docker Training Outputs /src/trainer/outputs diff --git a/configs/trainer.yaml b/configs/trainer.yaml index 367009f..340d42e 100644 --- a/configs/trainer.yaml +++ b/configs/trainer.yaml @@ -2,8 +2,8 @@ # This file is required by src/trainer/train_paper.py. It contains all parameters with no script-side defaults. # --- Data / Manifests --- -manifest_dir: "../../../data/manifests" -dataset_root: "../../../data/dataset" # used when auto_generate_manifests is true +manifest_dir: "../../data/manifests" +dataset_root: "../../data/dataset" # used when auto_generate_manifests is true auto_generate_manifests: true train_ratio: 0.8 val_ratio: 0.1 @@ -11,7 +11,7 @@ test_ratio: 0.1 seed: 42 # --- DataLoader --- -batch_size: 1024 # paper used 64 +batch_size: 512 # paper used 64 num_workers: 8 pin_memory: true persistent_workers: true @@ -33,13 +33,15 @@ labels_key_map: lp_gamma: "_cell_angle_gamma" dtype: "float32" # one of: float32, float64, float16, bfloat16 mmap_mode: null # NumPy memmap mode: 'r', 'r+', 'w+', or null to disable +floor_at_zero: False # Clamp negative counts to 0 before any normalization +normalize_log1p: false # If true, apply log1p(x) to compress dynamic range # --- Model architecture --- depths: [1, 1, 1] dims: [80, 80, 80] kernel_sizes: [100, 50, 25] strides: [5, 5, 5] -drop_path_rate: 0.3 +dropout_rate: 0.3 layer_scale_init_value: 1.0e-6 # Heads @@ -65,28 +67,30 @@ lambda_lp: 1.0 # Optional GEMD term on SG gemd_mu: 0.0 -gemd_distance_matrix: # e.g., "path/to/space_group_distance_matrix.npy" to enable GEMD +gemd_distance_matrix_path: # e.g., "path/to/space_group_distance_matrix.npy" to enable GEMD # Optimizer -lr: 0.01 # paper used 2e-4 +lr: 0.004 # paper used 2e-4 weight_decay: 0.01 # paper used 0.01 use_adamw: true +gradient_clip_val: 1.0 +gradient_clip_algorithm: "norm" # --- Logging --- logger: "mlflow" # 'csv' or 'mlflow' csv_logger_name: "model_logs" mlflow_experiment_name: "OpenAlphaDiffract" mlflow_tracking_uri: null # null uses MLflow default (file:./mlruns) -mlflow_run_name: "paper_model_run" +mlflow_run_name: "OpenAlphaDiffract_Run" # --- Trainer settings --- default_root_dir: "outputs/model" -max_epochs: 5 +max_epochs: 50 accumulate_grad_batches: 1 precision: "bf16-mixed" # e.g., '32', '16-mixed', 'bf16-mixed' accelerator: "gpu" devices: 1 -log_every_n_steps: 100 +log_every_n_steps: 50 deterministic: false benchmark: true diff --git a/src/trainer/dataset/dataset.py b/src/trainer/dataset/dataset.py index 794c97e..69b3bc8 100644 --- a/src/trainer/dataset/dataset.py +++ b/src/trainer/dataset/dataset.py @@ -100,6 +100,8 @@ def __init__( extract_labels: bool = False, labels_key_map: Optional[Dict[str, List[str]]] = None, allow_pickle: bool = True, + floor_at_zero: bool = True, + normalize_log1p: bool = False, ) -> None: super().__init__() self.manifest_path = manifest_path @@ -109,6 +111,8 @@ def __init__( self.return_meta = return_meta self.extract_labels = extract_labels self.allow_pickle = allow_pickle + self.floor_at_zero = floor_at_zero + self.normalize_log1p = normalize_log1p # Default key mapping for extracting fields from embedded containers # Simplified: single string keys, no search lists self.labels_key_map = labels_key_map or { @@ -229,9 +233,15 @@ def _get_exact(container, key: str): else: raise TypeError(f"Unsupported file format: expected object dtype .npy, got dtype={getattr(arr, 'dtype', '?')} for {fpath}") - # Cast dtype and apply transform to x only + # Cast dtype and apply preprocessing/transform to x only if self.dtype is not None and x_tensor is not None: x_tensor = x_tensor.to(self.dtype) + # Ensure non-negative counts + if x_tensor is not None and self.floor_at_zero: + x_tensor = torch.clamp(x_tensor, min=0) + # Optional log1p normalization to compress peak variance + if x_tensor is not None and self.normalize_log1p: + x_tensor = torch.log1p(x_tensor) if self.transform is not None and x_tensor is not None: x_tensor = self.transform(x_tensor) diff --git a/src/trainer/train_paper.py b/src/trainer/train_paper.py index 0a918af..ecf327c 100644 --- a/src/trainer/train_paper.py +++ b/src/trainer/train_paper.py @@ -5,7 +5,7 @@ import torch import yaml from pytorch_lightning import Trainer, seed_everything -from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor +from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, Callback from pytorch_lightning.loggers import CSVLogger try: from pytorch_lightning.loggers import MLFlowLogger @@ -55,6 +55,8 @@ def build_datamodule_from_cfg(cfg: Dict[str, Any]) -> NpyDataModule: "validate_paths": cfg["validate_paths"], "extract_labels": cfg["extract_labels"], "allow_pickle": cfg["allow_pickle"], + "floor_at_zero": cfg["floor_at_zero"], + "normalize_log1p": cfg["normalize_log1p"], } labels_key_map = cfg.get("labels_key_map") if labels_key_map is not None: @@ -84,7 +86,7 @@ def build_model_from_cfg(cfg: Dict[str, Any]) -> AlphaDiffractLightning: dims=tuple(cfg["dims"]), kernel_sizes=tuple(cfg["kernel_sizes"]), strides=tuple(cfg["strides"]), - drop_path_rate=cfg["drop_path_rate"], + dropout_rate=cfg["dropout_rate"], layer_scale_init_value=cfg["layer_scale_init_value"], # Heads head_dropout=cfg["head_dropout"], @@ -105,7 +107,7 @@ def build_model_from_cfg(cfg: Dict[str, Any]) -> AlphaDiffractLightning: lambda_lp=cfg["lambda_lp"], # Optional GEMD gemd_mu=cfg["gemd_mu"], - gemd_distance_matrix_path=cfg.get("gemd_distance_matrix"), + gemd_distance_matrix_path=cfg.get("gemd_distance_matrix_path"), # Optimizer lr=cfg["lr"], weight_decay=cfg["weight_decay"], @@ -114,7 +116,21 @@ def build_model_from_cfg(cfg: Dict[str, Any]) -> AlphaDiffractLightning: return model -def build_trainer_from_cfg(cfg: Dict[str, Any]) -> Trainer: +class ConfigArtifactLogger(Callback): + """ + Minimal callback: if MLFlowLogger is the active logger, log the raw YAML config + as an MLflow artifact under 'configs/'. + """ + def __init__(self, raw_config_path: str): + self.raw_config_path = raw_config_path + + def on_fit_start(self, trainer, pl_module) -> None: + logger = getattr(trainer, "logger", None) + if MLFlowLogger is not None and isinstance(logger, MLFlowLogger): + logger.experiment.log_artifact(logger.run_id, self.raw_config_path) + + +def build_trainer_from_cfg(cfg: Dict[str, Any], raw_config_path: Optional[str] = None) -> Trainer: ckpt_cb = ModelCheckpoint( monitor=cfg["monitor"], mode=cfg["mode"], @@ -147,11 +163,13 @@ def build_trainer_from_cfg(cfg: Dict[str, Any]) -> Trainer: devices=cfg["devices"], precision=cfg["precision"], accumulate_grad_batches=cfg["accumulate_grad_batches"], - callbacks=[ckpt_cb, lr_cb], + callbacks=[ckpt_cb, lr_cb, ConfigArtifactLogger(raw_config_path)], logger=logger, log_every_n_steps=cfg["log_every_n_steps"], deterministic=cfg["deterministic"], benchmark=cfg["benchmark"], + gradient_clip_val=cfg.get("gradient_clip_val", 0.0), + gradient_clip_algorithm=cfg.get("gradient_clip_algorithm", "norm"), ) return trainer @@ -167,7 +185,7 @@ def main(): dm = build_datamodule_from_cfg(cfg) model = build_model_from_cfg(cfg) - trainer = build_trainer_from_cfg(cfg) + trainer = build_trainer_from_cfg(cfg, raw_config_path=args.config) # Train resume_from: Optional[str] = cfg.get("resume_from") From d6edbdd98827eddb18e21279289a6ebbac3fe8b0 Mon Sep 17 00:00:00 2001 From: linked-liszt Date: Wed, 5 Nov 2025 10:35:52 -0600 Subject: [PATCH 07/18] feat: Experiment with full convnext feature extractor --- configs/norm_train.yaml | 105 +++++++++++++++++++++ configs/trainer.yaml | 19 ++-- src/trainer/model/model.py | 187 ++++++++++++++++++++++--------------- src/trainer/train_paper.py | 1 + 4 files changed, 230 insertions(+), 82 deletions(-) create mode 100644 configs/norm_train.yaml diff --git a/configs/norm_train.yaml b/configs/norm_train.yaml new file mode 100644 index 0000000..fb55418 --- /dev/null +++ b/configs/norm_train.yaml @@ -0,0 +1,105 @@ +# AlphaDiffract trainer configuration (paper-aligned defaults provided here) +# This file is required by src/trainer/train_paper.py. It contains all parameters with no script-side defaults. + +# --- Data / Manifests --- +manifest_dir: "../../data/manifests" +dataset_root: "../../data/dataset" # used when auto_generate_manifests is true +auto_generate_manifests: true +train_ratio: 0.8 +val_ratio: 0.1 +test_ratio: 0.1 +seed: 42 + +# --- DataLoader --- +batch_size: 512 # paper used 64 +num_workers: 8 +pin_memory: true +persistent_workers: true + +# --- Dataset label extraction (embedded in .npy/.npz) --- +validate_paths: false +extract_labels: true +allow_pickle: true +labels_key_map: + x: "dp" + cs: "cs" + sg: "sg" + lattice_params: null + lp_a: "_cell_length_a" + lp_b: "_cell_length_b" + lp_c: "_cell_length_c" + lp_alpha: "_cell_angle_alpha" + lp_beta: "_cell_angle_beta" + lp_gamma: "_cell_angle_gamma" +dtype: "float32" # one of: float32, float64, float16, bfloat16 +mmap_mode: null # NumPy memmap mode: 'r', 'r+', 'w+', or null to disable +floor_at_zero: True # Clamp negative counts to 0 before any normalization +normalize_log1p: True # If true, apply log1p(x) to compress dynamic range + +# --- Model architecture --- +depths: [1, 1, 1] +dims: [80, 80, 80] +kernel_sizes: [100, 50, 25] +strides: [5, 5, 5] +dropout_rate: 0.3 +layer_scale_init_value: 1.0e-6 + +# Heads +head_dropout: 0.2 +cs_hidden: [2300, 1150] +sg_hidden: [2300, 1150] +lp_hidden: [512, 256] + +# Task sizes +num_cs_classes: 7 +num_sg_classes: 230 +num_lp_outputs: 6 + +# LP output bounds +lp_bounds_min: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +lp_bounds_max: [500.0, 500.0, 500.0, 180.0, 180.0, 180.0] +bound_lp_with_sigmoid: true + +# Loss weights +lambda_cs: 1.0 +lambda_sg: 1.0 +lambda_lp: 1.0 + +# Optional GEMD term on SG +gemd_mu: 0.0 +gemd_distance_matrix_path: # e.g., "path/to/space_group_distance_matrix.npy" to enable GEMD + +# Optimizer +lr: 0.002 # paper used 2e-4 +weight_decay: 0.01 # paper used 0.01 +use_adamw: true +gradient_clip_val: 1.0 +gradient_clip_algorithm: "norm" + +# --- Logging --- +logger: "mlflow" # 'csv' or 'mlflow' +csv_logger_name: "model_logs" +mlflow_experiment_name: "OpenAlphaDiffractNorm" +mlflow_tracking_uri: null # null uses MLflow default (file:./mlruns) +mlflow_run_name: "OpenAlphaDiffract_Run" + +# --- Trainer settings --- +default_root_dir: "outputs/model" +max_epochs: 50 +accumulate_grad_batches: 1 +precision: "bf16-mixed" # e.g., '32', '16-mixed', 'bf16-mixed' +accelerator: "gpu" +devices: 1 +log_every_n_steps: 50 +deterministic: false +benchmark: true + +# --- Checkpointing --- +monitor: "val/loss" +mode: "min" +save_top_k: 1 +every_n_epochs: 1 + +# --- Evaluation --- +resume_from: # e.g., "outputs/paper_model/checkpoints/epochXYZ.ckpt" +test_after_train: true diff --git a/configs/trainer.yaml b/configs/trainer.yaml index 340d42e..d4125f7 100644 --- a/configs/trainer.yaml +++ b/configs/trainer.yaml @@ -11,7 +11,7 @@ test_ratio: 0.1 seed: 42 # --- DataLoader --- -batch_size: 512 # paper used 64 +batch_size: 256 # paper used 64 num_workers: 8 pin_memory: true persistent_workers: true @@ -33,16 +33,17 @@ labels_key_map: lp_gamma: "_cell_angle_gamma" dtype: "float32" # one of: float32, float64, float16, bfloat16 mmap_mode: null # NumPy memmap mode: 'r', 'r+', 'w+', or null to disable -floor_at_zero: False # Clamp negative counts to 0 before any normalization -normalize_log1p: false # If true, apply log1p(x) to compress dynamic range +floor_at_zero: True # Clamp negative counts to 0 before any normalization +normalize_log1p: True # If true, apply log1p(x) to compress dynamic range # --- Model architecture --- -depths: [1, 1, 1] -dims: [80, 80, 80] -kernel_sizes: [100, 50, 25] -strides: [5, 5, 5] +depths: [3, 3, 9, 3] +dims: [80, 160, 320, 640] +kernel_sizes: [7, 7, 7, 7] +strides: [4, 2, 2, 2] dropout_rate: 0.3 layer_scale_init_value: 1.0e-6 +drop_path_rate: 0.1 # Heads head_dropout: 0.2 @@ -70,7 +71,7 @@ gemd_mu: 0.0 gemd_distance_matrix_path: # e.g., "path/to/space_group_distance_matrix.npy" to enable GEMD # Optimizer -lr: 0.004 # paper used 2e-4 +lr: 0.0001 # paper used 2e-4 weight_decay: 0.01 # paper used 0.01 use_adamw: true gradient_clip_val: 1.0 @@ -79,7 +80,7 @@ gradient_clip_algorithm: "norm" # --- Logging --- logger: "mlflow" # 'csv' or 'mlflow' csv_logger_name: "model_logs" -mlflow_experiment_name: "OpenAlphaDiffract" +mlflow_experiment_name: "OpenAlphaDiffract_ConvFUll" mlflow_tracking_uri: null # null uses MLflow default (file:./mlruns) mlflow_run_name: "OpenAlphaDiffract_Run" diff --git a/src/trainer/model/model.py b/src/trainer/model/model.py index bf9abdd..7e01f20 100644 --- a/src/trainer/model/model.py +++ b/src/trainer/model/model.py @@ -33,7 +33,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # ----------------------------- # ConvNeXt 1D Block -# Follows ConvNeXt design adapted to 1D: depthwise conv -> LN -> PW-MLP (expand) -> GELU -> PW-MLP (project) -> gamma -> residual +# Follows ConvNeXt design adapted to 1D: +# depthwise conv -> LN (channels-last) -> PW-MLP (expand 4x) -> GELU -> PW-MLP (project) -> layer-scale gamma -> residual (+ DropPath) # ----------------------------- class ConvNeXtBlock1D(nn.Module): def __init__( @@ -44,12 +45,17 @@ def __init__( layer_scale_init_value: float = 1e-6, ): super().__init__() + # depthwise 1D conv self.dwconv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding='same', groups=dim) + # LayerNorm in channels-last layout self.norm = nn.LayerNorm(dim, eps=1e-6) + # pointwise MLP implemented by Linear on channels-last self.pwconv1 = nn.Linear(dim, 4 * dim) - self.act = nn.LeakyReLU() + self.act = nn.GELU() self.pwconv2 = nn.Linear(4 * dim, dim) + # layer-scale self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(dim)) if layer_scale_init_value > 0 else None + # stochastic depth self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -69,82 +75,122 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # ----------------------------- -# Backbone: ConvNeXt1D +# Downsample: LN (channels-last) -> Conv1d (stride, channel increase) +# ----------------------------- +class Downsample1D(nn.Module): + def __init__(self, in_dim: int, out_dim: int, stride: int): + super().__init__() + self.norm = nn.LayerNorm(in_dim, eps=1e-6) + # Use kernel_size = stride to mimic patch/downsample conv + self.conv = nn.Conv1d(in_dim, out_dim, kernel_size=stride, stride=stride) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: (N, C, L) + x = x.permute(0, 2, 1) # (N, L, C) + x = self.norm(x) + x = x.permute(0, 2, 1) # (N, C, L) + x = self.conv(x) # (N, C_out, L/stride) + return x + + +# ----------------------------- +# Backbone: ConvNeXt1D (generalized to N stages) # ----------------------------- class ConvNeXt1DBackbone(nn.Module): """ - ConvNeXt backbone adapted for XRD: - - 3 ConvNeXt blocks (kernel sizes 100, 50, 25; channels 80 throughout) - - Average pooling downsamples after each block with stride 5 - - Global average pooling at end to produce an 80-dim feature vector (final_pool=true) - - Output type: flatten - - Default feature dim is 80 per config. + ConvNeXt backbone adapted for 1D XRD signals, generalized to N stages: + - Stem: patchify Conv1d with stride=strides[0], out=dims[0] + LayerNorm + - Each stage: [depths[i] ConvNeXt blocks] using DWConv k=kernel_sizes[i], GELU, layer-scale + - Between stages: Downsample1D with stride=strides[i+1], increasing channels dims[i] -> dims[i+1] + - Final: global average pooling over length to produce dims[-1]-dim feature vector + + Notes: + - Stochastic depth (drop path) is linearly scheduled across all blocks from 0 to drop_path_rate. + - kernel_sizes can be domain-specific; ConvNeXt canonical uses k=7. + - dims should typically increase across stages (e.g., [80, 160, 320, 640]). """ def __init__( self, in_chans: int = 1, - depths: Tuple[int, int, int] = (1, 1, 1), - dims: Tuple[int, int, int] = (80, 80, 80), - kernel_sizes: Tuple[int, int, int] = (100, 50, 25), - strides: Tuple[int, int, int] = (5, 5, 5), + depths: Tuple[int, ...] = (3, 3, 9, 3), + dims: Tuple[int, ...] = (80, 160, 320, 640), + kernel_sizes: Tuple[int, ...] = (7, 7, 7, 7), + strides: Tuple[int, ...] = (4, 2, 2, 2), dropout_rate: float = 0.3, layer_scale_init_value: float = 1e-6, + drop_path_rate: float = 0.1, ): super().__init__() - assert len(depths) == 3 and len(dims) == 3 and len(kernel_sizes) == 3 and len(strides) == 3 + + n_stages = len(depths) + assert len(dims) == n_stages and len(kernel_sizes) == n_stages and len(strides) == n_stages, \ + "depths, dims, kernel_sizes, strides must have same length" self.dim_output = dims[-1] - self.stem = nn.Sequential( - nn.Conv1d(in_chans, dims[0], kernel_size=1, stride=1, padding=0), - nn.GroupNorm(1, dims[0], eps=1e-6), - ) + # Stem: patchify conv + LayerNorm (channels-last) + self.stem_conv = nn.Conv1d(in_chans, dims[0], kernel_size=strides[0], stride=strides[0]) + self.stem_norm = nn.LayerNorm(dims[0], eps=1e-6) - # 3 stages with ConvNeXt blocks (one block per stage by default) + # Stochastic depth schedule across all blocks + total_blocks = int(sum(depths)) + if total_blocks > 0 and drop_path_rate > 0.0: + dp_rates = np.linspace(0.0, drop_path_rate, total_blocks).tolist() + else: + dp_rates = [0.0] * max(total_blocks, 1) + + # Build stages self.blocks = nn.ModuleList() - for i in range(3): - stage_blocks = [ - ConvNeXtBlock1D( - dim=dims[i], - kernel_size=kernel_sizes[i], - drop_path=0.0, - + dp_idx = 0 + for i in range(n_stages): + stage_blocks: List[nn.Module] = [] + for _ in range(depths[i]): + stage_blocks.append( + ConvNeXtBlock1D( + dim=dims[i], + kernel_size=kernel_sizes[i], + drop_path=dp_rates[dp_idx] if dp_idx < len(dp_rates) else 0.0, + layer_scale_init_value=layer_scale_init_value, + ) ) - for _ in range(depths[i]) - ] + dp_idx += 1 self.blocks.append(nn.Sequential(*stage_blocks)) - # Average pooling downsampling after each stage - self.pools = nn.ModuleList([ - nn.AvgPool1d(kernel_size=strides[i], stride=strides[i]) - for i in range(3) - ]) + # Downsample transitions between stages + self.downsamples = nn.ModuleList() + for i in range(n_stages - 1): + self.downsamples.append( + Downsample1D(in_dim=dims[i], out_dim=dims[i + 1], stride=strides[i + 1]) + ) - self.stem_drops = nn.ModuleList([ + # Optional dropout after each downsample transition + self.down_drops = nn.ModuleList([ nn.Dropout(p=dropout_rate) if dropout_rate > 0.0 else nn.Identity() - for _ in range(3) + for _ in range(n_stages - 1) ]) - def forward(self, x: torch.Tensor) -> torch.Tensor: - # Accept (N, 8192) or (N, 1, 8192) + # Accept (N, L) or (N, 1, L) if x.ndim == 2: x = x[:, None, :] - # Stem expansion - x = self.stem(x) # (N, 80, L) - - # 3 blocks with avg pooling downsampling - for i in range(3): - x = self.blocks[i](x) # (N, 80, L) - x = self.pools[i](x) # (N, 80, L_i) - x = self.stem_drops[i](x) # Dropout on non-residual branch - - # Final global average pooling (final_pool=true) and flatten - x = x.mean(dim=-1) # (N, 80) - return x # (N, 80) + # Stem + x = self.stem_conv(x) # (N, C0, L0) + x = x.permute(0, 2, 1) # (N, L0, C0) + x = self.stem_norm(x) + x = x.permute(0, 2, 1) # (N, C0, L0) + + # Stages + downsample transitions + for i, stage in enumerate(self.blocks): + x = stage(x) # (N, Ci, Li) + if i < len(self.downsamples): + x = self.downsamples[i](x) # (N, Ci+1, Li+1) + x = self.down_drops[i](x) + + # Global average pooling over length + x = x.mean(dim=-1) # (N, dims[-1]) + return x # ----------------------------- @@ -190,24 +236,17 @@ class AlphaDiffractLightning(pl.LightningModule): Expected batch formats: - Tuple: (x, y_cs, y_sg, y_lp) - Dict keys: x or xrd or signal; cs, sg, lattice_params (or lp) - - Losses: - - CS: CrossEntropy - - SG: CrossEntropy - - LP: MSE (mean squared error); bounded via sigmoid to given ranges if enabled - - Metrics logged: - - train/val/test: cs_acc, sg_acc, lp_mae, lp_mse, total_loss """ def __init__( self, - depths: Tuple[int, int, int] = (1, 1, 1), - dims: Tuple[int, int, int] = (80, 80, 80), - kernel_sizes: Tuple[int, int, int] = (100, 50, 25), - strides: Tuple[int, int, int] = (5, 5, 5), + depths: Tuple[int, ...] = (3, 3, 9, 3), + dims: Tuple[int, ...] = (80, 160, 320, 640), + kernel_sizes: Tuple[int, ...] = (7, 7, 7, 7), + strides: Tuple[int, ...] = (4, 2, 2, 2), dropout_rate: float = 0.3, layer_scale_init_value: float = 1e-6, + drop_path_rate: float = 0.1, head_dropout: float = 0.5, cs_hidden: Optional[Tuple[int, ...]] = (2300, 1150), @@ -249,8 +288,9 @@ def __init__( strides=strides, dropout_rate=dropout_rate, layer_scale_init_value=layer_scale_init_value, + drop_path_rate=drop_path_rate, ) - feat_dim = self.backbone.dim_output # 80 + feat_dim = self.backbone.dim_output # dims[-1] # Heads (produce logits for classification; no softmax here) self.cs_head = make_mlp( @@ -317,12 +357,12 @@ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: """ Forward pass. Returns a dict with: - - cs_logits: (N, 7) - - sg_logits: (N, 230) - - lp: (N, 6) bounded if enabled - - features: (N, 80) + - cs_logits: (N, num_cs_classes) + - sg_logits: (N, num_sg_classes) + - lp: (N, num_lp_outputs) bounded if enabled + - features: (N, dims[-1]) """ - feats = self.backbone(x) # (N, 80) + feats = self.backbone(x) # (N, dims[-1]) cs_logits = self.cs_head(feats) sg_logits = self.sg_head(feats) lp = self.lp_head(feats) @@ -495,15 +535,16 @@ def configure_optimizers(self): # ----------------------------- def build_alphadiffract_model_for_8192() -> AlphaDiffractLightning: """ - Build the model for 1x8192 XRD input, 80-dim features, and three heads. + Build the model for 1x8192 XRD input, ConvNeXt-style 4-stage backbone, and three heads. """ return AlphaDiffractLightning( - depths=(1, 1, 1), - dims=(80, 80, 80), - kernel_sizes=(100, 50, 25), - strides=(5, 5, 5), + depths=(3, 3, 9, 3), + dims=(80, 160, 320, 640), + kernel_sizes=(7, 7, 7, 7), + strides=(4, 2, 2, 2), dropout_rate=0.3, layer_scale_init_value=1e-6, + drop_path_rate=0.1, head_dropout=0.5, cs_hidden=(2300, 1150), diff --git a/src/trainer/train_paper.py b/src/trainer/train_paper.py index ecf327c..61d70b0 100644 --- a/src/trainer/train_paper.py +++ b/src/trainer/train_paper.py @@ -88,6 +88,7 @@ def build_model_from_cfg(cfg: Dict[str, Any]) -> AlphaDiffractLightning: strides=tuple(cfg["strides"]), dropout_rate=cfg["dropout_rate"], layer_scale_init_value=cfg["layer_scale_init_value"], + drop_path_rate=cfg["drop_path_rate"], # Heads head_dropout=cfg["head_dropout"], cs_hidden=tuple(cfg["cs_hidden"]), From f30439b36f57eb08798843dc07d312e18a15a39a Mon Sep 17 00:00:00 2001 From: linked-liszt Date: Wed, 5 Nov 2025 21:33:34 -0600 Subject: [PATCH 08/18] Add wav2vec2 model. --- configs/trainer.yaml | 2 +- configs/trainer_wav2vec2.yaml | 121 ++++++ src/trainer/model/wav2vec2_model.py | 596 ++++++++++++++++++++++++++++ src/trainer/train_paper.py | 117 ++++-- 4 files changed, 800 insertions(+), 36 deletions(-) create mode 100644 configs/trainer_wav2vec2.yaml create mode 100644 src/trainer/model/wav2vec2_model.py diff --git a/configs/trainer.yaml b/configs/trainer.yaml index d4125f7..213e643 100644 --- a/configs/trainer.yaml +++ b/configs/trainer.yaml @@ -71,7 +71,7 @@ gemd_mu: 0.0 gemd_distance_matrix_path: # e.g., "path/to/space_group_distance_matrix.npy" to enable GEMD # Optimizer -lr: 0.0001 # paper used 2e-4 +lr: 0.00015 # paper used 2e-4 weight_decay: 0.01 # paper used 0.01 use_adamw: true gradient_clip_val: 1.0 diff --git a/configs/trainer_wav2vec2.yaml b/configs/trainer_wav2vec2.yaml new file mode 100644 index 0000000..cbd9a35 --- /dev/null +++ b/configs/trainer_wav2vec2.yaml @@ -0,0 +1,121 @@ +# AlphaDiffract trainer configuration for Wav2Vec2-style backbone (8192-length signals) +# Use with: python -m trainer.train_paper configs/trainer_wav2vec2.yaml + +# --- Data / Manifests --- +manifest_dir: "../../data/manifests" +dataset_root: "../../data/dataset" +auto_generate_manifests: true +train_ratio: 0.8 +val_ratio: 0.1 +test_ratio: 0.1 +seed: 42 + +# --- DataLoader --- +# Large dataset (14M samples, 150k materials) — start with batch_size=256 and tune later +batch_size: 256 +num_workers: 8 +pin_memory: true +persistent_workers: true + +# --- Dataset label extraction (embedded in .npy/.npz) --- +validate_paths: false +extract_labels: true +allow_pickle: true +labels_key_map: + x: "dp" + cs: "cs" + sg: "sg" + lattice_params: null + lp_a: "_cell_length_a" + lp_b: "_cell_length_b" + lp_c: "_cell_length_c" + lp_alpha: "_cell_angle_alpha" + lp_beta: "_cell_angle_beta" + lp_gamma: "_cell_angle_gamma" +dtype: "float32" +mmap_mode: null +floor_at_zero: true # counts are non-negative +normalize_log1p: true # compress dynamic range for stability + +# --- Model selection --- +model_type: "wav2vec2" # choose between "convnext" and "wav2vec2" + +# --- Wav2Vec2-style backbone (defaults tailored for 8192 inputs and more tokens) --- +in_chans: 1 +d_model: 512 +n_heads: 8 +num_layers: 8 +ff_dim: 2048 +# Convolutional feature extractor — total stride = 32 -> ~256 tokens for length 8192 +conv_kernel_sizes: [10, 5, 3, 3, 3, 2] +conv_strides: [2, 2, 2, 2, 2, 1] +conv_dropout: 0.0 +pos_kernel_size: 129 +pos_dropout: 0.1 +encoder_dropout: 0.1 +layer_norm_first: true +token_pool: "cls" # "mean" or "cls" + +# Heads (reuse paper defaults) +head_dropout: 0.2 +cs_hidden: [2300, 1150] +sg_hidden: [2300, 1150] +lp_hidden: [512, 256] + +# Task sizes +num_cs_classes: 7 +num_sg_classes: 230 +num_lp_outputs: 6 + +# LP output bounds +lp_bounds_min: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +lp_bounds_max: [500.0, 500.0, 500.0, 180.0, 180.0, 180.0] +bound_lp_with_sigmoid: true + +# Loss weights +lambda_cs: 1.0 +lambda_sg: 1.0 +lambda_lp: 0.0 + +# Optional GEMD term on SG +gemd_mu: 0.0 +gemd_distance_matrix_path: + +# Optimizer +lr: 0.0005 # moderately higher base LR with warmup; tune as needed +weight_decay: 0.01 +use_adamw: true +gradient_clip_val: 1.0 +gradient_clip_algorithm: "norm" + +# --- Scheduler --- +warmup_steps: 5000 +cosine_t_max: 112000 + +# --- Logging --- +logger: "mlflow" # 'csv' or 'mlflow' +csv_logger_name: "model_logs_wav2vec2" +mlflow_experiment_name: "OpenAlphaDiffract_Wav2Vec2" +mlflow_tracking_uri: null +mlflow_run_name: "Wav2Vec2_Run" + +# --- Trainer settings --- +default_root_dir: "outputs/wav2vec2_model" +max_epochs: 50 +accumulate_grad_batches: 1 +precision: "bf16-mixed" # good default for H100/A100; switch to '16-mixed' if needed +accelerator: "gpu" +devices: 1 +log_every_n_steps: 50 +deterministic: false +benchmark: true + +# --- Checkpointing --- +monitor: "val/loss" +mode: "min" +save_top_k: 1 +every_n_epochs: 1 + +# --- Evaluation --- +resume_from: +test_after_train: true diff --git a/src/trainer/model/wav2vec2_model.py b/src/trainer/model/wav2vec2_model.py new file mode 100644 index 0000000..9c075b1 --- /dev/null +++ b/src/trainer/model/wav2vec2_model.py @@ -0,0 +1,596 @@ +from typing import Dict, Tuple, Optional, List, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import pytorch_lightning as pl + + +# ============================= +# Wav2Vec2-style 1D backbone for 8192-length XRD signals +# ============================= + +def _same_padding(kernel_size: int) -> int: + # For odd kernel sizes, this approximates "same" padding in Conv1d + return kernel_size // 2 + + +class ConvFeatureExtractor1D(nn.Module): + """ + Wav2Vec2-style convolutional feature extractor for 1D signals. + + This stack reduces the raw sequence length to a token sequence length via strided 1D convolutions. + Defaults are chosen to produce an overall stride of 64 for input length 8192, i.e., ~128 tokens. + + Args: + in_chans: Input channels (1 for intensity-only). + conv_dim: Number of channels in conv feature maps (often equals d_model). + kernel_sizes: List of kernel sizes for each conv stage. + strides: List of strides for each conv stage (same length as kernel_sizes). + activation: Nonlinearity to apply after each conv (default: GELU). + dropout: Dropout applied after activation in each stage. + """ + def __init__( + self, + in_chans: int = 1, + conv_dim: int = 512, + kernel_sizes: Tuple[int, ...] = (10, 5, 3, 3, 3, 2), + strides: Tuple[int, ...] = (2, 2, 2, 2, 2, 2), + activation: Optional[nn.Module] = None, + dropout: float = 0.0, + ): + super().__init__() + assert len(kernel_sizes) == len(strides), "kernel_sizes and strides must have same length" + layers: List[nn.Module] = [] + c_in = in_chans + act = activation if activation is not None else nn.GELU() + for i, (k, s) in enumerate(zip(kernel_sizes, strides)): + layers.append(nn.Conv1d(c_in, conv_dim, kernel_size=k, stride=s, padding=_same_padding(k))) + layers.append(act) + if dropout and dropout > 0.0: + layers.append(nn.Dropout(p=dropout)) + c_in = conv_dim + self.net = nn.Sequential(*layers) + self.out_dim = conv_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: (N, L) or (N, 1, L) + if x.ndim == 2: + x = x[:, None, :] + return self.net(x) # (N, C=conv_dim, T) + + +class PositionalConvEmbedding(nn.Module): + """ + Convolutional positional embedding as used in wav2vec2: + depthwise Conv1d over features with GELU and dropout, added to token sequence. + """ + def __init__(self, d_model: int, kernel_size: int = 128, groups: Optional[int] = None, dropout: float = 0.1): + super().__init__() + g = groups if groups is not None else d_model + self.pos_conv = nn.Conv1d( + d_model, d_model, kernel_size=kernel_size, padding='same', groups=g + ) + self.activation = nn.GELU() + self.dropout = nn.Dropout(dropout) + + def forward(self, tokens: torch.Tensor) -> torch.Tensor: + # tokens: (N, T, D) + x = tokens.transpose(1, 2) # (N, D, T) + x = self.pos_conv(x) + x = self.activation(x) + x = x.transpose(1, 2) # (N, T, D) + return tokens + self.dropout(x) + + +class Wav2Vec2Backbone1D(nn.Module): + """ + Wav2Vec2-style backbone adapted for 1D XRD signals. + + Pipeline: + - Convolutional feature extractor: reduces length to token sequence (N, T, D) + - Convolutional positional embedding + - Transformer encoder stack (batch_first), dropout + - Feature pooling: mean over tokens -> (N, D) + + Defaults are chosen for 8192-length inputs to yield ~128 tokens (overall stride 64). + """ + def __init__( + self, + in_chans: int = 1, + d_model: int = 512, + n_heads: int = 8, + num_layers: int = 8, + ff_dim: Optional[int] = None, + conv_kernel_sizes: Tuple[int, ...] = (10, 5, 3, 3, 3, 2), + conv_strides: Tuple[int, ...] = (2, 2, 2, 2, 2, 2), + conv_dropout: float = 0.0, + pos_kernel_size: int = 128, + pos_dropout: float = 0.1, + encoder_dropout: float = 0.1, + layer_norm_first: bool = False, + token_pool: str = "mean", # "mean" or "cls" (mean by default) + ): + super().__init__() + ff_dim = ff_dim or (4 * d_model) + self.token_pool = token_pool + + # Conv feature extractor -> (N, D, T) + self.feature_extractor = ConvFeatureExtractor1D( + in_chans=in_chans, + conv_dim=d_model, + kernel_sizes=conv_kernel_sizes, + strides=conv_strides, + dropout=conv_dropout, + ) + + # Positional conv embedding + self.pos_embed = PositionalConvEmbedding(d_model=d_model, kernel_size=pos_kernel_size, dropout=pos_dropout) + # Pre-encoder token normalization (helps Transformer stability on non-centered inputs) + self.input_ln = nn.LayerNorm(d_model) + + # Transformer encoder + encoder_layer = nn.TransformerEncoderLayer( + d_model=d_model, + nhead=n_heads, + dim_feedforward=ff_dim, + dropout=encoder_dropout, + activation="gelu", + batch_first=True, + norm_first=layer_norm_first, + ) + self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + self.encoder_ln = nn.LayerNorm(d_model) + self.dim_output = d_model + + # Optional CLS token if desired + if self.token_pool == "cls": + self.cls = nn.Parameter(torch.zeros(1, 1, d_model)) + else: + self.register_parameter("cls", None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: (N, L) or (N, 1, L) + feats = self.feature_extractor(x) # (N, D, T) + tokens = feats.transpose(1, 2) # (N, T, D) + + if self.cls is not None: + cls_tok = self.cls.expand(tokens.size(0), -1, -1) # (N, 1, D) + tokens = torch.cat([cls_tok, tokens], dim=1) # (N, 1+T, D) + + tokens = self.pos_embed(tokens) # (N, T, D) + tokens = self.input_ln(tokens) # (N, T, D) + enc = self.encoder(tokens) # (N, T, D) + enc = self.encoder_ln(enc) # (N, T, D) + + if self.token_pool == "cls" and self.cls is not None: + pooled = enc[:, 0, :] # (N, D) + else: + pooled = enc.mean(dim=1) # (N, D) + return pooled + + +# ============================= +# Heads + Lightning module (mirrors existing AlphaDiffractLightning API) +# ============================= + +def make_mlp( + input_dim: int, + hidden_dims: Optional[Tuple[int, ...]], + output_dim: int, + dropout: float = 0.2, + output_activation: Optional[nn.Module] = None, +) -> nn.Module: + layers: List[nn.Module] = [] + last = input_dim + if hidden_dims is not None and len(hidden_dims) > 0: + for hd in hidden_dims: + layers.extend([nn.Linear(last, hd), nn.LeakyReLU()]) + if dropout and dropout > 0: + layers.append(nn.Dropout(dropout)) + last = hd + layers.append(nn.Linear(last, output_dim)) + if output_activation is not None: + layers.append(output_activation) + return nn.Sequential(*layers) + + +BatchType = Union[ + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + Dict[str, torch.Tensor], +] + + +class AlphaDiffractWav2Vec2Lightning(pl.LightningModule): + """ + AlphaDiffract variant with a Wav2Vec2-style 1D backbone. + + Same heads and multi-task losses as AlphaDiffractLightning: + - CS classifier head (7 classes) + - SG classifier head (230 classes) + - LP regressor head (6 outputs, optionally bounded to [min, max] via sigmoid) + """ + + def __init__( + self, + # Backbone (defaults tailored for 8192 input, ~128 tokens) + in_chans: int = 1, + d_model: int = 512, + n_heads: int = 8, + num_layers: int = 8, + ff_dim: Optional[int] = None, + conv_kernel_sizes: Tuple[int, ...] = (10, 5, 3, 3, 3, 2), + conv_strides: Tuple[int, ...] = (2, 2, 2, 2, 2, 2), # total stride 64 + conv_dropout: float = 0.0, + pos_kernel_size: int = 128, + pos_dropout: float = 0.1, + encoder_dropout: float = 0.1, + layer_norm_first: bool = False, + token_pool: str = "mean", + + # Heads + head_dropout: float = 0.5, + cs_hidden: Optional[Tuple[int, ...]] = (2300, 1150), + sg_hidden: Optional[Tuple[int, ...]] = (2300, 1150), + lp_hidden: Optional[Tuple[int, ...]] = (512, 256), + + # Task sizes + num_cs_classes: int = 7, + num_sg_classes: int = 230, + num_lp_outputs: int = 6, + + # LP bounding + lp_bounds_min: Tuple[float, float, float, float, float, float] = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0), + lp_bounds_max: Tuple[float, float, float, float, float, float] = (500.0, 500.0, 500.0, 180.0, 180.0, 180.0), + bound_lp_with_sigmoid: bool = True, + + # Loss weights + lambda_cs: float = 1.0, + lambda_sg: float = 1.0, + lambda_lp: float = 1.0, + # Optional GEMD for SG + gemd_mu: float = 0.0, + gemd_distance_matrix_path: Optional[str] = None, + + # Optimizer + lr: float = 2e-4, + weight_decay: float = 1e-2, + use_adamw: bool = True, + # Scheduler (step-based warmup + cosine decay) + warmup_steps: int = 5000, + cosine_t_max: int = 112000, + ): + super().__init__() + self.save_hyperparameters() + + # Backbone + self.backbone = Wav2Vec2Backbone1D( + in_chans=in_chans, + d_model=d_model, + n_heads=n_heads, + num_layers=num_layers, + ff_dim=ff_dim, + conv_kernel_sizes=conv_kernel_sizes, + conv_strides=conv_strides, + conv_dropout=conv_dropout, + pos_kernel_size=pos_kernel_size, + pos_dropout=pos_dropout, + encoder_dropout=encoder_dropout, + layer_norm_first=layer_norm_first, + token_pool=token_pool, + ) + feat_dim = self.backbone.dim_output + + # Heads + self.cs_head = make_mlp( + input_dim=feat_dim, + hidden_dims=cs_hidden, + output_dim=num_cs_classes, + dropout=head_dropout, + output_activation=None, + ) + self.sg_head = make_mlp( + input_dim=feat_dim, + hidden_dims=sg_hidden, + output_dim=num_sg_classes, + dropout=head_dropout, + output_activation=None, + ) + self.lp_head = make_mlp( + input_dim=feat_dim, + hidden_dims=lp_hidden, + output_dim=num_lp_outputs, + dropout=head_dropout, + output_activation=None, + ) + + # Losses + self.ce = nn.CrossEntropyLoss() + self.mse = nn.MSELoss() + + # LP bounds + self.register_buffer("lp_min", torch.tensor(lp_bounds_min, dtype=torch.float32)) + self.register_buffer("lp_max", torch.tensor(lp_bounds_max, dtype=torch.float32)) + self.bound_lp_with_sigmoid = bound_lp_with_sigmoid + + # weights + self.lambda_cs = lambda_cs + self.lambda_sg = lambda_sg + self.lambda_lp = lambda_lp + + # Optimizer config + self.lr = lr + self.weight_decay = weight_decay + self.use_adamw = use_adamw + # Scheduler params from constructor (exposed via config) + self.warmup_steps = warmup_steps + self.cosine_t_max = cosine_t_max + + # Task sizes + self.num_cs_classes = num_cs_classes + self.num_sg_classes = num_sg_classes + self.num_lp_outputs = num_lp_outputs + + # GEMD setup (optional) + self.gemd_mu = gemd_mu + self.register_buffer("gemd_D", torch.empty(0)) + if gemd_distance_matrix_path is not None: + D_np = np.load(gemd_distance_matrix_path) + D_t = torch.as_tensor(D_np, dtype=torch.float32) + if D_t.ndim != 2 or D_t.shape[0] != self.num_sg_classes or D_t.shape[1] != self.num_sg_classes: + raise ValueError("GEMD distance matrix must be of shape (num_sg_classes, num_sg_classes)") + self.register_buffer("gemd_D", D_t) + + # ----------------------------- + # Forward + # ----------------------------- + def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: + feats = self.backbone(x) # (N, D) + cs_logits = self.cs_head(feats) + sg_logits = self.sg_head(feats) + lp = self.lp_head(feats) + + if self.bound_lp_with_sigmoid: + lp = torch.sigmoid(lp) * (self.lp_max - self.lp_min) + self.lp_min + + return { + "features": feats, + "cs_logits": cs_logits, + "sg_logits": sg_logits, + "lp": lp, + } + + # ----------------------------- + # Data parsing helpers + # ----------------------------- + @staticmethod + def _to_index(y: torch.Tensor, num_classes: int) -> torch.Tensor: + """ + Robustly convert labels to 0-based class indices: + - Supports one-hot and integer labels. + - If integer labels appear 1-based (min>=1 and max<=num_classes), shift to 0-based. + - Clamp to [0, num_classes-1] to avoid out-of-range targets. + """ + if y.dim() > 1 and y.size(-1) > 1: + idx = y.argmax(dim=-1) + else: + idx = y.long() + + with torch.no_grad(): + if idx.numel() > 0: + minv = int(idx.min().item()) + maxv = int(idx.max().item()) + # Shift 1-based labels to 0-based if detected + if minv >= 1 and maxv <= num_classes: + idx = idx - 1 + # Ensure valid range + idx = idx.clamp(min=0, max=num_classes - 1) + return idx + + @staticmethod + def _extract_batch(batch: BatchType) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if isinstance(batch, (list, tuple)): + assert len(batch) >= 4, "Expected at least (x, cs, sg, lp) in the batch tuple." + x, y_cs, y_sg, y_lp = batch[:4] + elif isinstance(batch, dict): + x = batch.get("x", batch.get("xrd", batch.get("signal"))) + if x is None: + raise KeyError("Batch dict must contain 'x' or 'xrd' or 'signal'.") + y_cs = batch.get("cs") + y_sg = batch.get("sg") + y_lp = batch.get("lattice_params", batch.get("lp")) + if y_cs is None or y_sg is None or y_lp is None: + raise KeyError("Batch dict must contain 'cs', 'sg', and 'lattice_params' (or 'lp').") + else: + raise TypeError("Unsupported batch type. Use Tuple or Dict.") + + return x, y_cs, y_sg, y_lp + + # ----------------------------- + # Loss and metrics + # ----------------------------- + def _compute_losses_and_metrics( + self, preds: Dict[str, torch.Tensor], y_cs: torch.Tensor, y_sg: torch.Tensor, y_lp: torch.Tensor + ) -> Dict[str, torch.Tensor]: + cs_logits = preds["cs_logits"] + sg_logits = preds["sg_logits"] + lp_pred = preds["lp"] + + y_cs_idx = self._to_index(y_cs, self.num_cs_classes) + y_sg_idx = self._to_index(y_sg, self.num_sg_classes) + y_lp = y_lp.float() + + loss_cs = self.ce(cs_logits, y_cs_idx) + loss_sg = self.ce(sg_logits, y_sg_idx) + loss_lp = self.mse(lp_pred, y_lp) + + loss_gemd = torch.tensor(0.0, device=cs_logits.device) + sg_probs = torch.softmax(sg_logits, dim=1) + if self.gemd_mu > 0.0 and self.gemd_D.numel() > 0: + D_rows = self.gemd_D[y_sg_idx] + gemd_per_sample = (D_rows * sg_probs).sum(dim=1) + loss_gemd = gemd_per_sample.mean() + + total_loss = ( + self.lambda_cs * loss_cs + + self.lambda_sg * loss_sg + + self.lambda_lp * loss_lp + + self.gemd_mu * loss_gemd + ) + + with torch.no_grad(): + cs_acc = (cs_logits.argmax(dim=1) == y_cs_idx).float().mean() + sg_acc = (sg_logits.argmax(dim=1) == y_sg_idx).float().mean() + # Top-5 accuracy for SG to detect early improvements before top-1 moves + sg_top5 = ( + sg_logits.topk(5, dim=1).indices.eq(y_sg_idx.unsqueeze(1)).any(dim=1).float().mean() + ) + lp_mae = (lp_pred - y_lp).abs().mean() + lp_mse = F.mse_loss(lp_pred, y_lp) + + return { + "loss_total": total_loss, + "loss_cs": loss_cs, + "loss_sg": loss_sg, + "loss_lp": loss_lp, + "loss_gemd": loss_gemd, + "cs_acc": cs_acc, + "sg_acc": sg_acc, + "sg_top5": sg_top5, + "lp_mae": lp_mae, + "lp_mse": lp_mse, + } + + # ----------------------------- + # Lightning hooks + # ----------------------------- + def training_step(self, batch: BatchType, batch_idx: int) -> torch.Tensor: + x, y_cs, y_sg, y_lp = self._extract_batch(batch) + preds = self(x) + out = self._compute_losses_and_metrics(preds, y_cs, y_sg, y_lp) + + self.log("train/loss", out["loss_total"], prog_bar=True, on_step=True, on_epoch=True) + self.log("train/loss_cs", out["loss_cs"], on_step=True, on_epoch=True) + self.log("train/loss_sg", out["loss_sg"], on_step=True, on_epoch=True) + self.log("train/loss_lp", out["loss_lp"], on_step=True, on_epoch=True) + self.log("train/loss_gemd", out["loss_gemd"], on_step=True, on_epoch=True) + self.log("train/cs_acc", out["cs_acc"], prog_bar=True, on_step=True, on_epoch=True) + self.log("train/sg_acc", out["sg_acc"], on_step=True, on_epoch=True) + self.log("train/sg_top5", out["sg_top5"], on_step=True, on_epoch=True) + self.log("train/lp_mae", out["lp_mae"], on_step=True, on_epoch=True) + self.log("train/lp_mse", out["lp_mse"], on_step=True, on_epoch=True) + + return out["loss_total"] + + def validation_step(self, batch: BatchType, batch_idx: int) -> None: + x, y_cs, y_sg, y_lp = self._extract_batch(batch) + preds = self(x) + out = self._compute_losses_and_metrics(preds, y_cs, y_sg, y_lp) + + self.log("val/loss", out["loss_total"], prog_bar=True, on_epoch=True) + self.log("val/loss_cs", out["loss_cs"], on_epoch=True) + self.log("val/loss_sg", out["loss_sg"], on_epoch=True) + self.log("val/loss_lp", out["loss_lp"], on_epoch=True) + self.log("val/loss_gemd", out["loss_gemd"], on_epoch=True) + self.log("val/cs_acc", out["cs_acc"], prog_bar=True, on_epoch=True) + self.log("val/sg_acc", out["sg_acc"], on_epoch=True) + self.log("val/sg_top5", out["sg_top5"], on_epoch=True) + self.log("val/lp_mae", out["lp_mae"], on_epoch=True) + self.log("val/lp_mse", out["lp_mse"], on_epoch=True) + + def test_step(self, batch: BatchType, batch_idx: int) -> None: + x, y_cs, y_sg, y_lp = self._extract_batch(batch) + preds = self(x) + out = self._compute_losses_and_metrics(preds, y_cs, y_sg, y_lp) + + self.log("test/loss", out["loss_total"], prog_bar=True, on_epoch=True) + self.log("test/loss_cs", out["loss_cs"], on_epoch=True) + self.log("test/loss_sg", out["loss_sg"], on_epoch=True) + self.log("test/loss_lp", out["loss_lp"], on_epoch=True) + self.log("test/loss_gemd", out["loss_gemd"], on_epoch=True) + self.log("test/cs_acc", out["cs_acc"], prog_bar=True, on_epoch=True) + self.log("test/sg_acc", out["sg_acc"], on_epoch=True) + self.log("test/sg_top5", out["sg_top5"], on_epoch=True) + self.log("test/lp_mae", out["lp_mae"], on_epoch=True) + self.log("test/lp_mse", out["lp_mse"], on_epoch=True) + + def configure_optimizers(self): + params = self.parameters() + # AdamW with transformer-friendly betas/eps + if self.use_adamw: + optimizer = torch.optim.AdamW(params, lr=self.lr, weight_decay=self.weight_decay, betas=(0.9, 0.98), eps=1e-8) + else: + optimizer = torch.optim.Adam(params, lr=self.lr, weight_decay=self.weight_decay, betas=(0.9, 0.98), eps=1e-8) + + # Step-based warmup then cosine decay (better for very large datasets) + warmup_steps = getattr(self, "warmup_steps", 20000) + cosine_t_max = getattr(self, "cosine_t_max", 200000) + + linear_warmup = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, total_iters=warmup_steps) + cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cosine_t_max, eta_min=self.lr * 0.01) + scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[linear_warmup, cosine], milestones=[warmup_steps]) + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "step", # apply scheduler every training step + "frequency": 1, + "name": "warmup_cosine_steps", + }, + } + + +# ----------------------------- +# Example factory for 1x8192 input with ~128 tokens +# ----------------------------- +def build_wav2vec2_alphadiffract_for_8192() -> AlphaDiffractWav2Vec2Lightning: + """ + Build a Wav2Vec2-style AlphaDiffract model for 1x8192 XRD input. + + Design choices for large dataset (~14M samples): + - Overall conv stride 64 -> ~128 tokens for 8192 input + - d_model=512, n_heads=8, num_layers=8 (scalable up) + - Standard MLP heads reused from ConvNeXt variant + - Defaults suitable for supervised training; self-supervised pretraining pipeline not included here + """ + return AlphaDiffractWav2Vec2Lightning( + # Backbone sizing + in_chans=1, + d_model=512, + n_heads=8, + num_layers=8, + ff_dim=2048, + conv_kernel_sizes=(10, 5, 3, 3, 3, 2), + conv_strides=(2, 2, 2, 2, 2, 2), # total stride 64 + conv_dropout=0.0, + pos_kernel_size=128, + pos_dropout=0.1, + encoder_dropout=0.1, + layer_norm_first=False, + token_pool="mean", + + # Heads + head_dropout=0.5, + cs_hidden=(2300, 1150), + sg_hidden=(2300, 1150), + lp_hidden=(512, 256), + + num_cs_classes=7, + num_sg_classes=230, + num_lp_outputs=6, + + lp_bounds_min=(0.0, 0.0, 0.0, 0.0, 0.0, 0.0), + lp_bounds_max=(500.0, 500.0, 500.0, 180.0, 180.0, 180.0), + bound_lp_with_sigmoid=True, + + lambda_cs=1.0, + lambda_sg=1.0, + lambda_lp=1.0, + + # Optimizer defaults (tune as needed) + lr=2e-4, + weight_decay=1e-2, + use_adamw=True, + ) diff --git a/src/trainer/train_paper.py b/src/trainer/train_paper.py index 61d70b0..d40742e 100644 --- a/src/trainer/train_paper.py +++ b/src/trainer/train_paper.py @@ -15,6 +15,7 @@ # Project imports (expect PYTHONPATH=src or run via `python -m trainer.train_paper`) from dataset import NpyDataModule from model.model import AlphaDiffractLightning +from model.wav2vec2_model import AlphaDiffractWav2Vec2Lightning def parse_args() -> argparse.Namespace: @@ -80,41 +81,86 @@ def build_datamodule_from_cfg(cfg: Dict[str, Any]) -> NpyDataModule: def build_model_from_cfg(cfg: Dict[str, Any]) -> AlphaDiffractLightning: - model = AlphaDiffractLightning( - # Backbone - depths=tuple(cfg["depths"]), - dims=tuple(cfg["dims"]), - kernel_sizes=tuple(cfg["kernel_sizes"]), - strides=tuple(cfg["strides"]), - dropout_rate=cfg["dropout_rate"], - layer_scale_init_value=cfg["layer_scale_init_value"], - drop_path_rate=cfg["drop_path_rate"], - # Heads - head_dropout=cfg["head_dropout"], - cs_hidden=tuple(cfg["cs_hidden"]), - sg_hidden=tuple(cfg["sg_hidden"]), - lp_hidden=tuple(cfg["lp_hidden"]), - # Task sizes - num_cs_classes=cfg["num_cs_classes"], - num_sg_classes=cfg["num_sg_classes"], - num_lp_outputs=cfg["num_lp_outputs"], - # LP bounds and output handling - lp_bounds_min=tuple(cfg["lp_bounds_min"]), - lp_bounds_max=tuple(cfg["lp_bounds_max"]), - bound_lp_with_sigmoid=cfg["bound_lp_with_sigmoid"], - # Loss weights - lambda_cs=cfg["lambda_cs"], - lambda_sg=cfg["lambda_sg"], - lambda_lp=cfg["lambda_lp"], - # Optional GEMD - gemd_mu=cfg["gemd_mu"], - gemd_distance_matrix_path=cfg.get("gemd_distance_matrix_path"), - # Optimizer - lr=cfg["lr"], - weight_decay=cfg["weight_decay"], - use_adamw=cfg["use_adamw"], - ) - return model + model_type = cfg.get("model_type", "convnext").lower() + if model_type == "wav2vec2": + return AlphaDiffractWav2Vec2Lightning( + # Backbone (Wav2Vec2-style; defaults tailored for 8192-length signals and large dataset) + in_chans=cfg.get("in_chans", 1), + d_model=cfg.get("d_model", 512), + n_heads=cfg.get("n_heads", 8), + num_layers=cfg.get("num_layers", 8), + ff_dim=cfg.get("ff_dim", 2048), + conv_kernel_sizes=tuple(cfg.get("conv_kernel_sizes", (10, 5, 3, 3, 3, 2))), + conv_strides=tuple(cfg.get("conv_strides", (2, 2, 2, 2, 2, 2))), # overall stride 64 -> ~128 tokens + conv_dropout=cfg.get("conv_dropout", 0.0), + pos_kernel_size=cfg.get("pos_kernel_size", 128), + pos_dropout=cfg.get("pos_dropout", 0.1), + encoder_dropout=cfg.get("encoder_dropout", 0.1), + layer_norm_first=cfg.get("layer_norm_first", False), + token_pool=cfg.get("token_pool", "mean"), + # Heads (reuse paper defaults) + head_dropout=cfg["head_dropout"], + cs_hidden=tuple(cfg["cs_hidden"]), + sg_hidden=tuple(cfg["sg_hidden"]), + lp_hidden=tuple(cfg["lp_hidden"]), + # Task sizes + num_cs_classes=cfg["num_cs_classes"], + num_sg_classes=cfg["num_sg_classes"], + num_lp_outputs=cfg["num_lp_outputs"], + # LP bounds + lp_bounds_min=tuple(cfg["lp_bounds_min"]), + lp_bounds_max=tuple(cfg["lp_bounds_max"]), + bound_lp_with_sigmoid=cfg["bound_lp_with_sigmoid"], + # Loss weights + lambda_cs=cfg["lambda_cs"], + lambda_sg=cfg["lambda_sg"], + lambda_lp=cfg["lambda_lp"], + # Optional GEMD + gemd_mu=cfg["gemd_mu"], + gemd_distance_matrix_path=cfg.get("gemd_distance_matrix_path"), + # Optimizer + lr=cfg["lr"], + weight_decay=cfg["weight_decay"], + use_adamw=cfg["use_adamw"], + # Scheduler + warmup_steps=cfg.get("warmup_steps", 5000), + cosine_t_max=cfg.get("cosine_t_max", 112000), + ) + else: + return AlphaDiffractLightning( + # ConvNeXt1D backbone (paper defaults) + depths=tuple(cfg["depths"]), + dims=tuple(cfg["dims"]), + kernel_sizes=tuple(cfg["kernel_sizes"]), + strides=tuple(cfg["strides"]), + dropout_rate=cfg["dropout_rate"], + layer_scale_init_value=cfg["layer_scale_init_value"], + drop_path_rate=cfg["drop_path_rate"], + # Heads + head_dropout=cfg["head_dropout"], + cs_hidden=tuple(cfg["cs_hidden"]), + sg_hidden=tuple(cfg["sg_hidden"]), + lp_hidden=tuple(cfg["lp_hidden"]), + # Task sizes + num_cs_classes=cfg["num_cs_classes"], + num_sg_classes=cfg["num_sg_classes"], + num_lp_outputs=cfg["num_lp_outputs"], + # LP bounds and output handling + lp_bounds_min=tuple(cfg["lp_bounds_min"]), + lp_bounds_max=tuple(cfg["lp_bounds_max"]), + bound_lp_with_sigmoid=cfg["bound_lp_with_sigmoid"], + # Loss weights + lambda_cs=cfg["lambda_cs"], + lambda_sg=cfg["lambda_sg"], + lambda_lp=cfg["lambda_lp"], + # Optional GEMD + gemd_mu=cfg["gemd_mu"], + gemd_distance_matrix_path=cfg.get("gemd_distance_matrix_path"), + # Optimizer + lr=cfg["lr"], + weight_decay=cfg["weight_decay"], + use_adamw=cfg["use_adamw"], + ) class ConfigArtifactLogger(Callback): @@ -155,6 +201,7 @@ def build_trainer_from_cfg(cfg: Dict[str, Any], raw_config_path: Optional[str] = experiment_name=cfg["mlflow_experiment_name"], tracking_uri=cfg["mlflow_tracking_uri"], run_name=cfg["mlflow_run_name"], + log_model=True, ) trainer = Trainer( From 0fdc5e8d5591ef9c6feb922c6e60815cff6a6f48 Mon Sep 17 00:00:00 2001 From: linked-liszt Date: Fri, 7 Nov 2025 14:43:33 -0600 Subject: [PATCH 09/18] feat: add individual folder checkpointing and rruff processing (temp) --- configs/trainer_wav2vec2.yaml | 18 ++-- configs/trainer_wav2vec2_nonorm.yaml | 129 +++++++++++++++++++++++++++ src/trainer/dataset/datamodule.py | 23 ++++- src/trainer/model/model.py | 68 ++++++++------ src/trainer/model/wav2vec2_model.py | 68 ++++++++------ src/trainer/train_paper.py | 48 +++++++++- 6 files changed, 295 insertions(+), 59 deletions(-) create mode 100644 configs/trainer_wav2vec2_nonorm.yaml diff --git a/configs/trainer_wav2vec2.yaml b/configs/trainer_wav2vec2.yaml index cbd9a35..9022a0e 100644 --- a/configs/trainer_wav2vec2.yaml +++ b/configs/trainer_wav2vec2.yaml @@ -3,6 +3,7 @@ # --- Data / Manifests --- manifest_dir: "../../data/manifests" +extra_val_file: "rruff_sim.jsonl" dataset_root: "../../data/dataset" auto_generate_manifests: true train_ratio: 0.8 @@ -12,7 +13,7 @@ seed: 42 # --- DataLoader --- # Large dataset (14M samples, 150k materials) — start with batch_size=256 and tune later -batch_size: 256 +batch_size: 200 num_workers: 8 pin_memory: true persistent_workers: true @@ -37,6 +38,13 @@ mmap_mode: null floor_at_zero: true # counts are non-negative normalize_log1p: true # compress dynamic range for stability +# --- Noise augmentation (training split only; matches paper) --- +# If provided, noise is applied dynamically per-sample in the DataModule using the same +# sequencing as the paper: Poisson -> normalize -> add Gaussian -> renormalize -> rescale. +# Set ranges to None to disable. +noise_poisson_range: [1.0, 100.0] # λ_max ~ Uniform(1, 100) +noise_gaussian_range: [0.001, 0.1] # σ_rel ~ Uniform(1e-3, 1e-1) + # --- Model selection --- model_type: "wav2vec2" # choose between "convnext" and "wav2vec2" @@ -46,9 +54,9 @@ d_model: 512 n_heads: 8 num_layers: 8 ff_dim: 2048 -# Convolutional feature extractor — total stride = 32 -> ~256 tokens for length 8192 +# Convolutional feature extractor — total stride = 16 -> ~512 tokens for length 8192 conv_kernel_sizes: [10, 5, 3, 3, 3, 2] -conv_strides: [2, 2, 2, 2, 2, 1] +conv_strides: [2, 2, 2, 2, 1, 1] conv_dropout: 0.0 pos_kernel_size: 129 pos_dropout: 0.1 @@ -82,14 +90,14 @@ gemd_mu: 0.0 gemd_distance_matrix_path: # Optimizer -lr: 0.0005 # moderately higher base LR with warmup; tune as needed +lr: 0.0004 # moderately higher base LR with warmup; tune as needed weight_decay: 0.01 use_adamw: true gradient_clip_val: 1.0 gradient_clip_algorithm: "norm" # --- Scheduler --- -warmup_steps: 5000 +warmup_steps: 6000 cosine_t_max: 112000 # --- Logging --- diff --git a/configs/trainer_wav2vec2_nonorm.yaml b/configs/trainer_wav2vec2_nonorm.yaml new file mode 100644 index 0000000..4edfefd --- /dev/null +++ b/configs/trainer_wav2vec2_nonorm.yaml @@ -0,0 +1,129 @@ +# AlphaDiffract trainer configuration for Wav2Vec2-style backbone (8192-length signals) +# No log1p normalization (normalize_log1p: false) for A/B comparison against real-world RRUFF data +# Use with: PYTHONPATH=src python -m trainer.train_paper configs/trainer_wav2vec2_nonorm.yaml + +# --- Data / Manifests --- +manifest_dir: "../../data/manifests" +extra_val_file: "rruff_sim.jsonl" +dataset_root: "../../data/dataset" +auto_generate_manifests: true +train_ratio: 0.8 +val_ratio: 0.1 +test_ratio: 0.1 +seed: 42 + +# --- DataLoader --- +batch_size: 200 +num_workers: 8 +pin_memory: true +persistent_workers: true + +# --- Dataset label extraction (embedded in .npy/.npz) --- +validate_paths: false +extract_labels: true +allow_pickle: true +labels_key_map: + x: "dp" + cs: "cs" + sg: "sg" + lattice_params: null + lp_a: "_cell_length_a" + lp_b: "_cell_length_b" + lp_c: "_cell_length_c" + lp_alpha: "_cell_angle_alpha" + lp_beta: "_cell_angle_beta" + lp_gamma: "_cell_angle_gamma" +dtype: "float32" +mmap_mode: null +floor_at_zero: true # counts are non-negative +normalize_log1p: false # DISABLED for this run + +# --- Noise augmentation (training split only; matches paper) --- +# If provided, noise is applied dynamically per-sample in the DataModule using the same +# sequencing as the paper: Poisson -> normalize -> add Gaussian -> renormalize -> rescale. +# Set ranges to None to disable. +noise_poisson_range: [1.0, 100.0] +noise_gaussian_range: [0.001, 0.1] + +# --- Model selection --- +model_type: "wav2vec2" # choose between "convnext" and "wav2vec2" + +# --- Wav2Vec2-style backbone (defaults tailored for 8192 inputs and more tokens) --- +in_chans: 1 +d_model: 512 +n_heads: 8 +num_layers: 8 +ff_dim: 2048 +# Convolutional feature extractor — total stride = 16 -> ~512 tokens for length 8192 +conv_kernel_sizes: [10, 5, 3, 3, 3, 2] +conv_strides: [2, 2, 2, 2, 1, 1] +conv_dropout: 0.0 +pos_kernel_size: 129 +pos_dropout: 0.1 +encoder_dropout: 0.1 +layer_norm_first: true +token_pool: "cls" # "mean" or "cls" + +# Heads (reuse paper defaults) +head_dropout: 0.2 +cs_hidden: [2300, 1150] +sg_hidden: [2300, 1150] +lp_hidden: [512, 256] + +# Task sizes +num_cs_classes: 7 +num_sg_classes: 230 +num_lp_outputs: 6 + +# LP output bounds +lp_bounds_min: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +lp_bounds_max: [500.0, 500.0, 500.0, 180.0, 180.0, 180.0] +bound_lp_with_sigmoid: true + +# Loss weights +lambda_cs: 1.0 +lambda_sg: 1.0 +lambda_lp: 0.0 + +# Optional GEMD term on SG +gemd_mu: 0.0 +gemd_distance_matrix_path: + +# Optimizer +lr: 0.0004 +weight_decay: 0.01 +use_adamw: true +gradient_clip_val: 1.0 +gradient_clip_algorithm: "norm" + +# --- Scheduler --- +warmup_steps: 6000 +cosine_t_max: 112000 + +# --- Logging --- +logger: "mlflow" # 'csv' or 'mlflow' +csv_logger_name: "model_logs_wav2vec2_nonorm" +mlflow_experiment_name: "OpenAlphaDiffract_Wav2Vec2_NoNorm" +mlflow_tracking_uri: null +mlflow_run_name: "Wav2Vec2_Run_NoNorm" + +# --- Trainer settings --- +default_root_dir: "outputs/wav2vec2_model_nonorm" +max_epochs: 50 +accumulate_grad_batches: 1 +precision: "bf16-mixed" +accelerator: "gpu" +devices: 1 +log_every_n_steps: 50 +deterministic: false +benchmark: true + +# --- Checkpointing --- +monitor: "val/loss" +mode: "min" +save_top_k: 1 +every_n_epochs: 1 + +# --- Evaluation --- +resume_from: +test_after_train: true diff --git a/src/trainer/dataset/datamodule.py b/src/trainer/dataset/datamodule.py index 61e0af9..94379fe 100644 --- a/src/trainer/dataset/datamodule.py +++ b/src/trainer/dataset/datamodule.py @@ -98,6 +98,8 @@ def __init__( train_file: str = "train.jsonl", val_file: str = "val.jsonl", test_file: str = "test.jsonl", + # Optional: add a second validation manifest file (e.g., "rruff.jsonl") + extra_val_file: Optional[str] = None, ) -> None: super().__init__() self.manifest_dir = manifest_dir @@ -120,11 +122,13 @@ def __init__( self.train_file = train_file self.val_file = val_file self.test_file = test_file + self.extra_val_file = extra_val_file # Internal datasets self.train_ds = None self.val_ds = None self.test_ds = None + self.extra_val_ds = None def _manifest_paths(self) -> Dict[str, str]: # Allow overriding filenames while still respecting manifest_dir @@ -183,6 +187,11 @@ def setup(self, stage: Optional[str] = None) -> None: self.train_ds = self.dataset_cls(paths["train"], **self.dataset_kwargs) if self.val_ds is None: self.val_ds = self.dataset_cls(paths["val"], **self.dataset_kwargs) + # Extra validation dataset (e.g., rruff) + if self.extra_val_ds is None and self.extra_val_file is not None: + manifest_dir_abs = os.path.dirname(paths["val"]) + extra_path = os.path.join(manifest_dir_abs, self.extra_val_file) + self.extra_val_ds = self.dataset_cls(extra_path, **self.dataset_kwargs) if self.test_ds is None: self.test_ds = self.dataset_cls(paths["test"], **self.dataset_kwargs) @@ -198,7 +207,7 @@ def train_dataloader(self) -> DataLoader: ) def val_dataloader(self) -> DataLoader: - return DataLoader( + val_loader = DataLoader( self.val_ds, batch_size=self.batch_size, num_workers=self.num_workers, @@ -207,6 +216,18 @@ def val_dataloader(self) -> DataLoader: shuffle=False, collate_fn=self.collate_fn, ) + if getattr(self, "extra_val_ds", None) is not None: + extra_loader = DataLoader( + self.extra_val_ds, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers if self.num_workers > 0 else False, + shuffle=False, + collate_fn=self.collate_fn, + ) + return [val_loader, extra_loader] + return val_loader def test_dataloader(self) -> DataLoader: return DataLoader( diff --git a/src/trainer/model/model.py b/src/trainer/model/model.py index 7e01f20..6a0c2c7 100644 --- a/src/trainer/model/model.py +++ b/src/trainer/model/model.py @@ -397,11 +397,14 @@ def _to_index(y: torch.Tensor, num_classes: int) -> torch.Tensor: idx = idx - 1 return idx - @staticmethod - def _extract_batch(batch: BatchType) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def _extract_batch(self, batch: BatchType) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: if isinstance(batch, (list, tuple)): - assert len(batch) >= 4, "Expected at least (x, cs, sg, lp) in the batch tuple." - x, y_cs, y_sg, y_lp = batch[:4] + assert len(batch) >= 3, "Expected at least (x, cs, sg) in the batch tuple." + if len(batch) >= 4: + x, y_cs, y_sg, y_lp = batch[:4] + else: + x, y_cs, y_sg = batch[:3] + y_lp = None elif isinstance(batch, dict): # Try common keys x = batch.get("x", batch.get("xrd", batch.get("signal"))) @@ -410,8 +413,8 @@ def _extract_batch(batch: BatchType) -> Tuple[torch.Tensor, torch.Tensor, torch. y_cs = batch.get("cs") y_sg = batch.get("sg") y_lp = batch.get("lattice_params", batch.get("lp")) - if y_cs is None or y_sg is None or y_lp is None: - raise KeyError("Batch dict must contain 'cs', 'sg', and 'lattice_params' (or 'lp').") + if y_cs is None or y_sg is None: + raise KeyError("Batch dict must contain 'cs' and 'sg'. 'lattice_params' (or 'lp') is optional.") else: raise TypeError("Unsupported batch type. Use Tuple or Dict.") @@ -421,7 +424,7 @@ def _extract_batch(batch: BatchType) -> Tuple[torch.Tensor, torch.Tensor, torch. # Loss and metrics # ----------------------------- def _compute_losses_and_metrics( - self, preds: Dict[str, torch.Tensor], y_cs: torch.Tensor, y_sg: torch.Tensor, y_lp: torch.Tensor + self, preds: Dict[str, torch.Tensor], y_cs: torch.Tensor, y_sg: torch.Tensor, y_lp: Optional[torch.Tensor] ) -> Dict[str, torch.Tensor]: cs_logits = preds["cs_logits"] sg_logits = preds["sg_logits"] @@ -430,12 +433,13 @@ def _compute_losses_and_metrics( # targets: convert one-hot to index for CE if necessary y_cs_idx = self._to_index(y_cs, self.num_cs_classes) y_sg_idx = self._to_index(y_sg, self.num_sg_classes) - # regression targets should be float - y_lp = y_lp.float() + # regression targets should be float if present + if y_lp is not None: + y_lp = y_lp.float() loss_cs = self.ce(cs_logits, y_cs_idx) loss_sg = self.ce(sg_logits, y_sg_idx) - loss_lp = self.mse(lp_pred, y_lp) + loss_lp = self.mse(lp_pred, y_lp) if y_lp is not None else torch.tensor(0.0, device=cs_logits.device) # Optional GEMD term loss_gemd = torch.tensor(0.0, device=cs_logits.device) @@ -456,8 +460,12 @@ def _compute_losses_and_metrics( with torch.no_grad(): cs_acc = (cs_logits.argmax(dim=1) == y_cs_idx).float().mean() sg_acc = (sg_logits.argmax(dim=1) == y_sg_idx).float().mean() - lp_mae = (lp_pred - y_lp).abs().mean() - lp_mse = F.mse_loss(lp_pred, y_lp) + if y_lp is not None: + lp_mae = (lp_pred - y_lp).abs().mean() + lp_mse = F.mse_loss(lp_pred, y_lp) + else: + lp_mae = None + lp_mse = None return { "loss_total": total_loss, @@ -486,25 +494,31 @@ def training_step(self, batch: BatchType, batch_idx: int) -> torch.Tensor: self.log("train/loss_gemd", out["loss_gemd"], on_step=True, on_epoch=True) self.log("train/cs_acc", out["cs_acc"], prog_bar=True, on_step=True, on_epoch=True) self.log("train/sg_acc", out["sg_acc"], on_step=True, on_epoch=True) - self.log("train/lp_mae", out["lp_mae"], on_step=True, on_epoch=True) - self.log("train/lp_mse", out["lp_mse"], on_step=True, on_epoch=True) + if out.get("lp_mae") is not None: + self.log("train/lp_mae", out["lp_mae"], on_step=True, on_epoch=True) + if out.get("lp_mse") is not None: + self.log("train/lp_mse", out["lp_mse"], on_step=True, on_epoch=True) return out["loss_total"] - def validation_step(self, batch: BatchType, batch_idx: int) -> None: + def validation_step(self, batch: BatchType, batch_idx: int, dataloader_idx: Optional[int] = None) -> None: x, y_cs, y_sg, y_lp = self._extract_batch(batch) preds = self(x) out = self._compute_losses_and_metrics(preds, y_cs, y_sg, y_lp) - self.log("val/loss", out["loss_total"], prog_bar=True, on_epoch=True) - self.log("val/loss_cs", out["loss_cs"], on_epoch=True) - self.log("val/loss_sg", out["loss_sg"], on_epoch=True) - self.log("val/loss_lp", out["loss_lp"], on_epoch=True) - self.log("val/loss_gemd", out["loss_gemd"], on_epoch=True) - self.log("val/cs_acc", out["cs_acc"], prog_bar=True, on_epoch=True) - self.log("val/sg_acc", out["sg_acc"], on_epoch=True) - self.log("val/lp_mae", out["lp_mae"], on_epoch=True) - self.log("val/lp_mse", out["lp_mse"], on_epoch=True) + prefix = "val" if (dataloader_idx is None or dataloader_idx == 0) else "val_rruff" + + self.log(f"{prefix}/loss", out["loss_total"], prog_bar=True, on_epoch=True, add_dataloader_idx=False) + self.log(f"{prefix}/loss_cs", out["loss_cs"], on_epoch=True, add_dataloader_idx=False) + self.log(f"{prefix}/loss_sg", out["loss_sg"], on_epoch=True, add_dataloader_idx=False) + self.log(f"{prefix}/loss_lp", out["loss_lp"], on_epoch=True, add_dataloader_idx=False) + self.log(f"{prefix}/loss_gemd", out["loss_gemd"], on_epoch=True, add_dataloader_idx=False) + self.log(f"{prefix}/cs_acc", out["cs_acc"], prog_bar=True, on_epoch=True, add_dataloader_idx=False) + self.log(f"{prefix}/sg_acc", out["sg_acc"], on_epoch=True, add_dataloader_idx=False) + if out.get("lp_mae") is not None: + self.log(f"{prefix}/lp_mae", out["lp_mae"], on_epoch=True, add_dataloader_idx=False) + if out.get("lp_mse") is not None: + self.log(f"{prefix}/lp_mse", out["lp_mse"], on_epoch=True, add_dataloader_idx=False) def test_step(self, batch: BatchType, batch_idx: int) -> None: x, y_cs, y_sg, y_lp = self._extract_batch(batch) @@ -518,8 +532,10 @@ def test_step(self, batch: BatchType, batch_idx: int) -> None: self.log("test/loss_gemd", out["loss_gemd"], on_epoch=True) self.log("test/cs_acc", out["cs_acc"], prog_bar=True, on_epoch=True) self.log("test/sg_acc", out["sg_acc"], on_epoch=True) - self.log("test/lp_mae", out["lp_mae"], on_epoch=True) - self.log("test/lp_mse", out["lp_mse"], on_epoch=True) + if out.get("lp_mae") is not None: + self.log("test/lp_mae", out["lp_mae"], on_epoch=True) + if out.get("lp_mse") is not None: + self.log("test/lp_mse", out["lp_mse"], on_epoch=True) def configure_optimizers(self): params = self.parameters() diff --git a/src/trainer/model/wav2vec2_model.py b/src/trainer/model/wav2vec2_model.py index 9c075b1..9df21ef 100644 --- a/src/trainer/model/wav2vec2_model.py +++ b/src/trainer/model/wav2vec2_model.py @@ -388,11 +388,14 @@ def _to_index(y: torch.Tensor, num_classes: int) -> torch.Tensor: idx = idx.clamp(min=0, max=num_classes - 1) return idx - @staticmethod - def _extract_batch(batch: BatchType) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def _extract_batch(self, batch: BatchType) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: if isinstance(batch, (list, tuple)): - assert len(batch) >= 4, "Expected at least (x, cs, sg, lp) in the batch tuple." - x, y_cs, y_sg, y_lp = batch[:4] + assert len(batch) >= 3, "Expected at least (x, cs, sg) in the batch tuple." + if len(batch) >= 4: + x, y_cs, y_sg, y_lp = batch[:4] + else: + x, y_cs, y_sg = batch[:3] + y_lp = None elif isinstance(batch, dict): x = batch.get("x", batch.get("xrd", batch.get("signal"))) if x is None: @@ -400,8 +403,8 @@ def _extract_batch(batch: BatchType) -> Tuple[torch.Tensor, torch.Tensor, torch. y_cs = batch.get("cs") y_sg = batch.get("sg") y_lp = batch.get("lattice_params", batch.get("lp")) - if y_cs is None or y_sg is None or y_lp is None: - raise KeyError("Batch dict must contain 'cs', 'sg', and 'lattice_params' (or 'lp').") + if y_cs is None or y_sg is None: + raise KeyError("Batch dict must contain 'cs' and 'sg'. 'lattice_params' (or 'lp') is optional.") else: raise TypeError("Unsupported batch type. Use Tuple or Dict.") @@ -411,7 +414,7 @@ def _extract_batch(batch: BatchType) -> Tuple[torch.Tensor, torch.Tensor, torch. # Loss and metrics # ----------------------------- def _compute_losses_and_metrics( - self, preds: Dict[str, torch.Tensor], y_cs: torch.Tensor, y_sg: torch.Tensor, y_lp: torch.Tensor + self, preds: Dict[str, torch.Tensor], y_cs: torch.Tensor, y_sg: torch.Tensor, y_lp: Optional[torch.Tensor] ) -> Dict[str, torch.Tensor]: cs_logits = preds["cs_logits"] sg_logits = preds["sg_logits"] @@ -419,11 +422,12 @@ def _compute_losses_and_metrics( y_cs_idx = self._to_index(y_cs, self.num_cs_classes) y_sg_idx = self._to_index(y_sg, self.num_sg_classes) - y_lp = y_lp.float() + if y_lp is not None: + y_lp = y_lp.float() loss_cs = self.ce(cs_logits, y_cs_idx) loss_sg = self.ce(sg_logits, y_sg_idx) - loss_lp = self.mse(lp_pred, y_lp) + loss_lp = self.mse(lp_pred, y_lp) if y_lp is not None else torch.tensor(0.0, device=cs_logits.device) loss_gemd = torch.tensor(0.0, device=cs_logits.device) sg_probs = torch.softmax(sg_logits, dim=1) @@ -446,8 +450,12 @@ def _compute_losses_and_metrics( sg_top5 = ( sg_logits.topk(5, dim=1).indices.eq(y_sg_idx.unsqueeze(1)).any(dim=1).float().mean() ) - lp_mae = (lp_pred - y_lp).abs().mean() - lp_mse = F.mse_loss(lp_pred, y_lp) + if y_lp is not None: + lp_mae = (lp_pred - y_lp).abs().mean() + lp_mse = F.mse_loss(lp_pred, y_lp) + else: + lp_mae = None + lp_mse = None return { "loss_total": total_loss, @@ -478,26 +486,32 @@ def training_step(self, batch: BatchType, batch_idx: int) -> torch.Tensor: self.log("train/cs_acc", out["cs_acc"], prog_bar=True, on_step=True, on_epoch=True) self.log("train/sg_acc", out["sg_acc"], on_step=True, on_epoch=True) self.log("train/sg_top5", out["sg_top5"], on_step=True, on_epoch=True) - self.log("train/lp_mae", out["lp_mae"], on_step=True, on_epoch=True) - self.log("train/lp_mse", out["lp_mse"], on_step=True, on_epoch=True) + if out.get("lp_mae") is not None: + self.log("train/lp_mae", out["lp_mae"], on_step=True, on_epoch=True) + if out.get("lp_mse") is not None: + self.log("train/lp_mse", out["lp_mse"], on_step=True, on_epoch=True) return out["loss_total"] - def validation_step(self, batch: BatchType, batch_idx: int) -> None: + def validation_step(self, batch: BatchType, batch_idx: int, dataloader_idx: Optional[int] = None) -> None: x, y_cs, y_sg, y_lp = self._extract_batch(batch) preds = self(x) out = self._compute_losses_and_metrics(preds, y_cs, y_sg, y_lp) - self.log("val/loss", out["loss_total"], prog_bar=True, on_epoch=True) - self.log("val/loss_cs", out["loss_cs"], on_epoch=True) - self.log("val/loss_sg", out["loss_sg"], on_epoch=True) - self.log("val/loss_lp", out["loss_lp"], on_epoch=True) - self.log("val/loss_gemd", out["loss_gemd"], on_epoch=True) - self.log("val/cs_acc", out["cs_acc"], prog_bar=True, on_epoch=True) - self.log("val/sg_acc", out["sg_acc"], on_epoch=True) - self.log("val/sg_top5", out["sg_top5"], on_epoch=True) - self.log("val/lp_mae", out["lp_mae"], on_epoch=True) - self.log("val/lp_mse", out["lp_mse"], on_epoch=True) + prefix = "val" if (dataloader_idx is None or dataloader_idx == 0) else "val_rruff" + + self.log(f"{prefix}/loss", out["loss_total"], prog_bar=True, on_epoch=True, add_dataloader_idx=False) + self.log(f"{prefix}/loss_cs", out["loss_cs"], on_epoch=True, add_dataloader_idx=False) + self.log(f"{prefix}/loss_sg", out["loss_sg"], on_epoch=True, add_dataloader_idx=False) + self.log(f"{prefix}/loss_lp", out["loss_lp"], on_epoch=True, add_dataloader_idx=False) + self.log(f"{prefix}/loss_gemd", out["loss_gemd"], on_epoch=True, add_dataloader_idx=False) + self.log(f"{prefix}/cs_acc", out["cs_acc"], prog_bar=True, on_epoch=True, add_dataloader_idx=False) + self.log(f"{prefix}/sg_acc", out["sg_acc"], on_epoch=True, add_dataloader_idx=False) + self.log(f"{prefix}/sg_top5", out["sg_top5"], on_epoch=True, add_dataloader_idx=False) + if out.get("lp_mae") is not None: + self.log(f"{prefix}/lp_mae", out["lp_mae"], on_epoch=True, add_dataloader_idx=False) + if out.get("lp_mse") is not None: + self.log(f"{prefix}/lp_mse", out["lp_mse"], on_epoch=True, add_dataloader_idx=False) def test_step(self, batch: BatchType, batch_idx: int) -> None: x, y_cs, y_sg, y_lp = self._extract_batch(batch) @@ -512,8 +526,10 @@ def test_step(self, batch: BatchType, batch_idx: int) -> None: self.log("test/cs_acc", out["cs_acc"], prog_bar=True, on_epoch=True) self.log("test/sg_acc", out["sg_acc"], on_epoch=True) self.log("test/sg_top5", out["sg_top5"], on_epoch=True) - self.log("test/lp_mae", out["lp_mae"], on_epoch=True) - self.log("test/lp_mse", out["lp_mse"], on_epoch=True) + if out.get("lp_mae") is not None: + self.log("test/lp_mae", out["lp_mae"], on_epoch=True) + if out.get("lp_mse") is not None: + self.log("test/lp_mse", out["lp_mse"], on_epoch=True) def configure_optimizers(self): params = self.parameters() diff --git a/src/trainer/train_paper.py b/src/trainer/train_paper.py index d40742e..cfa1e14 100644 --- a/src/trainer/train_paper.py +++ b/src/trainer/train_paper.py @@ -76,6 +76,7 @@ def build_datamodule_from_cfg(cfg: Dict[str, Any]) -> NpyDataModule: val_ratio=cfg["val_ratio"], test_ratio=cfg["test_ratio"], seed=cfg["seed"], + extra_val_file=cfg.get("extra_val_file"), ) return dm @@ -177,6 +178,50 @@ def on_fit_start(self, trainer, pl_module) -> None: logger.experiment.log_artifact(logger.run_id, self.raw_config_path) +class RunCheckpointDirCallback(Callback): + """ + On fit start, create a unique sub-folder for checkpoints and, if MLflow is the logger, + log the folder path to MLflow so each run's checkpoints are easily discoverable. + """ + def __init__(self, base_dir: Optional[str] = None): + self.base_dir = base_dir + + def on_fit_start(self, trainer, pl_module) -> None: + # Determine base checkpoints directory + base_dir = self.base_dir or os.path.join(trainer.default_root_dir, "checkpoints") + + # Derive a unique run-specific subfolder + logger = getattr(trainer, "logger", None) + run_suffix: Optional[str] = None + if MLFlowLogger is not None and isinstance(logger, MLFlowLogger): + # Prefer the MLflow run_id when available + run_suffix = getattr(logger, "run_id", None) + + if not run_suffix: + # Fallback to timestamp-based folder if no MLflow run_id + from datetime import datetime + run_suffix = datetime.now().strftime("%Y%m%d_%H%M%S") + + run_ckpt_dir = os.path.join(base_dir, run_suffix) + os.makedirs(run_ckpt_dir, exist_ok=True) + + # Update any ModelCheckpoint callbacks to write into the run-specific directory + if hasattr(trainer, "checkpoint_callback") and trainer.checkpoint_callback is not None: + trainer.checkpoint_callback.dirpath = run_ckpt_dir + for cb in trainer.callbacks: + if isinstance(cb, ModelCheckpoint): + cb.dirpath = run_ckpt_dir + + # If MLflow is active, log the checkpoint directory path for this run + if MLFlowLogger is not None and isinstance(logger, MLFlowLogger): + try: + # Log as a MLflow param and tag for easy discovery + logger.experiment.log_param(logger.run_id, "checkpoint_dir", run_ckpt_dir) + logger.experiment.set_tag(logger.run_id, "checkpoint_dir", run_ckpt_dir) + except Exception: + pass + + def build_trainer_from_cfg(cfg: Dict[str, Any], raw_config_path: Optional[str] = None) -> Trainer: ckpt_cb = ModelCheckpoint( monitor=cfg["monitor"], @@ -211,13 +256,14 @@ def build_trainer_from_cfg(cfg: Dict[str, Any], raw_config_path: Optional[str] = devices=cfg["devices"], precision=cfg["precision"], accumulate_grad_batches=cfg["accumulate_grad_batches"], - callbacks=[ckpt_cb, lr_cb, ConfigArtifactLogger(raw_config_path)], + callbacks=[ckpt_cb, lr_cb, ConfigArtifactLogger(raw_config_path), RunCheckpointDirCallback()], logger=logger, log_every_n_steps=cfg["log_every_n_steps"], deterministic=cfg["deterministic"], benchmark=cfg["benchmark"], gradient_clip_val=cfg.get("gradient_clip_val", 0.0), gradient_clip_algorithm=cfg.get("gradient_clip_algorithm", "norm"), + val_check_interval=0.5 ) return trainer From d472d5f0859472b90cdacc6e35445976542d6d8c Mon Sep 17 00:00:00 2001 From: linked-liszt Date: Fri, 7 Nov 2025 14:58:55 -0600 Subject: [PATCH 10/18] feat: Add paper noising implementation --- src/trainer/dataset/datamodule.py | 20 +- src/trainer/dataset/dataset.py | 66 +++++- src/trainer/infer_rruff.py | 348 ++++++++++++++++++++++++++++++ src/trainer/train_paper.py | 3 + 4 files changed, 430 insertions(+), 7 deletions(-) create mode 100644 src/trainer/infer_rruff.py diff --git a/src/trainer/dataset/datamodule.py b/src/trainer/dataset/datamodule.py index 94379fe..13ab178 100644 --- a/src/trainer/dataset/datamodule.py +++ b/src/trainer/dataset/datamodule.py @@ -37,7 +37,7 @@ from __future__ import annotations import os -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Tuple # Resolve base for relative paths as the current working directory (runtime CWD) CWD_BASE = os.getcwd() @@ -53,7 +53,7 @@ "Install with: pip install pytorch-lightning" ) from e -from .dataset import NpyManifestDataset, default_manifest_paths +from .dataset import NpyManifestDataset, default_manifest_paths, make_poisson_gaussian_noise_transform from .manifest_utils import generate_manifests @@ -100,6 +100,9 @@ def __init__( test_file: str = "test.jsonl", # Optional: add a second validation manifest file (e.g., "rruff.jsonl") extra_val_file: Optional[str] = None, + # Optional noise augmentation for training split only + noise_poisson_range: Optional[Tuple[float, float]] = None, + noise_gaussian_range: Optional[Tuple[float, float]] = None, ) -> None: super().__init__() self.manifest_dir = manifest_dir @@ -124,6 +127,10 @@ def __init__( self.test_file = test_file self.extra_val_file = extra_val_file + # Optional noise augmentation params + self.noise_poisson_range = noise_poisson_range + self.noise_gaussian_range = noise_gaussian_range + # Internal datasets self.train_ds = None self.val_ds = None @@ -184,7 +191,14 @@ def setup(self, stage: Optional[str] = None) -> None: """ paths = self._manifest_paths() if self.train_ds is None: - self.train_ds = self.dataset_cls(paths["train"], **self.dataset_kwargs) + # Apply noise transform to training split only if ranges are provided + train_kwargs = dict(self.dataset_kwargs) + if self.noise_poisson_range is not None and self.noise_gaussian_range is not None: + train_kwargs = dict(train_kwargs) # shallow copy + train_kwargs["transform"] = make_poisson_gaussian_noise_transform( + self.noise_poisson_range, self.noise_gaussian_range + ) + self.train_ds = self.dataset_cls(paths["train"], **train_kwargs) if self.val_ds is None: self.val_ds = self.dataset_cls(paths["val"], **self.dataset_kwargs) # Extra validation dataset (e.g., rruff) diff --git a/src/trainer/dataset/dataset.py b/src/trainer/dataset/dataset.py index 69b3bc8..e260194 100644 --- a/src/trainer/dataset/dataset.py +++ b/src/trainer/dataset/dataset.py @@ -28,6 +28,63 @@ import torch from torch.utils.data import Dataset +# Noise transform builder following the paper's two-step model +from typing import Tuple, Callable + +def make_poisson_gaussian_noise_transform(poisson_range: Tuple[float, float], gaussian_range: Tuple[float, float]) -> Callable[[torch.Tensor], torch.Tensor]: + """ + Build a transform that applies the paper's two-step noise model: + 1) Poisson noise with mean lambda sampled uniformly from poisson_range, applied to dp/max(dp), + then rescaled by max(dp)/lambda. + 2) Gaussian noise with std sigma sampled uniformly from gaussian_range, applied after normalizing to [0,1]. + Finally renormalize by the new min/max and rescale back to the Poisson-perturbed range, clamp to non-negative. + """ + lam_lo, lam_hi = float(poisson_range[0]), float(poisson_range[1]) + sig_lo, sig_hi = float(gaussian_range[0]), float(gaussian_range[1]) + + def _transform(x: torch.Tensor) -> torch.Tensor: + # Preserve shape + orig_shape = x.shape + dp = x.reshape(-1) + # If all nonpositive, return as-is (clamped) + dp_max = torch.max(dp) + if dp_max <= 0: + return torch.clamp(x, min=0) + + # Sample parameters + lam = torch.empty(1).uniform_(lam_lo, lam_hi).item() + sigma = torch.empty(1).uniform_(sig_lo, sig_hi).item() + + # Poisson step + rate = torch.clamp(dp, min=0) / dp_max * lam + dp_pois = torch.poisson(rate) + dp_pois = dp_pois * dp_max / lam + + # Gaussian step on normalized intensity + dp_min = torch.min(dp_pois) + dp_max2 = torch.max(dp_pois) + denom = dp_max2 - dp_min + if denom > 0: + norm = (dp_pois - dp_min) / denom + else: + norm = torch.zeros_like(dp_pois) + dp_gauss = norm + torch.randn_like(norm) * sigma + + # Final renormalize and rescale back to original Poisson range + g_min = torch.min(dp_gauss) + g_max = torch.max(dp_gauss) + g_denom = g_max - g_min + if g_denom > 0: + norm2 = (dp_gauss - g_min) / g_denom + else: + norm2 = torch.zeros_like(dp_gauss) + dp_noisy = norm2 * denom + dp_min + + dp_noisy = torch.clamp(dp_noisy, min=0.0) + return dp_noisy.reshape(orig_shape) + + return _transform + def _read_jsonl_manifest(manifest_path: str) -> Tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]: """ @@ -236,14 +293,15 @@ def _get_exact(container, key: str): # Cast dtype and apply preprocessing/transform to x only if self.dtype is not None and x_tensor is not None: x_tensor = x_tensor.to(self.dtype) - # Ensure non-negative counts + # Ensure non-negative counts (needed before Poisson rate computation) if x_tensor is not None and self.floor_at_zero: x_tensor = torch.clamp(x_tensor, min=0) - # Optional log1p normalization to compress peak variance - if x_tensor is not None and self.normalize_log1p: - x_tensor = torch.log1p(x_tensor) + # Apply noise/augmentation transform following the paper (before any log transforms) if self.transform is not None and x_tensor is not None: x_tensor = self.transform(x_tensor) + # Optional log1p normalization to compress peak variance (applied after noise) + if x_tensor is not None and self.normalize_log1p: + x_tensor = torch.log1p(x_tensor) if not self.return_meta: # Backward-compat: return only x diff --git a/src/trainer/infer_rruff.py b/src/trainer/infer_rruff.py new file mode 100644 index 0000000..398025b --- /dev/null +++ b/src/trainer/infer_rruff.py @@ -0,0 +1,348 @@ +import argparse +import os +from typing import Any, Dict, Optional, Sequence + +import torch +import yaml +import numpy as np +from tqdm import tqdm + +# Optional: use sklearn for confusion matrix if available; otherwise fall back to numpy implementation +try: + from sklearn.metrics import confusion_matrix as sk_confusion_matrix +except Exception: + sk_confusion_matrix = None + +import matplotlib +matplotlib.use("Agg") # non-interactive backend +import matplotlib.pyplot as plt + +# Expect PYTHONPATH=src; run with: python -m trainer.infer_rruff configs/trainer_wav2vec2.yaml --ckpt /path/to.ckpt +from dataset import NpyDataModule +from model.wav2vec2_model import AlphaDiffractWav2Vec2Lightning +from model.model import AlphaDiffractLightning + + +# ----------------------------- +# Config & datamodule builders (adapted from train_paper/debug_sanity) +# ----------------------------- + +def load_config(path: str) -> Dict[str, Any]: + if not os.path.isfile(path): + raise FileNotFoundError(f"Config file not found: {path}") + with open(path, "r", encoding="utf-8") as f: + cfg = yaml.safe_load(f) + if not isinstance(cfg, dict): + raise ValueError(f"Config must be a mapping (YAML dict), got: {type(cfg)}") + return cfg + + +def _to_dtype(name: str) -> torch.dtype: + table = { + "float32": torch.float32, + "float64": torch.float64, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + } + if name not in table: + raise ValueError(f"Unsupported dtype '{name}'. Allowed: {list(table.keys())}") + return table[name] + + +def build_datamodule_from_cfg(cfg: Dict[str, Any]) -> NpyDataModule: + dataset_kwargs = { + "dtype": _to_dtype(cfg["dtype"]), + "mmap_mode": cfg["mmap_mode"], + "return_meta": True, + "validate_paths": cfg["validate_paths"], + "extract_labels": cfg["extract_labels"], + "allow_pickle": cfg["allow_pickle"], + "floor_at_zero": cfg["floor_at_zero"], + "normalize_log1p": cfg["normalize_log1p"], + } + labels_key_map = cfg.get("labels_key_map") + if labels_key_map is not None: + dataset_kwargs["labels_key_map"] = labels_key_map + + dm = NpyDataModule( + manifest_dir=cfg["manifest_dir"], + batch_size=cfg["batch_size"], + num_workers=cfg["num_workers"], + pin_memory=cfg["pin_memory"], + persistent_workers=cfg["persistent_workers"] and cfg["num_workers"] > 0, + dataset_kwargs=dataset_kwargs, + dataset_root=cfg["dataset_root"], + auto_generate_manifests=cfg["auto_generate_manifests"], + train_ratio=cfg["train_ratio"], + val_ratio=cfg["val_ratio"], + test_ratio=cfg["test_ratio"], + seed=cfg["seed"], + extra_val_file=cfg.get("extra_val_file"), # expect "rruff.jsonl" + ) + return dm + + +def build_model_class_from_cfg(cfg: Dict[str, Any]): + model_type = cfg.get("model_type", "convnext").lower() + return AlphaDiffractWav2Vec2Lightning if model_type == "wav2vec2" else AlphaDiffractLightning + + +# ----------------------------- +# Checkpoint resolution helpers +# ----------------------------- + +def find_checkpoint(cfg: Dict[str, Any], candidate_override: Optional[str] = None) -> Optional[str]: + """ + Resolve a checkpoint path. Priority: + 1) Explicit override if provided and exists + 2) Last/best checkpoint under default_root_dir/checkpoints/** + 3) Heuristics: also check under src/trainer/outputs/... in case training ran from src/trainer + """ + if candidate_override and os.path.isfile(candidate_override): + return candidate_override + + # Helper to scan a base checkpoints dir + def _scan_dir(base: str) -> Optional[str]: + if not os.path.isdir(base): + return None + # Prefer epoch-best ckpt if present, else last.ckpt + best_ckpts: list[str] = [] + last_ckpts: list[str] = [] + for root, _dirs, files in os.walk(base): + for fn in files: + if fn.endswith(".ckpt"): + full = os.path.join(root, fn) + if "val_loss" in fn: + best_ckpts.append(full) + elif fn == "last.ckpt": + last_ckpts.append(full) + # Heuristic: choose first best, else first last + if best_ckpts: + # sort by epoch number if present + try: + best_ckpts.sort(key=lambda p: int(os.path.basename(p).split("epoch")[1].split("-")[0])) + except Exception: + pass + return best_ckpts[0] + if last_ckpts: + return last_ckpts[0] + return None + + # 2) default_root_dir from cfg (relative to CWD) + default_root = cfg.get("default_root_dir", "outputs/model") + ckpt_dir = os.path.join(default_root, "checkpoints") + found = _scan_dir(ckpt_dir) + if found: + return found + + # 3) try under src/trainer/outputs in case runs were launched from within src/trainer + alt_ckpt_dir = os.path.join("src", "trainer", "outputs", os.path.basename(default_root), "checkpoints") + found = _scan_dir(alt_ckpt_dir) + return found + + +# ----------------------------- +# Evaluation and confusion matrix +# ----------------------------- + +def compute_confusion(y_true: np.ndarray, y_pred: np.ndarray, labels: Sequence[int], normalize: bool) -> np.ndarray: + if sk_confusion_matrix is not None: + norm_opt = "true" if normalize else None + return sk_confusion_matrix(y_true, y_pred, labels=list(labels), normalize=norm_opt) + # Fallback: manual + L = len(labels) + label_to_idx = {lab: i for i, lab in enumerate(labels)} + cm = np.zeros((L, L), dtype=np.float64) + for t, p in zip(y_true, y_pred): + cm[label_to_idx[t], label_to_idx[p]] += 1 + if normalize: + row_sums = cm.sum(axis=1, keepdims=True) + row_sums[row_sums == 0] = 1.0 + cm = cm / row_sums + return cm + + +def plot_confusion(cm: np.ndarray, labels: Sequence[int], title: str, save_path: str) -> None: + size = max(6, min(20, int(len(labels) * 0.4))) # adapt figure size + fig, ax = plt.subplots(figsize=(size, size)) + im = ax.imshow(cm, cmap="Blues", interpolation="nearest", aspect="auto") + ax.set_title(title) + ax.set_xlabel("Predicted") + ax.set_ylabel("True") + + # Tick labels only if small enough + if len(labels) <= 30: + ax.set_xticks(range(len(labels))) + ax.set_yticks(range(len(labels))) + ax.set_xticklabels(labels, rotation=90) + ax.set_yticklabels(labels) + else: + ax.set_xticks([]) + ax.set_yticks([]) + + fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) + plt.tight_layout() + os.makedirs(os.path.dirname(save_path), exist_ok=True) + fig.savefig(save_path, dpi=200) + plt.close(fig) + + +# ----------------------------- +# Main +# ----------------------------- + +def _resolve_loader_for_split(dm: NpyDataModule, split: str): + split = split.lower() + if split == "rruff": + # Get RRUFF loader: val_dataloader returns list [val_loader, extra_loader] when extra_val_file is set + dm.setup("validate") + val_dl = dm.val_dataloader() + if isinstance(val_dl, list): + if len(val_dl) < 2: + raise RuntimeError("extra_val_file set, but val_dataloader did not return the extra loader") + return val_dl[1] + else: + if getattr(dm, "extra_val_ds", None) is None: + raise RuntimeError("RRUFF dataset not available. Ensure 'extra_val_file' is set in the config (e.g., 'rruff.jsonl').") + from torch.utils.data import DataLoader + return DataLoader( + dm.extra_val_ds, + batch_size=dm.batch_size, + num_workers=dm.num_workers, + pin_memory=dm.pin_memory, + persistent_workers=dm.persistent_workers if dm.num_workers > 0 else False, + shuffle=False, + ) + elif split == "val": + dm.setup("validate") + val_dl = dm.val_dataloader() + if isinstance(val_dl, list): + return val_dl[0] + return val_dl + elif split == "test": + dm.setup("test") + return dm.test_dataloader() + elif split == "train": + dm.setup("fit") + return dm.train_dataloader() + else: + raise ValueError("split must be one of: rruff, val, test, train") + + +def run_inference_and_confusion(cfg_path: str, ckpt_path: Optional[str], task: str = "cs", normalize: bool = True, device: Optional[str] = None, out_dir: Optional[str] = None, split: str = "rruff") -> str: + cfg = load_config(cfg_path) + dm = build_datamodule_from_cfg(cfg) + dm.prepare_data() + + # Resolve which loader/split to evaluate + loader = _resolve_loader_for_split(dm, split) + + # Resolve checkpoint + ckpt_resolved = find_checkpoint(cfg, candidate_override=ckpt_path) + if not ckpt_resolved: + raise FileNotFoundError("Could not resolve a checkpoint. Pass --ckpt explicitly or ensure outputs/checkpoints exists.") + + # Build model class and load from checkpoint + ModelCls = build_model_class_from_cfg(cfg) + dev = device or ("cuda" if torch.cuda.is_available() else "cpu") + model = ModelCls.load_from_checkpoint(ckpt_resolved, map_location=dev) + model.eval() + model.to(dev) + + # Determine task specifics + if task not in ("cs", "sg"): + raise ValueError("task must be 'cs' or 'sg'") + num_classes = model.num_cs_classes if task == "cs" else model.num_sg_classes + logits_key = "cs_logits" if task == "cs" else "sg_logits" + label_key = "cs" if task == "cs" else "sg" + + # Collect predictions and targets + y_true_all: list[int] = [] + y_pred_all: list[int] = [] + + with torch.no_grad(): + for batch in tqdm(loader): + if not isinstance(batch, dict): + raise RuntimeError("Expected dict batches from NpyDataModule") + x = batch.get("x") + y = batch.get(label_key) + if x is None or y is None: + # Skip if label missing + continue + x = x.to(dev) + preds = model(x) + logits = preds[logits_key] + y_idx = model._to_index(y.to(dev), num_classes) + y_pred = logits.argmax(dim=1) + y_true_all.extend(y_idx.cpu().numpy().tolist()) + y_pred_all.extend(y_pred.cpu().numpy().tolist()) + + if len(y_true_all) == 0: + raise RuntimeError("No labeled samples found for the selected split/task") + + y_true_arr = np.asarray(y_true_all, dtype=np.int32) + y_pred_arr = np.asarray(y_pred_all, dtype=np.int32) + + # Use only classes present to keep plot size reasonable + present_labels = sorted(set(y_true_arr.tolist())) + cm = compute_confusion(y_true_arr, y_pred_arr, labels=present_labels, normalize=normalize) + + # Accuracy summary + acc = float((y_true_arr == y_pred_arr).mean()) + + # Output paths + out_base = out_dir or os.path.join("outputs", "eval") + os.makedirs(out_base, exist_ok=True) + split_tag = split.lower() + img_path = os.path.join(out_base, f"{split_tag}_confusion_{task}.png") + npy_path = os.path.join(out_base, f"{split_tag}_confusion_{task}.npy") + + # Save plot and raw matrix + plot_title = f"{split_tag.upper()} {task.upper()} confusion (normalize={'row' if normalize else 'none'}) | acc={acc:.3f} | classes={len(present_labels)}" + plot_confusion(cm, labels=present_labels, title=plot_title, save_path=img_path) + np.save(npy_path, cm) + + # Also save a text summary + summary_path = os.path.join(out_base, f"{split_tag}_confusion_{task}_summary.txt") + with open(summary_path, "w", encoding="utf-8") as f: + f.write(f"Checkpoint: {ckpt_resolved}\n") + f.write(f"Split: {split_tag}\n") + f.write(f"Task: {task}\n") + f.write(f"Accuracy: {acc:.4f}\n") + f.write(f"Classes present: {len(present_labels)}\n") + f.write(f"Labels: {present_labels}\n") + f.write(f"Image: {img_path}\n") + f.write(f"Matrix: {npy_path}\n") + + return img_path + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description="Run inference on a chosen split (RRUFF/val/test/train) and plot confusion matrix") + p.add_argument("config", type=str, help="Path to trainer config YAML (e.g., configs/trainer_wav2vec2.yaml)") + p.add_argument("--ckpt", type=str, default=None, help="Path to Lightning checkpoint (.ckpt). If omitted, attempts auto-discovery.") + p.add_argument("--task", type=str, default="cs", choices=["cs", "sg"], help="Which head to evaluate: cs (7 classes) or sg (230 classes)") + p.add_argument("--split", type=str, default="rruff", choices=["rruff", "val", "test", "train"], help="Dataset split to evaluate") + p.add_argument("--no-normalize", action="store_true", help="Disable row-normalization in confusion matrix") + p.add_argument("--device", type=str, default=None, help="Override device (cuda/cpu)") + p.add_argument("--out-dir", type=str, default=None, help="Directory to save outputs (default: outputs/eval)") + return p.parse_args() + + +def main() -> None: + args = parse_args() + normalize = not args.no_normalize + img_path = run_inference_and_confusion( + cfg_path=args.config, + ckpt_path=args.ckpt, + task=args.task, + normalize=normalize, + device=args.device, + out_dir=args.out_dir, + split=args.split, + ) + print(f"Saved confusion matrix to: {img_path}") + + +if __name__ == "__main__": + main() diff --git a/src/trainer/train_paper.py b/src/trainer/train_paper.py index cfa1e14..d441f19 100644 --- a/src/trainer/train_paper.py +++ b/src/trainer/train_paper.py @@ -77,6 +77,9 @@ def build_datamodule_from_cfg(cfg: Dict[str, Any]) -> NpyDataModule: test_ratio=cfg["test_ratio"], seed=cfg["seed"], extra_val_file=cfg.get("extra_val_file"), + # Optional noise augmentation: apply to training split only + noise_poisson_range=tuple(cfg.get("noise_poisson_range")) if cfg.get("noise_poisson_range") is not None else None, + noise_gaussian_range=tuple(cfg.get("noise_gaussian_range")) if cfg.get("noise_gaussian_range") is not None else None, ) return dm From 021b824398af950a7f9d405e84731548c54321d7 Mon Sep 17 00:00:00 2001 From: linked-liszt Date: Mon, 10 Nov 2025 12:02:00 -0600 Subject: [PATCH 11/18] refactor: Refactor index shift into dataloader --- src/trainer/dataset/datamodule.py | 37 +++++++++++++++++++++++++++++ src/trainer/model/model.py | 13 +++++----- src/trainer/model/wav2vec2_model.py | 11 ++------- 3 files changed, 46 insertions(+), 15 deletions(-) diff --git a/src/trainer/dataset/datamodule.py b/src/trainer/dataset/datamodule.py index 13ab178..42336fe 100644 --- a/src/trainer/dataset/datamodule.py +++ b/src/trainer/dataset/datamodule.py @@ -44,6 +44,7 @@ import torch from torch.utils.data import DataLoader +from torch.utils.data.dataloader import default_collate try: # Prefer importing Lightning; if not installed, this file will error at import time. import pytorch_lightning as pl @@ -57,6 +58,34 @@ from .manifest_utils import generate_manifests + +def _shift_one_based_collate(batch): + """ + Collate function that uses PyTorch's default_collate, then unconditionally shifts + cs and sg labels by -1 (assumes 1-based input). Performed under torch.no_grad to avoid + constructing any graphs. + """ + collated = default_collate(batch) + with torch.no_grad(): + def _shift(t): + return t - 1 if torch.is_tensor(t) else t + + if isinstance(collated, dict): + if "cs" in collated: + collated["cs"] = _shift(collated["cs"]) + if "sg" in collated: + collated["sg"] = _shift(collated["sg"]) + elif isinstance(collated, (list, tuple)): + # Tuple-based batches: (x, cs, sg, [lp]) + lst = list(collated) + if len(lst) >= 2: + lst[1] = _shift(lst[1]) + if len(lst) >= 3: + lst[2] = _shift(lst[2]) + collated = type(collated)(lst) + return collated + + class NpyDataModule(pl.LightningDataModule): """ LightningDataModule that reads train/val/test JSONL manifests and constructs DataLoaders. @@ -84,6 +113,7 @@ def __init__( num_workers: int = 4, pin_memory: bool = True, persistent_workers: bool = True, + prefetch_factor: Optional[int] = None, collate_fn: Optional[Callable] = None, dataset_cls: type = NpyManifestDataset, dataset_kwargs: Optional[Dict[str, Any]] = None, @@ -110,7 +140,10 @@ def __init__( self.num_workers = num_workers self.pin_memory = pin_memory self.persistent_workers = persistent_workers + self.prefetch_factor = prefetch_factor self.collate_fn = collate_fn + if self.collate_fn is None: + self.collate_fn = _shift_one_based_collate self.dataset_cls = dataset_cls self.dataset_kwargs = dataset_kwargs or {} @@ -215,6 +248,7 @@ def train_dataloader(self) -> DataLoader: batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, + prefetch_factor=(self.prefetch_factor if self.prefetch_factor is not None else 4), persistent_workers=self.persistent_workers if self.num_workers > 0 else False, shuffle=True, collate_fn=self.collate_fn, @@ -226,6 +260,7 @@ def val_dataloader(self) -> DataLoader: batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, + prefetch_factor=(self.prefetch_factor if self.prefetch_factor is not None else 4), persistent_workers=self.persistent_workers if self.num_workers > 0 else False, shuffle=False, collate_fn=self.collate_fn, @@ -236,6 +271,7 @@ def val_dataloader(self) -> DataLoader: batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, + prefetch_factor=(self.prefetch_factor if self.prefetch_factor is not None else 4), persistent_workers=self.persistent_workers if self.num_workers > 0 else False, shuffle=False, collate_fn=self.collate_fn, @@ -249,6 +285,7 @@ def test_dataloader(self) -> DataLoader: batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, + prefetch_factor=(self.prefetch_factor if self.prefetch_factor is not None else 4), persistent_workers=self.persistent_workers if self.num_workers > 0 else False, shuffle=False, collate_fn=self.collate_fn, diff --git a/src/trainer/model/model.py b/src/trainer/model/model.py index 6a0c2c7..2283e42 100644 --- a/src/trainer/model/model.py +++ b/src/trainer/model/model.py @@ -383,18 +383,19 @@ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: # ----------------------------- @staticmethod def _to_index(y: torch.Tensor, num_classes: int) -> torch.Tensor: - # Convert targets to class indices. Supports one-hot and integer labels. + """ + Convert labels to 0-based class indices: + - Supports one-hot and integer labels. + - Clamp to [0, num_classes-1] to avoid out-of-range targets. + Assumes labels are already 0-based. + """ if y.dim() > 1 and y.size(-1) > 1: idx = y.argmax(dim=-1) else: idx = y.long() - # Normalize 1-based labels to 0-based if detected (min>=1 and max==num_classes) with torch.no_grad(): if idx.numel() > 0: - minv = int(idx.min().item()) - maxv = int(idx.max().item()) - if minv >= 1 and maxv == num_classes: - idx = idx - 1 + idx = idx.clamp(min=0, max=num_classes - 1) return idx def _extract_batch(self, batch: BatchType) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: diff --git a/src/trainer/model/wav2vec2_model.py b/src/trainer/model/wav2vec2_model.py index 9df21ef..ecf6e61 100644 --- a/src/trainer/model/wav2vec2_model.py +++ b/src/trainer/model/wav2vec2_model.py @@ -367,24 +367,17 @@ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: @staticmethod def _to_index(y: torch.Tensor, num_classes: int) -> torch.Tensor: """ - Robustly convert labels to 0-based class indices: + Convert labels to 0-based class indices: - Supports one-hot and integer labels. - - If integer labels appear 1-based (min>=1 and max<=num_classes), shift to 0-based. - Clamp to [0, num_classes-1] to avoid out-of-range targets. + Assumes labels are already 0-based. """ if y.dim() > 1 and y.size(-1) > 1: idx = y.argmax(dim=-1) else: idx = y.long() - with torch.no_grad(): if idx.numel() > 0: - minv = int(idx.min().item()) - maxv = int(idx.max().item()) - # Shift 1-based labels to 0-based if detected - if minv >= 1 and maxv <= num_classes: - idx = idx - 1 - # Ensure valid range idx = idx.clamp(min=0, max=num_classes - 1) return idx From e071b29fdd860c75365ee2407fb867c1f979e7de Mon Sep 17 00:00:00 2001 From: linked-liszt Date: Tue, 11 Nov 2025 21:40:00 -0600 Subject: [PATCH 12/18] fix: Fixes to align with paper --- configs/trainer_convnext_paper.yaml | 117 ++++++++++++++++++++++++++++ src/trainer/dataset/datamodule.py | 64 +++++++++++++-- src/trainer/model/model.py | 18 ++--- src/trainer/train_paper.py | 3 +- 4 files changed, 181 insertions(+), 21 deletions(-) create mode 100644 configs/trainer_convnext_paper.yaml diff --git a/configs/trainer_convnext_paper.yaml b/configs/trainer_convnext_paper.yaml new file mode 100644 index 0000000..f6fa3f4 --- /dev/null +++ b/configs/trainer_convnext_paper.yaml @@ -0,0 +1,117 @@ +# AlphaDiffract trainer configuration — ConvNeXt (paper-matching lightweight variant) +# Use with: PYTHONPATH=src python -m trainer.train_paper configs/trainer_convnext_paper.yaml + +# --- Data / Manifests --- +manifest_dir: "../../data/manifests" +dataset_root: "../../data/dataset" +extra_val_file: "rruff.jsonl" +auto_generate_manifests: true +train_ratio: 0.8 +val_ratio: 0.1 +test_ratio: 0.1 +seed: 42 + +# --- DataLoader --- +batch_size: 256 # paper used 64 +num_workers: 8 +pin_memory: true +persistent_workers: true + +# --- Dataset label extraction (embedded in .npy/.npz) --- +validate_paths: false +extract_labels: true +allow_pickle: true +labels_key_map: + x: "dp" + cs: "cs" + sg: "sg" + lattice_params: null + lp_a: "_cell_length_a" + lp_b: "_cell_length_b" + lp_c: "_cell_length_c" + lp_alpha: "_cell_angle_alpha" + lp_beta: "_cell_angle_beta" + lp_gamma: "_cell_angle_gamma" +dtype: "float32" +mmap_mode: null +floor_at_zero: true +normalize_log1p: False # paper used log1p preprocessing + +# --- ConvNeXt (lightweight paper variant) --- +# 3 stages; one block per stage; large kernels; stride-5 downsampling +# NOTE: This implementation adapts ConvNeXt to 1D and uses global avg pooling. +depths: [1, 1, 1] +dims: [80, 80, 80] +kernel_sizes: [100, 50, 25] +strides: [5, 5, 5] +dropout_rate: 0.3 +layer_scale_init_value: 1.0e-6 +# Stochastic depth schedule across blocks; paper mentions 0.3 on stem branch +# We set an overall schedule up to 0.3 +drop_path_rate: 0.3 + +# Heads +head_dropout: 0.2 +cs_hidden: [2300, 1150] +sg_hidden: [2300, 1150] +lp_hidden: [512, 256] + +# Task sizes +num_cs_classes: 7 +num_sg_classes: 230 +num_lp_outputs: 6 + +# LP output bounds +lp_bounds_min: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +lp_bounds_max: [500.0, 500.0, 500.0, 180.0, 180.0, 180.0] +bound_lp_with_sigmoid: true + +# Loss weights +lambda_cs: 1.0 +lambda_sg: 0.0 +lambda_lp: 1.0 + +# Optional GEMD term on SG +gemd_mu: 0.0 +gemd_distance_matrix_path: null + +# Optimizer (paper): AdamW, lr=2e-4, wd=0.01 +lr: 0.0002 +weight_decay: 0.01 +use_adamw: true +gradient_clip_val: 1.0 +gradient_clip_algorithm: "norm" + +# --- Noise augmentation (training split only; matches paper) --- +# If provided, noise is applied dynamically per-sample in the DataModule using the same +# sequencing as the paper: Poisson -> normalize -> add Gaussian -> renormalize -> rescale. +# Set ranges to None to disable. +noise_poisson_range: [1.0, 100.0] +noise_gaussian_range: [0.001, 0.1] +# --- Logging --- +logger: "mlflow" +csv_logger_name: "model_logs_convnext_paper" +mlflow_experiment_name: "AlphaDiffract_Paper_ConvNeXt" +mlflow_tracking_uri: null +mlflow_run_name: "ConvNeXt_Paper_Run" + +# --- Trainer settings --- +default_root_dir: "outputs/convnext_paper" +max_epochs: 50 +accumulate_grad_batches: 1 +precision: "bf16-mixed" # switch to "16-mixed" or "32" if needed +accelerator: "gpu" +devices: 1 +log_every_n_steps: 50 +deterministic: false +benchmark: true + +# --- Checkpointing --- +monitor: "val/loss" +mode: "min" +save_top_k: 1 +every_n_epochs: 1 + +# --- Evaluation --- +resume_from: +test_after_train: true diff --git a/src/trainer/dataset/datamodule.py b/src/trainer/dataset/datamodule.py index 42336fe..6b01dbc 100644 --- a/src/trainer/dataset/datamodule.py +++ b/src/trainer/dataset/datamodule.py @@ -58,6 +58,28 @@ from .manifest_utils import generate_manifests +def make_standardize_transform(std_range): + """ + Build a transform that min-max scales a tensor to the given [min, max] range per-sample. + """ + min_t, max_t = float(std_range[0]), float(std_range[1]) + + def _transform(x: torch.Tensor) -> torch.Tensor: + orig_shape = x.shape + dp = x.reshape(-1) + dp_min = torch.min(dp) + dp_max = torch.max(dp) + denom = dp_max - dp_min + if denom > 0: + norm = (dp - dp_min) / denom + else: + norm = torch.zeros_like(dp) + scaled = norm * (max_t - min_t) + min_t + return scaled.reshape(orig_shape) + + return _transform + + def _shift_one_based_collate(batch): """ @@ -133,6 +155,8 @@ def __init__( # Optional noise augmentation for training split only noise_poisson_range: Optional[Tuple[float, float]] = None, noise_gaussian_range: Optional[Tuple[float, float]] = None, + # Optional post-noise standardization to a fixed range (per-sample) + standardize_to: Optional[Tuple[float, float]] = None, ) -> None: super().__init__() self.manifest_dir = manifest_dir @@ -163,6 +187,7 @@ def __init__( # Optional noise augmentation params self.noise_poisson_range = noise_poisson_range self.noise_gaussian_range = noise_gaussian_range + self.standardize_to = standardize_to # Internal datasets self.train_ds = None @@ -224,23 +249,48 @@ def setup(self, stage: Optional[str] = None) -> None: """ paths = self._manifest_paths() if self.train_ds is None: - # Apply noise transform to training split only if ranges are provided + # Compose transforms for training split: noise first, then optional standardization train_kwargs = dict(self.dataset_kwargs) + transforms = [] if self.noise_poisson_range is not None and self.noise_gaussian_range is not None: - train_kwargs = dict(train_kwargs) # shallow copy - train_kwargs["transform"] = make_poisson_gaussian_noise_transform( - self.noise_poisson_range, self.noise_gaussian_range + transforms.append( + make_poisson_gaussian_noise_transform(self.noise_poisson_range, self.noise_gaussian_range) ) + if self.standardize_to is not None: + transforms.append(make_standardize_transform(self.standardize_to)) + + if len(transforms) > 0: + train_kwargs = dict(train_kwargs) # shallow copy + def _compose(ts): + def _f(x: torch.Tensor) -> torch.Tensor: + for t in ts: + x = t(x) + return x + return _f + train_kwargs["transform"] = _compose(transforms) + self.train_ds = self.dataset_cls(paths["train"], **train_kwargs) if self.val_ds is None: - self.val_ds = self.dataset_cls(paths["val"], **self.dataset_kwargs) + val_kwargs = dict(self.dataset_kwargs) + if self.standardize_to is not None: + val_kwargs = dict(val_kwargs) + val_kwargs["transform"] = make_standardize_transform(self.standardize_to) + self.val_ds = self.dataset_cls(paths["val"], **val_kwargs) # Extra validation dataset (e.g., rruff) if self.extra_val_ds is None and self.extra_val_file is not None: manifest_dir_abs = os.path.dirname(paths["val"]) extra_path = os.path.join(manifest_dir_abs, self.extra_val_file) - self.extra_val_ds = self.dataset_cls(extra_path, **self.dataset_kwargs) + extra_kwargs = dict(self.dataset_kwargs) + if self.standardize_to is not None: + extra_kwargs = dict(extra_kwargs) + extra_kwargs["transform"] = make_standardize_transform(self.standardize_to) + self.extra_val_ds = self.dataset_cls(extra_path, **extra_kwargs) if self.test_ds is None: - self.test_ds = self.dataset_cls(paths["test"], **self.dataset_kwargs) + test_kwargs = dict(self.dataset_kwargs) + if self.standardize_to is not None: + test_kwargs = dict(test_kwargs) + test_kwargs["transform"] = make_standardize_transform(self.standardize_to) + self.test_ds = self.dataset_cls(paths["test"], **test_kwargs) def train_dataloader(self) -> DataLoader: return DataLoader( diff --git a/src/trainer/model/model.py b/src/trainer/model/model.py index 2283e42..6ce7d0f 100644 --- a/src/trainer/model/model.py +++ b/src/trainer/model/model.py @@ -47,11 +47,9 @@ def __init__( super().__init__() # depthwise 1D conv self.dwconv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding='same', groups=dim) - # LayerNorm in channels-last layout - self.norm = nn.LayerNorm(dim, eps=1e-6) # pointwise MLP implemented by Linear on channels-last self.pwconv1 = nn.Linear(dim, 4 * dim) - self.act = nn.GELU() + self.act = nn.LeakyReLU() self.pwconv2 = nn.Linear(4 * dim, dim) # layer-scale self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(dim)) if layer_scale_init_value > 0 else None @@ -63,7 +61,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: shortcut = x x = self.dwconv(x) # (N, C, L) x = x.permute(0, 2, 1) # (N, L, C) - x = self.norm(x) x = self.pwconv1(x) x = self.act(x) x = self.pwconv2(x) @@ -130,7 +127,7 @@ def __init__( self.dim_output = dims[-1] # Stem: patchify conv + LayerNorm (channels-last) - self.stem_conv = nn.Conv1d(in_chans, dims[0], kernel_size=strides[0], stride=strides[0]) + self.stem_conv = nn.Conv1d(in_chans, dims[0], kernel_size=1, stride=1) self.stem_norm = nn.LayerNorm(dims[0], eps=1e-6) # Stochastic depth schedule across all blocks @@ -157,12 +154,10 @@ def __init__( dp_idx += 1 self.blocks.append(nn.Sequential(*stage_blocks)) - # Downsample transitions between stages + # Downsample transitions between stages (average pooling to match OG) self.downsamples = nn.ModuleList() for i in range(n_stages - 1): - self.downsamples.append( - Downsample1D(in_dim=dims[i], out_dim=dims[i + 1], stride=strides[i + 1]) - ) + self.downsamples.append(nn.AvgPool1d(kernel_size=strides[i + 1], stride=strides[i + 1])) # Optional dropout after each downsample transition self.down_drops = nn.ModuleList([ @@ -176,10 +171,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x[:, None, :] # Stem - x = self.stem_conv(x) # (N, C0, L0) - x = x.permute(0, 2, 1) # (N, L0, C0) - x = self.stem_norm(x) - x = x.permute(0, 2, 1) # (N, C0, L0) + x = self.stem_conv(x) # (N, C0, L) # Stages + downsample transitions for i, stage in enumerate(self.blocks): diff --git a/src/trainer/train_paper.py b/src/trainer/train_paper.py index d441f19..11fd2f4 100644 --- a/src/trainer/train_paper.py +++ b/src/trainer/train_paper.py @@ -80,6 +80,8 @@ def build_datamodule_from_cfg(cfg: Dict[str, Any]) -> NpyDataModule: # Optional noise augmentation: apply to training split only noise_poisson_range=tuple(cfg.get("noise_poisson_range")) if cfg.get("noise_poisson_range") is not None else None, noise_gaussian_range=tuple(cfg.get("noise_gaussian_range")) if cfg.get("noise_gaussian_range") is not None else None, + # Optional standardization after noise to match OG runs + standardize_to=tuple(cfg.get("standardize_to")) if cfg.get("standardize_to") is not None else None, ) return dm @@ -249,7 +251,6 @@ def build_trainer_from_cfg(cfg: Dict[str, Any], raw_config_path: Optional[str] = experiment_name=cfg["mlflow_experiment_name"], tracking_uri=cfg["mlflow_tracking_uri"], run_name=cfg["mlflow_run_name"], - log_model=True, ) trainer = Trainer( From b357670eedf6c232fdc99958d494e7fffe880b84 Mon Sep 17 00:00:00 2001 From: linked-liszt Date: Wed, 12 Nov 2025 15:23:38 -0600 Subject: [PATCH 13/18] fix: Refactor model implementation to closer fit paper --- configs/trainer_convnext_paper.yaml | 30 +- configs/trainer_wav2vec2_nonorm.yaml | 19 +- src/trainer/infer_rruff.py | 9 +- src/trainer/model/model.py | 438 ++++++++++++--------------- src/trainer/train_paper.py | 28 +- 5 files changed, 241 insertions(+), 283 deletions(-) diff --git a/configs/trainer_convnext_paper.yaml b/configs/trainer_convnext_paper.yaml index f6fa3f4..8526647 100644 --- a/configs/trainer_convnext_paper.yaml +++ b/configs/trainer_convnext_paper.yaml @@ -2,8 +2,8 @@ # Use with: PYTHONPATH=src python -m trainer.train_paper configs/trainer_convnext_paper.yaml # --- Data / Manifests --- -manifest_dir: "../../data/manifests" -dataset_root: "../../data/dataset" +manifest_dir: "../../../ad_data/manifests" +dataset_root: "../../../ad_data/data/dataset" extra_val_file: "rruff.jsonl" auto_generate_manifests: true train_ratio: 0.8 @@ -12,7 +12,7 @@ test_ratio: 0.1 seed: 42 # --- DataLoader --- -batch_size: 256 # paper used 64 +batch_size: 64 # match OG run (64 per process) num_workers: 8 pin_memory: true persistent_workers: true @@ -36,6 +36,7 @@ dtype: "float32" mmap_mode: null floor_at_zero: true normalize_log1p: False # paper used log1p preprocessing +model_type: "multiscale" # --- ConvNeXt (lightweight paper variant) --- # 3 stages; one block per stage; large kernels; stride-5 downsampling @@ -48,10 +49,16 @@ dropout_rate: 0.3 layer_scale_init_value: 1.0e-6 # Stochastic depth schedule across blocks; paper mentions 0.3 on stem branch # We set an overall schedule up to 0.3 -drop_path_rate: 0.3 +drop_path_rate: 0.0 +ramped_dropout_rate: false +block_type: "convnext" +pooling_type: "average" +final_pool: true +use_batchnorm: false +output_type: "flatten" # Heads -head_dropout: 0.2 +head_dropout: 0.5 cs_hidden: [2300, 1150] sg_hidden: [2300, 1150] lp_hidden: [512, 256] @@ -63,12 +70,12 @@ num_lp_outputs: 6 # LP output bounds lp_bounds_min: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] -lp_bounds_max: [500.0, 500.0, 500.0, 180.0, 180.0, 180.0] +lp_bounds_max: [300.0, 300.0, 300.0, 180.0, 180.0, 180.0] bound_lp_with_sigmoid: true # Loss weights lambda_cs: 1.0 -lambda_sg: 0.0 +lambda_sg: 1.0 lambda_lp: 1.0 # Optional GEMD term on SG @@ -88,6 +95,9 @@ gradient_clip_algorithm: "norm" # Set ranges to None to disable. noise_poisson_range: [1.0, 100.0] noise_gaussian_range: [0.001, 0.1] + +# Standardize after noise to match OG CLI (--standardize-to 0 100) +standardize_to: [0.0, 100.0] # --- Logging --- logger: "mlflow" csv_logger_name: "model_logs_convnext_paper" @@ -97,12 +107,12 @@ mlflow_run_name: "ConvNeXt_Paper_Run" # --- Trainer settings --- default_root_dir: "outputs/convnext_paper" -max_epochs: 50 +max_epochs: 100 accumulate_grad_batches: 1 -precision: "bf16-mixed" # switch to "16-mixed" or "32" if needed +precision: "32" # match OG (AMP disabled) accelerator: "gpu" devices: 1 -log_every_n_steps: 50 +log_every_n_steps: 200 deterministic: false benchmark: true diff --git a/configs/trainer_wav2vec2_nonorm.yaml b/configs/trainer_wav2vec2_nonorm.yaml index 4edfefd..2e778e6 100644 --- a/configs/trainer_wav2vec2_nonorm.yaml +++ b/configs/trainer_wav2vec2_nonorm.yaml @@ -3,9 +3,9 @@ # Use with: PYTHONPATH=src python -m trainer.train_paper configs/trainer_wav2vec2_nonorm.yaml # --- Data / Manifests --- -manifest_dir: "../../data/manifests" -extra_val_file: "rruff_sim.jsonl" -dataset_root: "../../data/dataset" +manifest_dir: "../../../ad_data/manifest_original" +dataset_root: "../../../ad_data/data/dataset" +extra_val_file: "rruff.jsonl" auto_generate_manifests: true train_ratio: 0.8 val_ratio: 0.1 @@ -13,7 +13,7 @@ test_ratio: 0.1 seed: 42 # --- DataLoader --- -batch_size: 200 +batch_size: 180 num_workers: 8 pin_memory: true persistent_workers: true @@ -44,6 +44,7 @@ normalize_log1p: false # DISABLED for this run # Set ranges to None to disable. noise_poisson_range: [1.0, 100.0] noise_gaussian_range: [0.001, 0.1] +standardize_to: [0.0, 100.0] # --- Model selection --- model_type: "wav2vec2" # choose between "convnext" and "wav2vec2" @@ -57,15 +58,15 @@ ff_dim: 2048 # Convolutional feature extractor — total stride = 16 -> ~512 tokens for length 8192 conv_kernel_sizes: [10, 5, 3, 3, 3, 2] conv_strides: [2, 2, 2, 2, 1, 1] -conv_dropout: 0.0 +conv_dropout: 0.05 pos_kernel_size: 129 -pos_dropout: 0.1 -encoder_dropout: 0.1 +pos_dropout: 0.2 +encoder_dropout: 0.2 layer_norm_first: true token_pool: "cls" # "mean" or "cls" # Heads (reuse paper defaults) -head_dropout: 0.2 +head_dropout: 0.4 cs_hidden: [2300, 1150] sg_hidden: [2300, 1150] lp_hidden: [512, 256] @@ -91,7 +92,7 @@ gemd_distance_matrix_path: # Optimizer lr: 0.0004 -weight_decay: 0.01 +weight_decay: 0.02 use_adamw: true gradient_clip_val: 1.0 gradient_clip_algorithm: "norm" diff --git a/src/trainer/infer_rruff.py b/src/trainer/infer_rruff.py index 398025b..3e43a9e 100644 --- a/src/trainer/infer_rruff.py +++ b/src/trainer/infer_rruff.py @@ -20,7 +20,7 @@ # Expect PYTHONPATH=src; run with: python -m trainer.infer_rruff configs/trainer_wav2vec2.yaml --ckpt /path/to.ckpt from dataset import NpyDataModule from model.wav2vec2_model import AlphaDiffractWav2Vec2Lightning -from model.model import AlphaDiffractLightning +from model.model import AlphaDiffractMultiscaleLightning # ----------------------------- @@ -84,7 +84,12 @@ def build_datamodule_from_cfg(cfg: Dict[str, Any]) -> NpyDataModule: def build_model_class_from_cfg(cfg: Dict[str, Any]): model_type = cfg.get("model_type", "convnext").lower() - return AlphaDiffractWav2Vec2Lightning if model_type == "wav2vec2" else AlphaDiffractLightning + if model_type == "wav2vec2": + return AlphaDiffractWav2Vec2Lightning + elif model_type == "multiscale": + return AlphaDiffractMultiscaleLightning + else: + raise ValueError(f"Unsupported model_type '{model_type}'. Expected 'wav2vec2' or 'multiscale'.") # ----------------------------- diff --git a/src/trainer/model/model.py b/src/trainer/model/model.py index 6ce7d0f..8a3fc62 100644 --- a/src/trainer/model/model.py +++ b/src/trainer/model/model.py @@ -74,115 +74,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # ----------------------------- # Downsample: LN (channels-last) -> Conv1d (stride, channel increase) # ----------------------------- -class Downsample1D(nn.Module): - def __init__(self, in_dim: int, out_dim: int, stride: int): - super().__init__() - self.norm = nn.LayerNorm(in_dim, eps=1e-6) - # Use kernel_size = stride to mimic patch/downsample conv - self.conv = nn.Conv1d(in_dim, out_dim, kernel_size=stride, stride=stride) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # x: (N, C, L) - x = x.permute(0, 2, 1) # (N, L, C) - x = self.norm(x) - x = x.permute(0, 2, 1) # (N, C, L) - x = self.conv(x) # (N, C_out, L/stride) - return x # ----------------------------- # Backbone: ConvNeXt1D (generalized to N stages) # ----------------------------- -class ConvNeXt1DBackbone(nn.Module): - """ - ConvNeXt backbone adapted for 1D XRD signals, generalized to N stages: - - Stem: patchify Conv1d with stride=strides[0], out=dims[0] + LayerNorm - - Each stage: [depths[i] ConvNeXt blocks] using DWConv k=kernel_sizes[i], GELU, layer-scale - - Between stages: Downsample1D with stride=strides[i+1], increasing channels dims[i] -> dims[i+1] - - Final: global average pooling over length to produce dims[-1]-dim feature vector - - Notes: - - Stochastic depth (drop path) is linearly scheduled across all blocks from 0 to drop_path_rate. - - kernel_sizes can be domain-specific; ConvNeXt canonical uses k=7. - - dims should typically increase across stages (e.g., [80, 160, 320, 640]). - """ - - def __init__( - self, - in_chans: int = 1, - depths: Tuple[int, ...] = (3, 3, 9, 3), - dims: Tuple[int, ...] = (80, 160, 320, 640), - kernel_sizes: Tuple[int, ...] = (7, 7, 7, 7), - strides: Tuple[int, ...] = (4, 2, 2, 2), - dropout_rate: float = 0.3, - layer_scale_init_value: float = 1e-6, - drop_path_rate: float = 0.1, - ): - super().__init__() - - n_stages = len(depths) - assert len(dims) == n_stages and len(kernel_sizes) == n_stages and len(strides) == n_stages, \ - "depths, dims, kernel_sizes, strides must have same length" - - self.dim_output = dims[-1] - - # Stem: patchify conv + LayerNorm (channels-last) - self.stem_conv = nn.Conv1d(in_chans, dims[0], kernel_size=1, stride=1) - self.stem_norm = nn.LayerNorm(dims[0], eps=1e-6) - - # Stochastic depth schedule across all blocks - total_blocks = int(sum(depths)) - if total_blocks > 0 and drop_path_rate > 0.0: - dp_rates = np.linspace(0.0, drop_path_rate, total_blocks).tolist() - else: - dp_rates = [0.0] * max(total_blocks, 1) - - # Build stages - self.blocks = nn.ModuleList() - dp_idx = 0 - for i in range(n_stages): - stage_blocks: List[nn.Module] = [] - for _ in range(depths[i]): - stage_blocks.append( - ConvNeXtBlock1D( - dim=dims[i], - kernel_size=kernel_sizes[i], - drop_path=dp_rates[dp_idx] if dp_idx < len(dp_rates) else 0.0, - layer_scale_init_value=layer_scale_init_value, - ) - ) - dp_idx += 1 - self.blocks.append(nn.Sequential(*stage_blocks)) - - # Downsample transitions between stages (average pooling to match OG) - self.downsamples = nn.ModuleList() - for i in range(n_stages - 1): - self.downsamples.append(nn.AvgPool1d(kernel_size=strides[i + 1], stride=strides[i + 1])) - - # Optional dropout after each downsample transition - self.down_drops = nn.ModuleList([ - nn.Dropout(p=dropout_rate) if dropout_rate > 0.0 else nn.Identity() - for _ in range(n_stages - 1) - ]) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # Accept (N, L) or (N, 1, L) - if x.ndim == 2: - x = x[:, None, :] - - # Stem - x = self.stem_conv(x) # (N, C0, L) - - # Stages + downsample transitions - for i, stage in enumerate(self.blocks): - x = stage(x) # (N, Ci, Li) - if i < len(self.downsamples): - x = self.downsamples[i](x) # (N, Ci+1, Li+1) - x = self.down_drops[i](x) - - # Global average pooling over length - x = x.mean(dim=-1) # (N, dims[-1]) - return x # ----------------------------- @@ -199,7 +95,7 @@ def make_mlp( last = input_dim if hidden_dims is not None and len(hidden_dims) > 0: for hd in hidden_dims: - layers.extend([nn.Linear(last, hd), nn.LeakyReLU()]) + layers.extend([nn.Linear(last, hd), nn.ReLU()]) if dropout and dropout > 0: layers.append(nn.Dropout(dropout)) last = hd @@ -218,28 +114,149 @@ def make_mlp( ] -class AlphaDiffractLightning(pl.LightningModule): - """ - PyTorch Lightning module for the AlphaDiffract model: - - CS classifier head (7 classes) - - SG classifier head (230 classes) - - LP regressor head (6 outputs, bounded via sigmoid to [min, max]) - Expected batch formats: - - Tuple: (x, y_cs, y_sg, y_lp) - - Dict keys: x or xrd or signal; cs, sg, lattice_params (or lp) - """ +# ----------------------------- +# OG-style Multiscale CNN Backbone (1D) with ConvNeXt-like blocks +# Mirrors alphadiffract.model.MultiscaleCNNBackbone behavior: +# - sequential conv stages with specified kernel_sizes and strides +# - optional average/max pooling between stages and at the end +# - output_type: 'gap' or 'flatten' +# ----------------------------- +class MultiscaleCNNBackbone1D(nn.Module): def __init__( self, - depths: Tuple[int, ...] = (3, 3, 9, 3), - dims: Tuple[int, ...] = (80, 160, 320, 640), - kernel_sizes: Tuple[int, ...] = (7, 7, 7, 7), - strides: Tuple[int, ...] = (4, 2, 2, 2), + dim_in: int = 8192, + channels: Tuple[int, ...] = (80, 80, 80), + kernel_sizes: Tuple[int, ...] = (100, 50, 25), + strides: Tuple[int, ...] = (5, 5, 5), dropout_rate: float = 0.3, + ramped_dropout_rate: bool = False, + block_type: str = "convnext", + pooling_type: str = "average", + final_pool: bool = True, + use_batchnorm: bool = False, + activation: nn.Module = nn.LeakyReLU, + output_type: str = "flatten", layer_scale_init_value: float = 1e-6, - drop_path_rate: float = 0.1, + drop_path_rate: float = 0.0, + ): + super().__init__() + assert len(channels) == len(kernel_sizes) == len(strides), "channels, kernel_sizes, strides must match lengths" + self.dim_in = dim_in + self.output_type = output_type + + # Build per-stage dropout schedule + if ramped_dropout_rate: + dropout_per_stage = torch.linspace(0.0, dropout_rate, steps=len(channels)).tolist() + else: + dropout_per_stage = [dropout_rate] * len(channels) + + # Stochastic depth schedule for ConvNeXt-like blocks + total_blocks = len(channels) + if drop_path_rate > 0.0 and total_blocks > 0: + dp_rates = np.linspace(0.0, drop_path_rate, total_blocks).tolist() + else: + dp_rates = [0.0] * max(total_blocks, 1) + # Select pooling module + if pooling_type == "average": + pool_cls = nn.AvgPool1d + pool_kwargs = {"kernel_size": 3, "stride": 2} + elif pooling_type == "max": + pool_cls = nn.MaxPool1d + pool_kwargs = {"kernel_size": 2, "stride": 2} + else: + raise ValueError(f"Invalid pooling_type '{pooling_type}'") + + layers: List[nn.Module] = [] + in_ch = 1 + for i, (out_ch, k, s) in enumerate(zip(channels, kernel_sizes, strides)): + stage_layers: List[nn.Module] = [] + + # Stage conv (stride-based downsampling) + stage_layers.append(nn.Conv1d(in_ch, out_ch, kernel_size=k, stride=s, padding=0, bias=not use_batchnorm)) + if use_batchnorm: + stage_layers.append(nn.BatchNorm1d(out_ch)) + # Activation and Dropout + act = activation() if isinstance(activation, type) else activation + stage_layers.append(act) + if dropout_per_stage[i] > 0.0: + stage_layers.append(nn.Dropout(p=float(dropout_per_stage[i]))) + + # Optional ConvNeXt-like refinement block operating at out_ch + if block_type == "convnext": + stage_layers.append( + ConvNeXtBlock1D( + dim=out_ch, + kernel_size=k, + drop_path=dp_rates[i] if i < len(dp_rates) else 0.0, + layer_scale_init_value=layer_scale_init_value, + ) + ) + elif block_type in ("single_conv", "double_conv"): + # Already performed the primary conv; keep as-is + pass + else: + raise ValueError(f"Invalid block_type '{block_type}'") + + layers.append(nn.Sequential(*stage_layers)) + + # Inter-stage pooling + if i < len(channels) - 1 or final_pool: + layers.append(pool_cls(**pool_kwargs)) + + in_ch = out_ch + + self.net = nn.Sequential(*layers) + + # Determine output feature dimension + if self.output_type == "gap": + self.dim_output = channels[-1] + elif self.output_type == "flatten": + with torch.no_grad(): + dummy = torch.zeros(1, 1, self.dim_in) + out = self.net(dummy) + self.dim_output = int(out.shape[1] * out.shape[2]) + else: + raise ValueError(f"Invalid output_type '{self.output_type}'") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Accept (N, L) or (N, 1, L) + if x.ndim == 2: + x = x[:, None, :] + x = self.net(x) + if self.output_type == "gap": + x = x.mean(dim=-1) + else: # flatten + x = x.reshape(x.shape[0], -1) + return x + + +# ----------------------------- +# Lightning Module using MultiscaleCNNBackbone1D +# Same heads, losses, and metrics behavior as AlphaDiffractLightning +# ----------------------------- +class AlphaDiffractMultiscaleLightning(pl.LightningModule): + def __init__( + self, + # Backbone params (OG-style) + dim_in: int = 8192, + channels: Tuple[int, ...] = (80, 80, 80), + kernel_sizes: Tuple[int, ...] = (100, 50, 25), + strides: Tuple[int, ...] = (5, 5, 5), + dropout_rate: float = 0.3, + ramped_dropout_rate: bool = False, + block_type: str = "convnext", + pooling_type: str = "average", + final_pool: bool = True, + use_batchnorm: bool = False, + activation: nn.Module = nn.LeakyReLU, + output_type: str = "flatten", + layer_scale_init_value: float = 1e-6, + drop_path_rate: float = 0.0, + + # Heads head_dropout: float = 0.5, cs_hidden: Optional[Tuple[int, ...]] = (2300, 1150), sg_hidden: Optional[Tuple[int, ...]] = (2300, 1150), @@ -259,9 +276,6 @@ def __init__( lambda_cs: float = 1.0, lambda_sg: float = 1.0, lambda_lp: float = 1.0, - # Optional GEMD for SG - gemd_mu: float = 0.0, - gemd_distance_matrix_path: Optional[str] = None, # Optimizer lr: float = 2e-4, @@ -272,19 +286,25 @@ def __init__( self.save_hyperparameters() # Backbone - self.backbone = ConvNeXt1DBackbone( - in_chans=1, - depths=depths, - dims=dims, + self.backbone = MultiscaleCNNBackbone1D( + dim_in=dim_in, + channels=channels, kernel_sizes=kernel_sizes, strides=strides, dropout_rate=dropout_rate, + ramped_dropout_rate=ramped_dropout_rate, + block_type=block_type, + pooling_type=pooling_type, + final_pool=final_pool, + use_batchnorm=use_batchnorm, + activation=activation, + output_type=output_type, layer_scale_init_value=layer_scale_init_value, drop_path_rate=drop_path_rate, ) - feat_dim = self.backbone.dim_output # dims[-1] + feat_dim = self.backbone.dim_output - # Heads (produce logits for classification; no softmax here) + # Heads self.cs_head = make_mlp( input_dim=feat_dim, hidden_dims=cs_hidden, @@ -307,22 +327,17 @@ def __init__( output_activation=None, ) - # Losses + # Losses and bounds self.ce = nn.CrossEntropyLoss() self.mse = nn.MSELoss() - - # LP bounds self.register_buffer("lp_min", torch.tensor(lp_bounds_min, dtype=torch.float32)) self.register_buffer("lp_max", torch.tensor(lp_bounds_max, dtype=torch.float32)) - self.bound_lp_with_sigmoid = bound_lp_with_sigmoid - # weights + # weights and optim config self.lambda_cs = lambda_cs self.lambda_sg = lambda_sg self.lambda_lp = lambda_lp - - # Optimizer config self.lr = lr self.weight_decay = weight_decay self.use_adamw = use_adamw @@ -332,55 +347,17 @@ def __init__( self.num_sg_classes = num_sg_classes self.num_lp_outputs = num_lp_outputs - # GEMD setup (optional) - self.gemd_mu = gemd_mu - self.register_buffer("gemd_D", torch.empty(0)) - if gemd_distance_matrix_path is not None: - D_np = np.load(gemd_distance_matrix_path) - D_t = torch.as_tensor(D_np, dtype=torch.float32) - if D_t.ndim != 2 or D_t.shape[0] != self.num_sg_classes or D_t.shape[1] != self.num_sg_classes: - raise ValueError("GEMD distance matrix must be of shape (num_sg_classes, num_sg_classes)") - self.register_buffer("gemd_D", D_t) - - # ----------------------------- - # Forward - # ----------------------------- def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: - """ - Forward pass. - Returns a dict with: - - cs_logits: (N, num_cs_classes) - - sg_logits: (N, num_sg_classes) - - lp: (N, num_lp_outputs) bounded if enabled - - features: (N, dims[-1]) - """ - feats = self.backbone(x) # (N, dims[-1]) + feats = self.backbone(x) cs_logits = self.cs_head(feats) sg_logits = self.sg_head(feats) lp = self.lp_head(feats) - if self.bound_lp_with_sigmoid: - # Bound to [min, max] via sigmoid lp = torch.sigmoid(lp) * (self.lp_max - self.lp_min) + self.lp_min + return {"features": feats, "cs_logits": cs_logits, "sg_logits": sg_logits, "lp": lp} - return { - "features": feats, - "cs_logits": cs_logits, - "sg_logits": sg_logits, - "lp": lp, - } - - # ----------------------------- - # Data parsing helpers - # ----------------------------- @staticmethod def _to_index(y: torch.Tensor, num_classes: int) -> torch.Tensor: - """ - Convert labels to 0-based class indices: - - Supports one-hot and integer labels. - - Clamp to [0, num_classes-1] to avoid out-of-range targets. - Assumes labels are already 0-based. - """ if y.dim() > 1 and y.size(-1) > 1: idx = y.argmax(dim=-1) else: @@ -399,7 +376,6 @@ def _extract_batch(self, batch: BatchType) -> Tuple[torch.Tensor, torch.Tensor, x, y_cs, y_sg = batch[:3] y_lp = None elif isinstance(batch, dict): - # Try common keys x = batch.get("x", batch.get("xrd", batch.get("signal"))) if x is None: raise KeyError("Batch dict must contain 'x' or 'xrd' or 'signal'.") @@ -410,46 +386,24 @@ def _extract_batch(self, batch: BatchType) -> Tuple[torch.Tensor, torch.Tensor, raise KeyError("Batch dict must contain 'cs' and 'sg'. 'lattice_params' (or 'lp') is optional.") else: raise TypeError("Unsupported batch type. Use Tuple or Dict.") - return x, y_cs, y_sg, y_lp - # ----------------------------- - # Loss and metrics - # ----------------------------- def _compute_losses_and_metrics( self, preds: Dict[str, torch.Tensor], y_cs: torch.Tensor, y_sg: torch.Tensor, y_lp: Optional[torch.Tensor] ) -> Dict[str, torch.Tensor]: cs_logits = preds["cs_logits"] sg_logits = preds["sg_logits"] lp_pred = preds["lp"] - - # targets: convert one-hot to index for CE if necessary y_cs_idx = self._to_index(y_cs, self.num_cs_classes) y_sg_idx = self._to_index(y_sg, self.num_sg_classes) - # regression targets should be float if present if y_lp is not None: y_lp = y_lp.float() loss_cs = self.ce(cs_logits, y_cs_idx) loss_sg = self.ce(sg_logits, y_sg_idx) loss_lp = self.mse(lp_pred, y_lp) if y_lp is not None else torch.tensor(0.0, device=cs_logits.device) + total_loss = self.lambda_cs * loss_cs + self.lambda_sg * loss_sg + self.lambda_lp * loss_lp - # Optional GEMD term - loss_gemd = torch.tensor(0.0, device=cs_logits.device) - sg_probs = torch.softmax(sg_logits, dim=1) - if self.gemd_mu > 0.0 and self.gemd_D.numel() > 0: - D_rows = self.gemd_D[y_sg_idx] - gemd_per_sample = (D_rows * sg_probs).sum(dim=1) - loss_gemd = gemd_per_sample.mean() - - total_loss = ( - self.lambda_cs * loss_cs - + self.lambda_sg * loss_sg - + self.lambda_lp * loss_lp - + self.gemd_mu * loss_gemd - ) - - # metrics with torch.no_grad(): cs_acc = (cs_logits.argmax(dim=1) == y_cs_idx).float().mean() sg_acc = (sg_logits.argmax(dim=1) == y_sg_idx).float().mean() @@ -465,47 +419,37 @@ def _compute_losses_and_metrics( "loss_cs": loss_cs, "loss_sg": loss_sg, "loss_lp": loss_lp, - "loss_gemd": loss_gemd, "cs_acc": cs_acc, "sg_acc": sg_acc, "lp_mae": lp_mae, "lp_mse": lp_mse, } - # ----------------------------- - # Lightning hooks - # ----------------------------- def training_step(self, batch: BatchType, batch_idx: int) -> torch.Tensor: x, y_cs, y_sg, y_lp = self._extract_batch(batch) preds = self(x) out = self._compute_losses_and_metrics(preds, y_cs, y_sg, y_lp) - self.log("train/loss", out["loss_total"], prog_bar=True, on_step=True, on_epoch=True) self.log("train/loss_cs", out["loss_cs"], on_step=True, on_epoch=True) self.log("train/loss_sg", out["loss_sg"], on_step=True, on_epoch=True) self.log("train/loss_lp", out["loss_lp"], on_step=True, on_epoch=True) - self.log("train/loss_gemd", out["loss_gemd"], on_step=True, on_epoch=True) self.log("train/cs_acc", out["cs_acc"], prog_bar=True, on_step=True, on_epoch=True) self.log("train/sg_acc", out["sg_acc"], on_step=True, on_epoch=True) if out.get("lp_mae") is not None: self.log("train/lp_mae", out["lp_mae"], on_step=True, on_epoch=True) if out.get("lp_mse") is not None: self.log("train/lp_mse", out["lp_mse"], on_step=True, on_epoch=True) - return out["loss_total"] def validation_step(self, batch: BatchType, batch_idx: int, dataloader_idx: Optional[int] = None) -> None: x, y_cs, y_sg, y_lp = self._extract_batch(batch) preds = self(x) out = self._compute_losses_and_metrics(preds, y_cs, y_sg, y_lp) - prefix = "val" if (dataloader_idx is None or dataloader_idx == 0) else "val_rruff" - self.log(f"{prefix}/loss", out["loss_total"], prog_bar=True, on_epoch=True, add_dataloader_idx=False) self.log(f"{prefix}/loss_cs", out["loss_cs"], on_epoch=True, add_dataloader_idx=False) self.log(f"{prefix}/loss_sg", out["loss_sg"], on_epoch=True, add_dataloader_idx=False) self.log(f"{prefix}/loss_lp", out["loss_lp"], on_epoch=True, add_dataloader_idx=False) - self.log(f"{prefix}/loss_gemd", out["loss_gemd"], on_epoch=True, add_dataloader_idx=False) self.log(f"{prefix}/cs_acc", out["cs_acc"], prog_bar=True, on_epoch=True, add_dataloader_idx=False) self.log(f"{prefix}/sg_acc", out["sg_acc"], on_epoch=True, add_dataloader_idx=False) if out.get("lp_mae") is not None: @@ -517,12 +461,10 @@ def test_step(self, batch: BatchType, batch_idx: int) -> None: x, y_cs, y_sg, y_lp = self._extract_batch(batch) preds = self(x) out = self._compute_losses_and_metrics(preds, y_cs, y_sg, y_lp) - self.log("test/loss", out["loss_total"], prog_bar=True, on_epoch=True) self.log("test/loss_cs", out["loss_cs"], on_epoch=True) self.log("test/loss_sg", out["loss_sg"], on_epoch=True) self.log("test/loss_lp", out["loss_lp"], on_epoch=True) - self.log("test/loss_gemd", out["loss_gemd"], on_epoch=True) self.log("test/cs_acc", out["cs_acc"], prog_bar=True, on_epoch=True) self.log("test/sg_acc", out["sg_acc"], on_epoch=True) if out.get("lp_mae") is not None: @@ -536,43 +478,39 @@ def configure_optimizers(self): optimizer = torch.optim.AdamW(params, lr=self.lr, weight_decay=self.weight_decay) else: optimizer = torch.optim.Adam(params, lr=self.lr, weight_decay=self.weight_decay) - return optimizer + # Compute steps per epoch to match OG scheduler semantics: + # step_size_up = 6 * iterations_per_epoch + steps_per_epoch = None + try: + if hasattr(self, "trainer") and self.trainer is not None: + total_steps = getattr(self.trainer, "estimated_stepping_batches", None) + max_epochs = getattr(self.trainer, "max_epochs", None) + if total_steps is not None and max_epochs is not None and max_epochs > 0: + steps_per_epoch = max(1, total_steps // max_epochs) + except Exception: + pass + + if steps_per_epoch is None: + # Fallback if trainer hooks are unavailable; use a conservative default + steps_per_epoch = 100 + + step_size_up = int(6 * steps_per_epoch) + + scheduler = torch.optim.lr_scheduler.CyclicLR( + optimizer, + base_lr=self.lr * 0.1, + max_lr=self.lr, + step_size_up=step_size_up, + cycle_momentum=False, + mode="triangular2", + ) -# ----------------------------- -# Example factory -# ----------------------------- -def build_alphadiffract_model_for_8192() -> AlphaDiffractLightning: - """ - Build the model for 1x8192 XRD input, ConvNeXt-style 4-stage backbone, and three heads. - """ - return AlphaDiffractLightning( - depths=(3, 3, 9, 3), - dims=(80, 160, 320, 640), - kernel_sizes=(7, 7, 7, 7), - strides=(4, 2, 2, 2), - dropout_rate=0.3, - layer_scale_init_value=1e-6, - drop_path_rate=0.1, - - head_dropout=0.5, - cs_hidden=(2300, 1150), - sg_hidden=(2300, 1150), - lp_hidden=(512, 256), - - num_cs_classes=7, - num_sg_classes=230, - num_lp_outputs=6, - - lp_bounds_min=(0.0, 0.0, 0.0, 0.0, 0.0, 0.0), - lp_bounds_max=(500.0, 500.0, 500.0, 180.0, 180.0, 180.0), - bound_lp_with_sigmoid=True, - - lambda_cs=1.0, - lambda_sg=1.0, - lambda_lp=1.0, - - lr=2e-4, - weight_decay=1e-2, - use_adamw=True, - ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "step", # per-batch stepping, identical to OG + "name": "cyclic_lr", + }, + } diff --git a/src/trainer/train_paper.py b/src/trainer/train_paper.py index 11fd2f4..3f3d396 100644 --- a/src/trainer/train_paper.py +++ b/src/trainer/train_paper.py @@ -14,7 +14,7 @@ # Project imports (expect PYTHONPATH=src or run via `python -m trainer.train_paper`) from dataset import NpyDataModule -from model.model import AlphaDiffractLightning +from model.model import AlphaDiffractMultiscaleLightning from model.wav2vec2_model import AlphaDiffractWav2Vec2Lightning @@ -86,7 +86,7 @@ def build_datamodule_from_cfg(cfg: Dict[str, Any]) -> NpyDataModule: return dm -def build_model_from_cfg(cfg: Dict[str, Any]) -> AlphaDiffractLightning: +def build_model_from_cfg(cfg: Dict[str, Any]): model_type = cfg.get("model_type", "convnext").lower() if model_type == "wav2vec2": return AlphaDiffractWav2Vec2Lightning( @@ -132,17 +132,22 @@ def build_model_from_cfg(cfg: Dict[str, Any]) -> AlphaDiffractLightning: warmup_steps=cfg.get("warmup_steps", 5000), cosine_t_max=cfg.get("cosine_t_max", 112000), ) - else: - return AlphaDiffractLightning( - # ConvNeXt1D backbone (paper defaults) - depths=tuple(cfg["depths"]), - dims=tuple(cfg["dims"]), + elif model_type == "multiscale": + return AlphaDiffractMultiscaleLightning( + # Map OG-style multiscale CNN params from cfg (use dims as channels) + channels=tuple(cfg["dims"]), kernel_sizes=tuple(cfg["kernel_sizes"]), strides=tuple(cfg["strides"]), dropout_rate=cfg["dropout_rate"], + ramped_dropout_rate=cfg.get("ramped_dropout_rate", False), + block_type=cfg.get("block_type", "convnext"), + pooling_type=cfg.get("pooling_type", "average"), + final_pool=cfg.get("final_pool", True), + use_batchnorm=cfg.get("use_batchnorm", False), + output_type=cfg.get("output_type", "flatten"), layer_scale_init_value=cfg["layer_scale_init_value"], drop_path_rate=cfg["drop_path_rate"], - # Heads + # Heads (match OG JSON) head_dropout=cfg["head_dropout"], cs_hidden=tuple(cfg["cs_hidden"]), sg_hidden=tuple(cfg["sg_hidden"]), @@ -151,7 +156,7 @@ def build_model_from_cfg(cfg: Dict[str, Any]) -> AlphaDiffractLightning: num_cs_classes=cfg["num_cs_classes"], num_sg_classes=cfg["num_sg_classes"], num_lp_outputs=cfg["num_lp_outputs"], - # LP bounds and output handling + # LP bounds lp_bounds_min=tuple(cfg["lp_bounds_min"]), lp_bounds_max=tuple(cfg["lp_bounds_max"]), bound_lp_with_sigmoid=cfg["bound_lp_with_sigmoid"], @@ -159,14 +164,13 @@ def build_model_from_cfg(cfg: Dict[str, Any]) -> AlphaDiffractLightning: lambda_cs=cfg["lambda_cs"], lambda_sg=cfg["lambda_sg"], lambda_lp=cfg["lambda_lp"], - # Optional GEMD - gemd_mu=cfg["gemd_mu"], - gemd_distance_matrix_path=cfg.get("gemd_distance_matrix_path"), # Optimizer lr=cfg["lr"], weight_decay=cfg["weight_decay"], use_adamw=cfg["use_adamw"], ) + else: + raise ValueError(f"Unsupported model_type '{model_type}'. Expected 'multiscale' or 'wav2vec2'.") class ConfigArtifactLogger(Callback): From b6cdb491dfb956860914c32fd8612d16a17143b3 Mon Sep 17 00:00:00 2001 From: linked-liszt Date: Fri, 14 Nov 2025 10:56:00 -0600 Subject: [PATCH 14/18] feat: Pull in direct paper model modules --- configs/trainer_convnext_paper.yaml | 12 +-- src/trainer/model/model.py | 126 +++++++++++++++++++--------- src/trainer/train_paper.py | 19 ++++- 3 files changed, 111 insertions(+), 46 deletions(-) diff --git a/configs/trainer_convnext_paper.yaml b/configs/trainer_convnext_paper.yaml index 8526647..a8f0f2a 100644 --- a/configs/trainer_convnext_paper.yaml +++ b/configs/trainer_convnext_paper.yaml @@ -38,18 +38,18 @@ floor_at_zero: true normalize_log1p: False # paper used log1p preprocessing model_type: "multiscale" -# --- ConvNeXt (lightweight paper variant) --- +# --- ConvNeXt (OG-equivalent configuration) --- # 3 stages; one block per stage; large kernels; stride-5 downsampling -# NOTE: This implementation adapts ConvNeXt to 1D and uses global avg pooling. +# Matches OG multiscale_cnn_cls_regr_convnextBlock_LeakyReLU.json exactly depths: [1, 1, 1] dims: [80, 80, 80] kernel_sizes: [100, 50, 25] strides: [5, 5, 5] dropout_rate: 0.3 -layer_scale_init_value: 1.0e-6 -# Stochastic depth schedule across blocks; paper mentions 0.3 on stem branch -# We set an overall schedule up to 0.3 -drop_path_rate: 0.0 +# OG uses layer_scale_init_value=0 (disabled) +layer_scale_init_value: 0.0 +# OG uses constant drop_path_rate=0.3 (not ramped) +drop_path_rate: 0.3 ramped_dropout_rate: false block_type: "convnext" pooling_type: "average" diff --git a/src/trainer/model/model.py b/src/trainer/model/model.py index 8a3fc62..b13808e 100644 --- a/src/trainer/model/model.py +++ b/src/trainer/model/model.py @@ -43,13 +43,14 @@ def __init__( kernel_size: int = 7, drop_path: float = 0.0, layer_scale_init_value: float = 1e-6, + activation: nn.Module = nn.GELU, ): super().__init__() # depthwise 1D conv self.dwconv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding='same', groups=dim) # pointwise MLP implemented by Linear on channels-last self.pwconv1 = nn.Linear(dim, 4 * dim) - self.act = nn.LeakyReLU() + self.act = activation() if isinstance(activation, type) else activation self.pwconv2 = nn.Linear(4 * dim, dim) # layer-scale self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(dim)) if layer_scale_init_value > 0 else None @@ -71,6 +72,75 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +# ----------------------------- +# ConvNeXt Block Adaptor for Multiscale CNN +# Matches OG ConvNextBlock1DAdaptorForMultiscaleCNN behavior +# ----------------------------- +class ConvNextBlock1DAdaptor(nn.Module): + """ + Adaptor that wraps ConvNeXtBlock1D to match OG MultiscaleCNNBackbone behavior. + Handles channel adjustment, stride-based downsampling, and block application. + """ + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dropout: float = 0.0, + use_batchnorm: bool = False, + activation: nn.Module = nn.LeakyReLU, + layer_scale_init_value: float = 0.0, + drop_path_rate: float = 0.3, + block_type: str = "convnext", + ): + super().__init__() + + # Optional pointwise conv for channel adjustment (if in_channels != out_channels) + if in_channels != out_channels: + act = activation() if isinstance(activation, type) else activation + self.pwconv = nn.Sequential(nn.Linear(in_channels, out_channels), act) + else: + self.pwconv = None + + # ConvNeXt block (only if block_type == "convnext") + if block_type == "convnext": + self.block = ConvNeXtBlock1D( + dim=out_channels, + kernel_size=kernel_size, + drop_path=drop_path_rate, + layer_scale_init_value=layer_scale_init_value, + activation=activation, + ) + else: + self.block = None + + # Optional stride-based pooling for downsampling + if stride > 1: + self.reduction_pool = nn.AvgPool1d(kernel_size=stride, stride=stride) + else: + self.reduction_pool = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: (N, C_in, L) + + # Channel adjustment via pointwise conv (operates on channels-last) + if self.pwconv is not None: + x = x.permute(0, 2, 1) # (N, L, C_in) + x = self.pwconv(x) # (N, L, C_out) + x = x.permute(0, 2, 1) # (N, C_out, L) + + # Apply ConvNeXt block + if self.block is not None: + x = self.block(x) + + # Apply stride-based downsampling + if self.reduction_pool is not None: + x = self.reduction_pool(x) + + return x + + # ----------------------------- # Downsample: LN (channels-last) -> Conv1d (stride, channel increase) # ----------------------------- @@ -138,8 +208,8 @@ def __init__( use_batchnorm: bool = False, activation: nn.Module = nn.LeakyReLU, output_type: str = "flatten", - layer_scale_init_value: float = 1e-6, - drop_path_rate: float = 0.0, + layer_scale_init_value: float = 0.0, + drop_path_rate: float = 0.3, ): super().__init__() assert len(channels) == len(kernel_sizes) == len(strides), "channels, kernel_sizes, strides must match lengths" @@ -152,13 +222,6 @@ def __init__( else: dropout_per_stage = [dropout_rate] * len(channels) - # Stochastic depth schedule for ConvNeXt-like blocks - total_blocks = len(channels) - if drop_path_rate > 0.0 and total_blocks > 0: - dp_rates = np.linspace(0.0, drop_path_rate, total_blocks).tolist() - else: - dp_rates = [0.0] * max(total_blocks, 1) - # Select pooling module if pooling_type == "average": pool_cls = nn.AvgPool1d @@ -172,35 +235,20 @@ def __init__( layers: List[nn.Module] = [] in_ch = 1 for i, (out_ch, k, s) in enumerate(zip(channels, kernel_sizes, strides)): - stage_layers: List[nn.Module] = [] - - # Stage conv (stride-based downsampling) - stage_layers.append(nn.Conv1d(in_ch, out_ch, kernel_size=k, stride=s, padding=0, bias=not use_batchnorm)) - if use_batchnorm: - stage_layers.append(nn.BatchNorm1d(out_ch)) - # Activation and Dropout - act = activation() if isinstance(activation, type) else activation - stage_layers.append(act) - if dropout_per_stage[i] > 0.0: - stage_layers.append(nn.Dropout(p=float(dropout_per_stage[i]))) - - # Optional ConvNeXt-like refinement block operating at out_ch - if block_type == "convnext": - stage_layers.append( - ConvNeXtBlock1D( - dim=out_ch, - kernel_size=k, - drop_path=dp_rates[i] if i < len(dp_rates) else 0.0, - layer_scale_init_value=layer_scale_init_value, - ) - ) - elif block_type in ("single_conv", "double_conv"): - # Already performed the primary conv; keep as-is - pass - else: - raise ValueError(f"Invalid block_type '{block_type}'") - - layers.append(nn.Sequential(*stage_layers)) + # Build stage block matching OG ConvNextBlock1DAdaptorForMultiscaleCNN + stage_block = ConvNextBlock1DAdaptor( + in_channels=in_ch, + out_channels=out_ch, + kernel_size=k, + stride=s, + dropout=dropout_per_stage[i], + use_batchnorm=use_batchnorm, + activation=activation, + layer_scale_init_value=layer_scale_init_value, + drop_path_rate=drop_path_rate, + block_type=block_type, + ) + layers.append(stage_block) # Inter-stage pooling if i < len(channels) - 1 or final_pool: diff --git a/src/trainer/train_paper.py b/src/trainer/train_paper.py index 3f3d396..36ddf71 100644 --- a/src/trainer/train_paper.py +++ b/src/trainer/train_paper.py @@ -6,6 +6,7 @@ import yaml from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, Callback +import signal from pytorch_lightning.loggers import CSVLogger try: from pytorch_lightning.loggers import MLFlowLogger @@ -231,6 +232,22 @@ def on_fit_start(self, trainer, pl_module) -> None: pass +class MlflowShutdownCallback(Callback): + """ + Ensures MLflow runs are properly finalized on Ctrl-C (SIGINT) interruptions. + """ + def on_exception(self, trainer, pl_module, exception): + """Called when an exception occurs during training.""" + if isinstance(exception, KeyboardInterrupt): + logger = getattr(trainer, "logger", None) + if MLFlowLogger is not None and isinstance(logger, MLFlowLogger): + try: + logger.experiment.set_terminated(logger.run_id, status="KILLED") + print("\nMLflow run terminated with status KILLED") + except Exception as e: + print(f"\nWarning: Could not terminate MLflow run: {e}") + + def build_trainer_from_cfg(cfg: Dict[str, Any], raw_config_path: Optional[str] = None) -> Trainer: ckpt_cb = ModelCheckpoint( monitor=cfg["monitor"], @@ -264,7 +281,7 @@ def build_trainer_from_cfg(cfg: Dict[str, Any], raw_config_path: Optional[str] = devices=cfg["devices"], precision=cfg["precision"], accumulate_grad_batches=cfg["accumulate_grad_batches"], - callbacks=[ckpt_cb, lr_cb, ConfigArtifactLogger(raw_config_path), RunCheckpointDirCallback()], + callbacks=[ckpt_cb, lr_cb, ConfigArtifactLogger(raw_config_path), RunCheckpointDirCallback(), MlflowShutdownCallback()], logger=logger, log_every_n_steps=cfg["log_every_n_steps"], deterministic=cfg["deterministic"], From 1403e9914ea5e43671551f9604504a5b0ebc25d5 Mon Sep 17 00:00:00 2001 From: linked-liszt Date: Mon, 17 Nov 2025 21:37:07 -0600 Subject: [PATCH 15/18] feat: Remove wav2vec model for production --- configs/trainer_wav2vec2.yaml | 129 ------ configs/trainer_wav2vec2_nonorm.yaml | 130 ------ src/trainer/infer_rruff.py | 11 +- src/trainer/model/wav2vec2_model.py | 605 --------------------------- src/trainer/train_paper.py | 49 +-- 5 files changed, 6 insertions(+), 918 deletions(-) delete mode 100644 configs/trainer_wav2vec2.yaml delete mode 100644 configs/trainer_wav2vec2_nonorm.yaml delete mode 100644 src/trainer/model/wav2vec2_model.py diff --git a/configs/trainer_wav2vec2.yaml b/configs/trainer_wav2vec2.yaml deleted file mode 100644 index 9022a0e..0000000 --- a/configs/trainer_wav2vec2.yaml +++ /dev/null @@ -1,129 +0,0 @@ -# AlphaDiffract trainer configuration for Wav2Vec2-style backbone (8192-length signals) -# Use with: python -m trainer.train_paper configs/trainer_wav2vec2.yaml - -# --- Data / Manifests --- -manifest_dir: "../../data/manifests" -extra_val_file: "rruff_sim.jsonl" -dataset_root: "../../data/dataset" -auto_generate_manifests: true -train_ratio: 0.8 -val_ratio: 0.1 -test_ratio: 0.1 -seed: 42 - -# --- DataLoader --- -# Large dataset (14M samples, 150k materials) — start with batch_size=256 and tune later -batch_size: 200 -num_workers: 8 -pin_memory: true -persistent_workers: true - -# --- Dataset label extraction (embedded in .npy/.npz) --- -validate_paths: false -extract_labels: true -allow_pickle: true -labels_key_map: - x: "dp" - cs: "cs" - sg: "sg" - lattice_params: null - lp_a: "_cell_length_a" - lp_b: "_cell_length_b" - lp_c: "_cell_length_c" - lp_alpha: "_cell_angle_alpha" - lp_beta: "_cell_angle_beta" - lp_gamma: "_cell_angle_gamma" -dtype: "float32" -mmap_mode: null -floor_at_zero: true # counts are non-negative -normalize_log1p: true # compress dynamic range for stability - -# --- Noise augmentation (training split only; matches paper) --- -# If provided, noise is applied dynamically per-sample in the DataModule using the same -# sequencing as the paper: Poisson -> normalize -> add Gaussian -> renormalize -> rescale. -# Set ranges to None to disable. -noise_poisson_range: [1.0, 100.0] # λ_max ~ Uniform(1, 100) -noise_gaussian_range: [0.001, 0.1] # σ_rel ~ Uniform(1e-3, 1e-1) - -# --- Model selection --- -model_type: "wav2vec2" # choose between "convnext" and "wav2vec2" - -# --- Wav2Vec2-style backbone (defaults tailored for 8192 inputs and more tokens) --- -in_chans: 1 -d_model: 512 -n_heads: 8 -num_layers: 8 -ff_dim: 2048 -# Convolutional feature extractor — total stride = 16 -> ~512 tokens for length 8192 -conv_kernel_sizes: [10, 5, 3, 3, 3, 2] -conv_strides: [2, 2, 2, 2, 1, 1] -conv_dropout: 0.0 -pos_kernel_size: 129 -pos_dropout: 0.1 -encoder_dropout: 0.1 -layer_norm_first: true -token_pool: "cls" # "mean" or "cls" - -# Heads (reuse paper defaults) -head_dropout: 0.2 -cs_hidden: [2300, 1150] -sg_hidden: [2300, 1150] -lp_hidden: [512, 256] - -# Task sizes -num_cs_classes: 7 -num_sg_classes: 230 -num_lp_outputs: 6 - -# LP output bounds -lp_bounds_min: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] -lp_bounds_max: [500.0, 500.0, 500.0, 180.0, 180.0, 180.0] -bound_lp_with_sigmoid: true - -# Loss weights -lambda_cs: 1.0 -lambda_sg: 1.0 -lambda_lp: 0.0 - -# Optional GEMD term on SG -gemd_mu: 0.0 -gemd_distance_matrix_path: - -# Optimizer -lr: 0.0004 # moderately higher base LR with warmup; tune as needed -weight_decay: 0.01 -use_adamw: true -gradient_clip_val: 1.0 -gradient_clip_algorithm: "norm" - -# --- Scheduler --- -warmup_steps: 6000 -cosine_t_max: 112000 - -# --- Logging --- -logger: "mlflow" # 'csv' or 'mlflow' -csv_logger_name: "model_logs_wav2vec2" -mlflow_experiment_name: "OpenAlphaDiffract_Wav2Vec2" -mlflow_tracking_uri: null -mlflow_run_name: "Wav2Vec2_Run" - -# --- Trainer settings --- -default_root_dir: "outputs/wav2vec2_model" -max_epochs: 50 -accumulate_grad_batches: 1 -precision: "bf16-mixed" # good default for H100/A100; switch to '16-mixed' if needed -accelerator: "gpu" -devices: 1 -log_every_n_steps: 50 -deterministic: false -benchmark: true - -# --- Checkpointing --- -monitor: "val/loss" -mode: "min" -save_top_k: 1 -every_n_epochs: 1 - -# --- Evaluation --- -resume_from: -test_after_train: true diff --git a/configs/trainer_wav2vec2_nonorm.yaml b/configs/trainer_wav2vec2_nonorm.yaml deleted file mode 100644 index 2e778e6..0000000 --- a/configs/trainer_wav2vec2_nonorm.yaml +++ /dev/null @@ -1,130 +0,0 @@ -# AlphaDiffract trainer configuration for Wav2Vec2-style backbone (8192-length signals) -# No log1p normalization (normalize_log1p: false) for A/B comparison against real-world RRUFF data -# Use with: PYTHONPATH=src python -m trainer.train_paper configs/trainer_wav2vec2_nonorm.yaml - -# --- Data / Manifests --- -manifest_dir: "../../../ad_data/manifest_original" -dataset_root: "../../../ad_data/data/dataset" -extra_val_file: "rruff.jsonl" -auto_generate_manifests: true -train_ratio: 0.8 -val_ratio: 0.1 -test_ratio: 0.1 -seed: 42 - -# --- DataLoader --- -batch_size: 180 -num_workers: 8 -pin_memory: true -persistent_workers: true - -# --- Dataset label extraction (embedded in .npy/.npz) --- -validate_paths: false -extract_labels: true -allow_pickle: true -labels_key_map: - x: "dp" - cs: "cs" - sg: "sg" - lattice_params: null - lp_a: "_cell_length_a" - lp_b: "_cell_length_b" - lp_c: "_cell_length_c" - lp_alpha: "_cell_angle_alpha" - lp_beta: "_cell_angle_beta" - lp_gamma: "_cell_angle_gamma" -dtype: "float32" -mmap_mode: null -floor_at_zero: true # counts are non-negative -normalize_log1p: false # DISABLED for this run - -# --- Noise augmentation (training split only; matches paper) --- -# If provided, noise is applied dynamically per-sample in the DataModule using the same -# sequencing as the paper: Poisson -> normalize -> add Gaussian -> renormalize -> rescale. -# Set ranges to None to disable. -noise_poisson_range: [1.0, 100.0] -noise_gaussian_range: [0.001, 0.1] -standardize_to: [0.0, 100.0] - -# --- Model selection --- -model_type: "wav2vec2" # choose between "convnext" and "wav2vec2" - -# --- Wav2Vec2-style backbone (defaults tailored for 8192 inputs and more tokens) --- -in_chans: 1 -d_model: 512 -n_heads: 8 -num_layers: 8 -ff_dim: 2048 -# Convolutional feature extractor — total stride = 16 -> ~512 tokens for length 8192 -conv_kernel_sizes: [10, 5, 3, 3, 3, 2] -conv_strides: [2, 2, 2, 2, 1, 1] -conv_dropout: 0.05 -pos_kernel_size: 129 -pos_dropout: 0.2 -encoder_dropout: 0.2 -layer_norm_first: true -token_pool: "cls" # "mean" or "cls" - -# Heads (reuse paper defaults) -head_dropout: 0.4 -cs_hidden: [2300, 1150] -sg_hidden: [2300, 1150] -lp_hidden: [512, 256] - -# Task sizes -num_cs_classes: 7 -num_sg_classes: 230 -num_lp_outputs: 6 - -# LP output bounds -lp_bounds_min: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] -lp_bounds_max: [500.0, 500.0, 500.0, 180.0, 180.0, 180.0] -bound_lp_with_sigmoid: true - -# Loss weights -lambda_cs: 1.0 -lambda_sg: 1.0 -lambda_lp: 0.0 - -# Optional GEMD term on SG -gemd_mu: 0.0 -gemd_distance_matrix_path: - -# Optimizer -lr: 0.0004 -weight_decay: 0.02 -use_adamw: true -gradient_clip_val: 1.0 -gradient_clip_algorithm: "norm" - -# --- Scheduler --- -warmup_steps: 6000 -cosine_t_max: 112000 - -# --- Logging --- -logger: "mlflow" # 'csv' or 'mlflow' -csv_logger_name: "model_logs_wav2vec2_nonorm" -mlflow_experiment_name: "OpenAlphaDiffract_Wav2Vec2_NoNorm" -mlflow_tracking_uri: null -mlflow_run_name: "Wav2Vec2_Run_NoNorm" - -# --- Trainer settings --- -default_root_dir: "outputs/wav2vec2_model_nonorm" -max_epochs: 50 -accumulate_grad_batches: 1 -precision: "bf16-mixed" -accelerator: "gpu" -devices: 1 -log_every_n_steps: 50 -deterministic: false -benchmark: true - -# --- Checkpointing --- -monitor: "val/loss" -mode: "min" -save_top_k: 1 -every_n_epochs: 1 - -# --- Evaluation --- -resume_from: -test_after_train: true diff --git a/src/trainer/infer_rruff.py b/src/trainer/infer_rruff.py index 3e43a9e..0946b46 100644 --- a/src/trainer/infer_rruff.py +++ b/src/trainer/infer_rruff.py @@ -17,9 +17,8 @@ matplotlib.use("Agg") # non-interactive backend import matplotlib.pyplot as plt -# Expect PYTHONPATH=src; run with: python -m trainer.infer_rruff configs/trainer_wav2vec2.yaml --ckpt /path/to.ckpt +# Expect PYTHONPATH=src; run with: python -m trainer.infer_rruff configs/trainer_convnext_paper.yaml --ckpt /path/to.ckpt from dataset import NpyDataModule -from model.wav2vec2_model import AlphaDiffractWav2Vec2Lightning from model.model import AlphaDiffractMultiscaleLightning @@ -84,12 +83,10 @@ def build_datamodule_from_cfg(cfg: Dict[str, Any]) -> NpyDataModule: def build_model_class_from_cfg(cfg: Dict[str, Any]): model_type = cfg.get("model_type", "convnext").lower() - if model_type == "wav2vec2": - return AlphaDiffractWav2Vec2Lightning - elif model_type == "multiscale": + if model_type == "multiscale": return AlphaDiffractMultiscaleLightning else: - raise ValueError(f"Unsupported model_type '{model_type}'. Expected 'wav2vec2' or 'multiscale'.") + raise ValueError(f"Unsupported model_type '{model_type}'. Expected 'multiscale'.") # ----------------------------- @@ -324,7 +321,7 @@ def run_inference_and_confusion(cfg_path: str, ckpt_path: Optional[str], task: s def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description="Run inference on a chosen split (RRUFF/val/test/train) and plot confusion matrix") - p.add_argument("config", type=str, help="Path to trainer config YAML (e.g., configs/trainer_wav2vec2.yaml)") + p.add_argument("config", type=str, help="Path to trainer config YAML (e.g., configs/trainer_convnext_paper.yaml)") p.add_argument("--ckpt", type=str, default=None, help="Path to Lightning checkpoint (.ckpt). If omitted, attempts auto-discovery.") p.add_argument("--task", type=str, default="cs", choices=["cs", "sg"], help="Which head to evaluate: cs (7 classes) or sg (230 classes)") p.add_argument("--split", type=str, default="rruff", choices=["rruff", "val", "test", "train"], help="Dataset split to evaluate") diff --git a/src/trainer/model/wav2vec2_model.py b/src/trainer/model/wav2vec2_model.py deleted file mode 100644 index ecf6e61..0000000 --- a/src/trainer/model/wav2vec2_model.py +++ /dev/null @@ -1,605 +0,0 @@ -from typing import Dict, Tuple, Optional, List, Union - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import pytorch_lightning as pl - - -# ============================= -# Wav2Vec2-style 1D backbone for 8192-length XRD signals -# ============================= - -def _same_padding(kernel_size: int) -> int: - # For odd kernel sizes, this approximates "same" padding in Conv1d - return kernel_size // 2 - - -class ConvFeatureExtractor1D(nn.Module): - """ - Wav2Vec2-style convolutional feature extractor for 1D signals. - - This stack reduces the raw sequence length to a token sequence length via strided 1D convolutions. - Defaults are chosen to produce an overall stride of 64 for input length 8192, i.e., ~128 tokens. - - Args: - in_chans: Input channels (1 for intensity-only). - conv_dim: Number of channels in conv feature maps (often equals d_model). - kernel_sizes: List of kernel sizes for each conv stage. - strides: List of strides for each conv stage (same length as kernel_sizes). - activation: Nonlinearity to apply after each conv (default: GELU). - dropout: Dropout applied after activation in each stage. - """ - def __init__( - self, - in_chans: int = 1, - conv_dim: int = 512, - kernel_sizes: Tuple[int, ...] = (10, 5, 3, 3, 3, 2), - strides: Tuple[int, ...] = (2, 2, 2, 2, 2, 2), - activation: Optional[nn.Module] = None, - dropout: float = 0.0, - ): - super().__init__() - assert len(kernel_sizes) == len(strides), "kernel_sizes and strides must have same length" - layers: List[nn.Module] = [] - c_in = in_chans - act = activation if activation is not None else nn.GELU() - for i, (k, s) in enumerate(zip(kernel_sizes, strides)): - layers.append(nn.Conv1d(c_in, conv_dim, kernel_size=k, stride=s, padding=_same_padding(k))) - layers.append(act) - if dropout and dropout > 0.0: - layers.append(nn.Dropout(p=dropout)) - c_in = conv_dim - self.net = nn.Sequential(*layers) - self.out_dim = conv_dim - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # x: (N, L) or (N, 1, L) - if x.ndim == 2: - x = x[:, None, :] - return self.net(x) # (N, C=conv_dim, T) - - -class PositionalConvEmbedding(nn.Module): - """ - Convolutional positional embedding as used in wav2vec2: - depthwise Conv1d over features with GELU and dropout, added to token sequence. - """ - def __init__(self, d_model: int, kernel_size: int = 128, groups: Optional[int] = None, dropout: float = 0.1): - super().__init__() - g = groups if groups is not None else d_model - self.pos_conv = nn.Conv1d( - d_model, d_model, kernel_size=kernel_size, padding='same', groups=g - ) - self.activation = nn.GELU() - self.dropout = nn.Dropout(dropout) - - def forward(self, tokens: torch.Tensor) -> torch.Tensor: - # tokens: (N, T, D) - x = tokens.transpose(1, 2) # (N, D, T) - x = self.pos_conv(x) - x = self.activation(x) - x = x.transpose(1, 2) # (N, T, D) - return tokens + self.dropout(x) - - -class Wav2Vec2Backbone1D(nn.Module): - """ - Wav2Vec2-style backbone adapted for 1D XRD signals. - - Pipeline: - - Convolutional feature extractor: reduces length to token sequence (N, T, D) - - Convolutional positional embedding - - Transformer encoder stack (batch_first), dropout - - Feature pooling: mean over tokens -> (N, D) - - Defaults are chosen for 8192-length inputs to yield ~128 tokens (overall stride 64). - """ - def __init__( - self, - in_chans: int = 1, - d_model: int = 512, - n_heads: int = 8, - num_layers: int = 8, - ff_dim: Optional[int] = None, - conv_kernel_sizes: Tuple[int, ...] = (10, 5, 3, 3, 3, 2), - conv_strides: Tuple[int, ...] = (2, 2, 2, 2, 2, 2), - conv_dropout: float = 0.0, - pos_kernel_size: int = 128, - pos_dropout: float = 0.1, - encoder_dropout: float = 0.1, - layer_norm_first: bool = False, - token_pool: str = "mean", # "mean" or "cls" (mean by default) - ): - super().__init__() - ff_dim = ff_dim or (4 * d_model) - self.token_pool = token_pool - - # Conv feature extractor -> (N, D, T) - self.feature_extractor = ConvFeatureExtractor1D( - in_chans=in_chans, - conv_dim=d_model, - kernel_sizes=conv_kernel_sizes, - strides=conv_strides, - dropout=conv_dropout, - ) - - # Positional conv embedding - self.pos_embed = PositionalConvEmbedding(d_model=d_model, kernel_size=pos_kernel_size, dropout=pos_dropout) - # Pre-encoder token normalization (helps Transformer stability on non-centered inputs) - self.input_ln = nn.LayerNorm(d_model) - - # Transformer encoder - encoder_layer = nn.TransformerEncoderLayer( - d_model=d_model, - nhead=n_heads, - dim_feedforward=ff_dim, - dropout=encoder_dropout, - activation="gelu", - batch_first=True, - norm_first=layer_norm_first, - ) - self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) - self.encoder_ln = nn.LayerNorm(d_model) - self.dim_output = d_model - - # Optional CLS token if desired - if self.token_pool == "cls": - self.cls = nn.Parameter(torch.zeros(1, 1, d_model)) - else: - self.register_parameter("cls", None) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # x: (N, L) or (N, 1, L) - feats = self.feature_extractor(x) # (N, D, T) - tokens = feats.transpose(1, 2) # (N, T, D) - - if self.cls is not None: - cls_tok = self.cls.expand(tokens.size(0), -1, -1) # (N, 1, D) - tokens = torch.cat([cls_tok, tokens], dim=1) # (N, 1+T, D) - - tokens = self.pos_embed(tokens) # (N, T, D) - tokens = self.input_ln(tokens) # (N, T, D) - enc = self.encoder(tokens) # (N, T, D) - enc = self.encoder_ln(enc) # (N, T, D) - - if self.token_pool == "cls" and self.cls is not None: - pooled = enc[:, 0, :] # (N, D) - else: - pooled = enc.mean(dim=1) # (N, D) - return pooled - - -# ============================= -# Heads + Lightning module (mirrors existing AlphaDiffractLightning API) -# ============================= - -def make_mlp( - input_dim: int, - hidden_dims: Optional[Tuple[int, ...]], - output_dim: int, - dropout: float = 0.2, - output_activation: Optional[nn.Module] = None, -) -> nn.Module: - layers: List[nn.Module] = [] - last = input_dim - if hidden_dims is not None and len(hidden_dims) > 0: - for hd in hidden_dims: - layers.extend([nn.Linear(last, hd), nn.LeakyReLU()]) - if dropout and dropout > 0: - layers.append(nn.Dropout(dropout)) - last = hd - layers.append(nn.Linear(last, output_dim)) - if output_activation is not None: - layers.append(output_activation) - return nn.Sequential(*layers) - - -BatchType = Union[ - Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], - Dict[str, torch.Tensor], -] - - -class AlphaDiffractWav2Vec2Lightning(pl.LightningModule): - """ - AlphaDiffract variant with a Wav2Vec2-style 1D backbone. - - Same heads and multi-task losses as AlphaDiffractLightning: - - CS classifier head (7 classes) - - SG classifier head (230 classes) - - LP regressor head (6 outputs, optionally bounded to [min, max] via sigmoid) - """ - - def __init__( - self, - # Backbone (defaults tailored for 8192 input, ~128 tokens) - in_chans: int = 1, - d_model: int = 512, - n_heads: int = 8, - num_layers: int = 8, - ff_dim: Optional[int] = None, - conv_kernel_sizes: Tuple[int, ...] = (10, 5, 3, 3, 3, 2), - conv_strides: Tuple[int, ...] = (2, 2, 2, 2, 2, 2), # total stride 64 - conv_dropout: float = 0.0, - pos_kernel_size: int = 128, - pos_dropout: float = 0.1, - encoder_dropout: float = 0.1, - layer_norm_first: bool = False, - token_pool: str = "mean", - - # Heads - head_dropout: float = 0.5, - cs_hidden: Optional[Tuple[int, ...]] = (2300, 1150), - sg_hidden: Optional[Tuple[int, ...]] = (2300, 1150), - lp_hidden: Optional[Tuple[int, ...]] = (512, 256), - - # Task sizes - num_cs_classes: int = 7, - num_sg_classes: int = 230, - num_lp_outputs: int = 6, - - # LP bounding - lp_bounds_min: Tuple[float, float, float, float, float, float] = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0), - lp_bounds_max: Tuple[float, float, float, float, float, float] = (500.0, 500.0, 500.0, 180.0, 180.0, 180.0), - bound_lp_with_sigmoid: bool = True, - - # Loss weights - lambda_cs: float = 1.0, - lambda_sg: float = 1.0, - lambda_lp: float = 1.0, - # Optional GEMD for SG - gemd_mu: float = 0.0, - gemd_distance_matrix_path: Optional[str] = None, - - # Optimizer - lr: float = 2e-4, - weight_decay: float = 1e-2, - use_adamw: bool = True, - # Scheduler (step-based warmup + cosine decay) - warmup_steps: int = 5000, - cosine_t_max: int = 112000, - ): - super().__init__() - self.save_hyperparameters() - - # Backbone - self.backbone = Wav2Vec2Backbone1D( - in_chans=in_chans, - d_model=d_model, - n_heads=n_heads, - num_layers=num_layers, - ff_dim=ff_dim, - conv_kernel_sizes=conv_kernel_sizes, - conv_strides=conv_strides, - conv_dropout=conv_dropout, - pos_kernel_size=pos_kernel_size, - pos_dropout=pos_dropout, - encoder_dropout=encoder_dropout, - layer_norm_first=layer_norm_first, - token_pool=token_pool, - ) - feat_dim = self.backbone.dim_output - - # Heads - self.cs_head = make_mlp( - input_dim=feat_dim, - hidden_dims=cs_hidden, - output_dim=num_cs_classes, - dropout=head_dropout, - output_activation=None, - ) - self.sg_head = make_mlp( - input_dim=feat_dim, - hidden_dims=sg_hidden, - output_dim=num_sg_classes, - dropout=head_dropout, - output_activation=None, - ) - self.lp_head = make_mlp( - input_dim=feat_dim, - hidden_dims=lp_hidden, - output_dim=num_lp_outputs, - dropout=head_dropout, - output_activation=None, - ) - - # Losses - self.ce = nn.CrossEntropyLoss() - self.mse = nn.MSELoss() - - # LP bounds - self.register_buffer("lp_min", torch.tensor(lp_bounds_min, dtype=torch.float32)) - self.register_buffer("lp_max", torch.tensor(lp_bounds_max, dtype=torch.float32)) - self.bound_lp_with_sigmoid = bound_lp_with_sigmoid - - # weights - self.lambda_cs = lambda_cs - self.lambda_sg = lambda_sg - self.lambda_lp = lambda_lp - - # Optimizer config - self.lr = lr - self.weight_decay = weight_decay - self.use_adamw = use_adamw - # Scheduler params from constructor (exposed via config) - self.warmup_steps = warmup_steps - self.cosine_t_max = cosine_t_max - - # Task sizes - self.num_cs_classes = num_cs_classes - self.num_sg_classes = num_sg_classes - self.num_lp_outputs = num_lp_outputs - - # GEMD setup (optional) - self.gemd_mu = gemd_mu - self.register_buffer("gemd_D", torch.empty(0)) - if gemd_distance_matrix_path is not None: - D_np = np.load(gemd_distance_matrix_path) - D_t = torch.as_tensor(D_np, dtype=torch.float32) - if D_t.ndim != 2 or D_t.shape[0] != self.num_sg_classes or D_t.shape[1] != self.num_sg_classes: - raise ValueError("GEMD distance matrix must be of shape (num_sg_classes, num_sg_classes)") - self.register_buffer("gemd_D", D_t) - - # ----------------------------- - # Forward - # ----------------------------- - def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: - feats = self.backbone(x) # (N, D) - cs_logits = self.cs_head(feats) - sg_logits = self.sg_head(feats) - lp = self.lp_head(feats) - - if self.bound_lp_with_sigmoid: - lp = torch.sigmoid(lp) * (self.lp_max - self.lp_min) + self.lp_min - - return { - "features": feats, - "cs_logits": cs_logits, - "sg_logits": sg_logits, - "lp": lp, - } - - # ----------------------------- - # Data parsing helpers - # ----------------------------- - @staticmethod - def _to_index(y: torch.Tensor, num_classes: int) -> torch.Tensor: - """ - Convert labels to 0-based class indices: - - Supports one-hot and integer labels. - - Clamp to [0, num_classes-1] to avoid out-of-range targets. - Assumes labels are already 0-based. - """ - if y.dim() > 1 and y.size(-1) > 1: - idx = y.argmax(dim=-1) - else: - idx = y.long() - with torch.no_grad(): - if idx.numel() > 0: - idx = idx.clamp(min=0, max=num_classes - 1) - return idx - - def _extract_batch(self, batch: BatchType) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - if isinstance(batch, (list, tuple)): - assert len(batch) >= 3, "Expected at least (x, cs, sg) in the batch tuple." - if len(batch) >= 4: - x, y_cs, y_sg, y_lp = batch[:4] - else: - x, y_cs, y_sg = batch[:3] - y_lp = None - elif isinstance(batch, dict): - x = batch.get("x", batch.get("xrd", batch.get("signal"))) - if x is None: - raise KeyError("Batch dict must contain 'x' or 'xrd' or 'signal'.") - y_cs = batch.get("cs") - y_sg = batch.get("sg") - y_lp = batch.get("lattice_params", batch.get("lp")) - if y_cs is None or y_sg is None: - raise KeyError("Batch dict must contain 'cs' and 'sg'. 'lattice_params' (or 'lp') is optional.") - else: - raise TypeError("Unsupported batch type. Use Tuple or Dict.") - - return x, y_cs, y_sg, y_lp - - # ----------------------------- - # Loss and metrics - # ----------------------------- - def _compute_losses_and_metrics( - self, preds: Dict[str, torch.Tensor], y_cs: torch.Tensor, y_sg: torch.Tensor, y_lp: Optional[torch.Tensor] - ) -> Dict[str, torch.Tensor]: - cs_logits = preds["cs_logits"] - sg_logits = preds["sg_logits"] - lp_pred = preds["lp"] - - y_cs_idx = self._to_index(y_cs, self.num_cs_classes) - y_sg_idx = self._to_index(y_sg, self.num_sg_classes) - if y_lp is not None: - y_lp = y_lp.float() - - loss_cs = self.ce(cs_logits, y_cs_idx) - loss_sg = self.ce(sg_logits, y_sg_idx) - loss_lp = self.mse(lp_pred, y_lp) if y_lp is not None else torch.tensor(0.0, device=cs_logits.device) - - loss_gemd = torch.tensor(0.0, device=cs_logits.device) - sg_probs = torch.softmax(sg_logits, dim=1) - if self.gemd_mu > 0.0 and self.gemd_D.numel() > 0: - D_rows = self.gemd_D[y_sg_idx] - gemd_per_sample = (D_rows * sg_probs).sum(dim=1) - loss_gemd = gemd_per_sample.mean() - - total_loss = ( - self.lambda_cs * loss_cs - + self.lambda_sg * loss_sg - + self.lambda_lp * loss_lp - + self.gemd_mu * loss_gemd - ) - - with torch.no_grad(): - cs_acc = (cs_logits.argmax(dim=1) == y_cs_idx).float().mean() - sg_acc = (sg_logits.argmax(dim=1) == y_sg_idx).float().mean() - # Top-5 accuracy for SG to detect early improvements before top-1 moves - sg_top5 = ( - sg_logits.topk(5, dim=1).indices.eq(y_sg_idx.unsqueeze(1)).any(dim=1).float().mean() - ) - if y_lp is not None: - lp_mae = (lp_pred - y_lp).abs().mean() - lp_mse = F.mse_loss(lp_pred, y_lp) - else: - lp_mae = None - lp_mse = None - - return { - "loss_total": total_loss, - "loss_cs": loss_cs, - "loss_sg": loss_sg, - "loss_lp": loss_lp, - "loss_gemd": loss_gemd, - "cs_acc": cs_acc, - "sg_acc": sg_acc, - "sg_top5": sg_top5, - "lp_mae": lp_mae, - "lp_mse": lp_mse, - } - - # ----------------------------- - # Lightning hooks - # ----------------------------- - def training_step(self, batch: BatchType, batch_idx: int) -> torch.Tensor: - x, y_cs, y_sg, y_lp = self._extract_batch(batch) - preds = self(x) - out = self._compute_losses_and_metrics(preds, y_cs, y_sg, y_lp) - - self.log("train/loss", out["loss_total"], prog_bar=True, on_step=True, on_epoch=True) - self.log("train/loss_cs", out["loss_cs"], on_step=True, on_epoch=True) - self.log("train/loss_sg", out["loss_sg"], on_step=True, on_epoch=True) - self.log("train/loss_lp", out["loss_lp"], on_step=True, on_epoch=True) - self.log("train/loss_gemd", out["loss_gemd"], on_step=True, on_epoch=True) - self.log("train/cs_acc", out["cs_acc"], prog_bar=True, on_step=True, on_epoch=True) - self.log("train/sg_acc", out["sg_acc"], on_step=True, on_epoch=True) - self.log("train/sg_top5", out["sg_top5"], on_step=True, on_epoch=True) - if out.get("lp_mae") is not None: - self.log("train/lp_mae", out["lp_mae"], on_step=True, on_epoch=True) - if out.get("lp_mse") is not None: - self.log("train/lp_mse", out["lp_mse"], on_step=True, on_epoch=True) - - return out["loss_total"] - - def validation_step(self, batch: BatchType, batch_idx: int, dataloader_idx: Optional[int] = None) -> None: - x, y_cs, y_sg, y_lp = self._extract_batch(batch) - preds = self(x) - out = self._compute_losses_and_metrics(preds, y_cs, y_sg, y_lp) - - prefix = "val" if (dataloader_idx is None or dataloader_idx == 0) else "val_rruff" - - self.log(f"{prefix}/loss", out["loss_total"], prog_bar=True, on_epoch=True, add_dataloader_idx=False) - self.log(f"{prefix}/loss_cs", out["loss_cs"], on_epoch=True, add_dataloader_idx=False) - self.log(f"{prefix}/loss_sg", out["loss_sg"], on_epoch=True, add_dataloader_idx=False) - self.log(f"{prefix}/loss_lp", out["loss_lp"], on_epoch=True, add_dataloader_idx=False) - self.log(f"{prefix}/loss_gemd", out["loss_gemd"], on_epoch=True, add_dataloader_idx=False) - self.log(f"{prefix}/cs_acc", out["cs_acc"], prog_bar=True, on_epoch=True, add_dataloader_idx=False) - self.log(f"{prefix}/sg_acc", out["sg_acc"], on_epoch=True, add_dataloader_idx=False) - self.log(f"{prefix}/sg_top5", out["sg_top5"], on_epoch=True, add_dataloader_idx=False) - if out.get("lp_mae") is not None: - self.log(f"{prefix}/lp_mae", out["lp_mae"], on_epoch=True, add_dataloader_idx=False) - if out.get("lp_mse") is not None: - self.log(f"{prefix}/lp_mse", out["lp_mse"], on_epoch=True, add_dataloader_idx=False) - - def test_step(self, batch: BatchType, batch_idx: int) -> None: - x, y_cs, y_sg, y_lp = self._extract_batch(batch) - preds = self(x) - out = self._compute_losses_and_metrics(preds, y_cs, y_sg, y_lp) - - self.log("test/loss", out["loss_total"], prog_bar=True, on_epoch=True) - self.log("test/loss_cs", out["loss_cs"], on_epoch=True) - self.log("test/loss_sg", out["loss_sg"], on_epoch=True) - self.log("test/loss_lp", out["loss_lp"], on_epoch=True) - self.log("test/loss_gemd", out["loss_gemd"], on_epoch=True) - self.log("test/cs_acc", out["cs_acc"], prog_bar=True, on_epoch=True) - self.log("test/sg_acc", out["sg_acc"], on_epoch=True) - self.log("test/sg_top5", out["sg_top5"], on_epoch=True) - if out.get("lp_mae") is not None: - self.log("test/lp_mae", out["lp_mae"], on_epoch=True) - if out.get("lp_mse") is not None: - self.log("test/lp_mse", out["lp_mse"], on_epoch=True) - - def configure_optimizers(self): - params = self.parameters() - # AdamW with transformer-friendly betas/eps - if self.use_adamw: - optimizer = torch.optim.AdamW(params, lr=self.lr, weight_decay=self.weight_decay, betas=(0.9, 0.98), eps=1e-8) - else: - optimizer = torch.optim.Adam(params, lr=self.lr, weight_decay=self.weight_decay, betas=(0.9, 0.98), eps=1e-8) - - # Step-based warmup then cosine decay (better for very large datasets) - warmup_steps = getattr(self, "warmup_steps", 20000) - cosine_t_max = getattr(self, "cosine_t_max", 200000) - - linear_warmup = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, total_iters=warmup_steps) - cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cosine_t_max, eta_min=self.lr * 0.01) - scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[linear_warmup, cosine], milestones=[warmup_steps]) - - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": scheduler, - "interval": "step", # apply scheduler every training step - "frequency": 1, - "name": "warmup_cosine_steps", - }, - } - - -# ----------------------------- -# Example factory for 1x8192 input with ~128 tokens -# ----------------------------- -def build_wav2vec2_alphadiffract_for_8192() -> AlphaDiffractWav2Vec2Lightning: - """ - Build a Wav2Vec2-style AlphaDiffract model for 1x8192 XRD input. - - Design choices for large dataset (~14M samples): - - Overall conv stride 64 -> ~128 tokens for 8192 input - - d_model=512, n_heads=8, num_layers=8 (scalable up) - - Standard MLP heads reused from ConvNeXt variant - - Defaults suitable for supervised training; self-supervised pretraining pipeline not included here - """ - return AlphaDiffractWav2Vec2Lightning( - # Backbone sizing - in_chans=1, - d_model=512, - n_heads=8, - num_layers=8, - ff_dim=2048, - conv_kernel_sizes=(10, 5, 3, 3, 3, 2), - conv_strides=(2, 2, 2, 2, 2, 2), # total stride 64 - conv_dropout=0.0, - pos_kernel_size=128, - pos_dropout=0.1, - encoder_dropout=0.1, - layer_norm_first=False, - token_pool="mean", - - # Heads - head_dropout=0.5, - cs_hidden=(2300, 1150), - sg_hidden=(2300, 1150), - lp_hidden=(512, 256), - - num_cs_classes=7, - num_sg_classes=230, - num_lp_outputs=6, - - lp_bounds_min=(0.0, 0.0, 0.0, 0.0, 0.0, 0.0), - lp_bounds_max=(500.0, 500.0, 500.0, 180.0, 180.0, 180.0), - bound_lp_with_sigmoid=True, - - lambda_cs=1.0, - lambda_sg=1.0, - lambda_lp=1.0, - - # Optimizer defaults (tune as needed) - lr=2e-4, - weight_decay=1e-2, - use_adamw=True, - ) diff --git a/src/trainer/train_paper.py b/src/trainer/train_paper.py index 36ddf71..ef337a0 100644 --- a/src/trainer/train_paper.py +++ b/src/trainer/train_paper.py @@ -16,7 +16,6 @@ # Project imports (expect PYTHONPATH=src or run via `python -m trainer.train_paper`) from dataset import NpyDataModule from model.model import AlphaDiffractMultiscaleLightning -from model.wav2vec2_model import AlphaDiffractWav2Vec2Lightning def parse_args() -> argparse.Namespace: @@ -89,51 +88,7 @@ def build_datamodule_from_cfg(cfg: Dict[str, Any]) -> NpyDataModule: def build_model_from_cfg(cfg: Dict[str, Any]): model_type = cfg.get("model_type", "convnext").lower() - if model_type == "wav2vec2": - return AlphaDiffractWav2Vec2Lightning( - # Backbone (Wav2Vec2-style; defaults tailored for 8192-length signals and large dataset) - in_chans=cfg.get("in_chans", 1), - d_model=cfg.get("d_model", 512), - n_heads=cfg.get("n_heads", 8), - num_layers=cfg.get("num_layers", 8), - ff_dim=cfg.get("ff_dim", 2048), - conv_kernel_sizes=tuple(cfg.get("conv_kernel_sizes", (10, 5, 3, 3, 3, 2))), - conv_strides=tuple(cfg.get("conv_strides", (2, 2, 2, 2, 2, 2))), # overall stride 64 -> ~128 tokens - conv_dropout=cfg.get("conv_dropout", 0.0), - pos_kernel_size=cfg.get("pos_kernel_size", 128), - pos_dropout=cfg.get("pos_dropout", 0.1), - encoder_dropout=cfg.get("encoder_dropout", 0.1), - layer_norm_first=cfg.get("layer_norm_first", False), - token_pool=cfg.get("token_pool", "mean"), - # Heads (reuse paper defaults) - head_dropout=cfg["head_dropout"], - cs_hidden=tuple(cfg["cs_hidden"]), - sg_hidden=tuple(cfg["sg_hidden"]), - lp_hidden=tuple(cfg["lp_hidden"]), - # Task sizes - num_cs_classes=cfg["num_cs_classes"], - num_sg_classes=cfg["num_sg_classes"], - num_lp_outputs=cfg["num_lp_outputs"], - # LP bounds - lp_bounds_min=tuple(cfg["lp_bounds_min"]), - lp_bounds_max=tuple(cfg["lp_bounds_max"]), - bound_lp_with_sigmoid=cfg["bound_lp_with_sigmoid"], - # Loss weights - lambda_cs=cfg["lambda_cs"], - lambda_sg=cfg["lambda_sg"], - lambda_lp=cfg["lambda_lp"], - # Optional GEMD - gemd_mu=cfg["gemd_mu"], - gemd_distance_matrix_path=cfg.get("gemd_distance_matrix_path"), - # Optimizer - lr=cfg["lr"], - weight_decay=cfg["weight_decay"], - use_adamw=cfg["use_adamw"], - # Scheduler - warmup_steps=cfg.get("warmup_steps", 5000), - cosine_t_max=cfg.get("cosine_t_max", 112000), - ) - elif model_type == "multiscale": + if model_type == "multiscale": return AlphaDiffractMultiscaleLightning( # Map OG-style multiscale CNN params from cfg (use dims as channels) channels=tuple(cfg["dims"]), @@ -171,7 +126,7 @@ def build_model_from_cfg(cfg: Dict[str, Any]): use_adamw=cfg["use_adamw"], ) else: - raise ValueError(f"Unsupported model_type '{model_type}'. Expected 'multiscale' or 'wav2vec2'.") + raise ValueError(f"Unsupported model_type '{model_type}'. Expected 'multiscale'.") class ConfigArtifactLogger(Callback): From 536b58dba3cc1728f72ac6a78f5bf60078a9f413 Mon Sep 17 00:00:00 2001 From: linked-liszt Date: Tue, 18 Nov 2025 14:48:14 -0600 Subject: [PATCH 16/18] fix: Prune configs and renaming --- configs/norm_train.yaml | 105 ------------------- configs/trainer.yaml | 85 +++++++++------ configs/trainer_convnext_paper.yaml | 127 ----------------------- src/trainer/{train_paper.py => train.py} | 0 4 files changed, 53 insertions(+), 264 deletions(-) delete mode 100644 configs/norm_train.yaml delete mode 100644 configs/trainer_convnext_paper.yaml rename src/trainer/{train_paper.py => train.py} (100%) diff --git a/configs/norm_train.yaml b/configs/norm_train.yaml deleted file mode 100644 index fb55418..0000000 --- a/configs/norm_train.yaml +++ /dev/null @@ -1,105 +0,0 @@ -# AlphaDiffract trainer configuration (paper-aligned defaults provided here) -# This file is required by src/trainer/train_paper.py. It contains all parameters with no script-side defaults. - -# --- Data / Manifests --- -manifest_dir: "../../data/manifests" -dataset_root: "../../data/dataset" # used when auto_generate_manifests is true -auto_generate_manifests: true -train_ratio: 0.8 -val_ratio: 0.1 -test_ratio: 0.1 -seed: 42 - -# --- DataLoader --- -batch_size: 512 # paper used 64 -num_workers: 8 -pin_memory: true -persistent_workers: true - -# --- Dataset label extraction (embedded in .npy/.npz) --- -validate_paths: false -extract_labels: true -allow_pickle: true -labels_key_map: - x: "dp" - cs: "cs" - sg: "sg" - lattice_params: null - lp_a: "_cell_length_a" - lp_b: "_cell_length_b" - lp_c: "_cell_length_c" - lp_alpha: "_cell_angle_alpha" - lp_beta: "_cell_angle_beta" - lp_gamma: "_cell_angle_gamma" -dtype: "float32" # one of: float32, float64, float16, bfloat16 -mmap_mode: null # NumPy memmap mode: 'r', 'r+', 'w+', or null to disable -floor_at_zero: True # Clamp negative counts to 0 before any normalization -normalize_log1p: True # If true, apply log1p(x) to compress dynamic range - -# --- Model architecture --- -depths: [1, 1, 1] -dims: [80, 80, 80] -kernel_sizes: [100, 50, 25] -strides: [5, 5, 5] -dropout_rate: 0.3 -layer_scale_init_value: 1.0e-6 - -# Heads -head_dropout: 0.2 -cs_hidden: [2300, 1150] -sg_hidden: [2300, 1150] -lp_hidden: [512, 256] - -# Task sizes -num_cs_classes: 7 -num_sg_classes: 230 -num_lp_outputs: 6 - -# LP output bounds -lp_bounds_min: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] -lp_bounds_max: [500.0, 500.0, 500.0, 180.0, 180.0, 180.0] -bound_lp_with_sigmoid: true - -# Loss weights -lambda_cs: 1.0 -lambda_sg: 1.0 -lambda_lp: 1.0 - -# Optional GEMD term on SG -gemd_mu: 0.0 -gemd_distance_matrix_path: # e.g., "path/to/space_group_distance_matrix.npy" to enable GEMD - -# Optimizer -lr: 0.002 # paper used 2e-4 -weight_decay: 0.01 # paper used 0.01 -use_adamw: true -gradient_clip_val: 1.0 -gradient_clip_algorithm: "norm" - -# --- Logging --- -logger: "mlflow" # 'csv' or 'mlflow' -csv_logger_name: "model_logs" -mlflow_experiment_name: "OpenAlphaDiffractNorm" -mlflow_tracking_uri: null # null uses MLflow default (file:./mlruns) -mlflow_run_name: "OpenAlphaDiffract_Run" - -# --- Trainer settings --- -default_root_dir: "outputs/model" -max_epochs: 50 -accumulate_grad_batches: 1 -precision: "bf16-mixed" # e.g., '32', '16-mixed', 'bf16-mixed' -accelerator: "gpu" -devices: 1 -log_every_n_steps: 50 -deterministic: false -benchmark: true - -# --- Checkpointing --- -monitor: "val/loss" -mode: "min" -save_top_k: 1 -every_n_epochs: 1 - -# --- Evaluation --- -resume_from: # e.g., "outputs/paper_model/checkpoints/epochXYZ.ckpt" -test_after_train: true diff --git a/configs/trainer.yaml b/configs/trainer.yaml index 213e643..a8f0f2a 100644 --- a/configs/trainer.yaml +++ b/configs/trainer.yaml @@ -1,9 +1,10 @@ -# AlphaDiffract trainer configuration (paper-aligned defaults provided here) -# This file is required by src/trainer/train_paper.py. It contains all parameters with no script-side defaults. +# AlphaDiffract trainer configuration — ConvNeXt (paper-matching lightweight variant) +# Use with: PYTHONPATH=src python -m trainer.train_paper configs/trainer_convnext_paper.yaml # --- Data / Manifests --- -manifest_dir: "../../data/manifests" -dataset_root: "../../data/dataset" # used when auto_generate_manifests is true +manifest_dir: "../../../ad_data/manifests" +dataset_root: "../../../ad_data/data/dataset" +extra_val_file: "rruff.jsonl" auto_generate_manifests: true train_ratio: 0.8 val_ratio: 0.1 @@ -11,7 +12,7 @@ test_ratio: 0.1 seed: 42 # --- DataLoader --- -batch_size: 256 # paper used 64 +batch_size: 64 # match OG run (64 per process) num_workers: 8 pin_memory: true persistent_workers: true @@ -31,22 +32,33 @@ labels_key_map: lp_alpha: "_cell_angle_alpha" lp_beta: "_cell_angle_beta" lp_gamma: "_cell_angle_gamma" -dtype: "float32" # one of: float32, float64, float16, bfloat16 -mmap_mode: null # NumPy memmap mode: 'r', 'r+', 'w+', or null to disable -floor_at_zero: True # Clamp negative counts to 0 before any normalization -normalize_log1p: True # If true, apply log1p(x) to compress dynamic range +dtype: "float32" +mmap_mode: null +floor_at_zero: true +normalize_log1p: False # paper used log1p preprocessing +model_type: "multiscale" -# --- Model architecture --- -depths: [3, 3, 9, 3] -dims: [80, 160, 320, 640] -kernel_sizes: [7, 7, 7, 7] -strides: [4, 2, 2, 2] +# --- ConvNeXt (OG-equivalent configuration) --- +# 3 stages; one block per stage; large kernels; stride-5 downsampling +# Matches OG multiscale_cnn_cls_regr_convnextBlock_LeakyReLU.json exactly +depths: [1, 1, 1] +dims: [80, 80, 80] +kernel_sizes: [100, 50, 25] +strides: [5, 5, 5] dropout_rate: 0.3 -layer_scale_init_value: 1.0e-6 -drop_path_rate: 0.1 +# OG uses layer_scale_init_value=0 (disabled) +layer_scale_init_value: 0.0 +# OG uses constant drop_path_rate=0.3 (not ramped) +drop_path_rate: 0.3 +ramped_dropout_rate: false +block_type: "convnext" +pooling_type: "average" +final_pool: true +use_batchnorm: false +output_type: "flatten" # Heads -head_dropout: 0.2 +head_dropout: 0.5 cs_hidden: [2300, 1150] sg_hidden: [2300, 1150] lp_hidden: [512, 256] @@ -58,7 +70,7 @@ num_lp_outputs: 6 # LP output bounds lp_bounds_min: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] -lp_bounds_max: [500.0, 500.0, 500.0, 180.0, 180.0, 180.0] +lp_bounds_max: [300.0, 300.0, 300.0, 180.0, 180.0, 180.0] bound_lp_with_sigmoid: true # Loss weights @@ -68,30 +80,39 @@ lambda_lp: 1.0 # Optional GEMD term on SG gemd_mu: 0.0 -gemd_distance_matrix_path: # e.g., "path/to/space_group_distance_matrix.npy" to enable GEMD +gemd_distance_matrix_path: null -# Optimizer -lr: 0.00015 # paper used 2e-4 -weight_decay: 0.01 # paper used 0.01 +# Optimizer (paper): AdamW, lr=2e-4, wd=0.01 +lr: 0.0002 +weight_decay: 0.01 use_adamw: true gradient_clip_val: 1.0 gradient_clip_algorithm: "norm" +# --- Noise augmentation (training split only; matches paper) --- +# If provided, noise is applied dynamically per-sample in the DataModule using the same +# sequencing as the paper: Poisson -> normalize -> add Gaussian -> renormalize -> rescale. +# Set ranges to None to disable. +noise_poisson_range: [1.0, 100.0] +noise_gaussian_range: [0.001, 0.1] + +# Standardize after noise to match OG CLI (--standardize-to 0 100) +standardize_to: [0.0, 100.0] # --- Logging --- -logger: "mlflow" # 'csv' or 'mlflow' -csv_logger_name: "model_logs" -mlflow_experiment_name: "OpenAlphaDiffract_ConvFUll" -mlflow_tracking_uri: null # null uses MLflow default (file:./mlruns) -mlflow_run_name: "OpenAlphaDiffract_Run" +logger: "mlflow" +csv_logger_name: "model_logs_convnext_paper" +mlflow_experiment_name: "AlphaDiffract_Paper_ConvNeXt" +mlflow_tracking_uri: null +mlflow_run_name: "ConvNeXt_Paper_Run" # --- Trainer settings --- -default_root_dir: "outputs/model" -max_epochs: 50 +default_root_dir: "outputs/convnext_paper" +max_epochs: 100 accumulate_grad_batches: 1 -precision: "bf16-mixed" # e.g., '32', '16-mixed', 'bf16-mixed' +precision: "32" # match OG (AMP disabled) accelerator: "gpu" devices: 1 -log_every_n_steps: 50 +log_every_n_steps: 200 deterministic: false benchmark: true @@ -102,5 +123,5 @@ save_top_k: 1 every_n_epochs: 1 # --- Evaluation --- -resume_from: # e.g., "outputs/paper_model/checkpoints/epochXYZ.ckpt" +resume_from: test_after_train: true diff --git a/configs/trainer_convnext_paper.yaml b/configs/trainer_convnext_paper.yaml deleted file mode 100644 index a8f0f2a..0000000 --- a/configs/trainer_convnext_paper.yaml +++ /dev/null @@ -1,127 +0,0 @@ -# AlphaDiffract trainer configuration — ConvNeXt (paper-matching lightweight variant) -# Use with: PYTHONPATH=src python -m trainer.train_paper configs/trainer_convnext_paper.yaml - -# --- Data / Manifests --- -manifest_dir: "../../../ad_data/manifests" -dataset_root: "../../../ad_data/data/dataset" -extra_val_file: "rruff.jsonl" -auto_generate_manifests: true -train_ratio: 0.8 -val_ratio: 0.1 -test_ratio: 0.1 -seed: 42 - -# --- DataLoader --- -batch_size: 64 # match OG run (64 per process) -num_workers: 8 -pin_memory: true -persistent_workers: true - -# --- Dataset label extraction (embedded in .npy/.npz) --- -validate_paths: false -extract_labels: true -allow_pickle: true -labels_key_map: - x: "dp" - cs: "cs" - sg: "sg" - lattice_params: null - lp_a: "_cell_length_a" - lp_b: "_cell_length_b" - lp_c: "_cell_length_c" - lp_alpha: "_cell_angle_alpha" - lp_beta: "_cell_angle_beta" - lp_gamma: "_cell_angle_gamma" -dtype: "float32" -mmap_mode: null -floor_at_zero: true -normalize_log1p: False # paper used log1p preprocessing -model_type: "multiscale" - -# --- ConvNeXt (OG-equivalent configuration) --- -# 3 stages; one block per stage; large kernels; stride-5 downsampling -# Matches OG multiscale_cnn_cls_regr_convnextBlock_LeakyReLU.json exactly -depths: [1, 1, 1] -dims: [80, 80, 80] -kernel_sizes: [100, 50, 25] -strides: [5, 5, 5] -dropout_rate: 0.3 -# OG uses layer_scale_init_value=0 (disabled) -layer_scale_init_value: 0.0 -# OG uses constant drop_path_rate=0.3 (not ramped) -drop_path_rate: 0.3 -ramped_dropout_rate: false -block_type: "convnext" -pooling_type: "average" -final_pool: true -use_batchnorm: false -output_type: "flatten" - -# Heads -head_dropout: 0.5 -cs_hidden: [2300, 1150] -sg_hidden: [2300, 1150] -lp_hidden: [512, 256] - -# Task sizes -num_cs_classes: 7 -num_sg_classes: 230 -num_lp_outputs: 6 - -# LP output bounds -lp_bounds_min: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] -lp_bounds_max: [300.0, 300.0, 300.0, 180.0, 180.0, 180.0] -bound_lp_with_sigmoid: true - -# Loss weights -lambda_cs: 1.0 -lambda_sg: 1.0 -lambda_lp: 1.0 - -# Optional GEMD term on SG -gemd_mu: 0.0 -gemd_distance_matrix_path: null - -# Optimizer (paper): AdamW, lr=2e-4, wd=0.01 -lr: 0.0002 -weight_decay: 0.01 -use_adamw: true -gradient_clip_val: 1.0 -gradient_clip_algorithm: "norm" - -# --- Noise augmentation (training split only; matches paper) --- -# If provided, noise is applied dynamically per-sample in the DataModule using the same -# sequencing as the paper: Poisson -> normalize -> add Gaussian -> renormalize -> rescale. -# Set ranges to None to disable. -noise_poisson_range: [1.0, 100.0] -noise_gaussian_range: [0.001, 0.1] - -# Standardize after noise to match OG CLI (--standardize-to 0 100) -standardize_to: [0.0, 100.0] -# --- Logging --- -logger: "mlflow" -csv_logger_name: "model_logs_convnext_paper" -mlflow_experiment_name: "AlphaDiffract_Paper_ConvNeXt" -mlflow_tracking_uri: null -mlflow_run_name: "ConvNeXt_Paper_Run" - -# --- Trainer settings --- -default_root_dir: "outputs/convnext_paper" -max_epochs: 100 -accumulate_grad_batches: 1 -precision: "32" # match OG (AMP disabled) -accelerator: "gpu" -devices: 1 -log_every_n_steps: 200 -deterministic: false -benchmark: true - -# --- Checkpointing --- -monitor: "val/loss" -mode: "min" -save_top_k: 1 -every_n_epochs: 1 - -# --- Evaluation --- -resume_from: -test_after_train: true diff --git a/src/trainer/train_paper.py b/src/trainer/train.py similarity index 100% rename from src/trainer/train_paper.py rename to src/trainer/train.py From 7d15173fa64c3cb958dcbb0d7c8be76cca5f6002 Mon Sep 17 00:00:00 2001 From: linked-liszt Date: Tue, 18 Nov 2025 17:44:58 -0600 Subject: [PATCH 17/18] refactor: Move to nested config structure for organization --- configs/trainer.yaml | 219 +++++++++++++++--------------- src/trainer/dataset/datamodule.py | 19 ++- src/trainer/dataset/dataset.py | 16 +-- src/trainer/model/model.py | 114 ++++++++-------- src/trainer/train.py | 198 ++++++++++++++++----------- 5 files changed, 296 insertions(+), 270 deletions(-) diff --git a/configs/trainer.yaml b/configs/trainer.yaml index a8f0f2a..7d41407 100644 --- a/configs/trainer.yaml +++ b/configs/trainer.yaml @@ -1,127 +1,124 @@ # AlphaDiffract trainer configuration — ConvNeXt (paper-matching lightweight variant) # Use with: PYTHONPATH=src python -m trainer.train_paper configs/trainer_convnext_paper.yaml -# --- Data / Manifests --- -manifest_dir: "../../../ad_data/manifests" -dataset_root: "../../../ad_data/data/dataset" -extra_val_file: "rruff.jsonl" -auto_generate_manifests: true -train_ratio: 0.8 -val_ratio: 0.1 -test_ratio: 0.1 -seed: 42 +data: + manifest_dir: "../../../ad_data/manifests" + dataset_root: "../../../ad_data/data/dataset" + extra_val_file: "rruff.jsonl" + auto_generate_manifests: true + train_ratio: 0.8 + val_ratio: 0.1 + test_ratio: 0.1 + seed: 42 -# --- DataLoader --- -batch_size: 64 # match OG run (64 per process) -num_workers: 8 -pin_memory: true -persistent_workers: true + loader: + # --- DataLoader --- + batch_size: 64 # match OG run (64 per process) + num_workers: 8 + pin_memory: true + persistent_workers: true + prefetch_factor: 2 + train_file: "train.jsonl" + val_file: "val.jsonl" + test_file: "test.jsonl" -# --- Dataset label extraction (embedded in .npy/.npz) --- -validate_paths: false -extract_labels: true -allow_pickle: true -labels_key_map: - x: "dp" - cs: "cs" - sg: "sg" - lattice_params: null - lp_a: "_cell_length_a" - lp_b: "_cell_length_b" - lp_c: "_cell_length_c" - lp_alpha: "_cell_angle_alpha" - lp_beta: "_cell_angle_beta" - lp_gamma: "_cell_angle_gamma" -dtype: "float32" -mmap_mode: null -floor_at_zero: true -normalize_log1p: False # paper used log1p preprocessing -model_type: "multiscale" + preprocessing: + validate_paths: false + extract_labels: true + allow_pickle: true + labels_key_map: + x: "dp" + cs: "cs" + sg: "sg" + lattice_params: null + lp_a: "_cell_length_a" + lp_b: "_cell_length_b" + lp_c: "_cell_length_c" + lp_alpha: "_cell_angle_alpha" + lp_beta: "_cell_angle_beta" + lp_gamma: "_cell_angle_gamma" + dtype: "float32" + mmap_mode: null + floor_at_zero: true + normalize_log1p: False # paper used log1p preprocessing -# --- ConvNeXt (OG-equivalent configuration) --- -# 3 stages; one block per stage; large kernels; stride-5 downsampling -# Matches OG multiscale_cnn_cls_regr_convnextBlock_LeakyReLU.json exactly -depths: [1, 1, 1] -dims: [80, 80, 80] -kernel_sizes: [100, 50, 25] -strides: [5, 5, 5] -dropout_rate: 0.3 -# OG uses layer_scale_init_value=0 (disabled) -layer_scale_init_value: 0.0 -# OG uses constant drop_path_rate=0.3 (not ramped) -drop_path_rate: 0.3 -ramped_dropout_rate: false -block_type: "convnext" -pooling_type: "average" -final_pool: true -use_batchnorm: false -output_type: "flatten" + augmentation: + noise_poisson_range: [1.0, 100.0] + noise_gaussian_range: [0.001, 0.1] + standardize_to: [0.0, 100.0] -# Heads -head_dropout: 0.5 -cs_hidden: [2300, 1150] -sg_hidden: [2300, 1150] -lp_hidden: [512, 256] +model: + type: "multiscale" + + backbone: + dim_in: 8192 + dims: [80, 80, 80] + kernel_sizes: [100, 50, 25] + strides: [5, 5, 5] + dropout_rate: 0.3 + layer_scale_init_value: 0.0 + drop_path_rate: 0.3 + ramped_dropout_rate: false + block_type: "convnext" + pooling_type: "average" + final_pool: true + use_batchnorm: false + activation: "leaky_relu" + output_type: "flatten" -# Task sizes -num_cs_classes: 7 -num_sg_classes: 230 -num_lp_outputs: 6 + heads: + head_dropout: 0.5 + cs_hidden: [2300, 1150] + sg_hidden: [2300, 1150] + lp_hidden: [512, 256] -# LP output bounds -lp_bounds_min: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] -lp_bounds_max: [300.0, 300.0, 300.0, 180.0, 180.0, 180.0] -bound_lp_with_sigmoid: true + tasks: + num_cs_classes: 7 + num_sg_classes: 230 + num_lp_outputs: 6 -# Loss weights -lambda_cs: 1.0 -lambda_sg: 1.0 -lambda_lp: 1.0 + lp_bounds_min: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + lp_bounds_max: [300.0, 300.0, 300.0, 180.0, 180.0, 180.0] + bound_lp_with_sigmoid: true -# Optional GEMD term on SG -gemd_mu: 0.0 -gemd_distance_matrix_path: null + loss: + lambda_cs: 1.0 + lambda_sg: 1.0 + lambda_lp: 1.0 -# Optimizer (paper): AdamW, lr=2e-4, wd=0.01 -lr: 0.0002 -weight_decay: 0.01 -use_adamw: true -gradient_clip_val: 1.0 -gradient_clip_algorithm: "norm" + gemd_mu: 0.0 + gemd_distance_matrix_path: null -# --- Noise augmentation (training split only; matches paper) --- -# If provided, noise is applied dynamically per-sample in the DataModule using the same -# sequencing as the paper: Poisson -> normalize -> add Gaussian -> renormalize -> rescale. -# Set ranges to None to disable. -noise_poisson_range: [1.0, 100.0] -noise_gaussian_range: [0.001, 0.1] +optimizer: + lr: 0.0002 + weight_decay: 0.01 + use_adamw: true + gradient_clip_val: 1.0 + gradient_clip_algorithm: "norm" -# Standardize after noise to match OG CLI (--standardize-to 0 100) -standardize_to: [0.0, 100.0] -# --- Logging --- -logger: "mlflow" -csv_logger_name: "model_logs_convnext_paper" -mlflow_experiment_name: "AlphaDiffract_Paper_ConvNeXt" -mlflow_tracking_uri: null -mlflow_run_name: "ConvNeXt_Paper_Run" +trainer: + default_root_dir: "outputs/convnext_paper" + max_epochs: 100 + accumulate_grad_batches: 1 + precision: "32" # match OG (AMP disabled) + accelerator: "gpu" + devices: 1 + log_every_n_steps: 200 + deterministic: false + benchmark: true -# --- Trainer settings --- -default_root_dir: "outputs/convnext_paper" -max_epochs: 100 -accumulate_grad_batches: 1 -precision: "32" # match OG (AMP disabled) -accelerator: "gpu" -devices: 1 -log_every_n_steps: 200 -deterministic: false -benchmark: true +logging: + logger: "mlflow" + csv_logger_name: "model_logs_convnext_paper" + mlflow_experiment_name: "AlphaDiffract_Paper_ConvNeXt" + mlflow_tracking_uri: null + mlflow_run_name: "ConvNeXt_Paper_Run" -# --- Checkpointing --- -monitor: "val/loss" -mode: "min" -save_top_k: 1 -every_n_epochs: 1 - -# --- Evaluation --- -resume_from: -test_after_train: true +checkpointing: + monitor: "val/loss" + mode: "min" + save_top_k: 1 + every_n_epochs: 1 + + resume_from: null + test_after_train: true diff --git a/src/trainer/dataset/datamodule.py b/src/trainer/dataset/datamodule.py index 6b01dbc..9e5f965 100644 --- a/src/trainer/dataset/datamodule.py +++ b/src/trainer/dataset/datamodule.py @@ -130,12 +130,15 @@ class NpyDataModule(pl.LightningDataModule): def __init__( self, - manifest_dir: str = "data/manifests", - batch_size: int = 32, - num_workers: int = 4, - pin_memory: bool = True, - persistent_workers: bool = True, - prefetch_factor: Optional[int] = None, + manifest_dir: str, + batch_size: int, + num_workers: int, + pin_memory: bool, + persistent_workers: bool, + prefetch_factor: Optional[int], + train_file: str, + val_file: str, + test_file: str, collate_fn: Optional[Callable] = None, dataset_cls: type = NpyManifestDataset, dataset_kwargs: Optional[Dict[str, Any]] = None, @@ -146,10 +149,6 @@ def __init__( val_ratio: float = 0.1, test_ratio: float = 0.1, seed: int = 42, - # Custom manifest filenames (within manifest_dir) - train_file: str = "train.jsonl", - val_file: str = "val.jsonl", - test_file: str = "test.jsonl", # Optional: add a second validation manifest file (e.g., "rruff.jsonl") extra_val_file: Optional[str] = None, # Optional noise augmentation for training split only diff --git a/src/trainer/dataset/dataset.py b/src/trainer/dataset/dataset.py index e260194..ad7a557 100644 --- a/src/trainer/dataset/dataset.py +++ b/src/trainer/dataset/dataset.py @@ -149,16 +149,16 @@ class NpyManifestDataset(Dataset): def __init__( self, manifest_path: str, + dtype: torch.dtype, + mmap_mode: Optional[str], + return_meta: bool, + validate_paths: bool, + extract_labels: bool, + allow_pickle: bool, + floor_at_zero: bool, + normalize_log1p: bool, transform: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, - dtype: torch.dtype = torch.float32, - mmap_mode: Optional[str] = "r", - return_meta: bool = True, - validate_paths: bool = False, - extract_labels: bool = False, labels_key_map: Optional[Dict[str, List[str]]] = None, - allow_pickle: bool = True, - floor_at_zero: bool = True, - normalize_log1p: bool = False, ) -> None: super().__init__() self.manifest_path = manifest_path diff --git a/src/trainer/model/model.py b/src/trainer/model/model.py index b13808e..24d05f2 100644 --- a/src/trainer/model/model.py +++ b/src/trainer/model/model.py @@ -40,10 +40,10 @@ class ConvNeXtBlock1D(nn.Module): def __init__( self, dim: int, - kernel_size: int = 7, - drop_path: float = 0.0, - layer_scale_init_value: float = 1e-6, - activation: nn.Module = nn.GELU, + kernel_size: int, + drop_path: float, + layer_scale_init_value: float, + activation: nn.Module, ): super().__init__() # depthwise 1D conv @@ -72,10 +72,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -# ----------------------------- -# ConvNeXt Block Adaptor for Multiscale CNN -# Matches OG ConvNextBlock1DAdaptorForMultiscaleCNN behavior -# ----------------------------- class ConvNextBlock1DAdaptor(nn.Module): """ Adaptor that wraps ConvNeXtBlock1D to match OG MultiscaleCNNBackbone behavior. @@ -86,13 +82,13 @@ def __init__( in_channels: int, out_channels: int, kernel_size: int, - stride: int = 1, - dropout: float = 0.0, - use_batchnorm: bool = False, - activation: nn.Module = nn.LeakyReLU, - layer_scale_init_value: float = 0.0, - drop_path_rate: float = 0.3, - block_type: str = "convnext", + stride: int, + dropout: float, + use_batchnorm: bool, + activation: nn.Module, + layer_scale_init_value: float, + drop_path_rate: float, + block_type: str, ): super().__init__() @@ -196,20 +192,20 @@ def make_mlp( class MultiscaleCNNBackbone1D(nn.Module): def __init__( self, - dim_in: int = 8192, - channels: Tuple[int, ...] = (80, 80, 80), - kernel_sizes: Tuple[int, ...] = (100, 50, 25), - strides: Tuple[int, ...] = (5, 5, 5), - dropout_rate: float = 0.3, - ramped_dropout_rate: bool = False, - block_type: str = "convnext", - pooling_type: str = "average", - final_pool: bool = True, - use_batchnorm: bool = False, - activation: nn.Module = nn.LeakyReLU, - output_type: str = "flatten", - layer_scale_init_value: float = 0.0, - drop_path_rate: float = 0.3, + dim_in: int, + channels: Tuple[int, ...], + kernel_sizes: Tuple[int, ...], + strides: Tuple[int, ...], + dropout_rate: float, + ramped_dropout_rate: bool, + block_type: str, + pooling_type: str, + final_pool: bool, + use_batchnorm: bool, + activation: nn.Module, + output_type: str, + layer_scale_init_value: float, + drop_path_rate: float, ): super().__init__() assert len(channels) == len(kernel_sizes) == len(strides), "channels, kernel_sizes, strides must match lengths" @@ -289,46 +285,46 @@ class AlphaDiffractMultiscaleLightning(pl.LightningModule): def __init__( self, # Backbone params (OG-style) - dim_in: int = 8192, - channels: Tuple[int, ...] = (80, 80, 80), - kernel_sizes: Tuple[int, ...] = (100, 50, 25), - strides: Tuple[int, ...] = (5, 5, 5), - dropout_rate: float = 0.3, - ramped_dropout_rate: bool = False, - block_type: str = "convnext", - pooling_type: str = "average", - final_pool: bool = True, - use_batchnorm: bool = False, - activation: nn.Module = nn.LeakyReLU, - output_type: str = "flatten", - layer_scale_init_value: float = 1e-6, - drop_path_rate: float = 0.0, + dim_in: int, + channels: Tuple[int, ...], + kernel_sizes: Tuple[int, ...], + strides: Tuple[int, ...], + dropout_rate: float, + ramped_dropout_rate: bool, + block_type: str, + pooling_type: str, + final_pool: bool, + use_batchnorm: bool, + activation: nn.Module, + output_type: str, + layer_scale_init_value: float, + drop_path_rate: float, # Heads - head_dropout: float = 0.5, - cs_hidden: Optional[Tuple[int, ...]] = (2300, 1150), - sg_hidden: Optional[Tuple[int, ...]] = (2300, 1150), - lp_hidden: Optional[Tuple[int, ...]] = (512, 256), + head_dropout: float, + cs_hidden: Optional[Tuple[int, ...]], + sg_hidden: Optional[Tuple[int, ...]], + lp_hidden: Optional[Tuple[int, ...]], # Task sizes - num_cs_classes: int = 7, - num_sg_classes: int = 230, - num_lp_outputs: int = 6, + num_cs_classes: int, + num_sg_classes: int, + num_lp_outputs: int, # LP bounding - lp_bounds_min: Tuple[float, float, float, float, float, float] = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0), - lp_bounds_max: Tuple[float, float, float, float, float, float] = (500.0, 500.0, 500.0, 180.0, 180.0, 180.0), - bound_lp_with_sigmoid: bool = True, + lp_bounds_min: Tuple[float, float, float, float, float, float], + lp_bounds_max: Tuple[float, float, float, float, float, float], + bound_lp_with_sigmoid: bool, # Loss weights - lambda_cs: float = 1.0, - lambda_sg: float = 1.0, - lambda_lp: float = 1.0, + lambda_cs: float, + lambda_sg: float, + lambda_lp: float, # Optimizer - lr: float = 2e-4, - weight_decay: float = 1e-2, - use_adamw: bool = True, + lr: float, + weight_decay: float, + use_adamw: bool, ): super().__init__() self.save_hyperparameters() diff --git a/src/trainer/train.py b/src/trainer/train.py index ef337a0..c8c2522 100644 --- a/src/trainer/train.py +++ b/src/trainer/train.py @@ -48,82 +48,111 @@ def _to_dtype(name: str) -> torch.dtype: def build_datamodule_from_cfg(cfg: Dict[str, Any]) -> NpyDataModule: + data_cfg = cfg["data"] + loader_cfg = data_cfg["loader"] + prep_cfg = data_cfg["preprocessing"] + aug_cfg = data_cfg["augmentation"] + # Dataset kwargs from config dataset_kwargs = { - "dtype": _to_dtype(cfg["dtype"]), - "mmap_mode": cfg["mmap_mode"], + "dtype": _to_dtype(prep_cfg["dtype"]), + "mmap_mode": prep_cfg["mmap_mode"], "return_meta": True, - "validate_paths": cfg["validate_paths"], - "extract_labels": cfg["extract_labels"], - "allow_pickle": cfg["allow_pickle"], - "floor_at_zero": cfg["floor_at_zero"], - "normalize_log1p": cfg["normalize_log1p"], + "validate_paths": prep_cfg["validate_paths"], + "extract_labels": prep_cfg["extract_labels"], + "allow_pickle": prep_cfg["allow_pickle"], + "floor_at_zero": prep_cfg["floor_at_zero"], + "normalize_log1p": prep_cfg["normalize_log1p"], } - labels_key_map = cfg.get("labels_key_map") + labels_key_map = prep_cfg["labels_key_map"] if labels_key_map is not None: dataset_kwargs["labels_key_map"] = labels_key_map dm = NpyDataModule( - manifest_dir=cfg["manifest_dir"], - batch_size=cfg["batch_size"], - num_workers=cfg["num_workers"], - pin_memory=cfg["pin_memory"], - persistent_workers=cfg["persistent_workers"] and cfg["num_workers"] > 0, + manifest_dir=data_cfg["manifest_dir"], + batch_size=loader_cfg["batch_size"], + num_workers=loader_cfg["num_workers"], + pin_memory=loader_cfg["pin_memory"], + persistent_workers=loader_cfg["persistent_workers"] and loader_cfg["num_workers"] > 0, + prefetch_factor=loader_cfg["prefetch_factor"], + train_file=loader_cfg["train_file"], + val_file=loader_cfg["val_file"], + test_file=loader_cfg["test_file"], dataset_kwargs=dataset_kwargs, - dataset_root=cfg["dataset_root"], - auto_generate_manifests=cfg["auto_generate_manifests"], - train_ratio=cfg["train_ratio"], - val_ratio=cfg["val_ratio"], - test_ratio=cfg["test_ratio"], - seed=cfg["seed"], - extra_val_file=cfg.get("extra_val_file"), + dataset_root=data_cfg["dataset_root"], + auto_generate_manifests=data_cfg["auto_generate_manifests"], + train_ratio=data_cfg["train_ratio"], + val_ratio=data_cfg["val_ratio"], + test_ratio=data_cfg["test_ratio"], + seed=data_cfg["seed"], + extra_val_file=data_cfg["extra_val_file"], # Optional noise augmentation: apply to training split only - noise_poisson_range=tuple(cfg.get("noise_poisson_range")) if cfg.get("noise_poisson_range") is not None else None, - noise_gaussian_range=tuple(cfg.get("noise_gaussian_range")) if cfg.get("noise_gaussian_range") is not None else None, + noise_poisson_range=tuple(aug_cfg["noise_poisson_range"]) if aug_cfg["noise_poisson_range"] is not None else None, + noise_gaussian_range=tuple(aug_cfg["noise_gaussian_range"]) if aug_cfg["noise_gaussian_range"] is not None else None, # Optional standardization after noise to match OG runs - standardize_to=tuple(cfg.get("standardize_to")) if cfg.get("standardize_to") is not None else None, + standardize_to=tuple(aug_cfg["standardize_to"]) if aug_cfg["standardize_to"] is not None else None, ) return dm def build_model_from_cfg(cfg: Dict[str, Any]): - model_type = cfg.get("model_type", "convnext").lower() + model_cfg = cfg["model"] + backbone_cfg = model_cfg["backbone"] + heads_cfg = model_cfg["heads"] + tasks_cfg = model_cfg["tasks"] + loss_cfg = model_cfg["loss"] + optim_cfg = cfg["optimizer"] + + model_type = model_cfg["type"].lower() + + act_name = backbone_cfg["activation"].lower() + if act_name == "leaky_relu": + activation = torch.nn.LeakyReLU + elif act_name == "relu": + activation = torch.nn.ReLU + elif act_name == "gelu": + activation = torch.nn.GELU + else: + raise ValueError(f"Unsupported activation '{act_name}'") + if model_type == "multiscale": return AlphaDiffractMultiscaleLightning( # Map OG-style multiscale CNN params from cfg (use dims as channels) - channels=tuple(cfg["dims"]), - kernel_sizes=tuple(cfg["kernel_sizes"]), - strides=tuple(cfg["strides"]), - dropout_rate=cfg["dropout_rate"], - ramped_dropout_rate=cfg.get("ramped_dropout_rate", False), - block_type=cfg.get("block_type", "convnext"), - pooling_type=cfg.get("pooling_type", "average"), - final_pool=cfg.get("final_pool", True), - use_batchnorm=cfg.get("use_batchnorm", False), - output_type=cfg.get("output_type", "flatten"), - layer_scale_init_value=cfg["layer_scale_init_value"], - drop_path_rate=cfg["drop_path_rate"], - # Heads (match OG JSON) - head_dropout=cfg["head_dropout"], - cs_hidden=tuple(cfg["cs_hidden"]), - sg_hidden=tuple(cfg["sg_hidden"]), - lp_hidden=tuple(cfg["lp_hidden"]), - # Task sizes - num_cs_classes=cfg["num_cs_classes"], - num_sg_classes=cfg["num_sg_classes"], - num_lp_outputs=cfg["num_lp_outputs"], - # LP bounds - lp_bounds_min=tuple(cfg["lp_bounds_min"]), - lp_bounds_max=tuple(cfg["lp_bounds_max"]), - bound_lp_with_sigmoid=cfg["bound_lp_with_sigmoid"], - # Loss weights - lambda_cs=cfg["lambda_cs"], - lambda_sg=cfg["lambda_sg"], - lambda_lp=cfg["lambda_lp"], - # Optimizer - lr=cfg["lr"], - weight_decay=cfg["weight_decay"], - use_adamw=cfg["use_adamw"], + dim_in=backbone_cfg["dim_in"], + channels=tuple(backbone_cfg["dims"]), + kernel_sizes=tuple(backbone_cfg["kernel_sizes"]), + strides=tuple(backbone_cfg["strides"]), + dropout_rate=backbone_cfg["dropout_rate"], + ramped_dropout_rate=backbone_cfg["ramped_dropout_rate"], + block_type=backbone_cfg["block_type"], + pooling_type=backbone_cfg["pooling_type"], + final_pool=backbone_cfg["final_pool"], + use_batchnorm=backbone_cfg["use_batchnorm"], + activation=activation, + output_type=backbone_cfg["output_type"], + layer_scale_init_value=backbone_cfg["layer_scale_init_value"], + drop_path_rate=backbone_cfg["drop_path_rate"], + + head_dropout=heads_cfg["head_dropout"], + cs_hidden=tuple(heads_cfg["cs_hidden"]), + sg_hidden=tuple(heads_cfg["sg_hidden"]), + lp_hidden=tuple(heads_cfg["lp_hidden"]), + + num_cs_classes=tasks_cfg["num_cs_classes"], + num_sg_classes=tasks_cfg["num_sg_classes"], + num_lp_outputs=tasks_cfg["num_lp_outputs"], + + lp_bounds_min=tuple(tasks_cfg["lp_bounds_min"]), + lp_bounds_max=tuple(tasks_cfg["lp_bounds_max"]), + bound_lp_with_sigmoid=tasks_cfg["bound_lp_with_sigmoid"], + + lambda_cs=loss_cfg["lambda_cs"], + lambda_sg=loss_cfg["lambda_sg"], + lambda_lp=loss_cfg["lambda_lp"], + + lr=optim_cfg["lr"], + weight_decay=optim_cfg["weight_decay"], + use_adamw=optim_cfg["use_adamw"], ) else: raise ValueError(f"Unsupported model_type '{model_type}'. Expected 'multiscale'.") @@ -204,45 +233,50 @@ def on_exception(self, trainer, pl_module, exception): def build_trainer_from_cfg(cfg: Dict[str, Any], raw_config_path: Optional[str] = None) -> Trainer: + trainer_cfg = cfg["trainer"] + log_cfg = cfg["logging"] + ckpt_cfg = cfg["checkpointing"] + optim_cfg = cfg["optimizer"] + ckpt_cb = ModelCheckpoint( - monitor=cfg["monitor"], - mode=cfg["mode"], - save_top_k=cfg["save_top_k"], - dirpath=os.path.join(cfg["default_root_dir"], "checkpoints"), + monitor=ckpt_cfg["monitor"], + mode=ckpt_cfg["mode"], + save_top_k=ckpt_cfg["save_top_k"], + dirpath=os.path.join(trainer_cfg["default_root_dir"], "checkpoints"), filename="epoch{epoch:03d}-val_loss{val/loss:.4f}", save_last=True, - every_n_epochs=cfg["every_n_epochs"], + every_n_epochs=ckpt_cfg["every_n_epochs"], auto_insert_metric_name=False, ) lr_cb = LearningRateMonitor(logging_interval="epoch") # Configure logger from config logger = None - if cfg["logger"] == "csv": - logger = CSVLogger(save_dir=cfg["default_root_dir"], name=cfg["csv_logger_name"]) - elif cfg["logger"] == "mlflow": + if log_cfg["logger"] == "csv": + logger = CSVLogger(save_dir=trainer_cfg["default_root_dir"], name=log_cfg["csv_logger_name"]) + elif log_cfg["logger"] == "mlflow": if MLFlowLogger is None: raise ImportError("MLFlowLogger requested but 'mlflow' is not installed. Install with `pip install mlflow`.") logger = MLFlowLogger( - experiment_name=cfg["mlflow_experiment_name"], - tracking_uri=cfg["mlflow_tracking_uri"], - run_name=cfg["mlflow_run_name"], + experiment_name=log_cfg["mlflow_experiment_name"], + tracking_uri=log_cfg["mlflow_tracking_uri"], + run_name=log_cfg["mlflow_run_name"], ) trainer = Trainer( - default_root_dir=cfg["default_root_dir"], - max_epochs=cfg["max_epochs"], - accelerator=cfg["accelerator"], - devices=cfg["devices"], - precision=cfg["precision"], - accumulate_grad_batches=cfg["accumulate_grad_batches"], + default_root_dir=trainer_cfg["default_root_dir"], + max_epochs=trainer_cfg["max_epochs"], + accelerator=trainer_cfg["accelerator"], + devices=trainer_cfg["devices"], + precision=trainer_cfg["precision"], + accumulate_grad_batches=trainer_cfg["accumulate_grad_batches"], callbacks=[ckpt_cb, lr_cb, ConfigArtifactLogger(raw_config_path), RunCheckpointDirCallback(), MlflowShutdownCallback()], logger=logger, - log_every_n_steps=cfg["log_every_n_steps"], - deterministic=cfg["deterministic"], - benchmark=cfg["benchmark"], - gradient_clip_val=cfg.get("gradient_clip_val", 0.0), - gradient_clip_algorithm=cfg.get("gradient_clip_algorithm", "norm"), + log_every_n_steps=trainer_cfg["log_every_n_steps"], + deterministic=trainer_cfg["deterministic"], + benchmark=trainer_cfg["benchmark"], + gradient_clip_val=optim_cfg["gradient_clip_val"], + gradient_clip_algorithm=optim_cfg["gradient_clip_algorithm"], val_check_interval=0.5 ) return trainer @@ -253,7 +287,7 @@ def main(): args = parse_args() cfg = load_config(args.config) - seed_everything(cfg["seed"], workers=True) + seed_everything(cfg["data"]["seed"], workers=True) torch.set_float32_matmul_precision('high') @@ -262,11 +296,11 @@ def main(): trainer = build_trainer_from_cfg(cfg, raw_config_path=args.config) # Train - resume_from: Optional[str] = cfg.get("resume_from") + resume_from: Optional[str] = cfg["checkpointing"]["resume_from"] trainer.fit(model, datamodule=dm, ckpt_path=resume_from) # Test best model if requested - if cfg["test_after_train"]: + if cfg["checkpointing"]["test_after_train"]: ckpt_path = trainer.checkpoint_callback.best_model_path if trainer.checkpoint_callback else None trainer.test(model=model, datamodule=dm, ckpt_path=ckpt_path or "best") From 255dfa7420796a4f4bb5fa73f89bed947ce6fd64 Mon Sep 17 00:00:00 2001 From: linked-liszt Date: Wed, 19 Nov 2025 11:06:43 -0600 Subject: [PATCH 18/18] Refactor: Continued cleanup --- configs/trainer.yaml | 6 ++-- src/trainer/dataset/datamodule.py | 54 +------------------------------ src/trainer/dataset/dataset.py | 6 ++++ src/trainer/model/model.py | 36 ++------------------- src/trainer/train.py | 15 +++++---- 5 files changed, 21 insertions(+), 96 deletions(-) diff --git a/configs/trainer.yaml b/configs/trainer.yaml index 7d41407..a3113f1 100644 --- a/configs/trainer.yaml +++ b/configs/trainer.yaml @@ -1,6 +1,3 @@ -# AlphaDiffract trainer configuration — ConvNeXt (paper-matching lightweight variant) -# Use with: PYTHONPATH=src python -m trainer.train_paper configs/trainer_convnext_paper.yaml - data: manifest_dir: "../../../ad_data/manifests" dataset_root: "../../../ad_data/data/dataset" @@ -40,7 +37,8 @@ data: dtype: "float32" mmap_mode: null floor_at_zero: true - normalize_log1p: False # paper used log1p preprocessing + normalize_log1p: False + shift_labels: true augmentation: noise_poisson_range: [1.0, 100.0] diff --git a/src/trainer/dataset/datamodule.py b/src/trainer/dataset/datamodule.py index 9e5f965..4c28106 100644 --- a/src/trainer/dataset/datamodule.py +++ b/src/trainer/dataset/datamodule.py @@ -4,28 +4,6 @@ This module wires NpyManifestDataset to PyTorch Lightning and can optionally auto-generate manifests from a dataset root if they are missing. -Typical usage: - from pytorch_lightning import Trainer - from dataset.datamodule import NpyDataModule - - dm = NpyDataModule( - manifest_dir="data/manifests", - batch_size=64, - num_workers=8, - pin_memory=True, - persistent_workers=True, - # Optional: auto-generate manifests if missing - dataset_root="data/dataset", - auto_generate_manifests=True, - train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, seed=42, - # Dataset-specific kwargs - dataset_kwargs={"dtype": torch.float32, "mmap_mode": "r", "return_meta": True}, - ) - - trainer = Trainer(max_epochs=10, accelerator="auto", devices="auto") - trainer.fit(model, dm) - trainer.test(model, dm) - Notes: - Splitting is performed by material ID when generating manifests (never per-file). - Manifests avoid scanning the entire dataset during training. @@ -80,34 +58,6 @@ def _transform(x: torch.Tensor) -> torch.Tensor: return _transform - -def _shift_one_based_collate(batch): - """ - Collate function that uses PyTorch's default_collate, then unconditionally shifts - cs and sg labels by -1 (assumes 1-based input). Performed under torch.no_grad to avoid - constructing any graphs. - """ - collated = default_collate(batch) - with torch.no_grad(): - def _shift(t): - return t - 1 if torch.is_tensor(t) else t - - if isinstance(collated, dict): - if "cs" in collated: - collated["cs"] = _shift(collated["cs"]) - if "sg" in collated: - collated["sg"] = _shift(collated["sg"]) - elif isinstance(collated, (list, tuple)): - # Tuple-based batches: (x, cs, sg, [lp]) - lst = list(collated) - if len(lst) >= 2: - lst[1] = _shift(lst[1]) - if len(lst) >= 3: - lst[2] = _shift(lst[2]) - collated = type(collated)(lst) - return collated - - class NpyDataModule(pl.LightningDataModule): """ LightningDataModule that reads train/val/test JSONL manifests and constructs DataLoaders. @@ -165,9 +115,7 @@ def __init__( self.persistent_workers = persistent_workers self.prefetch_factor = prefetch_factor self.collate_fn = collate_fn - if self.collate_fn is None: - self.collate_fn = _shift_one_based_collate - + self.dataset_cls = dataset_cls self.dataset_kwargs = dataset_kwargs or {} diff --git a/src/trainer/dataset/dataset.py b/src/trainer/dataset/dataset.py index ad7a557..f830210 100644 --- a/src/trainer/dataset/dataset.py +++ b/src/trainer/dataset/dataset.py @@ -157,6 +157,7 @@ def __init__( allow_pickle: bool, floor_at_zero: bool, normalize_log1p: bool, + shift_labels: bool, transform: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, labels_key_map: Optional[Dict[str, List[str]]] = None, ) -> None: @@ -170,6 +171,7 @@ def __init__( self.allow_pickle = allow_pickle self.floor_at_zero = floor_at_zero self.normalize_log1p = normalize_log1p + self.shift_labels = shift_labels # Default key mapping for extracting fields from embedded containers # Simplified: single string keys, no search lists self.labels_key_map = labels_key_map or { @@ -311,8 +313,12 @@ def _get_exact(container, key: str): # Attach labels if present/extracted if self.extract_labels: if y_cs_t is not None: + if self.shift_labels: + y_cs_t = y_cs_t - 1 sample["cs"] = y_cs_t if y_sg_t is not None: + if self.shift_labels: + y_sg_t = y_sg_t - 1 sample["sg"] = y_sg_t if y_lp_t is not None: sample["lattice_params"] = y_lp_t diff --git a/src/trainer/model/model.py b/src/trainer/model/model.py index 24d05f2..e4b80df 100644 --- a/src/trainer/model/model.py +++ b/src/trainer/model/model.py @@ -3,7 +3,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -import numpy as np import pytorch_lightning as pl @@ -183,7 +182,6 @@ def make_mlp( # ----------------------------- -# OG-style Multiscale CNN Backbone (1D) with ConvNeXt-like blocks # Mirrors alphadiffract.model.MultiscaleCNNBackbone behavior: # - sequential conv stages with specified kernel_sizes and strides # - optional average/max pooling between stages and at the end @@ -212,13 +210,11 @@ def __init__( self.dim_in = dim_in self.output_type = output_type - # Build per-stage dropout schedule if ramped_dropout_rate: dropout_per_stage = torch.linspace(0.0, dropout_rate, steps=len(channels)).tolist() else: dropout_per_stage = [dropout_rate] * len(channels) - # Select pooling module if pooling_type == "average": pool_cls = nn.AvgPool1d pool_kwargs = {"kernel_size": 3, "stride": 2} @@ -231,7 +227,6 @@ def __init__( layers: List[nn.Module] = [] in_ch = 1 for i, (out_ch, k, s) in enumerate(zip(channels, kernel_sizes, strides)): - # Build stage block matching OG ConvNextBlock1DAdaptorForMultiscaleCNN stage_block = ConvNextBlock1DAdaptor( in_channels=in_ch, out_channels=out_ch, @@ -284,7 +279,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AlphaDiffractMultiscaleLightning(pl.LightningModule): def __init__( self, - # Backbone params (OG-style) dim_in: int, channels: Tuple[int, ...], kernel_sizes: Tuple[int, ...], @@ -300,36 +294,31 @@ def __init__( layer_scale_init_value: float, drop_path_rate: float, - # Heads head_dropout: float, cs_hidden: Optional[Tuple[int, ...]], sg_hidden: Optional[Tuple[int, ...]], lp_hidden: Optional[Tuple[int, ...]], - # Task sizes num_cs_classes: int, num_sg_classes: int, num_lp_outputs: int, - # LP bounding lp_bounds_min: Tuple[float, float, float, float, float, float], lp_bounds_max: Tuple[float, float, float, float, float, float], bound_lp_with_sigmoid: bool, - # Loss weights lambda_cs: float, lambda_sg: float, lambda_lp: float, - # Optimizer lr: float, weight_decay: float, use_adamw: bool, + steps_per_epoch: int, ): super().__init__() self.save_hyperparameters() - # Backbone self.backbone = MultiscaleCNNBackbone1D( dim_in=dim_in, channels=channels, @@ -348,7 +337,6 @@ def __init__( ) feat_dim = self.backbone.dim_output - # Heads self.cs_head = make_mlp( input_dim=feat_dim, hidden_dims=cs_hidden, @@ -371,22 +359,20 @@ def __init__( output_activation=None, ) - # Losses and bounds self.ce = nn.CrossEntropyLoss() self.mse = nn.MSELoss() self.register_buffer("lp_min", torch.tensor(lp_bounds_min, dtype=torch.float32)) self.register_buffer("lp_max", torch.tensor(lp_bounds_max, dtype=torch.float32)) self.bound_lp_with_sigmoid = bound_lp_with_sigmoid - # weights and optim config self.lambda_cs = lambda_cs self.lambda_sg = lambda_sg self.lambda_lp = lambda_lp self.lr = lr self.weight_decay = weight_decay self.use_adamw = use_adamw + self.steps_per_epoch = steps_per_epoch - # Task sizes self.num_cs_classes = num_cs_classes self.num_sg_classes = num_sg_classes self.num_lp_outputs = num_lp_outputs @@ -523,23 +509,7 @@ def configure_optimizers(self): else: optimizer = torch.optim.Adam(params, lr=self.lr, weight_decay=self.weight_decay) - # Compute steps per epoch to match OG scheduler semantics: - # step_size_up = 6 * iterations_per_epoch - steps_per_epoch = None - try: - if hasattr(self, "trainer") and self.trainer is not None: - total_steps = getattr(self.trainer, "estimated_stepping_batches", None) - max_epochs = getattr(self.trainer, "max_epochs", None) - if total_steps is not None and max_epochs is not None and max_epochs > 0: - steps_per_epoch = max(1, total_steps // max_epochs) - except Exception: - pass - - if steps_per_epoch is None: - # Fallback if trainer hooks are unavailable; use a conservative default - steps_per_epoch = 100 - - step_size_up = int(6 * steps_per_epoch) + step_size_up = int(6 * self.steps_per_epoch) scheduler = torch.optim.lr_scheduler.CyclicLR( optimizer, diff --git a/src/trainer/train.py b/src/trainer/train.py index c8c2522..f4b51e1 100644 --- a/src/trainer/train.py +++ b/src/trainer/train.py @@ -6,21 +6,18 @@ import yaml from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, Callback -import signal from pytorch_lightning.loggers import CSVLogger try: from pytorch_lightning.loggers import MLFlowLogger except Exception: MLFlowLogger = None -# Project imports (expect PYTHONPATH=src or run via `python -m trainer.train_paper`) from dataset import NpyDataModule from model.model import AlphaDiffractMultiscaleLightning def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description="Train AlphaDiffract paper model (config-required)") - # Require a config file path with no script-side defaults p.add_argument("config", type=str, help="Path to trainer config YAML (e.g., configs/trainer.yaml)") return p.parse_args() @@ -63,6 +60,7 @@ def build_datamodule_from_cfg(cfg: Dict[str, Any]) -> NpyDataModule: "allow_pickle": prep_cfg["allow_pickle"], "floor_at_zero": prep_cfg["floor_at_zero"], "normalize_log1p": prep_cfg["normalize_log1p"], + "shift_labels": prep_cfg["shift_labels"], } labels_key_map = prep_cfg["labels_key_map"] if labels_key_map is not None: @@ -95,7 +93,7 @@ def build_datamodule_from_cfg(cfg: Dict[str, Any]) -> NpyDataModule: return dm -def build_model_from_cfg(cfg: Dict[str, Any]): +def build_model_from_cfg(cfg: Dict[str, Any], steps_per_epoch: int): model_cfg = cfg["model"] backbone_cfg = model_cfg["backbone"] heads_cfg = model_cfg["heads"] @@ -149,10 +147,11 @@ def build_model_from_cfg(cfg: Dict[str, Any]): lambda_cs=loss_cfg["lambda_cs"], lambda_sg=loss_cfg["lambda_sg"], lambda_lp=loss_cfg["lambda_lp"], - + lr=optim_cfg["lr"], weight_decay=optim_cfg["weight_decay"], use_adamw=optim_cfg["use_adamw"], + steps_per_epoch=steps_per_epoch, ) else: raise ValueError(f"Unsupported model_type '{model_type}'. Expected 'multiscale'.") @@ -292,7 +291,11 @@ def main(): torch.set_float32_matmul_precision('high') dm = build_datamodule_from_cfg(cfg) - model = build_model_from_cfg(cfg) + # Explicitly setup datamodule to calculate steps_per_epoch + dm.setup("fit") + steps_per_epoch = len(dm.train_dataloader()) + + model = build_model_from_cfg(cfg, steps_per_epoch=steps_per_epoch) trainer = build_trainer_from_cfg(cfg, raw_config_path=args.config) # Train